diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 1582d862..e34945eb 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -196,6 +196,13 @@ class CitationConfig(BaseModel): enabled: bool = Field(default=False) +class Citation(BaseModel): + document_id: str + file_name: str + knowledge_id: str + score: float + + class WebSearchConfig(BaseModel): """联网搜索配置""" enabled: bool = Field(default=False) diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 7f91dada..3dda6fc0 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -82,6 +82,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -93,7 +99,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) memory_flag = False if memory: memory_tools, memory_flag = self.agent_service.load_memory_config( @@ -230,7 +237,7 @@ class AppChatService: }), "elapsed_time": elapsed_time, "suggested_questions": suggested_questions, - "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), + "citations": self.agent_service._filter_citations(features_config, citations_collector), "audio_url": audio_url, "audio_status": "pending" } @@ -283,6 +290,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -295,7 +308,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -409,7 +423,7 @@ class AppChatService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self.agent_service._filter_citations(features_config, []) + end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector) # 保存消息 human_meta = { diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 09a202cd..ac34b4de 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context from app.models import AgentConfig, ModelConfig, ModelType from app.repositories.tool_repository import ToolRepository -from app.schemas.app_schema import FileInput +from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service @@ -190,13 +190,19 @@ def create_web_search_tool(web_search_config: Dict[str, Any]): return web_search_tool -def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): +def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: kb_config: 知识库配置 kb_ids: 知识库ID列表 user_id: 用户ID + citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充) + 列表元素类型为 Citation,包含字段: + - document_id: 文档唯一标识 + - file_name: 文件名 + - knowledge_id: 知识库 ID + - score: 检索相关性得分 Returns: 检索到的相关知识内容 @@ -229,6 +235,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): } ) + # 收集引用信息 + if citations_collector is not None: + seen_doc_ids = {c.get("document_id") for c in citations_collector} + for chunk in retrieve_chunks_result: + meta = chunk.metadata or {} + doc_id = meta.get("document_id") or meta.get("doc_id") + if doc_id and doc_id not in seen_doc_ids: + seen_doc_ids.add(doc_id) + citations_collector.append(Citation( + document_id=doc_id, + file_name=meta.get("file_name", ""), + knowledge_id=str(meta.get("knowledge_id", "")), + score=meta.get("score", 0) + )) + return f"检索到以下相关信息:\n\n{context}" else: logger.warning("知识库检索未找到结果") @@ -320,26 +341,26 @@ class AgentRunService: self, knowledge_retrieval_config: dict | None, user_id - ) -> list: + ) -> tuple[list, list]: + """返回 (tools, citations_collector)""" if not knowledge_retrieval_config: - return [] + return [], [] + citations_collector = [] tools = [] knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + kb_tool = create_knowledge_retrieval_tool( + knowledge_retrieval_config, kb_ids, user_id, + citations_collector=citations_collector + ) tools.append(kb_tool) - logger.debug( "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } + extra={"kb_ids": kb_ids, "tool_count": len(tools)} ) - return tools + return tools, citations_collector def load_memory_config( self, @@ -441,12 +462,12 @@ class AgentRunService: @staticmethod def _filter_citations( features_config: Dict[str, Any], - citations: List[Any] + citations: List[Citation] ) -> List[Any]: """根据 citation 开关决定是否返回引用来源""" citation_cfg = features_config.get("citation", {}) if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"): - return citations + return [cit.model_dump() for cit in citations] return [] async def run( @@ -549,7 +570,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -680,7 +702,7 @@ class AgentRunService: "suggested_questions": await self._generate_suggested_questions( features_config, result["content"], api_key_config, effective_params ) if not sub_agent else [], - "citations": self._filter_citations(features_config, result.get("citations", [])), + "citations": self._filter_citations(features_config, citations_collector), "audio_url": audio_url, "audio_status": "pending" } @@ -790,7 +812,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -943,7 +966,7 @@ class AgentRunService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self._filter_citations(features_config, []) + end_data["citations"] = self._filter_citations(features_config, citations_collector) yield self._format_sse_event("end", end_data) logger.info(