feat(agent tool): agent tool bug fix
This commit is contained in:
@@ -2,14 +2,14 @@
|
||||
Agent 配置格式转换器
|
||||
用于将 Pydantic 模型转换为数据库存储格式
|
||||
"""
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from app.schemas.app_schema import (
|
||||
KnowledgeRetrievalConfig,
|
||||
MemoryConfig,
|
||||
VariableDefinition,
|
||||
ToolConfig,
|
||||
AgentConfigCreate,
|
||||
AgentConfigUpdate,
|
||||
AgentConfigUpdate, ToolOldConfig,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,10 +47,7 @@ class AgentConfigConverter:
|
||||
|
||||
# 5. 工具配置
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
result["tools"] = {
|
||||
name: tool.model_dump()
|
||||
for name, tool in config.tools.items()
|
||||
}
|
||||
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||
|
||||
return result
|
||||
|
||||
@@ -60,7 +57,7 @@ class AgentConfigConverter:
|
||||
knowledge_retrieval: Optional[Dict[str, Any]],
|
||||
memory: Optional[Dict[str, Any]],
|
||||
variables: Optional[list],
|
||||
tools: Optional[Dict[str, Any]],
|
||||
tools: Optional[Union[list, Dict[str, Any]]],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
将数据库存储格式转换为 Pydantic 对象
|
||||
@@ -113,9 +110,12 @@ class AgentConfigConverter:
|
||||
|
||||
# 5. 解析工具配置
|
||||
if tools:
|
||||
result["tools"] = {
|
||||
name: ToolConfig(**tool_data)
|
||||
for name, tool_data in tools.items()
|
||||
}
|
||||
if isinstance(tools, list):
|
||||
result["tools"] = [ToolConfig(**tool_config) for tool_config in tools]
|
||||
else:
|
||||
result["tools"] = {
|
||||
name: ToolOldConfig(**tool_data)
|
||||
for name, tool_data in tools.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
@@ -78,13 +78,17 @@ class AppChatService:
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
for tool_id, tool_config in config.tools.items():
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
@@ -219,6 +223,23 @@ class AppChatService:
|
||||
# 准备工具列表
|
||||
tools = []
|
||||
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, workspace_id))
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
knowledge_retrieval = config.knowledge_retrieval
|
||||
if knowledge_retrieval:
|
||||
@@ -237,20 +258,20 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
# web_tools = config.tools
|
||||
# web_search_choice = web_tools.get("web_search", {})
|
||||
# web_search_enable = web_search_choice.get("enabled", False)
|
||||
# if web_search == True:
|
||||
# if web_search_enable == True:
|
||||
# search_tool = create_web_search_tool({})
|
||||
# tools.append(search_tool)
|
||||
#
|
||||
# logger.debug(
|
||||
# "已添加网络搜索工具",
|
||||
# extra={
|
||||
# "tool_count": len(tools)
|
||||
# }
|
||||
# )
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -307,7 +307,7 @@ class AppService:
|
||||
knowledge_retrieval=storage_data.get("knowledge_retrieval"),
|
||||
memory=storage_data.get("memory"),
|
||||
variables=storage_data.get("variables", []),
|
||||
tools=storage_data.get("tools", {}),
|
||||
tools=storage_data.get("tools", []),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -689,7 +689,7 @@ class AppService:
|
||||
knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None,
|
||||
memory=source_config.memory.copy() if source_config.memory else None,
|
||||
variables=source_config.variables.copy() if source_config.variables else [],
|
||||
tools=source_config.tools.copy() if source_config.tools else {},
|
||||
tools=source_config.tools.copy() if source_config.tools else [],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -879,7 +879,7 @@ class AppService:
|
||||
# if data.variables is not None:
|
||||
agent_cfg.variables = storage_data.get("variables", [])
|
||||
# if data.tools is not None:
|
||||
agent_cfg.tools = storage_data.get("tools", {})
|
||||
agent_cfg.tools = storage_data.get("tools", [])
|
||||
|
||||
agent_cfg.updated_at = now
|
||||
|
||||
@@ -966,7 +966,7 @@ class AppService:
|
||||
"max_history": 10
|
||||
},
|
||||
variables=[],
|
||||
tools={},
|
||||
tools=[],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -1183,7 +1183,7 @@ class AppService:
|
||||
"knowledge_retrieval": agent_cfg.knowledge_retrieval,
|
||||
"memory": agent_cfg.memory,
|
||||
"variables": agent_cfg.variables or [],
|
||||
"tools": agent_cfg.tools or {},
|
||||
"tools": agent_cfg.tools or [],
|
||||
}
|
||||
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
||||
default_model_config_id = agent_cfg.default_model_config_id
|
||||
|
||||
@@ -298,16 +298,17 @@ class DraftRunService:
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_id, tool_config in agent_config.tools.items():
|
||||
for tool_config in agent_config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_id,
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(
|
||||
tool_config.get("config", {}).get("operation", None))
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
@@ -507,16 +508,17 @@ class DraftRunService:
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_id, tool_config in agent_config.tools.items():
|
||||
for tool_config in agent_config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_id,
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(
|
||||
tool_config.get("config", {}).get("operation", None))
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
|
||||
# 添加知识库检索工具
|
||||
|
||||
Reference in New Issue
Block a user