diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py index ea5fdb96..51415732 100644 --- a/api/app/core/tools/langchain_adapter.py +++ b/api/app/core/tools/langchain_adapter.py @@ -232,7 +232,7 @@ class LangchainAdapter: # 添加验证约束 if param.enum: # 枚举值约束 - field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$" + field_kwargs["pattern"] = f"^({'|'.join(map(str, param.enum))})$" if param.minimum is not None: field_kwargs["ge"] = param.minimum @@ -241,7 +241,7 @@ class LangchainAdapter: field_kwargs["le"] = param.maximum if param.pattern: - field_kwargs["regex"] = param.pattern + field_kwargs["pattern"] = param.pattern fields[param.name] = Field(**field_kwargs) annotations[param.name] = python_type diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 2901b7ca..e513a147 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -27,20 +27,22 @@ class SimpleMCPClient: # 确定连接类型 self.is_websocket = server_url.startswith(("ws://", "wss://")) + self.is_sse = "/sse" in server_url.lower() # 连接状态 self._websocket = None self._session = None self._request_id = 0 self._pending_requests = {} + self._server_capabilities = {} + self._endpoint_url = None # SSE endpoint URL + self._sse_task = None async def __aenter__(self): - """异步上下文管理器入口""" await self.connect() return self async def __aexit__(self, exc_type, exc_val, exc_tb): - """异步上下文管理器出口""" await self.disconnect() async def connect(self): @@ -57,47 +59,157 @@ class SimpleMCPClient: async def disconnect(self): """断开连接""" try: + if self._sse_task: + self._sse_task.cancel() if self._websocket: await self._websocket.close() self._websocket = None - if self._session: await self._session.close() self._session = None - except Exception as e: logger.error(f"断开连接失败: {e}") async def _connect_websocket(self): """WebSocket 连接""" headers = self._build_headers() - self._websocket = await websockets.connect( self.server_url, extra_headers=headers, timeout=self.timeout ) - - # 启动消息处理 asyncio.create_task(self._handle_websocket_messages()) - - # 发送初始化消息 await self._send_initialize() async def _connect_http(self): """HTTP 连接""" headers = self._build_headers() timeout = aiohttp.ClientTimeout(total=self.timeout) + self._session = aiohttp.ClientSession(headers=headers, timeout=timeout) - self._session = aiohttp.ClientSession( - headers=headers, - timeout=timeout - ) - - # 对于 ModelScope MCP 服务,需要先发送初始化请求 - if "modelscope.net" in self.server_url: + if self.is_sse: + await self._initialize_sse_session() + elif "modelscope.net" in self.server_url: await self._initialize_modelscope_session() + async def _initialize_sse_session(self): + """初始化 SSE MCP 会话 - 参考 Dify 实现""" + try: + # 建立 SSE 连接 + response = await self._session.get( + self.server_url, + headers={"Accept": "text/event-stream"} + ) + + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") + + # 启动 SSE 读取任务 + self._sse_task = asyncio.create_task(self._read_sse_stream(response)) + + # 等待获取 endpoint URL + for _ in range(10): + if self._endpoint_url: + break + await asyncio.sleep(1) + + if not self._endpoint_url: + raise MCPConnectionError("未能获取 endpoint URL") + + # 发送 initialize 请求到 endpoint + init_request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {}}, + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} + } + } + + init_response = await self._send_sse_request(init_request) + if "error" in init_response: + raise MCPConnectionError(f"初始化失败: {init_response['error']}") + + result = init_response.get("result", {}) + self._server_capabilities = result.get("capabilities", {}) + + # 发送 initialized 通知 + await self._send_sse_notification({"jsonrpc": "2.0", "method": "notifications/initialized"}) + + except aiohttp.ClientError as e: + raise MCPConnectionError(f"初始化连接失败: {e}") + + async def _read_sse_stream(self, response): + """读取 SSE 流""" + try: + async for line in response.content: + line = line.decode('utf-8').strip() + + if line.startswith('event:'): + continue + + if line.startswith('data:'): + data = line[5:].strip() # 去除 'data:' 后的空格 + if not data or data == '[DONE]': + continue + + try: + # 处理 endpoint 事件(相对路径或绝对路径) + if not self._endpoint_url: + # 如果是相对路径,拼接成完整 URL + if data.startswith('/'): + from urllib.parse import urlparse, urlunparse + parsed = urlparse(self.server_url) + self._endpoint_url = f"{parsed.scheme}://{parsed.netloc}{data}" + else: + self._endpoint_url = data + logger.info(f"获取到 endpoint URL: {self._endpoint_url}") + continue + + # 处理 message 事件 + message = json.loads(data) + request_id = message.get("id") + if request_id and request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(message) + except json.JSONDecodeError: + continue + except Exception as e: + logger.error(f"SSE 流读取错误: {e}") + + async def _send_sse_request(self, request: Dict[str, Any]) -> Dict[str, Any]: + """通过 SSE endpoint 发送请求""" + if not self._endpoint_url: + raise MCPConnectionError("endpoint URL 未初始化") + + request_id = request["id"] + future = asyncio.Future() + self._pending_requests[request_id] = future + + try: + async with self._session.post(self._endpoint_url, json=request) as response: + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") + + return await asyncio.wait_for(future, timeout=self.timeout) + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise MCPConnectionError("请求超时") + + async def _send_sse_notification(self, notification: Dict[str, Any]): + """发送通知(无需响应)""" + if not self._endpoint_url: + raise MCPConnectionError("endpoint URL 未初始化") + + async with self._session.post(self._endpoint_url, json=notification) as response: + if response.status != 200: + logger.warning(f"通知发送失败: {response.status}") + async def _initialize_modelscope_session(self): """初始化 ModelScope MCP 会话""" init_request = { @@ -107,18 +219,12 @@ class SimpleMCPClient: "params": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, - "clientInfo": { - "name": "MemoryBear", - "version": "1.0.0" - } + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} } } try: - async with self._session.post( - self.server_url, - json=init_request - ) as response: + async with self._session.post(self.server_url, json=init_request) as response: if response.status != 200: error_text = await response.text() raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") @@ -127,21 +233,16 @@ class SimpleMCPClient: if "error" in init_response: raise MCPConnectionError(f"初始化失败: {init_response['error']}") - # 获取 session ID session_id = response.headers.get("Mcp-Session-Id") or response.headers.get("mcp-session-id") if session_id: self._session.headers.update({"Mcp-Session-Id": session_id}) - # 发送 initialized 通知 initialized_notification = { "jsonrpc": "2.0", "method": "notifications/initialized" } - async with self._session.post( - self.server_url, - json=initialized_notification - ) as notif_response: + async with self._session.post(self.server_url, json=initialized_notification): pass except aiohttp.ClientError as e: @@ -149,12 +250,18 @@ class SimpleMCPClient: def _build_headers(self) -> Dict[str, str]: """构建请求头""" + # 基础 headers headers = { "Content-Type": "application/json", "Accept": "application/json, text/event-stream" } - # 添加认证头 + # 合并 connection_config 中的自定义 headers + custom_headers = self.connection_config.get("headers", {}) + if custom_headers: + headers.update(custom_headers) + + # 处理认证配置(认证 headers 优先级更高) auth_config = self.connection_config.get("auth_config", {}) auth_type = self.connection_config.get("auth_type", "none") @@ -178,7 +285,7 @@ class SimpleMCPClient: return headers async def _send_initialize(self): - """发送初始化消息""" + """发送初始化消息(WebSocket)""" init_message = { "jsonrpc": "2.0", "id": self._get_request_id(), @@ -186,124 +293,90 @@ class SimpleMCPClient: "params": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {}}, - "clientInfo": { - "name": "MemoryBear", - "version": "1.0.0" - } + "clientInfo": {"name": "MemoryBear", "version": "1.0.0"} } } await self._websocket.send(json.dumps(init_message)) + response = await self._websocket.recv() + response_data = json.loads(response) - # 等待初始化响应 - response = await asyncio.wait_for( - self._websocket.recv(), - timeout=self.timeout - ) + if "error" in response_data: + raise MCPConnectionError(f"初始化失败: {response_data['error']}") - init_response = json.loads(response) - if "error" in init_response: - raise MCPConnectionError(f"初始化失败: {init_response['error']}") + result = response_data.get("result", {}) + self._server_capabilities = result.get("capabilities", {}) + + await self._websocket.send(json.dumps({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + })) + + async def list_tools(self) -> List[Dict[str, Any]]: + """获取工具列表""" + request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "tools/list" + } + + if self.is_websocket: + await self._websocket.send(json.dumps(request)) + response = await self._websocket.recv() + response_data = json.loads(response) + elif self.is_sse: + response_data = await self._send_sse_request(request) + else: + async with self._session.post(self.server_url, json=request) as response: + response_data = await response.json() + + if "error" in response_data: + raise MCPConnectionError(f"获取工具列表失败: {response_data['error']}") + + result = response_data.get("result", {}) + return result.get("tools", []) + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: + """调用工具""" + request = { + "jsonrpc": "2.0", + "id": self._get_request_id(), + "method": "tools/call", + "params": {"name": tool_name, "arguments": arguments} + } + + if self.is_websocket: + await self._websocket.send(json.dumps(request)) + response = await self._websocket.recv() + response_data = json.loads(response) + elif self.is_sse: + response_data = await self._send_sse_request(request) + else: + async with self._session.post(self.server_url, json=request) as response: + response_data = await response.json() + + if "error" in response_data: + error = response_data["error"] + raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") + + return response_data.get("result", {}) + + def _get_request_id(self) -> int: + """生成请求 ID""" + self._request_id += 1 + return self._request_id async def _handle_websocket_messages(self): """处理 WebSocket 消息""" try: - while self._websocket and not self._websocket.closed: - try: - message = await self._websocket.recv() - data = json.loads(message) - - # 处理响应 - if "id" in data: - request_id = str(data["id"]) - if request_id in self._pending_requests: - future = self._pending_requests.pop(request_id) - if not future.done(): - future.set_result(data) - - except ConnectionClosed: - break - except Exception as e: - logger.error(f"处理WebSocket消息失败: {e}") - + async for message in self._websocket: + data = json.loads(message) + request_id = data.get("id") + if request_id and request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(data) + except ConnectionClosed: + logger.info("WebSocket 连接已关闭") except Exception as e: - logger.error(f"WebSocket消息处理异常: {e}") - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Any: - """调用工具""" - request_data = { - "jsonrpc": "2.0", - "id": self._get_request_id(), - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": arguments - } - } - - if self.is_websocket: - response = await self._send_websocket_request(request_data) - else: - response = await self._send_http_request(request_data) - - if "error" in response: - error = response["error"] - raise MCPConnectionError(f"工具调用失败: {error.get('message', '未知错误')}") - - return response.get("result", {}) - - async def list_tools(self) -> List[Dict[str, Any]]: - """获取工具列表""" - request_data = { - "jsonrpc": "2.0", - "id": self._get_request_id(), - "method": "tools/list", - "params": {} - } - - if self.is_websocket: - response = await self._send_websocket_request(request_data) - else: - response = await self._send_http_request(request_data) - - if "error" in response: - error = response["error"] - raise MCPConnectionError(f"获取工具列表失败: {error.get('message', '未知错误')}") - - result = response.get("result", {}) - return result.get("tools", []) - - async def _send_websocket_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """发送WebSocket请求""" - request_id = str(request_data["id"]) - future = asyncio.Future() - self._pending_requests[request_id] = future - - try: - await self._websocket.send(json.dumps(request_data)) - response = await asyncio.wait_for(future, timeout=self.timeout) - return response - except asyncio.TimeoutError: - self._pending_requests.pop(request_id, None) - raise - - async def _send_http_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]: - """发送HTTP请求""" - try: - async with self._session.post( - self.server_url, - json=request_data - ) as response: - if response.status != 200: - error_text = await response.text() - raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") - - return await response.json() - - except aiohttp.ClientError as e: - raise MCPConnectionError(f"HTTP请求失败: {e}") - - def _get_request_id(self) -> str: - """获取请求ID""" - self._request_id += 1 - return f"req_{self._request_id}_{int(time.time() * 1000)}" \ No newline at end of file + logger.error(f"WebSocket 消息处理错误: {e}")