feat(agent tool): agent tool bug fix
This commit is contained in:
@@ -433,7 +433,8 @@ async def chat(
|
|||||||
config=agent_config,
|
config=agent_config,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -469,7 +470,8 @@ async def chat(
|
|||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
)
|
)
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
elif app_type == AppType.MULTI_AGENT:
|
elif app_type == AppType.MULTI_AGENT:
|
||||||
@@ -486,8 +488,8 @@ async def chat(
|
|||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
|||||||
@@ -197,8 +197,8 @@ async def chat(
|
|||||||
config=config,
|
config=config,
|
||||||
web_search=web_search,
|
web_search=web_search,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id
|
user_rag_memory_id=user_rag_memory_id
|
||||||
):
|
):
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
@@ -214,7 +214,6 @@ async def chat(
|
|||||||
|
|
||||||
# 多 Agent 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.multi_agent_chat(
|
result = await app_chat_service.multi_agent_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
@@ -293,4 +292,4 @@ async def chat(
|
|||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
pass
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ LangChain Agent 封装
|
|||||||
- 支持流式输出
|
- 支持流式输出
|
||||||
- 使用 RedBearLLM 支持多提供商
|
- 使用 RedBearLLM 支持多提供商
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
@@ -156,6 +155,7 @@ class LangChainAgent:
|
|||||||
store.delete_duplicate_sessions()
|
store.delete_duplicate_sessions()
|
||||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
async def term_memory_redis_read(self,end_user_end):
|
async def term_memory_redis_read(self,end_user_end):
|
||||||
end_user_end = f"Term_{end_user_end}"
|
end_user_end = f"Term_{end_user_end}"
|
||||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional, Any, List, Dict
|
from typing import Optional, Any, List, Dict, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||||
|
|
||||||
@@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ToolConfig(BaseModel):
|
class ToolConfig(BaseModel):
|
||||||
|
"""工具配置"""
|
||||||
|
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||||
|
tool_id: str = Field(default=None, description="工具ID")
|
||||||
|
operation: Optional[str] = Field(default_factory=dict, description="工具特定配置")
|
||||||
|
|
||||||
|
class ToolOldConfig(BaseModel):
|
||||||
"""工具配置"""
|
"""工具配置"""
|
||||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||||
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
|
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
|
||||||
@@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Dict[str, ToolConfig] = Field(
|
tools: List[ToolConfig] = Field(
|
||||||
default_factory=dict,
|
default_factory=list,
|
||||||
description="工具配置,key 为工具名称(web_search, code_interpreter, image_generation 等)"
|
description="Agent 可用的工具列表"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel):
|
|||||||
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
|
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
|
||||||
|
|
||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
|
tools: Optional[List[ToolConfig]] = Field(default=None, description="工具列表")
|
||||||
|
|
||||||
|
|
||||||
# ---------- Output Schemas ----------
|
# ---------- Output Schemas ----------
|
||||||
@@ -216,7 +222,7 @@ class AgentConfig(BaseModel):
|
|||||||
variables: List[VariableDefinition] = []
|
variables: List[VariableDefinition] = []
|
||||||
|
|
||||||
# 工具配置
|
# 工具配置
|
||||||
tools: Dict[str, ToolConfig] = {}
|
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
|
||||||
|
|
||||||
is_active: bool
|
is_active: bool
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
|
|||||||
@@ -2,14 +2,14 @@
|
|||||||
Agent 配置格式转换器
|
Agent 配置格式转换器
|
||||||
用于将 Pydantic 模型转换为数据库存储格式
|
用于将 Pydantic 模型转换为数据库存储格式
|
||||||
"""
|
"""
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional, Union
|
||||||
from app.schemas.app_schema import (
|
from app.schemas.app_schema import (
|
||||||
KnowledgeRetrievalConfig,
|
KnowledgeRetrievalConfig,
|
||||||
MemoryConfig,
|
MemoryConfig,
|
||||||
VariableDefinition,
|
VariableDefinition,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
AgentConfigCreate,
|
AgentConfigCreate,
|
||||||
AgentConfigUpdate,
|
AgentConfigUpdate, ToolOldConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -47,10 +47,7 @@ class AgentConfigConverter:
|
|||||||
|
|
||||||
# 5. 工具配置
|
# 5. 工具配置
|
||||||
if hasattr(config, 'tools') and config.tools:
|
if hasattr(config, 'tools') and config.tools:
|
||||||
result["tools"] = {
|
result["tools"] = [tool.model_dump() for tool in config.tools]
|
||||||
name: tool.model_dump()
|
|
||||||
for name, tool in config.tools.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -60,7 +57,7 @@ class AgentConfigConverter:
|
|||||||
knowledge_retrieval: Optional[Dict[str, Any]],
|
knowledge_retrieval: Optional[Dict[str, Any]],
|
||||||
memory: Optional[Dict[str, Any]],
|
memory: Optional[Dict[str, Any]],
|
||||||
variables: Optional[list],
|
variables: Optional[list],
|
||||||
tools: Optional[Dict[str, Any]],
|
tools: Optional[Union[list, Dict[str, Any]]],
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
将数据库存储格式转换为 Pydantic 对象
|
将数据库存储格式转换为 Pydantic 对象
|
||||||
@@ -113,9 +110,12 @@ class AgentConfigConverter:
|
|||||||
|
|
||||||
# 5. 解析工具配置
|
# 5. 解析工具配置
|
||||||
if tools:
|
if tools:
|
||||||
result["tools"] = {
|
if isinstance(tools, list):
|
||||||
name: ToolConfig(**tool_data)
|
result["tools"] = [ToolConfig(**tool_config) for tool_config in tools]
|
||||||
for name, tool_data in tools.items()
|
else:
|
||||||
}
|
result["tools"] = {
|
||||||
|
name: ToolOldConfig(**tool_data)
|
||||||
|
for name, tool_data in tools.items()
|
||||||
|
}
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -78,13 +78,17 @@ class AppChatService:
|
|||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(config, 'tools') and config.tools:
|
if hasattr(config, 'tools') and config.tools:
|
||||||
for tool_id, tool_config in config.tools.items():
|
for tool_config in config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service._get_tool_instance(tool_id, ToolRepository.get_tenant_id_by_workspace_id(self.db, workspace_id))
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||||
|
ToolRepository.get_tenant_id_by_workspace_id(
|
||||||
|
self.db, workspace_id))
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
|
continue
|
||||||
# 转换为LangChain工具
|
# 转换为LangChain工具
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("config", {}).get("operation", None))
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
@@ -219,6 +223,23 @@ class AppChatService:
|
|||||||
# 准备工具列表
|
# 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
|
# 获取工具服务
|
||||||
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
|
if hasattr(config, 'tools') and config.tools:
|
||||||
|
for tool_config in config.tools:
|
||||||
|
if tool_config.get("enabled", False):
|
||||||
|
# 根据工具名称查找工具实例
|
||||||
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||||
|
ToolRepository.get_tenant_id_by_workspace_id(
|
||||||
|
self.db, workspace_id))
|
||||||
|
if tool_instance:
|
||||||
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
|
continue
|
||||||
|
# 转换为LangChain工具
|
||||||
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
knowledge_retrieval = config.knowledge_retrieval
|
knowledge_retrieval = config.knowledge_retrieval
|
||||||
if knowledge_retrieval:
|
if knowledge_retrieval:
|
||||||
@@ -237,20 +258,20 @@ class AppChatService:
|
|||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
web_tools = config.tools
|
# web_tools = config.tools
|
||||||
web_search_choice = web_tools.get("web_search", {})
|
# web_search_choice = web_tools.get("web_search", {})
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
# web_search_enable = web_search_choice.get("enabled", False)
|
||||||
if web_search == True:
|
# if web_search == True:
|
||||||
if web_search_enable == True:
|
# if web_search_enable == True:
|
||||||
search_tool = create_web_search_tool({})
|
# search_tool = create_web_search_tool({})
|
||||||
tools.append(search_tool)
|
# tools.append(search_tool)
|
||||||
|
#
|
||||||
logger.debug(
|
# logger.debug(
|
||||||
"已添加网络搜索工具",
|
# "已添加网络搜索工具",
|
||||||
extra={
|
# extra={
|
||||||
"tool_count": len(tools)
|
# "tool_count": len(tools)
|
||||||
}
|
# }
|
||||||
)
|
# )
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|||||||
@@ -307,7 +307,7 @@ class AppService:
|
|||||||
knowledge_retrieval=storage_data.get("knowledge_retrieval"),
|
knowledge_retrieval=storage_data.get("knowledge_retrieval"),
|
||||||
memory=storage_data.get("memory"),
|
memory=storage_data.get("memory"),
|
||||||
variables=storage_data.get("variables", []),
|
variables=storage_data.get("variables", []),
|
||||||
tools=storage_data.get("tools", {}),
|
tools=storage_data.get("tools", []),
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -689,7 +689,7 @@ class AppService:
|
|||||||
knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None,
|
knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None,
|
||||||
memory=source_config.memory.copy() if source_config.memory else None,
|
memory=source_config.memory.copy() if source_config.memory else None,
|
||||||
variables=source_config.variables.copy() if source_config.variables else [],
|
variables=source_config.variables.copy() if source_config.variables else [],
|
||||||
tools=source_config.tools.copy() if source_config.tools else {},
|
tools=source_config.tools.copy() if source_config.tools else [],
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -879,7 +879,7 @@ class AppService:
|
|||||||
# if data.variables is not None:
|
# if data.variables is not None:
|
||||||
agent_cfg.variables = storage_data.get("variables", [])
|
agent_cfg.variables = storage_data.get("variables", [])
|
||||||
# if data.tools is not None:
|
# if data.tools is not None:
|
||||||
agent_cfg.tools = storage_data.get("tools", {})
|
agent_cfg.tools = storage_data.get("tools", [])
|
||||||
|
|
||||||
agent_cfg.updated_at = now
|
agent_cfg.updated_at = now
|
||||||
|
|
||||||
@@ -966,7 +966,7 @@ class AppService:
|
|||||||
"max_history": 10
|
"max_history": 10
|
||||||
},
|
},
|
||||||
variables=[],
|
variables=[],
|
||||||
tools={},
|
tools=[],
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -1183,7 +1183,7 @@ class AppService:
|
|||||||
"knowledge_retrieval": agent_cfg.knowledge_retrieval,
|
"knowledge_retrieval": agent_cfg.knowledge_retrieval,
|
||||||
"memory": agent_cfg.memory,
|
"memory": agent_cfg.memory,
|
||||||
"variables": agent_cfg.variables or [],
|
"variables": agent_cfg.variables or [],
|
||||||
"tools": agent_cfg.tools or {},
|
"tools": agent_cfg.tools or [],
|
||||||
}
|
}
|
||||||
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
||||||
default_model_config_id = agent_cfg.default_model_config_id
|
default_model_config_id = agent_cfg.default_model_config_id
|
||||||
|
|||||||
@@ -298,16 +298,17 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||||
for tool_id, tool_config in agent_config.tools.items():
|
for tool_config in agent_config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service._get_tool_instance(tool_id,
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||||
ToolRepository.get_tenant_id_by_workspace_id(
|
ToolRepository.get_tenant_id_by_workspace_id(
|
||||||
self.db, str(workspace_id)))
|
self.db, str(workspace_id)))
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
|
continue
|
||||||
# 转换为LangChain工具
|
# 转换为LangChain工具
|
||||||
langchain_tool = tool_instance.to_langchain_tool(
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tool_config.get("config", {}).get("operation", None))
|
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
@@ -507,16 +508,17 @@ class DraftRunService:
|
|||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||||
for tool_id, tool_config in agent_config.tools.items():
|
for tool_config in agent_config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service._get_tool_instance(tool_id,
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||||
ToolRepository.get_tenant_id_by_workspace_id(
|
ToolRepository.get_tenant_id_by_workspace_id(
|
||||||
self.db, str(workspace_id)))
|
self.db, str(workspace_id)))
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
|
continue
|
||||||
# 转换为LangChain工具
|
# 转换为LangChain工具
|
||||||
langchain_tool = tool_instance.to_langchain_tool(
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tool_config.get("config", {}).get("operation", None))
|
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
|
|||||||
Reference in New Issue
Block a user