Merge pull request #688 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
feat(agent)
This commit is contained in:
@@ -196,6 +196,13 @@ class CitationConfig(BaseModel):
|
|||||||
enabled: bool = Field(default=False)
|
enabled: bool = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Citation(BaseModel):
|
||||||
|
document_id: str
|
||||||
|
file_name: str
|
||||||
|
knowledge_id: str
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
class WebSearchConfig(BaseModel):
|
class WebSearchConfig(BaseModel):
|
||||||
"""联网搜索配置"""
|
"""联网搜索配置"""
|
||||||
enabled: bool = Field(default=False)
|
enabled: bool = Field(default=False)
|
||||||
|
|||||||
@@ -82,6 +82,12 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||||
|
|
||||||
|
# opening_statement:首轮对话注入开场白
|
||||||
|
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
|
||||||
|
system_prompt = self.agent_service._inject_opening_statement(
|
||||||
|
features_config, system_prompt, is_new_conversation
|
||||||
|
)
|
||||||
|
|
||||||
# 准备工具列表
|
# 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
@@ -93,7 +99,8 @@ class AppChatService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||||
|
tools.extend(kb_tools)
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
memory_tools, memory_flag = self.agent_service.load_memory_config(
|
||||||
@@ -230,7 +237,7 @@ class AppChatService:
|
|||||||
}),
|
}),
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"suggested_questions": suggested_questions,
|
"suggested_questions": suggested_questions,
|
||||||
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
"citations": self.agent_service._filter_citations(features_config, citations_collector),
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
"audio_status": "pending"
|
"audio_status": "pending"
|
||||||
}
|
}
|
||||||
@@ -283,6 +290,12 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
|
||||||
|
|
||||||
|
# opening_statement:首轮对话注入开场白
|
||||||
|
is_new_conversation = not self.conversation_service.get_messages(conversation_id, limit=1)
|
||||||
|
system_prompt = self.agent_service._inject_opening_statement(
|
||||||
|
features_config, system_prompt, is_new_conversation
|
||||||
|
)
|
||||||
|
|
||||||
# 准备工具列表
|
# 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
@@ -295,7 +308,8 @@ class AppChatService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id))
|
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||||
|
tools.extend(kb_tools)
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
@@ -409,7 +423,7 @@ class AppChatService:
|
|||||||
logger.warning(f"TTS任务异常: {e}")
|
logger.warning(f"TTS任务异常: {e}")
|
||||||
audio_status = "failed"
|
audio_status = "failed"
|
||||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||||
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
end_data["citations"] = self.agent_service._filter_citations(features_config, citations_collector)
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
human_meta = {
|
human_meta = {
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import AgentConfig, ModelConfig, ModelType
|
from app.models import AgentConfig, ModelConfig, ModelType
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.app_schema import FileInput
|
from app.schemas.app_schema import FileInput, Citation
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
|
||||||
from app.services import task_service
|
from app.services import task_service
|
||||||
@@ -190,13 +190,19 @@ def create_web_search_tool(web_search_config: Dict[str, Any]):
|
|||||||
return web_search_tool
|
return web_search_tool
|
||||||
|
|
||||||
|
|
||||||
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id, citations_collector: Optional[List[Citation]] = None):
|
||||||
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
"""从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
kb_config: 知识库配置
|
kb_config: 知识库配置
|
||||||
kb_ids: 知识库ID列表
|
kb_ids: 知识库ID列表
|
||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
|
citations_collector: 用于收集引用信息的列表(由外部传入,tool 执行时填充)
|
||||||
|
列表元素类型为 Citation,包含字段:
|
||||||
|
- document_id: 文档唯一标识
|
||||||
|
- file_name: 文件名
|
||||||
|
- knowledge_id: 知识库 ID
|
||||||
|
- score: 检索相关性得分
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
检索到的相关知识内容
|
检索到的相关知识内容
|
||||||
@@ -229,6 +235,21 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 收集引用信息
|
||||||
|
if citations_collector is not None:
|
||||||
|
seen_doc_ids = {c.get("document_id") for c in citations_collector}
|
||||||
|
for chunk in retrieve_chunks_result:
|
||||||
|
meta = chunk.metadata or {}
|
||||||
|
doc_id = meta.get("document_id") or meta.get("doc_id")
|
||||||
|
if doc_id and doc_id not in seen_doc_ids:
|
||||||
|
seen_doc_ids.add(doc_id)
|
||||||
|
citations_collector.append(Citation(
|
||||||
|
document_id=doc_id,
|
||||||
|
file_name=meta.get("file_name", ""),
|
||||||
|
knowledge_id=str(meta.get("knowledge_id", "")),
|
||||||
|
score=meta.get("score", 0)
|
||||||
|
))
|
||||||
|
|
||||||
return f"检索到以下相关信息:\n\n{context}"
|
return f"检索到以下相关信息:\n\n{context}"
|
||||||
else:
|
else:
|
||||||
logger.warning("知识库检索未找到结果")
|
logger.warning("知识库检索未找到结果")
|
||||||
@@ -320,26 +341,26 @@ class AgentRunService:
|
|||||||
self,
|
self,
|
||||||
knowledge_retrieval_config: dict | None,
|
knowledge_retrieval_config: dict | None,
|
||||||
user_id
|
user_id
|
||||||
) -> list:
|
) -> tuple[list, list]:
|
||||||
|
"""返回 (tools, citations_collector)"""
|
||||||
if not knowledge_retrieval_config:
|
if not knowledge_retrieval_config:
|
||||||
return []
|
return [], []
|
||||||
|
|
||||||
|
citations_collector = []
|
||||||
tools = []
|
tools = []
|
||||||
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", [])
|
||||||
kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id"))
|
kb_ids = [kb["kb_id"] for kb in knowledge_bases if kb.get("kb_id")]
|
||||||
if kb_ids:
|
if kb_ids:
|
||||||
# 创建知识库检索工具
|
kb_tool = create_knowledge_retrieval_tool(
|
||||||
kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id)
|
knowledge_retrieval_config, kb_ids, user_id,
|
||||||
|
citations_collector=citations_collector
|
||||||
|
)
|
||||||
tools.append(kb_tool)
|
tools.append(kb_tool)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"已添加知识库检索工具",
|
"已添加知识库检索工具",
|
||||||
extra={
|
extra={"kb_ids": kb_ids, "tool_count": len(tools)}
|
||||||
"kb_ids": kb_ids,
|
|
||||||
"tool_count": len(tools)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
return tools
|
return tools, citations_collector
|
||||||
|
|
||||||
def load_memory_config(
|
def load_memory_config(
|
||||||
self,
|
self,
|
||||||
@@ -441,12 +462,12 @@ class AgentRunService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_citations(
|
def _filter_citations(
|
||||||
features_config: Dict[str, Any],
|
features_config: Dict[str, Any],
|
||||||
citations: List[Any]
|
citations: List[Citation]
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
"""根据 citation 开关决定是否返回引用来源"""
|
"""根据 citation 开关决定是否返回引用来源"""
|
||||||
citation_cfg = features_config.get("citation", {})
|
citation_cfg = features_config.get("citation", {})
|
||||||
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
|
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
|
||||||
return citations
|
return [cit.model_dump() for cit in citations]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
@@ -549,7 +570,8 @@ class AgentRunService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
|
||||||
|
tools.extend(kb_tools)
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
@@ -680,7 +702,7 @@ class AgentRunService:
|
|||||||
"suggested_questions": await self._generate_suggested_questions(
|
"suggested_questions": await self._generate_suggested_questions(
|
||||||
features_config, result["content"], api_key_config, effective_params
|
features_config, result["content"], api_key_config, effective_params
|
||||||
) if not sub_agent else [],
|
) if not sub_agent else [],
|
||||||
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
"citations": self._filter_citations(features_config, citations_collector),
|
||||||
"audio_url": audio_url,
|
"audio_url": audio_url,
|
||||||
"audio_status": "pending"
|
"audio_status": "pending"
|
||||||
}
|
}
|
||||||
@@ -790,7 +812,8 @@ class AgentRunService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id))
|
kb_tools, citations_collector = self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)
|
||||||
|
tools.extend(kb_tools)
|
||||||
|
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
@@ -943,7 +966,7 @@ class AgentRunService:
|
|||||||
logger.warning(f"TTS任务异常: {e}")
|
logger.warning(f"TTS任务异常: {e}")
|
||||||
audio_status = "failed"
|
audio_status = "failed"
|
||||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||||
end_data["citations"] = self._filter_citations(features_config, [])
|
end_data["citations"] = self._filter_citations(features_config, citations_collector)
|
||||||
yield self._format_sse_event("end", end_data)
|
yield self._format_sse_event("end", end_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
Reference in New Issue
Block a user