Merge pull request #495 from SuanmoSuanyangTechnology/feature/agent-tool_xjn

fix(mcp)
This commit is contained in:
Ke Sun
2026-03-06 16:40:03 +08:00
committed by GitHub
4 changed files with 47 additions and 3 deletions

View File

@@ -97,6 +97,12 @@ async def create_tool(
):
"""创建工具"""
try:
# 将 MCP 来源字段合并进 config
if request.tool_type == ToolType.MCP:
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
val = getattr(request, key, None)
if val is not None:
request.config[key] = val
tool_id = service.create_tool(
name=request.name,
tool_type=request.tool_type,

View File

@@ -3,7 +3,7 @@ import uuid
from datetime import datetime
from enum import StrEnum
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
@@ -163,6 +163,17 @@ class CustomToolConfig(Base):
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
class MCPSourceChannel(StrEnum):
"""MCP来源渠道枚举"""
ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼
MODELSCOPE = "modelscope" # ModelScope
TOKENFLUX = "tokenflux" # TokenFlux
LANGENG = "langeng" # 蓝耕科技
AI_302 = "302ai" # 302.AI
MCP_ROUTER = "mcp_router" # MCP Router
SELF_HOSTED = "self_hosted" # 自建
class MCPToolConfig(Base):
"""MCP工具配置模型"""
__tablename__ = "mcp_tool_configs"
@@ -170,6 +181,13 @@ class MCPToolConfig(Base):
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
server_url = Column(String(1000), nullable=False) # MCP服务器URL
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
# 来源渠道
source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED,
server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道")
market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id")
market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id")
mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id")
# 服务状态
last_health_check = Column(DateTime)

View File

@@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel):
health_status: str = "unknown"
error_message: Optional[str] = None
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
source_channel: Optional[str] = Field(None, description="来源渠道")
market_id: Optional[str] = Field(None, description="渠道市场id")
market_config_id: Optional[str] = Field(None, description="渠道市场配置id")
mcp_service_id: Optional[str] = Field(None, description="mcp服务id")
class Config:
from_attributes = True
@@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel):
tool_type: ToolType
config: Dict[str, Any] = Field(default_factory=dict)
tags: List[str] = Field(default_factory=list)
source_channel: Optional[str] = Field(None, description="来源渠道仅MCP工具")
market_id: Optional[str] = Field(None, description="渠道市场id仅MCP工具")
market_config_id: Optional[str] = Field(None, description="渠道市场配置id仅MCP工具")
mcp_service_id: Optional[str] = Field(None, description="mcp服务id仅MCP工具")
class ToolUpdateRequest(BaseModel):

View File

@@ -85,7 +85,7 @@ class ToolService:
"""检查工具名称是否重复"""
query = self.db.query(ToolConfig).filter(
ToolConfig.name == name,
ToolConfig.tool_type == tool_type.value,
ToolConfig.tool_type == tool_type,
ToolConfig.tenant_id == tenant_id
)
if exclude_id:
@@ -965,7 +965,11 @@ class ToolService:
id=tool_config.id,
server_url=config.get("server_url"),
connection_config=config.get("connection_config", {}),
available_tools=config.get("available_tools", [])
available_tools=config.get("available_tools", []),
source_channel=config.get("source_channel", "self_hosted"),
market_id=config.get("market_id"),
market_config_id=config.get("market_config_id"),
mcp_service_id=config.get("mcp_service_id"),
)
self.db.add(mcp_config)
@@ -1018,6 +1022,14 @@ class ToolService:
mcp_config.server_url = config.get("server_url")
mcp_config.connection_config = config.get("connection_config", {})
mcp_config.available_tools = config.get("available_tools", [])
if config.get("source_channel") is not None:
mcp_config.source_channel = config.get("source_channel")
if config.get("market_id") is not None:
mcp_config.market_id = config.get("market_id")
if config.get("market_config_id") is not None:
mcp_config.market_config_id = config.get("market_config_id")
if config.get("mcp_service_id") is not None:
mcp_config.mcp_service_id = config.get("mcp_service_id")
@staticmethod
def _determine_initial_status(tool_info: Dict[str, Any]) -> str: