- Refactor the extraction logic in tool service to flatten request body parameters into independent arguments exposed to the model.
1654 lines
68 KiB
Python
1654 lines
68 KiB
Python
"""工具服务 - 统一的工具管理和执行服务"""
|
||
import json
|
||
import uuid
|
||
import time
|
||
import importlib
|
||
from typing import Dict, Any, List, Optional
|
||
from datetime import datetime
|
||
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.error_codes import BizCode
|
||
from app.core.exceptions import BusinessException
|
||
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
|
||
from app.repositories.tool_repository import (
|
||
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
||
MCPToolRepository, ToolExecutionRepository
|
||
)
|
||
|
||
from app.models.tool_model import (
|
||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||
ToolExecution, ToolType, ToolStatus, ExecutionStatus, AuthType
|
||
)
|
||
from app.schemas.tool_schema import ToolInfo, ToolResult
|
||
from app.core.logging_config import get_business_logger
|
||
from app.core.tools.base import BaseTool
|
||
from app.core.tools.custom.base import CustomTool
|
||
from app.core.tools.mcp.base import MCPTool
|
||
|
||
logger = get_business_logger()
|
||
|
||
# 内置工具映射
|
||
BUILTIN_TOOLS = {
|
||
"DateTimeTool": "app.core.tools.builtin.datetime_tool",
|
||
"JsonTool": "app.core.tools.builtin.json_tool",
|
||
"BaiduSearchTool": "app.core.tools.builtin.baidu_search_tool",
|
||
"MinerUTool": "app.core.tools.builtin.mineru_tool",
|
||
"TextInTool": "app.core.tools.builtin.textin_tool",
|
||
"OpenClawTool": "app.core.tools.builtin.openclaw_tool",
|
||
}
|
||
|
||
|
||
class ToolService:
|
||
"""统一工具服务 - 管理工具的完整生命周期"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
self._tool_cache: Dict[str, BaseTool] = {}
|
||
|
||
# MCP管理器
|
||
self.mcp_tool_manager = MCPToolManager(db)
|
||
|
||
# 初始化仓储
|
||
self.tool_repo = ToolRepository()
|
||
self.builtin_repo = BuiltinToolRepository()
|
||
self.custom_repo = CustomToolRepository()
|
||
self.mcp_repo = MCPToolRepository()
|
||
self.execution_repo = ToolExecutionRepository()
|
||
|
||
def list_tools(
|
||
self,
|
||
tenant_id: uuid.UUID,
|
||
name: Optional[str] = None,
|
||
tool_type: Optional[ToolType] = None,
|
||
status: Optional[ToolStatus] = None
|
||
) -> List[ToolInfo]:
|
||
"""获取工具列表"""
|
||
try:
|
||
configs = self.tool_repo.find_by_tenant(
|
||
db=self.db,
|
||
tenant_id=tenant_id,
|
||
name=name,
|
||
tool_type=tool_type,
|
||
status=status
|
||
)
|
||
return [self._config_to_info(config) for config in configs]
|
||
except Exception as e:
|
||
logger.error(f"获取工具列表失败: {e}")
|
||
return []
|
||
|
||
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
|
||
"""获取工具详情"""
|
||
config = self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
|
||
return self._config_to_info(config) if config else None
|
||
|
||
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
|
||
"""检查工具名称是否重复"""
|
||
query = self.db.query(ToolConfig).filter(
|
||
ToolConfig.name == name,
|
||
ToolConfig.tool_type == tool_type,
|
||
ToolConfig.tenant_id == tenant_id
|
||
)
|
||
if exclude_id:
|
||
query = query.filter(ToolConfig.id != exclude_id)
|
||
if query.first():
|
||
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||
|
||
def _check_mcp_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, config: Dict[str, Any]):
|
||
"""检查MCP工具是否重复:市场来源按market_id+market_config_id+mcp_service_id判断(名称无关),自建按name+tool_type判断"""
|
||
from app.models.tool_model import MCPSourceChannel
|
||
source_channel = config.get("source_channel")
|
||
is_market_source = (
|
||
source_channel is not None
|
||
and source_channel != MCPSourceChannel.SELF_HOSTED
|
||
)
|
||
if is_market_source:
|
||
exists = (
|
||
self.db.query(ToolConfig)
|
||
.join(MCPToolConfig, MCPToolConfig.id == ToolConfig.id)
|
||
.filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.tool_type == tool_type,
|
||
MCPToolConfig.source_channel == source_channel,
|
||
MCPToolConfig.market_id == config.get("market_id"),
|
||
MCPToolConfig.market_config_id == config.get("market_config_id"),
|
||
MCPToolConfig.mcp_service_id == config.get("mcp_service_id"),
|
||
)
|
||
.first()
|
||
)
|
||
if exists:
|
||
raise BusinessException(f"该MCP服务已添加", BizCode.DUPLICATE_NAME)
|
||
else:
|
||
exists = (
|
||
self.db.query(ToolConfig)
|
||
.filter(
|
||
ToolConfig.name == name,
|
||
ToolConfig.tool_type == tool_type,
|
||
ToolConfig.tenant_id == tenant_id,
|
||
)
|
||
.first()
|
||
)
|
||
if exists:
|
||
raise BusinessException(f"工具 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||
|
||
async def create_tool(
|
||
self,
|
||
name: str,
|
||
tool_type: ToolType,
|
||
tenant_id: uuid.UUID,
|
||
icon: Optional[str] = None,
|
||
description: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
tags: Optional[List[str]] = None
|
||
) -> str:
|
||
"""创建工具"""
|
||
if tool_type == ToolType.BUILTIN:
|
||
raise ValueError("内置工具不允许创建")
|
||
|
||
cfg = config or {}
|
||
if tool_type == ToolType.MCP:
|
||
self._check_mcp_duplicate(name, tool_type, tenant_id, cfg)
|
||
# 创建前测试连接
|
||
test_result = await self._test_mcp_connection_by_config(cfg)
|
||
if not test_result["success"]:
|
||
raise BusinessException(f"MCP连接测试失败: {test_result['message']}", BizCode.INVALID_PARAMETER)
|
||
# 将发现的工具列表写回 config
|
||
if "available_tools" in test_result:
|
||
cfg["available_tools"] = test_result["available_tools"]
|
||
else:
|
||
self._check_name_duplicate(name, tool_type, tenant_id)
|
||
|
||
try:
|
||
# 创建基础配置
|
||
tool_config = ToolConfig(
|
||
name=name,
|
||
description=description,
|
||
icon=icon,
|
||
tool_type=tool_type.value,
|
||
tenant_id=tenant_id,
|
||
status=ToolStatus.AVAILABLE.value,
|
||
config_data=cfg,
|
||
tags=tags
|
||
)
|
||
self.db.add(tool_config)
|
||
self.db.flush()
|
||
|
||
# 创建类型特定配置
|
||
self._create_type_config(tool_config, cfg)
|
||
|
||
self.db.commit()
|
||
logger.info(f"工具创建成功: {tool_config.id}")
|
||
return str(tool_config.id)
|
||
|
||
except BusinessException:
|
||
self.db.rollback()
|
||
raise
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(f"创建工具失败: {e}")
|
||
raise
|
||
|
||
def update_tool(
|
||
self,
|
||
tool_id: str,
|
||
tenant_id: uuid.UUID,
|
||
name: Optional[str] = None,
|
||
description: Optional[str] = None,
|
||
icon: Optional[str] = None,
|
||
config: Optional[Dict[str, Any]] = None,
|
||
is_enabled: Optional[bool] = None,
|
||
tags: Optional[List[str]] = None
|
||
) -> bool:
|
||
"""更新工具"""
|
||
config_obj = self._get_tool_config(tool_id, tenant_id)
|
||
if not config_obj:
|
||
return False
|
||
|
||
if config_obj.tool_type == ToolType.BUILTIN.value:
|
||
if name or description or icon:
|
||
raise ValueError("内置工具不允许修改名称、描述和图标")
|
||
try:
|
||
if name:
|
||
self._check_name_duplicate(name, config_obj.tool_type, tenant_id, exclude_id=config_obj.id)
|
||
config_obj.name = name
|
||
if description:
|
||
config_obj.description = description
|
||
if icon:
|
||
config_obj.icon = icon
|
||
if tags:
|
||
config_obj.tags = tags
|
||
if config:
|
||
config_obj.config_data = config.copy()
|
||
|
||
# 同步到类型表
|
||
self._sync_type_config(config_obj, config, is_enabled)
|
||
|
||
# 更新状态逻辑
|
||
self._update_tool_status(config_obj)
|
||
|
||
# 清除缓存
|
||
self._clear_tool_cache(tool_id)
|
||
|
||
self.db.commit()
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(f"更新工具失败: {tool_id}, {e}")
|
||
return False
|
||
|
||
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
|
||
"""删除工具(逻辑删除)"""
|
||
config = self._get_tool_config(tool_id, tenant_id)
|
||
if not config:
|
||
return False
|
||
|
||
if config.tool_type == ToolType.BUILTIN.value:
|
||
raise ValueError("内置工具不允许删除")
|
||
|
||
try:
|
||
config.is_active = False
|
||
self._clear_tool_cache(tool_id)
|
||
self.db.commit()
|
||
return True
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(f"删除工具失败: {tool_id}, {e}")
|
||
return False
|
||
|
||
def set_tool_active(self, tool_id: str, tenant_id: uuid.UUID, is_active: bool) -> bool:
|
||
"""设置工具可用状态(启用/禁用)"""
|
||
# 直接查询,包含 is_active=False 的记录
|
||
config = self.db.query(ToolConfig).filter(
|
||
ToolConfig.id == uuid.UUID(tool_id),
|
||
ToolConfig.tenant_id == tenant_id
|
||
).first()
|
||
if not config:
|
||
return False
|
||
if config.tool_type == ToolType.BUILTIN.value:
|
||
raise ValueError("内置工具不允许修改可用状态")
|
||
try:
|
||
config.is_active = is_active
|
||
self._clear_tool_cache(tool_id)
|
||
self.db.commit()
|
||
return True
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(f"设置工具状态失败: {tool_id}, {e}")
|
||
return False
|
||
|
||
async def execute_tool(
|
||
self,
|
||
tool_id: str,
|
||
parameters: Dict[str, Any],
|
||
tenant_id: uuid.UUID,
|
||
user_id: Optional[uuid.UUID] = None,
|
||
workspace_id: Optional[uuid.UUID] = None,
|
||
timeout: float = 60.0
|
||
) -> ToolResult:
|
||
"""执行工具"""
|
||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||
start_time = time.time()
|
||
|
||
try:
|
||
# 获取工具实例
|
||
tool = self.get_tool_instance(tool_id, tenant_id)
|
||
if not tool:
|
||
return ToolResult.error_result(
|
||
error=f"工具不存在: {tool_id}",
|
||
execution_time=time.time() - start_time
|
||
)
|
||
|
||
# 记录执行开始
|
||
self._record_execution_start(
|
||
execution_id, tool_id, parameters, user_id, workspace_id
|
||
)
|
||
|
||
# 执行工具
|
||
result = await tool.safe_execute(**parameters)
|
||
|
||
# 记录执行完成
|
||
self._record_execution_complete(execution_id, result)
|
||
|
||
return result
|
||
|
||
except Exception as e:
|
||
execution_time = time.time() - start_time
|
||
error_result = ToolResult.error_result(
|
||
error=str(e),
|
||
execution_time=execution_time
|
||
)
|
||
self._record_execution_complete(execution_id, error_result)
|
||
return error_result
|
||
|
||
async def test_connection(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
||
"""测试工具连接"""
|
||
try:
|
||
config = self._get_tool_config(tool_id, tenant_id)
|
||
if not config:
|
||
return {"success": False, "message": "工具不存在"}
|
||
|
||
if config.tool_type == ToolType.MCP.value:
|
||
return await self._test_mcp_connection(config)
|
||
elif config.tool_type == ToolType.CUSTOM.value:
|
||
return await self._test_custom_connection(config)
|
||
elif config.tool_type == ToolType.BUILTIN.value:
|
||
return await self._test_builtin_connection(config)
|
||
else:
|
||
return {"success": True, "message": "未知工具类型"}
|
||
|
||
except Exception as e:
|
||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||
|
||
def ensure_builtin_tools_initialized(self, tenant_id: uuid.UUID):
|
||
"""确保内置工具已初始化(支持增量补充新工具)"""
|
||
builtin_config = self._load_builtin_config()
|
||
if not builtin_config:
|
||
return
|
||
|
||
existing_classes = self.builtin_repo.get_existing_tool_classes(self.db, tenant_id)
|
||
|
||
added = False
|
||
for tool_key, tool_info in builtin_config.items():
|
||
if tool_info['tool_class'] in existing_classes:
|
||
continue
|
||
try:
|
||
initial_status = self._determine_initial_status(tool_info)
|
||
tool_config = ToolConfig(
|
||
name=tool_info['name'],
|
||
description=tool_info['description'],
|
||
tool_type=ToolType.BUILTIN.value,
|
||
tenant_id=tenant_id,
|
||
status=initial_status,
|
||
config_data={"tool_class": tool_info['tool_class'],
|
||
"requires_config": tool_info.get('requires_config', False),
|
||
"is_enabled": False},
|
||
version=tool_info["version"]
|
||
)
|
||
self.db.add(tool_config)
|
||
self.db.flush()
|
||
|
||
builtin_config_obj = BuiltinToolConfig(
|
||
id=tool_config.id,
|
||
tool_class=tool_info['tool_class'],
|
||
parameters={},
|
||
requires_config=tool_info.get('requires_config', False)
|
||
)
|
||
self.db.add(builtin_config_obj)
|
||
added = True
|
||
|
||
except Exception as e:
|
||
logger.error(f"初始化内置工具失败: {tool_key}, {e}")
|
||
|
||
if added:
|
||
self.db.commit()
|
||
logger.info(f"租户 {tenant_id} 内置工具增量初始化完成")
|
||
|
||
async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]:
|
||
"""获取工具的所有方法
|
||
|
||
Args:
|
||
tool_id: 工具ID
|
||
tenant_id: 租户ID
|
||
|
||
Returns:
|
||
方法列表或None
|
||
"""
|
||
config = self._get_tool_config_all(tool_id, tenant_id)
|
||
if not config:
|
||
return None
|
||
|
||
try:
|
||
if config.tool_type == ToolType.BUILTIN.value:
|
||
return await self._get_builtin_tool_methods(config)
|
||
elif config.tool_type == ToolType.CUSTOM.value:
|
||
return await self._get_custom_tool_methods(config)
|
||
elif config.tool_type == ToolType.MCP.value:
|
||
return await self._get_mcp_tool_methods(config)
|
||
else:
|
||
return []
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取工具方法失败: {tool_id}, {e}")
|
||
return []
|
||
|
||
async def _get_builtin_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||
"""获取内置工具的方法"""
|
||
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
|
||
if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS:
|
||
return []
|
||
|
||
# 获取工具实例
|
||
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||
if not tool_instance:
|
||
return []
|
||
|
||
# 检查是否有operation参数
|
||
operation_param = None
|
||
for param in tool_instance.parameters:
|
||
if param.name == "operation" and param.enum:
|
||
operation_param = param
|
||
break
|
||
|
||
if operation_param:
|
||
# 有多个操作,为每个操作生成具体参数
|
||
methods = []
|
||
for operation in operation_param.enum:
|
||
# 获取该操作的具体参数
|
||
operation_params = self._get_operation_specific_params(tool_instance, operation)
|
||
methods.append({
|
||
"method_id": f"{config.name}_{operation}",
|
||
"name": operation,
|
||
"description": f"{config.description} - {operation}",
|
||
"parameters": operation_params
|
||
})
|
||
return methods
|
||
else:
|
||
# 只有一个方法
|
||
return [{
|
||
"method_id": config.name,
|
||
"name": config.name,
|
||
"description": config.description,
|
||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||
}]
|
||
|
||
def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]:
|
||
"""获取特定操作的参数列表"""
|
||
# 对于datetime_tool,根据操作类型返回相关参数
|
||
if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool':
|
||
return self._get_datetime_tool_params(operation)
|
||
# 对于json_tool,根据操作类型返回相关参数
|
||
elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool':
|
||
return self._get_json_tool_params(operation)
|
||
# 对于openclaw_tool,根据操作类型返回不同描述的参数
|
||
elif hasattr(tool_instance, 'name') and tool_instance.name == 'openclaw_tool':
|
||
return self._get_openclaw_tool_params(operation)
|
||
|
||
# 其他工具的默认处理:返回除operation外的所有参数
|
||
return [{
|
||
"name": param.name,
|
||
"type": param.type.value,
|
||
"description": param.description,
|
||
"required": param.required,
|
||
"default": param.default,
|
||
"enum": param.enum,
|
||
"minimum": param.minimum,
|
||
"maximum": param.maximum,
|
||
"pattern": param.pattern
|
||
} for param in tool_instance.parameters if param.name != "operation"]
|
||
|
||
def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||
"""获取datetime_tool特定操作的参数"""
|
||
if operation == "now":
|
||
return [
|
||
{
|
||
"name": "to_timezone",
|
||
"type": "string",
|
||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||
"required": False,
|
||
"default": "Asia/Shanghai"
|
||
},
|
||
{
|
||
"name": "output_format",
|
||
"type": "string",
|
||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
}
|
||
]
|
||
elif operation == "format":
|
||
return [
|
||
{
|
||
"name": "input_value",
|
||
"type": "string",
|
||
"description": "输入值(时间字符串或时间戳)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "input_format",
|
||
"type": "string",
|
||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
},
|
||
{
|
||
"name": "output_format",
|
||
"type": "string",
|
||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
}
|
||
]
|
||
elif operation == "convert_timezone":
|
||
return [
|
||
{
|
||
"name": "input_value",
|
||
"type": "string",
|
||
"description": "输入值(时间字符串或时间戳)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "input_format",
|
||
"type": "string",
|
||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
},
|
||
{
|
||
"name": "output_format",
|
||
"type": "string",
|
||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
},
|
||
{
|
||
"name": "from_timezone",
|
||
"type": "string",
|
||
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||
"required": False,
|
||
"default": "Asia/Shanghai"
|
||
},
|
||
{
|
||
"name": "to_timezone",
|
||
"type": "string",
|
||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||
"required": False,
|
||
"default": "Asia/Shanghai"
|
||
}
|
||
]
|
||
elif operation == "timestamp_to_datetime":
|
||
return [
|
||
{
|
||
"name": "input_value",
|
||
"type": "string",
|
||
"description": "输入值(时间字符串或时间戳)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "output_format",
|
||
"type": "string",
|
||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
},
|
||
{
|
||
"name": "to_timezone",
|
||
"type": "string",
|
||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||
"required": False,
|
||
"default": "Asia/Shanghai"
|
||
}
|
||
]
|
||
elif operation == "datetime_to_timestamp":
|
||
return [
|
||
{
|
||
"name": "input_value",
|
||
"type": "string",
|
||
"description": "输入值(时间字符串,如:2026-04-07 10:30:25)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "input_format",
|
||
"type": "string",
|
||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
},
|
||
{
|
||
"name": "from_timezone",
|
||
"type": "string",
|
||
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||
"required": False,
|
||
"default": "Asia/Shanghai"
|
||
}
|
||
]
|
||
else:
|
||
# 默认返回所有参数(除了operation)
|
||
return [
|
||
{
|
||
"name": "input_value",
|
||
"type": "string",
|
||
"description": "输入值(时间字符串或时间戳)",
|
||
"required": False
|
||
},
|
||
{
|
||
"name": "input_format",
|
||
"type": "string",
|
||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
},
|
||
{
|
||
"name": "output_format",
|
||
"type": "string",
|
||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||
"required": False,
|
||
"default": "%Y-%m-%d %H:%M:%S"
|
||
},
|
||
{
|
||
"name": "from_timezone",
|
||
"type": "string",
|
||
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||
"required": False,
|
||
"default": "Asia/Shanghai"
|
||
},
|
||
{
|
||
"name": "to_timezone",
|
||
"type": "string",
|
||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||
"required": False,
|
||
"default": "Asia/Shanghai"
|
||
},
|
||
{
|
||
"name": "calculation",
|
||
"type": "string",
|
||
"description": "时间计算表达式(如:+1d, -2h, +30m)",
|
||
"required": False
|
||
}
|
||
]
|
||
|
||
def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||
"""获取json_tool特定操作的参数"""
|
||
base_params = [
|
||
{
|
||
"name": "input_data",
|
||
"type": "string",
|
||
"description": "输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||
"required": True
|
||
}
|
||
]
|
||
|
||
if operation == "insert":
|
||
return base_params + [
|
||
{
|
||
"name": "json_path",
|
||
"type": "string",
|
||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "new_value",
|
||
"type": "string",
|
||
"description": "新值(用于insert操作)",
|
||
"required": True
|
||
}
|
||
]
|
||
elif operation == "replace":
|
||
return base_params + [
|
||
{
|
||
"name": "json_path",
|
||
"type": "string",
|
||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "old_text",
|
||
"type": "string",
|
||
"description": "要替换的原文本(用于replace操作)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "new_text",
|
||
"type": "string",
|
||
"description": "替换后的新文本(用于replace操作)",
|
||
"required": True
|
||
}
|
||
]
|
||
elif operation == "delete":
|
||
return base_params + [
|
||
{
|
||
"name": "json_path",
|
||
"type": "string",
|
||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||
"required": True
|
||
}
|
||
]
|
||
elif operation == "parse":
|
||
return base_params + [
|
||
{
|
||
"name": "json_path",
|
||
"type": "string",
|
||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||
"required": True
|
||
}
|
||
]
|
||
|
||
return base_params
|
||
|
||
@staticmethod
|
||
def _get_openclaw_tool_params(operation: str) -> List[Dict[str, Any]]:
|
||
"""获取 openclaw_tool 特定操作的参数"""
|
||
if operation == "print_task":
|
||
return [
|
||
{
|
||
"name": "message",
|
||
"type": "string",
|
||
"description": "发送给 OpenClaw 的打印任务描述,将用户的原始消息原封不动地传递给 OpenClaw,禁止改写、补充或润色用户的原文",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "image_url",
|
||
"type": "string",
|
||
"description": "可选,附带的设计图片或参考图,OpenClaw 可据此生成 3D 模型",
|
||
"required": False
|
||
}
|
||
]
|
||
elif operation == "device_query":
|
||
return [
|
||
{
|
||
"name": "message",
|
||
"type": "string",
|
||
"description": "发送给 OpenClaw 的设备查询指令",
|
||
"required": True
|
||
}
|
||
]
|
||
elif operation == "image_understand":
|
||
return [
|
||
{
|
||
"name": "message",
|
||
"type": "string",
|
||
"description": "发送给 OpenClaw 的图片理解任务,应描述需要对图片做什么(如描述内容、提取文字、分析信息)",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "image_url",
|
||
"type": "string",
|
||
"description": "要分析的图片 URL 或 base64 data URI",
|
||
"required": False
|
||
}
|
||
]
|
||
else:
|
||
# general 及其他
|
||
return [
|
||
{
|
||
"name": "message",
|
||
"type": "string",
|
||
"description": "发送给 OpenClaw Agent 的任务描述,应包含完整的任务需求",
|
||
"required": True
|
||
},
|
||
{
|
||
"name": "image_url",
|
||
"type": "string",
|
||
"description": "可选,附带的图片 URL 或 base64 data URI",
|
||
"required": False
|
||
}
|
||
]
|
||
|
||
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||
"""获取自定义工具的方法"""
|
||
custom_config = self.custom_repo.find_by_tool_id(self.db, config.id)
|
||
if not custom_config:
|
||
return []
|
||
|
||
try:
|
||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||
parser = OpenAPISchemaParser()
|
||
|
||
# 解析schema
|
||
if custom_config.schema_content:
|
||
success, schema, error = parser.parse_from_content(custom_config.schema_content, "application/json")
|
||
elif custom_config.schema_url:
|
||
success, schema, error = await parser.parse_from_url(custom_config.schema_url)
|
||
else:
|
||
return []
|
||
|
||
if not success:
|
||
return []
|
||
|
||
# 提取操作
|
||
tool_info = parser.extract_tool_info(schema)
|
||
operations = tool_info.get("operations", {})
|
||
|
||
methods = []
|
||
for operation_id, operation in operations.items():
|
||
# 生成参数列表
|
||
parameters = []
|
||
|
||
# 路径和查询参数
|
||
for param_name, param_info in operation.get("parameters", {}).items():
|
||
parameters.append({
|
||
"name": param_name,
|
||
"type": param_info.get("type", "string"),
|
||
"description": param_info.get("description", ""),
|
||
"required": param_info.get("required", False),
|
||
"enum": param_info.get("enum"),
|
||
"default": param_info.get("default")
|
||
})
|
||
|
||
# 请求体参数 — _extract_request_body 返回 {"schema": {...}, "required": bool, ...}
|
||
request_body = operation.get("request_body")
|
||
if request_body:
|
||
body_schema = request_body.get("schema", {})
|
||
schema_props = body_schema.get("properties", {})
|
||
required_props = body_schema.get("required", [])
|
||
|
||
for prop_name, prop_schema in schema_props.items():
|
||
parameters.append({
|
||
"name": prop_name,
|
||
"type": prop_schema.get("type", "string"),
|
||
"description": prop_schema.get("description", ""),
|
||
"required": prop_name in required_props,
|
||
"enum": prop_schema.get("enum"),
|
||
"default": prop_schema.get("default")
|
||
})
|
||
|
||
methods.append({
|
||
"method_id": operation_id,
|
||
"name": operation.get("summary", operation_id),
|
||
"description": operation.get("description", ""),
|
||
"method": operation.get("method", "GET"),
|
||
"path": operation.get("path", "/"),
|
||
"parameters": parameters
|
||
})
|
||
|
||
return methods
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析自定义工具schema失败: {e}")
|
||
return []
|
||
|
||
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||
"""获取MCP工具的方法和参数"""
|
||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||
if not mcp_config:
|
||
return []
|
||
|
||
available_tools = mcp_config.available_tools or []
|
||
if not available_tools:
|
||
# 如果没有工具列表,尝试同步
|
||
try:
|
||
success, tools, _ = await self.mcp_tool_manager.discover_tools(
|
||
mcp_config.server_url, mcp_config.connection_config or {}
|
||
)
|
||
if success:
|
||
# 转换为新格式
|
||
tool_list = []
|
||
for tool in tools:
|
||
if tool.get("name"):
|
||
tool_list.append({
|
||
tool["name"]: {
|
||
"description": tool.get("description", ""),
|
||
"inputSchema": tool.get("inputSchema", {})
|
||
}
|
||
})
|
||
mcp_config.available_tools = tool_list
|
||
self.db.commit()
|
||
available_tools = tool_list
|
||
except Exception as e:
|
||
logger.error(f"同步MCP工具列表失败: {e}")
|
||
return []
|
||
|
||
methods = []
|
||
|
||
# 处理新格式的available_tools
|
||
for tool_item in available_tools:
|
||
if isinstance(tool_item, dict):
|
||
for tool_name, tool_data in tool_item.items():
|
||
# 解析工具参数
|
||
parameters = []
|
||
input_schema = tool_data.get("inputSchema", {})
|
||
properties = input_schema.get("properties", {})
|
||
required_fields = input_schema.get("required", [])
|
||
|
||
for param_name, param_def in properties.items():
|
||
parameters.append({
|
||
"name": param_name,
|
||
"type": param_def.get("type", "string"),
|
||
"description": param_def.get("description", ""),
|
||
"required": param_name in required_fields,
|
||
"default": param_def.get("default"),
|
||
"enum": param_def.get("enum"),
|
||
"minimum": param_def.get("minimum"),
|
||
"maximum": param_def.get("maximum")
|
||
})
|
||
|
||
methods.append({
|
||
"method_id": tool_name,
|
||
"name": tool_name,
|
||
"description": tool_data.get("description", f"MCP工具: {tool_name}"),
|
||
"parameters": parameters
|
||
})
|
||
else:
|
||
# 兼容旧格式(字符串)
|
||
tool_name = str(tool_item)
|
||
methods.append({
|
||
"method_id": tool_name,
|
||
"name": tool_name,
|
||
"description": f"MCP工具: {tool_name}",
|
||
"parameters": []
|
||
})
|
||
|
||
return methods
|
||
|
||
def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
||
"""获取工具统计信息"""
|
||
try:
|
||
# 总数统计
|
||
total_tools = self.tool_repo.count_by_tenant(self.db, tenant_id)
|
||
|
||
# 状态统计
|
||
status_counts = self.tool_repo.get_status_statistics(self.db, tenant_id)
|
||
|
||
# 类型统计
|
||
type_counts = self.tool_repo.get_type_statistics(self.db, tenant_id)
|
||
|
||
# 启用/禁用统计
|
||
enabled_count = self.tool_repo.count_enabled_by_tenant(self.db, tenant_id)
|
||
disabled_count = total_tools - enabled_count
|
||
|
||
return {
|
||
"total_tools": total_tools,
|
||
"status_counts": [
|
||
{"status": status, "count": count}
|
||
for status, count in status_counts
|
||
],
|
||
"type_counts": {
|
||
tool_type: count for tool_type, count in type_counts
|
||
},
|
||
"enabled_count": enabled_count,
|
||
"disabled_count": disabled_count
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"获取工具统计失败: {e}")
|
||
return {
|
||
"total_tools": 0,
|
||
"status_counts": [],
|
||
"type_counts": {},
|
||
"enabled_count": 0,
|
||
"disabled_count": 0
|
||
}
|
||
|
||
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||
"""获取工具配置(仅返回 is_active=True)"""
|
||
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||
|
||
def _get_tool_config_all(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||
"""获取工具配置(返回所有)"""
|
||
return self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
|
||
|
||
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||
"""获取工具实例(仅返回 is_active=True 的工具)"""
|
||
if tool_id in self._tool_cache:
|
||
return self._tool_cache[tool_id]
|
||
|
||
config = self._get_tool_config(tool_id, tenant_id)
|
||
if not config or not config.is_active:
|
||
return None
|
||
|
||
try:
|
||
tool = self._create_tool_instance(config)
|
||
if tool:
|
||
self._tool_cache[tool_id] = tool
|
||
return tool
|
||
except Exception as e:
|
||
logger.error(f"创建工具实例失败: {tool_id}, {e}")
|
||
return None
|
||
|
||
def _create_tool_instance(self, config: ToolConfig) -> Optional[BaseTool]:
|
||
"""创建工具实例"""
|
||
if config.tool_type == ToolType.BUILTIN.value:
|
||
return self._create_builtin_instance(config)
|
||
elif config.tool_type == ToolType.CUSTOM.value:
|
||
return self._create_custom_instance(config)
|
||
elif config.tool_type == ToolType.MCP.value:
|
||
return self._create_mcp_instance(config)
|
||
return None
|
||
|
||
def _create_builtin_instance(self, config: ToolConfig) -> Optional[BaseTool]:
|
||
"""创建内置工具实例"""
|
||
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
|
||
|
||
if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS:
|
||
return None
|
||
|
||
try:
|
||
module_path = BUILTIN_TOOLS[builtin_config.tool_class]
|
||
module = importlib.import_module(module_path)
|
||
tool_class = getattr(module, builtin_config.tool_class)
|
||
|
||
tool_config = {
|
||
**config.config_data,
|
||
"parameters": builtin_config.parameters,
|
||
}
|
||
|
||
return tool_class(str(config.id), tool_config)
|
||
except Exception as e:
|
||
logger.error(f"创建内置工具实例失败: {builtin_config.tool_class}, {e}")
|
||
return None
|
||
|
||
def _create_custom_instance(self, config: ToolConfig) -> Optional[CustomTool]:
|
||
"""创建自定义工具实例"""
|
||
custom_config = self.custom_repo.find_by_tool_id(self.db, config.id)
|
||
|
||
if not custom_config:
|
||
return None
|
||
|
||
tool_config = {
|
||
"base_url": custom_config.base_url,
|
||
"auth_type": custom_config.auth_type,
|
||
"auth_config": custom_config.auth_config or {},
|
||
"timeout": custom_config.timeout or 30,
|
||
"schema_content": custom_config.schema_content,
|
||
"schema_url": custom_config.schema_url
|
||
}
|
||
|
||
return CustomTool(str(config.id), tool_config)
|
||
|
||
def _create_mcp_instance(self, config: ToolConfig) -> Optional[MCPTool]:
|
||
"""创建MCP工具实例"""
|
||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||
|
||
if not mcp_config:
|
||
return None
|
||
|
||
tool_config = {
|
||
"server_url": mcp_config.server_url,
|
||
"connection_config": mcp_config.connection_config or {},
|
||
"available_tools": mcp_config.available_tools or []
|
||
}
|
||
|
||
return MCPTool(str(config.id), tool_config)
|
||
|
||
def _config_to_info(self, config: ToolConfig) -> ToolInfo:
|
||
"""配置转换为信息对象"""
|
||
config_data = config.config_data or {}
|
||
|
||
# 对于MCP工具,从MCPToolConfig获取额外信息
|
||
if config.tool_type == ToolType.MCP.value:
|
||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||
if mcp_config:
|
||
# 处理available_tools显示格式
|
||
available_tools_display = []
|
||
for tool_item in (mcp_config.available_tools or []):
|
||
if isinstance(tool_item, dict):
|
||
available_tools_display.extend(list(tool_item.keys()))
|
||
else:
|
||
available_tools_display.append(str(tool_item))
|
||
|
||
config_data.update({
|
||
"last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None,
|
||
"health_status": mcp_config.health_status,
|
||
"available_tools": available_tools_display,
|
||
"source_channel": mcp_config.source_channel,
|
||
"market_id": mcp_config.market_id,
|
||
"market_config_id": mcp_config.market_config_id,
|
||
"mcp_service_id": mcp_config.mcp_service_id
|
||
})
|
||
|
||
return ToolInfo(
|
||
id=str(config.id),
|
||
name=config.name,
|
||
description=config.description or "",
|
||
icon=config.icon,
|
||
tool_type=ToolType(config.tool_type),
|
||
version=config.version or "1.0.0",
|
||
status=ToolStatus(config.status),
|
||
tags=config.tags or [],
|
||
tenant_id=str(config.tenant_id) if config.tenant_id else None,
|
||
config_data=config_data,
|
||
is_active=config.is_active,
|
||
created_at=config.created_at
|
||
)
|
||
|
||
def _create_type_config(self, tool_config: ToolConfig, config: Dict[str, Any]):
|
||
"""创建类型特定配置"""
|
||
if tool_config.tool_type == ToolType.CUSTOM.value:
|
||
# 从 schema 中解析 base_url
|
||
base_url = config.get("base_url")
|
||
if not base_url and (config.get("schema_content") or config.get("schema_url")):
|
||
try:
|
||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||
parser = OpenAPISchemaParser()
|
||
|
||
if config.get("schema_content"):
|
||
success, schema, _ = parser.parse_from_content(config["schema_content"], "application/json")
|
||
else:
|
||
success, schema, _ = parser.parse_from_url(config["schema_url"])
|
||
|
||
if success:
|
||
tool_info = parser.extract_tool_info(schema)
|
||
servers = tool_info.get("servers", [])
|
||
base_url = servers[0].get("url") if servers else ""
|
||
except Exception as e:
|
||
logger.error(f"解析schema获取base_url失败: {e}")
|
||
|
||
custom_config = CustomToolConfig(
|
||
id=tool_config.id,
|
||
base_url=base_url,
|
||
auth_type=config.get("auth_type", "none"),
|
||
auth_config=config.get("auth_config", {}),
|
||
timeout=config.get("timeout", 30),
|
||
schema_content=config.get("schema_content"),
|
||
schema_url=config.get("schema_url")
|
||
)
|
||
self.db.add(custom_config)
|
||
|
||
elif tool_config.tool_type == ToolType.MCP.value:
|
||
mcp_config = MCPToolConfig(
|
||
id=tool_config.id,
|
||
server_url=config.get("server_url"),
|
||
connection_config=config.get("connection_config", {}),
|
||
available_tools=config.get("available_tools", []),
|
||
source_channel=config.get("source_channel", "self_hosted"),
|
||
market_id=config.get("market_id"),
|
||
market_config_id=config.get("market_config_id"),
|
||
mcp_service_id=config.get("mcp_service_id"),
|
||
)
|
||
self.db.add(mcp_config)
|
||
|
||
def _sync_type_config(self, tool_config: ToolConfig, config: Dict[str, Any], is_enabled: bool):
|
||
"""同步到类型特定表"""
|
||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||
BuiltinToolConfig.id == tool_config.id
|
||
).first()
|
||
if builtin_config:
|
||
builtin_config.parameters = config.get("parameters", {})
|
||
if is_enabled is not None:
|
||
builtin_config.is_enabled = is_enabled
|
||
|
||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||
custom_config = self.db.query(CustomToolConfig).filter(
|
||
CustomToolConfig.id == tool_config.id
|
||
).first()
|
||
if custom_config:
|
||
base_url = config.get("base_url")
|
||
if not base_url and (config.get("schema_content") or config.get("schema_url")):
|
||
try:
|
||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||
parser = OpenAPISchemaParser()
|
||
|
||
if config.get("schema_content"):
|
||
success, schema, _ = parser.parse_from_content(config["schema_content"],
|
||
"application/json")
|
||
else:
|
||
success, schema, _ = parser.parse_from_url(config["schema_url"])
|
||
|
||
if success:
|
||
tool_info = parser.extract_tool_info(schema)
|
||
servers = tool_info.get("servers", [])
|
||
base_url = servers[0].get("url") if servers else ""
|
||
except Exception as e:
|
||
logger.error(f"解析schema获取base_url失败: {e}")
|
||
custom_config.base_url = base_url
|
||
custom_config.auth_type = config.get("auth_type", "none")
|
||
custom_config.auth_config = config.get("auth_config", {})
|
||
custom_config.timeout = config.get("timeout", 30)
|
||
custom_config.schema_content = config.get("schema_content")
|
||
custom_config.schema_url = config.get("schema_url")
|
||
|
||
elif tool_config.tool_type == ToolType.MCP.value:
|
||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||
MCPToolConfig.id == tool_config.id
|
||
).first()
|
||
if mcp_config:
|
||
mcp_config.server_url = config.get("server_url")
|
||
mcp_config.connection_config = config.get("connection_config", {})
|
||
mcp_config.available_tools = config.get("available_tools", [])
|
||
if config.get("source_channel") is not None:
|
||
mcp_config.source_channel = config.get("source_channel")
|
||
if config.get("market_id") is not None:
|
||
mcp_config.market_id = config.get("market_id")
|
||
if config.get("market_config_id") is not None:
|
||
mcp_config.market_config_id = config.get("market_config_id")
|
||
if config.get("mcp_service_id") is not None:
|
||
mcp_config.mcp_service_id = config.get("mcp_service_id")
|
||
|
||
@staticmethod
|
||
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
|
||
"""确定工具初始状态"""
|
||
if tool_info.get('requires_config', False):
|
||
return ToolStatus.UNCONFIGURED
|
||
else:
|
||
return ToolStatus.AVAILABLE
|
||
|
||
def _update_tool_status(self, tool_config: ToolConfig):
|
||
"""更新工具状态逻辑"""
|
||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||
BuiltinToolConfig.id == tool_config.id
|
||
).first()
|
||
|
||
if builtin_config:
|
||
if builtin_config.requires_config:
|
||
# 需要配置的工具
|
||
if self._is_tool_configured(builtin_config):
|
||
if tool_config.config_data.get("is_enabled", None):
|
||
tool_config.status = ToolStatus.AVAILABLE.value
|
||
else:
|
||
tool_config.status = ToolStatus.CONFIGURED_DISABLED.value
|
||
else:
|
||
tool_config.status = ToolStatus.UNCONFIGURED.value
|
||
else:
|
||
# 不需要配置的工具
|
||
tool_config.status = ToolStatus.AVAILABLE.value
|
||
|
||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||
custom_config = self.db.query(CustomToolConfig).filter(
|
||
CustomToolConfig.id == tool_config.id
|
||
).first()
|
||
|
||
if custom_config and tool_config.name and (custom_config.schema_content or custom_config.schema_url):
|
||
tool_config.status = ToolStatus.AVAILABLE.value
|
||
else:
|
||
tool_config.status = ToolStatus.UNCONFIGURED.value
|
||
|
||
elif tool_config.tool_type == ToolType.MCP.value:
|
||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||
MCPToolConfig.id == tool_config.id
|
||
).first()
|
||
|
||
if mcp_config:
|
||
if mcp_config.health_status == "healthy":
|
||
tool_config.status = ToolStatus.AVAILABLE.value
|
||
elif mcp_config.health_status == "error":
|
||
tool_config.status = ToolStatus.ERROR.value
|
||
else:
|
||
tool_config.status = ToolStatus.UNCONFIGURED.value
|
||
|
||
def _is_tool_configured(self, builtin_config: BuiltinToolConfig) -> bool:
|
||
"""检查工具是否已配置"""
|
||
# 从配置文件获取必需参数
|
||
builtin_config_data = self._load_builtin_config()
|
||
required_params = {}
|
||
for key, value in builtin_config_data.items():
|
||
if builtin_config.tool_class == value["tool_class"]:
|
||
required_params = value.get('parameters', {})
|
||
break
|
||
|
||
# 检查所有必需参数是否已配置
|
||
for param_name, param_info in required_params.items():
|
||
if param_info.get('required', False):
|
||
if not builtin_config.parameters.get(param_name):
|
||
return False
|
||
return True
|
||
|
||
def _clear_tool_cache(self, tool_id: str):
|
||
"""清除工具缓存"""
|
||
if tool_id in self._tool_cache:
|
||
del self._tool_cache[tool_id]
|
||
|
||
def _record_execution_start(
|
||
self,
|
||
execution_id: str,
|
||
tool_id: str,
|
||
parameters: Dict[str, Any],
|
||
user_id: Optional[uuid.UUID],
|
||
workspace_id: Optional[uuid.UUID]
|
||
):
|
||
"""记录执行开始"""
|
||
try:
|
||
execution = ToolExecution(
|
||
execution_id=execution_id,
|
||
tool_config_id=uuid.UUID(tool_id),
|
||
status=ExecutionStatus.RUNNING.value,
|
||
input_data=parameters,
|
||
started_at=datetime.now(),
|
||
user_id=user_id,
|
||
workspace_id=workspace_id
|
||
)
|
||
self.db.add(execution)
|
||
self.db.commit()
|
||
except Exception as e:
|
||
logger.error(f"记录执行开始失败: {execution_id}, {e}")
|
||
|
||
def _record_execution_complete(self, execution_id: str, result: ToolResult):
|
||
"""记录执行完成"""
|
||
try:
|
||
execution = self.db.query(ToolExecution).filter(
|
||
ToolExecution.execution_id == execution_id
|
||
).first()
|
||
|
||
if execution:
|
||
execution.status = ExecutionStatus.COMPLETED.value if result.success else ExecutionStatus.FAILED.value
|
||
execution.output_data = result.data if result.success else None
|
||
execution.error_message = result.error if not result.success else None
|
||
execution.completed_at = datetime.now()
|
||
execution.execution_time = result.execution_time
|
||
execution.token_usage = result.token_usage
|
||
self.db.commit()
|
||
except Exception as e:
|
||
logger.error(f"记录执行完成失败: {execution_id}, {e}")
|
||
|
||
@staticmethod
|
||
def _load_builtin_config() -> Dict[str, Any]:
|
||
"""加载内置工具配置"""
|
||
import json
|
||
from pathlib import Path
|
||
|
||
config_file = Path(__file__).parent.parent / "core" / "tools" / "configs" / "builtin_tools.json"
|
||
try:
|
||
with open(config_file, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except Exception as e:
|
||
logger.error(f"加载内置工具配置失败: {e}")
|
||
return {}
|
||
|
||
async def _test_mcp_connection_by_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""根据配置参数直接测试MCP连接(创建前调用,无需已存在的工具记录)"""
|
||
server_url = config.get("server_url")
|
||
if not server_url:
|
||
return {"success": False, "message": "server_url不能为空"}
|
||
connection_config = config.get("connection_config") or {}
|
||
try:
|
||
test_result = await self.mcp_tool_manager.test_tool_connection(server_url, connection_config)
|
||
if not test_result["success"]:
|
||
return test_result
|
||
success_flag, tools, error = await self.mcp_tool_manager.discover_tools(server_url, connection_config)
|
||
if not success_flag:
|
||
return {"success": False, "message": f"获取工具列表失败: {error}"}
|
||
tool_list = [
|
||
{tool["name"]: {"description": tool.get("description", ""), "inputSchema": tool.get("inputSchema", {})}}
|
||
for tool in tools if tool.get("name")
|
||
]
|
||
return {"success": True, "message": "MCP连接测试成功", "available_tools": tool_list}
|
||
except Exception as e:
|
||
return {"success": False, "message": f"连接测试异常: {str(e)}"}
|
||
|
||
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||
"""测试MCP连接并自动同步工具列表"""
|
||
try:
|
||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||
if not mcp_config:
|
||
return {"success": False, "message": "MCP配置不存在"}
|
||
|
||
# 使用集成的MCP管理器测试连接
|
||
test_result = await self.mcp_tool_manager.test_tool_connection(
|
||
mcp_config.server_url, mcp_config.connection_config or {}
|
||
)
|
||
|
||
if test_result["success"]:
|
||
# 连接成功,自动同步工具列表
|
||
success, tools, error = await self.mcp_tool_manager.discover_tools(
|
||
mcp_config.server_url, mcp_config.connection_config or {}
|
||
)
|
||
|
||
if success:
|
||
# 转换为新格式
|
||
tool_list = []
|
||
tool_names = []
|
||
for tool in tools:
|
||
if tool.get("name"):
|
||
tool_names.append(tool["name"])
|
||
tool_list.append({
|
||
tool["name"]: {
|
||
"description": tool.get("description", ""),
|
||
"inputSchema": tool.get("inputSchema", {})
|
||
}
|
||
})
|
||
|
||
# 更新数据库
|
||
mcp_config.available_tools = tool_list
|
||
mcp_config.last_health_check = datetime.now()
|
||
mcp_config.health_status = "healthy"
|
||
mcp_config.error_message = None
|
||
config.status = ToolStatus.AVAILABLE.value
|
||
|
||
self.db.commit()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "MCP连接成功并同步工具列表",
|
||
"details": {
|
||
"server_url": mcp_config.server_url,
|
||
"tools_count": len(tool_names),
|
||
"tools": tool_names
|
||
}
|
||
}
|
||
else:
|
||
return {"success": False, "message": f"同步工具失败: {error}"}
|
||
else:
|
||
# 更新错误状态
|
||
mcp_config.last_health_check = datetime.now()
|
||
mcp_config.health_status = "error"
|
||
mcp_config.error_message = test_result.get("error", "连接失败")
|
||
config.status = ToolStatus.ERROR.value
|
||
self.db.commit()
|
||
|
||
return test_result
|
||
|
||
except Exception as e:
|
||
logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}")
|
||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||
|
||
@staticmethod
|
||
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
|
||
"""解析OpenAPI schema获取接口信息"""
|
||
try:
|
||
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||
|
||
parser = OpenAPISchemaParser()
|
||
|
||
if schema_data:
|
||
success, schema, error = parser.parse_from_content(schema_data, "application/json")
|
||
elif schema_url:
|
||
success, schema, error = await parser.parse_from_url(schema_url)
|
||
else:
|
||
return {"success": False, "message": "schema_data或schema_url必须提供一个"}
|
||
|
||
if not success:
|
||
return {"success": False, "message": error}
|
||
|
||
# 提取工具信息
|
||
tool_info = parser.extract_tool_info(schema)
|
||
|
||
# 获取base_url
|
||
servers = tool_info.get("servers", [])
|
||
base_url = servers[0].get("url") if servers else ""
|
||
|
||
return {
|
||
"success": True,
|
||
"data": {
|
||
"title": tool_info["name"],
|
||
"description": tool_info["description"],
|
||
"version": tool_info["version"],
|
||
"base_url": base_url,
|
||
"operations": list(tool_info["operations"].values())
|
||
}
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"解析OpenAPI schema失败: {e}")
|
||
return {"success": False, "message": f"解析失败: {str(e)}"}
|
||
|
||
async def sync_mcp_tools(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
||
"""同步MCP工具列表到数据库"""
|
||
try:
|
||
config = self._get_tool_config(tool_id, tenant_id)
|
||
if not config or config.tool_type != ToolType.MCP.value:
|
||
return {"success": False, "message": "工具不存在或不是MCP工具"}
|
||
|
||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||
if not mcp_config:
|
||
return {"success": False, "message": "MCP配置不存在"}
|
||
|
||
# 创建MCP客户端
|
||
connection_config = mcp_config.connection_config or {}
|
||
client = SimpleMCPClient(mcp_config.server_url, connection_config)
|
||
|
||
async with client:
|
||
# 获取工具列表
|
||
tools = await client.list_tools()
|
||
|
||
# 转换为新格式
|
||
tool_list = []
|
||
tool_names = []
|
||
for tool in tools:
|
||
if tool.get("name"):
|
||
tool_names.append(tool["name"])
|
||
tool_list.append({
|
||
tool["name"]: {
|
||
"description": tool.get("description", ""),
|
||
"inputSchema": tool.get("inputSchema", {})
|
||
}
|
||
})
|
||
|
||
# 更新数据库
|
||
mcp_config.available_tools = tool_list
|
||
mcp_config.last_health_check = datetime.now()
|
||
mcp_config.health_status = "healthy"
|
||
mcp_config.error_message = None
|
||
|
||
# 更新工具状态
|
||
config.status = ToolStatus.AVAILABLE.value
|
||
|
||
self.db.commit()
|
||
|
||
return {
|
||
"success": True,
|
||
"message": "工具列表同步成功",
|
||
"tools_count": len(tool_names),
|
||
"tools": tool_names
|
||
}
|
||
|
||
except Exception as e:
|
||
# 更新错误状态
|
||
try:
|
||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||
if mcp_config:
|
||
mcp_config.last_health_check = datetime.now()
|
||
mcp_config.health_status = "error"
|
||
mcp_config.error_message = str(e)
|
||
config.status = ToolStatus.ERROR.value
|
||
self.db.commit()
|
||
except:
|
||
pass
|
||
|
||
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
|
||
return {"success": False, "message": f"同步失败: {str(e)}"}
|
||
|
||
async def _test_custom_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||
"""测试自定义工具连接(基础连接测试)"""
|
||
try:
|
||
custom_config = self.db.query(CustomToolConfig).filter(
|
||
CustomToolConfig.id == config.id
|
||
).first()
|
||
|
||
if not custom_config or not custom_config.base_url:
|
||
return {"success": False, "message": "自定义工具配置不完整"}
|
||
|
||
import aiohttp
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
custom_config.base_url,
|
||
timeout=aiohttp.ClientTimeout(total=10)
|
||
) as response:
|
||
if response.status == 200:
|
||
return {"success": True, "message": "自定义工具连接成功"}
|
||
else:
|
||
return {"success": False, "message": f"连接失败,状态码: {response.status}"}
|
||
|
||
except Exception as e:
|
||
return {"success": False, "message": f"自定义工具测试失败: {str(e)}"}
|
||
|
||
async def test_custom_tool(
|
||
self,
|
||
tool_id: str,
|
||
tenant_id: uuid.UUID,
|
||
method: str,
|
||
path: str,
|
||
parameters: Dict[str, Any]
|
||
) -> Dict[str, Any]:
|
||
"""测试自定义工具API调用"""
|
||
try:
|
||
config = self._get_tool_config(tool_id, tenant_id)
|
||
if not config or config.tool_type != ToolType.CUSTOM.value:
|
||
return {"success": False, "message": "工具不存在或不是自定义工具"}
|
||
|
||
custom_config = self.db.query(CustomToolConfig).filter(
|
||
CustomToolConfig.id == config.id
|
||
).first()
|
||
|
||
if not custom_config or not custom_config.base_url:
|
||
return {"success": False, "message": "自定义工具配置不完整"}
|
||
|
||
# 构建完整URL
|
||
url = custom_config.base_url.rstrip('/') + '/' + path.lstrip('/')
|
||
|
||
# 构建请求头
|
||
headers = {"Content-Type": "application/json"}
|
||
|
||
# 添加认证头
|
||
if custom_config.auth_type != AuthType.NONE.value:
|
||
auth_config = custom_config.auth_config or {}
|
||
if custom_config.auth_type == AuthType.API_KEY.value:
|
||
key_name = auth_config.get("key_name", "X-API-Key")
|
||
api_key = auth_config.get("api_key")
|
||
if api_key:
|
||
headers[key_name] = api_key
|
||
elif custom_config.auth_type == AuthType.BEARER_TOKEN.value:
|
||
token = auth_config.get("token")
|
||
if token:
|
||
headers["Authorization"] = f"Bearer {token}"
|
||
elif custom_config.auth_type == AuthType.BASIC_AUTH.value:
|
||
import base64
|
||
username = auth_config.get("username", "")
|
||
password = auth_config.get("password", "")
|
||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||
headers["Authorization"] = f"Basic {credentials}"
|
||
|
||
import aiohttp
|
||
async with aiohttp.ClientSession() as session:
|
||
# 根据方法发送请求
|
||
if method.upper() == "GET":
|
||
async with session.get(
|
||
url,
|
||
params=parameters,
|
||
headers=headers,
|
||
timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30)
|
||
) as response:
|
||
result_data = await response.text()
|
||
return {
|
||
"success": True,
|
||
"message": "测试成功",
|
||
"status_code": response.status,
|
||
"response_data": result_data[:1000] # 限制返回数据长度
|
||
}
|
||
else:
|
||
async with session.request(
|
||
method.upper(),
|
||
url,
|
||
json=parameters,
|
||
headers=headers,
|
||
timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30)
|
||
) as response:
|
||
result_data = await response.text()
|
||
return {
|
||
"success": True,
|
||
"message": "测试成功",
|
||
"status_code": response.status,
|
||
"response_data": result_data[:1000] # 限制返回数据长度
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"测试自定义工具API失败: {tool_id}, 错误: {e}")
|
||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||
|
||
async def _test_builtin_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||
"""测试内置工具连接"""
|
||
try:
|
||
# 获取工具实例
|
||
tool_instance = self.get_tool_instance(str(config.id), config.tenant_id)
|
||
if not tool_instance:
|
||
return {"success": False, "message": "无法创建工具实例"}
|
||
|
||
# 检查工具是否有test_connection方法
|
||
if hasattr(tool_instance, 'test_connection'):
|
||
result = await tool_instance.test_connection()
|
||
return result
|
||
else:
|
||
# 检查是否需要配置
|
||
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
|
||
if builtin_config and builtin_config.requires_config:
|
||
# 检查必需参数是否已配置
|
||
if self._is_tool_configured(builtin_config):
|
||
return {"success": True, "message": "内置工具已正确配置"}
|
||
else:
|
||
return {"success": False, "message": "工具缺少必需配置参数"}
|
||
else:
|
||
return {"success": True, "message": "内置工具无需连接测试"}
|
||
|
||
except Exception as e:
|
||
logger.error(f"测试内置工具失败: {config.id}, 错误: {e}")
|
||
return {"success": False, "message": f"测试失败: {str(e)}"}
|