From a3e6f67ff707c169b04a22cd3c41d826197d56aa Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Wed, 11 Mar 2026 17:19:07 +0800 Subject: [PATCH] fix(tool): The MCP tool checks for duplicate additions from the main screen and performs a test before adding. --- api/app/controllers/tool_controller.py | 5 +- api/app/services/tool_service.py | 81 ++++++++++++++++++++++++-- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index ce5b15c0..10ca83af 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -14,6 +14,7 @@ from app.models import User from app.models.tool_model import ToolType, ToolStatus, AuthType from app.services.tool_service import ToolService from app.schemas.response_schema import ApiResponse +from app.core.exceptions import BusinessException router = APIRouter(prefix="/tools", tags=["Tool System"]) @@ -103,7 +104,7 @@ async def create_tool( val = getattr(request, key, None) if val is not None: request.config[key] = val - tool_id = service.create_tool( + tool_id = await service.create_tool( name=request.name, tool_type=request.tool_type, tenant_id=current_user.tenant_id, @@ -113,6 +114,8 @@ async def create_tool( tags=request.tags ) return success(data={"tool_id": tool_id}, msg="工具创建成功") + except BusinessException as e: + raise HTTPException(status_code=400, detail=e.message) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 4fe1e9e6..23def7f8 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -93,7 +93,44 @@ class ToolService: if query.first(): raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME) - def create_tool( + def _check_mcp_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, config: Dict[str, Any]): + """检查MCP工具是否重复:市场来源按market_id+market_config_id+mcp_service_id判断(名称无关),自建按name+tool_type判断""" + from app.models.tool_model import MCPSourceChannel + source_channel = config.get("source_channel") + is_market_source = ( + source_channel is not None + and source_channel != MCPSourceChannel.SELF_HOSTED + ) + if is_market_source: + exists = ( + self.db.query(ToolConfig) + .join(MCPToolConfig, MCPToolConfig.id == ToolConfig.id) + .filter( + ToolConfig.tenant_id == tenant_id, + ToolConfig.tool_type == tool_type, + MCPToolConfig.source_channel == source_channel, + MCPToolConfig.market_id == config.get("market_id"), + MCPToolConfig.market_config_id == config.get("market_config_id"), + MCPToolConfig.mcp_service_id == config.get("mcp_service_id"), + ) + .first() + ) + if exists: + raise BusinessException(f"该MCP服务已添加", BizCode.DUPLICATE_NAME) + else: + exists = ( + self.db.query(ToolConfig) + .filter( + ToolConfig.name == name, + ToolConfig.tool_type == tool_type, + ToolConfig.tenant_id == tenant_id, + ) + .first() + ) + if exists: + raise BusinessException(f"工具 '{name}' 已存在", BizCode.DUPLICATE_NAME) + + async def create_tool( self, name: str, tool_type: ToolType, @@ -106,7 +143,19 @@ class ToolService: """创建工具""" if tool_type == ToolType.BUILTIN: raise ValueError("内置工具不允许创建") - self._check_name_duplicate(name, tool_type, tenant_id) + + cfg = config or {} + if tool_type == ToolType.MCP: + self._check_mcp_duplicate(name, tool_type, tenant_id, cfg) + # 创建前测试连接 + test_result = await self._test_mcp_connection_by_config(cfg) + if not test_result["success"]: + raise BusinessException(f"MCP连接测试失败: {test_result['message']}", BizCode.INVALID_PARAMETER) + # 将发现的工具列表写回 config + if "available_tools" in test_result: + cfg["available_tools"] = test_result["available_tools"] + else: + self._check_name_duplicate(name, tool_type, tenant_id) try: # 创建基础配置 @@ -117,19 +166,22 @@ class ToolService: tool_type=tool_type.value, tenant_id=tenant_id, status=ToolStatus.AVAILABLE.value, - config_data=config or {}, + config_data=cfg, tags=tags ) self.db.add(tool_config) self.db.flush() # 创建类型特定配置 - self._create_type_config(tool_config, config or {}) + self._create_type_config(tool_config, cfg) self.db.commit() logger.info(f"工具创建成功: {tool_config.id}") return str(tool_config.id) + except BusinessException: + self.db.rollback() + raise except Exception as e: self.db.rollback() logger.error(f"创建工具失败: {e}") @@ -1165,6 +1217,27 @@ class ToolService: logger.error(f"加载内置工具配置失败: {e}") return {} + async def _test_mcp_connection_by_config(self, config: Dict[str, Any]) -> Dict[str, Any]: + """根据配置参数直接测试MCP连接(创建前调用,无需已存在的工具记录)""" + server_url = config.get("server_url") + if not server_url: + return {"success": False, "message": "server_url不能为空"} + connection_config = config.get("connection_config") or {} + try: + test_result = await self.mcp_tool_manager.test_tool_connection(server_url, connection_config) + if not test_result["success"]: + return test_result + success_flag, tools, error = await self.mcp_tool_manager.discover_tools(server_url, connection_config) + if not success_flag: + return {"success": False, "message": f"获取工具列表失败: {error}"} + tool_list = [ + {tool["name"]: {"description": tool.get("description", ""), "inputSchema": tool.get("inputSchema", {})}} + for tool in tools if tool.get("name") + ] + return {"success": True, "message": "MCP连接测试成功", "available_tools": tool_list} + except Exception as e: + return {"success": False, "message": f"连接测试异常: {str(e)}"} + async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]: """测试MCP连接并自动同步工具列表""" try: