diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 7f91dada..3dda6fc0 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -82,6 +82,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -93,7 +99,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) memory_flag = False if memory: memory_tools, memory_flag = self.agent_service.load_memory_config( @@ -230,7 +237,7 @@ class AppChatService: }), "elapsed_time": elapsed_time, "suggested_questions": suggested_questions, - "citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), + "citations": self.agent_service._filter_citations(features_config, citations_collector), "audio_url": audio_url, "audio_status": "pending" } @@ -283,6 +290,12 @@ class AppChatService: ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt + # opening_statement:首轮对话注入开场白 + is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1) + system_prompt = self.agent_service._inject_opening_statement( + features_config, system_prompt, is_new_conversation + ) + # 准备工具列表 tools = [] @@ -295,7 +308,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -409,7 +423,7 @@ class AppChatService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self.agent_service._filter_citations(features_config, []) + end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector) # 保存消息 human_meta = { diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 09a202cd..d71d5c24 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -190,13 +190,14 @@ def create_web_search_tool(web_search_config: Dict[str, Any]): return web_search_tool -def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): +def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: list = None): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: kb_config: 知识库配置 kb_ids: 知识库ID列表 user_id: 用户ID + citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充) Returns: 检索到的相关知识内容 @@ -229,6 +230,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): } ) + # 收集引用信息 + if citations_collector is not None: + seen_doc_ids = {c.get("document_id") for c in citations_collector} + for chunk in retrieve_chunks_result: + meta = chunk.metadata or {} + doc_id = meta.get("document_id") or meta.get("doc_id") + if doc_id and doc_id not in seen_doc_ids: + seen_doc_ids.add(doc_id) + citations_collector.append({ + "document_id": doc_id, + "file_name": meta.get("file_name", ""), + "knowledge_id": str(meta.get("knowledge_id", "")), + "score": meta.get("score", 0), + }) + return f"检索到以下相关信息:\n\n{context}" else: logger.warning("知识库检索未找到结果") @@ -320,26 +336,26 @@ class AgentRunService: self, knowledge_retrieval_config: dict | None, user_id - ) -> list: + ) -> tuple[list, list]: + """返回 (tools, citations_collector)""" if not knowledge_retrieval_config: - return [] + return [], [] + citations_collector = [] tools = [] knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + kb_tool = create_knowledge_retrieval_tool( + knowledge_retrieval_config, kb_ids, user_id, + citations_collector=citations_collector + ) tools.append(kb_tool) - logger.debug( "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } + extra={"kb_ids": kb_ids, "tool_count": len(tools)} ) - return tools + return tools, citations_collector def load_memory_config( self, @@ -549,7 +565,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False if memory: @@ -680,7 +697,7 @@ class AgentRunService: "suggested_questions": await self._generate_suggested_questions( features_config, result["content"], api_key_config, effective_params ) if not sub_agent else [], - "citations": self._filter_citations(features_config, result.get("citations", [])), + "citations": self._filter_citations(features_config, citations_collector), "audio_url": audio_url, "audio_status": "pending" } @@ -790,7 +807,8 @@ class AgentRunService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) + kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id) + tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -943,7 +961,7 @@ class AgentRunService: logger.warning(f"TTS任务异常: {e}") audio_status = "failed" end_data["audio_status"] = audio_status if stream_audio_url else None - end_data["citations"] = self._filter_citations(features_config, []) + end_data["citations"] = self._filter_citations(features_config, citations_collector) yield self._format_sse_event("end", end_data) logger.info(