feat(agent): Opening remarks and document citation function

This commit is contained in:
Timebomb2018
2026-03-26 10:06:05 +08:00
parent 30b5db1e98
commit b7a03a844f
2 changed files with 51 additions and 19 deletions

View File

@@ -82,6 +82,12 @@ class AppChatService:
) )
system_prompt = system_prompt_rendered.get_text_content() or system_prompt system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# opening_statement首轮对话注入开场白
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
system_prompt = self.agent_service._inject_opening_statement(
features_config, system_prompt, is_new_conversation
)
# 准备工具列表 # 准备工具列表
tools = [] tools = []
@@ -93,7 +99,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
tools.extend(kb_tools)
memory_flag = False memory_flag = False
if memory: if memory:
memory_tools, memory_flag = self.agent_service.load_memory_config( memory_tools, memory_flag = self.agent_service.load_memory_config(
@@ -230,7 +237,7 @@ class AppChatService:
}), }),
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"suggested_questions": suggested_questions, "suggested_questions": suggested_questions,
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])), "citations": self.agent_service._filter_citations(features_config, citations_collector),
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": "pending" "audio_status": "pending"
} }
@@ -283,6 +290,12 @@ class AppChatService:
) )
system_prompt = system_prompt_rendered.get_text_content() or system_prompt system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# opening_statement首轮对话注入开场白
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
system_prompt = self.agent_service._inject_opening_statement(
features_config, system_prompt, is_new_conversation
)
# 准备工具列表 # 准备工具列表
tools = [] tools = []
@@ -295,7 +308,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
if memory: if memory:
@@ -409,7 +423,7 @@ class AppChatService:
logger.warning(f"TTS任务异常: {e}") logger.warning(f"TTS任务异常: {e}")
audio_status = "failed" audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None end_data["audio_status"] = audio_status if stream_audio_url else None
end_data["citations"] = self.agent_service._filter_citations(features_config, []) end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector)
# 保存消息 # 保存消息
human_meta = { human_meta = {

View File

@@ -190,13 +190,14 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
return web_search_tool return web_search_tool
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: list = None):
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
Args: Args:
kb_config: 知识库配置 kb_config: 知识库配置
kb_ids: 知识库ID列表 kb_ids: 知识库ID列表
user_id: 用户ID user_id: 用户ID
citations_collector: 用于收集引用信息的列表由外部传入tool 执行时填充)
Returns: Returns:
检索到的相关知识内容 检索到的相关知识内容
@@ -229,6 +230,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
} }
) )
# 收集引用信息
if citations_collector is not None:
seen_doc_ids = {c.get("document_id") for c in citations_collector}
for chunk in retrieve_chunks_result:
meta = chunk.metadata or {}
doc_id = meta.get("document_id") or meta.get("doc_id")
if doc_id and doc_id not in seen_doc_ids:
seen_doc_ids.add(doc_id)
citations_collector.append({
"document_id": doc_id,
"file_name": meta.get("file_name", ""),
"knowledge_id": str(meta.get("knowledge_id", "")),
"score": meta.get("score", 0),
})
return f"检索到以下相关信息:\n\n{context}" return f"检索到以下相关信息:\n\n{context}"
else: else:
logger.warning("知识库检索未找到结果") logger.warning("知识库检索未找到结果")
@@ -320,26 +336,26 @@ class AgentRunService:
self, self,
knowledge_retrieval_config: dict | None, knowledge_retrieval_config: dict | None,
user_id user_id
) -> list: ) -> tuple[list, list]:
"""返回 (tools, citations_collector)"""
if not knowledge_retrieval_config: if not knowledge_retrieval_config:
return [] return [], []
citations_collector = []
tools = [] tools = []
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
if kb_ids: if kb_ids:
# 创建知识库检索工具 kb_tool = create_knowledge_retrieval_tool(
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) knowledge_retrieval_config, kb_ids, user_id,
citations_collector=citations_collector
)
tools.append(kb_tool) tools.append(kb_tool)
logger.debug( logger.debug(
"已添加知识库检索工具", "已添加知识库检索工具",
extra={ extra={"kb_ids": kb_ids, "tool_count": len(tools)}
"kb_ids": kb_ids,
"tool_count": len(tools)
}
) )
return tools return tools, citations_collector
def load_memory_config( def load_memory_config(
self, self,
@@ -549,7 +565,8 @@ class AgentRunService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
if memory: if memory:
@@ -680,7 +697,7 @@ class AgentRunService:
"suggested_questions": await self._generate_suggested_questions( "suggested_questions": await self._generate_suggested_questions(
features_config, result["content"], api_key_config, effective_params features_config, result["content"], api_key_config, effective_params
) if not sub_agent else [], ) if not sub_agent else [],
"citations": self._filter_citations(features_config, result.get("citations", [])), "citations": self._filter_citations(features_config, citations_collector),
"audio_url": audio_url, "audio_url": audio_url,
"audio_status": "pending" "audio_status": "pending"
} }
@@ -790,7 +807,8 @@ class AgentRunService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
@@ -943,7 +961,7 @@ class AgentRunService:
logger.warning(f"TTS任务异常: {e}") logger.warning(f"TTS任务异常: {e}")
audio_status = "failed" audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None end_data["audio_status"] = audio_status if stream_audio_url else None
end_data["citations"] = self._filter_citations(features_config, []) end_data["citations"] = self._filter_citations(features_config, citations_collector)
yield self._format_sse_event("end", end_data) yield self._format_sse_event("end", end_data)
logger.info( logger.info(