Merge pull request #115 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(mcp tool)
This commit is contained in:
@@ -232,7 +232,7 @@ class LangchainAdapter:
|
||||
# 添加验证约束
|
||||
if param.enum:
|
||||
# 枚举值约束
|
||||
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$"
|
||||
|
||||
if param.minimum is not None:
|
||||
field_kwargs["ge"] = param.minimum
|
||||
@@ -241,7 +241,7 @@ class LangchainAdapter:
|
||||
field_kwargs["le"] = param.maximum
|
||||
|
||||
if param.pattern:
|
||||
field_kwargs["regex"] = param.pattern
|
||||
field_kwargs["pattern"] = param.pattern
|
||||
|
||||
fields[param.name] = Field(**field_kwargs)
|
||||
annotations[param.name] = python_type
|
||||
|
||||
@@ -27,20 +27,22 @@ class SimpleMCPClient:
|
||||
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
self.is_sse = "/sse" in server_url.lower()
|
||||
|
||||
# 连接状态
|
||||
self._websocket = None
|
||||
self._session = None
|
||||
self._request_id = 0
|
||||
self._pending_requests = {}
|
||||
self._server_capabilities = {}
|
||||
self._endpoint_url = None # SSE endpoint URL
|
||||
self._sse_task = None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
|
||||
async def connect(self):
|
||||
@@ -57,47 +59,157 @@ class SimpleMCPClient:
|
||||
async def disconnect(self):
|
||||
"""断开连接"""
|
||||
try:
|
||||
if self._sse_task:
|
||||
self._sse_task.cancel()
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._session:
|
||||
await self._session.close()
|
||||
self._session = None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"断开连接失败: {e}")
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""WebSocket 连接"""
|
||||
headers = self._build_headers()
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
extra_headers=headers,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
# 启动消息处理
|
||||
asyncio.create_task(self._handle_websocket_messages())
|
||||
|
||||
# 发送初始化消息
|
||||
await self._send_initialize()
|
||||
|
||||
async def _connect_http(self):
|
||||
"""HTTP 连接"""
|
||||
headers = self._build_headers()
|
||||
timeout = aiohttp.ClientTimeout(total=self.timeout)
|
||||
self._session = aiohttp.ClientSession(headers=headers, timeout=timeout)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers=headers,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
# 对于 ModelScope MCP 服务,需要先发送初始化请求
|
||||
if "modelscope.net" in self.server_url:
|
||||
if self.is_sse:
|
||||
await self._initialize_sse_session()
|
||||
elif "modelscope.net" in self.server_url:
|
||||
await self._initialize_modelscope_session()
|
||||
|
||||
async def _initialize_sse_session(self):
|
||||
"""初始化 SSE MCP 会话 - 参考 Dify 实现"""
|
||||
try:
|
||||
# 建立 SSE 连接
|
||||
response = await self._session.get(
|
||||
self.server_url,
|
||||
headers={"Accept": "text/event-stream"}
|
||||
)
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
|
||||
|
||||
# 启动 SSE 读取任务
|
||||
self._sse_task = asyncio.create_task(self._read_sse_stream(response))
|
||||
|
||||
# 等待获取 endpoint URL
|
||||
for _ in range(10):
|
||||
if self._endpoint_url:
|
||||
break
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("未能获取 endpoint URL")
|
||||
|
||||
# 发送 initialize 请求到 endpoint
|
||||
init_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
init_response = await self._send_sse_request(init_request)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
result = init_response.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
# 发送 initialized 通知
|
||||
await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"初始化连接失败: {e}")
|
||||
|
||||
async def _read_sse_stream(self, response):
|
||||
"""读取 SSE 流"""
|
||||
try:
|
||||
async for line in response.content:
|
||||
line = line.decode('utf-8').strip()
|
||||
|
||||
if line.startswith('event:'):
|
||||
continue
|
||||
|
||||
if line.startswith('data:'):
|
||||
data = line[5:].strip() # 去除 'data:' 后的空格
|
||||
if not data or data == '[DONE]':
|
||||
continue
|
||||
|
||||
try:
|
||||
# 处理 endpoint 事件(相对路径或绝对路径)
|
||||
if not self._endpoint_url:
|
||||
# 如果是相对路径,拼接成完整 URL
|
||||
if data.startswith('/'):
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
parsed = urlparse(self.server_url)
|
||||
self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}"
|
||||
else:
|
||||
self._endpoint_url = data
|
||||
logger.info(f"获取到 endpoint URL: {self._endpoint_url}")
|
||||
continue
|
||||
|
||||
# 处理 message 事件
|
||||
message = json.loads(data)
|
||||
request_id = message.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(message)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"SSE 流读取错误: {e}")
|
||||
|
||||
async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""通过 SSE endpoint 发送请求"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
request_id = request["id"]
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
async with self._session.post(self._endpoint_url, json=request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await asyncio.wait_for(future, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError("请求超时")
|
||||
|
||||
async def _send_sse_notification(self, notification: Dict[str, Any]):
|
||||
"""发送通知(无需响应)"""
|
||||
if not self._endpoint_url:
|
||||
raise MCPConnectionError("endpoint URL 未初始化")
|
||||
|
||||
async with self._session.post(self._endpoint_url, json=notification) as response:
|
||||
if response.status != 200:
|
||||
logger.warning(f"通知发送失败: {response.status}")
|
||||
|
||||
async def _initialize_modelscope_session(self):
|
||||
"""初始化 ModelScope MCP 会话"""
|
||||
init_request = {
|
||||
@@ -107,18 +219,12 @@ class SimpleMCPClient:
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=init_request
|
||||
) as response:
|
||||
async with self._session.post(self.server_url, json=init_request) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}")
|
||||
@@ -127,21 +233,16 @@ class SimpleMCPClient:
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
|
||||
# 获取 session ID
|
||||
session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id")
|
||||
if session_id:
|
||||
self._session.headers.update({"Mcp-Session-Id": session_id})
|
||||
|
||||
# 发送 initialized 通知
|
||||
initialized_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}
|
||||
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=initialized_notification
|
||||
) as notif_response:
|
||||
async with self._session.post(self.server_url, json=initialized_notification):
|
||||
pass
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
@@ -149,12 +250,18 @@ class SimpleMCPClient:
|
||||
|
||||
def _build_headers(self) -> Dict[str, str]:
|
||||
"""构建请求头"""
|
||||
# 基础 headers
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
}
|
||||
|
||||
# 添加认证头
|
||||
# 合并 connection_config 中的自定义 headers
|
||||
custom_headers = self.connection_config.get("headers", {})
|
||||
if custom_headers:
|
||||
headers.update(custom_headers)
|
||||
|
||||
# 处理认证配置(认证 headers 优先级更高)
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
|
||||
@@ -178,7 +285,7 @@ class SimpleMCPClient:
|
||||
return headers
|
||||
|
||||
async def _send_initialize(self):
|
||||
"""发送初始化消息"""
|
||||
"""发送初始化消息(WebSocket)"""
|
||||
init_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
@@ -186,124 +293,90 @@ class SimpleMCPClient:
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {}},
|
||||
"clientInfo": {
|
||||
"name": "MemoryBear",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
"clientInfo": {"name": "MemoryBear", "version": "1.0.0"}
|
||||
}
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(init_message))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
|
||||
# 等待初始化响应
|
||||
response = await asyncio.wait_for(
|
||||
self._websocket.recv(),
|
||||
timeout=self.timeout
|
||||
)
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"初始化失败: {response_data['error']}")
|
||||
|
||||
init_response = json.loads(response)
|
||||
if "error" in init_response:
|
||||
raise MCPConnectionError(f"初始化失败: {init_response['error']}")
|
||||
result = response_data.get("result", {})
|
||||
self._server_capabilities = result.get("capabilities", {})
|
||||
|
||||
await self._websocket.send(json.dumps({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
}))
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}")
|
||||
|
||||
result = response_data.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {"name": tool_name, "arguments": arguments}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
await self._websocket.send(json.dumps(request))
|
||||
response = await self._websocket.recv()
|
||||
response_data = json.loads(response)
|
||||
elif self.is_sse:
|
||||
response_data = await self._send_sse_request(request)
|
||||
else:
|
||||
async with self._session.post(self.server_url, json=request) as response:
|
||||
response_data = await response.json()
|
||||
|
||||
if "error" in response_data:
|
||||
error = response_data["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response_data.get("result", {})
|
||||
|
||||
def _get_request_id(self) -> int:
|
||||
"""生成请求 ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def _handle_websocket_messages(self):
|
||||
"""处理 WebSocket 消息"""
|
||||
try:
|
||||
while self._websocket and not self._websocket.closed:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
data = json.loads(message)
|
||||
|
||||
# 处理响应
|
||||
if "id" in data:
|
||||
request_id = str(data["id"])
|
||||
if request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
|
||||
except ConnectionClosed:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"处理WebSocket消息失败: {e}")
|
||||
|
||||
async for message in self._websocket:
|
||||
data = json.loads(message)
|
||||
request_id = data.get("id")
|
||||
if request_id and request_id in self._pending_requests:
|
||||
future = self._pending_requests.pop(request_id)
|
||||
if not future.done():
|
||||
future.set_result(data)
|
||||
except ConnectionClosed:
|
||||
logger.info("WebSocket 连接已关闭")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket消息处理异常: {e}")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any:
|
||||
"""调用工具"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
async def list_tools(self) -> List[Dict[str, Any]]:
|
||||
"""获取工具列表"""
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": self._get_request_id(),
|
||||
"method": "tools/list",
|
||||
"params": {}
|
||||
}
|
||||
|
||||
if self.is_websocket:
|
||||
response = await self._send_websocket_request(request_data)
|
||||
else:
|
||||
response = await self._send_http_request(request_data)
|
||||
|
||||
if "error" in response:
|
||||
error = response["error"]
|
||||
raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
result = response.get("result", {})
|
||||
return result.get("tools", [])
|
||||
|
||||
async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送WebSocket请求"""
|
||||
request_id = str(request_data["id"])
|
||||
future = asyncio.Future()
|
||||
self._pending_requests[request_id] = future
|
||||
|
||||
try:
|
||||
await self._websocket.send(json.dumps(request_data))
|
||||
response = await asyncio.wait_for(future, timeout=self.timeout)
|
||||
return response
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""发送HTTP请求"""
|
||||
try:
|
||||
async with self._session.post(
|
||||
self.server_url,
|
||||
json=request_data
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
|
||||
|
||||
return await response.json()
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise MCPConnectionError(f"HTTP请求失败: {e}")
|
||||
|
||||
def _get_request_id(self) -> str:
|
||||
"""获取请求ID"""
|
||||
self._request_id += 1
|
||||
return f"req_{self._request_id}_{int(time.time() * 1000)}"
|
||||
logger.error(f"WebSocket 消息处理错误: {e}")
|
||||
|
||||
Reference in New Issue
Block a user