feat(tool system): add all methods for obtaining the tool
This commit is contained in:
@@ -60,6 +60,22 @@ async def list_tools(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{tool_id}/methods", response_model=ApiResponse)
|
||||||
|
async def get_tool_methods(
|
||||||
|
tool_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
service: ToolService = Depends(get_tool_service)
|
||||||
|
):
|
||||||
|
"""获取工具的所有方法"""
|
||||||
|
try:
|
||||||
|
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
|
||||||
|
if methods is None:
|
||||||
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
|
return success(data=methods, msg="获取工具方法成功")
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{tool_id}", response_model=ApiResponse)
|
@router.get("/{tool_id}", response_model=ApiResponse)
|
||||||
async def get_tool(
|
async def get_tool(
|
||||||
tool_id: str,
|
tool_id: str,
|
||||||
@@ -159,7 +175,8 @@ async def execute_tool(
|
|||||||
workspace_id=current_user.current_workspace_id,
|
workspace_id=current_user.current_workspace_id,
|
||||||
timeout=request.timeout
|
timeout=request.timeout
|
||||||
)
|
)
|
||||||
|
if not result.success:
|
||||||
|
raise HTTPException(status_code=400, detail=result["error"])
|
||||||
return success(
|
return success(
|
||||||
data={
|
data={
|
||||||
"success": result.success,
|
"success": result.success,
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ class DateTimeTool(BuiltinTool):
|
|||||||
type=ParameterType.STRING,
|
type=ParameterType.STRING,
|
||||||
description="操作类型",
|
description="操作类型",
|
||||||
required=True,
|
required=True,
|
||||||
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
|
enum=["format", "convert_timezone", "timestamp_to_datetime", "now"]
|
||||||
),
|
),
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
name="input_value",
|
name="input_value",
|
||||||
|
|||||||
@@ -29,8 +29,7 @@ class JsonTool(BuiltinTool):
|
|||||||
type=ParameterType.STRING,
|
type=ParameterType.STRING,
|
||||||
description="操作类型",
|
description="操作类型",
|
||||||
required=True,
|
required=True,
|
||||||
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge",
|
enum=["insert", "replace", "delete", "parse"]
|
||||||
"extract", "insert", "replace", "delete", "parse"]
|
|
||||||
),
|
),
|
||||||
ToolParameter(
|
ToolParameter(
|
||||||
name="input_data",
|
name="input_data",
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ class MCPClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
init_response = json.loads(response)
|
init_response = json.loads(response)
|
||||||
if "error" in init_response:
|
if init_response.get("error", None) is not None:
|
||||||
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
|
||||||
|
|
||||||
return True
|
return True
|
||||||
@@ -325,7 +325,7 @@ class MCPClient:
|
|||||||
try:
|
try:
|
||||||
response = await self._send_request(request_data, timeout)
|
response = await self._send_request(request_data, timeout)
|
||||||
|
|
||||||
if "error" in response:
|
if response.get("error", None) is not None:
|
||||||
error = response["error"]
|
error = response["error"]
|
||||||
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
"""工具数据访问层"""
|
"""工具数据访问层"""
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import func, or_
|
from sqlalchemy import func
|
||||||
|
|
||||||
from app.repositories.base_repository import BaseRepository
|
|
||||||
from app.models.tool_model import (
|
from app.models.tool_model import (
|
||||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||||
ToolExecution, ToolType, ToolStatus
|
ToolExecution, ToolType, ToolStatus
|
||||||
@@ -14,6 +13,31 @@ from app.models.tool_model import (
|
|||||||
class ToolRepository:
|
class ToolRepository:
|
||||||
"""工具仓储类"""
|
"""工具仓储类"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]:
|
||||||
|
"""根据工作流ID获取tenant_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
workflow_id: 工作流配置ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tenant_id或None
|
||||||
|
"""
|
||||||
|
from app.models.app_model import App
|
||||||
|
from app.models.workflow_model import WorkflowConfig
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
|
||||||
|
result = db.query(Workspace.tenant_id).join(
|
||||||
|
App, App.workspace_id == Workspace.id
|
||||||
|
).join(
|
||||||
|
WorkflowConfig, WorkflowConfig.app_id == App.id
|
||||||
|
).filter(
|
||||||
|
WorkflowConfig.id == workflow_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_by_tenant(
|
def find_by_tenant(
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
@@ -297,6 +297,165 @@ class ToolService:
|
|||||||
self.db.commit()
|
self.db.commit()
|
||||||
logger.info(f"租户 {tenant_id} 内置工具初始化完成")
|
logger.info(f"租户 {tenant_id} 内置工具初始化完成")
|
||||||
|
|
||||||
|
async def get_tool_methods(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""获取工具的所有方法
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_id: 工具ID
|
||||||
|
tenant_id: 租户ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
方法列表或None
|
||||||
|
"""
|
||||||
|
config = self._get_tool_config(tool_id, tenant_id)
|
||||||
|
if not config:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.tool_type == ToolType.BUILTIN.value:
|
||||||
|
return await self._get_builtin_tool_methods(config)
|
||||||
|
elif config.tool_type == ToolType.CUSTOM.value:
|
||||||
|
return await self._get_custom_tool_methods(config)
|
||||||
|
elif config.tool_type == ToolType.MCP.value:
|
||||||
|
return await self._get_mcp_tool_methods(config)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取工具方法失败: {tool_id}, {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _get_builtin_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||||
|
"""获取内置工具的方法"""
|
||||||
|
builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id)
|
||||||
|
if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 获取工具实例
|
||||||
|
tool_instance = self._get_tool_instance(str(config.id), config.tenant_id)
|
||||||
|
if not tool_instance:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 检查是否有operation参数
|
||||||
|
operation_param = None
|
||||||
|
for param in tool_instance.parameters:
|
||||||
|
if param.name == "operation" and param.enum:
|
||||||
|
operation_param = param
|
||||||
|
break
|
||||||
|
|
||||||
|
if operation_param:
|
||||||
|
# 有多个操作
|
||||||
|
methods = []
|
||||||
|
for operation in operation_param.enum:
|
||||||
|
methods.append({
|
||||||
|
"method_id": f"{config.name}_{operation}",
|
||||||
|
"name": operation,
|
||||||
|
"description": f"{config.description} - {operation}",
|
||||||
|
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||||
|
})
|
||||||
|
return methods
|
||||||
|
else:
|
||||||
|
# 只有一个方法
|
||||||
|
return [{
|
||||||
|
"method_id": config.name,
|
||||||
|
"name": config.name,
|
||||||
|
"description": config.description,
|
||||||
|
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||||
|
}]
|
||||||
|
|
||||||
|
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||||
|
"""获取自定义工具的方法"""
|
||||||
|
custom_config = self.custom_repo.find_by_tool_id(self.db, config.id)
|
||||||
|
if not custom_config:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.core.tools.custom.schema_parser import OpenAPISchemaParser
|
||||||
|
parser = OpenAPISchemaParser()
|
||||||
|
|
||||||
|
# 解析schema
|
||||||
|
if custom_config.schema_content:
|
||||||
|
success, schema, error = parser.parse_from_content(custom_config.schema_content, "application/json")
|
||||||
|
elif custom_config.schema_url:
|
||||||
|
success, schema, error = await parser.parse_from_url(custom_config.schema_url)
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 提取操作
|
||||||
|
tool_info = parser.extract_tool_info(schema)
|
||||||
|
operations = tool_info.get("operations", {})
|
||||||
|
|
||||||
|
methods = []
|
||||||
|
for operation_id, operation in operations.items():
|
||||||
|
# 生成参数列表
|
||||||
|
parameters = []
|
||||||
|
|
||||||
|
# 路径和查询参数
|
||||||
|
for param_name, param_info in operation.get("parameters", {}).items():
|
||||||
|
parameters.append({
|
||||||
|
"name": param_name,
|
||||||
|
"type": param_info.get("type", "string"),
|
||||||
|
"description": param_info.get("description", ""),
|
||||||
|
"required": param_info.get("required", False),
|
||||||
|
"enum": param_info.get("enum"),
|
||||||
|
"default": param_info.get("default")
|
||||||
|
})
|
||||||
|
|
||||||
|
# 请求体参数
|
||||||
|
request_body = operation.get("request_body")
|
||||||
|
if request_body:
|
||||||
|
schema_props = request_body.get("schema", {}).get("properties", {})
|
||||||
|
required_props = request_body.get("schema", {}).get("required", [])
|
||||||
|
|
||||||
|
for prop_name, prop_schema in schema_props.items():
|
||||||
|
parameters.append({
|
||||||
|
"name": prop_name,
|
||||||
|
"type": prop_schema.get("type", "string"),
|
||||||
|
"description": prop_schema.get("description", ""),
|
||||||
|
"required": prop_name in required_props,
|
||||||
|
"enum": prop_schema.get("enum"),
|
||||||
|
"default": prop_schema.get("default")
|
||||||
|
})
|
||||||
|
|
||||||
|
methods.append({
|
||||||
|
"method_id": operation_id,
|
||||||
|
"name": operation.get("summary", operation_id),
|
||||||
|
"description": operation.get("description", ""),
|
||||||
|
"method": operation.get("method", "GET"),
|
||||||
|
"path": operation.get("path", "/"),
|
||||||
|
"parameters": parameters
|
||||||
|
})
|
||||||
|
|
||||||
|
return methods
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析自定义工具schema失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _get_mcp_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||||
|
"""获取MCP工具的方法"""
|
||||||
|
mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id)
|
||||||
|
if not mcp_config:
|
||||||
|
return []
|
||||||
|
|
||||||
|
available_tools = mcp_config.available_tools or []
|
||||||
|
if not available_tools:
|
||||||
|
return []
|
||||||
|
|
||||||
|
methods = []
|
||||||
|
for tool_name in available_tools:
|
||||||
|
methods.append({
|
||||||
|
"method_id": tool_name,
|
||||||
|
"name": tool_name,
|
||||||
|
"description": f"MCP工具: {tool_name}",
|
||||||
|
"parameters": [] # MCP工具参数需要动态获取
|
||||||
|
})
|
||||||
|
|
||||||
|
return methods
|
||||||
|
|
||||||
def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]:
|
||||||
"""获取工具统计信息"""
|
"""获取工具统计信息"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user