diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 1582d862..e34945eb 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -196,6 +196,13 @@ class CitationConfig(BaseModel): enabled: bool = Field(default=False) +class Citation(BaseModel): + document_id: str + file_name: str + knowledge_id: str + score: float + + class WebSearchConfig(BaseModel): """联网搜索配置""" enabled: bool = Field(default=False) diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index d71d5c24..ac34b4de 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context from app.models import AgentConfig, ModelConfig, ModelType from app.repositories.tool_repository import ToolRepository -from app.schemas.app_schema import FileInput +from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service @@ -190,7 +190,7 @@ def create_web_search_tool(web_search_config: Dict[str, Any]): return web_search_tool -def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: list = None): +def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: @@ -198,6 +198,11 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collec kb_ids: 知识库ID列表 user_id: 用户ID citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充) + 列表元素类型为 Citation,包含字段: + - document_id: 文档唯一标识 + - file_name: 文件名 + - knowledge_id: 知识库 ID + - score: 检索相关性得分 Returns: 检索到的相关知识内容 @@ -238,12 +243,12 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collec doc_id = meta.get("document_id") or meta.get("doc_id") if doc_id and doc_id not in seen_doc_ids: seen_doc_ids.add(doc_id) - citations_collector.append({ - "document_id": doc_id, - "file_name": meta.get("file_name", ""), - "knowledge_id": str(meta.get("knowledge_id", "")), - "score": meta.get("score", 0), - }) + citations_collector.append(Citation( + document_id=doc_id, + file_name=meta.get("file_name", ""), + knowledge_id=str(meta.get("knowledge_id", "")), + score=meta.get("score", 0) + )) return f"检索到以下相关信息:\n\n{context}" else: @@ -344,7 +349,7 @@ class AgentRunService: citations_collector = [] tools = [] knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: kb_tool = create_knowledge_retrieval_tool( knowledge_retrieval_config, kb_ids, user_id, @@ -457,12 +462,12 @@ class AgentRunService: @staticmethod def _filter_citations( features_config: Dict[str, Any], - citations: List[Any] + citations: List[Citation] ) -> List[Any]: """根据 citation 开关决定是否返回引用来源""" citation_cfg = features_config.get("citation", {}) if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"): - return citations + return [cit.model_dump() for cit in citations] return [] async def run(