diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 79c87205..479686ef 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -60,6 +60,22 @@ async def list_tools( raise HTTPException(status_code=500, detail=str(e)) +@router.get("/{tool_id}/methods", response_model=ApiResponse) +async def get_tool_methods( + tool_id: str, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """获取工具的所有方法""" + try: + methods = await service.get_tool_methods(tool_id, current_user.tenant_id) + if methods is None: + raise HTTPException(status_code=404, detail="工具不存在") + return success(data=methods, msg="获取工具方法成功") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + @router.get("/{tool_id}", response_model=ApiResponse) async def get_tool( tool_id: str, @@ -159,7 +175,8 @@ async def execute_tool( workspace_id=current_user.current_workspace_id, timeout=request.timeout ) - + if not result.success: + raise HTTPException(status_code=400, detail=result["error"]) return success( data={ "success": result.success, diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py index 9cad3579..7b6fa8ef 100644 --- a/api/app/core/tools/builtin/datetime_tool.py +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool): type=ParameterType.STRING, description="操作类型", required=True, - enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"] + enum=["format", "convert_timezone", "timestamp_to_datetime", "now"] ), ToolParameter( name="input_value", diff --git a/api/app/core/tools/builtin/json_tool.py b/api/app/core/tools/builtin/json_tool.py index 12f1e688..f22e9370 100644 --- a/api/app/core/tools/builtin/json_tool.py +++ b/api/app/core/tools/builtin/json_tool.py @@ -29,8 +29,7 @@ class JsonTool(BuiltinTool): type=ParameterType.STRING, description="操作类型", required=True, - enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", - "extract", "insert", "replace", "delete", "parse"] + enum=["insert", "replace", "delete", "parse"] ), ToolParameter( name="input_data", diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 2e37f2b1..a1d2ecaa 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -204,7 +204,7 @@ class MCPClient: ) init_response = json.loads(response) - if "error" in init_response: + if init_response.get("error", None) is not None: raise MCPProtocolError(f"初始化失败: {init_response['error']}") return True @@ -325,7 +325,7 @@ class MCPClient: try: response = await self._send_request(request_data, timeout) - if "error" in response: + if response.get("error", None) is not None: error = response["error"] raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}") diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index dc78e761..3aa7b16e 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -1,10 +1,9 @@ """工具数据访问层""" import uuid -from typing import List, Optional, Dict, Any +from typing import List, Optional from sqlalchemy.orm import Session -from sqlalchemy import func, or_ +from sqlalchemy import func -from app.repositories.base_repository import BaseRepository from app.models.tool_model import ( ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, ToolExecution, ToolType, ToolStatus @@ -14,6 +13,31 @@ from app.models.tool_model import ( class ToolRepository: """工具仓储类""" + @staticmethod + def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]: + """根据工作流ID获取tenant_id + + Args: + db: 数据库会话 + workflow_id: 工作流配置ID + + Returns: + tenant_id或None + """ + from app.models.app_model import App + from app.models.workflow_model import WorkflowConfig + from app.models.workspace_model import Workspace + + result = db.query(Workspace.tenant_id).join( + App, App.workspace_id == Workspace.id + ).join( + WorkflowConfig, WorkflowConfig.app_id == App.id + ).filter( + WorkflowConfig.id == workflow_id + ).first() + + return result[0] if result else None + @staticmethod def find_by_tenant( db: Session, diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 783df81a..50cca957 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -297,6 +297,165 @@ class ToolService: self.db.commit() logger.info(f"租户 {tenant_id} 内置工具初始化完成") + async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]: + """获取工具的所有方法 + + Args: + tool_id: 工具ID + tenant_id: 租户ID + + Returns: + 方法列表或None + """ + config = self._get_tool_config(tool_id, tenant_id) + if not config: + return None + + try: + if config.tool_type == ToolType.BUILTIN.value: + return await self._get_builtin_tool_methods(config) + elif config.tool_type == ToolType.CUSTOM.value: + return await self._get_custom_tool_methods(config) + elif config.tool_type == ToolType.MCP.value: + return await self._get_mcp_tool_methods(config) + else: + return [] + + except Exception as e: + logger.error(f"获取工具方法失败: {tool_id}, {e}") + return [] + + async def _get_builtin_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: + """获取内置工具的方法""" + builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id) + if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS: + return [] + + # 获取工具实例 + tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + if not tool_instance: + return [] + + # 检查是否有operation参数 + operation_param = None + for param in tool_instance.parameters: + if param.name == "operation" and param.enum: + operation_param = param + break + + if operation_param: + # 有多个操作 + methods = [] + for operation in operation_param.enum: + methods.append({ + "method_id": f"{config.name}_{operation}", + "name": operation, + "description": f"{config.description} - {operation}", + "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + }) + return methods + else: + # 只有一个方法 + return [{ + "method_id": config.name, + "name": config.name, + "description": config.description, + "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + }] + + async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: + """获取自定义工具的方法""" + custom_config = self.custom_repo.find_by_tool_id(self.db, config.id) + if not custom_config: + return [] + + try: + from app.core.tools.custom.schema_parser import OpenAPISchemaParser + parser = OpenAPISchemaParser() + + # 解析schema + if custom_config.schema_content: + success, schema, error = parser.parse_from_content(custom_config.schema_content, "application/json") + elif custom_config.schema_url: + success, schema, error = await parser.parse_from_url(custom_config.schema_url) + else: + return [] + + if not success: + return [] + + # 提取操作 + tool_info = parser.extract_tool_info(schema) + operations = tool_info.get("operations", {}) + + methods = [] + for operation_id, operation in operations.items(): + # 生成参数列表 + parameters = [] + + # 路径和查询参数 + for param_name, param_info in operation.get("parameters", {}).items(): + parameters.append({ + "name": param_name, + "type": param_info.get("type", "string"), + "description": param_info.get("description", ""), + "required": param_info.get("required", False), + "enum": param_info.get("enum"), + "default": param_info.get("default") + }) + + # 请求体参数 + request_body = operation.get("request_body") + if request_body: + schema_props = request_body.get("schema", {}).get("properties", {}) + required_props = request_body.get("schema", {}).get("required", []) + + for prop_name, prop_schema in schema_props.items(): + parameters.append({ + "name": prop_name, + "type": prop_schema.get("type", "string"), + "description": prop_schema.get("description", ""), + "required": prop_name in required_props, + "enum": prop_schema.get("enum"), + "default": prop_schema.get("default") + }) + + methods.append({ + "method_id": operation_id, + "name": operation.get("summary", operation_id), + "description": operation.get("description", ""), + "method": operation.get("method", "GET"), + "path": operation.get("path", "/"), + "parameters": parameters + }) + + return methods + + except Exception as e: + logger.error(f"解析自定义工具schema失败: {e}") + return [] + + async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: + """获取MCP工具的方法""" + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) + if not mcp_config: + return [] + + available_tools = mcp_config.available_tools or [] + if not available_tools: + return [] + + methods = [] + for tool_name in available_tools: + methods.append({ + "method_id": tool_name, + "name": tool_name, + "description": f"MCP工具: {tool_name}", + "parameters": [] # MCP工具参数需要动态获取 + }) + + return methods + def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]: """获取工具统计信息""" try: