Merge branch 'release/v0.2.7' of github.com:SuanmoSuanyangTechnology/MemoryBear into release/v0.2.7
This commit is contained in:
@@ -14,6 +14,7 @@ from app.models import User
|
|||||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
|
||||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||||
|
|
||||||
@@ -103,7 +104,7 @@ async def create_tool(
|
|||||||
val = getattr(request, key, None)
|
val = getattr(request, key, None)
|
||||||
if val is not None:
|
if val is not None:
|
||||||
request.config[key] = val
|
request.config[key] = val
|
||||||
tool_id = service.create_tool(
|
tool_id = await service.create_tool(
|
||||||
name=request.name,
|
name=request.name,
|
||||||
tool_type=request.tool_type,
|
tool_type=request.tool_type,
|
||||||
tenant_id=current_user.tenant_id,
|
tenant_id=current_user.tenant_id,
|
||||||
@@ -113,6 +114,8 @@ async def create_tool(
|
|||||||
tags=request.tags
|
tags=request.tags
|
||||||
)
|
)
|
||||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||||
|
except BusinessException as e:
|
||||||
|
raise HTTPException(status_code=400, detail=e.message)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -93,7 +93,44 @@ class ToolService:
|
|||||||
if query.first():
|
if query.first():
|
||||||
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
def create_tool(
|
def _check_mcp_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, config: Dict[str, Any]):
|
||||||
|
"""检查MCP工具是否重复:市场来源按market_id+market_config_id+mcp_service_id判断(名称无关),自建按name+tool_type判断"""
|
||||||
|
from app.models.tool_model import MCPSourceChannel
|
||||||
|
source_channel = config.get("source_channel")
|
||||||
|
is_market_source = (
|
||||||
|
source_channel is not None
|
||||||
|
and source_channel != MCPSourceChannel.SELF_HOSTED
|
||||||
|
)
|
||||||
|
if is_market_source:
|
||||||
|
exists = (
|
||||||
|
self.db.query(ToolConfig)
|
||||||
|
.join(MCPToolConfig, MCPToolConfig.id == ToolConfig.id)
|
||||||
|
.filter(
|
||||||
|
ToolConfig.tenant_id == tenant_id,
|
||||||
|
ToolConfig.tool_type == tool_type,
|
||||||
|
MCPToolConfig.source_channel == source_channel,
|
||||||
|
MCPToolConfig.market_id == config.get("market_id"),
|
||||||
|
MCPToolConfig.market_config_id == config.get("market_config_id"),
|
||||||
|
MCPToolConfig.mcp_service_id == config.get("mcp_service_id"),
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if exists:
|
||||||
|
raise BusinessException(f"该MCP服务已添加", BizCode.DUPLICATE_NAME)
|
||||||
|
else:
|
||||||
|
exists = (
|
||||||
|
self.db.query(ToolConfig)
|
||||||
|
.filter(
|
||||||
|
ToolConfig.name == name,
|
||||||
|
ToolConfig.tool_type == tool_type,
|
||||||
|
ToolConfig.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if exists:
|
||||||
|
raise BusinessException(f"工具 '{name}' 已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
|
async def create_tool(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
tool_type: ToolType,
|
tool_type: ToolType,
|
||||||
@@ -106,7 +143,19 @@ class ToolService:
|
|||||||
"""创建工具"""
|
"""创建工具"""
|
||||||
if tool_type == ToolType.BUILTIN:
|
if tool_type == ToolType.BUILTIN:
|
||||||
raise ValueError("内置工具不允许创建")
|
raise ValueError("内置工具不允许创建")
|
||||||
self._check_name_duplicate(name, tool_type, tenant_id)
|
|
||||||
|
cfg = config or {}
|
||||||
|
if tool_type == ToolType.MCP:
|
||||||
|
self._check_mcp_duplicate(name, tool_type, tenant_id, cfg)
|
||||||
|
# 创建前测试连接
|
||||||
|
test_result = await self._test_mcp_connection_by_config(cfg)
|
||||||
|
if not test_result["success"]:
|
||||||
|
raise BusinessException(f"MCP连接测试失败: {test_result['message']}", BizCode.INVALID_PARAMETER)
|
||||||
|
# 将发现的工具列表写回 config
|
||||||
|
if "available_tools" in test_result:
|
||||||
|
cfg["available_tools"] = test_result["available_tools"]
|
||||||
|
else:
|
||||||
|
self._check_name_duplicate(name, tool_type, tenant_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 创建基础配置
|
# 创建基础配置
|
||||||
@@ -117,19 +166,22 @@ class ToolService:
|
|||||||
tool_type=tool_type.value,
|
tool_type=tool_type.value,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
status=ToolStatus.AVAILABLE.value,
|
status=ToolStatus.AVAILABLE.value,
|
||||||
config_data=config or {},
|
config_data=cfg,
|
||||||
tags=tags
|
tags=tags
|
||||||
)
|
)
|
||||||
self.db.add(tool_config)
|
self.db.add(tool_config)
|
||||||
self.db.flush()
|
self.db.flush()
|
||||||
|
|
||||||
# 创建类型特定配置
|
# 创建类型特定配置
|
||||||
self._create_type_config(tool_config, config or {})
|
self._create_type_config(tool_config, cfg)
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
logger.info(f"工具创建成功: {tool_config.id}")
|
logger.info(f"工具创建成功: {tool_config.id}")
|
||||||
return str(tool_config.id)
|
return str(tool_config.id)
|
||||||
|
|
||||||
|
except BusinessException:
|
||||||
|
self.db.rollback()
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.db.rollback()
|
self.db.rollback()
|
||||||
logger.error(f"创建工具失败: {e}")
|
logger.error(f"创建工具失败: {e}")
|
||||||
@@ -1165,6 +1217,27 @@ class ToolService:
|
|||||||
logger.error(f"加载内置工具配置失败: {e}")
|
logger.error(f"加载内置工具配置失败: {e}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
async def _test_mcp_connection_by_config(self, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""根据配置参数直接测试MCP连接(创建前调用,无需已存在的工具记录)"""
|
||||||
|
server_url = config.get("server_url")
|
||||||
|
if not server_url:
|
||||||
|
return {"success": False, "message": "server_url不能为空"}
|
||||||
|
connection_config = config.get("connection_config") or {}
|
||||||
|
try:
|
||||||
|
test_result = await self.mcp_tool_manager.test_tool_connection(server_url, connection_config)
|
||||||
|
if not test_result["success"]:
|
||||||
|
return test_result
|
||||||
|
success_flag, tools, error = await self.mcp_tool_manager.discover_tools(server_url, connection_config)
|
||||||
|
if not success_flag:
|
||||||
|
return {"success": False, "message": f"获取工具列表失败: {error}"}
|
||||||
|
tool_list = [
|
||||||
|
{tool["name"]: {"description": tool.get("description", ""), "inputSchema": tool.get("inputSchema", {})}}
|
||||||
|
for tool in tools if tool.get("name")
|
||||||
|
]
|
||||||
|
return {"success": True, "message": "MCP连接测试成功", "available_tools": tool_list}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "message": f"连接测试异常: {str(e)}"}
|
||||||
|
|
||||||
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]:
|
||||||
"""测试MCP连接并自动同步工具列表"""
|
"""测试MCP连接并自动同步工具列表"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user