feat(agent tool): mcp tool repair

This commit is contained in:
谢俊男
2026-01-07 18:59:28 +08:00
parent 99b4a17f43
commit 25ce86ae93
9 changed files with 621 additions and 1436 deletions

View File

@@ -215,8 +215,8 @@ async def sync_mcp_tools(
"""同步MCP工具列表"""
try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if result["success"] is False:
raise HTTPException(status_code=404, detail=result["message"])
if not result.get("success", False):
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
return success(data=result, msg="MCP工具列表同步完成")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -78,13 +78,20 @@ class LangchainAdapter:
Args:
tool: 内部工具实例
operation: 特定操作(适用于有操作的工具)
operation: 特定操作(适用于有操作的工具)或MCP工具名称
Returns:
Langchain兼容的工具包装器
"""
try:
if operation and tool.name in ['datetime_tool', 'json_tool']:
# 处理MCP工具的特定工具名称
if hasattr(tool, 'tool_type') and tool.tool_type.value == 'mcp' and operation:
# 为MCP工具创建特定工具名称的实例
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
elif operation and tool.name in ['datetime_tool', 'json_tool']:
# 为特定操作创建工具
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
@@ -106,6 +113,18 @@ class LangchainAdapter:
from app.core.tools.builtin.operation_tool import OperationTool
return OperationTool(base_tool, operation)
@staticmethod
def _create_mcp_tool_with_name(base_tool: BaseTool, tool_name: str) -> BaseTool:
"""为MCP工具创建指定工具名称的实例"""
from app.core.tools.mcp.base import MCPTool
# 创建新的配置,指定具体工具名称
new_config = base_tool.config.copy()
new_config["tool_name"] = tool_name
# 创建新的MCP工具实例
return MCPTool(f"{base_tool.tool_id}_{tool_name}", new_config)
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具

View File

@@ -1,12 +1,20 @@
"""MCP工具模块"""
"""MCP 工具模块 - Model Context Protocol 支持"""
from app.core.tools.mcp.base import MCPTool
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.service_manager import MCPServiceManager
# 主要类导出
from .base import MCPTool, MCPToolManager, MCPError
from .client import SimpleMCPClient, MCPConnectionError
from .service_manager import MCPServiceManager
__all__ = [
# 核心类
"MCPTool",
"MCPClient",
"MCPConnectionPool",
"MCPToolManager",
"MCPError",
# 客户端类
"SimpleMCPClient",
"MCPConnectionError",
# 服务管理(简化版)
"MCPServiceManager"
]

View File

@@ -1,10 +1,9 @@
"""MCP工具基类"""
"""MCP工具基类 - 整合版本"""
import time
from typing import Dict, Any, List
from typing import List, Dict, Any
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -14,215 +13,174 @@ class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {})
self.tool_name = config.get("tool_name", "") # 特定工具名称
self.tool_schema = config.get("tool_schema", {}) # 工具参数 schema
self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property
def name(self) -> str:
"""工具名称"""
return f"mcp_tool_{self.tool_id[:8]}"
return f"mcp_{self.tool_name}" if self.tool_name else f"mcp_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
return f"MCP工具 - 连接到 {self.server_url}"
if self.tool_schema.get("description"):
return self.tool_schema["description"]
return f"MCP工具: {self.tool_name}" if self.tool_name else f"MCP工具 - 连接到 {self.server_url}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="要调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
ToolParameter(
"""从 MCP 工具 schema 生成参数"""
if not self.tool_schema:
return [ToolParameter(
name="arguments",
type=ParameterType.OBJECT,
description="工具参数JSON对象",
description="工具参数",
required=False,
default={}
),
ToolParameter(
name="timeout",
type=ParameterType.INTEGER,
description="超时时间(秒)",
required=False,
default=30,
minimum=1,
maximum=300
)
])
)]
# 解析 MCP 工具的 inputSchema
input_schema = self.tool_schema.get("inputSchema", {})
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
params = []
for param_name, param_def in properties.items():
param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
params.append(ToolParameter(
name=param_name,
type=param_type,
description=param_def.get("description", f"参数: {param_name}"),
required=param_name in required_fields,
default=param_def.get("default"),
enum=param_def.get("enum"),
minimum=param_def.get("minimum"),
maximum=param_def.get("maximum")
))
return params
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
"""转换 JSON Schema 类型到 ParameterType"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(json_type, ParameterType.STRING)
async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具"""
"""执行 MCP 工具"""
start_time = time.time()
try:
# 确保连接
if not self._connected:
await self.connect()
from .client import SimpleMCPClient
# 确定要调用的工具
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name:
raise ValueError("必须指定要调用的MCP工具名称")
if tool_name not in self.available_tools:
raise ValueError(f"MCP工具不存在: {tool_name}")
# 获取参数
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
# 调用MCP工具
result = await self._call_mcp_tool(tool_name, arguments, timeout)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
client = SimpleMCPClient(self.server_url, self.connection_config)
async with client:
# 使用指定的工具名称或默认第一个工具
tool_name_to_use = self.tool_name
if not tool_name_to_use and self.available_tools:
tool_name_to_use = self.available_tools[0]
if not tool_name_to_use:
raise Exception("未指定工具名称且无可用工具")
result = await client.call_tool(tool_name_to_use, kwargs)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
logger.error(f"MCP工具执行失败: {self.tool_name or 'unknown'}, 错误: {e}")
return ToolResult.error_result(
error=str(e),
error_code="MCP_ERROR",
error_code="MCP_EXECUTION_ERROR",
execution_time=execution_time
)
class MCPError(Exception):
"""MCP 错误基类"""
pass
class MCPToolManager:
"""MCP 工具管理器 - 简化版本"""
async def connect(self) -> bool:
"""连接到MCP服务器"""
def __init__(self, db=None):
self.db = db
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
async def discover_tools(
self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> tuple[bool, List[Dict[str, Any]], str | None]:
"""发现 MCP 服务器上的工具"""
try:
from .client import MCPClient
from .client import SimpleMCPClient
if self._connected:
return True
self._client = MCPClient(self.server_url, self.connection_config)
if await self._client.connect():
self._connected = True
# 更新可用工具列表
await self._update_available_tools()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
logger.error(f"MCP服务器连接失败: {self.server_url}")
return False
client = SimpleMCPClient(server_url, connection_config)
async with client:
tools = await client.list_tools()
# 缓存工具信息
self._tool_cache[server_url] = {
"tools": tools,
"connection_config": connection_config,
"last_updated": time.time()
}
logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
return True, tools, None
except Exception as e:
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
self._connected = False
return False
error_msg = f"发现工具失败: {e}"
logger.error(error_msg)
return False, [], error_msg
async def _update_available_tools(self):
"""更新可用工具列表"""
async def test_tool_connection(
self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try:
if self._client and self._connected:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
except Exception as e:
logger.error(f"更新MCP工具列表失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
await self._client.disconnect()
self._client = None
from .client import SimpleMCPClient
self._connected = False
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def get_health_status(self) -> Dict[str, Any]:
"""获取MCP服务健康状态"""
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
if not self._client or not self._connected:
raise Exception("MCP客户端未连接")
try:
result = await self._client.call_tool(tool_name, arguments, timeout)
return result
except Exception as e:
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
raise
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
if self._client:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
client = SimpleMCPClient(server_url, connection_config)
async with client:
tools = await client.list_tools()
return {
"success": True,
"tools_count": len(tools),
"tools": [tool.get("name") for tool in tools],
"message": "连接成功"
}
except Exception as e:
return {
"success": False,
"error": str(e)
"error": str(e),
"message": "连接失败"
}

View File

@@ -1,9 +1,8 @@
"""MCP客户端 - Model Context Protocol客户端实现"""
"""MCP客户端 - 简化版本"""
import asyncio
import json
import time
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
from typing import Dict, Any, List
import aiohttp
import websockets
from websockets.exceptions import ConnectionClosed
@@ -18,139 +17,156 @@ class MCPConnectionError(Exception):
pass
class MCPProtocolError(Exception):
"""MCP协议错误"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
class SimpleMCPClient:
"""简化的 MCP 客户端"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url
self.connection_config = connection_config or {}
self.timeout = self.connection_config.get("timeout", 30)
# 解析URL确定连接类型
parsed_url = urlparse(server_url)
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 确定连接类型
self.is_websocket = server_url.startswith(("ws://", "wss://"))
# 连接状态
self._connected = False
self._websocket = None
self._session = None
# 请求管理
self._request_id = 0
self._pending_requests: Dict[str, asyncio.Future] = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
self._pending_requests = {}
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
async def connect(self):
"""建立连接"""
try:
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
if self.is_websocket:
await self._connect_websocket()
else:
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
await self._connect_http()
except Exception as e:
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
raise MCPConnectionError(f"连接失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
async def disconnect(self):
"""断开连接"""
try:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
if self._websocket:
await self._websocket.close()
self._websocket = None
elif self._session:
if self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
logger.error(f"断开连接失败: {e}")
def _build_auth_headers(self) -> Dict[str, str]:
"""构建认证头"""
headers = {}
auth_type = self.connection_config.get("auth_type", "none")
async def _connect_websocket(self):
"""WebSocket 连接"""
headers = self._build_headers()
self._websocket = await websockets.connect(
self.server_url,
extra_headers=headers,
timeout=self.timeout
)
# 启动消息处理
asyncio.create_task(self._handle_websocket_messages())
# 发送初始化消息
await self._send_initialize()
async def _connect_http(self):
"""HTTP 连接"""
headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(
headers=headers,
timeout=timeout
)
# 对于 ModelScope MCP 服务,需要先发送初始化请求
if "modelscope.net" in self.server_url:
await self._initialize_modelscope_session()
async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话"""
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
}
}
try:
async with self._session.post(
self.server_url,
json=init_request
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
init_response = await response.json()
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
# 获取 session ID
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id})
# 发送 initialized 通知
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
async with self._session.post(
self.server_url,
json=initialized_notification
) as notif_response:
pass
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
# 添加认证头
auth_config = self.connection_config.get("auth_config", {})
auth_type = self.connection_config.get("auth_type", "none")
if auth_type == "api_key":
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif auth_type == "bearer_token":
if auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "api_key":
key = auth_config.get("api_key")
header_name = auth_config.get("key_name", "X-API-Key")
if key:
headers[header_name] = key
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
@@ -161,160 +177,63 @@ class MCPClient:
return headers
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
extra_headers.update(auth_headers)
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
async def _send_initialize(self):
"""发送初始化消息"""
init_message = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if init_response.get("error", None) is not None:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
headers.update(auth_headers)
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
async def _handle_websocket_messages(self):
"""处理 WebSocket 消息"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
data = json.loads(message)
# 处理响应
if "id" in data:
request_id = str(data["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
logger.error(f"WebSocket消息处理异常: {e}")
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
@staticmethod
async def _handle_notification(message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"id": self._get_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
@@ -322,343 +241,69 @@ class MCPClient:
}
}
try:
response = await self._send_request(request_data, timeout)
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/list"
"id": self._get_request_id(),
"method": "tools/list",
"params": {}
}
try:
response = await self._send_request(request_data, timeout)
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
if self.connection_type == "websocket":
request_id = str(request_data["id"])
return await self._send_websocket_request(request_data, request_id, timeout)
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
response = await self._send_http_request(request_data)
# 创建Future等待响应
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送WebSocket请求"""
request_id = str(request_data["id"])
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
response = await asyncio.wait_for(future, timeout=self.timeout)
return response
except asyncio.TimeoutError:
await self._pending_requests.pop(request_id, None)
self._pending_requests.pop(request_id, None)
raise
except Exception as e:
await self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
self.server_url,
json=request_data
) as response:
if response.status == 200:
return await response.json()
else:
async with self._session.post(
self.server_url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as root_response:
if root_response.status != 200:
error_text = await root_response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = round((time.time() - start_time) * 1000)
self._last_health_check = round(time.time() * 1000)
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
def _get_request_id(self) -> str:
"""获取请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
class MCPConnectionPool:
"""MCP连接池 - 管理多个MCP客户端连接"""
def __init__(self, max_connections: int = 10):
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
async def disconnect_all(self):
"""断开所有连接"""
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
return {
"total_connections": len(self._clients),
"max_connections": self.max_connections,
"connections": {
url: client.get_connection_info()
for url, client in self._clients.items()
}
}
return f"req_{self._request_id}_{int(time.time() * 1000)}"

View File

@@ -1,6 +1,4 @@
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
import asyncio
import time
"""MCP服务管理器 - 简化版本"""
import uuid
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
@@ -8,136 +6,53 @@ from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
from app.core.tools.mcp.base import MCPToolManager
logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
"""MCP服务管理器 - 简化版本,主要用于工具创建"""
def __init__(self, db: Session = None):
"""初始化MCP服务管理器
Args:
db: 数据库会话(可选)
"""
self.db = db
if db:
self.connection_pool = MCPConnectionPool(max_connections=20)
else:
self.connection_pool = None
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
# 配置
self.health_check_interval = 60 # 健康检查间隔(秒)
self.max_retry_attempts = 3 # 最大重试次数
self.retry_delay = 5 # 重试延迟(秒)
# 状态
self._running = False
self._manager_task = None
self.tool_manager = MCPToolManager(db) if db else None
async def start(self):
"""启动服务管理器"""
if self._running:
return
self._running = True
logger.info("MCP服务管理器启动")
# 加载现有服务
await self._load_existing_services()
# 启动管理任务
self._manager_task = asyncio.create_task(self._management_loop())
async def stop(self):
"""停止服务管理器"""
if not self._running:
return
self._running = False
logger.info("MCP服务管理器停止")
# 停止管理任务
if self._manager_task:
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
# 停止所有监控任务
for task in self._monitoring_tasks.values():
task.cancel()
if self._monitoring_tasks:
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
self._monitoring_tasks.clear()
# 断开所有连接
await self.connection_pool.disconnect_all()
async def register_service(
async def create_mcp_tool(
self,
server_url: str,
connection_config: Dict[str, Any],
tenant_id: uuid.UUID,
tool_name: str,
service_name: str = None
) -> Tuple[bool, str, Optional[str]]:
"""注册MCP服务
"""创建单个MCP工具
Args:
server_url: 服务器URL
connection_config: 连接配置
tenant_id: 租户ID
service_name: 服务名称(可选)
tool_name: 具体工具名称
service_name: 服务名称
Returns:
(是否成功, 服务ID或错误信息, 错误详情)
(是否成功, 工具ID或错误信息, 错误详情)
"""
try:
# 检查服务是否已存在
existing_service = self.db.query(MCPToolConfig).filter(
MCPToolConfig.server_url == server_url
).first()
if existing_service:
return False, "服务已存在", f"URL {server_url} 已被注册"
# 测试连接
try:
client = MCPClient(server_url, connection_config)
if not await client.connect():
return False, "连接测试失败", "无法连接到MCP服务器"
# 获取可用工具
available_tools = await client.list_tools()
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
await client.disconnect()
except Exception as e:
return False, "连接测试失败", str(e)
if not service_name:
service_name = f"mcp_{tool_name}"
# 创建工具配置
if not service_name:
service_name = f"mcp_service_{server_url.split('/')[-1]}"
tool_config = ToolConfig(
name=service_name,
description=f"MCP服务 - {server_url}",
description=f"MCP工具: {tool_name}",
tool_type=ToolType.MCP.value,
tenant_id=tenant_id,
version="1.0.0",
status=ToolStatus.AVAILABLE.value,
config_data={
"server_url": server_url,
"connection_config": connection_config
"connection_config": connection_config,
"tool_name": tool_name
}
)
@@ -149,460 +64,22 @@ class MCPServiceManager:
id=tool_config.id,
server_url=server_url,
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
available_tools=[tool_name],
health_status="unknown",
last_health_check=datetime.now()
)
self.db.add(mcp_config)
self.db.commit()
service_id = str(tool_config.id)
# 添加到内存管理
self._services[service_id] = {
"id": service_id,
"server_url": server_url,
"connection_config": connection_config,
"tenant_id": tenant_id,
"available_tools": tool_names,
"status": "healthy",
"last_health_check": time.time(),
"retry_count": 0,
"created_at": time.time()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
return True, service_id, None
logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})")
return True, str(tool_config.id), None
except Exception as e:
self.db.rollback()
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
return False, "注册失败", str(e)
logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}")
return False, "创建失败", str(e)
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
"""注销MCP服务
Args:
service_id: 服务ID
Returns:
(是否成功, 错误信息)
"""
try:
# 从数据库删除
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if not tool_config:
return False, "服务不存在"
self.db.delete(tool_config)
self.db.commit()
# 停止监控
await self._stop_service_monitoring(service_id)
# 从内存移除
if service_id in self._services:
del self._services[service_id]
logger.info(f"MCP服务注销成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def update_service(
self,
service_id: str,
connection_config: Dict[str, Any] = None,
enabled: bool = None
) -> Tuple[bool, str]:
"""更新MCP服务配置
Args:
service_id: 服务ID
connection_config: 新的连接配置
enabled: 是否启用
Returns:
(是否成功, 错误信息)
"""
try:
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if not mcp_config:
return False, "服务不存在"
tool_config = mcp_config.base_config
if connection_config is not None:
mcp_config.connection_config = connection_config
tool_config.config_data["connection_config"] = connection_config
if enabled is not None:
tool_config.is_enabled = enabled
self.db.commit()
# 更新内存状态
if service_id in self._services:
if connection_config is not None:
self._services[service_id]["connection_config"] = connection_config
# 如果配置有变化,重启监控
if connection_config is not None:
await self._restart_service_monitoring(service_id)
logger.info(f"MCP服务更新成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
"""获取服务状态
Args:
service_id: 服务ID
Returns:
服务状态信息
"""
if service_id not in self._services:
return None
service_info = self._services[service_id].copy()
# 添加实时健康检查
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
service_info["real_time_health"] = health_status
except Exception as e:
service_info["real_time_health"] = {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
return service_info
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
"""列出所有服务
Args:
tenant_id: 租户ID过滤
Returns:
服务列表
"""
services = []
for service_id, service_info in self._services.items():
if tenant_id and service_info["tenant_id"] != tenant_id:
continue
services.append(service_info.copy())
return services
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的可用工具
Args:
service_id: 服务ID
Returns:
工具列表
"""
if service_id not in self._services:
return []
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
tools = await client.list_tools()
# 更新缓存的工具列表
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.available_tools = tool_names
self.db.commit()
return tools
except Exception as e:
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
return []
async def call_service_tool(
self,
service_id: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: int = 30
) -> Dict[str, Any]:
"""调用服务工具
Args:
service_id: 服务ID
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
执行结果
"""
if service_id not in self._services:
raise ValueError(f"服务不存在: {service_id}")
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
result = await client.call_tool(tool_name, arguments, timeout)
# 更新服务状态为健康
service_info["status"] = "healthy"
service_info["last_health_check"] = time.time()
service_info["retry_count"] = 0
return result
except Exception as e:
# 更新服务状态为错误
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
raise
async def _load_existing_services(self):
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.status == ToolStatus.AVAILABLE.value,
ToolConfig.tool_type == ToolType.MCP.value
).all()
for mcp_config in mcp_configs:
tool_config = mcp_config.base_config
service_id = str(mcp_config.id)
self._services[service_id] = {
"id": service_id,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"tenant_id": tool_config.tenant_id,
"available_tools": mcp_config.available_tools or [],
"status": mcp_config.health_status or "unknown",
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
"retry_count": 0,
"created_at": tool_config.created_at.timestamp()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
except Exception as e:
logger.error(f"加载现有服务失败: {e}")
async def _start_service_monitoring(self, service_id: str):
"""启动服务监控"""
if service_id in self._monitoring_tasks:
return
task = asyncio.create_task(self._monitor_service(service_id))
self._monitoring_tasks[service_id] = task
async def _stop_service_monitoring(self, service_id: str):
"""停止服务监控"""
if service_id in self._monitoring_tasks:
task = self._monitoring_tasks.pop(service_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _restart_service_monitoring(self, service_id: str):
"""重启服务监控"""
await self._stop_service_monitoring(service_id)
await self._start_service_monitoring(service_id)
async def _monitor_service(self, service_id: str):
"""监控单个服务"""
try:
while self._running and service_id in self._services:
service_info = self._services[service_id]
try:
# 执行健康检查
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
if health_status["healthy"]:
# 服务健康
service_info["status"] = "healthy"
service_info["retry_count"] = 0
# 更新工具列表
try:
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
except Exception as e:
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
else:
# 服务不健康
service_info["status"] = "unhealthy"
service_info["last_error"] = health_status.get("error", "健康检查失败")
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
# 更新数据库
await self._update_service_health_in_db(service_id, health_status)
except Exception as e:
# 监控异常
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
# 如果重试次数过多,暂停监控
if service_info["retry_count"] >= self.max_retry_attempts:
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
service_info["retry_count"] = 0 # 重置重试计数
# 等待下次检查
await asyncio.sleep(self.health_check_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
"""更新数据库中的服务健康状态"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.now()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")
else:
mcp_config.error_message = None
self.db.commit()
except Exception as e:
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
self.db.rollback()
async def _management_loop(self):
"""管理循环 - 处理服务清理等任务"""
try:
while self._running:
# 清理失效的服务
await self._cleanup_failed_services()
# 等待下次循环
await asyncio.sleep(300) # 5分钟
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"管理循环异常: {e}")
async def _cleanup_failed_services(self):
"""清理长期失效的服务"""
try:
current_time = time.time()
cleanup_threshold = 24 * 60 * 60 # 24小时
services_to_cleanup = []
for service_id, service_info in self._services.items():
# 检查服务是否长期失效
if (service_info["status"] in ["error", "unhealthy"] and
current_time - service_info["last_health_check"] > cleanup_threshold):
services_to_cleanup.append(service_id)
for service_id in services_to_cleanup:
logger.warning(f"清理长期失效的服务: {service_id}")
# 停止监控但不删除数据库记录
await self._stop_service_monitoring(service_id)
# 标记为禁用
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if tool_config:
tool_config.is_enabled = False
self.db.commit()
# 从内存移除
del self._services[service_id]
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {
"running": self._running,
"total_services": len(self._services),
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
"monitoring_tasks": len(self._monitoring_tasks),
"connection_pool_status": self.connection_pool.get_pool_status()
}
def get_tool_manager(self) -> MCPToolManager:
"""获取工具管理器实例"""
return self.tool_manager

