feat(apikey system): tool system development

This commit is contained in:
谢俊男
2025-12-20 15:24:28 +08:00
parent 3fbd4f206e
commit c26af11f76
39 changed files with 9338 additions and 4 deletions

View File

@@ -0,0 +1,12 @@
"""MCP工具模块"""
from .base import MCPTool
from .client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager
__all__ = [
"MCPTool",
"MCPClient",
"MCPConnectionPool",
"MCPServiceManager"
]

View File

@@ -0,0 +1,258 @@
"""MCP工具基类"""
import time
from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {})
self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property
def name(self) -> str:
"""工具名称"""
return f"mcp_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
return f"MCP工具 - 连接到 {self.server_url}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="要调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
ToolParameter(
name="arguments",
type=ParameterType.OBJECT,
description="工具参数JSON对象",
required=False,
default={}
),
ToolParameter(
name="timeout",
type=ParameterType.INTEGER,
description="超时时间(秒)",
required=False,
default=30,
minimum=1,
maximum=300
)
])
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具"""
start_time = time.time()
try:
# 确保连接
if not self._connected:
await self.connect()
# 确定要调用的工具
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name:
raise ValueError("必须指定要调用的MCP工具名称")
if tool_name not in self.available_tools:
raise ValueError(f"MCP工具不存在: {tool_name}")
# 获取参数
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
# 调用MCP工具
result = await self._call_mcp_tool(tool_name, arguments, timeout)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MCP_ERROR",
execution_time=execution_time
)
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
# 这里应该实现实际的MCP连接逻辑
# 为了简化,这里只是模拟连接
# 测试服务器连接
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
# 尝试获取服务器信息
async with session.get(f"{self.server_url}/info") as response:
if response.status == 200:
server_info = await response.json()
self.available_tools = server_info.get("tools", [])
self._connected = True
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
raise Exception(f"服务器响应错误: {response.status}")
except Exception as e:
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
# 这里应该实现实际的断开逻辑
self._client = None
self._connected = False
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def get_health_status(self) -> Dict[str, Any]:
"""获取MCP服务健康状态"""
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
# 构建MCP请求
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
# 发送请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
result = await response.json()
# 检查MCP响应
if "error" in result:
error = result["error"]
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
return result.get("result", {})
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
# 获取工具列表
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/list"
}
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status == 200:
result = await response.json()
if "result" in result:
tools = result["result"].get("tools", [])
self.available_tools = [tool.get("name") for tool in tools]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,626 @@
"""MCP客户端 - Model Context Protocol客户端实现"""
import asyncio
import json
import time
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
import aiohttp
import websockets
from websockets.exceptions import ConnectionClosed
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPConnectionError(Exception):
"""MCP连接错误"""
pass
class MCPProtocolError(Exception):
"""MCP协议错误"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url
self.connection_config = connection_config or {}
# 解析URL确定连接类型
parsed_url = urlparse(server_url)
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 连接状态
self._connected = False
self._websocket = None
self._session = None
# 请求管理
self._request_id = 0
self._pending_requests: Dict[str, asyncio.Future] = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
try:
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
else:
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
except Exception as e:
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
try:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
await self._websocket.close()
self._websocket = None
elif self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理器异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
async def _handle_notification(self, message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
try:
response = await self._send_request(request_data, timeout)
if "error" in response:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/list"
}
try:
response = await self._send_request(request_data, timeout)
if not response["error"] is None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
request_id = str(request_data["id"])
if self.connection_type == "websocket":
return await self._send_websocket_request(request_data, request_id, timeout)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
# 创建Future等待响应
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
except Exception as e:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = time.time() - start_time
self._last_health_check = time.time()
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
class MCPConnectionPool:
"""MCP连接池 - 管理多个MCP客户端连接"""
def __init__(self, max_connections: int = 10):
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
async def disconnect_all(self):
"""断开所有连接"""
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
return {
"total_connections": len(self._clients),
"max_connections": self.max_connections,
"connections": {
url: client.get_connection_info()
for url, client in self._clients.items()
}
}

View File

@@ -0,0 +1,604 @@
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
import asyncio
import time
import uuid
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
def __init__(self, db: Session):
"""初始化MCP服务管理器
Args:
db: 数据库会话
"""
self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20)
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
# 配置
self.health_check_interval = 60 # 健康检查间隔(秒)
self.max_retry_attempts = 3 # 最大重试次数
self.retry_delay = 5 # 重试延迟(秒)
# 状态
self._running = False
self._manager_task = None
async def start(self):
"""启动服务管理器"""
if self._running:
return
self._running = True
logger.info("MCP服务管理器启动")
# 加载现有服务
await self._load_existing_services()
# 启动管理任务
self._manager_task = asyncio.create_task(self._management_loop())
async def stop(self):
"""停止服务管理器"""
if not self._running:
return
self._running = False
logger.info("MCP服务管理器停止")
# 停止管理任务
if self._manager_task:
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
# 停止所有监控任务
for task in self._monitoring_tasks.values():
task.cancel()
if self._monitoring_tasks:
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
self._monitoring_tasks.clear()
# 断开所有连接
await self.connection_pool.disconnect_all()
async def register_service(
self,
server_url: str,
connection_config: Dict[str, Any],
tenant_id: uuid.UUID,
service_name: str = None
) -> Tuple[bool, str, Optional[str]]:
"""注册MCP服务
Args:
server_url: 服务器URL
connection_config: 连接配置
tenant_id: 租户ID
service_name: 服务名称(可选)
Returns:
(是否成功, 服务ID或错误信息, 错误详情)
"""
try:
# 检查服务是否已存在
existing_service = self.db.query(MCPToolConfig).filter(
MCPToolConfig.server_url == server_url
).first()
if existing_service:
return False, "服务已存在", f"URL {server_url} 已被注册"
# 测试连接
try:
client = MCPClient(server_url, connection_config)
if not await client.connect():
return False, "连接测试失败", "无法连接到MCP服务器"
# 获取可用工具
available_tools = await client.list_tools()
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
await client.disconnect()
except Exception as e:
return False, "连接测试失败", str(e)
# 创建工具配置
if not service_name:
service_name = f"mcp_service_{server_url.split('/')[-1]}"
tool_config = ToolConfig(
name=service_name,
description=f"MCP服务 - {server_url}",
tool_type=ToolType.MCP.value,
tenant_id=tenant_id,
version="1.0.0",
config_data={
"server_url": server_url,
"connection_config": connection_config
}
)
self.db.add(tool_config)
self.db.flush()
# 创建MCP特定配置
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=server_url,
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
last_health_check=datetime.utcnow()
)
self.db.add(mcp_config)
self.db.commit()
service_id = str(tool_config.id)
# 添加到内存管理
self._services[service_id] = {
"id": service_id,
"server_url": server_url,
"connection_config": connection_config,
"tenant_id": tenant_id,
"available_tools": tool_names,
"status": "healthy",
"last_health_check": time.time(),
"retry_count": 0,
"created_at": time.time()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
return True, service_id, None
except Exception as e:
self.db.rollback()
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
return False, "注册失败", str(e)
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
"""注销MCP服务
Args:
service_id: 服务ID
Returns:
(是否成功, 错误信息)
"""
try:
# 从数据库删除
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if not tool_config:
return False, "服务不存在"
self.db.delete(tool_config)
self.db.commit()
# 停止监控
await self._stop_service_monitoring(service_id)
# 从内存移除
if service_id in self._services:
del self._services[service_id]
logger.info(f"MCP服务注销成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def update_service(
self,
service_id: str,
connection_config: Dict[str, Any] = None,
enabled: bool = None
) -> Tuple[bool, str]:
"""更新MCP服务配置
Args:
service_id: 服务ID
connection_config: 新的连接配置
enabled: 是否启用
Returns:
(是否成功, 错误信息)
"""
try:
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if not mcp_config:
return False, "服务不存在"
tool_config = mcp_config.base_config
if connection_config is not None:
mcp_config.connection_config = connection_config
tool_config.config_data["connection_config"] = connection_config
if enabled is not None:
tool_config.is_enabled = enabled
self.db.commit()
# 更新内存状态
if service_id in self._services:
if connection_config is not None:
self._services[service_id]["connection_config"] = connection_config
# 如果配置有变化,重启监控
if connection_config is not None:
await self._restart_service_monitoring(service_id)
logger.info(f"MCP服务更新成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
"""获取服务状态
Args:
service_id: 服务ID
Returns:
服务状态信息
"""
if service_id not in self._services:
return None
service_info = self._services[service_id].copy()
# 添加实时健康检查
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
service_info["real_time_health"] = health_status
except Exception as e:
service_info["real_time_health"] = {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
return service_info
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
"""列出所有服务
Args:
tenant_id: 租户ID过滤
Returns:
服务列表
"""
services = []
for service_id, service_info in self._services.items():
if tenant_id and service_info["tenant_id"] != tenant_id:
continue
services.append(service_info.copy())
return services
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的可用工具
Args:
service_id: 服务ID
Returns:
工具列表
"""
if service_id not in self._services:
return []
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
tools = await client.list_tools()
# 更新缓存的工具列表
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.available_tools = tool_names
self.db.commit()
return tools
except Exception as e:
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
return []
async def call_service_tool(
self,
service_id: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: int = 30
) -> Dict[str, Any]:
"""调用服务工具
Args:
service_id: 服务ID
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
执行结果
"""
if service_id not in self._services:
raise ValueError(f"服务不存在: {service_id}")
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
result = await client.call_tool(tool_name, arguments, timeout)
# 更新服务状态为健康
service_info["status"] = "healthy"
service_info["last_health_check"] = time.time()
service_info["retry_count"] = 0
return result
except Exception as e:
# 更新服务状态为错误
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
raise
async def _load_existing_services(self):
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.is_enabled == True
).all()
for mcp_config in mcp_configs:
tool_config = mcp_config.base_config
service_id = str(mcp_config.id)
self._services[service_id] = {
"id": service_id,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"tenant_id": tool_config.tenant_id,
"available_tools": mcp_config.available_tools or [],
"status": mcp_config.health_status or "unknown",
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
"retry_count": 0,
"created_at": tool_config.created_at.timestamp()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
except Exception as e:
logger.error(f"加载现有服务失败: {e}")
async def _start_service_monitoring(self, service_id: str):
"""启动服务监控"""
if service_id in self._monitoring_tasks:
return
task = asyncio.create_task(self._monitor_service(service_id))
self._monitoring_tasks[service_id] = task
async def _stop_service_monitoring(self, service_id: str):
"""停止服务监控"""
if service_id in self._monitoring_tasks:
task = self._monitoring_tasks.pop(service_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _restart_service_monitoring(self, service_id: str):
"""重启服务监控"""
await self._stop_service_monitoring(service_id)
await self._start_service_monitoring(service_id)
async def _monitor_service(self, service_id: str):
"""监控单个服务"""
try:
while self._running and service_id in self._services:
service_info = self._services[service_id]
try:
# 执行健康检查
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
if health_status["healthy"]:
# 服务健康
service_info["status"] = "healthy"
service_info["retry_count"] = 0
# 更新工具列表
try:
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
except Exception as e:
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
else:
# 服务不健康
service_info["status"] = "unhealthy"
service_info["last_error"] = health_status.get("error", "健康检查失败")
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
# 更新数据库
await self._update_service_health_in_db(service_id, health_status)
except Exception as e:
# 监控异常
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
# 如果重试次数过多,暂停监控
if service_info["retry_count"] >= self.max_retry_attempts:
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
service_info["retry_count"] = 0 # 重置重试计数
# 等待下次检查
await asyncio.sleep(self.health_check_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
"""更新数据库中的服务健康状态"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.utcnow()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")
else:
mcp_config.error_message = None
self.db.commit()
except Exception as e:
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
self.db.rollback()
async def _management_loop(self):
"""管理循环 - 处理服务清理等任务"""
try:
while self._running:
# 清理失效的服务
await self._cleanup_failed_services()
# 等待下次循环
await asyncio.sleep(300) # 5分钟
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"管理循环异常: {e}")
async def _cleanup_failed_services(self):
"""清理长期失效的服务"""
try:
current_time = time.time()
cleanup_threshold = 24 * 60 * 60 # 24小时
services_to_cleanup = []
for service_id, service_info in self._services.items():
# 检查服务是否长期失效
if (service_info["status"] in ["error", "unhealthy"] and
current_time - service_info["last_health_check"] > cleanup_threshold):
services_to_cleanup.append(service_id)
for service_id in services_to_cleanup:
logger.warning(f"清理长期失效的服务: {service_id}")
# 停止监控但不删除数据库记录
await self._stop_service_monitoring(service_id)
# 标记为禁用
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if tool_config:
tool_config.is_enabled = False
self.db.commit()
# 从内存移除
del self._services[service_id]
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {
"running": self._running,
"total_services": len(self._services),
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
"monitoring_tasks": len(self._monitoring_tasks),
"connection_pool_status": self.connection_pool.get_pool_status()
}