From 26947d85aedab6d058f52518c0ee925bfd6d661d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Tue, 6 Jan 2026 20:05:18 +0800 Subject: [PATCH] feat(agent tool): agent tool bug fix --- .../controllers/public_share_controller.py | 10 ++-- .../controllers/service/app_api_controller.py | 7 +-- api/app/core/agent/langchain_agent.py | 2 +- api/app/schemas/app_schema.py | 18 ++++-- api/app/services/agent_config_converter.py | 22 ++++---- api/app/services/app_chat_service.py | 55 +++++++++++++------ api/app/services/app_service.py | 10 ++-- api/app/services/draft_run_service.py | 18 +++--- 8 files changed, 86 insertions(+), 56 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index adb199fb..02c73718 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -433,7 +433,8 @@ async def chat( config=agent_config, memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ): yield event @@ -469,7 +470,8 @@ async def chat( web_search=payload.web_search, memory=payload.memory, storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + user_rag_memory_id=user_rag_memory_id, + workspace_id=workspace_id ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) elif app_type == AppType.MULTI_AGENT: @@ -486,8 +488,8 @@ async def chat( config=config, web_search=payload.web_search, memory=payload.memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index 180310a6..583b4700 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -197,8 +197,8 @@ async def chat( config=config, web_search=web_search, memory=memory, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): yield event @@ -214,7 +214,6 @@ async def chat( # 多 Agent 非流式返回 result = await app_chat_service.multi_agent_chat( - message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID user_id=end_user_id, # 转换为字符串 @@ -293,4 +292,4 @@ async def chat( from app.core.exceptions import BusinessException from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) - pass + diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index acd8cc5b..ef9a489f 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -7,7 +7,6 @@ LangChain Agent 封装 - 支持流式输出 - 使用 RedBearLLM 支持多提供商 """ -import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence @@ -156,6 +155,7 @@ class LangChainAgent: store.delete_duplicate_sessions() # logger.info(f'Redis_Agent:{end_user_end};{session_id}') return session_id + async def term_memory_redis_read(self,end_user_end): end_user_end = f"Term_{end_user_end}" history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 81cd704d..d20570ce 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -1,6 +1,6 @@ import datetime import uuid -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List, Dict, Union from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator @@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel): class ToolConfig(BaseModel): + """工具配置""" + enabled: bool = Field(default=False, description="是否启用该工具") + tool_id: str = Field(default=None, description="工具ID") + operation: Optional[str] = Field(default_factory=dict, description="工具特定配置") + +class ToolOldConfig(BaseModel): """工具配置""" enabled: bool = Field(default=False, description="是否启用该工具") config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置") @@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel): ) # 工具配置 - tools: Dict[str, ToolConfig] = Field( - default_factory=dict, - description="工具配置,key 为工具名称(web_search, code_interpreter, image_generation 等)" + tools: List[ToolConfig] = Field( + default_factory=list, + description="Agent 可用的工具列表" ) @@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel): variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") # 工具配置 - tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") + tools: Optional[List[ToolConfig]] = Field(default=None, description="工具列表") # ---------- Output Schemas ---------- @@ -216,7 +222,7 @@ class AgentConfig(BaseModel): variables: List[VariableDefinition] = [] # 工具配置 - tools: Dict[str, ToolConfig] = {} + tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = [] is_active: bool created_at: datetime.datetime diff --git a/api/app/services/agent_config_converter.py b/api/app/services/agent_config_converter.py index 3ab14157..eda4b5c4 100644 --- a/api/app/services/agent_config_converter.py +++ b/api/app/services/agent_config_converter.py @@ -2,14 +2,14 @@ Agent 配置格式转换器 用于将 Pydantic 模型转换为数据库存储格式 """ -from typing import Dict, Any, Optional +from typing import Dict, Any, Optional, Union from app.schemas.app_schema import ( KnowledgeRetrievalConfig, MemoryConfig, VariableDefinition, ToolConfig, AgentConfigCreate, - AgentConfigUpdate, + AgentConfigUpdate, ToolOldConfig, ) @@ -47,10 +47,7 @@ class AgentConfigConverter: # 5. 工具配置 if hasattr(config, 'tools') and config.tools: - result["tools"] = { - name: tool.model_dump() - for name, tool in config.tools.items() - } + result["tools"] = [tool.model_dump() for tool in config.tools] return result @@ -60,7 +57,7 @@ class AgentConfigConverter: knowledge_retrieval: Optional[Dict[str, Any]], memory: Optional[Dict[str, Any]], variables: Optional[list], - tools: Optional[Dict[str, Any]], + tools: Optional[Union[list, Dict[str, Any]]], ) -> Dict[str, Any]: """ 将数据库存储格式转换为 Pydantic 对象 @@ -113,9 +110,12 @@ class AgentConfigConverter: # 5. 解析工具配置 if tools: - result["tools"] = { - name: ToolConfig(**tool_data) - for name, tool_data in tools.items() - } + if isinstance(tools, list): + result["tools"] = [ToolConfig(**tool_config) for tool_config in tools] + else: + result["tools"] = { + name: ToolOldConfig(**tool_data) + for name, tool_data in tools.items() + } return result diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 917184c7..537eac8d 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -78,13 +78,17 @@ class AppChatService: # 从配置中获取启用的工具 if hasattr(config, 'tools') and config.tools: - for tool_id, tool_config in config.tools.items(): + for tool_config in config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id)) + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, workspace_id)) if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None)) + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) # 添加知识库检索工具 @@ -219,6 +223,23 @@ class AppChatService: # 准备工具列表 tools = [] + # 获取工具服务 + tool_service = ToolService(self.db) + + if hasattr(config, 'tools') and config.tools: + for tool_config in config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, workspace_id)) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + # 添加知识库检索工具 knowledge_retrieval = config.knowledge_retrieval if knowledge_retrieval: @@ -237,20 +258,20 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search == True: - if web_search_enable == True: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) + # web_tools = config.tools + # web_search_choice = web_tools.get("web_search", {}) + # web_search_enable = web_search_choice.get("enabled", False) + # if web_search == True: + # if web_search_enable == True: + # search_tool = create_web_search_tool({}) + # tools.append(search_tool) + # + # logger.debug( + # "已添加网络搜索工具", + # extra={ + # "tool_count": len(tools) + # } + # ) # 获取模型参数 model_parameters = config.model_parameters diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index 38097c4e..e15f68fe 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -307,7 +307,7 @@ class AppService: knowledge_retrieval=storage_data.get("knowledge_retrieval"), memory=storage_data.get("memory"), variables=storage_data.get("variables", []), - tools=storage_data.get("tools", {}), + tools=storage_data.get("tools", []), is_active=True, created_at=now, updated_at=now, @@ -689,7 +689,7 @@ class AppService: knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None, memory=source_config.memory.copy() if source_config.memory else None, variables=source_config.variables.copy() if source_config.variables else [], - tools=source_config.tools.copy() if source_config.tools else {}, + tools=source_config.tools.copy() if source_config.tools else [], is_active=True, created_at=now, updated_at=now, @@ -879,7 +879,7 @@ class AppService: # if data.variables is not None: agent_cfg.variables = storage_data.get("variables", []) # if data.tools is not None: - agent_cfg.tools = storage_data.get("tools", {}) + agent_cfg.tools = storage_data.get("tools", []) agent_cfg.updated_at = now @@ -966,7 +966,7 @@ class AppService: "max_history": 10 }, variables=[], - tools={}, + tools=[], is_active=True, created_at=now, updated_at=now, @@ -1183,7 +1183,7 @@ class AppService: "knowledge_retrieval": agent_cfg.knowledge_retrieval, "memory": agent_cfg.memory, "variables": agent_cfg.variables or [], - "tools": agent_cfg.tools or {}, + "tools": agent_cfg.tools or [], } # config = AgentConfigConverter.from_storage_format(agent_cfg) default_model_config_id = agent_cfg.default_model_config_id diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 79aebcce..9a1dbd32 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -298,16 +298,17 @@ class DraftRunService: # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools: - for tool_id, tool_config in agent_config.tools.items(): + for tool_config in agent_config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_id, + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), ToolRepository.get_tenant_id_by_workspace_id( self.db, str(workspace_id))) if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool( - tool_config.get("config", {}).get("operation", None)) + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) # 添加知识库检索工具 @@ -507,16 +508,17 @@ class DraftRunService: # 从配置中获取启用的工具 if hasattr(agent_config, 'tools') and agent_config.tools: - for tool_id, tool_config in agent_config.tools.items(): + for tool_config in agent_config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_id, + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), ToolRepository.get_tenant_id_by_workspace_id( self.db, str(workspace_id))) if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool( - tool_config.get("config", {}).get("operation", None)) + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) # 添加知识库检索工具