feat(agent): Opening remarks and document citation function
This commit is contained in:
@@ -196,6 +196,13 @@ class CitationConfig(BaseModel):
|
||||
enabled: bool = Field(default=False)
|
||||
|
||||
|
||||
class Citation(BaseModel):
|
||||
document_id: str
|
||||
file_name: str
|
||||
knowledge_id: str
|
||||
score: float
|
||||
|
||||
|
||||
class WebSearchConfig(BaseModel):
|
||||
"""联网搜索配置"""
|
||||
enabled: bool = Field(default=False)
|
||||
|
||||
@@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput
|
||||
from app.schemas.app_schema import FileInput, Citation
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||
from app.services import task_service
|
||||
@@ -190,7 +190,7 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
|
||||
return web_search_tool
|
||||
|
||||
|
||||
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: list = None):
|
||||
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None):
|
||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||
|
||||
Args:
|
||||
@@ -198,6 +198,11 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collec
|
||||
kb_ids: 知识库ID列表
|
||||
user_id: 用户ID
|
||||
citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充)
|
||||
列表元素类型为 Citation,包含字段:
|
||||
- document_id: 文档唯一标识
|
||||
- file_name: 文件名
|
||||
- knowledge_id: 知识库 ID
|
||||
- score: 检索相关性得分
|
||||
|
||||
Returns:
|
||||
检索到的相关知识内容
|
||||
@@ -238,12 +243,12 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collec
|
||||
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||||
if doc_id and doc_id not in seen_doc_ids:
|
||||
seen_doc_ids.add(doc_id)
|
||||
citations_collector.append({
|
||||
"document_id": doc_id,
|
||||
"file_name": meta.get("file_name", ""),
|
||||
"knowledge_id": str(meta.get("knowledge_id", "")),
|
||||
"score": meta.get("score", 0),
|
||||
})
|
||||
citations_collector.append(Citation(
|
||||
document_id=doc_id,
|
||||
file_name=meta.get("file_name", ""),
|
||||
knowledge_id=str(meta.get("knowledge_id", "")),
|
||||
score=meta.get("score", 0)
|
||||
))
|
||||
|
||||
return f"检索到以下相关信息:\n\n{context}"
|
||||
else:
|
||||
@@ -344,7 +349,7 @@ class AgentRunService:
|
||||
citations_collector = []
|
||||
tools = []
|
||||
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
||||
kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")]
|
||||
if kb_ids:
|
||||
kb_tool = create_knowledge_retrieval_tool(
|
||||
knowledge_retrieval_config, kb_ids, user_id,
|
||||
@@ -457,12 +462,12 @@ class AgentRunService:
|
||||
@staticmethod
|
||||
def _filter_citations(
|
||||
features_config: Dict[str, Any],
|
||||
citations: List[Any]
|
||||
citations: List[Citation]
|
||||
) -> List[Any]:
|
||||
"""根据 citation 开关决定是否返回引用来源"""
|
||||
citation_cfg = features_config.get("citation", {})
|
||||
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
|
||||
return citations
|
||||
return [cit.model_dump() for cit in citations]
|
||||
return []
|
||||
|
||||
async def run(
|
||||
|
||||
Reference in New Issue
Block a user