feat(agent tool): agent tool bug fix

This commit is contained in:
谢俊男
2026-01-06 20:05:18 +08:00
parent 477404554e
commit 26947d85ae
8 changed files with 86 additions and 56 deletions

View File

@@ -433,7 +433,8 @@ async def chat(
config=agent_config, config=agent_config,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
): ):
yield event yield event
@@ -469,7 +470,8 @@ async def chat(
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
) )
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT: elif app_type == AppType.MULTI_AGENT:
@@ -486,8 +488,8 @@ async def chat(
config=config, config=config,
web_search=payload.web_search, web_search=payload.web_search,
memory=payload.memory, memory=payload.memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event yield event

View File

@@ -197,8 +197,8 @@ async def chat(
config=config, config=config,
web_search=web_search, web_search=web_search,
memory=memory, memory=memory,
storage_type=storage_type, storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id user_rag_memory_id=user_rag_memory_id
): ):
yield event yield event
@@ -214,7 +214,6 @@ async def chat(
# 多 Agent 非流式返回 # 多 Agent 非流式返回
result = await app_chat_service.multi_agent_chat( result = await app_chat_service.multi_agent_chat(
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串 user_id=end_user_id, # 转换为字符串
@@ -293,4 +292,4 @@ async def chat(
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -7,7 +7,6 @@ LangChain Agent 封装
- 支持流式输出 - 支持流式输出
- 使用 RedBearLLM 支持多提供商 - 使用 RedBearLLM 支持多提供商
""" """
import os
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
@@ -156,6 +155,7 @@ class LangChainAgent:
store.delete_duplicate_sessions() store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}') # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id return session_id
async def term_memory_redis_read(self,end_user_end): async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}" end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)

View File

@@ -1,6 +1,6 @@
import datetime import datetime
import uuid import uuid
from typing import Optional, Any, List, Dict from typing import Optional, Any, List, Dict, Union
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
@@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel):
class ToolConfig(BaseModel): class ToolConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
tool_id: str = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default_factory=dict, description="工具特定配置")
class ToolOldConfig(BaseModel):
"""工具配置""" """工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具") enabled: bool = Field(default=False, description="是否启用该工具")
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置") config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
@@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel):
) )
# 工具配置 # 工具配置
tools: Dict[str, ToolConfig] = Field( tools: List[ToolConfig] = Field(
default_factory=dict, default_factory=list,
description="工具配置key 为工具名称web_search, code_interpreter, image_generation 等)" description="Agent 可用的工具列表"
) )
@@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel):
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表") variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
# 工具配置 # 工具配置
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置") tools: Optional[List[ToolConfig]] = Field(default=None, description="工具列表")
# ---------- Output Schemas ---------- # ---------- Output Schemas ----------
@@ -216,7 +222,7 @@ class AgentConfig(BaseModel):
variables: List[VariableDefinition] = [] variables: List[VariableDefinition] = []
# 工具配置 # 工具配置
tools: Dict[str, ToolConfig] = {} tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
is_active: bool is_active: bool
created_at: datetime.datetime created_at: datetime.datetime

View File

@@ -2,14 +2,14 @@
Agent 配置格式转换器 Agent 配置格式转换器
用于将 Pydantic 模型转换为数据库存储格式 用于将 Pydantic 模型转换为数据库存储格式
""" """
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, Union
from app.schemas.app_schema import ( from app.schemas.app_schema import (
KnowledgeRetrievalConfig, KnowledgeRetrievalConfig,
MemoryConfig, MemoryConfig,
VariableDefinition, VariableDefinition,
ToolConfig, ToolConfig,
AgentConfigCreate, AgentConfigCreate,
AgentConfigUpdate, AgentConfigUpdate, ToolOldConfig,
) )
@@ -47,10 +47,7 @@ class AgentConfigConverter:
# 5. 工具配置 # 5. 工具配置
if hasattr(config, 'tools') and config.tools: if hasattr(config, 'tools') and config.tools:
result["tools"] = { result["tools"] = [tool.model_dump() for tool in config.tools]
name: tool.model_dump()
for name, tool in config.tools.items()
}
return result return result
@@ -60,7 +57,7 @@ class AgentConfigConverter:
knowledge_retrieval: Optional[Dict[str, Any]], knowledge_retrieval: Optional[Dict[str, Any]],
memory: Optional[Dict[str, Any]], memory: Optional[Dict[str, Any]],
variables: Optional[list], variables: Optional[list],
tools: Optional[Dict[str, Any]], tools: Optional[Union[list, Dict[str, Any]]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
将数据库存储格式转换为 Pydantic 对象 将数据库存储格式转换为 Pydantic 对象
@@ -113,9 +110,12 @@ class AgentConfigConverter:
# 5. 解析工具配置 # 5. 解析工具配置
if tools: if tools:
result["tools"] = { if isinstance(tools, list):
name: ToolConfig(**tool_data) result["tools"] = [ToolConfig(**tool_config) for tool_config in tools]
for name, tool_data in tools.items() else:
} result["tools"] = {
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
return result return result

View File

@@ -78,13 +78,17 @@ class AppChatService:
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools: if hasattr(config, 'tools') and config.tools:
for tool_id, tool_config in config.tools.items(): for tool_config in config.tools:
if tool_config.get("enabled", False): if tool_config.get("enabled", False):
# 根据工具名称查找工具实例 # 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id)) tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
if tool_instance: if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具 # 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None)) langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool) tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具
@@ -219,6 +223,23 @@ class AppChatService:
# 准备工具列表 # 准备工具列表
tools = [] tools = []
# 获取工具服务
tool_service = ToolService(self.db)
if hasattr(config, 'tools') and config.tools:
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval: if knowledge_retrieval:
@@ -237,20 +258,20 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id) memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool) tools.append(memory_tool)
web_tools = config.tools # web_tools = config.tools
web_search_choice = web_tools.get("web_search", {}) # web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False) # web_search_enable = web_search_choice.get("enabled", False)
if web_search == True: # if web_search == True:
if web_search_enable == True: # if web_search_enable == True:
search_tool = create_web_search_tool({}) # search_tool = create_web_search_tool({})
tools.append(search_tool) # tools.append(search_tool)
#
logger.debug( # logger.debug(
"已添加网络搜索工具", # "已添加网络搜索工具",
extra={ # extra={
"tool_count": len(tools) # "tool_count": len(tools)
} # }
) # )
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters

View File

@@ -307,7 +307,7 @@ class AppService:
knowledge_retrieval=storage_data.get("knowledge_retrieval"), knowledge_retrieval=storage_data.get("knowledge_retrieval"),
memory=storage_data.get("memory"), memory=storage_data.get("memory"),
variables=storage_data.get("variables", []), variables=storage_data.get("variables", []),
tools=storage_data.get("tools", {}), tools=storage_data.get("tools", []),
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -689,7 +689,7 @@ class AppService:
knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None, knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None,
memory=source_config.memory.copy() if source_config.memory else None, memory=source_config.memory.copy() if source_config.memory else None,
variables=source_config.variables.copy() if source_config.variables else [], variables=source_config.variables.copy() if source_config.variables else [],
tools=source_config.tools.copy() if source_config.tools else {}, tools=source_config.tools.copy() if source_config.tools else [],
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -879,7 +879,7 @@ class AppService:
# if data.variables is not None: # if data.variables is not None:
agent_cfg.variables = storage_data.get("variables", []) agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None: # if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", {}) agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.updated_at = now agent_cfg.updated_at = now
@@ -966,7 +966,7 @@ class AppService:
"max_history": 10 "max_history": 10
}, },
variables=[], variables=[],
tools={}, tools=[],
is_active=True, is_active=True,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
@@ -1183,7 +1183,7 @@ class AppService:
"knowledge_retrieval": agent_cfg.knowledge_retrieval, "knowledge_retrieval": agent_cfg.knowledge_retrieval,
"memory": agent_cfg.memory, "memory": agent_cfg.memory,
"variables": agent_cfg.variables or [], "variables": agent_cfg.variables or [],
"tools": agent_cfg.tools or {}, "tools": agent_cfg.tools or [],
} }
# config = AgentConfigConverter.from_storage_format(agent_cfg) # config = AgentConfigConverter.from_storage_format(agent_cfg)
default_model_config_id = agent_cfg.default_model_config_id default_model_config_id = agent_cfg.default_model_config_id

View File

@@ -298,16 +298,17 @@ class DraftRunService:
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools: if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_id, tool_config in agent_config.tools.items(): for tool_config in agent_config.tools:
if tool_config.get("enabled", False): if tool_config.get("enabled", False):
# 根据工具名称查找工具实例 # 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id, tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id( ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id))) self.db, str(workspace_id)))
if tool_instance: if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具 # 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool( langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool) tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具
@@ -507,16 +508,17 @@ class DraftRunService:
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools: if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_id, tool_config in agent_config.tools.items(): for tool_config in agent_config.tools:
if tool_config.get("enabled", False): if tool_config.get("enabled", False):
# 根据工具名称查找工具实例 # 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_id, tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id( ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id))) self.db, str(workspace_id)))
if tool_instance: if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具 # 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool( langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tool_config.get("config", {}).get("operation", None))
tools.append(langchain_tool) tools.append(langchain_tool)
# 添加知识库检索工具 # 添加知识库检索工具