feat(tool system): Tool system reengineering
This commit is contained in:
@@ -4,7 +4,8 @@ from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -123,33 +124,43 @@ class MCPTool(BaseTool):
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
try:
|
||||
# 这里应该实现实际的MCP连接逻辑
|
||||
# 为了简化,这里只是模拟连接
|
||||
from .client import MCPClient
|
||||
|
||||
# 测试服务器连接
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# 尝试获取服务器信息
|
||||
async with session.get(f"{self.server_url}/info") as response:
|
||||
if response.status == 200:
|
||||
server_info = await response.json()
|
||||
self.available_tools = server_info.get("tools", [])
|
||||
self._connected = True
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"服务器响应错误: {response.status}")
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
self._client = MCPClient(self.server_url, self.connection_config)
|
||||
|
||||
if await self._client.connect():
|
||||
self._connected = True
|
||||
# 更新可用工具列表
|
||||
await self._update_available_tools()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
|
||||
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def _update_available_tools(self):
|
||||
"""更新可用工具列表"""
|
||||
try:
|
||||
if self._client and self._connected:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"更新MCP工具列表失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
# 这里应该实现实际的断开逻辑
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
@@ -171,38 +182,15 @@ class MCPTool(BaseTool):
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
# 构建MCP请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
if not self._client or not self._connected:
|
||||
raise Exception("MCP客户端未连接")
|
||||
|
||||
# 发送请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# 检查MCP响应
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
|
||||
|
||||
return result.get("result", {})
|
||||
try:
|
||||
result = await self._client.call_tool(tool_name, arguments, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
@@ -210,27 +198,10 @@ class MCPTool(BaseTool):
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 获取工具列表
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if "result" in result:
|
||||
tools = result["result"].get("tools", [])
|
||||
self.available_tools = [tool.get("name") for tool in tools]
|
||||
return tools
|
||||
if self._client:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
|
||||
@@ -134,11 +134,40 @@ class MCPClient:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def _build_auth_headers(self) -> Dict[str, str]:
|
||||
"""构建认证头"""
|
||||
headers = {}
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
|
||||
if auth_type == "api_key":
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif auth_type == "bearer_token":
|
||||
token = auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
elif auth_type == "basic_auth":
|
||||
username = auth_config.get("username")
|
||||
password = auth_config.get("password")
|
||||
if username and password:
|
||||
import base64
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
extra_headers.update(auth_headers)
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
@@ -190,6 +219,8 @@ class MCPClient:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
headers.update(auth_headers)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
@@ -251,8 +282,9 @@ class MCPClient:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
async def _handle_notification(self, message: Dict[str, Any]):
|
||||
|
||||
@staticmethod
|
||||
async def _handle_notification(message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
@@ -327,7 +359,7 @@ class MCPClient:
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if not response["error"] is None:
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
@@ -372,10 +404,10 @@ class MCPClient:
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
@@ -424,9 +456,9 @@ class MCPClient:
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = time.time() - start_time
|
||||
response_time = round((time.time() - start_time) * 1000)
|
||||
|
||||
self._last_health_check = time.time()
|
||||
self._last_health_check = round(time.time() * 1000)
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
|
||||
@@ -148,7 +148,7 @@ class MCPServiceManager:
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
last_health_check=datetime.utcnow()
|
||||
last_health_check=datetime.now()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
@@ -410,7 +410,8 @@ class MCPServiceManager:
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.is_enabled == True
|
||||
ToolConfig.status == ToolStatus.AVAILABLE.value,
|
||||
ToolConfig.tool_type == ToolType.MCP.value
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
@@ -531,7 +532,7 @@ class MCPServiceManager:
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.utcnow()
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
|
||||
Reference in New Issue
Block a user