feat(agent tool): mcp tool repair
This commit is contained in:
@@ -215,8 +215,8 @@ async def sync_mcp_tools(
|
||||
"""同步MCP工具列表"""
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=404, detail=result["message"])
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -78,13 +78,20 @@ class LangchainAdapter:
|
||||
|
||||
Args:
|
||||
tool: 内部工具实例
|
||||
operation: 特定操作(适用于有操作的工具)
|
||||
operation: 特定操作(适用于有操作的工具)或MCP工具名称
|
||||
|
||||
Returns:
|
||||
Langchain兼容的工具包装器
|
||||
"""
|
||||
try:
|
||||
if operation and tool.name in ['datetime_tool', 'json_tool']:
|
||||
# 处理MCP工具的特定工具名称
|
||||
if hasattr(tool, 'tool_type') and tool.tool_type.value == 'mcp' and operation:
|
||||
# 为MCP工具创建特定工具名称的实例
|
||||
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
|
||||
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||
return wrapper
|
||||
elif operation and tool.name in ['datetime_tool', 'json_tool']:
|
||||
# 为特定操作创建工具
|
||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||
@@ -106,6 +113,18 @@ class LangchainAdapter:
|
||||
from app.core.tools.builtin.operation_tool import OperationTool
|
||||
return OperationTool(base_tool, operation)
|
||||
|
||||
@staticmethod
|
||||
def _create_mcp_tool_with_name(base_tool: BaseTool, tool_name: str) -> BaseTool:
|
||||
"""为MCP工具创建指定工具名称的实例"""
|
||||
from app.core.tools.mcp.base import MCPTool
|
||||
|
||||
# 创建新的配置,指定具体工具名称
|
||||
new_config = base_tool.config.copy()
|
||||
new_config["tool_name"] = tool_name
|
||||
|
||||
# 创建新的MCP工具实例
|
||||
return MCPTool(f"{base_tool.tool_id}_{tool_name}", new_config)
|
||||
|
||||
@staticmethod
|
||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||
"""批量转换工具
|
||||
|
||||
@@ -1,12 +1,20 @@
|
||||
"""MCP工具模块"""
|
||||
"""MCP 工具模块 - Model Context Protocol 支持"""
|
||||
|
||||
from app.core.tools.mcp.base import MCPTool
|
||||
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
|
||||
from app.core.tools.mcp.service_manager import MCPServiceManager
|
||||
# 主要类导出
|
||||
from .base import MCPTool, MCPToolManager, MCPError
|
||||
from .client import SimpleMCPClient, MCPConnectionError
|
||||
from .service_manager import MCPServiceManager
|
||||
|
||||
__all__ = [
|
||||
# 核心类
|
||||
"MCPTool",
|
||||
"MCPClient",
|
||||
"MCPConnectionPool",
|
||||
"MCPToolManager",
|
||||
"MCPError",
|
||||
|
||||
# 客户端类
|
||||
"SimpleMCPClient",
|
||||
"MCPConnectionError",
|
||||
|
||||
# 服务管理(简化版)
|
||||
"MCPServiceManager"
|
||||
]
|
||||
@@ -1,10 +1,9 @@
|
||||
"""MCP工具基类"""
|
||||
"""MCP工具基类 - 整合版本"""
|
||||
import time
|
||||
from typing import Dict, Any, List
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -14,215 +13,174 @@ class MCPTool(BaseTool):
|
||||
"""MCP工具 - Model Context Protocol工具"""
|
||||
|
||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||
"""初始化MCP工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
config: 工具配置
|
||||
"""
|
||||
super().__init__(tool_id, config)
|
||||
self.server_url = config.get("server_url", "")
|
||||
self.connection_config = config.get("connection_config", {})
|
||||
self.tool_name = config.get("tool_name", "") # 特定工具名称
|
||||
self.tool_schema = config.get("tool_schema", {}) # 工具参数 schema
|
||||
self.available_tools = config.get("available_tools", [])
|
||||
self._client = None
|
||||
self._connected = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""工具名称"""
|
||||
return f"mcp_tool_{self.tool_id[:8]}"
|
||||
return f"mcp_{self.tool_name}" if self.tool_name else f"mcp_tool_{self.tool_id[:8]}"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""工具描述"""
|
||||
return f"MCP工具 - 连接到 {self.server_url}"
|
||||
if self.tool_schema.get("description"):
|
||||
return self.tool_schema["description"]
|
||||
return f"MCP工具: {self.tool_name}" if self.tool_name else f"MCP工具 - 连接到 {self.server_url}"
|
||||
|
||||
@property
|
||||
def tool_type(self) -> ToolType:
|
||||
"""工具类型"""
|
||||
return ToolType.MCP
|
||||
|
||||
@property
|
||||
def parameters(self) -> List[ToolParameter]:
|
||||
"""工具参数定义"""
|
||||
params = []
|
||||
|
||||
# 添加工具选择参数
|
||||
if len(self.available_tools) > 1:
|
||||
params.append(ToolParameter(
|
||||
name="tool_name",
|
||||
type=ParameterType.STRING,
|
||||
description="要调用的MCP工具名称",
|
||||
required=True,
|
||||
enum=self.available_tools
|
||||
))
|
||||
|
||||
# 添加通用参数
|
||||
params.extend([
|
||||
ToolParameter(
|
||||
"""从 MCP 工具 schema 生成参数"""
|
||||
if not self.tool_schema:
|
||||
return [ToolParameter(
|
||||
name="arguments",
|
||||
type=ParameterType.OBJECT,
|
||||
description="工具参数(JSON对象)",
|
||||
description="工具参数",
|
||||
required=False,
|
||||
default={}
|
||||
),
|
||||
ToolParameter(
|
||||
name="timeout",
|
||||
type=ParameterType.INTEGER,
|
||||
description="超时时间(秒)",
|
||||
required=False,
|
||||
default=30,
|
||||
minimum=1,
|
||||
maximum=300
|
||||
)
|
||||
])
|
||||
)]
|
||||
|
||||
# 解析 MCP 工具的 inputSchema
|
||||
input_schema = self.tool_schema.get("inputSchema", {})
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
params = []
|
||||
for param_name, param_def in properties.items():
|
||||
param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
|
||||
|
||||
params.append(ToolParameter(
|
||||
name=param_name,
|
||||
type=param_type,
|
||||
description=param_def.get("description", f"参数: {param_name}"),
|
||||
required=param_name in required_fields,
|
||||
default=param_def.get("default"),
|
||||
enum=param_def.get("enum"),
|
||||
minimum=param_def.get("minimum"),
|
||||
maximum=param_def.get("maximum")
|
||||
))
|
||||
|
||||
return params
|
||||
|
||||
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
|
||||
"""转换 JSON Schema 类型到 ParameterType"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
"integer": ParameterType.INTEGER,
|
||||
"number": ParameterType.NUMBER,
|
||||
"boolean": ParameterType.BOOLEAN,
|
||||
"array": ParameterType.ARRAY,
|
||||
"object": ParameterType.OBJECT
|
||||
}
|
||||
return type_mapping.get(json_type, ParameterType.STRING)
|
||||
|
||||
async def execute(self, **kwargs) -> ToolResult:
|
||||
"""执行MCP工具"""
|
||||
"""执行 MCP 工具"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# 确保连接
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
# 确定要调用的工具
|
||||
tool_name = kwargs.get("tool_name")
|
||||
if not tool_name and len(self.available_tools) == 1:
|
||||
tool_name = self.available_tools[0]
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("必须指定要调用的MCP工具名称")
|
||||
|
||||
if tool_name not in self.available_tools:
|
||||
raise ValueError(f"MCP工具不存在: {tool_name}")
|
||||
|
||||
# 获取参数
|
||||
arguments = kwargs.get("arguments", {})
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
|
||||
# 调用MCP工具
|
||||
result = await self._call_mcp_tool(tool_name, arguments, timeout)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
client = SimpleMCPClient(self.server_url, self.connection_config)
|
||||
|
||||
async with client:
|
||||
# 使用指定的工具名称或默认第一个工具
|
||||
tool_name_to_use = self.tool_name
|
||||
if not tool_name_to_use and self.available_tools:
|
||||
tool_name_to_use = self.available_tools[0]
|
||||
|
||||
if not tool_name_to_use:
|
||||
raise Exception("未指定工具名称且无可用工具")
|
||||
|
||||
result = await client.call_tool(tool_name_to_use, kwargs)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
return ToolResult.success_result(
|
||||
data=result,
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
logger.error(f"MCP工具执行失败: {self.tool_name or 'unknown'}, 错误: {e}")
|
||||
return ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="MCP_ERROR",
|
||||
error_code="MCP_EXECUTION_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
|
||||
class MCPError(Exception):
|
||||
"""MCP 错误基类"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPToolManager:
|
||||
"""MCP 工具管理器 - 简化版本"""
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
def __init__(self, db=None):
|
||||
self.db = db
|
||||
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
|
||||
|
||||
async def discover_tools(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any] = None
|
||||
) -> tuple[bool, List[Dict[str, Any]], str | None]:
|
||||
"""发现 MCP 服务器上的工具"""
|
||||
try:
|
||||
from .client import MCPClient
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
self._client = MCPClient(self.server_url, self.connection_config)
|
||||
|
||||
if await self._client.connect():
|
||||
self._connected = True
|
||||
# 更新可用工具列表
|
||||
await self._update_available_tools()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}")
|
||||
return False
|
||||
client = SimpleMCPClient(server_url, connection_config)
|
||||
|
||||
async with client:
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 缓存工具信息
|
||||
self._tool_cache[server_url] = {
|
||||
"tools": tools,
|
||||
"connection_config": connection_config,
|
||||
"last_updated": time.time()
|
||||
}
|
||||
|
||||
logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
|
||||
return True, tools, None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
error_msg = f"发现工具失败: {e}"
|
||||
logger.error(error_msg)
|
||||
return False, [], error_msg
|
||||
|
||||
async def _update_available_tools(self):
|
||||
"""更新可用工具列表"""
|
||||
async def test_tool_connection(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
if self._client and self._connected:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"更新MCP工具列表失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
from .client import SimpleMCPClient
|
||||
|
||||
self._connected = False
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def get_health_status(self) -> Dict[str, Any]:
|
||||
"""获取MCP服务健康状态"""
|
||||
return {
|
||||
"connected": self._connected,
|
||||
"server_url": self.server_url,
|
||||
"available_tools": self.available_tools,
|
||||
"last_check": time.time()
|
||||
}
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
if not self._client or not self._connected:
|
||||
raise Exception("MCP客户端未连接")
|
||||
|
||||
try:
|
||||
result = await self._client.call_tool(tool_name, arguments, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
try:
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
if self._client:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
def test_connection(self) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
try:
|
||||
# 这里应该实现同步的连接测试
|
||||
# 为了简化,返回基本信息
|
||||
return {
|
||||
"success": bool(self.server_url),
|
||||
"server_url": self.server_url,
|
||||
"connected": self._connected,
|
||||
"available_tools_count": len(self.available_tools),
|
||||
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
|
||||
}
|
||||
client = SimpleMCPClient(server_url, connection_config)
|
||||
|
||||
async with client:
|
||||
tools = await client.list_tools()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tools_count": len(tools),
|
||||
"tools": [tool.get("name") for tool in tools],
|
||||
"message": "连接成功"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
"error": str(e),
|
||||
"message": "连接失败"
|
||||
}
|
||||
@@ -1,9 +1,8 @@
|
||||
"""MCP客户端 - Model Context Protocol客户端实现"""
|
||||
"""MCP客户端 - 简化版本"""
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from urllib.parse import urlparse
|
||||
from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
@@ -18,139 +17,156 @@ class MCPConnectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MCPProtocolError(Exception):
|
||||
"""MCP协议错误"""
|
||||
pass
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP客户端 - 支持HTTP和WebSocket连接"""
|
||||
class SimpleMCPClient:
|
||||
"""简化的 MCP 客户端"""
|
||||
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
"""初始化MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: MCP服务器URL
|
||||
connection_config: 连接配置
|
||||
"""
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
self.timeout = self.connection_config.get("timeout", 30)
|
||||
|
||||
# 解析URL确定连接类型
|
||||
parsed_url = urlparse(server_url)
|
||||
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
|
||||
# 连接状态
|
||||
self._connected = False
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
|
||||
# 请求管理
|
||||
self._request_id = 0
|
||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
||||
|
||||
# 连接池配置
|
||||
self.max_connections = self.connection_config.get("max_connections", 10)
|
||||
self.connection_timeout = self.connection_config.get("timeout", 30)
|
||||
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
|
||||
self.retry_delay = self.connection_config.get("retry_delay", 1)
|
||||
|
||||
# 健康检查
|
||||
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
|
||||
self._health_check_task = None
|
||||
self._last_health_check = None
|
||||
|
||||
# 事件回调
|
||||
self._on_connect_callbacks: List[Callable] = []
|
||||
self._on_disconnect_callbacks: List[Callable] = []
|
||||
self._on_error_callbacks: List[Callable] = []
|
||||
self._pending_requests = {}
|
||||
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器
|
||||
|
||||
Returns:
|
||||
连接是否成功
|
||||
"""
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
"""建立连接"""
|
||||
try:
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"连接MCP服务器: {self.server_url}")
|
||||
|
||||
if self.connection_type == "websocket":
|
||||
success = await self._connect_websocket()
|
||||
if self.is_websocket:
|
||||
await self._connect_websocket()
|
||||
else:
|
||||
success = await self._connect_http()
|
||||
|
||||
if success:
|
||||
self._connected = True
|
||||
await self._start_health_check()
|
||||
await self._notify_connect_callbacks()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
|
||||
return success
|
||||
|
||||
await self._connect_http()
|
||||
except Exception as e:
|
||||
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
|
||||
await self._notify_error_callbacks(e)
|
||||
return False
|
||||
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
||||
raise MCPConnectionError(f"连接失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接
|
||||
|
||||
Returns:
|
||||
断开是否成功
|
||||
"""
|
||||
async def disconnect(self):
|
||||
"""断开连接"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return True
|
||||
|
||||
logger.info(f"断开MCP服务器连接: {self.server_url}")
|
||||
|
||||
# 停止健康检查
|
||||
await self._stop_health_check()
|
||||
|
||||
# 取消所有待处理的请求
|
||||
for future in self._pending_requests.values():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
self._pending_requests.clear()
|
||||
|
||||
# 断开连接
|
||||
if self.connection_type == "websocket" and self._websocket:
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
elif self._session:
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
logger.error(f"断开连接失败: {e}")
|
||||
|
||||
def _build_auth_headers(self) -> Dict[str, str]:
|
||||
"""构建认证头"""
|
||||
headers = {}
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
async def _connect_websocket(self):
|
||||
"""WebSocket 连接"""
|
||||
headers = self._build_headers()
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 启动消息处理
|
||||
asyncio.create_task(self._handle_websocket_messages())
|
||||
|
||||
# 发送初始化消息
|
||||
await self._send_initialize()
|
||||
|
||||
async def _connect_http(self):
|
||||
"""HTTP 连接"""
|
||||
headers = self._build_headers()
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 对于 ModelScope MCP 服务,需要先发送初始化请求
|
||||
if "modelscope.net" in self.server_url:
|
||||
await self._initialize_modelscope_session()
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
"""初始化 ModelScope MCP 会话"""
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=init_request
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
|
||||
init_response = await response.json()
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
# 获取 session ID
|
||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||
if session_id:
|
||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
# 发送 initialized 通知
|
||||
initialized_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=initialized_notification
|
||||
) as notif_response:
|
||||
pass
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
|
||||
if auth_type == "api_key":
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif auth_type == "bearer_token":
|
||||
if auth_type == "bearer_token":
|
||||
token = auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
elif auth_type == "api_key":
|
||||
key = auth_config.get("api_key")
|
||||
header_name = auth_config.get("key_name", "X-API-Key")
|
||||
if key:
|
||||
headers[header_name] = key
|
||||
elif auth_type == "basic_auth":
|
||||
username = auth_config.get("username")
|
||||
password = auth_config.get("password")
|
||||
@@ -161,160 +177,63 @@ class MCPClient:
|
||||
|
||||
return headers
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
extra_headers.update(auth_headers)
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=extra_headers,
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
# 启动消息监听
|
||||
asyncio.create_task(self._websocket_message_handler())
|
||||
|
||||
# 发送初始化消息
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "ToolManagementSystem",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
async def _send_initialize(self):
|
||||
"""发送初始化消息"""
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.connection_timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if init_response.get("error", None) is not None:
|
||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket连接失败: {e}")
|
||||
return False
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
async def _connect_http(self) -> bool:
|
||||
"""建立HTTP连接"""
|
||||
try:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
headers.update(auth_headers)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
|
||||
|
||||
async with self._session.get(test_url) as response:
|
||||
if response.status == 200:
|
||||
return True
|
||||
else:
|
||||
# 尝试根路径
|
||||
async with self._session.get(self.server_url) as root_response:
|
||||
return root_response.status < 400
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"HTTP连接失败: {e}")
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
return False
|
||||
|
||||
async def _websocket_message_handler(self):
|
||||
"""WebSocket消息处理器"""
|
||||
async def _handle_websocket_messages(self):
|
||||
"""处理 WebSocket 消息"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
await self._handle_message(json.loads(message))
|
||||
data = json.loads(message)
|
||||
|
||||
# 处理响应
|
||||
if "id" in data:
|
||||
request_id = str(data["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"解析WebSocket消息失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理器异常: {e}")
|
||||
finally:
|
||||
self._connected = False
|
||||
await self._notify_disconnect_callbacks()
|
||||
logger.error(f"WebSocket消息处理异常: {e}")
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""处理收到的消息"""
|
||||
try:
|
||||
# 检查是否是响应消息
|
||||
if "id" in message:
|
||||
request_id = str(message["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
|
||||
# 处理通知消息
|
||||
elif "method" in message:
|
||||
await self._handle_notification(message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _handle_notification(message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
|
||||
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
|
||||
|
||||
# 这里可以根据需要处理特定的通知
|
||||
# 例如:工具列表更新、服务器状态变化等
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
|
||||
"""调用MCP工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
@@ -322,343 +241,69 @@ class MCPClient:
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError(f"工具调用超时: {tool_name}")
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取可用工具列表
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
|
||||
Raises:
|
||||
MCPConnectionError: 连接错误
|
||||
MCPProtocolError: 协议错误
|
||||
"""
|
||||
if not self._connected:
|
||||
raise MCPConnectionError("MCP客户端未连接")
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "tools/list"
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPProtocolError("获取工具列表超时")
|
||||
|
||||
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
"""发送请求并等待响应
|
||||
|
||||
Args:
|
||||
request_data: 请求数据
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
if self.connection_type == "websocket":
|
||||
request_id = str(request_data["id"])
|
||||
return await self._send_websocket_request(request_data, request_id, timeout)
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
return await self._send_http_request(request_data, timeout)
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
if not self._websocket or self._websocket.closed:
|
||||
raise MCPConnectionError("WebSocket连接已断开")
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
# 创建Future等待响应
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
request_id = str(request_data["id"])
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
|
||||
# 等待响应
|
||||
response = await asyncio.wait_for(future, timeout=timeout)
|
||||
response = await asyncio.wait_for(future, timeout=self.timeout)
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
if not self._session:
|
||||
raise MCPConnectionError("HTTP会话未建立")
|
||||
|
||||
try:
|
||||
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
|
||||
|
||||
async with self._session.post(
|
||||
url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
self.server_url,
|
||||
json=request_data
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
else:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=request_data,
|
||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
||||
) as root_response:
|
||||
if root_response.status != 200:
|
||||
error_text = await root_response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""执行健康检查
|
||||
|
||||
Returns:
|
||||
健康状态信息
|
||||
"""
|
||||
try:
|
||||
if not self._connected:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": "未连接",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
# 发送ping请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_next_request_id(),
|
||||
"method": "ping"
|
||||
}
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = round((time.time() - start_time) * 1000)
|
||||
|
||||
self._last_health_check = round(time.time() * 1000)
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
"response_time": response_time,
|
||||
"timestamp": self._last_health_check,
|
||||
"server_info": response.get("result", {})
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
async def _start_health_check(self):
|
||||
"""启动健康检查任务"""
|
||||
if self.health_check_interval > 0:
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
|
||||
async def _stop_health_check(self):
|
||||
"""停止健康检查任务"""
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._health_check_task = None
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""健康检查循环"""
|
||||
try:
|
||||
while self._connected:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
if self._connected:
|
||||
health_status = await self.health_check()
|
||||
if not health_status["healthy"]:
|
||||
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
|
||||
# 可以在这里实现重连逻辑
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"健康检查循环异常: {e}")
|
||||
|
||||
def _get_next_request_id(self) -> str:
|
||||
"""获取下一个请求ID"""
|
||||
def _get_request_id(self) -> str:
|
||||
"""获取请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
|
||||
# 事件回调管理
|
||||
def on_connect(self, callback: Callable):
|
||||
"""注册连接回调"""
|
||||
self._on_connect_callbacks.append(callback)
|
||||
|
||||
def on_disconnect(self, callback: Callable):
|
||||
"""注册断开连接回调"""
|
||||
self._on_disconnect_callbacks.append(callback)
|
||||
|
||||
def on_error(self, callback: Callable):
|
||||
"""注册错误回调"""
|
||||
self._on_error_callbacks.append(callback)
|
||||
|
||||
async def _notify_connect_callbacks(self):
|
||||
"""通知连接回调"""
|
||||
for callback in self._on_connect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_disconnect_callbacks(self):
|
||||
"""通知断开连接回调"""
|
||||
for callback in self._on_disconnect_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback()
|
||||
else:
|
||||
callback()
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接回调执行失败: {e}")
|
||||
|
||||
async def _notify_error_callbacks(self, error: Exception):
|
||||
"""通知错误回调"""
|
||||
for callback in self._on_error_callbacks:
|
||||
try:
|
||||
if asyncio.iscoroutinefunction(callback):
|
||||
await callback(error)
|
||||
else:
|
||||
callback(error)
|
||||
except Exception as e:
|
||||
logger.error(f"错误回调执行失败: {e}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""检查是否已连接"""
|
||||
return self._connected
|
||||
|
||||
@property
|
||||
def last_health_check(self) -> Optional[float]:
|
||||
"""获取最后一次健康检查时间"""
|
||||
return self._last_health_check
|
||||
|
||||
def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""获取连接信息"""
|
||||
return {
|
||||
"server_url": self.server_url,
|
||||
"connection_type": self.connection_type,
|
||||
"connected": self._connected,
|
||||
"last_health_check": self._last_health_check,
|
||||
"pending_requests": len(self._pending_requests),
|
||||
"config": self.connection_config
|
||||
}
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
|
||||
class MCPConnectionPool:
|
||||
"""MCP连接池 - 管理多个MCP客户端连接"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
"""初始化连接池
|
||||
|
||||
Args:
|
||||
max_connections: 最大连接数
|
||||
"""
|
||||
self.max_connections = max_connections
|
||||
self._clients: Dict[str, MCPClient] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
|
||||
"""获取或创建MCP客户端
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
|
||||
Returns:
|
||||
MCP客户端实例
|
||||
"""
|
||||
async with self._lock:
|
||||
if server_url in self._clients:
|
||||
client = self._clients[server_url]
|
||||
if client.is_connected:
|
||||
return client
|
||||
else:
|
||||
# 尝试重连
|
||||
if await client.connect():
|
||||
return client
|
||||
else:
|
||||
# 移除失效的客户端
|
||||
del self._clients[server_url]
|
||||
|
||||
# 检查连接数限制
|
||||
if len(self._clients) >= self.max_connections:
|
||||
# 移除最旧的连接
|
||||
oldest_url = next(iter(self._clients))
|
||||
await self._clients[oldest_url].disconnect()
|
||||
del self._clients[oldest_url]
|
||||
|
||||
# 创建新客户端
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if await client.connect():
|
||||
self._clients[server_url] = client
|
||||
return client
|
||||
else:
|
||||
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
|
||||
|
||||
async def disconnect_all(self):
|
||||
"""断开所有连接"""
|
||||
async with self._lock:
|
||||
for client in self._clients.values():
|
||||
await client.disconnect()
|
||||
self._clients.clear()
|
||||
|
||||
def get_pool_status(self) -> Dict[str, Any]:
|
||||
"""获取连接池状态"""
|
||||
return {
|
||||
"total_connections": len(self._clients),
|
||||
"max_connections": self.max_connections,
|
||||
"connections": {
|
||||
url: client.get_connection_info()
|
||||
for url, client in self._clients.items()
|
||||
}
|
||||
}
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
@@ -1,6 +1,4 @@
|
||||
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
|
||||
import asyncio
|
||||
import time
|
||||
"""MCP服务管理器 - 简化版本"""
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
@@ -8,136 +6,53 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
|
||||
from app.core.tools.mcp.base import MCPToolManager
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class MCPServiceManager:
|
||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
||||
"""MCP服务管理器 - 简化版本,主要用于工具创建"""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
"""初始化MCP服务管理器
|
||||
|
||||
Args:
|
||||
db: 数据库会话(可选)
|
||||
"""
|
||||
self.db = db
|
||||
if db:
|
||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
||||
else:
|
||||
self.connection_pool = None
|
||||
|
||||
# 服务状态管理
|
||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
||||
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
|
||||
|
||||
# 配置
|
||||
self.health_check_interval = 60 # 健康检查间隔(秒)
|
||||
self.max_retry_attempts = 3 # 最大重试次数
|
||||
self.retry_delay = 5 # 重试延迟(秒)
|
||||
|
||||
# 状态
|
||||
self._running = False
|
||||
self._manager_task = None
|
||||
self.tool_manager = MCPToolManager(db) if db else None
|
||||
|
||||
async def start(self):
|
||||
"""启动服务管理器"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._running = True
|
||||
logger.info("MCP服务管理器启动")
|
||||
|
||||
# 加载现有服务
|
||||
await self._load_existing_services()
|
||||
|
||||
# 启动管理任务
|
||||
self._manager_task = asyncio.create_task(self._management_loop())
|
||||
|
||||
async def stop(self):
|
||||
"""停止服务管理器"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._running = False
|
||||
logger.info("MCP服务管理器停止")
|
||||
|
||||
# 停止管理任务
|
||||
if self._manager_task:
|
||||
self._manager_task.cancel()
|
||||
try:
|
||||
await self._manager_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# 停止所有监控任务
|
||||
for task in self._monitoring_tasks.values():
|
||||
task.cancel()
|
||||
|
||||
if self._monitoring_tasks:
|
||||
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
|
||||
|
||||
self._monitoring_tasks.clear()
|
||||
|
||||
# 断开所有连接
|
||||
await self.connection_pool.disconnect_all()
|
||||
|
||||
async def register_service(
|
||||
async def create_mcp_tool(
|
||||
self,
|
||||
server_url: str,
|
||||
connection_config: Dict[str, Any],
|
||||
tenant_id: uuid.UUID,
|
||||
tool_name: str,
|
||||
service_name: str = None
|
||||
) -> Tuple[bool, str, Optional[str]]:
|
||||
"""注册MCP服务
|
||||
"""创建单个MCP工具
|
||||
|
||||
Args:
|
||||
server_url: 服务器URL
|
||||
connection_config: 连接配置
|
||||
tenant_id: 租户ID
|
||||
service_name: 服务名称(可选)
|
||||
tool_name: 具体工具名称
|
||||
service_name: 服务名称
|
||||
|
||||
Returns:
|
||||
(是否成功, 服务ID或错误信息, 错误详情)
|
||||
(是否成功, 工具ID或错误信息, 错误详情)
|
||||
"""
|
||||
try:
|
||||
# 检查服务是否已存在
|
||||
existing_service = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.server_url == server_url
|
||||
).first()
|
||||
|
||||
if existing_service:
|
||||
return False, "服务已存在", f"URL {server_url} 已被注册"
|
||||
|
||||
# 测试连接
|
||||
try:
|
||||
client = MCPClient(server_url, connection_config)
|
||||
if not await client.connect():
|
||||
return False, "连接测试失败", "无法连接到MCP服务器"
|
||||
|
||||
# 获取可用工具
|
||||
available_tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
except Exception as e:
|
||||
return False, "连接测试失败", str(e)
|
||||
if not service_name:
|
||||
service_name = f"mcp_{tool_name}"
|
||||
|
||||
# 创建工具配置
|
||||
if not service_name:
|
||||
service_name = f"mcp_service_{server_url.split('/')[-1]}"
|
||||
|
||||
tool_config = ToolConfig(
|
||||
name=service_name,
|
||||
description=f"MCP服务 - {server_url}",
|
||||
description=f"MCP工具: {tool_name}",
|
||||
tool_type=ToolType.MCP.value,
|
||||
tenant_id=tenant_id,
|
||||
version="1.0.0",
|
||||
status=ToolStatus.AVAILABLE.value,
|
||||
config_data={
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config
|
||||
"connection_config": connection_config,
|
||||
"tool_name": tool_name
|
||||
}
|
||||
)
|
||||
|
||||
@@ -149,460 +64,22 @@ class MCPServiceManager:
|
||||
id=tool_config.id,
|
||||
server_url=server_url,
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
available_tools=[tool_name],
|
||||
health_status="unknown",
|
||||
last_health_check=datetime.now()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
self.db.commit()
|
||||
|
||||
service_id = str(tool_config.id)
|
||||
|
||||
# 添加到内存管理
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": server_url,
|
||||
"connection_config": connection_config,
|
||||
"tenant_id": tenant_id,
|
||||
"available_tools": tool_names,
|
||||
"status": "healthy",
|
||||
"last_health_check": time.time(),
|
||||
"retry_count": 0,
|
||||
"created_at": time.time()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
|
||||
return True, service_id, None
|
||||
logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})")
|
||||
return True, str(tool_config.id), None
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
|
||||
return False, "注册失败", str(e)
|
||||
logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}")
|
||||
return False, "创建失败", str(e)
|
||||
|
||||
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
|
||||
"""注销MCP服务
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 从数据库删除
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if not tool_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 停止监控
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 从内存移除
|
||||
if service_id in self._services:
|
||||
del self._services[service_id]
|
||||
|
||||
logger.info(f"MCP服务注销成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def update_service(
|
||||
self,
|
||||
service_id: str,
|
||||
connection_config: Dict[str, Any] = None,
|
||||
enabled: bool = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""更新MCP服务配置
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
connection_config: 新的连接配置
|
||||
enabled: 是否启用
|
||||
|
||||
Returns:
|
||||
(是否成功, 错误信息)
|
||||
"""
|
||||
try:
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return False, "服务不存在"
|
||||
|
||||
tool_config = mcp_config.base_config
|
||||
|
||||
if connection_config is not None:
|
||||
mcp_config.connection_config = connection_config
|
||||
tool_config.config_data["connection_config"] = connection_config
|
||||
|
||||
if enabled is not None:
|
||||
tool_config.is_enabled = enabled
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新内存状态
|
||||
if service_id in self._services:
|
||||
if connection_config is not None:
|
||||
self._services[service_id]["connection_config"] = connection_config
|
||||
|
||||
# 如果配置有变化,重启监控
|
||||
if connection_config is not None:
|
||||
await self._restart_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"MCP服务更新成功: {service_id}")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
|
||||
return False, str(e)
|
||||
|
||||
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取服务状态
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
服务状态信息
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return None
|
||||
|
||||
service_info = self._services[service_id].copy()
|
||||
|
||||
# 添加实时健康检查
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
service_info["real_time_health"] = health_status
|
||||
|
||||
except Exception as e:
|
||||
service_info["real_time_health"] = {
|
||||
"healthy": False,
|
||||
"error": str(e),
|
||||
"timestamp": time.time()
|
||||
}
|
||||
|
||||
return service_info
|
||||
|
||||
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
|
||||
"""列出所有服务
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
|
||||
Returns:
|
||||
服务列表
|
||||
"""
|
||||
services = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
if tenant_id and service_info["tenant_id"] != tenant_id:
|
||||
continue
|
||||
|
||||
services.append(service_info.copy())
|
||||
|
||||
return services
|
||||
|
||||
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
|
||||
"""获取服务的可用工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
|
||||
Returns:
|
||||
工具列表
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
return []
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 更新缓存的工具列表
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
|
||||
# 更新数据库
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.available_tools = tool_names
|
||||
self.db.commit()
|
||||
|
||||
return tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def call_service_tool(
|
||||
self,
|
||||
service_id: str,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""调用服务工具
|
||||
|
||||
Args:
|
||||
service_id: 服务ID
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
timeout: 超时时间
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if service_id not in self._services:
|
||||
raise ValueError(f"服务不存在: {service_id}")
|
||||
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
result = await client.call_tool(tool_name, arguments, timeout)
|
||||
|
||||
# 更新服务状态为健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["last_health_check"] = time.time()
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# 更新服务状态为错误
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def _load_existing_services(self):
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.status == ToolStatus.AVAILABLE.value,
|
||||
ToolConfig.tool_type == ToolType.MCP.value
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
tool_config = mcp_config.base_config
|
||||
service_id = str(mcp_config.id)
|
||||
|
||||
self._services[service_id] = {
|
||||
"id": service_id,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"tenant_id": tool_config.tenant_id,
|
||||
"available_tools": mcp_config.available_tools or [],
|
||||
"status": mcp_config.health_status or "unknown",
|
||||
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
|
||||
"retry_count": 0,
|
||||
"created_at": tool_config.created_at.timestamp()
|
||||
}
|
||||
|
||||
# 启动监控
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载现有服务失败: {e}")
|
||||
|
||||
async def _start_service_monitoring(self, service_id: str):
|
||||
"""启动服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
return
|
||||
|
||||
task = asyncio.create_task(self._monitor_service(service_id))
|
||||
self._monitoring_tasks[service_id] = task
|
||||
|
||||
async def _stop_service_monitoring(self, service_id: str):
|
||||
"""停止服务监控"""
|
||||
if service_id in self._monitoring_tasks:
|
||||
task = self._monitoring_tasks.pop(service_id)
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _restart_service_monitoring(self, service_id: str):
|
||||
"""重启服务监控"""
|
||||
await self._stop_service_monitoring(service_id)
|
||||
await self._start_service_monitoring(service_id)
|
||||
|
||||
async def _monitor_service(self, service_id: str):
|
||||
"""监控单个服务"""
|
||||
try:
|
||||
while self._running and service_id in self._services:
|
||||
service_info = self._services[service_id]
|
||||
|
||||
try:
|
||||
# 执行健康检查
|
||||
client = await self.connection_pool.get_client(
|
||||
service_info["server_url"],
|
||||
service_info["connection_config"]
|
||||
)
|
||||
|
||||
health_status = await client.health_check()
|
||||
|
||||
if health_status["healthy"]:
|
||||
# 服务健康
|
||||
service_info["status"] = "healthy"
|
||||
service_info["retry_count"] = 0
|
||||
|
||||
# 更新工具列表
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
service_info["available_tools"] = tool_names
|
||||
except Exception as e:
|
||||
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
|
||||
|
||||
else:
|
||||
# 服务不健康
|
||||
service_info["status"] = "unhealthy"
|
||||
service_info["last_error"] = health_status.get("error", "健康检查失败")
|
||||
service_info["retry_count"] += 1
|
||||
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
# 更新数据库
|
||||
await self._update_service_health_in_db(service_id, health_status)
|
||||
|
||||
except Exception as e:
|
||||
# 监控异常
|
||||
service_info["status"] = "error"
|
||||
service_info["last_error"] = str(e)
|
||||
service_info["retry_count"] += 1
|
||||
service_info["last_health_check"] = time.time()
|
||||
|
||||
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
|
||||
|
||||
# 如果重试次数过多,暂停监控
|
||||
if service_info["retry_count"] >= self.max_retry_attempts:
|
||||
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
|
||||
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
|
||||
service_info["retry_count"] = 0 # 重置重试计数
|
||||
|
||||
# 等待下次检查
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
|
||||
|
||||
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
|
||||
"""更新数据库中的服务健康状态"""
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == uuid.UUID(service_id)
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
else:
|
||||
mcp_config.error_message = None
|
||||
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
|
||||
self.db.rollback()
|
||||
|
||||
async def _management_loop(self):
|
||||
"""管理循环 - 处理服务清理等任务"""
|
||||
try:
|
||||
while self._running:
|
||||
# 清理失效的服务
|
||||
await self._cleanup_failed_services()
|
||||
|
||||
# 等待下次循环
|
||||
await asyncio.sleep(300) # 5分钟
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"管理循环异常: {e}")
|
||||
|
||||
async def _cleanup_failed_services(self):
|
||||
"""清理长期失效的服务"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
cleanup_threshold = 24 * 60 * 60 # 24小时
|
||||
|
||||
services_to_cleanup = []
|
||||
|
||||
for service_id, service_info in self._services.items():
|
||||
# 检查服务是否长期失效
|
||||
if (service_info["status"] in ["error", "unhealthy"] and
|
||||
current_time - service_info["last_health_check"] > cleanup_threshold):
|
||||
|
||||
services_to_cleanup.append(service_id)
|
||||
|
||||
for service_id in services_to_cleanup:
|
||||
logger.warning(f"清理长期失效的服务: {service_id}")
|
||||
|
||||
# 停止监控但不删除数据库记录
|
||||
await self._stop_service_monitoring(service_id)
|
||||
|
||||
# 标记为禁用
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
||||
if tool_config:
|
||||
tool_config.is_enabled = False
|
||||
self.db.commit()
|
||||
|
||||
# 从内存移除
|
||||
del self._services[service_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理失效服务失败: {e}")
|
||||
|
||||
def get_manager_status(self) -> Dict[str, Any]:
|
||||
"""获取管理器状态"""
|
||||
return {
|
||||
"running": self._running,
|
||||
"total_services": len(self._services),
|
||||
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
|
||||
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
|
||||
"monitoring_tasks": len(self._monitoring_tasks),
|
||||
"connection_pool_status": self.connection_pool.get_pool_status()
|
||||
}
|
||||
def get_tool_manager(self) -> MCPToolManager:
|
||||
"""获取工具管理器实例"""
|
||||
return self.tool_manager
|
||||
@@ -77,7 +77,7 @@ class AppChatService:
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
@@ -109,20 +109,21 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
# web_tools = config.tools
|
||||
# web_search_choice = web_tools.get("web_search", {})
|
||||
# web_search_enable = web_search_choice.get("enabled", False)
|
||||
# if web_search == True:
|
||||
# if web_search_enable == True:
|
||||
# search_tool = create_web_search_tool({})
|
||||
# tools.append(search_tool)
|
||||
#
|
||||
# logger.debug(
|
||||
# "已添加网络搜索工具",
|
||||
# extra={
|
||||
# "tool_count": len(tools)
|
||||
# }
|
||||
# )
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
@@ -226,7 +227,7 @@ class AppChatService:
|
||||
# 获取工具服务
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
if hasattr(config, 'tools') and config.tools:
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||
for tool_config in config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
@@ -258,20 +259,21 @@ class AppChatService:
|
||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||
tools.append(memory_tool)
|
||||
|
||||
# web_tools = config.tools
|
||||
# web_search_choice = web_tools.get("web_search", {})
|
||||
# web_search_enable = web_search_choice.get("enabled", False)
|
||||
# if web_search == True:
|
||||
# if web_search_enable == True:
|
||||
# search_tool = create_web_search_tool({})
|
||||
# tools.append(search_tool)
|
||||
#
|
||||
# logger.debug(
|
||||
# "已添加网络搜索工具",
|
||||
# extra={
|
||||
# "tool_count": len(tools)
|
||||
# }
|
||||
# )
|
||||
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||
web_tools = config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 获取模型参数
|
||||
model_parameters = config.model_parameters
|
||||
|
||||
@@ -297,19 +297,35 @@ class DraftRunService:
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_config in agent_config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
for tool_config in agent_config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||
ToolRepository.get_tenant_id_by_workspace_id(
|
||||
self.db, str(workspace_id)))
|
||||
if tool_instance:
|
||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||
continue
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
@@ -507,7 +523,7 @@ class DraftRunService:
|
||||
tool_service = ToolService(self.db)
|
||||
|
||||
# 从配置中获取启用的工具
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
for tool_config in agent_config.tools:
|
||||
if tool_config.get("enabled", False):
|
||||
# 根据工具名称查找工具实例
|
||||
@@ -520,6 +536,22 @@ class DraftRunService:
|
||||
# 转换为LangChain工具
|
||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||
tools.append(langchain_tool)
|
||||
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||
web_tools = agent_config.tools
|
||||
web_search_choice = web_tools.get("web_search", {})
|
||||
web_search_enable = web_search_choice.get("enabled", False)
|
||||
if web_search == True:
|
||||
if web_search_enable == True:
|
||||
search_tool = create_web_search_tool({})
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.debug(
|
||||
"已添加网络搜索工具",
|
||||
extra={
|
||||
"tool_count": len(tools)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 添加知识库检索工具
|
||||
if agent_config.knowledge_retrieval:
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.tools.mcp import MCPClient
|
||||
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
|
||||
from app.repositories.tool_repository import (
|
||||
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
||||
MCPToolRepository, ToolExecutionRepository
|
||||
@@ -42,6 +42,9 @@ class ToolService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._tool_cache: Dict[str, BaseTool] = {}
|
||||
|
||||
# MCP管理器
|
||||
self.mcp_tool_manager = MCPToolManager(db)
|
||||
|
||||
# 初始化仓储
|
||||
self.tool_repo = ToolRepository()
|
||||
@@ -675,23 +678,85 @@ class ToolService:
|
||||
return []
|
||||
|
||||
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||
"""获取MCP工具的方法"""
|
||||
"""获取MCP工具的方法和参数"""
|
||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||
if not mcp_config:
|
||||
return []
|
||||
|
||||
available_tools = mcp_config.available_tools or []
|
||||
if not available_tools:
|
||||
return []
|
||||
# 如果没有工具列表,尝试同步
|
||||
try:
|
||||
success, tools, _ = await self.mcp_tool_manager.discover_tools(
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
if success:
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
mcp_config.available_tools = tool_names
|
||||
self.db.commit()
|
||||
available_tools = tool_names
|
||||
except Exception as e:
|
||||
logger.error(f"同步MCP工具列表失败: {e}")
|
||||
return []
|
||||
|
||||
methods = []
|
||||
for tool_name in available_tools:
|
||||
methods.append({
|
||||
"method_id": tool_name,
|
||||
"name": tool_name,
|
||||
"description": f"MCP工具: {tool_name}",
|
||||
"parameters": [] # MCP工具参数需要动态获取
|
||||
})
|
||||
|
||||
# 获取工具详细信息
|
||||
try:
|
||||
success, tools, _ = await self.mcp_tool_manager.discover_tools(
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
|
||||
if success:
|
||||
tools_dict = {tool.get("name"): tool for tool in tools if tool.get("name")}
|
||||
|
||||
for tool_name in available_tools:
|
||||
tool_info = tools_dict.get(tool_name, {})
|
||||
|
||||
# 解析工具参数
|
||||
parameters = []
|
||||
input_schema = tool_info.get("inputSchema", {})
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
for param_name, param_def in properties.items():
|
||||
parameters.append({
|
||||
"name": param_name,
|
||||
"type": param_def.get("type", "string"),
|
||||
"description": param_def.get("description", ""),
|
||||
"required": param_name in required_fields,
|
||||
"default": param_def.get("default"),
|
||||
"enum": param_def.get("enum"),
|
||||
"minimum": param_def.get("minimum"),
|
||||
"maximum": param_def.get("maximum")
|
||||
})
|
||||
|
||||
methods.append({
|
||||
"method_id": tool_name,
|
||||
"name": tool_name,
|
||||
"description": tool_info.get("description", f"MCP工具: {tool_name}"),
|
||||
"parameters": parameters
|
||||
})
|
||||
else:
|
||||
# 如果无法获取详细信息,返回基本信息
|
||||
for tool_name in available_tools:
|
||||
methods.append({
|
||||
"method_id": tool_name,
|
||||
"name": tool_name,
|
||||
"description": f"MCP工具: {tool_name}",
|
||||
"parameters": []
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取MCP工具详细信息失败: {e}")
|
||||
# 返回基本信息
|
||||
for tool_name in available_tools:
|
||||
methods.append({
|
||||
"method_id": tool_name,
|
||||
"name": tool_name,
|
||||
"description": f"MCP工具: {tool_name}",
|
||||
"parameters": []
|
||||
})
|
||||
|
||||
return methods
|
||||
|
||||
@@ -812,10 +877,14 @@ class ToolService:
|
||||
if not mcp_config:
|
||||
return None
|
||||
|
||||
# 从配置中获取特定工具名称
|
||||
tool_name = config.config_data.get("tool_name")
|
||||
|
||||
tool_config = {
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config or {},
|
||||
"available_tools": mcp_config.available_tools or []
|
||||
"available_tools": mcp_config.available_tools or [],
|
||||
"tool_name": tool_name # 指定具体工具
|
||||
}
|
||||
|
||||
return MCPTool(str(config.id), tool_config)
|
||||
@@ -1071,71 +1140,59 @@ class ToolService:
|
||||
return {}
|
||||
|
||||
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||
"""测试MCP连接"""
|
||||
"""测试MCP连接并自动同步工具列表"""
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == config.id
|
||||
).first()
|
||||
|
||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||
if not mcp_config:
|
||||
return {"success": False, "message": "MCP配置不存在"}
|
||||
|
||||
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
# tools = await client.list_tools()
|
||||
await client.disconnect()
|
||||
|
||||
# 更新连接状态
|
||||
# 使用集成的MCP管理器测试连接
|
||||
test_result = await self.mcp_tool_manager.test_tool_connection(
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
|
||||
if test_result["success"]:
|
||||
# 连接成功,自动同步工具列表
|
||||
success, tools, error = await self.mcp_tool_manager.discover_tools(
|
||||
mcp_config.server_url, mcp_config.connection_config or {}
|
||||
)
|
||||
|
||||
if success:
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
|
||||
# 更新数据库
|
||||
mcp_config.available_tools = tool_names
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "healthy"
|
||||
mcp_config.error_message = None
|
||||
|
||||
# 更新工具状态
|
||||
self._update_tool_status(config)
|
||||
config.status = ToolStatus.AVAILABLE.value
|
||||
|
||||
self.db.commit()
|
||||
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "MCP连接成功",
|
||||
# "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
|
||||
"details": {"server_url": mcp_config.server_url}
|
||||
"message": "MCP连接成功并同步工具列表",
|
||||
"details": {
|
||||
"server_url": mcp_config.server_url,
|
||||
"tools_count": len(tool_names),
|
||||
"tools": tool_names
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
await client.disconnect()
|
||||
|
||||
# 更新错误状态
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = str(e)
|
||||
self._update_tool_status(config)
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": f"MCP功能测试失败: {str(e)}"}
|
||||
else:
|
||||
return {"success": False, "message": f"同步工具失败: {error}"}
|
||||
else:
|
||||
# 更新连接失败状态
|
||||
# 更新错误状态
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = "连接失败"
|
||||
self._update_tool_status(config)
|
||||
mcp_config.error_message = test_result.get("error", "连接失败")
|
||||
config.status = ToolStatus.ERROR.value
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
|
||||
|
||||
return test_result
|
||||
|
||||
except Exception as e:
|
||||
# 更新异常状态
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == config.id
|
||||
).first()
|
||||
if mcp_config:
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = str(e)
|
||||
self._update_tool_status(config)
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
|
||||
logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}")
|
||||
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
|
||||
@@ -1190,57 +1247,44 @@ class ToolService:
|
||||
|
||||
# 创建MCP客户端
|
||||
connection_config = mcp_config.connection_config or {}
|
||||
client = SimpleMCPClient(mcp_config.server_url, connection_config)
|
||||
|
||||
client = MCPClient(mcp_config.server_url, connection_config)
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
# 获取工具列表
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
|
||||
# 更新数据库
|
||||
mcp_config.available_tools = tool_names
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "healthy"
|
||||
mcp_config.error_message = None
|
||||
|
||||
# 更新工具状态
|
||||
config.status = ToolStatus.AVAILABLE.value
|
||||
|
||||
self.db.commit()
|
||||
|
||||
await client.disconnect()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "工具列表同步成功",
|
||||
"tools_count": len(tool_names),
|
||||
"tools": tool_names
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await client.disconnect()
|
||||
|
||||
# 更新错误状态
|
||||
async with client:
|
||||
# 获取工具列表
|
||||
tools = await client.list_tools()
|
||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
|
||||
# 更新数据库
|
||||
mcp_config.available_tools = tool_names
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "healthy"
|
||||
mcp_config.error_message = None
|
||||
|
||||
# 更新工具状态
|
||||
config.status = ToolStatus.AVAILABLE.value
|
||||
|
||||
self.db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "工具列表同步成功",
|
||||
"tools_count": len(tool_names),
|
||||
"tools": tool_names
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 更新错误状态
|
||||
try:
|
||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||
if mcp_config:
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = str(e)
|
||||
config.status = ToolStatus.ERROR.value
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": f"获取工具列表失败: {str(e)}"}
|
||||
else:
|
||||
# 连接失败
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
mcp_config.health_status = "error"
|
||||
mcp_config.error_message = "连接失败"
|
||||
config.status = ToolStatus.ERROR.value
|
||||
self.db.commit()
|
||||
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
|
||||
except Exception as e:
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
|
||||
return {"success": False, "message": f"同步失败: {str(e)}"}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user