View File

@@ -77,7 +77,7 @@ class AppChatService:
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools:
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
@@ -109,20 +109,21 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
# web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
# if web_search_enable == True:
# search_tool = create_web_search_tool({})
# tools.append(search_tool)
#
# logger.debug(
# "已添加网络搜索工具",
# extra={
# "tool_count": len(tools)
# }
# )
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters
@@ -226,7 +227,7 @@ class AppChatService:
# 获取工具服务
tool_service = ToolService(self.db)
if hasattr(config, 'tools') and config.tools:
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
@@ -258,20 +259,21 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
# web_tools = config.tools
# web_search_choice = web_tools.get("web_search", {})
# web_search_enable = web_search_choice.get("enabled", False)
# if web_search == True:
# if web_search_enable == True:
# search_tool = create_web_search_tool({})
# tools.append(search_tool)
#
# logger.debug(
# "已添加网络搜索工具",
# extra={
# "tool_count": len(tools)
# }
# )
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
web_tools = config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 获取模型参数
model_parameters = config.model_parameters

View File

@@ -297,19 +297,35 @@ class DraftRunService:
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
if hasattr(agent_config, 'tools') and agent_config.tools:
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
ToolRepository.get_tenant_id_by_workspace_id(
self.db, str(workspace_id)))
if tool_instance:
if tool_instance.name == "baidu_search_tool" and not web_search:
continue
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:
@@ -507,7 +523,7 @@ class DraftRunService:
tool_service = ToolService(self.db)
# 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools:
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
for tool_config in agent_config.tools:
if tool_config.get("enabled", False):
# 根据工具名称查找工具实例
@@ -520,6 +536,22 @@ class DraftRunService:
# 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具
if agent_config.knowledge_retrieval:

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from sqlalchemy.orm import Session
from app.core.tools.mcp import MCPClient
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
from app.repositories.tool_repository import (
ToolRepository, BuiltinToolRepository, CustomToolRepository,
MCPToolRepository, ToolExecutionRepository
@@ -42,6 +42,9 @@ class ToolService:
def __init__(self, db: Session):
self.db = db
self._tool_cache: Dict[str, BaseTool] = {}
# MCP管理器
self.mcp_tool_manager = MCPToolManager(db)
# 初始化仓储
self.tool_repo = ToolRepository()
@@ -675,23 +678,85 @@ class ToolService:
return []
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取MCP工具的方法"""
"""获取MCP工具的方法和参数"""
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return []
available_tools = mcp_config.available_tools or []
if not available_tools:
return []
# 如果没有工具列表,尝试同步
try:
success, tools, _ = await self.mcp_tool_manager.discover_tools(
mcp_config.server_url, mcp_config.connection_config or {}
)
if success:
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
mcp_config.available_tools = tool_names
self.db.commit()
available_tools = tool_names
except Exception as e:
logger.error(f"同步MCP工具列表失败: {e}")
return []
methods = []
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": [] # MCP工具参数需要动态获取
})
# 获取工具详细信息
try:
success, tools, _ = await self.mcp_tool_manager.discover_tools(
mcp_config.server_url, mcp_config.connection_config or {}
)
if success:
tools_dict = {tool.get("name"): tool for tool in tools if tool.get("name")}
for tool_name in available_tools:
tool_info = tools_dict.get(tool_name, {})
# 解析工具参数
parameters = []
input_schema = tool_info.get("inputSchema", {})
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
for param_name, param_def in properties.items():
parameters.append({
"name": param_name,
"type": param_def.get("type", "string"),
"description": param_def.get("description", ""),
"required": param_name in required_fields,
"default": param_def.get("default"),
"enum": param_def.get("enum"),
"minimum": param_def.get("minimum"),
"maximum": param_def.get("maximum")
})
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": tool_info.get("description", f"MCP工具: {tool_name}"),
"parameters": parameters
})
else:
# 如果无法获取详细信息,返回基本信息
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": []
})
except Exception as e:
logger.error(f"获取MCP工具详细信息失败: {e}")
# 返回基本信息
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": []
})
return methods
@@ -812,10 +877,14 @@ class ToolService:
if not mcp_config:
return None
# 从配置中获取特定工具名称
tool_name = config.config_data.get("tool_name")
tool_config = {
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"available_tools": mcp_config.available_tools or []
"available_tools": mcp_config.available_tools or [],
"tool_name": tool_name # 指定具体工具
}
return MCPTool(str(config.id), tool_config)
@@ -1071,71 +1140,59 @@ class ToolService:
return {}
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
"""测试MCP连接"""
"""测试MCP连接并自动同步工具列表"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == config.id
).first()
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
# tools = await client.list_tools()
await client.disconnect()
# 更新连接状态
# 使用集成的MCP管理器测试连接
test_result = await self.mcp_tool_manager.test_tool_connection(
mcp_config.server_url, mcp_config.connection_config or {}
)
if test_result["success"]:
# 连接成功,自动同步工具列表
success, tools, error = await self.mcp_tool_manager.discover_tools(
mcp_config.server_url, mcp_config.connection_config or {}
)
if success:
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
self._update_tool_status(config)
config.status = ToolStatus.AVAILABLE.value
self.db.commit()
return {
"success": True,
"message": "MCP连接成功",
# "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
"details": {"server_url": mcp_config.server_url}
"message": "MCP连接成功并同步工具列表",
"details": {
"server_url": mcp_config.server_url,
"tools_count": len(tool_names),
"tools": tool_names
}
}
except Exception as e:
await client.disconnect()
# 更新错误状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP功能测试失败: {str(e)}"}
else:
return {"success": False, "message": f"同步工具失败: {error}"}
else:
# 更新连接失败状态
# 更新错误状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = "连接失败"
self._update_tool_status(config)
mcp_config.error_message = test_result.get("error", "连接失败")
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": "MCP连接失败"}
return test_result
except Exception as e:
# 更新异常状态
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == config.id
).first()
if mcp_config:
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}")
return {"success": False, "message": f"测试失败: {str(e)}"}
@staticmethod
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
@@ -1190,57 +1247,44 @@ class ToolService:
# 创建MCP客户端
connection_config = mcp_config.connection_config or {}
client = SimpleMCPClient(mcp_config.server_url, connection_config)
client = MCPClient(mcp_config.server_url, connection_config)
if await client.connect():
try:
# 获取工具列表
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
config.status = ToolStatus.AVAILABLE.value
self.db.commit()
await client.disconnect()
return {
"success": True,
"message": "工具列表同步成功",
"tools_count": len(tool_names),
"tools": tool_names
}
except Exception as e:
await client.disconnect()
# 更新错误状态
async with client:
# 获取工具列表
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态
config.status = ToolStatus.AVAILABLE.value
self.db.commit()
return {
"success": True,
"message": "工具列表同步成功",
"tools_count": len(tool_names),
"tools": tool_names
}
except Exception as e:
# 更新错误状态
try:
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if mcp_config:
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": f"获取工具列表失败: {str(e)}"}
else:
# 连接失败
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = "连接失败"
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": "MCP连接失败"}
except Exception as e:
except:
pass
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
return {"success": False, "message": f"同步失败: {str(e)}"}