Merge pull request #41 from SuanmoSuanyangTechnology/feature/agent-tool_xjn

Feature/agent tool xjn
This commit is contained in:
Mark
2026-01-06 20:30:55 +08:00
committed by GitHub
15 changed files with 413 additions and 106 deletions

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
@@ -26,4 +27,9 @@ def get_workspace_list(
):
"""获取工作空间列表"""
workspace_list = HomePageService.get_workspace_list(db, current_user.tenant_id)
return success(data=workspace_list, msg="工作空间列表获取成功")
return success(data=workspace_list, msg="工作空间列表获取成功")
@router.get("/version", response_model=ApiResponse)
def get_system_version():
"""获取系统版本号"""
return success(data={"version": settings.SYSTEM_VERSION}, msg="系统版本获取成功")

View File

@@ -433,7 +433,8 @@ async def chat(
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -469,7 +470,8 @@ async def chat(
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
@@ -486,8 +488,8 @@ async def chat(
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event

View File

@@ -154,7 +154,8 @@ async def chat(
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -178,7 +179,8 @@ async def chat(
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
@@ -195,8 +197,8 @@ async def chat(
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -212,7 +214,6 @@ async def chat(
# 多 Agent 非流式返回
result = await app_chat_service.multi_agent_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
@@ -291,4 +292,4 @@ async def chat(
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
pass

View File

@@ -7,7 +7,6 @@ LangChain Agent 封装
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
@@ -97,8 +96,7 @@ class LangChainAgent:
"temperature": temperature,
"streaming": streaming,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
"tool_count": len(self.tools)
"tool_names": [tool.name for tool in self.tools] if self.tools else []
}
)
@@ -139,8 +137,11 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
async def term_memory_save(self,messages,end_user_end,aimessages):
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
"""
短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j
"""
end_user_end=f"Term_{end_user_end}"
print(messages)
print(aimessages)
@@ -154,6 +155,7 @@ class LangChainAgent:
store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id
async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)

View File

@@ -164,6 +164,9 @@ class Settings:
TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60"))
TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10"))
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v1.0.0")
def get_memory_output_path(self, filename: str = "") -> str:
"""

View File

@@ -191,10 +191,14 @@ class BaseTool(ABC):
execution_time=execution_time
)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
def to_langchain_tool(self, operation: Optional[str] = None):
"""转换为Langchain工具格式
Args:
operation: 特定操作(适用于有操作的工具)
"""
from app.core.tools.langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
return LangchainAdapter.convert_tool(self, operation)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -0,0 +1,216 @@
"""操作工具 - 为特定操作创建的工具包装器"""
from typing import List
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.models import ToolType
class OperationTool(BaseTool):
"""操作工具 - 包装基础工具的特定操作"""
def __init__(self, base_tool: BaseTool, operation: str):
self.base_tool = base_tool
self.operation = operation
super().__init__(base_tool.tool_id, base_tool.config)
@property
def name(self) -> str:
return f"{self.base_tool.name}_{self.operation}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
def description(self) -> str:
return f"{self.base_tool.description} - {self.operation}"
@property
def parameters(self) -> List[ToolParameter]:
"""返回特定操作的参数"""
if self.base_tool.name == 'datetime_tool':
return self._get_datetime_params()
elif self.base_tool.name == 'json_tool':
return self._get_json_params()
else:
# 默认返回除operation外的所有参数
return [p for p in self.base_tool.parameters if p.name != "operation"]
def _get_datetime_params(self) -> List[ToolParameter]:
"""获取datetime_tool特定操作的参数"""
if self.operation == "now":
return [
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "format":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
)
]
elif self.operation == "convert_timezone":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
elif self.operation == "timestamp_to_datetime":
return [
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="Asia/Shanghai"
)
]
else:
return []
def _get_json_params(self) -> List[ToolParameter]:
"""获取json_tool特定操作的参数"""
base_params = [
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
)
]
if self.operation == "insert":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="new_value",
type=ParameterType.STRING,
description="新值用于insert操作",
required=True
)
]
elif self.operation == "replace":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
),
ToolParameter(
name="old_text",
type=ParameterType.STRING,
description="要替换的原文本用于replace操作",
required=True
),
ToolParameter(
name="new_text",
type=ParameterType.STRING,
description="替换后的新文本用于replace操作",
required=True
)
]
elif self.operation == "delete":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
elif self.operation == "parse":
return base_params + [
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式$.user.name或users[0].name",
required=True
)
]
else:
return base_params
async def execute(self, **kwargs) -> ToolResult:
"""执行特定操作"""
# 添加operation参数
kwargs["operation"] = self.operation
return await self.base_tool.execute(**kwargs)

View File

@@ -1,4 +1,5 @@
"""自定义工具基类"""
import json
import time
from typing import Dict, Any, List, Optional
import aiohttp
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
if not self.schema_content:
return operations
if isinstance(self.schema_content, str):
try:
self.schema_content = json.loads(self.schema_content)
except json.JSONDecodeError:
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
return operations
paths = self.schema_content.get("paths", {})

View File

@@ -38,7 +38,7 @@ class LangchainToolWrapper(LangchainBaseTool):
name=tool_instance.name,
description=tool_instance.description,
args_schema=args_schema,
_tool_instance=tool_instance,
tool_instance=tool_instance,
**kwargs
)
@@ -59,7 +59,7 @@ class LangchainToolWrapper(LangchainBaseTool):
"""异步执行工具"""
try:
# 执行内部工具
result = await self._tool_instance.safe_execute(**kwargs)
result = await self.tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result)
@@ -73,24 +73,39 @@ class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具
Args:
tool: 内部工具实例
operation: 特定操作(适用于有操作的工具)
Returns:
Langchain兼容的工具包装器
"""
try:
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
if operation and tool.name in ['datetime_tool', 'json_tool']:
# 为特定操作创建工具
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
else:
# 单个工具
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise
@staticmethod
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
"""为特定操作创建工具实例"""
from app.core.tools.builtin.operation_tool import OperationTool
return OperationTool(base_tool, operation)
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具
@@ -110,7 +125,7 @@ class LangchainAdapter:
except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
return converted_tools
@staticmethod
@@ -169,9 +184,10 @@ class LangchainAdapter:
"ToolArgsSchema",
(BaseModel,),
{
"__module__": __name__,
"__annotations__": annotations,
**fields,
"Config": type("Config", (), {"extra": "forbid"})
"model_config": {"extra": "forbid"},
**fields
}
)

View File

@@ -16,14 +16,17 @@ logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
def __init__(self, db: Session):
def __init__(self, db: Session = None):
"""初始化MCP服务管理器
Args:
db: 数据库会话
db: 数据库会话(可选)
"""
self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20)
if db:
self.connection_pool = MCPConnectionPool(max_connections=20)
else:
self.connection_pool = None
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
@@ -592,7 +595,7 @@ class MCPServiceManager:
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {

View File

@@ -1,6 +1,6 @@
import datetime
import uuid
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List, Dict, Union
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
@@ -36,6 +36,12 @@ class KnowledgeRetrievalConfig(BaseModel):
class ToolConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
tool_id: str = Field(default=None, description="工具ID")
operation: Optional[str] = Field(default_factory=dict, description="工具特定配置")
class ToolOldConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")
config: Optional[Dict[str, Any]] = Field(default_factory=dict, description="工具特定配置")
@@ -103,9 +109,9 @@ class AgentConfigCreate(BaseModel):
)
# 工具配置
tools: Dict[str, ToolConfig] = Field(
default_factory=dict,
description="工具配置key 为工具名称web_search, code_interpreter, image_generation 等)"
tools: List[ToolConfig] = Field(
default_factory=list,
description="Agent 可用的工具列表"
)
@@ -158,7 +164,7 @@ class AgentConfigUpdate(BaseModel):
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
# 工具配置
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
tools: Optional[List[ToolConfig]] = Field(default=None, description="工具列表")
# ---------- Output Schemas ----------
@@ -216,7 +222,7 @@ class AgentConfig(BaseModel):
variables: List[VariableDefinition] = []
# 工具配置
tools: Dict[str, ToolConfig] = {}
tools: Union[List[ToolConfig], Dict[str, ToolOldConfig]] = []
is_active: bool
created_at: datetime.datetime

View File

@@ -2,14 +2,14 @@
Agent 配置格式转换器
用于将 Pydantic 模型转换为数据库存储格式
"""
from typing import Dict, Any, Optional
from typing import Dict, Any, Optional, Union
from app.schemas.app_schema import (
KnowledgeRetrievalConfig,
MemoryConfig,
VariableDefinition,
ToolConfig,
AgentConfigCreate,
AgentConfigUpdate,
AgentConfigUpdate, ToolOldConfig,
)
@@ -47,10 +47,7 @@ class AgentConfigConverter:
# 5. 工具配置
if hasattr(config, 'tools') and config.tools:
result["tools"] = {
name: tool.model_dump()
for name, tool in config.tools.items()
}
result["tools"] = [tool.model_dump() for tool in config.tools]
return result
@@ -60,7 +57,7 @@ class AgentConfigConverter:
knowledge_retrieval: Optional[Dict[str, Any]],
memory: Optional[Dict[str, Any]],
variables: Optional[list],
tools: Optional[Dict[str, Any]],
tools: Optional[Union[list, Dict[str, Any]]],
) -> Dict[str, Any]:
"""
将数据库存储格式转换为 Pydantic 对象
@@ -113,9 +110,12 @@ class AgentConfigConverter:
# 5. 解析工具配置
if tools:
result["tools"] = {
name: ToolConfig(**tool_data)
for name, tool_data in tools.items()
}
if isinstance(tools, list):
result["tools"] = [ToolConfig(**tool_config) for tool_config in tools]
else:
result["tools"] = {
name: ToolOldConfig(**tool_data)
for name, tool_data in tools.items()
}
return result

View File

@@ -14,6 +14,10 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.db import get_db, get_db_context
from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig
from app.services.tool_service import ToolService
from app.repositories.tool_repository import ToolRepository
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
@@ -43,6 +47,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None
) -> Dict[str, Any]:
"""聊天(非流式)"""
@@ -68,6 +73,24 @@ class AppChatService:
# 准备工具列表
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
@@ -86,7 +109,7 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.tools
# web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
@@ -173,6 +196,7 @@ class AppChatService:
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
@@ -199,6 +223,23 @@ class AppChatService:
# 准备工具列表
tools = []
# 获取工具服务
tool_service = ToolService(self.db)
if hasattr(config, 'tools') and config.tools:
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, workspace_id))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
knowledge_retrieval = config.knowledge_retrieval
if knowledge_retrieval:
@@ -217,20 +258,20 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
# if web_search_enable == True:
# search_tool = create_web_search_tool({})
# tools.append(search_tool)
#
# logger.debug(
# "已添加网络搜索工具",
# extra={
# "tool_count": len(tools)
# }
# )
# 获取模型参数
model_parameters = config.model_parameters

View File

@@ -307,7 +307,7 @@ class AppService:
knowledge_retrieval=storage_data.get("knowledge_retrieval"),
memory=storage_data.get("memory"),
variables=storage_data.get("variables", []),
tools=storage_data.get("tools", {}),
tools=storage_data.get("tools", []),
is_active=True,
created_at=now,
updated_at=now,
@@ -689,7 +689,7 @@ class AppService:
knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None,
memory=source_config.memory.copy() if source_config.memory else None,
variables=source_config.variables.copy() if source_config.variables else [],
tools=source_config.tools.copy() if source_config.tools else {},
tools=source_config.tools.copy() if source_config.tools else [],
is_active=True,
created_at=now,
updated_at=now,
@@ -879,7 +879,7 @@ class AppService:
# if data.variables is not None:
agent_cfg.variables = storage_data.get("variables", [])
# if data.tools is not None:
agent_cfg.tools = storage_data.get("tools", {})
agent_cfg.tools = storage_data.get("tools", [])
agent_cfg.updated_at = now
@@ -966,7 +966,7 @@ class AppService:
"max_history": 10
},
variables=[],
tools={},
tools=[],
is_active=True,
created_at=now,
updated_at=now,
@@ -1183,7 +1183,7 @@ class AppService:
"knowledge_retrieval": agent_cfg.knowledge_retrieval,
"memory": agent_cfg.memory,
"variables": agent_cfg.variables or [],
"tools": agent_cfg.tools or {},
"tools": agent_cfg.tools or [],
}
# config = AgentConfigConverter.from_storage_format(agent_cfg)
default_model_config_id = agent_cfg.default_model_config_id

View File

@@ -10,19 +10,22 @@ import time
import uuid
from typing import Any, AsyncGenerator, Dict, List, Optional
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval
from app.models import AgentConfig, ModelApiKey, ModelConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message
from app.services.langchain_tool_server import Search
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_parameter_merger import ModelParameterMerger
from langchain.tools import tool
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.services.tool_service import ToolService
logger = get_business_logger()
class KnowledgeRetrievalInput(BaseModel):
@@ -291,24 +294,22 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
# 添加网络搜索工具
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
tool_service = ToolService(self.db)
if web_search_enable:
logger.info("网络搜索已启用")
# 创建网络搜索工具
search_tool = create_web_search_tool(web_search_config)
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -503,24 +504,22 @@ class DraftRunService:
# 4. 准备工具列表
tools = []
# 添加网络搜索工具
if web_search:
if agent_config.tools:
web_search_config = agent_config.tools.get("web_search", {})
web_search_enable = web_search_config.get("enabled", False)
tool_service = ToolService(self.db)
if web_search_enable:
logger.info("网络搜索已启用")
# 创建网络搜索工具
search_tool = create_web_search_tool(web_search_config)
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
# 添加知识库检索工具
if agent_config.knowledge_retrieval: