feat(apikey system): tool system development
This commit is contained in:
12
api/app/core/tools/mcp/__init__.py
Normal file
12
api/app/core/tools/mcp/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""MCP工具模块"""
|
||||
|
||||
from .base import MCPTool
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
from .service_manager import MCPServiceManager
|
||||
|
||||
__all__ = [
|
||||
"MCPTool",
|
||||
"MCPClient",
|
||||
"MCPConnectionPool",
|
||||
"MCPServiceManager"
|
||||
]
|
||||
258
api/app/core/tools/mcp/base.py
Normal file
258
api/app/core/tools/mcp/base.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""MCP工具基类"""
|
||||
import time
|
||||
from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPTool(BaseTool):
|
||||
"""MCP工具 - Model Context Protocol工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化MCP工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.connection_config = config.get("connection_config", {})
|
||||
self.available_tools = config.get("available_tools", [])
|
||||
self._client = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
return f"mcp_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
return f"MCP工具 - 连接到 {self.server_url}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.MCP
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 添加工具选择参数
|
||||
if len(self.available_tools) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="tool_name",
|
||||
type=ParameterType.STRING,
|
||||
description="要调用的MCP工具名称",
|
||||
required=True,
|
||||
enum=self.available_tools
|
||||
))
|
||||
|
||||
# 添加通用参数
|
||||
params.extend([
|
||||
ToolParameter(
|
||||
name="arguments",
|
||||
type=ParameterType.OBJECT,
|
||||
description="工具参数(JSON对象)",
|
||||
required=False,
|
||||
default={}
|
||||
),
|
||||
ToolParameter(
|
||||
name="timeout",
|
||||
type=ParameterType.INTEGER,
|
||||
description="超时时间(秒)",
|
||||
required=False,
|
||||
default=30,
|
||||
minimum=1,
|
||||
maximum=300
|
||||
)
|
||||
])
|
||||
|
||||
return params
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MCP工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保连接
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 确定要调用的工具
|
||||
tool_name = kwargs.get("tool_name")
|
||||
if not tool_name and len(self.available_tools) == 1:
|
||||
tool_name = self.available_tools[0]
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("必须指定要调用的MCP工具名称")
|
||||
|
||||
if tool_name not in self.available_tools:
|
||||
raise ValueError(f"MCP工具不存在: {tool_name}")
|
||||
|
||||
# 获取参数
|
||||
arguments = kwargs.get("arguments", {})
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
|
||||
# 调用MCP工具
|
||||
result = await self._call_mcp_tool(tool_name, arguments, timeout)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
try:
|
||||
# 这里应该实现实际的MCP连接逻辑
|
||||
# 为了简化,这里只是模拟连接
|
||||
|
||||
# 测试服务器连接
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# 尝试获取服务器信息
|
||||
async with session.get(f"{self.server_url}/info") as response:
|
||||
if response.status == 200:
|
||||
server_info = await response.json()
|
||||
self.available_tools = server_info.get("tools", [])
|
||||
self._connected = True
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"服务器响应错误: {response.status}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
# 这里应该实现实际的断开逻辑
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""获取MCP服务健康状态"""
|
||||
return {
|
||||
"connected": self._connected,
|
||||
"server_url": self.server_url,
|
||||
"available_tools": self.available_tools,
|
||||
"last_check": time.time()
|
||||
}
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
# 构建MCP请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# 检查MCP响应
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
|
||||
|
||||
return result.get("result", {})
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
try:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 获取工具列表
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if "result" in result:
|
||||
tools = result["result"].get("tools", [])
|
||||
self.available_tools = [tool.get("name") for tool in tools]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
try:
|
||||
# 这里应该实现同步的连接测试
|
||||
# 为了简化,返回基本信息
|
||||
return {
|
||||
"success": bool(self.server_url),
|
||||
"server_url": self.server_url,
|
||||
"connected": self._connected,
|
||||
"available_tools_count": len(self.available_tools),
|
||||
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
626
api/app/core/tools/mcp/client.py
Normal file
626
api/app/core/tools/mcp/client.py
Normal file
@@ -0,0 +1,626 @@
|
||||
"""MCP客户端 - Model Context Protocol客户端实现"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from urllib.parse import urlparse
|
||||
import aiohttp
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPConnectionError(Exception):
|
||||
"""MCP连接错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPProtocolError(Exception):
|
||||
"""MCP协议错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP客户端 - 支持HTTP和WebSocket连接"""
|
||||
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
"""初始化MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: MCP服务器URL
|
||||
connection_config: 连接配置
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
|
||||
# 解析URL确定连接类型
|
||||
parsed_url = urlparse(server_url)
|
||||
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
|
||||
# 请求管理
|
||||
self._request_id = 0
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
||||
|
||||
# 连接池配置
|
||||
self.max_connections = self.connection_config.get("max_connections", 10)
|
||||
self.connection_timeout = self.connection_config.get("timeout", 30)
|
||||
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
|
||||
self.retry_delay = self.connection_config.get("retry_delay", 1)
|
||||
|
||||
# 健康检查
|
||||
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
|
||||
self._health_check_task = None
|
||||
self._last_health_check = None
|
||||
|
||||
# 事件回调
|
||||
self._on_connect_callbacks: List[Callable] = []
|
||||
self._on_disconnect_callbacks: List[Callable] = []
|
||||
self._on_error_callbacks: List[Callable] = []
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器
|
||||
|
||||
Returns:
|
||||
连接是否成功
|
||||
"""
|
||||
try:
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"连接MCP服务器: {self.server_url}")
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
success = await self._connect_websocket()
|
||||
else:
|
||||
success = await self._connect_http()
|
||||
|
||||
if success:
|
||||
self._connected = True
|
||||
await self._start_health_check()
|
||||
await self._notify_connect_callbacks()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
|
||||
await self._notify_error_callbacks(e)
|
||||
return False
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接
|
||||
|
||||
Returns:
|
||||
断开是否成功
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"断开MCP服务器连接: {self.server_url}")
|
||||
|
||||
# 停止健康检查
|
||||
await self._stop_health_check()
|
||||
|
||||
# 取消所有待处理的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
# 断开连接
|
||||
if self.connection_type == "websocket" and self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
elif self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=extra_headers,
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
# 启动消息监听
|
||||
asyncio.create_task(self._websocket_message_handler())
|
||||
|
||||
# 发送初始化消息
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ToolManagementSystem",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket连接失败: {e}")
|
||||
return False
|
||||
|
||||
async def _connect_http(self) -> bool:
|
||||
"""建立HTTP连接"""
|
||||
try:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
|
||||
|
||||
async with self._session.get(test_url) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
else:
|
||||
# 尝试根路径
|
||||
async with self._session.get(self.server_url) as root_response:
|
||||
return root_response.status < 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTTP连接失败: {e}")
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
return False
|
||||
|
||||
async def _websocket_message_handler(self):
|
||||
"""WebSocket消息处理器"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
await self._handle_message(json.loads(message))
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析WebSocket消息失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理器异常: {e}")
|
||||
finally:
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
try:
|
||||
# 检查是否是响应消息
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
# 处理通知消息
|
||||
elif "method" in message:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
async def _handle_notification(self, message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
||||
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
|
||||
|
||||
# 这里可以根据需要处理特定的通知
|
||||
# 例如:工具列表更新、服务器状态变化等
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
|
||||
"""调用MCP工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError(f"工具调用超时: {tool_name}")
|
||||
|
||||
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if not response["error"] is None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError("获取工具列表超时")
|
||||
|
||||
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送请求并等待响应
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
request_id = str(request_data["id"])
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
return await self._send_websocket_request(request_data, request_id, timeout)
|
||||
else:
|
||||
return await self._send_http_request(request_data, timeout)
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
if not self._websocket or self._websocket.closed:
|
||||
raise MCPConnectionError("WebSocket连接已断开")
|
||||
|
||||
# 创建Future等待响应
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
if not self._session:
|
||||
raise MCPConnectionError("HTTP会话未建立")
|
||||
|
||||
try:
|
||||
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
|
||||
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""执行健康检查
|
||||
|
||||
Returns:
|
||||
健康状态信息
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": "未连接",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 发送ping请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "ping"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
self._last_health_check = time.time()
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
"response_time": response_time,
|
||||
"timestamp": self._last_health_check,
|
||||
"server_info": response.get("result", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
async def _start_health_check(self):
|
||||
"""启动健康检查任务"""
|
||||
if self.health_check_interval > 0:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
async def _stop_health_check(self):
|
||||
"""停止健康检查任务"""
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_check_task = None
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""健康检查循环"""
|
||||
try:
|
||||
while self._connected:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self._connected:
|
||||
health_status = await self.health_check()
|
||||
if not health_status["healthy"]:
|
||||
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
|
||||
# 可以在这里实现重连逻辑
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查循环异常: {e}")
|
||||
|
||||
def _get_next_request_id(self) -> str:
|
||||
"""获取下一个请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
|
||||
# 事件回调管理
|
||||
def on_connect(self, callback: Callable):
|
||||
"""注册连接回调"""
|
||||
self._on_connect_callbacks.append(callback)
|
||||
|
||||
def on_disconnect(self, callback: Callable):
|
||||
"""注册断开连接回调"""
|
||||
self._on_disconnect_callbacks.append(callback)
|
||||
|
||||
def on_error(self, callback: Callable):
|
||||
"""注册错误回调"""
|
||||
self._on_error_callbacks.append(callback)
|
||||
|
||||
async def _notify_connect_callbacks(self):
|
||||
"""通知连接回调"""
|
||||
for callback in self._on_connect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_disconnect_callbacks(self):
|
||||
"""通知断开连接回调"""
|
||||
for callback in self._on_disconnect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_error_callbacks(self, error: Exception):
|
||||
"""通知错误回调"""
|
||||
for callback in self._on_error_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"错误回调执行失败: {e}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def last_health_check(self) -> Optional[float]:
|
||||
"""获取最后一次健康检查时间"""
|
||||
return self._last_health_check
|
||||
|
||||
def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""获取连接信息"""
|
||||
return {
|
||||
"server_url": self.server_url,
|
||||
"connection_type": self.connection_type,
|
||||
"connected": self._connected,
|
||||
"last_health_check": self._last_health_check,
|
||||
"pending_requests": len(self._pending_requests),
|
||||
"config": self.connection_config
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
class MCPConnectionPool:
|
||||
"""MCP连接池 - 管理多个MCP客户端连接"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
"""初始化连接池
|
||||
|
||||
Args:
|
||||
max_connections: 最大连接数
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self._clients: Dict[str, MCPClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
|
||||
"""获取或创建MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
|
||||
Returns:
|
||||
MCP客户端实例
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_url in self._clients:
|
||||
client = self._clients[server_url]
|
||||
if client.is_connected:
|
||||
return client
|
||||
else:
|
||||
# 尝试重连
|
||||
if await client.connect():
|
||||
return client
|
||||
else:
|
||||
# 移除失效的客户端
|
||||
del self._clients[server_url]
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self._clients) >= self.max_connections:
|
||||
# 移除最旧的连接
|
||||
oldest_url = next(iter(self._clients))
|
||||
await self._clients[oldest_url].disconnect()
|
||||
del self._clients[oldest_url]
|
||||
|
||||
# 创建新客户端
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if await client.connect():
|
||||
self._clients[server_url] = client
|
||||
return client
|
||||
else:
|
||||
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""断开所有连接"""
|
||||
async with self._lock:
|
||||
for client in self._clients.values():
|
||||
await client.disconnect()
|
||||
self._clients.clear()
|
||||
|
||||
def get_pool_status(self) -> Dict[str, Any]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
"total_connections": len(self._clients),
|
||||
"max_connections": self.max_connections,
|
||||
"connections": {
|
||||
url: client.get_connection_info()
|
||||
for url, client in self._clients.items()
|
||||
}
|
||||
}
|
||||
604
api/app/core/tools/mcp/service_manager.py
Normal file
604
api/app/core/tools/mcp/service_manager.py
Normal file
@@ -0,0 +1,604 @@
|
||||
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPServiceManager:
|
||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化MCP服务管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
|
||||
# 服务状态管理
|
||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
|
||||
|
||||
# 配置
|
||||
self.health_check_interval = 60 # 健康检查间隔(秒)
|
||||
self.max_retry_attempts = 3 # 最大重试次数
|
||||
self.retry_delay = 5 # 重试延迟(秒)
|
||||
|
||||
# 状态
|
||||
self._running = False
|
||||
self._manager_task = None
|
||||
|
||||
async def start(self):
|
||||
"""启动服务管理器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("MCP服务管理器启动")
|
||||
|
||||
# 加载现有服务
|
||||
await self._load_existing_services()
|
||||
|
||||
# 启动管理任务
|
||||
self._manager_task = asyncio.create_task(self._management_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务管理器"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
logger.info("MCP服务管理器停止")
|
||||
|
||||
# 停止管理任务
|
||||
if self._manager_task:
|
||||
self._manager_task.cancel()
|
||||
try:
|
||||
await self._manager_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有监控任务
|
||||
for task in self._monitoring_tasks.values():
|
||||
task.cancel()
|
||||
|
||||
if self._monitoring_tasks:
|
||||
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
|
||||
|
||||
self._monitoring_tasks.clear()
|
||||
|
||||
# 断开所有连接
|
||||
await self.connection_pool.disconnect_all()
|
||||
|
||||
async def register_service(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any],
|
||||
tenant_id: uuid.UUID,
|
||||
service_name: str = None
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""注册MCP服务
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
tenant_id: 租户ID
|
||||
service_name: 服务名称(可选)
|
||||
|
||||
Returns:
|
||||
(是否成功, 服务ID或错误信息, 错误详情)
|
||||
"""
|
||||
try:
|
||||
# 检查服务是否已存在
|
||||
existing_service = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.server_url == server_url
|
||||
).first()
|
||||
|
||||
if existing_service:
|
||||
return False, "服务已存在", f"URL {server_url} 已被注册"
|
||||
|
||||
# 测试连接
|
||||
try:
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if not await client.connect():
|
||||
return False, "连接测试失败", "无法连接到MCP服务器"
|
||||
|
||||
# 获取可用工具
|
||||
available_tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
except Exception as e:
|
||||
return False, "连接测试失败", str(e)
|
||||
|
||||
# 创建工具配置
|
||||
if not service_name:
|
||||
service_name = f"mcp_service_{server_url.split('/')[-1]}"
|
||||
|
||||
tool_config = ToolConfig(
|
||||
name=service_name,
|
||||
description=f"MCP服务 - {server_url}",
|
||||
tool_type=ToolType.MCP.value,
|
||||
tenant_id=tenant_id,
|
||||
version="1.0.0",
|
||||
config_data={
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config
|
||||
}
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush()
|
||||
|
||||
# 创建MCP特定配置
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=server_url,
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
last_health_check=datetime.utcnow()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
self.db.commit()
|
||||
|
||||
service_id = str(tool_config.id)
|
||||
|
||||
# 添加到内存管理
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config,
|
||||
"tenant_id": tenant_id,
|
||||
"available_tools": tool_names,
|
||||
"status": "healthy",
|
||||
"last_health_check": time.time(),
|
||||
"retry_count": 0,
|
||||
"created_at": time.time()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
|
||||
return True, service_id, None
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
|
||||
return False, "注册失败", str(e)
|
||||
|
||||
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
|
||||
"""注销MCP服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 从数据库删除
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if not tool_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 停止监控
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 从内存移除
|
||||
if service_id in self._services:
|
||||
del self._services[service_id]
|
||||
|
||||
logger.info(f"MCP服务注销成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def update_service(
|
||||
self,
|
||||
service_id: str,
|
||||
connection_config: Dict[str, Any] = None,
|
||||
enabled: bool = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""更新MCP服务配置
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
connection_config: 新的连接配置
|
||||
enabled: 是否启用
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
tool_config = mcp_config.base_config
|
||||
|
||||
if connection_config is not None:
|
||||
mcp_config.connection_config = connection_config
|
||||
tool_config.config_data["connection_config"] = connection_config
|
||||
|
||||
if enabled is not None:
|
||||
tool_config.is_enabled = enabled
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新内存状态
|
||||
if service_id in self._services:
|
||||
if connection_config is not None:
|
||||
self._services[service_id]["connection_config"] = connection_config
|
||||
|
||||
# 如果配置有变化,重启监控
|
||||
if connection_config is not None:
|
||||
await self._restart_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务更新成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取服务状态
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
服务状态信息
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return None
|
||||
|
||||
service_info = self._services[service_id].copy()
|
||||
|
||||
# 添加实时健康检查
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
service_info["real_time_health"] = health_status
|
||||
|
||||
except Exception as e:
|
||||
service_info["real_time_health"] = {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
return service_info
|
||||
|
||||
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
|
||||
"""列出所有服务
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
|
||||
Returns:
|
||||
服务列表
|
||||
"""
|
||||
services = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
if tenant_id and service_info["tenant_id"] != tenant_id:
|
||||
continue
|
||||
|
||||
services.append(service_info.copy())
|
||||
|
||||
return services
|
||||
|
||||
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取服务的可用工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return []
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 更新缓存的工具列表
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.available_tools = tool_names
|
||||
self.db.commit()
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def call_service_tool(
|
||||
self,
|
||||
service_id: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""调用服务工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
raise ValueError(f"服务不存在: {service_id}")
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
result = await client.call_tool(tool_name, arguments, timeout)
|
||||
|
||||
# 更新服务状态为健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["last_health_check"] = time.time()
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# 更新服务状态为错误
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def _load_existing_services(self):
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.is_enabled == True
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
tool_config = mcp_config.base_config
|
||||
service_id = str(mcp_config.id)
|
||||
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"tenant_id": tool_config.tenant_id,
|
||||
"available_tools": mcp_config.available_tools or [],
|
||||
"status": mcp_config.health_status or "unknown",
|
||||
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
|
||||
"retry_count": 0,
|
||||
"created_at": tool_config.created_at.timestamp()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有服务失败: {e}")
|
||||
|
||||
async def _start_service_monitoring(self, service_id: str):
|
||||
"""启动服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(self._monitor_service(service_id))
|
||||
self._monitoring_tasks[service_id] = task
|
||||
|
||||
async def _stop_service_monitoring(self, service_id: str):
|
||||
"""停止服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
task = self._monitoring_tasks.pop(service_id)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _restart_service_monitoring(self, service_id: str):
|
||||
"""重启服务监控"""
|
||||
await self._stop_service_monitoring(service_id)
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
async def _monitor_service(self, service_id: str):
|
||||
"""监控单个服务"""
|
||||
try:
|
||||
while self._running and service_id in self._services:
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
# 执行健康检查
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
|
||||
if health_status["healthy"]:
|
||||
# 服务健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
# 更新工具列表
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
except Exception as e:
|
||||
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
|
||||
|
||||
else:
|
||||
# 服务不健康
|
||||
service_info["status"] = "unhealthy"
|
||||
service_info["last_error"] = health_status.get("error", "健康检查失败")
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
# 更新数据库
|
||||
await self._update_service_health_in_db(service_id, health_status)
|
||||
|
||||
except Exception as e:
|
||||
# 监控异常
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
|
||||
|
||||
# 如果重试次数过多,暂停监控
|
||||
if service_info["retry_count"] >= self.max_retry_attempts:
|
||||
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
|
||||
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
|
||||
service_info["retry_count"] = 0 # 重置重试计数
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
|
||||
|
||||
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
|
||||
"""更新数据库中的服务健康状态"""
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.utcnow()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
else:
|
||||
mcp_config.error_message = None
|
||||
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
|
||||
self.db.rollback()
|
||||
|
||||
async def _management_loop(self):
|
||||
"""管理循环 - 处理服务清理等任务"""
|
||||
try:
|
||||
while self._running:
|
||||
# 清理失效的服务
|
||||
await self._cleanup_failed_services()
|
||||
|
||||
# 等待下次循环
|
||||
await asyncio.sleep(300) # 5分钟
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"管理循环异常: {e}")
|
||||
|
||||
async def _cleanup_failed_services(self):
|
||||
"""清理长期失效的服务"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
cleanup_threshold = 24 * 60 * 60 # 24小时
|
||||
|
||||
services_to_cleanup = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
# 检查服务是否长期失效
|
||||
if (service_info["status"] in ["error", "unhealthy"] and
|
||||
current_time - service_info["last_health_check"] > cleanup_threshold):
|
||||
|
||||
services_to_cleanup.append(service_id)
|
||||
|
||||
for service_id in services_to_cleanup:
|
||||
logger.warning(f"清理长期失效的服务: {service_id}")
|
||||
|
||||
# 停止监控但不删除数据库记录
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 标记为禁用
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if tool_config:
|
||||
tool_config.is_enabled = False
|
||||
self.db.commit()
|
||||
|
||||
# 从内存移除
|
||||
del self._services[service_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理失效服务失败: {e}")
|
||||
|
||||
def get_manager_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"total_services": len(self._services),
|
||||
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
|
||||
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
|
||||
"monitoring_tasks": len(self._monitoring_tasks),
|
||||
"connection_pool_status": self.connection_pool.get_pool_status()
|
||||
}
|
||||
Reference in New Issue
Block a user