From 391cd602a2feef16fd5cdf125412b2a4f344d91d Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Fri, 6 Mar 2026 16:32:33 +0800 Subject: [PATCH] fix(mcp): MCP tool binds the information of the tool marketplace --- api/app/controllers/tool_controller.py | 6 ++++++ api/app/models/tool_model.py | 20 +++++++++++++++++++- api/app/schemas/tool_schema.py | 8 ++++++++ api/app/services/tool_service.py | 16 ++++++++++++++-- 4 files changed, 47 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index a3624ea4..ce5b15c0 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -97,6 +97,12 @@ async def create_tool( ): """创建工具""" try: + # 将 MCP 来源字段合并进 config + if request.tool_type == ToolType.MCP: + for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"): + val = getattr(request, key, None) + if val is not None: + request.config[key] = val tool_id = service.create_tool( name=request.name, tool_type=request.tool_type, diff --git a/api/app/models/tool_model.py b/api/app/models/tool_model.py index ccd28693..98448bc5 100644 --- a/api/app/models/tool_model.py +++ b/api/app/models/tool_model.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from enum import StrEnum -from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean +from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -163,6 +163,17 @@ class CustomToolConfig(Base): return f"" +class MCPSourceChannel(StrEnum): + """MCP来源渠道枚举""" + ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼 + MODELSCOPE = "modelscope" # ModelScope + TOKENFLUX = "tokenflux" # TokenFlux + LANGENG = "langeng" # 蓝耕科技 + AI_302 = "302ai" # 302.AI + MCP_ROUTER = "mcp_router" # MCP Router + SELF_HOSTED = "self_hosted" # 自建 + + class MCPToolConfig(Base): """MCP工具配置模型""" __tablename__ = "mcp_tool_configs" @@ -170,6 +181,13 @@ class MCPToolConfig(Base): id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) server_url = Column(String(1000), nullable=False) # MCP服务器URL connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息) + + # 来源渠道 + source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED, + server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道") + market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id") + market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id") + mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id") # 服务状态 last_health_check = Column(DateTime) diff --git a/api/app/schemas/tool_schema.py b/api/app/schemas/tool_schema.py index 48afe2c3..2ba86c2c 100644 --- a/api/app/schemas/tool_schema.py +++ b/api/app/schemas/tool_schema.py @@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel): health_status: str = "unknown" error_message: Optional[str] = None available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]") + source_channel: Optional[str] = Field(None, description="来源渠道") + market_id: Optional[str] = Field(None, description="渠道市场id") + market_config_id: Optional[str] = Field(None, description="渠道市场配置id") + mcp_service_id: Optional[str] = Field(None, description="mcp服务id") class Config: from_attributes = True @@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel): tool_type: ToolType config: Dict[str, Any] = Field(default_factory=dict) tags: List[str] = Field(default_factory=list) + source_channel: Optional[str] = Field(None, description="来源渠道(仅MCP工具)") + market_id: Optional[str] = Field(None, description="渠道市场id(仅MCP工具)") + market_config_id: Optional[str] = Field(None, description="渠道市场配置id(仅MCP工具)") + mcp_service_id: Optional[str] = Field(None, description="mcp服务id(仅MCP工具)") class ToolUpdateRequest(BaseModel): diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index f6e2ccce..60ac1a38 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -85,7 +85,7 @@ class ToolService: """检查工具名称是否重复""" query = self.db.query(ToolConfig).filter( ToolConfig.name == name, - ToolConfig.tool_type == tool_type.value, + ToolConfig.tool_type == tool_type, ToolConfig.tenant_id == tenant_id ) if exclude_id: @@ -965,7 +965,11 @@ class ToolService: id=tool_config.id, server_url=config.get("server_url"), connection_config=config.get("connection_config", {}), - available_tools=config.get("available_tools", []) + available_tools=config.get("available_tools", []), + source_channel=config.get("source_channel", "self_hosted"), + market_id=config.get("market_id"), + market_config_id=config.get("market_config_id"), + mcp_service_id=config.get("mcp_service_id"), ) self.db.add(mcp_config) @@ -1018,6 +1022,14 @@ class ToolService: mcp_config.server_url = config.get("server_url") mcp_config.connection_config = config.get("connection_config", {}) mcp_config.available_tools = config.get("available_tools", []) + if config.get("source_channel") is not None: + mcp_config.source_channel = config.get("source_channel") + if config.get("market_id") is not None: + mcp_config.market_id = config.get("market_id") + if config.get("market_config_id") is not None: + mcp_config.market_config_id = config.get("market_config_id") + if config.get("mcp_service_id") is not None: + mcp_config.mcp_service_id = config.get("mcp_service_id") @staticmethod def _determine_initial_status(tool_info: Dict[str, Any]) -> str: