Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
Mark
2026-01-07 19:35:39 +08:00
9 changed files with 621 additions and 1436 deletions

View File

@@ -215,8 +215,8 @@ async def sync_mcp_tools(
"""同步MCP工具列表""" """同步MCP工具列表"""
try: try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id) result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if result["success"] is False: if not result.get("success", False):
raise HTTPException(status_code=404, detail=result["message"]) raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
return success(data=result, msg="MCP工具列表同步完成") return success(data=result, msg="MCP工具列表同步完成")
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -78,13 +78,20 @@ class LangchainAdapter:
Args: Args:
tool: 内部工具实例 tool: 内部工具实例
operation: 特定操作(适用于有操作的工具) operation: 特定操作(适用于有操作的工具)或MCP工具名称
Returns: Returns:
Langchain兼容的工具包装器 Langchain兼容的工具包装器
""" """
try: try:
if operation and tool.name in ['datetime_tool', 'json_tool']: # 处理MCP工具的特定工具名称
if hasattr(tool, 'tool_type') and tool.tool_type.value == 'mcp' and operation:
# 为MCP工具创建特定工具名称的实例
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
return wrapper
elif operation and tool.name in ['datetime_tool', 'json_tool']:
# 为特定操作创建工具 # 为特定操作创建工具
operation_tool = LangchainAdapter._create_operation_tool(tool, operation) operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
wrapper = LangchainToolWrapper(tool_instance=operation_tool) wrapper = LangchainToolWrapper(tool_instance=operation_tool)
@@ -106,6 +113,18 @@ class LangchainAdapter:
from app.core.tools.builtin.operation_tool import OperationTool from app.core.tools.builtin.operation_tool import OperationTool
return OperationTool(base_tool, operation) return OperationTool(base_tool, operation)
@staticmethod
def _create_mcp_tool_with_name(base_tool: BaseTool, tool_name: str) -> BaseTool:
"""为MCP工具创建指定工具名称的实例"""
from app.core.tools.mcp.base import MCPTool
# 创建新的配置,指定具体工具名称
new_config = base_tool.config.copy()
new_config["tool_name"] = tool_name
# 创建新的MCP工具实例
return MCPTool(f"{base_tool.tool_id}_{tool_name}", new_config)
@staticmethod @staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]: def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具 """批量转换工具

View File

@@ -1,12 +1,20 @@
"""MCP工具模块""" """MCP 工具模块 - Model Context Protocol 支持"""
from app.core.tools.mcp.base import MCPTool # 主要类导出
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool from .base import MCPTool, MCPToolManager, MCPError
from app.core.tools.mcp.service_manager import MCPServiceManager from .client import SimpleMCPClient, MCPConnectionError
from .service_manager import MCPServiceManager
__all__ = [ __all__ = [
# 核心类
"MCPTool", "MCPTool",
"MCPClient", "MCPToolManager",
"MCPConnectionPool", "MCPError",
# 客户端类
"SimpleMCPClient",
"MCPConnectionError",
# 服务管理(简化版)
"MCPServiceManager" "MCPServiceManager"
] ]

View File

@@ -1,10 +1,9 @@
"""MCP工具基类""" """MCP工具基类 - 整合版本"""
import time import time
from typing import Dict, Any, List from typing import List, Dict, Any
from app.models.tool_model import ToolType from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
logger = get_business_logger() logger = get_business_logger()
@@ -14,215 +13,174 @@ class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具""" """MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]): def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config) super().__init__(tool_id, config)
self.server_url = config.get("server_url", "") self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {}) self.connection_config = config.get("connection_config", {})
self.tool_name = config.get("tool_name", "") # 特定工具名称
self.tool_schema = config.get("tool_schema", {}) # 工具参数 schema
self.available_tools = config.get("available_tools", []) self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property @property
def name(self) -> str: def name(self) -> str:
"""工具名称""" return f"mcp_{self.tool_name}" if self.tool_name else f"mcp_tool_{self.tool_id[:8]}"
return f"mcp_tool_{self.tool_id[:8]}"
@property @property
def description(self) -> str: def description(self) -> str:
"""工具描述""" if self.tool_schema.get("description"):
return f"MCP工具 - 连接到 {self.server_url}" return self.tool_schema["description"]
return f"MCP工具: {self.tool_name}" if self.tool_name else f"MCP工具 - 连接到 {self.server_url}"
@property @property
def tool_type(self) -> ToolType: def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP return ToolType.MCP
@property @property
def parameters(self) -> List[ToolParameter]: def parameters(self) -> List[ToolParameter]:
"""工具参数定义""" """从 MCP 工具 schema 生成参数"""
params = [] if not self.tool_schema:
return [ToolParameter(
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="要调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
ToolParameter(
name="arguments", name="arguments",
type=ParameterType.OBJECT, type=ParameterType.OBJECT,
description="工具参数JSON对象", description="工具参数",
required=False, required=False,
default={} default={}
), )]
ToolParameter(
name="timeout", # 解析 MCP 工具的 inputSchema
type=ParameterType.INTEGER, input_schema = self.tool_schema.get("inputSchema", {})
description="超时时间(秒)", properties = input_schema.get("properties", {})
required=False, required_fields = input_schema.get("required", [])
default=30,
minimum=1, params = []
maximum=300 for param_name, param_def in properties.items():
) param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
])
params.append(ToolParameter(
name=param_name,
type=param_type,
description=param_def.get("description", f"参数: {param_name}"),
required=param_name in required_fields,
default=param_def.get("default"),
enum=param_def.get("enum"),
minimum=param_def.get("minimum"),
maximum=param_def.get("maximum")
))
return params return params
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
"""转换 JSON Schema 类型到 ParameterType"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(json_type, ParameterType.STRING)
async def execute(self, **kwargs) -> ToolResult: async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具""" """执行 MCP 工具"""
start_time = time.time() start_time = time.time()
try: try:
# 确保连接 from .client import SimpleMCPClient
if not self._connected:
await self.connect()
# 确定要调用的工具 client = SimpleMCPClient(self.server_url, self.connection_config)
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name: async with client:
raise ValueError("必须指定要调用的MCP工具名称") # 使用指定的工具名称或默认第一个工具
tool_name_to_use = self.tool_name
if not tool_name_to_use and self.available_tools:
tool_name_to_use = self.available_tools[0]
if tool_name not in self.available_tools: if not tool_name_to_use:
raise ValueError(f"MCP工具不存在: {tool_name}") raise Exception("未指定工具名称且无可用工具")
# 获取参数 result = await client.call_tool(tool_name_to_use, kwargs)
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
# 调用MCP工具 execution_time = time.time() - start_time
result = await self._call_mcp_tool(tool_name, arguments, timeout) return ToolResult.success_result(
data=result,
execution_time = time.time() - start_time execution_time=execution_time
return ToolResult.success_result( )
data=result,
execution_time=execution_time
)
except Exception as e: except Exception as e:
execution_time = time.time() - start_time execution_time = time.time() - start_time
logger.error(f"MCP工具执行失败: {self.tool_name or 'unknown'}, 错误: {e}")
return ToolResult.error_result( return ToolResult.error_result(
error=str(e), error=str(e),
error_code="MCP_ERROR", error_code="MCP_EXECUTION_ERROR",
execution_time=execution_time execution_time=execution_time
) )
async def connect(self) -> bool:
"""连接到MCP服务器""" class MCPError(Exception):
"""MCP 错误基类"""
pass
class MCPToolManager:
"""MCP 工具管理器 - 简化版本"""
def __init__(self, db=None):
self.db = db
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
async def discover_tools(
self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> tuple[bool, List[Dict[str, Any]], str | None]:
"""发现 MCP 服务器上的工具"""
try: try:
from .client import MCPClient from .client import SimpleMCPClient
if self._connected: client = SimpleMCPClient(server_url, connection_config)
return True
self._client = MCPClient(self.server_url, self.connection_config) async with client:
tools = await client.list_tools()
if await self._client.connect(): # 缓存工具信息
self._connected = True self._tool_cache[server_url] = {
# 更新可用工具列表 "tools": tools,
await self._update_available_tools() "connection_config": connection_config,
logger.info(f"MCP服务器连接成功: {self.server_url}") "last_updated": time.time()
return True }
else:
logger.error(f"MCP服务器连接失败: {self.server_url}") logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
return False return True, tools, None
except Exception as e: except Exception as e:
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}") error_msg = f"发现工具失败: {e}"
self._connected = False logger.error(error_msg)
return False return False, [], error_msg
async def _update_available_tools(self): async def test_tool_connection(
"""更新可用工具列表""" self,
server_url: str,
connection_config: Dict[str, Any] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try: try:
if self._client and self._connected: from .client import SimpleMCPClient
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
except Exception as e:
logger.error(f"更新MCP工具列表失败: {e}")
async def disconnect(self) -> bool: client = SimpleMCPClient(server_url, connection_config)
"""断开MCP服务器连接"""
try:
if self._client:
await self._client.disconnect()
self._client = None
self._connected = False async with client:
logger.info(f"MCP服务器连接已断开: {self.server_url}") tools = await client.list_tools()
return True
except Exception as e: return {
logger.error(f"断开MCP服务器连接失败: {e}") "success": True,
return False "tools_count": len(tools),
"tools": [tool.get("name") for tool in tools],
def get_health_status(self) -> Dict[str, Any]: "message": "连接成功"
"""获取MCP服务健康状态""" }
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
if not self._client or not self._connected:
raise Exception("MCP客户端未连接")
try:
result = await self._client.call_tool(tool_name, arguments, timeout)
return result
except Exception as e:
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
raise
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
if self._client:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
except Exception as e: except Exception as e:
return { return {
"success": False, "success": False,
"error": str(e) "error": str(e),
"message": "连接失败"
} }

View File

@@ -1,9 +1,8 @@
"""MCP客户端 - Model Context Protocol客户端实现""" """MCP客户端 - 简化版本"""
import asyncio import asyncio
import json import json
import time import time
from typing import Dict, Any, List, Optional, Callable from typing import Dict, Any, List
from urllib.parse import urlparse
import aiohttp import aiohttp
import websockets import websockets
from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosed
@@ -18,571 +17,22 @@ class MCPConnectionError(Exception):
pass pass
class MCPProtocolError(Exception): class SimpleMCPClient:
"""MCP协议错误""" """简化的 MCP 客户端"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url self.server_url = server_url
self.connection_config = connection_config or {} self.connection_config = connection_config or {}
self.timeout = self.connection_config.get("timeout", 30)
# 解析URL确定连接类型 # 确定连接类型
parsed_url = urlparse(server_url) self.is_websocket = server_url.startswith(("ws://", "wss://"))
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 连接状态 # 连接状态
self._connected = False
self._websocket = None self._websocket = None
self._session = None self._session = None
# 请求管理
self._request_id = 0 self._request_id = 0
self._pending_requests: Dict[str, asyncio.Future] = {} self._pending_requests = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
try:
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
else:
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
except Exception as e:
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
try:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
await self._websocket.close()
self._websocket = None
elif self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def _build_auth_headers(self) -> Dict[str, str]:
"""构建认证头"""
headers = {}
auth_type = self.connection_config.get("auth_type", "none")
auth_config = self.connection_config.get("auth_config", {})
if auth_type == "api_key":
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
return headers
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
extra_headers.update(auth_headers)
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if init_response.get("error", None) is not None:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
headers.update(auth_headers)
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理器异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
@staticmethod
async def _handle_notification(message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
try:
response = await self._send_request(request_data, timeout)
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/list"
}
try:
response = await self._send_request(request_data, timeout)
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
if self.connection_type == "websocket":
request_id = str(request_data["id"])
return await self._send_websocket_request(request_data, request_id, timeout)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
# 创建Future等待响应
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
await self._pending_requests.pop(request_id, None)
raise
except Exception as e:
await self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
if response.status == 200:
return await response.json()
else:
async with self._session.post(
self.server_url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as root_response:
if root_response.status != 200:
error_text = await root_response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = round((time.time() - start_time) * 1000)
self._last_health_check = round(time.time() * 1000)
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self): async def __aenter__(self):
"""异步上下文管理器入口""" """异步上下文管理器入口"""
@@ -593,72 +43,267 @@ class MCPClient:
"""异步上下文管理器出口""" """异步上下文管理器出口"""
await self.disconnect() await self.disconnect()
async def connect(self):
class MCPConnectionPool: """建立连接"""
"""MCP连接池 - 管理多个MCP客户端连接""" try:
if self.is_websocket:
def __init__(self, max_connections: int = 10): await self._connect_websocket()
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else: else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}") await self._connect_http()
except Exception as e:
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
raise MCPConnectionError(f"连接失败: {e}")
async def disconnect_all(self): async def disconnect(self):
"""断开所有连接""" """断开连接"""
async with self._lock: try:
for client in self._clients.values(): if self._websocket:
await client.disconnect() await self._websocket.close()
self._clients.clear() self._websocket = None
def get_pool_status(self) -> Dict[str, Any]: if self._session:
"""获取连接池状态""" await self._session.close()
return { self._session = None
"total_connections": len(self._clients),
"max_connections": self.max_connections, except Exception as e:
"connections": { logger.error(f"断开连接失败: {e}")
url: client.get_connection_info()
for url, client in self._clients.items() async def _connect_websocket(self):
"""WebSocket 连接"""
headers = self._build_headers()
self._websocket = await websockets.connect(
self.server_url,
extra_headers=headers,
timeout=self.timeout
)
# 启动消息处理
asyncio.create_task(self._handle_websocket_messages())
# 发送初始化消息
await self._send_initialize()
async def _connect_http(self):
"""HTTP 连接"""
headers = self._build_headers()
timeout = aiohttp.ClientTimeout(total=self.timeout)
self._session = aiohttp.ClientSession(
headers=headers,
timeout=timeout
)
# 对于 ModelScope MCP 服务,需要先发送初始化请求
if "modelscope.net" in self.server_url:
await self._initialize_modelscope_session()
async def _initialize_modelscope_session(self):
"""初始化 ModelScope MCP 会话"""
init_request = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
} }
} }
try:
async with self._session.post(
self.server_url,
json=init_request
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
init_response = await response.json()
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
# 获取 session ID
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
if session_id:
self._session.headers.update({"Mcp-Session-Id": session_id})
# 发送 initialized 通知
initialized_notification = {
"jsonrpc": "2.0",
"method": "notifications/initialized"
}
async with self._session.post(
self.server_url,
json=initialized_notification
) as notif_response:
pass
except aiohttp.ClientError as e:
raise MCPConnectionError(f"初始化连接失败: {e}")
def _build_headers(self) -> Dict[str, str]:
"""构建请求头"""
headers = {
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
}
# 添加认证头
auth_config = self.connection_config.get("auth_config", {})
auth_type = self.connection_config.get("auth_type", "none")
if auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "api_key":
key = auth_config.get("api_key")
header_name = auth_config.get("key_name", "X-API-Key")
if key:
headers[header_name] = key
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
return headers
async def _send_initialize(self):
"""发送初始化消息"""
init_message = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {}},
"clientInfo": {
"name": "MemoryBear",
"version": "1.0.0"
}
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
async def _handle_websocket_messages(self):
"""处理 WebSocket 消息"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
data = json.loads(message)
# 处理响应
if "id" in data:
request_id = str(data["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(data)
except ConnectionClosed:
break
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理异常: {e}")
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
"""调用工具"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
async def list_tools(self) -> List[Dict[str, Any]]:
"""获取工具列表"""
request_data = {
"jsonrpc": "2.0",
"id": self._get_request_id(),
"method": "tools/list",
"params": {}
}
if self.is_websocket:
response = await self._send_websocket_request(request_data)
else:
response = await self._send_http_request(request_data)
if "error" in response:
error = response["error"]
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送WebSocket请求"""
request_id = str(request_data["id"])
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
await self._websocket.send(json.dumps(request_data))
response = await asyncio.wait_for(future, timeout=self.timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""发送HTTP请求"""
try:
async with self._session.post(
self.server_url,
json=request_data
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
def _get_request_id(self) -> str:
"""获取请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"

View File

@@ -1,6 +1,4 @@
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控""" """MCP服务管理器 - 简化版本"""
import asyncio
import time
import uuid import uuid
from typing import Dict, Any, List, Optional, Tuple from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime from datetime import datetime
@@ -8,136 +6,53 @@ from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool from app.core.tools.mcp.base import MCPToolManager
logger = get_business_logger() logger = get_business_logger()
class MCPServiceManager: class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期""" """MCP服务管理器 - 简化版本,主要用于工具创建"""
def __init__(self, db: Session = None): def __init__(self, db: Session = None):
"""初始化MCP服务管理器
Args:
db: 数据库会话(可选)
"""
self.db = db self.db = db
if db: self.tool_manager = MCPToolManager(db) if db else None
self.connection_pool = MCPConnectionPool(max_connections=20)
else:
self.connection_pool = None
# 服务状态管理 async def create_mcp_tool(
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
# 配置
self.health_check_interval = 60 # 健康检查间隔(秒)
self.max_retry_attempts = 3 # 最大重试次数
self.retry_delay = 5 # 重试延迟(秒)
# 状态
self._running = False
self._manager_task = None
async def start(self):
"""启动服务管理器"""
if self._running:
return
self._running = True
logger.info("MCP服务管理器启动")
# 加载现有服务
await self._load_existing_services()
# 启动管理任务
self._manager_task = asyncio.create_task(self._management_loop())
async def stop(self):
"""停止服务管理器"""
if not self._running:
return
self._running = False
logger.info("MCP服务管理器停止")
# 停止管理任务
if self._manager_task:
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
# 停止所有监控任务
for task in self._monitoring_tasks.values():
task.cancel()
if self._monitoring_tasks:
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
self._monitoring_tasks.clear()
# 断开所有连接
await self.connection_pool.disconnect_all()
async def register_service(
self, self,
server_url: str, server_url: str,
connection_config: Dict[str, Any], connection_config: Dict[str, Any],
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
tool_name: str,
service_name: str = None service_name: str = None
) -> Tuple[bool, str, Optional[str]]: ) -> Tuple[bool, str, Optional[str]]:
"""注册MCP服务 """创建单个MCP工具
Args: Args:
server_url: 服务器URL server_url: 服务器URL
connection_config: 连接配置 connection_config: 连接配置
tenant_id: 租户ID tenant_id: 租户ID
service_name: 服务名称(可选) tool_name: 具体工具名称
service_name: 服务名称
Returns: Returns:
(是否成功, 服务ID或错误信息, 错误详情) (是否成功, 工具ID或错误信息, 错误详情)
""" """
try: try:
# 检查服务是否已存在 if not service_name:
existing_service = self.db.query(MCPToolConfig).filter( service_name = f"mcp_{tool_name}"
MCPToolConfig.server_url == server_url
).first()
if existing_service:
return False, "服务已存在", f"URL {server_url} 已被注册"
# 测试连接
try:
client = MCPClient(server_url, connection_config)
if not await client.connect():
return False, "连接测试失败", "无法连接到MCP服务器"
# 获取可用工具
available_tools = await client.list_tools()
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
await client.disconnect()
except Exception as e:
return False, "连接测试失败", str(e)
# 创建工具配置 # 创建工具配置
if not service_name:
service_name = f"mcp_service_{server_url.split('/')[-1]}"
tool_config = ToolConfig( tool_config = ToolConfig(
name=service_name, name=service_name,
description=f"MCP服务 - {server_url}", description=f"MCP工具: {tool_name}",
tool_type=ToolType.MCP.value, tool_type=ToolType.MCP.value,
tenant_id=tenant_id, tenant_id=tenant_id,
version="1.0.0", status=ToolStatus.AVAILABLE.value,
config_data={ config_data={
"server_url": server_url, "server_url": server_url,
"connection_config": connection_config "connection_config": connection_config,
"tool_name": tool_name
} }
) )
@@ -149,460 +64,22 @@ class MCPServiceManager:
id=tool_config.id, id=tool_config.id,
server_url=server_url, server_url=server_url,
connection_config=connection_config, connection_config=connection_config,
available_tools=tool_names, available_tools=[tool_name],
health_status="healthy", health_status="unknown",
last_health_check=datetime.now() last_health_check=datetime.now()
) )
self.db.add(mcp_config) self.db.add(mcp_config)
self.db.commit() self.db.commit()
service_id = str(tool_config.id) logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})")
return True, str(tool_config.id), None
# 添加到内存管理
self._services[service_id] = {
"id": service_id,
"server_url": server_url,
"connection_config": connection_config,
"tenant_id": tenant_id,
"available_tools": tool_names,
"status": "healthy",
"last_health_check": time.time(),
"retry_count": 0,
"created_at": time.time()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
return True, service_id, None
except Exception as e: except Exception as e:
self.db.rollback() self.db.rollback()
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}") logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}")
return False, "注册失败", str(e) return False, "创建失败", str(e)
async def unregister_service(self, service_id: str) -> Tuple[bool, str]: def get_tool_manager(self) -> MCPToolManager:
"""注销MCP服务 """获取工具管理器实例"""
return self.tool_manager
Args:
service_id: 服务ID
Returns:
(是否成功, 错误信息)
"""
try:
# 从数据库删除
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if not tool_config:
return False, "服务不存在"
self.db.delete(tool_config)
self.db.commit()
# 停止监控
await self._stop_service_monitoring(service_id)
# 从内存移除
if service_id in self._services:
del self._services[service_id]
logger.info(f"MCP服务注销成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def update_service(
self,
service_id: str,
connection_config: Dict[str, Any] = None,
enabled: bool = None
) -> Tuple[bool, str]:
"""更新MCP服务配置
Args:
service_id: 服务ID
connection_config: 新的连接配置
enabled: 是否启用
Returns:
(是否成功, 错误信息)
"""
try:
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if not mcp_config:
return False, "服务不存在"
tool_config = mcp_config.base_config
if connection_config is not None:
mcp_config.connection_config = connection_config
tool_config.config_data["connection_config"] = connection_config
if enabled is not None:
tool_config.is_enabled = enabled
self.db.commit()
# 更新内存状态
if service_id in self._services:
if connection_config is not None:
self._services[service_id]["connection_config"] = connection_config
# 如果配置有变化,重启监控
if connection_config is not None:
await self._restart_service_monitoring(service_id)
logger.info(f"MCP服务更新成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
"""获取服务状态
Args:
service_id: 服务ID
Returns:
服务状态信息
"""
if service_id not in self._services:
return None
service_info = self._services[service_id].copy()
# 添加实时健康检查
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
service_info["real_time_health"] = health_status
except Exception as e:
service_info["real_time_health"] = {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
return service_info
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
"""列出所有服务
Args:
tenant_id: 租户ID过滤
Returns:
服务列表
"""
services = []
for service_id, service_info in self._services.items():
if tenant_id and service_info["tenant_id"] != tenant_id:
continue
services.append(service_info.copy())
return services
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的可用工具
Args:
service_id: 服务ID
Returns:
工具列表
"""
if service_id not in self._services:
return []
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
tools = await client.list_tools()
# 更新缓存的工具列表
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.available_tools = tool_names
self.db.commit()
return tools
except Exception as e:
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
return []
async def call_service_tool(
self,
service_id: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: int = 30
) -> Dict[str, Any]:
"""调用服务工具
Args:
service_id: 服务ID
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
执行结果
"""
if service_id not in self._services:
raise ValueError(f"服务不存在: {service_id}")
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
result = await client.call_tool(tool_name, arguments, timeout)
# 更新服务状态为健康
service_info["status"] = "healthy"
service_info["last_health_check"] = time.time()
service_info["retry_count"] = 0
return result
except Exception as e:
# 更新服务状态为错误
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
raise
async def _load_existing_services(self):
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.status == ToolStatus.AVAILABLE.value,
ToolConfig.tool_type == ToolType.MCP.value
).all()
for mcp_config in mcp_configs:
tool_config = mcp_config.base_config
service_id = str(mcp_config.id)
self._services[service_id] = {
"id": service_id,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"tenant_id": tool_config.tenant_id,
"available_tools": mcp_config.available_tools or [],
"status": mcp_config.health_status or "unknown",
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
"retry_count": 0,
"created_at": tool_config.created_at.timestamp()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
except Exception as e:
logger.error(f"加载现有服务失败: {e}")
async def _start_service_monitoring(self, service_id: str):
"""启动服务监控"""
if service_id in self._monitoring_tasks:
return
task = asyncio.create_task(self._monitor_service(service_id))
self._monitoring_tasks[service_id] = task
async def _stop_service_monitoring(self, service_id: str):
"""停止服务监控"""
if service_id in self._monitoring_tasks:
task = self._monitoring_tasks.pop(service_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _restart_service_monitoring(self, service_id: str):
"""重启服务监控"""
await self._stop_service_monitoring(service_id)
await self._start_service_monitoring(service_id)
async def _monitor_service(self, service_id: str):
"""监控单个服务"""
try:
while self._running and service_id in self._services:
service_info = self._services[service_id]
try:
# 执行健康检查
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
if health_status["healthy"]:
# 服务健康
service_info["status"] = "healthy"
service_info["retry_count"] = 0
# 更新工具列表
try:
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
except Exception as e:
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
else:
# 服务不健康
service_info["status"] = "unhealthy"
service_info["last_error"] = health_status.get("error", "健康检查失败")
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
# 更新数据库
await self._update_service_health_in_db(service_id, health_status)
except Exception as e:
# 监控异常
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
# 如果重试次数过多,暂停监控
if service_info["retry_count"] >= self.max_retry_attempts:
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
service_info["retry_count"] = 0 # 重置重试计数
# 等待下次检查
await asyncio.sleep(self.health_check_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
"""更新数据库中的服务健康状态"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.now()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")
else:
mcp_config.error_message = None
self.db.commit()
except Exception as e:
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
self.db.rollback()
async def _management_loop(self):
"""管理循环 - 处理服务清理等任务"""
try:
while self._running:
# 清理失效的服务
await self._cleanup_failed_services()
# 等待下次循环
await asyncio.sleep(300) # 5分钟
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"管理循环异常: {e}")
async def _cleanup_failed_services(self):
"""清理长期失效的服务"""
try:
current_time = time.time()
cleanup_threshold = 24 * 60 * 60 # 24小时
services_to_cleanup = []
for service_id, service_info in self._services.items():
# 检查服务是否长期失效
if (service_info["status"] in ["error", "unhealthy"] and
current_time - service_info["last_health_check"] > cleanup_threshold):
services_to_cleanup.append(service_id)
for service_id in services_to_cleanup:
logger.warning(f"清理长期失效的服务: {service_id}")
# 停止监控但不删除数据库记录
await self._stop_service_monitoring(service_id)
# 标记为禁用
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if tool_config:
tool_config.is_enabled = False
self.db.commit()
# 从内存移除
del self._services[service_id]
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {
"running": self._running,
"total_services": len(self._services),
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
"monitoring_tasks": len(self._monitoring_tasks),
"connection_pool_status": self.connection_pool.get_pool_status()
}

View File

@@ -77,7 +77,7 @@ class AppChatService:
tool_service = ToolService(self.db) tool_service = ToolService(self.db)
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(config, 'tools') and config.tools: if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools: for tool_config in config.tools:
if tool_config.get("enabled", False): if tool_config.get("enabled", False):
# 根据工具名称查找工具实例 # 根据工具名称查找工具实例
@@ -109,20 +109,21 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id) memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool) tools.append(memory_tool)
# web_tools = config.tools if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
# web_search_choice = web_tools.get("web_search", {}) web_tools = config.tools
# web_search_enable = web_search_choice.get("enabled", False) web_search_choice = web_tools.get("web_search", {})
# if web_search == True: web_search_enable = web_search_choice.get("enabled", False)
# if web_search_enable == True: if web_search == True:
# search_tool = create_web_search_tool({}) if web_search_enable == True:
# tools.append(search_tool) search_tool = create_web_search_tool({})
# tools.append(search_tool)
# logger.debug(
# "已添加网络搜索工具", logger.debug(
# extra={ "已添加网络搜索工具",
# "tool_count": len(tools) extra={
# } "tool_count": len(tools)
# ) }
)
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters
@@ -226,7 +227,7 @@ class AppChatService:
# 获取工具服务 # 获取工具服务
tool_service = ToolService(self.db) tool_service = ToolService(self.db)
if hasattr(config, 'tools') and config.tools: if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
for tool_config in config.tools: for tool_config in config.tools:
if tool_config.get("enabled", False): if tool_config.get("enabled", False):
# 根据工具名称查找工具实例 # 根据工具名称查找工具实例
@@ -258,20 +259,21 @@ class AppChatService:
memory_tool = create_long_term_memory_tool(memory_config, user_id) memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool) tools.append(memory_tool)
# web_tools = config.tools if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
# web_search_choice = web_tools.get("web_search", {}) web_tools = config.tools
# web_search_enable = web_search_choice.get("enabled", False) web_search_choice = web_tools.get("web_search", {})
# if web_search == True: web_search_enable = web_search_choice.get("enabled", False)
# if web_search_enable == True: if web_search == True:
# search_tool = create_web_search_tool({}) if web_search_enable == True:
# tools.append(search_tool) search_tool = create_web_search_tool({})
# tools.append(search_tool)
# logger.debug(
# "已添加网络搜索工具", logger.debug(
# extra={ "已添加网络搜索工具",
# "tool_count": len(tools) extra={
# } "tool_count": len(tools)
# ) }
)
# 获取模型参数 # 获取模型参数
model_parameters = config.model_parameters model_parameters = config.model_parameters

View File

@@ -297,19 +297,35 @@ class DraftRunService:
tool_service = ToolService(self.db) tool_service = ToolService(self.db)
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools: if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
for tool_config in agent_config.tools: if hasattr(agent_config, 'tools') and agent_config.tools:
if tool_config.get("enabled", False): for tool_config in agent_config.tools:
# 根据工具名称查找工具实例 if tool_config.get("enabled", False):
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), # 根据工具名称查找工具实例
ToolRepository.get_tenant_id_by_workspace_id( tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
self.db, str(workspace_id))) ToolRepository.get_tenant_id_by_workspace_id(
if tool_instance: self.db, str(workspace_id)))
if tool_instance.name == "baidu_search_tool" and not web_search: if tool_instance:
continue if tool_instance.name == "baidu_search_tool" and not web_search:
# 转换为LangChain工具 continue
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) # 转换为LangChain工具
tools.append(langchain_tool) langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具 # 添加知识库检索工具
if agent_config.knowledge_retrieval: if agent_config.knowledge_retrieval:
@@ -507,7 +523,7 @@ class DraftRunService:
tool_service = ToolService(self.db) tool_service = ToolService(self.db)
# 从配置中获取启用的工具 # 从配置中获取启用的工具
if hasattr(agent_config, 'tools') and agent_config.tools: if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
for tool_config in agent_config.tools: for tool_config in agent_config.tools:
if tool_config.get("enabled", False): if tool_config.get("enabled", False):
# 根据工具名称查找工具实例 # 根据工具名称查找工具实例
@@ -520,6 +536,22 @@ class DraftRunService:
# 转换为LangChain工具 # 转换为LangChain工具
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
tools.append(langchain_tool) tools.append(langchain_tool)
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
web_tools = agent_config.tools
web_search_choice = web_tools.get("web_search", {})
web_search_enable = web_search_choice.get("enabled", False)
if web_search == True:
if web_search_enable == True:
search_tool = create_web_search_tool({})
tools.append(search_tool)
logger.debug(
"已添加网络搜索工具",
extra={
"tool_count": len(tools)
}
)
# 添加知识库检索工具 # 添加知识库检索工具
if agent_config.knowledge_retrieval: if agent_config.knowledge_retrieval:

View File

@@ -8,7 +8,7 @@ from datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.tools.mcp import MCPClient from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
from app.repositories.tool_repository import ( from app.repositories.tool_repository import (
ToolRepository, BuiltinToolRepository, CustomToolRepository, ToolRepository, BuiltinToolRepository, CustomToolRepository,
MCPToolRepository, ToolExecutionRepository MCPToolRepository, ToolExecutionRepository
@@ -43,6 +43,9 @@ class ToolService:
self.db = db self.db = db
self._tool_cache: Dict[str, BaseTool] = {} self._tool_cache: Dict[str, BaseTool] = {}
# MCP管理器
self.mcp_tool_manager = MCPToolManager(db)
# 初始化仓储 # 初始化仓储
self.tool_repo = ToolRepository() self.tool_repo = ToolRepository()
self.builtin_repo = BuiltinToolRepository() self.builtin_repo = BuiltinToolRepository()
@@ -675,23 +678,85 @@ class ToolService:
return [] return []
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
"""获取MCP工具的方法""" """获取MCP工具的方法和参数"""
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
if not mcp_config: if not mcp_config:
return [] return []
available_tools = mcp_config.available_tools or [] available_tools = mcp_config.available_tools or []
if not available_tools: if not available_tools:
return [] # 如果没有工具列表,尝试同步
try:
success, tools, _ = await self.mcp_tool_manager.discover_tools(
mcp_config.server_url, mcp_config.connection_config or {}
)
if success:
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
mcp_config.available_tools = tool_names
self.db.commit()
available_tools = tool_names
except Exception as e:
logger.error(f"同步MCP工具列表失败: {e}")
return []
methods = [] methods = []
for tool_name in available_tools:
methods.append({ # 获取工具详细信息
"method_id": tool_name, try:
"name": tool_name, success, tools, _ = await self.mcp_tool_manager.discover_tools(
"description": f"MCP工具: {tool_name}", mcp_config.server_url, mcp_config.connection_config or {}
"parameters": [] # MCP工具参数需要动态获取 )
})
if success:
tools_dict = {tool.get("name"): tool for tool in tools if tool.get("name")}
for tool_name in available_tools:
tool_info = tools_dict.get(tool_name, {})
# 解析工具参数
parameters = []
input_schema = tool_info.get("inputSchema", {})
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
for param_name, param_def in properties.items():
parameters.append({
"name": param_name,
"type": param_def.get("type", "string"),
"description": param_def.get("description", ""),
"required": param_name in required_fields,
"default": param_def.get("default"),
"enum": param_def.get("enum"),
"minimum": param_def.get("minimum"),
"maximum": param_def.get("maximum")
})
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": tool_info.get("description", f"MCP工具: {tool_name}"),
"parameters": parameters
})
else:
# 如果无法获取详细信息,返回基本信息
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": []
})
except Exception as e:
logger.error(f"获取MCP工具详细信息失败: {e}")
# 返回基本信息
for tool_name in available_tools:
methods.append({
"method_id": tool_name,
"name": tool_name,
"description": f"MCP工具: {tool_name}",
"parameters": []
})
return methods return methods
@@ -812,10 +877,14 @@ class ToolService:
if not mcp_config: if not mcp_config:
return None return None
# 从配置中获取特定工具名称
tool_name = config.config_data.get("tool_name")
tool_config = { tool_config = {
"server_url": mcp_config.server_url, "server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {}, "connection_config": mcp_config.connection_config or {},
"available_tools": mcp_config.available_tools or [] "available_tools": mcp_config.available_tools or [],
"tool_name": tool_name # 指定具体工具
} }
return MCPTool(str(config.id), tool_config) return MCPTool(str(config.id), tool_config)
@@ -1071,71 +1140,59 @@ class ToolService:
return {} return {}
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]: async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
"""测试MCP连接""" """测试MCP连接并自动同步工具列表"""
try: try:
mcp_config = self.db.query(MCPToolConfig).filter( mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
MCPToolConfig.id == config.id
).first()
if not mcp_config: if not mcp_config:
return {"success": False, "message": "MCP配置不存在"} return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {}) # 使用集成的MCP管理器测试连接
test_result = await self.mcp_tool_manager.test_tool_connection(
mcp_config.server_url, mcp_config.connection_config or {}
)
if await client.connect(): if test_result["success"]:
try: # 连接成功,自动同步工具列表
# tools = await client.list_tools() success, tools, error = await self.mcp_tool_manager.discover_tools(
await client.disconnect() mcp_config.server_url, mcp_config.connection_config or {}
)
# 更新连接状态 if success:
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
# 更新数据库
mcp_config.available_tools = tool_names
mcp_config.last_health_check = datetime.now() mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy" mcp_config.health_status = "healthy"
mcp_config.error_message = None mcp_config.error_message = None
config.status = ToolStatus.AVAILABLE.value
# 更新工具状态
self._update_tool_status(config)
self.db.commit() self.db.commit()
return { return {
"success": True, "success": True,
"message": "MCP连接成功", "message": "MCP连接成功并同步工具列表",
# "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)} "details": {
"details": {"server_url": mcp_config.server_url} "server_url": mcp_config.server_url,
"tools_count": len(tool_names),
"tools": tool_names
}
} }
except Exception as e: else:
await client.disconnect() return {"success": False, "message": f"同步工具失败: {error}"}
# 更新错误状态
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP功能测试失败: {str(e)}"}
else: else:
# 更新连接失败状态 # 更新错误状态
mcp_config.last_health_check = datetime.now() mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error" mcp_config.health_status = "error"
mcp_config.error_message = "连接失败" mcp_config.error_message = test_result.get("error", "连接失败")
self._update_tool_status(config) config.status = ToolStatus.ERROR.value
self.db.commit() self.db.commit()
return {"success": False, "message": "MCP连接失败"} return test_result
except Exception as e: except Exception as e:
# 更新异常状态 logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}")
mcp_config = self.db.query(MCPToolConfig).filter( return {"success": False, "message": f"测试失败: {str(e)}"}
MCPToolConfig.id == config.id
).first()
if mcp_config:
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = str(e)
self._update_tool_status(config)
self.db.commit()
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
@staticmethod @staticmethod
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]: async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
@@ -1190,57 +1247,44 @@ class ToolService:
# 创建MCP客户端 # 创建MCP客户端
connection_config = mcp_config.connection_config or {} connection_config = mcp_config.connection_config or {}
client = SimpleMCPClient(mcp_config.server_url, connection_config)
client = MCPClient(mcp_config.server_url, connection_config) async with client:
# 获取工具列表
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
if await client.connect(): # 更新数据库
try: mcp_config.available_tools = tool_names
# 获取工具列表 mcp_config.last_health_check = datetime.now()
tools = await client.list_tools() mcp_config.health_status = "healthy"
tool_names = [tool.get("name") for tool in tools if tool.get("name")] mcp_config.error_message = None
# 更新数据库 # 更新工具状态
mcp_config.available_tools = tool_names config.status = ToolStatus.AVAILABLE.value
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "healthy"
mcp_config.error_message = None
# 更新工具状态 self.db.commit()
config.status = ToolStatus.AVAILABLE.value
self.db.commit() return {
"success": True,
"message": "工具列表同步成功",
"tools_count": len(tool_names),
"tools": tool_names
}
await client.disconnect() except Exception as e:
# 更新错误状态
return { try:
"success": True, mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
"message": "工具列表同步成功", if mcp_config:
"tools_count": len(tool_names),
"tools": tool_names
}
except Exception as e:
await client.disconnect()
# 更新错误状态
mcp_config.last_health_check = datetime.now() mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error" mcp_config.health_status = "error"
mcp_config.error_message = str(e) mcp_config.error_message = str(e)
config.status = ToolStatus.ERROR.value config.status = ToolStatus.ERROR.value
self.db.commit() self.db.commit()
except:
pass
return {"success": False, "message": f"获取工具列表失败: {str(e)}"}
else:
# 连接失败
mcp_config.last_health_check = datetime.now()
mcp_config.health_status = "error"
mcp_config.error_message = "连接失败"
config.status = ToolStatus.ERROR.value
self.db.commit()
return {"success": False, "message": "MCP连接失败"}
except Exception as e:
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}") logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
return {"success": False, "message": f"同步失败: {str(e)}"} return {"success": False, "message": f"同步失败: {str(e)}"}