326 lines
12 KiB
Python
326 lines
12 KiB
Python
"""工具控制器 - 简化统一的工具管理接口"""
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.error_codes import BizCode
|
||
from app.schemas.tool_schema import (
|
||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||
CustomToolTestRequest, ToolActiveUpdate
|
||
)
|
||
|
||
from app.core.response_utils import success
|
||
from app.db import get_db
|
||
from app.dependencies import get_current_user
|
||
from app.models import User
|
||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||
from app.services.tool_service import ToolService
|
||
from app.schemas.response_schema import ApiResponse
|
||
from app.core.exceptions import BusinessException
|
||
|
||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||
|
||
|
||
def get_tool_service(db: Session = Depends(get_db)) -> ToolService:
|
||
return ToolService(db)
|
||
|
||
|
||
@router.get("/statistics", response_model=ApiResponse)
|
||
async def get_tool_statistics(
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""获取工具统计信息"""
|
||
try:
|
||
stats = service.get_tool_statistics(current_user.tenant_id)
|
||
return success(data=stats, msg="获取统计信息成功")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("", response_model=ApiResponse)
|
||
async def list_tools(
|
||
name: Optional[str] = Query(None),
|
||
tool_type: Optional[str] = Query(None),
|
||
status: Optional[str] = Query(None),
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""获取工具列表"""
|
||
try:
|
||
# 确保内置工具已初始化
|
||
service.ensure_builtin_tools_initialized(current_user.tenant_id)
|
||
|
||
# 获取工具列表
|
||
tools = service.list_tools(
|
||
tenant_id=current_user.tenant_id,
|
||
name=name,
|
||
tool_type=ToolType(tool_type) if tool_type else None,
|
||
status=ToolStatus(status) if status else None
|
||
)
|
||
return success(data=tools, msg="获取工具列表成功")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/{tool_id}/methods", response_model=ApiResponse)
|
||
async def get_tool_methods(
|
||
tool_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""获取工具的所有方法"""
|
||
try:
|
||
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
|
||
if methods is None:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
return success(data=methods, msg="获取工具方法成功")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/{tool_id}", response_model=ApiResponse)
|
||
async def get_tool(
|
||
tool_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""获取工具详情"""
|
||
tool = service.get_tool_info(tool_id, current_user.tenant_id)
|
||
if not tool:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
return success(data=tool, msg="获取工具详情成功")
|
||
|
||
|
||
@router.post("", response_model=ApiResponse)
|
||
async def create_tool(
|
||
request: ToolCreateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""创建工具"""
|
||
try:
|
||
# 将 MCP 来源字段合并进 config
|
||
if request.tool_type == ToolType.MCP:
|
||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||
val = getattr(request, key, None)
|
||
if val is not None:
|
||
request.config[key] = val
|
||
tool_id = await service.create_tool(
|
||
name=request.name,
|
||
tool_type=request.tool_type,
|
||
tenant_id=current_user.tenant_id,
|
||
icon=request.icon,
|
||
description=request.description,
|
||
config=request.config,
|
||
tags=request.tags
|
||
)
|
||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||
except BusinessException as e:
|
||
raise HTTPException(status_code=400, detail=e.message)
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.put("/{tool_id}", response_model=ApiResponse)
|
||
async def update_tool(
|
||
tool_id: str,
|
||
request: ToolUpdateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""更新工具"""
|
||
try:
|
||
success_flag = service.update_tool(
|
||
tool_id=tool_id,
|
||
tenant_id=current_user.tenant_id,
|
||
name=request.name,
|
||
description=request.description,
|
||
icon=request.icon,
|
||
config=request.config,
|
||
is_enabled=request.config.get("is_enabled", None),
|
||
tags=request.tags
|
||
)
|
||
if not success_flag:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
return success(msg="工具更新成功")
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.delete("/{tool_id}", response_model=ApiResponse)
|
||
async def delete_tool(
|
||
tool_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""删除工具(逻辑删除,is_active=False)"""
|
||
try:
|
||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||
if not success_flag:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
return success(msg="工具删除成功")
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||
async def set_tool_active(
|
||
tool_id: str,
|
||
request: ToolActiveUpdate,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""设置工具可用状态(启用/禁用)
|
||
|
||
- is_active=true: 启用工具
|
||
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||
"""
|
||
try:
|
||
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||
if not success_flag:
|
||
raise HTTPException(status_code=404, detail="工具不存在")
|
||
action = "启用" if request.is_active else "禁用"
|
||
return success(msg=f"工具已{action}")
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/execution/execute", response_model=ApiResponse)
|
||
async def execute_tool(
|
||
request: ToolExecuteRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""执行工具"""
|
||
try:
|
||
result = await service.execute_tool(
|
||
tool_id=request.tool_id,
|
||
parameters=request.parameters,
|
||
tenant_id=current_user.tenant_id,
|
||
user_id=current_user.id,
|
||
workspace_id=current_user.current_workspace_id,
|
||
timeout=request.timeout
|
||
)
|
||
if not result.success:
|
||
raise HTTPException(status_code=400, detail=result["error"])
|
||
return success(
|
||
data={
|
||
"success": result.success,
|
||
"data": result.data,
|
||
"error": result.error,
|
||
"execution_time": result.execution_time,
|
||
"token_usage": result.token_usage
|
||
},
|
||
msg="工具执行完成"
|
||
)
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.post("/parse_schema", response_model=ApiResponse)
|
||
async def parse_openapi_schema(
|
||
request: ParseSchemaRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""解析OpenAPI schema"""
|
||
try:
|
||
result = await service.parse_openapi_schema(request.schema_content, request.schema_url)
|
||
if result["success"] is False:
|
||
raise HTTPException(status_code=400, detail=result["message"])
|
||
return success(data=result, msg="Schema解析完成")
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/{tool_id}/sync_mcp_tools", response_model=ApiResponse)
|
||
async def sync_mcp_tools(
|
||
tool_id: str,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""同步MCP工具列表"""
|
||
try:
|
||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||
if not result.get("success", False):
|
||
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||
return success(data=result, msg="MCP工具列表同步完成")
|
||
except BusinessException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
@router.post("/{tool_id}/test", response_model=ApiResponse)
|
||
async def test_tool_connection(
|
||
tool_id: str,
|
||
test_request: Optional[CustomToolTestRequest] = None,
|
||
current_user: User = Depends(get_current_user),
|
||
service: ToolService = Depends(get_tool_service)
|
||
):
|
||
"""测试工具连接"""
|
||
try:
|
||
if test_request:
|
||
# 自定义工具测试
|
||
result = await service.test_custom_tool(
|
||
tool_id, current_user.tenant_id,
|
||
test_request.method, test_request.path, test_request.parameters
|
||
)
|
||
else:
|
||
# 普通连接测试
|
||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||
if result["success"] is False:
|
||
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||
return success(data=result, msg="连接测试完成")
|
||
except BusinessException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@router.get("/enums/tool_types", response_model=ApiResponse)
|
||
async def get_tool_types():
|
||
"""获取工具类型枚举"""
|
||
return success(
|
||
data=[
|
||
{"value": ToolType.BUILTIN.value, "label": "内置工具"},
|
||
{"value": ToolType.CUSTOM.value, "label": "自定义工具"},
|
||
{"value": ToolType.MCP.value, "label": "MCP工具"}
|
||
],
|
||
msg="获取工具类型成功"
|
||
)
|
||
|
||
|
||
@router.get("/enums/status", response_model=ApiResponse)
|
||
async def get_tool_status():
|
||
"""获取工具状态枚举"""
|
||
return success(data=ToolStatus.get_all_statuses_with_labels(), msg="获取工具状态成功")
|
||
|
||
|
||
@router.get("/auth/types", response_model=ApiResponse)
|
||
async def get_auth_types():
|
||
"""获取认证类型枚举"""
|
||
return success(data=AuthType.get_all_types_with_labels(), msg="获取认证类型成功")
|