Merge pull request #495 from SuanmoSuanyangTechnology/feature/agent-tool_xjn
fix(mcp)
This commit is contained in:
@@ -97,6 +97,12 @@ async def create_tool(
|
||||
):
|
||||
"""创建工具"""
|
||||
try:
|
||||
# 将 MCP 来源字段合并进 config
|
||||
if request.tool_type == ToolType.MCP:
|
||||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||
val = getattr(request, key, None)
|
||||
if val is not None:
|
||||
request.config[key] = val
|
||||
tool_id = service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean
|
||||
from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean, text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -163,6 +163,17 @@ class CustomToolConfig(Base):
|
||||
return f"<CustomToolConfig(id={self.id}, auth_type={self.auth_type})>"
|
||||
|
||||
|
||||
class MCPSourceChannel(StrEnum):
|
||||
"""MCP来源渠道枚举"""
|
||||
ALIYUN_BAILIAN = "aliyun_bailian" # 阿里云百炼
|
||||
MODELSCOPE = "modelscope" # ModelScope
|
||||
TOKENFLUX = "tokenflux" # TokenFlux
|
||||
LANGENG = "langeng" # 蓝耕科技
|
||||
AI_302 = "302ai" # 302.AI
|
||||
MCP_ROUTER = "mcp_router" # MCP Router
|
||||
SELF_HOSTED = "self_hosted" # 自建
|
||||
|
||||
|
||||
class MCPToolConfig(Base):
|
||||
"""MCP工具配置模型"""
|
||||
__tablename__ = "mcp_tool_configs"
|
||||
@@ -170,6 +181,13 @@ class MCPToolConfig(Base):
|
||||
id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True)
|
||||
server_url = Column(String(1000), nullable=False) # MCP服务器URL
|
||||
connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息)
|
||||
|
||||
# 来源渠道
|
||||
source_channel = Column(String(50), default=MCPSourceChannel.SELF_HOSTED,
|
||||
server_default=text(f"'{MCPSourceChannel.SELF_HOSTED}'"), nullable=False, comment="来源渠道")
|
||||
market_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场id")
|
||||
market_config_id = Column(UUID(as_uuid=True), nullable=True, comment="渠道市场配置id")
|
||||
mcp_service_id = Column(String(255), nullable=True, comment="mcp服务id")
|
||||
|
||||
# 服务状态
|
||||
last_health_check = Column(DateTime)
|
||||
|
||||
@@ -155,6 +155,10 @@ class MCPToolConfigSchema(BaseModel):
|
||||
health_status: str = "unknown"
|
||||
error_message: Optional[str] = None
|
||||
available_tools: List[Dict[str, Dict[str, Any]]] = Field(default_factory=list, description="工具列表,格式: [{'tool_name': str, 'arguments': dict}]")
|
||||
source_channel: Optional[str] = Field(None, description="来源渠道")
|
||||
market_id: Optional[str] = Field(None, description="渠道市场id")
|
||||
market_config_id: Optional[str] = Field(None, description="渠道市场配置id")
|
||||
mcp_service_id: Optional[str] = Field(None, description="mcp服务id")
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -192,6 +196,10 @@ class ToolCreateRequest(BaseModel):
|
||||
tool_type: ToolType
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
source_channel: Optional[str] = Field(None, description="来源渠道(仅MCP工具)")
|
||||
market_id: Optional[str] = Field(None, description="渠道市场id(仅MCP工具)")
|
||||
market_config_id: Optional[str] = Field(None, description="渠道市场配置id(仅MCP工具)")
|
||||
mcp_service_id: Optional[str] = Field(None, description="mcp服务id(仅MCP工具)")
|
||||
|
||||
|
||||
class ToolUpdateRequest(BaseModel):
|
||||
|
||||
@@ -85,7 +85,7 @@ class ToolService:
|
||||
"""检查工具名称是否重复"""
|
||||
query = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.name == name,
|
||||
ToolConfig.tool_type == tool_type.value,
|
||||
ToolConfig.tool_type == tool_type,
|
||||
ToolConfig.tenant_id == tenant_id
|
||||
)
|
||||
if exclude_id:
|
||||
@@ -965,7 +965,11 @@ class ToolService:
|
||||
id=tool_config.id,
|
||||
server_url=config.get("server_url"),
|
||||
connection_config=config.get("connection_config", {}),
|
||||
available_tools=config.get("available_tools", [])
|
||||
available_tools=config.get("available_tools", []),
|
||||
source_channel=config.get("source_channel", "self_hosted"),
|
||||
market_id=config.get("market_id"),
|
||||
market_config_id=config.get("market_config_id"),
|
||||
mcp_service_id=config.get("mcp_service_id"),
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
@@ -1018,6 +1022,14 @@ class ToolService:
|
||||
mcp_config.server_url = config.get("server_url")
|
||||
mcp_config.connection_config = config.get("connection_config", {})
|
||||
mcp_config.available_tools = config.get("available_tools", [])
|
||||
if config.get("source_channel") is not None:
|
||||
mcp_config.source_channel = config.get("source_channel")
|
||||
if config.get("market_id") is not None:
|
||||
mcp_config.market_id = config.get("market_id")
|
||||
if config.get("market_config_id") is not None:
|
||||
mcp_config.market_config_id = config.get("market_config_id")
|
||||
if config.get("mcp_service_id") is not None:
|
||||
mcp_config.mcp_service_id = config.get("mcp_service_id")
|
||||
|
||||
@staticmethod
|
||||
def _determine_initial_status(tool_info: Dict[str, Any]) -> str:
|
||||
|
||||
Reference in New Issue
Block a user