From 25ce86ae93e13d69fe6c7d0249b9f10530b78a5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Wed, 7 Jan 2026 18:59:28 +0800 Subject: [PATCH] feat(agent tool): mcp tool repair --- api/app/controllers/tool_controller.py | 4 +- api/app/core/tools/langchain_adapter.py | 23 +- api/app/core/tools/mcp/__init__.py | 20 +- api/app/core/tools/mcp/base.py | 292 ++++---- api/app/core/tools/mcp/client.py | 769 ++++++---------------- api/app/core/tools/mcp/service_manager.py | 573 +--------------- api/app/services/app_chat_service.py | 62 +- api/app/services/draft_run_service.py | 60 +- api/app/services/tool_service.py | 254 ++++--- 9 files changed, 621 insertions(+), 1436 deletions(-) diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 479686ef..a3624ea4 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -215,8 +215,8 @@ async def sync_mcp_tools( """同步MCP工具列表""" try: result = await service.sync_mcp_tools(tool_id, current_user.tenant_id) - if result["success"] is False: - raise HTTPException(status_code=404, detail=result["message"]) + if not result.get("success", False): + raise HTTPException(status_code=400, detail=result.get("message", "同步失败")) return success(data=result, msg="MCP工具列表同步完成") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index 89ccc205..f7aa0eb8 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -78,13 +78,20 @@ class LangchainAdapter: Args: tool: 内部工具实例 - operation: 特定操作(适用于有操作的工具) + operation: 特定操作(适用于有操作的工具)或MCP工具名称 Returns: Langchain兼容的工具包装器 """ try: - if operation and tool.name in ['datetime_tool', 'json_tool']: + # 处理MCP工具的特定工具名称 + if hasattr(tool, 'tool_type') and tool.tool_type.value == 'mcp' and operation: + # 为MCP工具创建特定工具名称的实例 + mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation) + wrapper = LangchainToolWrapper(tool_instance=mcp_tool) + logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式") + return wrapper + elif operation and tool.name in ['datetime_tool', 'json_tool']: # 为特定操作创建工具 operation_tool = LangchainAdapter._create_operation_tool(tool, operation) wrapper = LangchainToolWrapper(tool_instance=operation_tool) @@ -106,6 +113,18 @@ class LangchainAdapter: from app.core.tools.builtin.operation_tool import OperationTool return OperationTool(base_tool, operation) + @staticmethod + def _create_mcp_tool_with_name(base_tool: BaseTool, tool_name: str) -> BaseTool: + """为MCP工具创建指定工具名称的实例""" + from app.core.tools.mcp.base import MCPTool + + # 创建新的配置,指定具体工具名称 + new_config = base_tool.config.copy() + new_config["tool_name"] = tool_name + + # 创建新的MCP工具实例 + return MCPTool(f"{base_tool.tool_id}_{tool_name}", new_config) + @staticmethod def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]: """批量转换工具 diff --git a/api/app/core/tools/mcp/__init__.py b/api/app/core/tools/mcp/__init__.py index 4c9519b3..b48aa096 100644 --- a/api/app/core/tools/mcp/__init__.py +++ b/api/app/core/tools/mcp/__init__.py @@ -1,12 +1,20 @@ -"""MCP工具模块""" +"""MCP 工具模块 - Model Context Protocol 支持""" -from app.core.tools.mcp.base import MCPTool -from app.core.tools.mcp.client import MCPClient, MCPConnectionPool -from app.core.tools.mcp.service_manager import MCPServiceManager +# 主要类导出 +from .base import MCPTool, MCPToolManager, MCPError +from .client import SimpleMCPClient, MCPConnectionError +from .service_manager import MCPServiceManager __all__ = [ + # 核心类 "MCPTool", - "MCPClient", - "MCPConnectionPool", + "MCPToolManager", + "MCPError", + + # 客户端类 + "SimpleMCPClient", + "MCPConnectionError", + + # 服务管理(简化版) "MCPServiceManager" ] \ No newline at end of file diff --git a/api/app/core/tools/mcp/base.py b/api/app/core/tools/mcp/base.py index 3fa103ab..7784bfb9 100644 --- a/api/app/core/tools/mcp/base.py +++ b/api/app/core/tools/mcp/base.py @@ -1,10 +1,9 @@ -"""MCP工具基类""" +"""MCP工具基类 - 整合版本""" import time -from typing import Dict, Any, List +from typing import List, Dict, Any from app.models.tool_model import ToolType -from app.core.tools.base import BaseTool -from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType from app.core.logging_config import get_business_logger logger = get_business_logger() @@ -14,215 +13,174 @@ class MCPTool(BaseTool): """MCP工具 - Model Context Protocol工具""" def __init__(self, tool_id: str, config: Dict[str, Any]): - """初始化MCP工具 - - Args: - tool_id: 工具ID - config: 工具配置 - """ super().__init__(tool_id, config) self.server_url = config.get("server_url", "") self.connection_config = config.get("connection_config", {}) + self.tool_name = config.get("tool_name", "") # 特定工具名称 + self.tool_schema = config.get("tool_schema", {}) # 工具参数 schema self.available_tools = config.get("available_tools", []) - self._client = None - self._connected = False @property def name(self) -> str: - """工具名称""" - return f"mcp_tool_{self.tool_id[:8]}" + return f"mcp_{self.tool_name}" if self.tool_name else f"mcp_tool_{self.tool_id[:8]}" @property def description(self) -> str: - """工具描述""" - return f"MCP工具 - 连接到 {self.server_url}" + if self.tool_schema.get("description"): + return self.tool_schema["description"] + return f"MCP工具: {self.tool_name}" if self.tool_name else f"MCP工具 - 连接到 {self.server_url}" @property def tool_type(self) -> ToolType: - """工具类型""" return ToolType.MCP @property def parameters(self) -> List[ToolParameter]: - """工具参数定义""" - params = [] - - # 添加工具选择参数 - if len(self.available_tools) > 1: - params.append(ToolParameter( - name="tool_name", - type=ParameterType.STRING, - description="要调用的MCP工具名称", - required=True, - enum=self.available_tools - )) - - # 添加通用参数 - params.extend([ - ToolParameter( + """从 MCP 工具 schema 生成参数""" + if not self.tool_schema: + return [ToolParameter( name="arguments", type=ParameterType.OBJECT, - description="工具参数(JSON对象)", + description="工具参数", required=False, default={} - ), - ToolParameter( - name="timeout", - type=ParameterType.INTEGER, - description="超时时间(秒)", - required=False, - default=30, - minimum=1, - maximum=300 - ) - ]) + )] + + # 解析 MCP 工具的 inputSchema + input_schema = self.tool_schema.get("inputSchema", {}) + properties = input_schema.get("properties", {}) + required_fields = input_schema.get("required", []) + + params = [] + for param_name, param_def in properties.items(): + param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string")) + + params.append(ToolParameter( + name=param_name, + type=param_type, + description=param_def.get("description", f"参数: {param_name}"), + required=param_name in required_fields, + default=param_def.get("default"), + enum=param_def.get("enum"), + minimum=param_def.get("minimum"), + maximum=param_def.get("maximum") + )) return params + def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType: + """转换 JSON Schema 类型到 ParameterType""" + type_mapping = { + "string": ParameterType.STRING, + "integer": ParameterType.INTEGER, + "number": ParameterType.NUMBER, + "boolean": ParameterType.BOOLEAN, + "array": ParameterType.ARRAY, + "object": ParameterType.OBJECT + } + return type_mapping.get(json_type, ParameterType.STRING) + async def execute(self, **kwargs) -> ToolResult: - """执行MCP工具""" + """执行 MCP 工具""" start_time = time.time() try: - # 确保连接 - if not self._connected: - await self.connect() + from .client import SimpleMCPClient - # 确定要调用的工具 - tool_name = kwargs.get("tool_name") - if not tool_name and len(self.available_tools) == 1: - tool_name = self.available_tools[0] - - if not tool_name: - raise ValueError("必须指定要调用的MCP工具名称") - - if tool_name not in self.available_tools: - raise ValueError(f"MCP工具不存在: {tool_name}") - - # 获取参数 - arguments = kwargs.get("arguments", {}) - timeout = kwargs.get("timeout", 30) - - # 调用MCP工具 - result = await self._call_mcp_tool(tool_name, arguments, timeout) - - execution_time = time.time() - start_time - return ToolResult.success_result( - data=result, - execution_time=execution_time - ) + client = SimpleMCPClient(self.server_url, self.connection_config) + async with client: + # 使用指定的工具名称或默认第一个工具 + tool_name_to_use = self.tool_name + if not tool_name_to_use and self.available_tools: + tool_name_to_use = self.available_tools[0] + + if not tool_name_to_use: + raise Exception("未指定工具名称且无可用工具") + + result = await client.call_tool(tool_name_to_use, kwargs) + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + except Exception as e: execution_time = time.time() - start_time + logger.error(f"MCP工具执行失败: {self.tool_name or 'unknown'}, 错误: {e}") return ToolResult.error_result( error=str(e), - error_code="MCP_ERROR", + error_code="MCP_EXECUTION_ERROR", execution_time=execution_time ) + + +class MCPError(Exception): + """MCP 错误基类""" + pass + + +class MCPToolManager: + """MCP 工具管理器 - 简化版本""" - async def connect(self) -> bool: - """连接到MCP服务器""" + def __init__(self, db=None): + self.db = db + self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info + + async def discover_tools( + self, + server_url: str, + connection_config: Dict[str, Any] = None + ) -> tuple[bool, List[Dict[str, Any]], str | None]: + """发现 MCP 服务器上的工具""" try: - from .client import MCPClient + from .client import SimpleMCPClient - if self._connected: - return True - - self._client = MCPClient(self.server_url, self.connection_config) - - if await self._client.connect(): - self._connected = True - # 更新可用工具列表 - await self._update_available_tools() - logger.info(f"MCP服务器连接成功: {self.server_url}") - return True - else: - logger.error(f"MCP服务器连接失败: {self.server_url}") - return False + client = SimpleMCPClient(server_url, connection_config) + async with client: + tools = await client.list_tools() + + # 缓存工具信息 + self._tool_cache[server_url] = { + "tools": tools, + "connection_config": connection_config, + "last_updated": time.time() + } + + logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}") + return True, tools, None + except Exception as e: - logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}") - self._connected = False - return False + error_msg = f"发现工具失败: {e}" + logger.error(error_msg) + return False, [], error_msg - async def _update_available_tools(self): - """更新可用工具列表""" + async def test_tool_connection( + self, + server_url: str, + connection_config: Dict[str, Any] = None + ) -> Dict[str, Any]: + """测试工具连接""" try: - if self._client and self._connected: - tools = await self._client.list_tools() - self.available_tools = [tool.get("name") for tool in tools if tool.get("name")] - logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具") - except Exception as e: - logger.error(f"更新MCP工具列表失败: {e}") - - async def disconnect(self) -> bool: - """断开MCP服务器连接""" - try: - if self._client: - await self._client.disconnect() - self._client = None + from .client import SimpleMCPClient - self._connected = False - logger.info(f"MCP服务器连接已断开: {self.server_url}") - return True - - except Exception as e: - logger.error(f"断开MCP服务器连接失败: {e}") - return False - - def get_health_status(self) -> Dict[str, Any]: - """获取MCP服务健康状态""" - return { - "connected": self._connected, - "server_url": self.server_url, - "available_tools": self.available_tools, - "last_check": time.time() - } - - async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any: - """调用MCP工具""" - if not self._client or not self._connected: - raise Exception("MCP客户端未连接") - - try: - result = await self._client.call_tool(tool_name, arguments, timeout) - return result - except Exception as e: - logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}") - raise - - async def list_available_tools(self) -> List[Dict[str, Any]]: - """列出可用的MCP工具""" - try: - if not self._connected: - await self.connect() - - if self._client: - tools = await self._client.list_tools() - self.available_tools = [tool.get("name") for tool in tools if tool.get("name")] - return tools - - return [] - - except Exception as e: - logger.error(f"获取MCP工具列表失败: {e}") - return [] - - def test_connection(self) -> Dict[str, Any]: - """测试MCP连接""" - try: - # 这里应该实现同步的连接测试 - # 为了简化,返回基本信息 - return { - "success": bool(self.server_url), - "server_url": self.server_url, - "connected": self._connected, - "available_tools_count": len(self.available_tools), - "message": "MCP配置有效" if self.server_url else "缺少服务器URL配置" - } + client = SimpleMCPClient(server_url, connection_config) + async with client: + tools = await client.list_tools() + + return { + "success": True, + "tools_count": len(tools), + "tools": [tool.get("name") for tool in tools], + "message": "连接成功" + } + except Exception as e: return { "success": False, - "error": str(e) + "error": str(e), + "message": "连接失败" } \ No newline at end of file diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index a1d2ecaa..2901b7ca 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -1,9 +1,8 @@ -"""MCP客户端 - Model Context Protocol客户端实现""" +"""MCP客户端 - 简化版本""" import asyncio import json import time -from typing import Dict, Any, List, Optional, Callable -from urllib.parse import urlparse +from typing import Dict, Any, List import aiohttp import websockets from websockets.exceptions import ConnectionClosed @@ -18,139 +17,156 @@ class MCPConnectionError(Exception): pass -class MCPProtocolError(Exception): - """MCP协议错误""" - pass - - -class MCPClient: - """MCP客户端 - 支持HTTP和WebSocket连接""" +class SimpleMCPClient: + """简化的 MCP 客户端""" def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): - """初始化MCP客户端 - - Args: - server_url: MCP服务器URL - connection_config: 连接配置 - """ self.server_url = server_url self.connection_config = connection_config or {} + self.timeout = self.connection_config.get("timeout", 30) - # 解析URL确定连接类型 - parsed_url = urlparse(server_url) - self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http" + # 确定连接类型 + self.is_websocket = server_url.startswith(("ws://", "wss://")) # 连接状态 - self._connected = False self._websocket = None self._session = None - - # 请求管理 self._request_id = 0 - self._pending_requests: Dict[str, asyncio.Future] = {} - - # 连接池配置 - self.max_connections = self.connection_config.get("max_connections", 10) - self.connection_timeout = self.connection_config.get("timeout", 30) - self.retry_attempts = self.connection_config.get("retry_attempts", 3) - self.retry_delay = self.connection_config.get("retry_delay", 1) - - # 健康检查 - self.health_check_interval = self.connection_config.get("health_check_interval", 60) - self._health_check_task = None - self._last_health_check = None - - # 事件回调 - self._on_connect_callbacks: List[Callable] = [] - self._on_disconnect_callbacks: List[Callable] = [] - self._on_error_callbacks: List[Callable] = [] + self._pending_requests = {} - async def connect(self) -> bool: - """连接到MCP服务器 - - Returns: - 连接是否成功 - """ + async def __aenter__(self): + """异步上下文管理器入口""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + await self.disconnect() + + async def connect(self): + """建立连接""" try: - if self._connected: - return True - - logger.info(f"连接MCP服务器: {self.server_url}") - - if self.connection_type == "websocket": - success = await self._connect_websocket() + if self.is_websocket: + await self._connect_websocket() else: - success = await self._connect_http() - - if success: - self._connected = True - await self._start_health_check() - await self._notify_connect_callbacks() - logger.info(f"MCP服务器连接成功: {self.server_url}") - - return success - + await self._connect_http() except Exception as e: - logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}") - await self._notify_error_callbacks(e) - return False + logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}") + raise MCPConnectionError(f"连接失败: {e}") - async def disconnect(self) -> bool: - """断开MCP服务器连接 - - Returns: - 断开是否成功 - """ + async def disconnect(self): + """断开连接""" try: - if not self._connected: - return True - - logger.info(f"断开MCP服务器连接: {self.server_url}") - - # 停止健康检查 - await self._stop_health_check() - - # 取消所有待处理的请求 - for future in self._pending_requests.values(): - if not future.done(): - future.cancel() - self._pending_requests.clear() - - # 断开连接 - if self.connection_type == "websocket" and self._websocket: + if self._websocket: await self._websocket.close() self._websocket = None - elif self._session: + + if self._session: await self._session.close() self._session = None - - self._connected = False - await self._notify_disconnect_callbacks() - logger.info(f"MCP服务器连接已断开: {self.server_url}") - - return True - + except Exception as e: - logger.error(f"断开MCP服务器连接失败: {e}") - return False + logger.error(f"断开连接失败: {e}") - def _build_auth_headers(self) -> Dict[str, str]: - """构建认证头""" - headers = {} - auth_type = self.connection_config.get("auth_type", "none") + async def _connect_websocket(self): + """WebSocket 连接""" + headers = self._build_headers() + + self._websocket = await websockets.connect( + self.server_url, + extra_headers=headers, + timeout=self.timeout + ) + + # 启动消息处理 + asyncio.create_task(self._handle_websocket_messages()) + + # 发送初始化消息 + await self._send_initialize() + + async def _connect_http(self): + """HTTP 连接""" + headers = self._build_headers() + timeout = aiohttp.ClientTimeout(total=self.timeout) + + self._session = aiohttp.ClientSession( + headers=headers, + timeout=timeout + ) + + # 对于 ModelScope MCP 服务,需要先发送初始化请求 + if "modelscope.net" in self.server_url: + await self._initialize_modelscope_session() + + async def _initialize_modelscope_session(self): + """初始化 ModelScope MCP 会话""" + init_request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": { + "name": "MemoryBear", + "version": "1.0.0" + } + } + } + + try: + async with self._session.post( + self.server_url, + json=init_request + ) as response: + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") + + init_response = await response.json() + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") + + # 获取 session ID + session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") + if session_id: + self._session.headers.update({"Mcp-Session-Id": session_id}) + + # 发送 initialized 通知 + initialized_notification = { + "jsonrpc": "2.0", + "method": "notifications/initialized" + } + + async with self._session.post( + self.server_url, + json=initialized_notification + ) as notif_response: + pass + + except aiohttp.ClientError as e: + raise MCPConnectionError(f"初始化连接失败: {e}") + + def _build_headers(self) -> Dict[str, str]: + """构建请求头""" + headers = { + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" + } + + # 添加认证头 auth_config = self.connection_config.get("auth_config", {}) + auth_type = self.connection_config.get("auth_type", "none") - if auth_type == "api_key": - api_key = auth_config.get("api_key") - key_name = auth_config.get("key_name", "X-API-Key") - if api_key: - headers[key_name] = api_key - - elif auth_type == "bearer_token": + if auth_type == "bearer_token": token = auth_config.get("token") if token: headers["Authorization"] = f"Bearer {token}" - + elif auth_type == "api_key": + key = auth_config.get("api_key") + header_name = auth_config.get("key_name", "X-API-Key") + if key: + headers[header_name] = key elif auth_type == "basic_auth": username = auth_config.get("username") password = auth_config.get("password") @@ -161,160 +177,63 @@ class MCPClient: return headers - async def _connect_websocket(self) -> bool: - """建立WebSocket连接""" - try: - # WebSocket连接配置 - extra_headers = self.connection_config.get("headers", {}) - auth_headers = self._build_auth_headers() - extra_headers.update(auth_headers) - - self._websocket = await websockets.connect( - self.server_url, - extra_headers=extra_headers, - timeout=self.connection_timeout - ) - - # 启动消息监听 - asyncio.create_task(self._websocket_message_handler()) - - # 发送初始化消息 - init_message = { - "jsonrpc": "2.0", - "id": self._get_next_request_id(), - "method": "initialize", - "params": { - "protocolVersion": "2024-11-05", - "capabilities": { - "tools": {} - }, - "clientInfo": { - "name": "ToolManagementSystem", - "version": "1.0.0" - } + async def _send_initialize(self): + """发送初始化消息""" + init_message = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": { + "name": "MemoryBear", + "version": "1.0.0" } } - - await self._websocket.send(json.dumps(init_message)) - - # 等待初始化响应 - response = await asyncio.wait_for( - self._websocket.recv(), - timeout=self.connection_timeout - ) - - init_response = json.loads(response) - if init_response.get("error", None) is not None: - raise MCPProtocolError(f"初始化失败: {init_response['error']}") - - return True - - except Exception as e: - logger.error(f"WebSocket连接失败: {e}") - return False + } + + await self._websocket.send(json.dumps(init_message)) + + # 等待初始化响应 + response = await asyncio.wait_for( + self._websocket.recv(), + timeout=self.timeout + ) + + init_response = json.loads(response) + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") - async def _connect_http(self) -> bool: - """建立HTTP连接""" - try: - # HTTP会话配置 - timeout = aiohttp.ClientTimeout(total=self.connection_timeout) - headers = self.connection_config.get("headers", {}) - auth_headers = self._build_auth_headers() - headers.update(auth_headers) - - self._session = aiohttp.ClientSession( - timeout=timeout, - headers=headers - ) - - # 测试连接 - test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health" - - async with self._session.get(test_url) as response: - if response.status == 200: - return True - else: - # 尝试根路径 - async with self._session.get(self.server_url) as root_response: - return root_response.status < 400 - - except Exception as e: - logger.error(f"HTTP连接失败: {e}") - if self._session: - await self._session.close() - self._session = None - return False - - async def _websocket_message_handler(self): - """WebSocket消息处理器""" + async def _handle_websocket_messages(self): + """处理 WebSocket 消息""" try: while self._websocket and not self._websocket.closed: try: message = await self._websocket.recv() - await self._handle_message(json.loads(message)) + data = json.loads(message) + + # 处理响应 + if "id" in data: + request_id = str(data["id"]) + if request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(data) + except ConnectionClosed: break - except json.JSONDecodeError as e: - logger.error(f"解析WebSocket消息失败: {e}") except Exception as e: logger.error(f"处理WebSocket消息失败: {e}") except Exception as e: - logger.error(f"WebSocket消息处理器异常: {e}") - finally: - self._connected = False - await self._notify_disconnect_callbacks() + logger.error(f"WebSocket消息处理异常: {e}") - async def _handle_message(self, message: Dict[str, Any]): - """处理收到的消息""" - try: - # 检查是否是响应消息 - if "id" in message: - request_id = str(message["id"]) - if request_id in self._pending_requests: - future = self._pending_requests.pop(request_id) - if not future.done(): - future.set_result(message) - - # 处理通知消息 - elif "method" in message: - await self._handle_notification(message) - - except Exception as e: - logger.error(f"处理消息失败: {e}") - - @staticmethod - async def _handle_notification(message: Dict[str, Any]): - """处理通知消息""" - method = message.get("method") - params = message.get("params", {}) - - logger.debug(f"收到MCP通知: {method}, 参数: {params}") - - # 这里可以根据需要处理特定的通知 - # 例如:工具列表更新、服务器状态变化等 - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]: - """调用MCP工具 - - Args: - tool_name: 工具名称 - arguments: 工具参数 - timeout: 超时时间(秒) - - Returns: - 工具执行结果 - - Raises: - MCPConnectionError: 连接错误 - MCPProtocolError: 协议错误 - """ - if not self._connected: - raise MCPConnectionError("MCP客户端未连接") - + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + """调用工具""" request_data = { "jsonrpc": "2.0", - "id": self._get_next_request_id(), + "id": self._get_request_id(), "method": "tools/call", "params": { "name": tool_name, @@ -322,343 +241,69 @@ class MCPClient: } } - try: - response = await self._send_request(request_data, timeout) - - if response.get("error", None) is not None: - error = response["error"] - raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}") - - return response.get("result", {}) - - except asyncio.TimeoutError: - raise MCPProtocolError(f"工具调用超时: {tool_name}") + if self.is_websocket: + response = await self._send_websocket_request(request_data) + else: + response = await self._send_http_request(request_data) + + if "error" in response: + error = response["error"] + raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") + + return response.get("result", {}) - async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]: - """获取可用工具列表 - - Args: - timeout: 超时时间(秒) - - Returns: - 工具列表 - - Raises: - MCPConnectionError: 连接错误 - MCPProtocolError: 协议错误 - """ - if not self._connected: - raise MCPConnectionError("MCP客户端未连接") - + async def list_tools(self) -> List[Dict[str, Any]]: + """获取工具列表""" request_data = { "jsonrpc": "2.0", - "id": self._get_next_request_id(), - "method": "tools/list" + "id": self._get_request_id(), + "method": "tools/list", + "params": {} } - try: - response = await self._send_request(request_data, timeout) - - if response.get("error", None) is not None: - error = response["error"] - raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}") - - result = response.get("result", {}) - return result.get("tools", []) - - except asyncio.TimeoutError: - raise MCPProtocolError("获取工具列表超时") - - async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: - """发送请求并等待响应 - - Args: - request_data: 请求数据 - timeout: 超时时间(秒) - - Returns: - 响应数据 - """ - if self.connection_type == "websocket": - request_id = str(request_data["id"]) - return await self._send_websocket_request(request_data, request_id, timeout) + if self.is_websocket: + response = await self._send_websocket_request(request_data) else: - return await self._send_http_request(request_data, timeout) - - async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]: - """发送WebSocket请求""" - if not self._websocket or self._websocket.closed: - raise MCPConnectionError("WebSocket连接已断开") + response = await self._send_http_request(request_data) - # 创建Future等待响应 + if "error" in response: + error = response["error"] + raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}") + + result = response.get("result", {}) + return result.get("tools", []) + + async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: + """发送WebSocket请求""" + request_id = str(request_data["id"]) future = asyncio.Future() self._pending_requests[request_id] = future try: - # 发送请求 await self._websocket.send(json.dumps(request_data)) - - # 等待响应 - response = await asyncio.wait_for(future, timeout=timeout) + response = await asyncio.wait_for(future, timeout=self.timeout) return response - except asyncio.TimeoutError: - await self._pending_requests.pop(request_id, None) + self._pending_requests.pop(request_id, None) raise - except Exception as e: - await self._pending_requests.pop(request_id, None) - raise MCPConnectionError(f"发送WebSocket请求失败: {e}") - async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: """发送HTTP请求""" - if not self._session: - raise MCPConnectionError("HTTP会话未建立") - try: - url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp" - async with self._session.post( - url, - json=request_data, - timeout=aiohttp.ClientTimeout(total=timeout) + self.server_url, + json=request_data ) as response: - if response.status == 200: - return await response.json() - else: - async with self._session.post( - self.server_url, - json=request_data, - timeout=aiohttp.ClientTimeout(total=timeout) - ) as root_response: - if root_response.status != 200: - error_text = await root_response.text() - raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") - - return await response.json() + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") + + return await response.json() except aiohttp.ClientError as e: raise MCPConnectionError(f"HTTP请求失败: {e}") - async def health_check(self) -> Dict[str, Any]: - """执行健康检查 - - Returns: - 健康状态信息 - """ - try: - if not self._connected: - return { - "healthy": False, - "error": "未连接", - "timestamp": time.time() - } - - # 发送ping请求 - request_data = { - "jsonrpc": "2.0", - "id": self._get_next_request_id(), - "method": "ping" - } - - start_time = time.time() - response = await self._send_request(request_data, timeout=5) - response_time = round((time.time() - start_time) * 1000) - - self._last_health_check = round(time.time() * 1000) - - return { - "healthy": True, - "response_time": response_time, - "timestamp": self._last_health_check, - "server_info": response.get("result", {}) - } - - except Exception as e: - return { - "healthy": False, - "error": str(e), - "timestamp": time.time() - } - - async def _start_health_check(self): - """启动健康检查任务""" - if self.health_check_interval > 0: - self._health_check_task = asyncio.create_task(self._health_check_loop()) - - async def _stop_health_check(self): - """停止健康检查任务""" - if self._health_check_task: - self._health_check_task.cancel() - try: - await self._health_check_task - except asyncio.CancelledError: - pass - self._health_check_task = None - - async def _health_check_loop(self): - """健康检查循环""" - try: - while self._connected: - await asyncio.sleep(self.health_check_interval) - - if self._connected: - health_status = await self.health_check() - if not health_status["healthy"]: - logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}") - # 可以在这里实现重连逻辑 - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"健康检查循环异常: {e}") - - def _get_next_request_id(self) -> str: - """获取下一个请求ID""" + def _get_request_id(self) -> str: + """获取请求ID""" self._request_id += 1 - return f"req_{self._request_id}_{int(time.time() * 1000)}" - - # 事件回调管理 - def on_connect(self, callback: Callable): - """注册连接回调""" - self._on_connect_callbacks.append(callback) - - def on_disconnect(self, callback: Callable): - """注册断开连接回调""" - self._on_disconnect_callbacks.append(callback) - - def on_error(self, callback: Callable): - """注册错误回调""" - self._on_error_callbacks.append(callback) - - async def _notify_connect_callbacks(self): - """通知连接回调""" - for callback in self._on_connect_callbacks: - try: - if asyncio.iscoroutinefunction(callback): - await callback() - else: - callback() - except Exception as e: - logger.error(f"连接回调执行失败: {e}") - - async def _notify_disconnect_callbacks(self): - """通知断开连接回调""" - for callback in self._on_disconnect_callbacks: - try: - if asyncio.iscoroutinefunction(callback): - await callback() - else: - callback() - except Exception as e: - logger.error(f"断开连接回调执行失败: {e}") - - async def _notify_error_callbacks(self, error: Exception): - """通知错误回调""" - for callback in self._on_error_callbacks: - try: - if asyncio.iscoroutinefunction(callback): - await callback(error) - else: - callback(error) - except Exception as e: - logger.error(f"错误回调执行失败: {e}") - - @property - def is_connected(self) -> bool: - """检查是否已连接""" - return self._connected - - @property - def last_health_check(self) -> Optional[float]: - """获取最后一次健康检查时间""" - return self._last_health_check - - def get_connection_info(self) -> Dict[str, Any]: - """获取连接信息""" - return { - "server_url": self.server_url, - "connection_type": self.connection_type, - "connected": self._connected, - "last_health_check": self._last_health_check, - "pending_requests": len(self._pending_requests), - "config": self.connection_config - } - - async def __aenter__(self): - """异步上下文管理器入口""" - await self.connect() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" - await self.disconnect() - - -class MCPConnectionPool: - """MCP连接池 - 管理多个MCP客户端连接""" - - def __init__(self, max_connections: int = 10): - """初始化连接池 - - Args: - max_connections: 最大连接数 - """ - self.max_connections = max_connections - self._clients: Dict[str, MCPClient] = {} - self._lock = asyncio.Lock() - - async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient: - """获取或创建MCP客户端 - - Args: - server_url: 服务器URL - connection_config: 连接配置 - - Returns: - MCP客户端实例 - """ - async with self._lock: - if server_url in self._clients: - client = self._clients[server_url] - if client.is_connected: - return client - else: - # 尝试重连 - if await client.connect(): - return client - else: - # 移除失效的客户端 - del self._clients[server_url] - - # 检查连接数限制 - if len(self._clients) >= self.max_connections: - # 移除最旧的连接 - oldest_url = next(iter(self._clients)) - await self._clients[oldest_url].disconnect() - del self._clients[oldest_url] - - # 创建新客户端 - client = MCPClient(server_url, connection_config) - if await client.connect(): - self._clients[server_url] = client - return client - else: - raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}") - - async def disconnect_all(self): - """断开所有连接""" - async with self._lock: - for client in self._clients.values(): - await client.disconnect() - self._clients.clear() - - def get_pool_status(self) -> Dict[str, Any]: - """获取连接池状态""" - return { - "total_connections": len(self._clients), - "max_connections": self.max_connections, - "connections": { - url: client.get_connection_info() - for url, client in self._clients.items() - } - } \ No newline at end of file + return f"req_{self._request_id}_{int(time.time() * 1000)}" \ No newline at end of file diff --git a/api/app/core/tools/mcp/service_manager.py b/api/app/core/tools/mcp/service_manager.py index 01312444..2144999a 100644 --- a/api/app/core/tools/mcp/service_manager.py +++ b/api/app/core/tools/mcp/service_manager.py @@ -1,6 +1,4 @@ -"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控""" -import asyncio -import time +"""MCP服务管理器 - 简化版本""" import uuid from typing import Dict, Any, List, Optional, Tuple from datetime import datetime @@ -8,136 +6,53 @@ from sqlalchemy.orm import Session from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus from app.core.logging_config import get_business_logger -from app.core.tools.mcp.client import MCPClient, MCPConnectionPool +from app.core.tools.mcp.base import MCPToolManager logger = get_business_logger() class MCPServiceManager: - """MCP服务管理器 - 管理MCP服务的生命周期""" + """MCP服务管理器 - 简化版本,主要用于工具创建""" def __init__(self, db: Session = None): - """初始化MCP服务管理器 - - Args: - db: 数据库会话(可选) - """ self.db = db - if db: - self.connection_pool = MCPConnectionPool(max_connections=20) - else: - self.connection_pool = None - - # 服务状态管理 - self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info - self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task - - # 配置 - self.health_check_interval = 60 # 健康检查间隔(秒) - self.max_retry_attempts = 3 # 最大重试次数 - self.retry_delay = 5 # 重试延迟(秒) - - # 状态 - self._running = False - self._manager_task = None + self.tool_manager = MCPToolManager(db) if db else None - async def start(self): - """启动服务管理器""" - if self._running: - return - - self._running = True - logger.info("MCP服务管理器启动") - - # 加载现有服务 - await self._load_existing_services() - - # 启动管理任务 - self._manager_task = asyncio.create_task(self._management_loop()) - - async def stop(self): - """停止服务管理器""" - if not self._running: - return - - self._running = False - logger.info("MCP服务管理器停止") - - # 停止管理任务 - if self._manager_task: - self._manager_task.cancel() - try: - await self._manager_task - except asyncio.CancelledError: - pass - - # 停止所有监控任务 - for task in self._monitoring_tasks.values(): - task.cancel() - - if self._monitoring_tasks: - await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True) - - self._monitoring_tasks.clear() - - # 断开所有连接 - await self.connection_pool.disconnect_all() - - async def register_service( + async def create_mcp_tool( self, server_url: str, connection_config: Dict[str, Any], tenant_id: uuid.UUID, + tool_name: str, service_name: str = None ) -> Tuple[bool, str, Optional[str]]: - """注册MCP服务 + """创建单个MCP工具 Args: server_url: 服务器URL connection_config: 连接配置 tenant_id: 租户ID - service_name: 服务名称(可选) + tool_name: 具体工具名称 + service_name: 服务名称 Returns: - (是否成功, 服务ID或错误信息, 错误详情) + (是否成功, 工具ID或错误信息, 错误详情) """ try: - # 检查服务是否已存在 - existing_service = self.db.query(MCPToolConfig).filter( - MCPToolConfig.server_url == server_url - ).first() - - if existing_service: - return False, "服务已存在", f"URL {server_url} 已被注册" - - # 测试连接 - try: - client = MCPClient(server_url, connection_config) - if not await client.connect(): - return False, "连接测试失败", "无法连接到MCP服务器" - - # 获取可用工具 - available_tools = await client.list_tools() - tool_names = [tool.get("name") for tool in available_tools if tool.get("name")] - - await client.disconnect() - - except Exception as e: - return False, "连接测试失败", str(e) + if not service_name: + service_name = f"mcp_{tool_name}" # 创建工具配置 - if not service_name: - service_name = f"mcp_service_{server_url.split('/')[-1]}" - tool_config = ToolConfig( name=service_name, - description=f"MCP服务 - {server_url}", + description=f"MCP工具: {tool_name}", tool_type=ToolType.MCP.value, tenant_id=tenant_id, - version="1.0.0", + status=ToolStatus.AVAILABLE.value, config_data={ "server_url": server_url, - "connection_config": connection_config + "connection_config": connection_config, + "tool_name": tool_name } ) @@ -149,460 +64,22 @@ class MCPServiceManager: id=tool_config.id, server_url=server_url, connection_config=connection_config, - available_tools=tool_names, - health_status="healthy", + available_tools=[tool_name], + health_status="unknown", last_health_check=datetime.now() ) self.db.add(mcp_config) self.db.commit() - service_id = str(tool_config.id) - - # 添加到内存管理 - self._services[service_id] = { - "id": service_id, - "server_url": server_url, - "connection_config": connection_config, - "tenant_id": tenant_id, - "available_tools": tool_names, - "status": "healthy", - "last_health_check": time.time(), - "retry_count": 0, - "created_at": time.time() - } - - # 启动监控 - await self._start_service_monitoring(service_id) - - logger.info(f"MCP服务注册成功: {service_id} ({server_url})") - return True, service_id, None + logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})") + return True, str(tool_config.id), None except Exception as e: self.db.rollback() - logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}") - return False, "注册失败", str(e) + logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}") + return False, "创建失败", str(e) - async def unregister_service(self, service_id: str) -> Tuple[bool, str]: - """注销MCP服务 - - Args: - service_id: 服务ID - - Returns: - (是否成功, 错误信息) - """ - try: - # 从数据库删除 - tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) - if not tool_config: - return False, "服务不存在" - - self.db.delete(tool_config) - self.db.commit() - - # 停止监控 - await self._stop_service_monitoring(service_id) - - # 从内存移除 - if service_id in self._services: - del self._services[service_id] - - logger.info(f"MCP服务注销成功: {service_id}") - return True, "" - - except Exception as e: - self.db.rollback() - logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}") - return False, str(e) - - async def update_service( - self, - service_id: str, - connection_config: Dict[str, Any] = None, - enabled: bool = None - ) -> Tuple[bool, str]: - """更新MCP服务配置 - - Args: - service_id: 服务ID - connection_config: 新的连接配置 - enabled: 是否启用 - - Returns: - (是否成功, 错误信息) - """ - try: - # 更新数据库 - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == uuid.UUID(service_id) - ).first() - - if not mcp_config: - return False, "服务不存在" - - tool_config = mcp_config.base_config - - if connection_config is not None: - mcp_config.connection_config = connection_config - tool_config.config_data["connection_config"] = connection_config - - if enabled is not None: - tool_config.is_enabled = enabled - - self.db.commit() - - # 更新内存状态 - if service_id in self._services: - if connection_config is not None: - self._services[service_id]["connection_config"] = connection_config - - # 如果配置有变化,重启监控 - if connection_config is not None: - await self._restart_service_monitoring(service_id) - - logger.info(f"MCP服务更新成功: {service_id}") - return True, "" - - except Exception as e: - self.db.rollback() - logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}") - return False, str(e) - - async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]: - """获取服务状态 - - Args: - service_id: 服务ID - - Returns: - 服务状态信息 - """ - if service_id not in self._services: - return None - - service_info = self._services[service_id].copy() - - # 添加实时健康检查 - try: - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - health_status = await client.health_check() - service_info["real_time_health"] = health_status - - except Exception as e: - service_info["real_time_health"] = { - "healthy": False, - "error": str(e), - "timestamp": time.time() - } - - return service_info - - async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]: - """列出所有服务 - - Args: - tenant_id: 租户ID过滤 - - Returns: - 服务列表 - """ - services = [] - - for service_id, service_info in self._services.items(): - if tenant_id and service_info["tenant_id"] != tenant_id: - continue - - services.append(service_info.copy()) - - return services - - async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]: - """获取服务的可用工具 - - Args: - service_id: 服务ID - - Returns: - 工具列表 - """ - if service_id not in self._services: - return [] - - service_info = self._services[service_id] - - try: - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - tools = await client.list_tools() - - # 更新缓存的工具列表 - tool_names = [tool.get("name") for tool in tools if tool.get("name")] - service_info["available_tools"] = tool_names - - # 更新数据库 - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == uuid.UUID(service_id) - ).first() - - if mcp_config: - mcp_config.available_tools = tool_names - self.db.commit() - - return tools - - except Exception as e: - logger.error(f"获取服务工具失败: {service_id}, 错误: {e}") - return [] - - async def call_service_tool( - self, - service_id: str, - tool_name: str, - arguments: Dict[str, Any], - timeout: int = 30 - ) -> Dict[str, Any]: - """调用服务工具 - - Args: - service_id: 服务ID - tool_name: 工具名称 - arguments: 工具参数 - timeout: 超时时间 - - Returns: - 执行结果 - """ - if service_id not in self._services: - raise ValueError(f"服务不存在: {service_id}") - - service_info = self._services[service_id] - - try: - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - result = await client.call_tool(tool_name, arguments, timeout) - - # 更新服务状态为健康 - service_info["status"] = "healthy" - service_info["last_health_check"] = time.time() - service_info["retry_count"] = 0 - - return result - - except Exception as e: - # 更新服务状态为错误 - service_info["status"] = "error" - service_info["last_error"] = str(e) - service_info["retry_count"] += 1 - - logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}") - raise - - async def _load_existing_services(self): - """加载现有服务""" - try: - mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter( - ToolConfig.status == ToolStatus.AVAILABLE.value, - ToolConfig.tool_type == ToolType.MCP.value - ).all() - - for mcp_config in mcp_configs: - tool_config = mcp_config.base_config - service_id = str(mcp_config.id) - - self._services[service_id] = { - "id": service_id, - "server_url": mcp_config.server_url, - "connection_config": mcp_config.connection_config or {}, - "tenant_id": tool_config.tenant_id, - "available_tools": mcp_config.available_tools or [], - "status": mcp_config.health_status or "unknown", - "last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0, - "retry_count": 0, - "created_at": tool_config.created_at.timestamp() - } - - # 启动监控 - await self._start_service_monitoring(service_id) - - logger.info(f"加载了 {len(mcp_configs)} 个MCP服务") - - except Exception as e: - logger.error(f"加载现有服务失败: {e}") - - async def _start_service_monitoring(self, service_id: str): - """启动服务监控""" - if service_id in self._monitoring_tasks: - return - - task = asyncio.create_task(self._monitor_service(service_id)) - self._monitoring_tasks[service_id] = task - - async def _stop_service_monitoring(self, service_id: str): - """停止服务监控""" - if service_id in self._monitoring_tasks: - task = self._monitoring_tasks.pop(service_id) - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - async def _restart_service_monitoring(self, service_id: str): - """重启服务监控""" - await self._stop_service_monitoring(service_id) - await self._start_service_monitoring(service_id) - - async def _monitor_service(self, service_id: str): - """监控单个服务""" - try: - while self._running and service_id in self._services: - service_info = self._services[service_id] - - try: - # 执行健康检查 - client = await self.connection_pool.get_client( - service_info["server_url"], - service_info["connection_config"] - ) - - health_status = await client.health_check() - - if health_status["healthy"]: - # 服务健康 - service_info["status"] = "healthy" - service_info["retry_count"] = 0 - - # 更新工具列表 - try: - tools = await client.list_tools() - tool_names = [tool.get("name") for tool in tools if tool.get("name")] - service_info["available_tools"] = tool_names - except Exception as e: - logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}") - - else: - # 服务不健康 - service_info["status"] = "unhealthy" - service_info["last_error"] = health_status.get("error", "健康检查失败") - service_info["retry_count"] += 1 - - service_info["last_health_check"] = time.time() - - # 更新数据库 - await self._update_service_health_in_db(service_id, health_status) - - except Exception as e: - # 监控异常 - service_info["status"] = "error" - service_info["last_error"] = str(e) - service_info["retry_count"] += 1 - service_info["last_health_check"] = time.time() - - logger.error(f"服务监控异常: {service_id}, 错误: {e}") - - # 如果重试次数过多,暂停监控 - if service_info["retry_count"] >= self.max_retry_attempts: - logger.warning(f"服务 {service_id} 重试次数过多,暂停监控") - await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间 - service_info["retry_count"] = 0 # 重置重试计数 - - # 等待下次检查 - await asyncio.sleep(self.health_check_interval) - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"服务监控任务异常: {service_id}, 错误: {e}") - - async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]): - """更新数据库中的服务健康状态""" - try: - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == uuid.UUID(service_id) - ).first() - - if mcp_config: - mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy" - mcp_config.last_health_check = datetime.now() - - if not health_status["healthy"]: - mcp_config.error_message = health_status.get("error", "") - else: - mcp_config.error_message = None - - self.db.commit() - - except Exception as e: - logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}") - self.db.rollback() - - async def _management_loop(self): - """管理循环 - 处理服务清理等任务""" - try: - while self._running: - # 清理失效的服务 - await self._cleanup_failed_services() - - # 等待下次循环 - await asyncio.sleep(300) # 5分钟 - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"管理循环异常: {e}") - - async def _cleanup_failed_services(self): - """清理长期失效的服务""" - try: - current_time = time.time() - cleanup_threshold = 24 * 60 * 60 # 24小时 - - services_to_cleanup = [] - - for service_id, service_info in self._services.items(): - # 检查服务是否长期失效 - if (service_info["status"] in ["error", "unhealthy"] and - current_time - service_info["last_health_check"] > cleanup_threshold): - - services_to_cleanup.append(service_id) - - for service_id in services_to_cleanup: - logger.warning(f"清理长期失效的服务: {service_id}") - - # 停止监控但不删除数据库记录 - await self._stop_service_monitoring(service_id) - - # 标记为禁用 - tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) - if tool_config: - tool_config.is_enabled = False - self.db.commit() - - # 从内存移除 - del self._services[service_id] - - except Exception as e: - logger.error(f"清理失效服务失败: {e}") - - def get_manager_status(self) -> Dict[str, Any]: - """获取管理器状态""" - return { - "running": self._running, - "total_services": len(self._services), - "healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]), - "unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]), - "monitoring_tasks": len(self._monitoring_tasks), - "connection_pool_status": self.connection_pool.get_pool_status() - } \ No newline at end of file + def get_tool_manager(self) -> MCPToolManager: + """获取工具管理器实例""" + return self.tool_manager \ No newline at end of file diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 537eac8d..ec046013 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -77,7 +77,7 @@ class AppChatService: tool_service = ToolService(self.db) # 从配置中获取启用的工具 - if hasattr(config, 'tools') and config.tools: + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): for tool_config in config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 @@ -109,20 +109,21 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - # web_tools = config.tools - # web_search_choice = web_tools.get("web_search", {}) - # web_search_enable = web_search_choice.get("enabled", False) - # if web_search == True: - # if web_search_enable == True: - # search_tool = create_web_search_tool({}) - # tools.append(search_tool) - # - # logger.debug( - # "已添加网络搜索工具", - # extra={ - # "tool_count": len(tools) - # } - # ) + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) # 获取模型参数 model_parameters = config.model_parameters @@ -226,7 +227,7 @@ class AppChatService: # 获取工具服务 tool_service = ToolService(self.db) - if hasattr(config, 'tools') and config.tools: + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): for tool_config in config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 @@ -258,20 +259,21 @@ class AppChatService: memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - # web_tools = config.tools - # web_search_choice = web_tools.get("web_search", {}) - # web_search_enable = web_search_choice.get("enabled", False) - # if web_search == True: - # if web_search_enable == True: - # search_tool = create_web_search_tool({}) - # tools.append(search_tool) - # - # logger.debug( - # "已添加网络搜索工具", - # extra={ - # "tool_count": len(tools) - # } - # ) + if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): + web_tools = config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) # 获取模型参数 model_parameters = config.model_parameters diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 9a1dbd32..cdbb213e 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -297,19 +297,35 @@ class DraftRunService: tool_service = ToolService(self.db) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools: - for tool_config in agent_config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), - ToolRepository.get_tenant_id_by_workspace_id( - self.db, str(workspace_id))) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) + if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): + if hasattr(agent_config, 'tools') and agent_config.tools: + for tool_config in agent_config.tools: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), + ToolRepository.get_tenant_id_by_workspace_id( + self.db, str(workspace_id))) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): + web_tools = agent_config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) # 添加知识库检索工具 if agent_config.knowledge_retrieval: @@ -507,7 +523,7 @@ class DraftRunService: tool_service = ToolService(self.db) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools: + if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): for tool_config in agent_config.tools: if tool_config.get("enabled", False): # 根据工具名称查找工具实例 @@ -520,6 +536,22 @@ class DraftRunService: # 转换为LangChain工具 langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) + elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): + web_tools = agent_config.tools + web_search_choice = web_tools.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search == True: + if web_search_enable == True: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + # 添加知识库检索工具 if agent_config.knowledge_retrieval: diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index ab5128fd..b3258f88 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -8,7 +8,7 @@ from datetime import datetime from sqlalchemy.orm import Session -from app.core.tools.mcp import MCPClient +from app.core.tools.mcp import MCPToolManager, SimpleMCPClient from app.repositories.tool_repository import ( ToolRepository, BuiltinToolRepository, CustomToolRepository, MCPToolRepository, ToolExecutionRepository @@ -42,6 +42,9 @@ class ToolService: def __init__(self, db: Session): self.db = db self._tool_cache: Dict[str, BaseTool] = {} + + # MCP管理器 + self.mcp_tool_manager = MCPToolManager(db) # 初始化仓储 self.tool_repo = ToolRepository() @@ -675,23 +678,85 @@ class ToolService: return [] async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: - """获取MCP工具的方法""" + """获取MCP工具的方法和参数""" mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) if not mcp_config: return [] available_tools = mcp_config.available_tools or [] if not available_tools: - return [] + # 如果没有工具列表,尝试同步 + try: + success, tools, _ = await self.mcp_tool_manager.discover_tools( + mcp_config.server_url, mcp_config.connection_config or {} + ) + if success: + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + mcp_config.available_tools = tool_names + self.db.commit() + available_tools = tool_names + except Exception as e: + logger.error(f"同步MCP工具列表失败: {e}") + return [] methods = [] - for tool_name in available_tools: - methods.append({ - "method_id": tool_name, - "name": tool_name, - "description": f"MCP工具: {tool_name}", - "parameters": [] # MCP工具参数需要动态获取 - }) + + # 获取工具详细信息 + try: + success, tools, _ = await self.mcp_tool_manager.discover_tools( + mcp_config.server_url, mcp_config.connection_config or {} + ) + + if success: + tools_dict = {tool.get("name"): tool for tool in tools if tool.get("name")} + + for tool_name in available_tools: + tool_info = tools_dict.get(tool_name, {}) + + # 解析工具参数 + parameters = [] + input_schema = tool_info.get("inputSchema", {}) + properties = input_schema.get("properties", {}) + required_fields = input_schema.get("required", []) + + for param_name, param_def in properties.items(): + parameters.append({ + "name": param_name, + "type": param_def.get("type", "string"), + "description": param_def.get("description", ""), + "required": param_name in required_fields, + "default": param_def.get("default"), + "enum": param_def.get("enum"), + "minimum": param_def.get("minimum"), + "maximum": param_def.get("maximum") + }) + + methods.append({ + "method_id": tool_name, + "name": tool_name, + "description": tool_info.get("description", f"MCP工具: {tool_name}"), + "parameters": parameters + }) + else: + # 如果无法获取详细信息,返回基本信息 + for tool_name in available_tools: + methods.append({ + "method_id": tool_name, + "name": tool_name, + "description": f"MCP工具: {tool_name}", + "parameters": [] + }) + + except Exception as e: + logger.error(f"获取MCP工具详细信息失败: {e}") + # 返回基本信息 + for tool_name in available_tools: + methods.append({ + "method_id": tool_name, + "name": tool_name, + "description": f"MCP工具: {tool_name}", + "parameters": [] + }) return methods @@ -812,10 +877,14 @@ class ToolService: if not mcp_config: return None + # 从配置中获取特定工具名称 + tool_name = config.config_data.get("tool_name") + tool_config = { "server_url": mcp_config.server_url, "connection_config": mcp_config.connection_config or {}, - "available_tools": mcp_config.available_tools or [] + "available_tools": mcp_config.available_tools or [], + "tool_name": tool_name # 指定具体工具 } return MCPTool(str(config.id), tool_config) @@ -1071,71 +1140,59 @@ class ToolService: return {} async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]: - """测试MCP连接""" + """测试MCP连接并自动同步工具列表""" try: - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == config.id - ).first() - + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) if not mcp_config: return {"success": False, "message": "MCP配置不存在"} - client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {}) - - if await client.connect(): - try: - # tools = await client.list_tools() - await client.disconnect() - - # 更新连接状态 + # 使用集成的MCP管理器测试连接 + test_result = await self.mcp_tool_manager.test_tool_connection( + mcp_config.server_url, mcp_config.connection_config or {} + ) + + if test_result["success"]: + # 连接成功,自动同步工具列表 + success, tools, error = await self.mcp_tool_manager.discover_tools( + mcp_config.server_url, mcp_config.connection_config or {} + ) + + if success: + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + + # 更新数据库 + mcp_config.available_tools = tool_names mcp_config.last_health_check = datetime.now() mcp_config.health_status = "healthy" mcp_config.error_message = None - - # 更新工具状态 - self._update_tool_status(config) + config.status = ToolStatus.AVAILABLE.value + self.db.commit() - + return { "success": True, - "message": "MCP连接成功", - # "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)} - "details": {"server_url": mcp_config.server_url} + "message": "MCP连接成功并同步工具列表", + "details": { + "server_url": mcp_config.server_url, + "tools_count": len(tool_names), + "tools": tool_names + } } - except Exception as e: - await client.disconnect() - - # 更新错误状态 - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "error" - mcp_config.error_message = str(e) - self._update_tool_status(config) - self.db.commit() - - return {"success": False, "message": f"MCP功能测试失败: {str(e)}"} + else: + return {"success": False, "message": f"同步工具失败: {error}"} else: - # 更新连接失败状态 + # 更新错误状态 mcp_config.last_health_check = datetime.now() mcp_config.health_status = "error" - mcp_config.error_message = "连接失败" - self._update_tool_status(config) + mcp_config.error_message = test_result.get("error", "连接失败") + config.status = ToolStatus.ERROR.value self.db.commit() - - return {"success": False, "message": "MCP连接失败"} - + + return test_result + except Exception as e: - # 更新异常状态 - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == config.id - ).first() - if mcp_config: - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "error" - mcp_config.error_message = str(e) - self._update_tool_status(config) - self.db.commit() - - return {"success": False, "message": f"MCP测试异常: {str(e)}"} + logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}") + return {"success": False, "message": f"测试失败: {str(e)}"} @staticmethod async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]: @@ -1190,57 +1247,44 @@ class ToolService: # 创建MCP客户端 connection_config = mcp_config.connection_config or {} + client = SimpleMCPClient(mcp_config.server_url, connection_config) - client = MCPClient(mcp_config.server_url, connection_config) - - if await client.connect(): - try: - # 获取工具列表 - tools = await client.list_tools() - tool_names = [tool.get("name") for tool in tools if tool.get("name")] - - # 更新数据库 - mcp_config.available_tools = tool_names - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "healthy" - mcp_config.error_message = None - - # 更新工具状态 - config.status = ToolStatus.AVAILABLE.value - - self.db.commit() - - await client.disconnect() - - return { - "success": True, - "message": "工具列表同步成功", - "tools_count": len(tool_names), - "tools": tool_names - } - - except Exception as e: - await client.disconnect() - - # 更新错误状态 + async with client: + # 获取工具列表 + tools = await client.list_tools() + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + + # 更新数据库 + mcp_config.available_tools = tool_names + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "healthy" + mcp_config.error_message = None + + # 更新工具状态 + config.status = ToolStatus.AVAILABLE.value + + self.db.commit() + + return { + "success": True, + "message": "工具列表同步成功", + "tools_count": len(tool_names), + "tools": tool_names + } + + except Exception as e: + # 更新错误状态 + try: + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) + if mcp_config: mcp_config.last_health_check = datetime.now() mcp_config.health_status = "error" mcp_config.error_message = str(e) config.status = ToolStatus.ERROR.value self.db.commit() - - return {"success": False, "message": f"获取工具列表失败: {str(e)}"} - else: - # 连接失败 - mcp_config.last_health_check = datetime.now() - mcp_config.health_status = "error" - mcp_config.error_message = "连接失败" - config.status = ToolStatus.ERROR.value - self.db.commit() - - return {"success": False, "message": "MCP连接失败"} - - except Exception as e: + except: + pass + logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}") return {"success": False, "message": f"同步失败: {str(e)}"}