Merge pull request #56 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
feat(agent tool)
This commit is contained in:
@@ -21,24 +21,35 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
|
||||
# 内部工具实例
|
||||
tool_instance: BaseTool = Field(..., description="内部工具实例")
|
||||
# 特定操作(用于自定义工具)
|
||||
operation: Optional[str] = Field(None, description="特定操作")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, tool_instance: BaseTool, **kwargs):
|
||||
def __init__(self, tool_instance: BaseTool, operation: Optional[str] = None, **kwargs):
|
||||
"""初始化Langchain工具包装器
|
||||
|
||||
Args:
|
||||
tool_instance: 内部工具实例
|
||||
operation: 特定操作(用于自定义工具)
|
||||
"""
|
||||
# 动态创建参数schema
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
|
||||
args_schema = LangchainAdapter._create_pydantic_schema(
|
||||
tool_instance.parameters, operation
|
||||
)
|
||||
|
||||
# 构建工具名称
|
||||
tool_name = tool_instance.name
|
||||
if operation:
|
||||
tool_name = f"{tool_instance.name}_{operation}"
|
||||
|
||||
super().__init__(
|
||||
name=tool_instance.name,
|
||||
name=tool_name,
|
||||
description=tool_instance.description,
|
||||
args_schema=args_schema,
|
||||
tool_instance=tool_instance,
|
||||
operation=operation,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -58,6 +69,10 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
) -> str:
|
||||
"""异步执行工具"""
|
||||
try:
|
||||
# 如果有特定操作,添加到参数中
|
||||
if self.operation:
|
||||
kwargs["operation"] = self.operation
|
||||
|
||||
# 执行内部工具
|
||||
result = await self.tool_instance.safe_execute(**kwargs)
|
||||
|
||||
@@ -85,16 +100,21 @@ class LangchainAdapter:
|
||||
"""
|
||||
try:
|
||||
# 处理MCP工具的特定工具名称
|
||||
if hasattr(tool, 'tool_type') and tool.tool_type.value == 'mcp' and operation:
|
||||
if hasattr(tool, 'tool_type') and tool.tool_type.value == "mcp" and operation:
|
||||
# 为MCP工具创建特定工具名称的实例
|
||||
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
|
||||
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
elif operation and tool.name in ['datetime_tool', 'json_tool']:
|
||||
# 为特定操作创建工具
|
||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||
elif operation and LangchainAdapter._tool_supports_operations(tool):
|
||||
# 为支持多操作的工具创建特定操作实例
|
||||
if tool.tool_type.value == "custom":
|
||||
# 自定义工具直接传递operation参数
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool, operation=operation)
|
||||
else:
|
||||
# 内置工具使用OperationTool包装
|
||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
else:
|
||||
@@ -108,23 +128,42 @@ class LangchainAdapter:
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
|
||||
"""为特定操作创建工具实例"""
|
||||
from app.core.tools.builtin.operation_tool import OperationTool
|
||||
return OperationTool(base_tool, operation)
|
||||
def _tool_supports_operations(tool: BaseTool) -> bool:
|
||||
"""检查工具是否支持多操作"""
|
||||
# 内置工具中支持操作的工具
|
||||
builtin_operation_tools = ['datetime_tool', 'json_tool']
|
||||
|
||||
# 检查内置工具
|
||||
if tool.tool_type.value == "builtin" and tool.name in builtin_operation_tools:
|
||||
return True
|
||||
|
||||
# 检查自定义工具(自定义工具通过解析OpenAPI schema支持多操作)
|
||||
if tool.tool_type.value == "custom":
|
||||
# 检查工具是否有多个操作
|
||||
if hasattr(tool, '_parsed_operations') and len(tool._parsed_operations) > 1:
|
||||
return True
|
||||
# 或者检查参数中是否有operation参数
|
||||
for param in tool.parameters:
|
||||
if param.name == "operation" and param.enum:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _create_mcp_tool_with_name(base_tool: BaseTool, tool_name: str) -> BaseTool:
|
||||
"""为MCP工具创建指定工具名称的实例"""
|
||||
from app.core.tools.mcp.base import MCPTool
|
||||
|
||||
# 创建新的配置,指定具体工具名称
|
||||
new_config = base_tool.config.copy()
|
||||
new_config["tool_name"] = tool_name
|
||||
|
||||
# 创建新的MCP工具实例
|
||||
return MCPTool(f"{base_tool.tool_id}_{tool_name}", new_config)
|
||||
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
|
||||
"""为特定操作创建工具实例"""
|
||||
if base_tool.tool_type.value == "builtin":
|
||||
from app.core.tools.builtin.operation_tool import OperationTool
|
||||
return OperationTool(base_tool, operation)
|
||||
else:
|
||||
raise ValueError(f"不支持的工具类型: {base_tool.tool_type.value}")
|
||||
|
||||
@staticmethod
|
||||
def _create_mcp_tool_with_name(mcp_tool: BaseTool, tool_name: str) -> BaseTool:
|
||||
"""为MCP工具创建指定工具名称的实例"""
|
||||
mcp_tool.set_current_tool(tool_name)
|
||||
return mcp_tool
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
@@ -148,11 +187,15 @@ class LangchainAdapter:
|
||||
return converted_tools
|
||||
|
||||
@staticmethod
|
||||
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
|
||||
def _create_pydantic_schema(
|
||||
parameters: List[ToolParameter],
|
||||
operation: Optional[str] = None
|
||||
) -> Type[BaseModel]:
|
||||
"""根据工具参数创建Pydantic schema
|
||||
|
||||
Args:
|
||||
parameters: 工具参数列表
|
||||
operation: 特定操作(用于过滤参数)
|
||||
|
||||
Returns:
|
||||
Pydantic模型类
|
||||
@@ -161,7 +204,12 @@ class LangchainAdapter:
|
||||
fields = {}
|
||||
annotations = {}
|
||||
|
||||
for param in parameters:
|
||||
# 如果指定了operation,过滤掉operation参数
|
||||
filtered_params = parameters
|
||||
if operation:
|
||||
filtered_params = [p for p in parameters if p.name != "operation"]
|
||||
|
||||
for param in filtered_params:
|
||||
# 确定Python类型
|
||||
python_type = LangchainAdapter._get_python_type(param.type)
|
||||
|
||||
|
||||
@@ -16,19 +16,15 @@ class MCPTool(BaseTool):
|
||||
super().__init__(tool_id, config)
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.connection_config = config.get("connection_config", {})
|
||||
self.tool_name = config.get("tool_name", "") # 特定工具名称
|
||||
self.tool_schema = config.get("tool_schema", {}) # 工具参数 schema
|
||||
self.available_tools = config.get("available_tools", [])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"mcp_{self.tool_name}" if self.tool_name else f"mcp_tool_{self.tool_id[:8]}"
|
||||
return f"mcp_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
if self.tool_schema.get("description"):
|
||||
return self.tool_schema["description"]
|
||||
return f"MCP工具: {self.tool_name}" if self.tool_name else f"MCP工具 - 连接到 {self.server_url}"
|
||||
return f"MCP工具 - 连接到 {self.server_url}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
@@ -36,20 +32,36 @@ class MCPTool(BaseTool):
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""从 MCP 工具 schema 生成参数"""
|
||||
if not self.tool_schema:
|
||||
return [ToolParameter(
|
||||
"""根据工具名称返回对应参数"""
|
||||
# 如果有指定的工具名称,从 available_tools 中获取参数
|
||||
tool_name = getattr(self, '_current_tool_name', None)
|
||||
if tool_name and self.available_tools:
|
||||
for tool_info in self.available_tools:
|
||||
if tool_info.get("tool_name") == tool_name:
|
||||
arguments = tool_info.get("arguments", {})
|
||||
return self._generate_parameters_from_schema(arguments)
|
||||
|
||||
# 默认返回通用参数
|
||||
return [
|
||||
ToolParameter(
|
||||
name="tool_name",
|
||||
type=ParameterType.STRING,
|
||||
description="要执行的工具名称",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="arguments",
|
||||
type=ParameterType.OBJECT,
|
||||
description="工具参数",
|
||||
required=False,
|
||||
default={}
|
||||
)]
|
||||
|
||||
# 解析 MCP 工具的 inputSchema
|
||||
input_schema = self.tool_schema.get("inputSchema", {})
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_parameters_from_schema(self, arguments: Dict[str, Any]) -> List[ToolParameter]:
|
||||
"""从参数schema生成参数列表"""
|
||||
properties = arguments.get("properties", {})
|
||||
required_fields = arguments.get("required", [])
|
||||
|
||||
params = []
|
||||
for param_name, param_def in properties.items():
|
||||
@@ -69,7 +81,7 @@ class MCPTool(BaseTool):
|
||||
return params
|
||||
|
||||
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
|
||||
"""转换 JSON Schema 类型到 ParameterType"""
|
||||
"""转换JSON Schema类型到ParameterType"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
"integer": ParameterType.INTEGER,
|
||||
@@ -80,25 +92,27 @@ class MCPTool(BaseTool):
|
||||
}
|
||||
return type_mapping.get(json_type, ParameterType.STRING)
|
||||
|
||||
def set_current_tool(self, tool_name: str):
|
||||
"""设置当前工具名称,用于获取特定参数"""
|
||||
self._current_tool_name = tool_name
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行 MCP 工具"""
|
||||
"""执行MCP工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
tool_name = kwargs.get("tool_name")
|
||||
if not tool_name:
|
||||
raise Exception("未指定工具名称")
|
||||
|
||||
arguments = kwargs.get("arguments", {})
|
||||
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
client = SimpleMCPClient(self.server_url, self.connection_config)
|
||||
|
||||
async with client:
|
||||
# 使用指定的工具名称或默认第一个工具
|
||||
tool_name_to_use = self.tool_name
|
||||
if not tool_name_to_use and self.available_tools:
|
||||
tool_name_to_use = self.available_tools[0]
|
||||
|
||||
if not tool_name_to_use:
|
||||
raise Exception("未指定工具名称且无可用工具")
|
||||
|
||||
result = await client.call_tool(tool_name_to_use, kwargs)
|
||||
result = await client.call_tool(tool_name, arguments)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
@@ -108,7 +122,7 @@ class MCPTool(BaseTool):
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"MCP工具执行失败: {self.tool_name or 'unknown'}, 错误: {e}")
|
||||
logger.error(f"MCP工具执行失败: {kwargs.get('tool_name', 'unknown')}, 错误: {e}")
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_EXECUTION_ERROR",
|
||||
|
||||
@@ -154,7 +154,7 @@ class MCPToolConfigSchema(BaseModel):
|
||||
last_health_check: Optional[datetime] = None
|
||||
health_status: str = "unknown"
|
||||
error_message: Optional[str] = None
|
||||
available_tools: List[str] = Field(default_factory=list)
|
||||
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
@@ -42,7 +42,7 @@ class ToolService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._tool_cache: Dict[str, BaseTool] = {}
|
||||
|
||||
|
||||
# MCP管理器
|
||||
self.mcp_tool_manager = MCPToolManager(db)
|
||||
|
||||
@@ -691,34 +691,35 @@ class ToolService:
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
if success:
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
mcp_config.available_tools = tool_names
|
||||
# 转换为新格式
|
||||
tool_list = []
|
||||
for tool in tools:
|
||||
if tool.get("name"):
|
||||
tool_list.append({
|
||||
tool["name"]: {
|
||||
"description": tool.get("description", ""),
|
||||
"inputSchema": tool.get("inputSchema", {})
|
||||
}
|
||||
})
|
||||
mcp_config.available_tools = tool_list
|
||||
self.db.commit()
|
||||
available_tools = tool_names
|
||||
available_tools = tool_list
|
||||
except Exception as e:
|
||||
logger.error(f"同步MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
methods = []
|
||||
|
||||
# 获取工具详细信息
|
||||
try:
|
||||
success, tools, _ = await self.mcp_tool_manager.discover_tools(
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
|
||||
if success:
|
||||
tools_dict = {tool.get("name"): tool for tool in tools if tool.get("name")}
|
||||
|
||||
for tool_name in available_tools:
|
||||
tool_info = tools_dict.get(tool_name, {})
|
||||
|
||||
|
||||
# 处理新格式的available_tools
|
||||
for tool_item in available_tools:
|
||||
if isinstance(tool_item, dict):
|
||||
for tool_name, tool_data in tool_item.items():
|
||||
# 解析工具参数
|
||||
parameters = []
|
||||
input_schema = tool_info.get("inputSchema", {})
|
||||
input_schema = tool_data.get("inputSchema", {})
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
|
||||
for param_name, param_def in properties.items():
|
||||
parameters.append({
|
||||
"name": param_name,
|
||||
@@ -730,27 +731,16 @@ class ToolService:
|
||||
"minimum": param_def.get("minimum"),
|
||||
"maximum": param_def.get("maximum")
|
||||
})
|
||||
|
||||
|
||||
methods.append({
|
||||
"method_id": tool_name,
|
||||
"name": tool_name,
|
||||
"description": tool_info.get("description", f"MCP工具: {tool_name}"),
|
||||
"description": tool_data.get("description", f"MCP工具: {tool_name}"),
|
||||
"parameters": parameters
|
||||
})
|
||||
else:
|
||||
# 如果无法获取详细信息,返回基本信息
|
||||
for tool_name in available_tools:
|
||||
methods.append({
|
||||
"method_id": tool_name,
|
||||
"name": tool_name,
|
||||
"description": f"MCP工具: {tool_name}",
|
||||
"parameters": []
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具详细信息失败: {e}")
|
||||
# 返回基本信息
|
||||
for tool_name in available_tools:
|
||||
# 兼容旧格式(字符串)
|
||||
tool_name = str(tool_item)
|
||||
methods.append({
|
||||
"method_id": tool_name,
|
||||
"name": tool_name,
|
||||
@@ -877,14 +867,10 @@ class ToolService:
|
||||
if not mcp_config:
|
||||
return None
|
||||
|
||||
# 从配置中获取特定工具名称
|
||||
tool_name = config.config_data.get("tool_name")
|
||||
|
||||
tool_config = {
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"available_tools": mcp_config.available_tools or [],
|
||||
"tool_name": tool_name # 指定具体工具
|
||||
"available_tools": mcp_config.available_tools or []
|
||||
}
|
||||
|
||||
return MCPTool(str(config.id), tool_config)
|
||||
@@ -897,10 +883,18 @@ class ToolService:
|
||||
if config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||
if mcp_config:
|
||||
# 处理available_tools显示格式
|
||||
available_tools_display = []
|
||||
for tool_item in (mcp_config.available_tools or []):
|
||||
if isinstance(tool_item, dict):
|
||||
available_tools_display.extend(list(tool_item.keys()))
|
||||
else:
|
||||
available_tools_display.append(str(tool_item))
|
||||
|
||||
config_data.update({
|
||||
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
|
||||
"health_status": mcp_config.health_status,
|
||||
"available_tools": mcp_config.available_tools or []
|
||||
"available_tools": available_tools_display
|
||||
})
|
||||
|
||||
return ToolInfo(
|
||||
@@ -1150,25 +1144,36 @@ class ToolService:
|
||||
test_result = await self.mcp_tool_manager.test_tool_connection(
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
|
||||
|
||||
if test_result["success"]:
|
||||
# 连接成功,自动同步工具列表
|
||||
success, tools, error = await self.mcp_tool_manager.discover_tools(
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
|
||||
|
||||
if success:
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
|
||||
# 转换为新格式
|
||||
tool_list = []
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
if tool.get("name"):
|
||||
tool_names.append(tool["name"])
|
||||
tool_list.append({
|
||||
tool["name"]: {
|
||||
"description": tool.get("description", ""),
|
||||
"inputSchema": tool.get("inputSchema", {})
|
||||
}
|
||||
})
|
||||
|
||||
# 更新数据库
|
||||
mcp_config.available_tools = tool_names
|
||||
mcp_config.available_tools = tool_list
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "healthy"
|
||||
mcp_config.error_message = None
|
||||
config.status = ToolStatus.AVAILABLE.value
|
||||
|
||||
|
||||
self.db.commit()
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "MCP连接成功并同步工具列表",
|
||||
@@ -1187,9 +1192,9 @@ class ToolService:
|
||||
mcp_config.error_message = test_result.get("error", "连接失败")
|
||||
config.status = ToolStatus.ERROR.value
|
||||
self.db.commit()
|
||||
|
||||
|
||||
return test_result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}")
|
||||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||||
@@ -1248,30 +1253,42 @@ class ToolService:
|
||||
# 创建MCP客户端
|
||||
connection_config = mcp_config.connection_config or {}
|
||||
client = SimpleMCPClient(mcp_config.server_url, connection_config)
|
||||
|
||||
|
||||
async with client:
|
||||
# 获取工具列表
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
|
||||
|
||||
# 转换为新格式
|
||||
tool_list = []
|
||||
tool_names = []
|
||||
for tool in tools:
|
||||
if tool.get("name"):
|
||||
tool_names.append(tool["name"])
|
||||
tool_list.append({
|
||||
tool["name"]: {
|
||||
"description": tool.get("description", ""),
|
||||
"inputSchema": tool.get("inputSchema", {})
|
||||
}
|
||||
})
|
||||
|
||||
# 更新数据库
|
||||
mcp_config.available_tools = tool_names
|
||||
mcp_config.available_tools = tool_list
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "healthy"
|
||||
mcp_config.error_message = None
|
||||
|
||||
|
||||
# 更新工具状态
|
||||
config.status = ToolStatus.AVAILABLE.value
|
||||
|
||||
|
||||
self.db.commit()
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "工具列表同步成功",
|
||||
"tools_count": len(tool_names),
|
||||
"tools": tool_names
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 更新错误状态
|
||||
try:
|
||||
@@ -1284,7 +1301,7 @@ class ToolService:
|
||||
self.db.commit()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
|
||||
return {"success": False, "message": f"同步失败: {str(e)}"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user