feat(tool system): Tool system reengineering

This commit is contained in:
谢俊男
2025-12-25 17:30:20 +08:00
parent 3bcaead413
commit 04be3088a2
25 changed files with 1887 additions and 3325 deletions

View File

@@ -4,7 +4,8 @@ from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -123,33 +124,43 @@ class MCPTool(BaseTool):
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
# 这里应该实现实际的MCP连接逻辑
# 为了简化,这里只是模拟连接
from .client import MCPClient
# 测试服务器连接
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
# 尝试获取服务器信息
async with session.get(f"{self.server_url}/info") as response:
if response.status == 200:
server_info = await response.json()
self.available_tools = server_info.get("tools", [])
self._connected = True
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
raise Exception(f"服务器响应错误: {response.status}")
if self._connected:
return True
self._client = MCPClient(self.server_url, self.connection_config)
if await self._client.connect():
self._connected = True
# 更新可用工具列表
await self._update_available_tools()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
logger.error(f"MCP服务器连接失败: {self.server_url}")
return False
except Exception as e:
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def _update_available_tools(self):
"""更新可用工具列表"""
try:
if self._client and self._connected:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
except Exception as e:
logger.error(f"更新MCP工具列表失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
# 这里应该实现实际的断开逻辑
await self._client.disconnect()
self._client = None
self._connected = False
@@ -171,38 +182,15 @@ class MCPTool(BaseTool):
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
# 构建MCP请求
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
if not self._client or not self._connected:
raise Exception("MCP客户端未连接")
# 发送请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
result = await response.json()
# 检查MCP响应
if "error" in result:
error = result["error"]
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
return result.get("result", {})
try:
result = await self._client.call_tool(tool_name, arguments, timeout)
return result
except Exception as e:
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
raise
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
@@ -210,27 +198,10 @@ class MCPTool(BaseTool):
if not self._connected:
await self.connect()
# 获取工具列表
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/list"
}
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status == 200:
result = await response.json()
if "result" in result:
tools = result["result"].get("tools", [])
self.available_tools = [tool.get("name") for tool in tools]
return tools
if self._client:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
return tools
return []

View File

@@ -134,11 +134,40 @@ class MCPClient:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def _build_auth_headers(self) -> Dict[str, str]:
"""构建认证头"""
headers = {}
auth_type = self.connection_config.get("auth_type", "none")
auth_config = self.connection_config.get("auth_config", {})
if auth_type == "api_key":
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
return headers
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
extra_headers.update(auth_headers)
self._websocket = await websockets.connect(
self.server_url,
@@ -190,6 +219,8 @@ class MCPClient:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
headers.update(auth_headers)
self._session = aiohttp.ClientSession(
timeout=timeout,
@@ -251,8 +282,9 @@ class MCPClient:
except Exception as e:
logger.error(f"处理消息失败: {e}")
async def _handle_notification(self, message: Dict[str, Any]):
@staticmethod
async def _handle_notification(message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
@@ -327,7 +359,7 @@ class MCPClient:
try:
response = await self._send_request(request_data, timeout)
if not response["error"] is None:
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
@@ -372,10 +404,10 @@ class MCPClient:
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
await self._pending_requests.pop(request_id, None)
raise
except Exception as e:
self._pending_requests.pop(request_id, None)
await self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
@@ -424,9 +456,9 @@ class MCPClient:
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = time.time() - start_time
response_time = round((time.time() - start_time) * 1000)
self._last_health_check = time.time()
self._last_health_check = round(time.time() * 1000)
return {
"healthy": True,

View File

@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
@@ -148,7 +148,7 @@ class MCPServiceManager:
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
last_health_check=datetime.utcnow()
last_health_check=datetime.now()
)
self.db.add(mcp_config)
@@ -410,7 +410,8 @@ class MCPServiceManager:
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.is_enabled == True
ToolConfig.status == ToolStatus.AVAILABLE.value,
ToolConfig.tool_type == ToolType.MCP.value
).all()
for mcp_config in mcp_configs:
@@ -531,7 +532,7 @@ class MCPServiceManager:
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.utcnow()
mcp_config.last_health_check = datetime.now()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")