Files
MemoryBear/api/app/controllers/tool_controller.py

326 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""工具控制器 - 简化统一的工具管理接口"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.schemas.tool_schema import (
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
CustomToolTestRequest, ToolActiveUpdate
)
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.tool_model import ToolType, ToolStatus, AuthType
from app.services.tool_service import ToolService
from app.schemas.response_schema import ApiResponse
from app.core.exceptions import BusinessException
router = APIRouter(prefix="/tools", tags=["Tool System"])
def get_tool_service(db: Session = Depends(get_db)) -> ToolService:
return ToolService(db)
@router.get("/statistics", response_model=ApiResponse)
async def get_tool_statistics(
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具统计信息"""
try:
stats = service.get_tool_statistics(current_user.tenant_id)
return success(data=stats, msg="获取统计信息成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("", response_model=ApiResponse)
async def list_tools(
name: Optional[str] = Query(None),
tool_type: Optional[str] = Query(None),
status: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具列表"""
try:
# 确保内置工具已初始化
service.ensure_builtin_tools_initialized(current_user.tenant_id)
# 获取工具列表
tools = service.list_tools(
tenant_id=current_user.tenant_id,
name=name,
tool_type=ToolType(tool_type) if tool_type else None,
status=ToolStatus(status) if status else None
)
return success(data=tools, msg="获取工具列表成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}/methods", response_model=ApiResponse)
async def get_tool_methods(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具的所有方法"""
try:
methods = await service.get_tool_methods(tool_id, current_user.tenant_id)
if methods is None:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=methods, msg="获取工具方法成功")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/{tool_id}", response_model=ApiResponse)
async def get_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具详情"""
tool = service.get_tool_info(tool_id, current_user.tenant_id)
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=tool, msg="获取工具详情成功")
@router.post("", response_model=ApiResponse)
async def create_tool(
request: ToolCreateRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""创建工具"""
try:
# 将 MCP 来源字段合并进 config
if request.tool_type == ToolType.MCP:
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
val = getattr(request, key, None)
if val is not None:
request.config[key] = val
tool_id = await service.create_tool(
name=request.name,
tool_type=request.tool_type,
tenant_id=current_user.tenant_id,
icon=request.icon,
description=request.description,
config=request.config,
tags=request.tags
)
return success(data={"tool_id": tool_id}, msg="工具创建成功")
except BusinessException as e:
raise HTTPException(status_code=400, detail=e.message)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{tool_id}", response_model=ApiResponse)
async def update_tool(
tool_id: str,
request: ToolUpdateRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""更新工具"""
try:
success_flag = service.update_tool(
tool_id=tool_id,
tenant_id=current_user.tenant_id,
name=request.name,
description=request.description,
icon=request.icon,
config=request.config,
is_enabled=request.config.get("is_enabled", None),
tags=request.tags
)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
return success(msg="工具更新成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{tool_id}", response_model=ApiResponse)
async def delete_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""删除工具逻辑删除is_active=False"""
try:
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
return success(msg="工具删除成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.patch("/{tool_id}/active", response_model=ApiResponse)
async def set_tool_active(
tool_id: str,
request: ToolActiveUpdate,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""设置工具可用状态(启用/禁用)
- is_active=true: 启用工具
- is_active=false: 禁用工具(等同于删除,但可恢复)
"""
try:
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
action = "启用" if request.is_active else "禁用"
return success(msg=f"工具已{action}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/execution/execute", response_model=ApiResponse)
async def execute_tool(
request: ToolExecuteRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""执行工具"""
try:
result = await service.execute_tool(
tool_id=request.tool_id,
parameters=request.parameters,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
timeout=request.timeout
)
if not result.success:
raise HTTPException(status_code=400, detail=result["error"])
return success(
data={
"success": result.success,
"data": result.data,
"error": result.error,
"execution_time": result.execution_time,
"token_usage": result.token_usage
},
msg="工具执行完成"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/parse_schema", response_model=ApiResponse)
async def parse_openapi_schema(
request: ParseSchemaRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""解析OpenAPI schema"""
try:
result = await service.parse_openapi_schema(request.schema_content, request.schema_url)
if result["success"] is False:
raise HTTPException(status_code=400, detail=result["message"])
return success(data=result, msg="Schema解析完成")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{tool_id}/sync_mcp_tools", response_model=ApiResponse)
async def sync_mcp_tools(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""同步MCP工具列表"""
try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if not result.get("success", False):
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
return success(data=result, msg="MCP工具列表同步完成")
except BusinessException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{tool_id}/test", response_model=ApiResponse)
async def test_tool_connection(
tool_id: str,
test_request: Optional[CustomToolTestRequest] = None,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""测试工具连接"""
try:
if test_request:
# 自定义工具测试
result = await service.test_custom_tool(
tool_id, current_user.tenant_id,
test_request.method, test_request.path, test_request.parameters
)
else:
# 普通连接测试
result = await service.test_connection(tool_id, current_user.tenant_id)
if result["success"] is False:
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
return success(data=result, msg="连接测试完成")
except BusinessException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.get("/enums/tool_types", response_model=ApiResponse)
async def get_tool_types():
"""获取工具类型枚举"""
return success(
data=[
{"value": ToolType.BUILTIN.value, "label": "内置工具"},
{"value": ToolType.CUSTOM.value, "label": "自定义工具"},
{"value": ToolType.MCP.value, "label": "MCP工具"}
],
msg="获取工具类型成功"
)
@router.get("/enums/status", response_model=ApiResponse)
async def get_tool_status():
"""获取工具状态枚举"""
return success(data=ToolStatus.get_all_statuses_with_labels(), msg="获取工具状态成功")
@router.get("/auth/types", response_model=ApiResponse)
async def get_auth_types():
"""获取认证类型枚举"""
return success(data=AuthType.get_all_types_with_labels(), msg="获取认证类型成功")