Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
@@ -215,8 +215,8 @@ async def sync_mcp_tools(
|
|||||||
"""同步MCP工具列表"""
|
"""同步MCP工具列表"""
|
||||||
try:
|
try:
|
||||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||||
if result["success"] is False:
|
if not result.get("success", False):
|
||||||
raise HTTPException(status_code=404, detail=result["message"])
|
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||||
return success(data=result, msg="MCP工具列表同步完成")
|
return success(data=result, msg="MCP工具列表同步完成")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|||||||
@@ -78,13 +78,20 @@ class LangchainAdapter:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
tool: 内部工具实例
|
tool: 内部工具实例
|
||||||
operation: 特定操作(适用于有操作的工具)
|
operation: 特定操作(适用于有操作的工具)或MCP工具名称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Langchain兼容的工具包装器
|
Langchain兼容的工具包装器
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if operation and tool.name in ['datetime_tool', 'json_tool']:
|
# 处理MCP工具的特定工具名称
|
||||||
|
if hasattr(tool, 'tool_type') and tool.tool_type.value == 'mcp' and operation:
|
||||||
|
# 为MCP工具创建特定工具名称的实例
|
||||||
|
mcp_tool = LangchainAdapter._create_mcp_tool_with_name(tool, operation)
|
||||||
|
wrapper = LangchainToolWrapper(tool_instance=mcp_tool)
|
||||||
|
logger.debug(f"MCP工具转换成功: {tool.name}_{operation} -> Langchain格式")
|
||||||
|
return wrapper
|
||||||
|
elif operation and tool.name in ['datetime_tool', 'json_tool']:
|
||||||
# 为特定操作创建工具
|
# 为特定操作创建工具
|
||||||
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
operation_tool = LangchainAdapter._create_operation_tool(tool, operation)
|
||||||
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
wrapper = LangchainToolWrapper(tool_instance=operation_tool)
|
||||||
@@ -106,6 +113,18 @@ class LangchainAdapter:
|
|||||||
from app.core.tools.builtin.operation_tool import OperationTool
|
from app.core.tools.builtin.operation_tool import OperationTool
|
||||||
return OperationTool(base_tool, operation)
|
return OperationTool(base_tool, operation)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _create_mcp_tool_with_name(base_tool: BaseTool, tool_name: str) -> BaseTool:
|
||||||
|
"""为MCP工具创建指定工具名称的实例"""
|
||||||
|
from app.core.tools.mcp.base import MCPTool
|
||||||
|
|
||||||
|
# 创建新的配置,指定具体工具名称
|
||||||
|
new_config = base_tool.config.copy()
|
||||||
|
new_config["tool_name"] = tool_name
|
||||||
|
|
||||||
|
# 创建新的MCP工具实例
|
||||||
|
return MCPTool(f"{base_tool.tool_id}_{tool_name}", new_config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
|
||||||
"""批量转换工具
|
"""批量转换工具
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
"""MCP工具模块"""
|
"""MCP 工具模块 - Model Context Protocol 支持"""
|
||||||
|
|
||||||
from app.core.tools.mcp.base import MCPTool
|
# 主要类导出
|
||||||
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
|
from .base import MCPTool, MCPToolManager, MCPError
|
||||||
from app.core.tools.mcp.service_manager import MCPServiceManager
|
from .client import SimpleMCPClient, MCPConnectionError
|
||||||
|
from .service_manager import MCPServiceManager
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# 核心类
|
||||||
"MCPTool",
|
"MCPTool",
|
||||||
"MCPClient",
|
"MCPToolManager",
|
||||||
"MCPConnectionPool",
|
"MCPError",
|
||||||
|
|
||||||
|
# 客户端类
|
||||||
|
"SimpleMCPClient",
|
||||||
|
"MCPConnectionError",
|
||||||
|
|
||||||
|
# 服务管理(简化版)
|
||||||
"MCPServiceManager"
|
"MCPServiceManager"
|
||||||
]
|
]
|
||||||
@@ -1,10 +1,9 @@
|
|||||||
"""MCP工具基类"""
|
"""MCP工具基类 - 整合版本"""
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Any, List
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
from app.models.tool_model import ToolType
|
from app.models.tool_model import ToolType
|
||||||
from app.core.tools.base import BaseTool
|
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -14,215 +13,174 @@ class MCPTool(BaseTool):
|
|||||||
"""MCP工具 - Model Context Protocol工具"""
|
"""MCP工具 - Model Context Protocol工具"""
|
||||||
|
|
||||||
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
||||||
"""初始化MCP工具
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_id: 工具ID
|
|
||||||
config: 工具配置
|
|
||||||
"""
|
|
||||||
super().__init__(tool_id, config)
|
super().__init__(tool_id, config)
|
||||||
self.server_url = config.get("server_url", "")
|
self.server_url = config.get("server_url", "")
|
||||||
self.connection_config = config.get("connection_config", {})
|
self.connection_config = config.get("connection_config", {})
|
||||||
|
self.tool_name = config.get("tool_name", "") # 特定工具名称
|
||||||
|
self.tool_schema = config.get("tool_schema", {}) # 工具参数 schema
|
||||||
self.available_tools = config.get("available_tools", [])
|
self.available_tools = config.get("available_tools", [])
|
||||||
self._client = None
|
|
||||||
self._connected = False
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""工具名称"""
|
return f"mcp_{self.tool_name}" if self.tool_name else f"mcp_tool_{self.tool_id[:8]}"
|
||||||
return f"mcp_tool_{self.tool_id[:8]}"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def description(self) -> str:
|
def description(self) -> str:
|
||||||
"""工具描述"""
|
if self.tool_schema.get("description"):
|
||||||
return f"MCP工具 - 连接到 {self.server_url}"
|
return self.tool_schema["description"]
|
||||||
|
return f"MCP工具: {self.tool_name}" if self.tool_name else f"MCP工具 - 连接到 {self.server_url}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tool_type(self) -> ToolType:
|
def tool_type(self) -> ToolType:
|
||||||
"""工具类型"""
|
|
||||||
return ToolType.MCP
|
return ToolType.MCP
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self) -> List[ToolParameter]:
|
def parameters(self) -> List[ToolParameter]:
|
||||||
"""工具参数定义"""
|
"""从 MCP 工具 schema 生成参数"""
|
||||||
params = []
|
if not self.tool_schema:
|
||||||
|
return [ToolParameter(
|
||||||
# 添加工具选择参数
|
|
||||||
if len(self.available_tools) > 1:
|
|
||||||
params.append(ToolParameter(
|
|
||||||
name="tool_name",
|
|
||||||
type=ParameterType.STRING,
|
|
||||||
description="要调用的MCP工具名称",
|
|
||||||
required=True,
|
|
||||||
enum=self.available_tools
|
|
||||||
))
|
|
||||||
|
|
||||||
# 添加通用参数
|
|
||||||
params.extend([
|
|
||||||
ToolParameter(
|
|
||||||
name="arguments",
|
name="arguments",
|
||||||
type=ParameterType.OBJECT,
|
type=ParameterType.OBJECT,
|
||||||
description="工具参数(JSON对象)",
|
description="工具参数",
|
||||||
required=False,
|
required=False,
|
||||||
default={}
|
default={}
|
||||||
),
|
)]
|
||||||
ToolParameter(
|
|
||||||
name="timeout",
|
# 解析 MCP 工具的 inputSchema
|
||||||
type=ParameterType.INTEGER,
|
input_schema = self.tool_schema.get("inputSchema", {})
|
||||||
description="超时时间(秒)",
|
properties = input_schema.get("properties", {})
|
||||||
required=False,
|
required_fields = input_schema.get("required", [])
|
||||||
default=30,
|
|
||||||
minimum=1,
|
params = []
|
||||||
maximum=300
|
for param_name, param_def in properties.items():
|
||||||
)
|
param_type = self._convert_json_type_to_parameter_type(param_def.get("type", "string"))
|
||||||
])
|
|
||||||
|
params.append(ToolParameter(
|
||||||
|
name=param_name,
|
||||||
|
type=param_type,
|
||||||
|
description=param_def.get("description", f"参数: {param_name}"),
|
||||||
|
required=param_name in required_fields,
|
||||||
|
default=param_def.get("default"),
|
||||||
|
enum=param_def.get("enum"),
|
||||||
|
minimum=param_def.get("minimum"),
|
||||||
|
maximum=param_def.get("maximum")
|
||||||
|
))
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
def _convert_json_type_to_parameter_type(self, json_type: str) -> ParameterType:
|
||||||
|
"""转换 JSON Schema 类型到 ParameterType"""
|
||||||
|
type_mapping = {
|
||||||
|
"string": ParameterType.STRING,
|
||||||
|
"integer": ParameterType.INTEGER,
|
||||||
|
"number": ParameterType.NUMBER,
|
||||||
|
"boolean": ParameterType.BOOLEAN,
|
||||||
|
"array": ParameterType.ARRAY,
|
||||||
|
"object": ParameterType.OBJECT
|
||||||
|
}
|
||||||
|
return type_mapping.get(json_type, ParameterType.STRING)
|
||||||
|
|
||||||
async def execute(self, **kwargs) -> ToolResult:
|
async def execute(self, **kwargs) -> ToolResult:
|
||||||
"""执行MCP工具"""
|
"""执行 MCP 工具"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 确保连接
|
from .client import SimpleMCPClient
|
||||||
if not self._connected:
|
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
# 确定要调用的工具
|
client = SimpleMCPClient(self.server_url, self.connection_config)
|
||||||
tool_name = kwargs.get("tool_name")
|
|
||||||
if not tool_name and len(self.available_tools) == 1:
|
|
||||||
tool_name = self.available_tools[0]
|
|
||||||
|
|
||||||
if not tool_name:
|
async with client:
|
||||||
raise ValueError("必须指定要调用的MCP工具名称")
|
# 使用指定的工具名称或默认第一个工具
|
||||||
|
tool_name_to_use = self.tool_name
|
||||||
|
if not tool_name_to_use and self.available_tools:
|
||||||
|
tool_name_to_use = self.available_tools[0]
|
||||||
|
|
||||||
if tool_name not in self.available_tools:
|
if not tool_name_to_use:
|
||||||
raise ValueError(f"MCP工具不存在: {tool_name}")
|
raise Exception("未指定工具名称且无可用工具")
|
||||||
|
|
||||||
# 获取参数
|
result = await client.call_tool(tool_name_to_use, kwargs)
|
||||||
arguments = kwargs.get("arguments", {})
|
|
||||||
timeout = kwargs.get("timeout", 30)
|
|
||||||
|
|
||||||
# 调用MCP工具
|
execution_time = time.time() - start_time
|
||||||
result = await self._call_mcp_tool(tool_name, arguments, timeout)
|
return ToolResult.success_result(
|
||||||
|
data=result,
|
||||||
execution_time = time.time() - start_time
|
execution_time=execution_time
|
||||||
return ToolResult.success_result(
|
)
|
||||||
data=result,
|
|
||||||
execution_time=execution_time
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
|
logger.error(f"MCP工具执行失败: {self.tool_name or 'unknown'}, 错误: {e}")
|
||||||
return ToolResult.error_result(
|
return ToolResult.error_result(
|
||||||
error=str(e),
|
error=str(e),
|
||||||
error_code="MCP_ERROR",
|
error_code="MCP_EXECUTION_ERROR",
|
||||||
execution_time=execution_time
|
execution_time=execution_time
|
||||||
)
|
)
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
|
||||||
"""连接到MCP服务器"""
|
class MCPError(Exception):
|
||||||
|
"""MCP 错误基类"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MCPToolManager:
|
||||||
|
"""MCP 工具管理器 - 简化版本"""
|
||||||
|
|
||||||
|
def __init__(self, db=None):
|
||||||
|
self.db = db
|
||||||
|
self._tool_cache: Dict[str, Dict[str, Any]] = {} # server_url -> tools_info
|
||||||
|
|
||||||
|
async def discover_tools(
|
||||||
|
self,
|
||||||
|
server_url: str,
|
||||||
|
connection_config: Dict[str, Any] = None
|
||||||
|
) -> tuple[bool, List[Dict[str, Any]], str | None]:
|
||||||
|
"""发现 MCP 服务器上的工具"""
|
||||||
try:
|
try:
|
||||||
from .client import MCPClient
|
from .client import SimpleMCPClient
|
||||||
|
|
||||||
if self._connected:
|
client = SimpleMCPClient(server_url, connection_config)
|
||||||
return True
|
|
||||||
|
|
||||||
self._client = MCPClient(self.server_url, self.connection_config)
|
async with client:
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
if await self._client.connect():
|
# 缓存工具信息
|
||||||
self._connected = True
|
self._tool_cache[server_url] = {
|
||||||
# 更新可用工具列表
|
"tools": tools,
|
||||||
await self._update_available_tools()
|
"connection_config": connection_config,
|
||||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
"last_updated": time.time()
|
||||||
return True
|
}
|
||||||
else:
|
|
||||||
logger.error(f"MCP服务器连接失败: {self.server_url}")
|
logger.info(f"发现 {len(tools)} 个MCP工具: {server_url}")
|
||||||
return False
|
return True, tools, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
|
error_msg = f"发现工具失败: {e}"
|
||||||
self._connected = False
|
logger.error(error_msg)
|
||||||
return False
|
return False, [], error_msg
|
||||||
|
|
||||||
async def _update_available_tools(self):
|
async def test_tool_connection(
|
||||||
"""更新可用工具列表"""
|
self,
|
||||||
|
server_url: str,
|
||||||
|
connection_config: Dict[str, Any] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""测试工具连接"""
|
||||||
try:
|
try:
|
||||||
if self._client and self._connected:
|
from .client import SimpleMCPClient
|
||||||
tools = await self._client.list_tools()
|
|
||||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
|
||||||
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"更新MCP工具列表失败: {e}")
|
|
||||||
|
|
||||||
async def disconnect(self) -> bool:
|
client = SimpleMCPClient(server_url, connection_config)
|
||||||
"""断开MCP服务器连接"""
|
|
||||||
try:
|
|
||||||
if self._client:
|
|
||||||
await self._client.disconnect()
|
|
||||||
self._client = None
|
|
||||||
|
|
||||||
self._connected = False
|
async with client:
|
||||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
tools = await client.list_tools()
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
return {
|
||||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
"success": True,
|
||||||
return False
|
"tools_count": len(tools),
|
||||||
|
"tools": [tool.get("name") for tool in tools],
|
||||||
def get_health_status(self) -> Dict[str, Any]:
|
"message": "连接成功"
|
||||||
"""获取MCP服务健康状态"""
|
}
|
||||||
return {
|
|
||||||
"connected": self._connected,
|
|
||||||
"server_url": self.server_url,
|
|
||||||
"available_tools": self.available_tools,
|
|
||||||
"last_check": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
|
||||||
"""调用MCP工具"""
|
|
||||||
if not self._client or not self._connected:
|
|
||||||
raise Exception("MCP客户端未连接")
|
|
||||||
|
|
||||||
try:
|
|
||||||
result = await self._client.call_tool(tool_name, arguments, timeout)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
|
||||||
"""列出可用的MCP工具"""
|
|
||||||
try:
|
|
||||||
if not self._connected:
|
|
||||||
await self.connect()
|
|
||||||
|
|
||||||
if self._client:
|
|
||||||
tools = await self._client.list_tools()
|
|
||||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
|
||||||
return tools
|
|
||||||
|
|
||||||
return []
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取MCP工具列表失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def test_connection(self) -> Dict[str, Any]:
|
|
||||||
"""测试MCP连接"""
|
|
||||||
try:
|
|
||||||
# 这里应该实现同步的连接测试
|
|
||||||
# 为了简化,返回基本信息
|
|
||||||
return {
|
|
||||||
"success": bool(self.server_url),
|
|
||||||
"server_url": self.server_url,
|
|
||||||
"connected": self._connected,
|
|
||||||
"available_tools_count": len(self.available_tools),
|
|
||||||
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e)
|
"error": str(e),
|
||||||
|
"message": "连接失败"
|
||||||
}
|
}
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
"""MCP客户端 - Model Context Protocol客户端实现"""
|
"""MCP客户端 - 简化版本"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Any, List, Optional, Callable
|
from typing import Dict, Any, List
|
||||||
from urllib.parse import urlparse
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import websockets
|
import websockets
|
||||||
from websockets.exceptions import ConnectionClosed
|
from websockets.exceptions import ConnectionClosed
|
||||||
@@ -18,571 +17,22 @@ class MCPConnectionError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class MCPProtocolError(Exception):
|
class SimpleMCPClient:
|
||||||
"""MCP协议错误"""
|
"""简化的 MCP 客户端"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
|
||||||
"""MCP客户端 - 支持HTTP和WebSocket连接"""
|
|
||||||
|
|
||||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||||
"""初始化MCP客户端
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_url: MCP服务器URL
|
|
||||||
connection_config: 连接配置
|
|
||||||
"""
|
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.connection_config = connection_config or {}
|
self.connection_config = connection_config or {}
|
||||||
|
self.timeout = self.connection_config.get("timeout", 30)
|
||||||
|
|
||||||
# 解析URL确定连接类型
|
# 确定连接类型
|
||||||
parsed_url = urlparse(server_url)
|
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||||
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
|
|
||||||
|
|
||||||
# 连接状态
|
# 连接状态
|
||||||
self._connected = False
|
|
||||||
self._websocket = None
|
self._websocket = None
|
||||||
self._session = None
|
self._session = None
|
||||||
|
|
||||||
# 请求管理
|
|
||||||
self._request_id = 0
|
self._request_id = 0
|
||||||
self._pending_requests: Dict[str, asyncio.Future] = {}
|
self._pending_requests = {}
|
||||||
|
|
||||||
# 连接池配置
|
|
||||||
self.max_connections = self.connection_config.get("max_connections", 10)
|
|
||||||
self.connection_timeout = self.connection_config.get("timeout", 30)
|
|
||||||
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
|
|
||||||
self.retry_delay = self.connection_config.get("retry_delay", 1)
|
|
||||||
|
|
||||||
# 健康检查
|
|
||||||
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
|
|
||||||
self._health_check_task = None
|
|
||||||
self._last_health_check = None
|
|
||||||
|
|
||||||
# 事件回调
|
|
||||||
self._on_connect_callbacks: List[Callable] = []
|
|
||||||
self._on_disconnect_callbacks: List[Callable] = []
|
|
||||||
self._on_error_callbacks: List[Callable] = []
|
|
||||||
|
|
||||||
async def connect(self) -> bool:
|
|
||||||
"""连接到MCP服务器
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
连接是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if self._connected:
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info(f"连接MCP服务器: {self.server_url}")
|
|
||||||
|
|
||||||
if self.connection_type == "websocket":
|
|
||||||
success = await self._connect_websocket()
|
|
||||||
else:
|
|
||||||
success = await self._connect_http()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
self._connected = True
|
|
||||||
await self._start_health_check()
|
|
||||||
await self._notify_connect_callbacks()
|
|
||||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
|
||||||
|
|
||||||
return success
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
|
|
||||||
await self._notify_error_callbacks(e)
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def disconnect(self) -> bool:
|
|
||||||
"""断开MCP服务器连接
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
断开是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not self._connected:
|
|
||||||
return True
|
|
||||||
|
|
||||||
logger.info(f"断开MCP服务器连接: {self.server_url}")
|
|
||||||
|
|
||||||
# 停止健康检查
|
|
||||||
await self._stop_health_check()
|
|
||||||
|
|
||||||
# 取消所有待处理的请求
|
|
||||||
for future in self._pending_requests.values():
|
|
||||||
if not future.done():
|
|
||||||
future.cancel()
|
|
||||||
self._pending_requests.clear()
|
|
||||||
|
|
||||||
# 断开连接
|
|
||||||
if self.connection_type == "websocket" and self._websocket:
|
|
||||||
await self._websocket.close()
|
|
||||||
self._websocket = None
|
|
||||||
elif self._session:
|
|
||||||
await self._session.close()
|
|
||||||
self._session = None
|
|
||||||
|
|
||||||
self._connected = False
|
|
||||||
await self._notify_disconnect_callbacks()
|
|
||||||
logger.info(f"MCP服务器连接已断开: {self.server_url}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _build_auth_headers(self) -> Dict[str, str]:
|
|
||||||
"""构建认证头"""
|
|
||||||
headers = {}
|
|
||||||
auth_type = self.connection_config.get("auth_type", "none")
|
|
||||||
auth_config = self.connection_config.get("auth_config", {})
|
|
||||||
|
|
||||||
if auth_type == "api_key":
|
|
||||||
api_key = auth_config.get("api_key")
|
|
||||||
key_name = auth_config.get("key_name", "X-API-Key")
|
|
||||||
if api_key:
|
|
||||||
headers[key_name] = api_key
|
|
||||||
|
|
||||||
elif auth_type == "bearer_token":
|
|
||||||
token = auth_config.get("token")
|
|
||||||
if token:
|
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
|
||||||
|
|
||||||
elif auth_type == "basic_auth":
|
|
||||||
username = auth_config.get("username")
|
|
||||||
password = auth_config.get("password")
|
|
||||||
if username and password:
|
|
||||||
import base64
|
|
||||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
||||||
headers["Authorization"] = f"Basic {credentials}"
|
|
||||||
|
|
||||||
return headers
|
|
||||||
|
|
||||||
async def _connect_websocket(self) -> bool:
|
|
||||||
"""建立WebSocket连接"""
|
|
||||||
try:
|
|
||||||
# WebSocket连接配置
|
|
||||||
extra_headers = self.connection_config.get("headers", {})
|
|
||||||
auth_headers = self._build_auth_headers()
|
|
||||||
extra_headers.update(auth_headers)
|
|
||||||
|
|
||||||
self._websocket = await websockets.connect(
|
|
||||||
self.server_url,
|
|
||||||
extra_headers=extra_headers,
|
|
||||||
timeout=self.connection_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
# 启动消息监听
|
|
||||||
asyncio.create_task(self._websocket_message_handler())
|
|
||||||
|
|
||||||
# 发送初始化消息
|
|
||||||
init_message = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": self._get_next_request_id(),
|
|
||||||
"method": "initialize",
|
|
||||||
"params": {
|
|
||||||
"protocolVersion": "2024-11-05",
|
|
||||||
"capabilities": {
|
|
||||||
"tools": {}
|
|
||||||
},
|
|
||||||
"clientInfo": {
|
|
||||||
"name": "ToolManagementSystem",
|
|
||||||
"version": "1.0.0"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await self._websocket.send(json.dumps(init_message))
|
|
||||||
|
|
||||||
# 等待初始化响应
|
|
||||||
response = await asyncio.wait_for(
|
|
||||||
self._websocket.recv(),
|
|
||||||
timeout=self.connection_timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
init_response = json.loads(response)
|
|
||||||
if init_response.get("error", None) is not None:
|
|
||||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"WebSocket连接失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _connect_http(self) -> bool:
|
|
||||||
"""建立HTTP连接"""
|
|
||||||
try:
|
|
||||||
# HTTP会话配置
|
|
||||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
|
||||||
headers = self.connection_config.get("headers", {})
|
|
||||||
auth_headers = self._build_auth_headers()
|
|
||||||
headers.update(auth_headers)
|
|
||||||
|
|
||||||
self._session = aiohttp.ClientSession(
|
|
||||||
timeout=timeout,
|
|
||||||
headers=headers
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试连接
|
|
||||||
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
|
|
||||||
|
|
||||||
async with self._session.get(test_url) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
# 尝试根路径
|
|
||||||
async with self._session.get(self.server_url) as root_response:
|
|
||||||
return root_response.status < 400
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"HTTP连接失败: {e}")
|
|
||||||
if self._session:
|
|
||||||
await self._session.close()
|
|
||||||
self._session = None
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _websocket_message_handler(self):
|
|
||||||
"""WebSocket消息处理器"""
|
|
||||||
try:
|
|
||||||
while self._websocket and not self._websocket.closed:
|
|
||||||
try:
|
|
||||||
message = await self._websocket.recv()
|
|
||||||
await self._handle_message(json.loads(message))
|
|
||||||
except ConnectionClosed:
|
|
||||||
break
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"解析WebSocket消息失败: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理WebSocket消息失败: {e}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"WebSocket消息处理器异常: {e}")
|
|
||||||
finally:
|
|
||||||
self._connected = False
|
|
||||||
await self._notify_disconnect_callbacks()
|
|
||||||
|
|
||||||
async def _handle_message(self, message: Dict[str, Any]):
|
|
||||||
"""处理收到的消息"""
|
|
||||||
try:
|
|
||||||
# 检查是否是响应消息
|
|
||||||
if "id" in message:
|
|
||||||
request_id = str(message["id"])
|
|
||||||
if request_id in self._pending_requests:
|
|
||||||
future = self._pending_requests.pop(request_id)
|
|
||||||
if not future.done():
|
|
||||||
future.set_result(message)
|
|
||||||
|
|
||||||
# 处理通知消息
|
|
||||||
elif "method" in message:
|
|
||||||
await self._handle_notification(message)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理消息失败: {e}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _handle_notification(message: Dict[str, Any]):
|
|
||||||
"""处理通知消息"""
|
|
||||||
method = message.get("method")
|
|
||||||
params = message.get("params", {})
|
|
||||||
|
|
||||||
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
|
|
||||||
|
|
||||||
# 这里可以根据需要处理特定的通知
|
|
||||||
# 例如:工具列表更新、服务器状态变化等
|
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
|
|
||||||
"""调用MCP工具
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: 工具名称
|
|
||||||
arguments: 工具参数
|
|
||||||
timeout: 超时时间(秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
工具执行结果
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
MCPConnectionError: 连接错误
|
|
||||||
MCPProtocolError: 协议错误
|
|
||||||
"""
|
|
||||||
if not self._connected:
|
|
||||||
raise MCPConnectionError("MCP客户端未连接")
|
|
||||||
|
|
||||||
request_data = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": self._get_next_request_id(),
|
|
||||||
"method": "tools/call",
|
|
||||||
"params": {
|
|
||||||
"name": tool_name,
|
|
||||||
"arguments": arguments
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self._send_request(request_data, timeout)
|
|
||||||
|
|
||||||
if response.get("error", None) is not None:
|
|
||||||
error = response["error"]
|
|
||||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
|
||||||
|
|
||||||
return response.get("result", {})
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise MCPProtocolError(f"工具调用超时: {tool_name}")
|
|
||||||
|
|
||||||
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
|
|
||||||
"""获取可用工具列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
timeout: 超时时间(秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
工具列表
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
MCPConnectionError: 连接错误
|
|
||||||
MCPProtocolError: 协议错误
|
|
||||||
"""
|
|
||||||
if not self._connected:
|
|
||||||
raise MCPConnectionError("MCP客户端未连接")
|
|
||||||
|
|
||||||
request_data = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": self._get_next_request_id(),
|
|
||||||
"method": "tools/list"
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self._send_request(request_data, timeout)
|
|
||||||
|
|
||||||
if response.get("error", None) is not None:
|
|
||||||
error = response["error"]
|
|
||||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
|
||||||
|
|
||||||
result = response.get("result", {})
|
|
||||||
return result.get("tools", [])
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
raise MCPProtocolError("获取工具列表超时")
|
|
||||||
|
|
||||||
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
|
||||||
"""发送请求并等待响应
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request_data: 请求数据
|
|
||||||
timeout: 超时时间(秒)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
响应数据
|
|
||||||
"""
|
|
||||||
if self.connection_type == "websocket":
|
|
||||||
request_id = str(request_data["id"])
|
|
||||||
return await self._send_websocket_request(request_data, request_id, timeout)
|
|
||||||
else:
|
|
||||||
return await self._send_http_request(request_data, timeout)
|
|
||||||
|
|
||||||
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
|
|
||||||
"""发送WebSocket请求"""
|
|
||||||
if not self._websocket or self._websocket.closed:
|
|
||||||
raise MCPConnectionError("WebSocket连接已断开")
|
|
||||||
|
|
||||||
# 创建Future等待响应
|
|
||||||
future = asyncio.Future()
|
|
||||||
self._pending_requests[request_id] = future
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 发送请求
|
|
||||||
await self._websocket.send(json.dumps(request_data))
|
|
||||||
|
|
||||||
# 等待响应
|
|
||||||
response = await asyncio.wait_for(future, timeout=timeout)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
await self._pending_requests.pop(request_id, None)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
await self._pending_requests.pop(request_id, None)
|
|
||||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
|
||||||
|
|
||||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
|
||||||
"""发送HTTP请求"""
|
|
||||||
if not self._session:
|
|
||||||
raise MCPConnectionError("HTTP会话未建立")
|
|
||||||
|
|
||||||
try:
|
|
||||||
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
|
|
||||||
|
|
||||||
async with self._session.post(
|
|
||||||
url,
|
|
||||||
json=request_data,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
|
||||||
) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
return await response.json()
|
|
||||||
else:
|
|
||||||
async with self._session.post(
|
|
||||||
self.server_url,
|
|
||||||
json=request_data,
|
|
||||||
timeout=aiohttp.ClientTimeout(total=timeout)
|
|
||||||
) as root_response:
|
|
||||||
if root_response.status != 200:
|
|
||||||
error_text = await root_response.text()
|
|
||||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
|
||||||
|
|
||||||
return await response.json()
|
|
||||||
|
|
||||||
except aiohttp.ClientError as e:
|
|
||||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
|
||||||
|
|
||||||
async def health_check(self) -> Dict[str, Any]:
|
|
||||||
"""执行健康检查
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
健康状态信息
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if not self._connected:
|
|
||||||
return {
|
|
||||||
"healthy": False,
|
|
||||||
"error": "未连接",
|
|
||||||
"timestamp": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送ping请求
|
|
||||||
request_data = {
|
|
||||||
"jsonrpc": "2.0",
|
|
||||||
"id": self._get_next_request_id(),
|
|
||||||
"method": "ping"
|
|
||||||
}
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
response = await self._send_request(request_data, timeout=5)
|
|
||||||
response_time = round((time.time() - start_time) * 1000)
|
|
||||||
|
|
||||||
self._last_health_check = round(time.time() * 1000)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"healthy": True,
|
|
||||||
"response_time": response_time,
|
|
||||||
"timestamp": self._last_health_check,
|
|
||||||
"server_info": response.get("result", {})
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return {
|
|
||||||
"healthy": False,
|
|
||||||
"error": str(e),
|
|
||||||
"timestamp": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _start_health_check(self):
|
|
||||||
"""启动健康检查任务"""
|
|
||||||
if self.health_check_interval > 0:
|
|
||||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
|
||||||
|
|
||||||
async def _stop_health_check(self):
|
|
||||||
"""停止健康检查任务"""
|
|
||||||
if self._health_check_task:
|
|
||||||
self._health_check_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._health_check_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
self._health_check_task = None
|
|
||||||
|
|
||||||
async def _health_check_loop(self):
|
|
||||||
"""健康检查循环"""
|
|
||||||
try:
|
|
||||||
while self._connected:
|
|
||||||
await asyncio.sleep(self.health_check_interval)
|
|
||||||
|
|
||||||
if self._connected:
|
|
||||||
health_status = await self.health_check()
|
|
||||||
if not health_status["healthy"]:
|
|
||||||
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
|
|
||||||
# 可以在这里实现重连逻辑
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"健康检查循环异常: {e}")
|
|
||||||
|
|
||||||
def _get_next_request_id(self) -> str:
|
|
||||||
"""获取下一个请求ID"""
|
|
||||||
self._request_id += 1
|
|
||||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
|
||||||
|
|
||||||
# 事件回调管理
|
|
||||||
def on_connect(self, callback: Callable):
|
|
||||||
"""注册连接回调"""
|
|
||||||
self._on_connect_callbacks.append(callback)
|
|
||||||
|
|
||||||
def on_disconnect(self, callback: Callable):
|
|
||||||
"""注册断开连接回调"""
|
|
||||||
self._on_disconnect_callbacks.append(callback)
|
|
||||||
|
|
||||||
def on_error(self, callback: Callable):
|
|
||||||
"""注册错误回调"""
|
|
||||||
self._on_error_callbacks.append(callback)
|
|
||||||
|
|
||||||
async def _notify_connect_callbacks(self):
|
|
||||||
"""通知连接回调"""
|
|
||||||
for callback in self._on_connect_callbacks:
|
|
||||||
try:
|
|
||||||
if asyncio.iscoroutinefunction(callback):
|
|
||||||
await callback()
|
|
||||||
else:
|
|
||||||
callback()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"连接回调执行失败: {e}")
|
|
||||||
|
|
||||||
async def _notify_disconnect_callbacks(self):
|
|
||||||
"""通知断开连接回调"""
|
|
||||||
for callback in self._on_disconnect_callbacks:
|
|
||||||
try:
|
|
||||||
if asyncio.iscoroutinefunction(callback):
|
|
||||||
await callback()
|
|
||||||
else:
|
|
||||||
callback()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"断开连接回调执行失败: {e}")
|
|
||||||
|
|
||||||
async def _notify_error_callbacks(self, error: Exception):
|
|
||||||
"""通知错误回调"""
|
|
||||||
for callback in self._on_error_callbacks:
|
|
||||||
try:
|
|
||||||
if asyncio.iscoroutinefunction(callback):
|
|
||||||
await callback(error)
|
|
||||||
else:
|
|
||||||
callback(error)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"错误回调执行失败: {e}")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_connected(self) -> bool:
|
|
||||||
"""检查是否已连接"""
|
|
||||||
return self._connected
|
|
||||||
|
|
||||||
@property
|
|
||||||
def last_health_check(self) -> Optional[float]:
|
|
||||||
"""获取最后一次健康检查时间"""
|
|
||||||
return self._last_health_check
|
|
||||||
|
|
||||||
def get_connection_info(self) -> Dict[str, Any]:
|
|
||||||
"""获取连接信息"""
|
|
||||||
return {
|
|
||||||
"server_url": self.server_url,
|
|
||||||
"connection_type": self.connection_type,
|
|
||||||
"connected": self._connected,
|
|
||||||
"last_health_check": self._last_health_check,
|
|
||||||
"pending_requests": len(self._pending_requests),
|
|
||||||
"config": self.connection_config
|
|
||||||
}
|
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""异步上下文管理器入口"""
|
"""异步上下文管理器入口"""
|
||||||
@@ -593,72 +43,267 @@ class MCPClient:
|
|||||||
"""异步上下文管理器出口"""
|
"""异步上下文管理器出口"""
|
||||||
await self.disconnect()
|
await self.disconnect()
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
class MCPConnectionPool:
|
"""建立连接"""
|
||||||
"""MCP连接池 - 管理多个MCP客户端连接"""
|
try:
|
||||||
|
if self.is_websocket:
|
||||||
def __init__(self, max_connections: int = 10):
|
await self._connect_websocket()
|
||||||
"""初始化连接池
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_connections: 最大连接数
|
|
||||||
"""
|
|
||||||
self.max_connections = max_connections
|
|
||||||
self._clients: Dict[str, MCPClient] = {}
|
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
|
|
||||||
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
|
|
||||||
"""获取或创建MCP客户端
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server_url: 服务器URL
|
|
||||||
connection_config: 连接配置
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MCP客户端实例
|
|
||||||
"""
|
|
||||||
async with self._lock:
|
|
||||||
if server_url in self._clients:
|
|
||||||
client = self._clients[server_url]
|
|
||||||
if client.is_connected:
|
|
||||||
return client
|
|
||||||
else:
|
|
||||||
# 尝试重连
|
|
||||||
if await client.connect():
|
|
||||||
return client
|
|
||||||
else:
|
|
||||||
# 移除失效的客户端
|
|
||||||
del self._clients[server_url]
|
|
||||||
|
|
||||||
# 检查连接数限制
|
|
||||||
if len(self._clients) >= self.max_connections:
|
|
||||||
# 移除最旧的连接
|
|
||||||
oldest_url = next(iter(self._clients))
|
|
||||||
await self._clients[oldest_url].disconnect()
|
|
||||||
del self._clients[oldest_url]
|
|
||||||
|
|
||||||
# 创建新客户端
|
|
||||||
client = MCPClient(server_url, connection_config)
|
|
||||||
if await client.connect():
|
|
||||||
self._clients[server_url] = client
|
|
||||||
return client
|
|
||||||
else:
|
else:
|
||||||
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
|
await self._connect_http()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
||||||
|
raise MCPConnectionError(f"连接失败: {e}")
|
||||||
|
|
||||||
async def disconnect_all(self):
|
async def disconnect(self):
|
||||||
"""断开所有连接"""
|
"""断开连接"""
|
||||||
async with self._lock:
|
try:
|
||||||
for client in self._clients.values():
|
if self._websocket:
|
||||||
await client.disconnect()
|
await self._websocket.close()
|
||||||
self._clients.clear()
|
self._websocket = None
|
||||||
|
|
||||||
def get_pool_status(self) -> Dict[str, Any]:
|
if self._session:
|
||||||
"""获取连接池状态"""
|
await self._session.close()
|
||||||
return {
|
self._session = None
|
||||||
"total_connections": len(self._clients),
|
|
||||||
"max_connections": self.max_connections,
|
except Exception as e:
|
||||||
"connections": {
|
logger.error(f"断开连接失败: {e}")
|
||||||
url: client.get_connection_info()
|
|
||||||
for url, client in self._clients.items()
|
async def _connect_websocket(self):
|
||||||
|
"""WebSocket 连接"""
|
||||||
|
headers = self._build_headers()
|
||||||
|
|
||||||
|
self._websocket = await websockets.connect(
|
||||||
|
self.server_url,
|
||||||
|
extra_headers=headers,
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# 启动消息处理
|
||||||
|
asyncio.create_task(self._handle_websocket_messages())
|
||||||
|
|
||||||
|
# 发送初始化消息
|
||||||
|
await self._send_initialize()
|
||||||
|
|
||||||
|
async def _connect_http(self):
|
||||||
|
"""HTTP 连接"""
|
||||||
|
headers = self._build_headers()
|
||||||
|
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||||
|
|
||||||
|
self._session = aiohttp.ClientSession(
|
||||||
|
headers=headers,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# 对于 ModelScope MCP 服务,需要先发送初始化请求
|
||||||
|
if "modelscope.net" in self.server_url:
|
||||||
|
await self._initialize_modelscope_session()
|
||||||
|
|
||||||
|
async def _initialize_modelscope_session(self):
|
||||||
|
"""初始化 ModelScope MCP 会话"""
|
||||||
|
init_request = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": self._get_request_id(),
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {"tools": {}},
|
||||||
|
"clientInfo": {
|
||||||
|
"name": "MemoryBear",
|
||||||
|
"version": "1.0.0"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self._session.post(
|
||||||
|
self.server_url,
|
||||||
|
json=init_request
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
|
init_response = await response.json()
|
||||||
|
if "error" in init_response:
|
||||||
|
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||||
|
|
||||||
|
# 获取 session ID
|
||||||
|
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||||
|
if session_id:
|
||||||
|
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||||
|
|
||||||
|
# 发送 initialized 通知
|
||||||
|
initialized_notification = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "notifications/initialized"
|
||||||
|
}
|
||||||
|
|
||||||
|
async with self._session.post(
|
||||||
|
self.server_url,
|
||||||
|
json=initialized_notification
|
||||||
|
) as notif_response:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||||
|
|
||||||
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
|
"""构建请求头"""
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加认证头
|
||||||
|
auth_config = self.connection_config.get("auth_config", {})
|
||||||
|
auth_type = self.connection_config.get("auth_type", "none")
|
||||||
|
|
||||||
|
if auth_type == "bearer_token":
|
||||||
|
token = auth_config.get("token")
|
||||||
|
if token:
|
||||||
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
|
elif auth_type == "api_key":
|
||||||
|
key = auth_config.get("api_key")
|
||||||
|
header_name = auth_config.get("key_name", "X-API-Key")
|
||||||
|
if key:
|
||||||
|
headers[header_name] = key
|
||||||
|
elif auth_type == "basic_auth":
|
||||||
|
username = auth_config.get("username")
|
||||||
|
password = auth_config.get("password")
|
||||||
|
if username and password:
|
||||||
|
import base64
|
||||||
|
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||||
|
headers["Authorization"] = f"Basic {credentials}"
|
||||||
|
|
||||||
|
return headers
|
||||||
|
|
||||||
|
async def _send_initialize(self):
|
||||||
|
"""发送初始化消息"""
|
||||||
|
init_message = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": self._get_request_id(),
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {"tools": {}},
|
||||||
|
"clientInfo": {
|
||||||
|
"name": "MemoryBear",
|
||||||
|
"version": "1.0.0"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await self._websocket.send(json.dumps(init_message))
|
||||||
|
|
||||||
|
# 等待初始化响应
|
||||||
|
response = await asyncio.wait_for(
|
||||||
|
self._websocket.recv(),
|
||||||
|
timeout=self.timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
init_response = json.loads(response)
|
||||||
|
if "error" in init_response:
|
||||||
|
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||||
|
|
||||||
|
async def _handle_websocket_messages(self):
|
||||||
|
"""处理 WebSocket 消息"""
|
||||||
|
try:
|
||||||
|
while self._websocket and not self._websocket.closed:
|
||||||
|
try:
|
||||||
|
message = await self._websocket.recv()
|
||||||
|
data = json.loads(message)
|
||||||
|
|
||||||
|
# 处理响应
|
||||||
|
if "id" in data:
|
||||||
|
request_id = str(data["id"])
|
||||||
|
if request_id in self._pending_requests:
|
||||||
|
future = self._pending_requests.pop(request_id)
|
||||||
|
if not future.done():
|
||||||
|
future.set_result(data)
|
||||||
|
|
||||||
|
except ConnectionClosed:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理WebSocket消息失败: {e}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"WebSocket消息处理异常: {e}")
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||||
|
"""调用工具"""
|
||||||
|
request_data = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": self._get_request_id(),
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": arguments
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_websocket:
|
||||||
|
response = await self._send_websocket_request(request_data)
|
||||||
|
else:
|
||||||
|
response = await self._send_http_request(request_data)
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
error = response["error"]
|
||||||
|
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||||
|
|
||||||
|
return response.get("result", {})
|
||||||
|
|
||||||
|
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||||
|
"""获取工具列表"""
|
||||||
|
request_data = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": self._get_request_id(),
|
||||||
|
"method": "tools/list",
|
||||||
|
"params": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.is_websocket:
|
||||||
|
response = await self._send_websocket_request(request_data)
|
||||||
|
else:
|
||||||
|
response = await self._send_http_request(request_data)
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
error = response["error"]
|
||||||
|
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||||
|
|
||||||
|
result = response.get("result", {})
|
||||||
|
return result.get("tools", [])
|
||||||
|
|
||||||
|
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""发送WebSocket请求"""
|
||||||
|
request_id = str(request_data["id"])
|
||||||
|
future = asyncio.Future()
|
||||||
|
self._pending_requests[request_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._websocket.send(json.dumps(request_data))
|
||||||
|
response = await asyncio.wait_for(future, timeout=self.timeout)
|
||||||
|
return response
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending_requests.pop(request_id, None)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""发送HTTP请求"""
|
||||||
|
try:
|
||||||
|
async with self._session.post(
|
||||||
|
self.server_url,
|
||||||
|
json=request_data
|
||||||
|
) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
error_text = await response.text()
|
||||||
|
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||||
|
|
||||||
|
def _get_request_id(self) -> str:
|
||||||
|
"""获取请求ID"""
|
||||||
|
self._request_id += 1
|
||||||
|
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
|
"""MCP服务管理器 - 简化版本"""
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, List, Optional, Tuple
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -8,136 +6,53 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
|
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.tools.mcp.client import MCPClient, MCPConnectionPool
|
from app.core.tools.mcp.base import MCPToolManager
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
class MCPServiceManager:
|
class MCPServiceManager:
|
||||||
"""MCP服务管理器 - 管理MCP服务的生命周期"""
|
"""MCP服务管理器 - 简化版本,主要用于工具创建"""
|
||||||
|
|
||||||
def __init__(self, db: Session = None):
|
def __init__(self, db: Session = None):
|
||||||
"""初始化MCP服务管理器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db: 数据库会话(可选)
|
|
||||||
"""
|
|
||||||
self.db = db
|
self.db = db
|
||||||
if db:
|
self.tool_manager = MCPToolManager(db) if db else None
|
||||||
self.connection_pool = MCPConnectionPool(max_connections=20)
|
|
||||||
else:
|
|
||||||
self.connection_pool = None
|
|
||||||
|
|
||||||
# 服务状态管理
|
async def create_mcp_tool(
|
||||||
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
|
|
||||||
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
|
|
||||||
|
|
||||||
# 配置
|
|
||||||
self.health_check_interval = 60 # 健康检查间隔(秒)
|
|
||||||
self.max_retry_attempts = 3 # 最大重试次数
|
|
||||||
self.retry_delay = 5 # 重试延迟(秒)
|
|
||||||
|
|
||||||
# 状态
|
|
||||||
self._running = False
|
|
||||||
self._manager_task = None
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动服务管理器"""
|
|
||||||
if self._running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._running = True
|
|
||||||
logger.info("MCP服务管理器启动")
|
|
||||||
|
|
||||||
# 加载现有服务
|
|
||||||
await self._load_existing_services()
|
|
||||||
|
|
||||||
# 启动管理任务
|
|
||||||
self._manager_task = asyncio.create_task(self._management_loop())
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止服务管理器"""
|
|
||||||
if not self._running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._running = False
|
|
||||||
logger.info("MCP服务管理器停止")
|
|
||||||
|
|
||||||
# 停止管理任务
|
|
||||||
if self._manager_task:
|
|
||||||
self._manager_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._manager_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 停止所有监控任务
|
|
||||||
for task in self._monitoring_tasks.values():
|
|
||||||
task.cancel()
|
|
||||||
|
|
||||||
if self._monitoring_tasks:
|
|
||||||
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
|
|
||||||
|
|
||||||
self._monitoring_tasks.clear()
|
|
||||||
|
|
||||||
# 断开所有连接
|
|
||||||
await self.connection_pool.disconnect_all()
|
|
||||||
|
|
||||||
async def register_service(
|
|
||||||
self,
|
self,
|
||||||
server_url: str,
|
server_url: str,
|
||||||
connection_config: Dict[str, Any],
|
connection_config: Dict[str, Any],
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
|
tool_name: str,
|
||||||
service_name: str = None
|
service_name: str = None
|
||||||
) -> Tuple[bool, str, Optional[str]]:
|
) -> Tuple[bool, str, Optional[str]]:
|
||||||
"""注册MCP服务
|
"""创建单个MCP工具
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_url: 服务器URL
|
server_url: 服务器URL
|
||||||
connection_config: 连接配置
|
connection_config: 连接配置
|
||||||
tenant_id: 租户ID
|
tenant_id: 租户ID
|
||||||
service_name: 服务名称(可选)
|
tool_name: 具体工具名称
|
||||||
|
service_name: 服务名称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(是否成功, 服务ID或错误信息, 错误详情)
|
(是否成功, 工具ID或错误信息, 错误详情)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 检查服务是否已存在
|
if not service_name:
|
||||||
existing_service = self.db.query(MCPToolConfig).filter(
|
service_name = f"mcp_{tool_name}"
|
||||||
MCPToolConfig.server_url == server_url
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if existing_service:
|
|
||||||
return False, "服务已存在", f"URL {server_url} 已被注册"
|
|
||||||
|
|
||||||
# 测试连接
|
|
||||||
try:
|
|
||||||
client = MCPClient(server_url, connection_config)
|
|
||||||
if not await client.connect():
|
|
||||||
return False, "连接测试失败", "无法连接到MCP服务器"
|
|
||||||
|
|
||||||
# 获取可用工具
|
|
||||||
available_tools = await client.list_tools()
|
|
||||||
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
|
|
||||||
|
|
||||||
await client.disconnect()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
return False, "连接测试失败", str(e)
|
|
||||||
|
|
||||||
# 创建工具配置
|
# 创建工具配置
|
||||||
if not service_name:
|
|
||||||
service_name = f"mcp_service_{server_url.split('/')[-1]}"
|
|
||||||
|
|
||||||
tool_config = ToolConfig(
|
tool_config = ToolConfig(
|
||||||
name=service_name,
|
name=service_name,
|
||||||
description=f"MCP服务 - {server_url}",
|
description=f"MCP工具: {tool_name}",
|
||||||
tool_type=ToolType.MCP.value,
|
tool_type=ToolType.MCP.value,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
version="1.0.0",
|
status=ToolStatus.AVAILABLE.value,
|
||||||
config_data={
|
config_data={
|
||||||
"server_url": server_url,
|
"server_url": server_url,
|
||||||
"connection_config": connection_config
|
"connection_config": connection_config,
|
||||||
|
"tool_name": tool_name
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -149,460 +64,22 @@ class MCPServiceManager:
|
|||||||
id=tool_config.id,
|
id=tool_config.id,
|
||||||
server_url=server_url,
|
server_url=server_url,
|
||||||
connection_config=connection_config,
|
connection_config=connection_config,
|
||||||
available_tools=tool_names,
|
available_tools=[tool_name],
|
||||||
health_status="healthy",
|
health_status="unknown",
|
||||||
last_health_check=datetime.now()
|
last_health_check=datetime.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.db.add(mcp_config)
|
self.db.add(mcp_config)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
service_id = str(tool_config.id)
|
logger.info(f"MCP工具创建成功: {tool_config.id} ({tool_name})")
|
||||||
|
return True, str(tool_config.id), None
|
||||||
# 添加到内存管理
|
|
||||||
self._services[service_id] = {
|
|
||||||
"id": service_id,
|
|
||||||
"server_url": server_url,
|
|
||||||
"connection_config": connection_config,
|
|
||||||
"tenant_id": tenant_id,
|
|
||||||
"available_tools": tool_names,
|
|
||||||
"status": "healthy",
|
|
||||||
"last_health_check": time.time(),
|
|
||||||
"retry_count": 0,
|
|
||||||
"created_at": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
# 启动监控
|
|
||||||
await self._start_service_monitoring(service_id)
|
|
||||||
|
|
||||||
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
|
|
||||||
return True, service_id, None
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.db.rollback()
|
self.db.rollback()
|
||||||
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
|
logger.error(f"创建MCP工具失败: {tool_name}, 错误: {e}")
|
||||||
return False, "注册失败", str(e)
|
return False, "创建失败", str(e)
|
||||||
|
|
||||||
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
|
def get_tool_manager(self) -> MCPToolManager:
|
||||||
"""注销MCP服务
|
"""获取工具管理器实例"""
|
||||||
|
return self.tool_manager
|
||||||
Args:
|
|
||||||
service_id: 服务ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(是否成功, 错误信息)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 从数据库删除
|
|
||||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
|
||||||
if not tool_config:
|
|
||||||
return False, "服务不存在"
|
|
||||||
|
|
||||||
self.db.delete(tool_config)
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
# 停止监控
|
|
||||||
await self._stop_service_monitoring(service_id)
|
|
||||||
|
|
||||||
# 从内存移除
|
|
||||||
if service_id in self._services:
|
|
||||||
del self._services[service_id]
|
|
||||||
|
|
||||||
logger.info(f"MCP服务注销成功: {service_id}")
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.db.rollback()
|
|
||||||
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
|
|
||||||
return False, str(e)
|
|
||||||
|
|
||||||
async def update_service(
|
|
||||||
self,
|
|
||||||
service_id: str,
|
|
||||||
connection_config: Dict[str, Any] = None,
|
|
||||||
enabled: bool = None
|
|
||||||
) -> Tuple[bool, str]:
|
|
||||||
"""更新MCP服务配置
|
|
||||||
|
|
||||||
Args:
|
|
||||||
service_id: 服务ID
|
|
||||||
connection_config: 新的连接配置
|
|
||||||
enabled: 是否启用
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(是否成功, 错误信息)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 更新数据库
|
|
||||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
|
||||||
MCPToolConfig.id == uuid.UUID(service_id)
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not mcp_config:
|
|
||||||
return False, "服务不存在"
|
|
||||||
|
|
||||||
tool_config = mcp_config.base_config
|
|
||||||
|
|
||||||
if connection_config is not None:
|
|
||||||
mcp_config.connection_config = connection_config
|
|
||||||
tool_config.config_data["connection_config"] = connection_config
|
|
||||||
|
|
||||||
if enabled is not None:
|
|
||||||
tool_config.is_enabled = enabled
|
|
||||||
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
# 更新内存状态
|
|
||||||
if service_id in self._services:
|
|
||||||
if connection_config is not None:
|
|
||||||
self._services[service_id]["connection_config"] = connection_config
|
|
||||||
|
|
||||||
# 如果配置有变化,重启监控
|
|
||||||
if connection_config is not None:
|
|
||||||
await self._restart_service_monitoring(service_id)
|
|
||||||
|
|
||||||
logger.info(f"MCP服务更新成功: {service_id}")
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.db.rollback()
|
|
||||||
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
|
|
||||||
return False, str(e)
|
|
||||||
|
|
||||||
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""获取服务状态
|
|
||||||
|
|
||||||
Args:
|
|
||||||
service_id: 服务ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
服务状态信息
|
|
||||||
"""
|
|
||||||
if service_id not in self._services:
|
|
||||||
return None
|
|
||||||
|
|
||||||
service_info = self._services[service_id].copy()
|
|
||||||
|
|
||||||
# 添加实时健康检查
|
|
||||||
try:
|
|
||||||
client = await self.connection_pool.get_client(
|
|
||||||
service_info["server_url"],
|
|
||||||
service_info["connection_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
health_status = await client.health_check()
|
|
||||||
service_info["real_time_health"] = health_status
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
service_info["real_time_health"] = {
|
|
||||||
"healthy": False,
|
|
||||||
"error": str(e),
|
|
||||||
"timestamp": time.time()
|
|
||||||
}
|
|
||||||
|
|
||||||
return service_info
|
|
||||||
|
|
||||||
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
|
|
||||||
"""列出所有服务
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: 租户ID过滤
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
服务列表
|
|
||||||
"""
|
|
||||||
services = []
|
|
||||||
|
|
||||||
for service_id, service_info in self._services.items():
|
|
||||||
if tenant_id and service_info["tenant_id"] != tenant_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
services.append(service_info.copy())
|
|
||||||
|
|
||||||
return services
|
|
||||||
|
|
||||||
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
|
|
||||||
"""获取服务的可用工具
|
|
||||||
|
|
||||||
Args:
|
|
||||||
service_id: 服务ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
工具列表
|
|
||||||
"""
|
|
||||||
if service_id not in self._services:
|
|
||||||
return []
|
|
||||||
|
|
||||||
service_info = self._services[service_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = await self.connection_pool.get_client(
|
|
||||||
service_info["server_url"],
|
|
||||||
service_info["connection_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
# 更新缓存的工具列表
|
|
||||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
|
||||||
service_info["available_tools"] = tool_names
|
|
||||||
|
|
||||||
# 更新数据库
|
|
||||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
|
||||||
MCPToolConfig.id == uuid.UUID(service_id)
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if mcp_config:
|
|
||||||
mcp_config.available_tools = tool_names
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
return tools
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def call_service_tool(
|
|
||||||
self,
|
|
||||||
service_id: str,
|
|
||||||
tool_name: str,
|
|
||||||
arguments: Dict[str, Any],
|
|
||||||
timeout: int = 30
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""调用服务工具
|
|
||||||
|
|
||||||
Args:
|
|
||||||
service_id: 服务ID
|
|
||||||
tool_name: 工具名称
|
|
||||||
arguments: 工具参数
|
|
||||||
timeout: 超时时间
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
执行结果
|
|
||||||
"""
|
|
||||||
if service_id not in self._services:
|
|
||||||
raise ValueError(f"服务不存在: {service_id}")
|
|
||||||
|
|
||||||
service_info = self._services[service_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
client = await self.connection_pool.get_client(
|
|
||||||
service_info["server_url"],
|
|
||||||
service_info["connection_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await client.call_tool(tool_name, arguments, timeout)
|
|
||||||
|
|
||||||
# 更新服务状态为健康
|
|
||||||
service_info["status"] = "healthy"
|
|
||||||
service_info["last_health_check"] = time.time()
|
|
||||||
service_info["retry_count"] = 0
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 更新服务状态为错误
|
|
||||||
service_info["status"] = "error"
|
|
||||||
service_info["last_error"] = str(e)
|
|
||||||
service_info["retry_count"] += 1
|
|
||||||
|
|
||||||
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _load_existing_services(self):
|
|
||||||
"""加载现有服务"""
|
|
||||||
try:
|
|
||||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
|
||||||
ToolConfig.status == ToolStatus.AVAILABLE.value,
|
|
||||||
ToolConfig.tool_type == ToolType.MCP.value
|
|
||||||
).all()
|
|
||||||
|
|
||||||
for mcp_config in mcp_configs:
|
|
||||||
tool_config = mcp_config.base_config
|
|
||||||
service_id = str(mcp_config.id)
|
|
||||||
|
|
||||||
self._services[service_id] = {
|
|
||||||
"id": service_id,
|
|
||||||
"server_url": mcp_config.server_url,
|
|
||||||
"connection_config": mcp_config.connection_config or {},
|
|
||||||
"tenant_id": tool_config.tenant_id,
|
|
||||||
"available_tools": mcp_config.available_tools or [],
|
|
||||||
"status": mcp_config.health_status or "unknown",
|
|
||||||
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
|
|
||||||
"retry_count": 0,
|
|
||||||
"created_at": tool_config.created_at.timestamp()
|
|
||||||
}
|
|
||||||
|
|
||||||
# 启动监控
|
|
||||||
await self._start_service_monitoring(service_id)
|
|
||||||
|
|
||||||
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载现有服务失败: {e}")
|
|
||||||
|
|
||||||
async def _start_service_monitoring(self, service_id: str):
|
|
||||||
"""启动服务监控"""
|
|
||||||
if service_id in self._monitoring_tasks:
|
|
||||||
return
|
|
||||||
|
|
||||||
task = asyncio.create_task(self._monitor_service(service_id))
|
|
||||||
self._monitoring_tasks[service_id] = task
|
|
||||||
|
|
||||||
async def _stop_service_monitoring(self, service_id: str):
|
|
||||||
"""停止服务监控"""
|
|
||||||
if service_id in self._monitoring_tasks:
|
|
||||||
task = self._monitoring_tasks.pop(service_id)
|
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def _restart_service_monitoring(self, service_id: str):
|
|
||||||
"""重启服务监控"""
|
|
||||||
await self._stop_service_monitoring(service_id)
|
|
||||||
await self._start_service_monitoring(service_id)
|
|
||||||
|
|
||||||
async def _monitor_service(self, service_id: str):
|
|
||||||
"""监控单个服务"""
|
|
||||||
try:
|
|
||||||
while self._running and service_id in self._services:
|
|
||||||
service_info = self._services[service_id]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 执行健康检查
|
|
||||||
client = await self.connection_pool.get_client(
|
|
||||||
service_info["server_url"],
|
|
||||||
service_info["connection_config"]
|
|
||||||
)
|
|
||||||
|
|
||||||
health_status = await client.health_check()
|
|
||||||
|
|
||||||
if health_status["healthy"]:
|
|
||||||
# 服务健康
|
|
||||||
service_info["status"] = "healthy"
|
|
||||||
service_info["retry_count"] = 0
|
|
||||||
|
|
||||||
# 更新工具列表
|
|
||||||
try:
|
|
||||||
tools = await client.list_tools()
|
|
||||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
|
||||||
service_info["available_tools"] = tool_names
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 服务不健康
|
|
||||||
service_info["status"] = "unhealthy"
|
|
||||||
service_info["last_error"] = health_status.get("error", "健康检查失败")
|
|
||||||
service_info["retry_count"] += 1
|
|
||||||
|
|
||||||
service_info["last_health_check"] = time.time()
|
|
||||||
|
|
||||||
# 更新数据库
|
|
||||||
await self._update_service_health_in_db(service_id, health_status)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 监控异常
|
|
||||||
service_info["status"] = "error"
|
|
||||||
service_info["last_error"] = str(e)
|
|
||||||
service_info["retry_count"] += 1
|
|
||||||
service_info["last_health_check"] = time.time()
|
|
||||||
|
|
||||||
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
|
|
||||||
|
|
||||||
# 如果重试次数过多,暂停监控
|
|
||||||
if service_info["retry_count"] >= self.max_retry_attempts:
|
|
||||||
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
|
|
||||||
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
|
|
||||||
service_info["retry_count"] = 0 # 重置重试计数
|
|
||||||
|
|
||||||
# 等待下次检查
|
|
||||||
await asyncio.sleep(self.health_check_interval)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
|
|
||||||
|
|
||||||
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
|
|
||||||
"""更新数据库中的服务健康状态"""
|
|
||||||
try:
|
|
||||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
|
||||||
MCPToolConfig.id == uuid.UUID(service_id)
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if mcp_config:
|
|
||||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
|
||||||
mcp_config.last_health_check = datetime.now()
|
|
||||||
|
|
||||||
if not health_status["healthy"]:
|
|
||||||
mcp_config.error_message = health_status.get("error", "")
|
|
||||||
else:
|
|
||||||
mcp_config.error_message = None
|
|
||||||
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
|
|
||||||
self.db.rollback()
|
|
||||||
|
|
||||||
async def _management_loop(self):
|
|
||||||
"""管理循环 - 处理服务清理等任务"""
|
|
||||||
try:
|
|
||||||
while self._running:
|
|
||||||
# 清理失效的服务
|
|
||||||
await self._cleanup_failed_services()
|
|
||||||
|
|
||||||
# 等待下次循环
|
|
||||||
await asyncio.sleep(300) # 5分钟
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"管理循环异常: {e}")
|
|
||||||
|
|
||||||
async def _cleanup_failed_services(self):
|
|
||||||
"""清理长期失效的服务"""
|
|
||||||
try:
|
|
||||||
current_time = time.time()
|
|
||||||
cleanup_threshold = 24 * 60 * 60 # 24小时
|
|
||||||
|
|
||||||
services_to_cleanup = []
|
|
||||||
|
|
||||||
for service_id, service_info in self._services.items():
|
|
||||||
# 检查服务是否长期失效
|
|
||||||
if (service_info["status"] in ["error", "unhealthy"] and
|
|
||||||
current_time - service_info["last_health_check"] > cleanup_threshold):
|
|
||||||
|
|
||||||
services_to_cleanup.append(service_id)
|
|
||||||
|
|
||||||
for service_id in services_to_cleanup:
|
|
||||||
logger.warning(f"清理长期失效的服务: {service_id}")
|
|
||||||
|
|
||||||
# 停止监控但不删除数据库记录
|
|
||||||
await self._stop_service_monitoring(service_id)
|
|
||||||
|
|
||||||
# 标记为禁用
|
|
||||||
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
|
|
||||||
if tool_config:
|
|
||||||
tool_config.is_enabled = False
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
# 从内存移除
|
|
||||||
del self._services[service_id]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"清理失效服务失败: {e}")
|
|
||||||
|
|
||||||
def get_manager_status(self) -> Dict[str, Any]:
|
|
||||||
"""获取管理器状态"""
|
|
||||||
return {
|
|
||||||
"running": self._running,
|
|
||||||
"total_services": len(self._services),
|
|
||||||
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
|
|
||||||
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
|
|
||||||
"monitoring_tasks": len(self._monitoring_tasks),
|
|
||||||
"connection_pool_status": self.connection_pool.get_pool_status()
|
|
||||||
}
|
|
||||||
@@ -77,7 +77,7 @@ class AppChatService:
|
|||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(config, 'tools') and config.tools:
|
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||||
for tool_config in config.tools:
|
for tool_config in config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
@@ -109,20 +109,21 @@ class AppChatService:
|
|||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
# web_tools = config.tools
|
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||||
# web_search_choice = web_tools.get("web_search", {})
|
web_tools = config.tools
|
||||||
# web_search_enable = web_search_choice.get("enabled", False)
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
# if web_search == True:
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
# if web_search_enable == True:
|
if web_search == True:
|
||||||
# search_tool = create_web_search_tool({})
|
if web_search_enable == True:
|
||||||
# tools.append(search_tool)
|
search_tool = create_web_search_tool({})
|
||||||
#
|
tools.append(search_tool)
|
||||||
# logger.debug(
|
|
||||||
# "已添加网络搜索工具",
|
logger.debug(
|
||||||
# extra={
|
"已添加网络搜索工具",
|
||||||
# "tool_count": len(tools)
|
extra={
|
||||||
# }
|
"tool_count": len(tools)
|
||||||
# )
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
@@ -226,7 +227,7 @@ class AppChatService:
|
|||||||
# 获取工具服务
|
# 获取工具服务
|
||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
if hasattr(config, 'tools') and config.tools:
|
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list):
|
||||||
for tool_config in config.tools:
|
for tool_config in config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
@@ -258,20 +259,21 @@ class AppChatService:
|
|||||||
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
memory_tool = create_long_term_memory_tool(memory_config, user_id)
|
||||||
tools.append(memory_tool)
|
tools.append(memory_tool)
|
||||||
|
|
||||||
# web_tools = config.tools
|
if hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict):
|
||||||
# web_search_choice = web_tools.get("web_search", {})
|
web_tools = config.tools
|
||||||
# web_search_enable = web_search_choice.get("enabled", False)
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
# if web_search == True:
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
# if web_search_enable == True:
|
if web_search == True:
|
||||||
# search_tool = create_web_search_tool({})
|
if web_search_enable == True:
|
||||||
# tools.append(search_tool)
|
search_tool = create_web_search_tool({})
|
||||||
#
|
tools.append(search_tool)
|
||||||
# logger.debug(
|
|
||||||
# "已添加网络搜索工具",
|
logger.debug(
|
||||||
# extra={
|
"已添加网络搜索工具",
|
||||||
# "tool_count": len(tools)
|
extra={
|
||||||
# }
|
"tool_count": len(tools)
|
||||||
# )
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 获取模型参数
|
# 获取模型参数
|
||||||
model_parameters = config.model_parameters
|
model_parameters = config.model_parameters
|
||||||
|
|||||||
@@ -297,19 +297,35 @@ class DraftRunService:
|
|||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list):
|
||||||
for tool_config in agent_config.tools:
|
if hasattr(agent_config, 'tools') and agent_config.tools:
|
||||||
if tool_config.get("enabled", False):
|
for tool_config in agent_config.tools:
|
||||||
# 根据工具名称查找工具实例
|
if tool_config.get("enabled", False):
|
||||||
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
# 根据工具名称查找工具实例
|
||||||
ToolRepository.get_tenant_id_by_workspace_id(
|
tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""),
|
||||||
self.db, str(workspace_id)))
|
ToolRepository.get_tenant_id_by_workspace_id(
|
||||||
if tool_instance:
|
self.db, str(workspace_id)))
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
if tool_instance:
|
||||||
continue
|
if tool_instance.name == "baidu_search_tool" and not web_search:
|
||||||
# 转换为LangChain工具
|
continue
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
# 转换为LangChain工具
|
||||||
tools.append(langchain_tool)
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
|
tools.append(langchain_tool)
|
||||||
|
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||||
|
web_tools = agent_config.tools
|
||||||
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
|
if web_search == True:
|
||||||
|
if web_search_enable == True:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加网络搜索工具",
|
||||||
|
extra={
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
@@ -507,7 +523,7 @@ class DraftRunService:
|
|||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
# 从配置中获取启用的工具
|
# 从配置中获取启用的工具
|
||||||
if hasattr(agent_config, 'tools') and agent_config.tools:
|
if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||||
for tool_config in agent_config.tools:
|
for tool_config in agent_config.tools:
|
||||||
if tool_config.get("enabled", False):
|
if tool_config.get("enabled", False):
|
||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
@@ -520,6 +536,22 @@ class DraftRunService:
|
|||||||
# 转换为LangChain工具
|
# 转换为LangChain工具
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
|
elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict):
|
||||||
|
web_tools = agent_config.tools
|
||||||
|
web_search_choice = web_tools.get("web_search", {})
|
||||||
|
web_search_enable = web_search_choice.get("enabled", False)
|
||||||
|
if web_search == True:
|
||||||
|
if web_search_enable == True:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"已添加网络搜索工具",
|
||||||
|
extra={
|
||||||
|
"tool_count": len(tools)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# 添加知识库检索工具
|
# 添加知识库检索工具
|
||||||
if agent_config.knowledge_retrieval:
|
if agent_config.knowledge_retrieval:
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.tools.mcp import MCPClient
|
from app.core.tools.mcp import MCPToolManager, SimpleMCPClient
|
||||||
from app.repositories.tool_repository import (
|
from app.repositories.tool_repository import (
|
||||||
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
ToolRepository, BuiltinToolRepository, CustomToolRepository,
|
||||||
MCPToolRepository, ToolExecutionRepository
|
MCPToolRepository, ToolExecutionRepository
|
||||||
@@ -43,6 +43,9 @@ class ToolService:
|
|||||||
self.db = db
|
self.db = db
|
||||||
self._tool_cache: Dict[str, BaseTool] = {}
|
self._tool_cache: Dict[str, BaseTool] = {}
|
||||||
|
|
||||||
|
# MCP管理器
|
||||||
|
self.mcp_tool_manager = MCPToolManager(db)
|
||||||
|
|
||||||
# 初始化仓储
|
# 初始化仓储
|
||||||
self.tool_repo = ToolRepository()
|
self.tool_repo = ToolRepository()
|
||||||
self.builtin_repo = BuiltinToolRepository()
|
self.builtin_repo = BuiltinToolRepository()
|
||||||
@@ -675,23 +678,85 @@ class ToolService:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||||
"""获取MCP工具的方法"""
|
"""获取MCP工具的方法和参数"""
|
||||||
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||||
if not mcp_config:
|
if not mcp_config:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
available_tools = mcp_config.available_tools or []
|
available_tools = mcp_config.available_tools or []
|
||||||
if not available_tools:
|
if not available_tools:
|
||||||
return []
|
# 如果没有工具列表,尝试同步
|
||||||
|
try:
|
||||||
|
success, tools, _ = await self.mcp_tool_manager.discover_tools(
|
||||||
|
mcp_config.server_url, mcp_config.connection_config or {}
|
||||||
|
)
|
||||||
|
if success:
|
||||||
|
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||||
|
mcp_config.available_tools = tool_names
|
||||||
|
self.db.commit()
|
||||||
|
available_tools = tool_names
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"同步MCP工具列表失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
methods = []
|
methods = []
|
||||||
for tool_name in available_tools:
|
|
||||||
methods.append({
|
# 获取工具详细信息
|
||||||
"method_id": tool_name,
|
try:
|
||||||
"name": tool_name,
|
success, tools, _ = await self.mcp_tool_manager.discover_tools(
|
||||||
"description": f"MCP工具: {tool_name}",
|
mcp_config.server_url, mcp_config.connection_config or {}
|
||||||
"parameters": [] # MCP工具参数需要动态获取
|
)
|
||||||
})
|
|
||||||
|
if success:
|
||||||
|
tools_dict = {tool.get("name"): tool for tool in tools if tool.get("name")}
|
||||||
|
|
||||||
|
for tool_name in available_tools:
|
||||||
|
tool_info = tools_dict.get(tool_name, {})
|
||||||
|
|
||||||
|
# 解析工具参数
|
||||||
|
parameters = []
|
||||||
|
input_schema = tool_info.get("inputSchema", {})
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required_fields = input_schema.get("required", [])
|
||||||
|
|
||||||
|
for param_name, param_def in properties.items():
|
||||||
|
parameters.append({
|
||||||
|
"name": param_name,
|
||||||
|
"type": param_def.get("type", "string"),
|
||||||
|
"description": param_def.get("description", ""),
|
||||||
|
"required": param_name in required_fields,
|
||||||
|
"default": param_def.get("default"),
|
||||||
|
"enum": param_def.get("enum"),
|
||||||
|
"minimum": param_def.get("minimum"),
|
||||||
|
"maximum": param_def.get("maximum")
|
||||||
|
})
|
||||||
|
|
||||||
|
methods.append({
|
||||||
|
"method_id": tool_name,
|
||||||
|
"name": tool_name,
|
||||||
|
"description": tool_info.get("description", f"MCP工具: {tool_name}"),
|
||||||
|
"parameters": parameters
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# 如果无法获取详细信息,返回基本信息
|
||||||
|
for tool_name in available_tools:
|
||||||
|
methods.append({
|
||||||
|
"method_id": tool_name,
|
||||||
|
"name": tool_name,
|
||||||
|
"description": f"MCP工具: {tool_name}",
|
||||||
|
"parameters": []
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取MCP工具详细信息失败: {e}")
|
||||||
|
# 返回基本信息
|
||||||
|
for tool_name in available_tools:
|
||||||
|
methods.append({
|
||||||
|
"method_id": tool_name,
|
||||||
|
"name": tool_name,
|
||||||
|
"description": f"MCP工具: {tool_name}",
|
||||||
|
"parameters": []
|
||||||
|
})
|
||||||
|
|
||||||
return methods
|
return methods
|
||||||
|
|
||||||
@@ -812,10 +877,14 @@ class ToolService:
|
|||||||
if not mcp_config:
|
if not mcp_config:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# 从配置中获取特定工具名称
|
||||||
|
tool_name = config.config_data.get("tool_name")
|
||||||
|
|
||||||
tool_config = {
|
tool_config = {
|
||||||
"server_url": mcp_config.server_url,
|
"server_url": mcp_config.server_url,
|
||||||
"connection_config": mcp_config.connection_config or {},
|
"connection_config": mcp_config.connection_config or {},
|
||||||
"available_tools": mcp_config.available_tools or []
|
"available_tools": mcp_config.available_tools or [],
|
||||||
|
"tool_name": tool_name # 指定具体工具
|
||||||
}
|
}
|
||||||
|
|
||||||
return MCPTool(str(config.id), tool_config)
|
return MCPTool(str(config.id), tool_config)
|
||||||
@@ -1071,71 +1140,59 @@ class ToolService:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||||
"""测试MCP连接"""
|
"""测试MCP连接并自动同步工具列表"""
|
||||||
try:
|
try:
|
||||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||||
MCPToolConfig.id == config.id
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not mcp_config:
|
if not mcp_config:
|
||||||
return {"success": False, "message": "MCP配置不存在"}
|
return {"success": False, "message": "MCP配置不存在"}
|
||||||
|
|
||||||
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
|
# 使用集成的MCP管理器测试连接
|
||||||
|
test_result = await self.mcp_tool_manager.test_tool_connection(
|
||||||
|
mcp_config.server_url, mcp_config.connection_config or {}
|
||||||
|
)
|
||||||
|
|
||||||
if await client.connect():
|
if test_result["success"]:
|
||||||
try:
|
# 连接成功,自动同步工具列表
|
||||||
# tools = await client.list_tools()
|
success, tools, error = await self.mcp_tool_manager.discover_tools(
|
||||||
await client.disconnect()
|
mcp_config.server_url, mcp_config.connection_config or {}
|
||||||
|
)
|
||||||
|
|
||||||
# 更新连接状态
|
if success:
|
||||||
|
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||||
|
|
||||||
|
# 更新数据库
|
||||||
|
mcp_config.available_tools = tool_names
|
||||||
mcp_config.last_health_check = datetime.now()
|
mcp_config.last_health_check = datetime.now()
|
||||||
mcp_config.health_status = "healthy"
|
mcp_config.health_status = "healthy"
|
||||||
mcp_config.error_message = None
|
mcp_config.error_message = None
|
||||||
|
config.status = ToolStatus.AVAILABLE.value
|
||||||
|
|
||||||
# 更新工具状态
|
|
||||||
self._update_tool_status(config)
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": "MCP连接成功",
|
"message": "MCP连接成功并同步工具列表",
|
||||||
# "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)}
|
"details": {
|
||||||
"details": {"server_url": mcp_config.server_url}
|
"server_url": mcp_config.server_url,
|
||||||
|
"tools_count": len(tool_names),
|
||||||
|
"tools": tool_names
|
||||||
|
}
|
||||||
}
|
}
|
||||||
except Exception as e:
|
else:
|
||||||
await client.disconnect()
|
return {"success": False, "message": f"同步工具失败: {error}"}
|
||||||
|
|
||||||
# 更新错误状态
|
|
||||||
mcp_config.last_health_check = datetime.now()
|
|
||||||
mcp_config.health_status = "error"
|
|
||||||
mcp_config.error_message = str(e)
|
|
||||||
self._update_tool_status(config)
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
return {"success": False, "message": f"MCP功能测试失败: {str(e)}"}
|
|
||||||
else:
|
else:
|
||||||
# 更新连接失败状态
|
# 更新错误状态
|
||||||
mcp_config.last_health_check = datetime.now()
|
mcp_config.last_health_check = datetime.now()
|
||||||
mcp_config.health_status = "error"
|
mcp_config.health_status = "error"
|
||||||
mcp_config.error_message = "连接失败"
|
mcp_config.error_message = test_result.get("error", "连接失败")
|
||||||
self._update_tool_status(config)
|
config.status = ToolStatus.ERROR.value
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
|
||||||
return {"success": False, "message": "MCP连接失败"}
|
return test_result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 更新异常状态
|
logger.error(f"测试MCP连接失败: {config.id}, 错误: {e}")
|
||||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
return {"success": False, "message": f"测试失败: {str(e)}"}
|
||||||
MCPToolConfig.id == config.id
|
|
||||||
).first()
|
|
||||||
if mcp_config:
|
|
||||||
mcp_config.last_health_check = datetime.now()
|
|
||||||
mcp_config.health_status = "error"
|
|
||||||
mcp_config.error_message = str(e)
|
|
||||||
self._update_tool_status(config)
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
return {"success": False, "message": f"MCP测试异常: {str(e)}"}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
|
async def parse_openapi_schema(schema_data: str = None, schema_url: str = None) -> Dict[str, Any]:
|
||||||
@@ -1190,57 +1247,44 @@ class ToolService:
|
|||||||
|
|
||||||
# 创建MCP客户端
|
# 创建MCP客户端
|
||||||
connection_config = mcp_config.connection_config or {}
|
connection_config = mcp_config.connection_config or {}
|
||||||
|
client = SimpleMCPClient(mcp_config.server_url, connection_config)
|
||||||
|
|
||||||
client = MCPClient(mcp_config.server_url, connection_config)
|
async with client:
|
||||||
|
# 获取工具列表
|
||||||
|
tools = await client.list_tools()
|
||||||
|
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
||||||
|
|
||||||
if await client.connect():
|
# 更新数据库
|
||||||
try:
|
mcp_config.available_tools = tool_names
|
||||||
# 获取工具列表
|
mcp_config.last_health_check = datetime.now()
|
||||||
tools = await client.list_tools()
|
mcp_config.health_status = "healthy"
|
||||||
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
|
mcp_config.error_message = None
|
||||||
|
|
||||||
# 更新数据库
|
# 更新工具状态
|
||||||
mcp_config.available_tools = tool_names
|
config.status = ToolStatus.AVAILABLE.value
|
||||||
mcp_config.last_health_check = datetime.now()
|
|
||||||
mcp_config.health_status = "healthy"
|
|
||||||
mcp_config.error_message = None
|
|
||||||
|
|
||||||
# 更新工具状态
|
self.db.commit()
|
||||||
config.status = ToolStatus.AVAILABLE.value
|
|
||||||
|
|
||||||
self.db.commit()
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": "工具列表同步成功",
|
||||||
|
"tools_count": len(tool_names),
|
||||||
|
"tools": tool_names
|
||||||
|
}
|
||||||
|
|
||||||
await client.disconnect()
|
except Exception as e:
|
||||||
|
# 更新错误状态
|
||||||
return {
|
try:
|
||||||
"success": True,
|
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||||
"message": "工具列表同步成功",
|
if mcp_config:
|
||||||
"tools_count": len(tool_names),
|
|
||||||
"tools": tool_names
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await client.disconnect()
|
|
||||||
|
|
||||||
# 更新错误状态
|
|
||||||
mcp_config.last_health_check = datetime.now()
|
mcp_config.last_health_check = datetime.now()
|
||||||
mcp_config.health_status = "error"
|
mcp_config.health_status = "error"
|
||||||
mcp_config.error_message = str(e)
|
mcp_config.error_message = str(e)
|
||||||
config.status = ToolStatus.ERROR.value
|
config.status = ToolStatus.ERROR.value
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
return {"success": False, "message": f"获取工具列表失败: {str(e)}"}
|
|
||||||
else:
|
|
||||||
# 连接失败
|
|
||||||
mcp_config.last_health_check = datetime.now()
|
|
||||||
mcp_config.health_status = "error"
|
|
||||||
mcp_config.error_message = "连接失败"
|
|
||||||
config.status = ToolStatus.ERROR.value
|
|
||||||
self.db.commit()
|
|
||||||
|
|
||||||
return {"success": False, "message": "MCP连接失败"}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
|
logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}")
|
||||||
return {"success": False, "message": f"同步失败: {str(e)}"}
|
return {"success": False, "message": f"同步失败: {str(e)}"}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user