feat(agent tool): add agent tool plugin
This commit is contained in:
@@ -191,10 +191,14 @@ class BaseTool(ABC):
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def to_langchain_tool(self):
|
||||
"""转换为Langchain工具格式"""
|
||||
def to_langchain_tool(self, operation: Optional[str] = None):
|
||||
"""转换为Langchain工具格式
|
||||
|
||||
Args:
|
||||
operation: 特定操作(适用于有操作的工具)
|
||||
"""
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
return LangchainAdapter.convert_tool(self)
|
||||
return LangchainAdapter.convert_tool(self, operation)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"
|
||||
216
api/app/core/tools/builtin/operation_tool.py
Normal file
216
api/app/core/tools/builtin/operation_tool.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""操作工具 - 为特定操作创建的工具包装器"""
|
||||
from typing import List
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.models import ToolType
|
||||
|
||||
|
||||
class OperationTool(BaseTool):
|
||||
"""操作工具 - 包装基础工具的特定操作"""
|
||||
|
||||
def __init__(self, base_tool: BaseTool, operation: str):
|
||||
self.base_tool = base_tool
|
||||
self.operation = operation
|
||||
super().__init__(base_tool.tool_id, base_tool.config)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"{self.base_tool.name}_{self.operation}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.BUILTIN
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return f"{self.base_tool.description} - {self.operation}"
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""返回特定操作的参数"""
|
||||
if self.base_tool.name == 'datetime_tool':
|
||||
return self._get_datetime_params()
|
||||
elif self.base_tool.name == 'json_tool':
|
||||
return self._get_json_params()
|
||||
else:
|
||||
# 默认返回除operation外的所有参数
|
||||
return [p for p in self.base_tool.parameters if p.name != "operation"]
|
||||
|
||||
def _get_datetime_params(self) -> List[ToolParameter]:
|
||||
"""获取datetime_tool特定操作的参数"""
|
||||
if self.operation == "now":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "format":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
]
|
||||
elif self.operation == "convert_timezone":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="input_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="from_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
elif self.operation == "timestamp_to_datetime":
|
||||
return [
|
||||
ToolParameter(
|
||||
name="input_value",
|
||||
type=ParameterType.STRING,
|
||||
description="输入值(时间字符串或时间戳)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="output_format",
|
||||
type=ParameterType.STRING,
|
||||
description="输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
required=False,
|
||||
default="%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="Asia/Shanghai"
|
||||
)
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def _get_json_params(self) -> List[ToolParameter]:
|
||||
"""获取json_tool特定操作的参数"""
|
||||
base_params = [
|
||||
ToolParameter(
|
||||
name="input_data",
|
||||
type=ParameterType.STRING,
|
||||
description="输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
|
||||
if self.operation == "insert":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_value",
|
||||
type=ParameterType.STRING,
|
||||
description="新值(用于insert操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "replace":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="old_text",
|
||||
type=ParameterType.STRING,
|
||||
description="要替换的原文本(用于replace操作)",
|
||||
required=True
|
||||
),
|
||||
ToolParameter(
|
||||
name="new_text",
|
||||
type=ParameterType.STRING,
|
||||
description="替换后的新文本(用于replace操作)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "delete":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
elif self.operation == "parse":
|
||||
return base_params + [
|
||||
ToolParameter(
|
||||
name="json_path",
|
||||
type=ParameterType.STRING,
|
||||
description="JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
required=True
|
||||
)
|
||||
]
|
||||
else:
|
||||
return base_params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行特定操作"""
|
||||
# 添加operation参数
|
||||
kwargs["operation"] = self.operation
|
||||
return await self.base_tool.execute(**kwargs)
|
||||
@@ -1,4 +1,5 @@
|
||||
"""自定义工具基类"""
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
import aiohttp
|
||||
@@ -135,6 +136,13 @@ class CustomTool(BaseTool):
|
||||
|
||||
if not self.schema_content:
|
||||
return operations
|
||||
|
||||
if isinstance(self.schema_content, str):
|
||||
try:
|
||||
self.schema_content = json.loads(self.schema_content)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无效的OpenAPI schema: {self.schema_content}")
|
||||
return operations
|
||||
|
||||
paths = self.schema_content.get("paths", {})
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
name=tool_instance.name,
|
||||
description=tool_instance.description,
|
||||
args_schema=args_schema,
|
||||
_tool_instance=tool_instance,
|
||||
tool_instance=tool_instance,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class LangchainToolWrapper(LangchainBaseTool):
|
||||
"""异步执行工具"""
|
||||
try:
|
||||
# 执行内部工具
|
||||
result = await self._tool_instance.safe_execute(**kwargs)
|
||||
result = await self.tool_instance.safe_execute(**kwargs)
|
||||
|
||||
# 转换结果为Langchain格式
|
||||
return LangchainAdapter._format_result_for_langchain(result)
|
||||
@@ -73,24 +73,39 @@ class LangchainAdapter:
|
||||
"""Langchain适配器 - 负责工具格式转换和标准化"""
|
||||
|
||||
@staticmethod
|
||||
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
|
||||
def convert_tool(tool: BaseTool, operation: Optional[str] = None) -> LangchainToolWrapper:
|
||||
"""将内部工具转换为Langchain工具
|
||||
|
||||
Args:
|
||||
tool: 内部工具实例
|
||||
operation: 特定操作(适用于有操作的工具)
|
||||
|
||||
Returns:
|
||||
Langchain兼容的工具包装器
|
||||
"""
|
||||
try:
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
if operation and tool.name in ['datetime_tool', 'json_tool']:
|
||||
# 为特定操作创建工具
|
||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||
logger.debug(f"工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
else:
|
||||
# 单个工具
|
||||
wrapper = LangchainToolWrapper(tool_instance=tool)
|
||||
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
|
||||
return wrapper
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _create_operation_tool(base_tool: BaseTool, operation: str) -> BaseTool:
|
||||
"""为特定操作创建工具实例"""
|
||||
from app.core.tools.builtin.operation_tool import OperationTool
|
||||
return OperationTool(base_tool, operation)
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
@@ -110,7 +125,7 @@ class LangchainAdapter:
|
||||
except Exception as e:
|
||||
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
|
||||
|
||||
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
|
||||
logger.info(f"批量转换完成: {len(converted_tools)} 个工具")
|
||||
return converted_tools
|
||||
|
||||
@staticmethod
|
||||
@@ -169,9 +184,10 @@ class LangchainAdapter:
|
||||
"ToolArgsSchema",
|
||||
(BaseModel,),
|
||||
{
|
||||
"__module__": __name__,
|
||||
"__annotations__": annotations,
|
||||
**fields,
|
||||
"Config": type("Config", (), {"extra": "forbid"})
|
||||
"model_config": {"extra": "forbid"},
|
||||
**fields
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -16,14 +16,17 @@ logger = get_business_logger()
|
||||
class MCPServiceManager:
|
||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
def __init__(self, db: Session = None):
|
||||
"""初始化MCP服务管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
db: 数据库会话(可选)
|
||||
"""
|
||||
self.db = db
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
if db:
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
else:
|
||||
self.connection_pool = None
|
||||
|
||||
# 服务状态管理
|
||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||
@@ -592,7 +595,7 @@ class MCPServiceManager:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理失效服务失败: {e}")
|
||||
|
||||
|
||||
def get_manager_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user