feat(tool system): Tool system reengineering

This commit is contained in:
谢俊男
2025-12-25 17:30:20 +08:00
parent 3bcaead413
commit 04be3088a2
25 changed files with 1887 additions and 3325 deletions

View File

@@ -33,7 +33,6 @@ from . import (
emotion_config_controller,
prompt_optimizer_controller,
tool_controller,
tool_execution_controller,
)
from . import user_memory_controllers
@@ -71,6 +70,5 @@ manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(tool_controller.router)
manager_router.include_router(tool_execution_controller.router)
__all__ = ["manager_router"]

View File

@@ -1,585 +1,250 @@
"""工具管理API控制器"""
import base64
from typing import List, Optional, Dict, Any
"""工具控制器 - 简化统一的工具管理接口"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Body
from langfuse.api.core import jsonable_encoder
from sqlalchemy.exc import SQLAlchemyError
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field, PositiveInt, field_validator
from cryptography.fernet import Fernet
from app.schemas.tool_schema import (
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
)
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig
from app.core.logging_config import get_business_logger
from app.core.config import settings
from app.core.tools.config_manager import ConfigManager
from app.models.tool_model import ToolType, ToolStatus, AuthType
from app.services.tool_service import ToolService
from app.schemas.response_schema import ApiResponse
logger = get_business_logger()
router = APIRouter(prefix="/tools", tags=["工具管理"])
router = APIRouter(prefix="/tools", tags=["Tool System"])
# ==================== 辅助函数 ====================
def get_tool_service(db: Session = Depends(get_db)) -> ToolService:
return ToolService(db)
def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""加密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
encrypted_params = {}
sensitive_keys = ['api_key', 'token', 'api_secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
encrypted_params[key] = cipher.encrypt(str(value).encode()).decode()
else:
encrypted_params[key] = value
return encrypted_params
@router.get("/statistics", response_model=ApiResponse)
async def get_tool_statistics(
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具统计信息"""
try:
stats = service.get_tool_statistics(current_user.tenant_id)
return success(data=stats, msg="获取统计信息成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]:
"""解密敏感参数"""
cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode())
cipher = Fernet(cipher_key)
decrypted_params = {}
sensitive_keys = ['api_key', 'token', 'secret', 'password']
for key, value in parameters.items():
if any(sensitive in key.lower() for sensitive in sensitive_keys) and value:
try:
decrypted_params[key] = cipher.decrypt(value.encode()).decode()
except Exception as e:
decrypted_params[key] = value
else:
decrypted_params[key] = value
return decrypted_params
def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str:
"""更新工具状态并返回新状态"""
if tool_config.tool_type == ToolType.BUILTIN:
if not tool_info or not tool_info.get('requires_config', False):
new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具
elif not builtin_config or not builtin_config.parameters:
new_status = ToolStatus.INACTIVE.value
else:
# 检查是否有必要的API密钥
has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value
else: # 自定义和MCP工具
new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value
# 更新数据库中的状态
if tool_config.status != new_status:
tool_config.status = new_status
return new_status
# ==================== 请求/响应模型 ====================
class ToolListResponse(BaseModel):
"""工具列表响应"""
id: str
name: str
description: str
tool_type: str
category: str
version: str = "1.0.0"
status: str # active inactive error loading
requires_config: bool = False
# is_configured: bool = False
class Config:
from_attributes = True
class BuiltinToolConfigRequest(BaseModel):
"""内置工具配置请求"""
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
class CustomToolCreateRequest(BaseModel):
"""自定义工具创建请求体模型,包含参数校验规则"""
name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填")
description: str = Field(None, description="工具描述")
base_url: str = Field(None, description="工具基础URL")
schema_url: str = Field(None, description="工具Schema URL")
schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容可选")
auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型")
auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典")
timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间1-300秒默认30")
# 自定义校验当auth_type为api_key时auth_config必须包含api_key字段
@field_validator("auth_config")
def validate_auth_config(cls, v, values):
auth_type = values.data.get("auth_type")
if auth_type == "api_key" and (not v or "api_key" not in v):
raise ValueError("认证类型为api_key时auth_config必须包含api_key字段")
if auth_type == "bearer_token" and (not v or "bearer_token" not in v):
raise ValueError("认证类型为bearer_token时auth_config必须包含bearer_token字段")
return v
class MCPToolCreateRequest(BaseModel):
"""MCP工具创建请求体模型适配MCP业务特性"""
# 基础必填字段(带长度/格式校验)
name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称")
description: str = Field(None, description="MCP工具描述")
# MCP核心字段服务端URL强制HTTP/HTTPS格式
server_url: str = Field(..., description="MCP服务端URL仅支持http/https协议")
# 连接配置:默认空字典,可自定义校验规则(根据实际业务调整)
connection_config: Dict[str, Any] = Field({},description="MCP连接配置如认证信息、超时、重试等默认空字典")
@field_validator("connection_config")
def validate_connection_config(cls, v):
# 示例1若包含timeout必须是1-300的整数
if "timeout" in v:
timeout = v["timeout"]
if not isinstance(timeout, int) or timeout < 1 or timeout > 300:
raise ValueError("connection_config.timeout必须是1-300的整数")
return v
# @field_validator("server_url")
# def validate_server_url_protocol(cls, v):
# if v.scheme != "https":
# raise ValueError("MCP服务端URL仅支持HTTPS协议安全要求")
# return v
# ==================== API端点 ====================
@router.get("", response_model=List[ToolListResponse])
@router.get("", response_model=ApiResponse)
async def list_tools(
name: Optional[str] = None,
tool_type: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
name: Optional[str] = Query(None),
tool_type: Optional[str] = Query(None),
status: Optional[str] = Query(None),
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具列表包含内置工具、自定义工具和MCP工具"""
"""获取工具列表"""
try:
# 初始化内置工具(如果需要)
config_manager = ConfigManager()
config_manager.ensure_builtin_tools_initialized(
current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus
# 确保内置工具已初始化
service.ensure_builtin_tools_initialized(current_user.tenant_id)
# 获取工具列表
tools = service.list_tools(
tenant_id=current_user.tenant_id,
name=name,
tool_type=ToolType(tool_type) if tool_type else None,
status=ToolStatus(status) if status else None
)
return success(data=tools, msg="获取工具列表成功")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
response_tools = []
query = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id
@router.get("/{tool_id}", response_model=ApiResponse)
async def get_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取工具详情"""
tool = service.get_tool_info(tool_id, current_user.tenant_id)
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=tool, msg="获取工具详情成功")
@router.post("", response_model=ApiResponse)
async def create_tool(
request: ToolCreateRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""创建工具"""
try:
tool_id = service.create_tool(
name=request.name,
tool_type=request.tool_type,
tenant_id=current_user.tenant_id,
icon=request.icon,
description=request.description,
config=request.config
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type)
if name:
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
tools = query.all()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
for tool_config in tools:
if tool_config.tool_type == ToolType.BUILTIN.value:
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
status = _update_tool_status(tool_config, builtin_config, tool_info)
else:
status = _update_tool_status(tool_config)
response_tools.append(ToolListResponse(
id=str(tool_config.id),
name=tool_config.name,
description=tool_config.description,
tool_type=tool_config.tool_type,
category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type,
version="1.0.0",
status=status,
requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False,
))
return response_tools
return success(data={"tool_id": tool_id}, msg="工具创建成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"获取工具列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}")
async def get_builtin_tool_detail(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
@router.put("/{tool_id}", response_model=ApiResponse)
async def update_tool(
tool_id: str,
request: ToolUpdateRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""获取内置工具详情"""
"""更新工具"""
try:
config_manager = ConfigManager()
builtin_tools = config_manager.load_builtin_tools_config()
configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()}
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id
).first()
builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first()
tool_info = configured_tools.get(builtin_config.tool_class)
is_configured = False
config_parameters = {}
if builtin_config and builtin_config.parameters:
is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token'))
# 不返回敏感信息,只返回非敏感配置
config_parameters = {k: v for k, v in builtin_config.parameters.items()
if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])}
return {
"id": tool_config.id,
"name": tool_config.name,
"description": tool_config.description,
"category": tool_info['category'],
"status": tool_config.tool_type,
"requires_config": tool_info['requires_config'],
"is_configured": is_configured,
"config_parameters": config_parameters
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/builtin/{tool_id}/configure")
async def configure_builtin_tool(
tool_id: str,
request: BuiltinToolConfigRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""配置内置工具参数(租户级别)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 获取全局工具信息
config_manager = ConfigManager()
builtin_tools_config = config_manager.load_builtin_tools_config()
tool_info = None
for tool_key, info in builtin_tools_config.items():
if info['tool_class'] == builtin_config.tool_class:
tool_info = info
break
if not tool_info:
raise HTTPException(status_code=404, detail="工具信息不存在")
# 加密敏感参数
encrypted_params = _encrypt_sensitive_params(request.parameters)
# 更新配置
builtin_config.parameters = encrypted_params
# 更新状态
_update_tool_status(tool_config, builtin_config, tool_info)
db.commit()
return {
"success": True,
"message": f"工具 {tool_config.name} 配置成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"配置内置工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/builtin/{tool_id}/config")
async def get_builtin_tool_config(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""获取内置工具配置(用于使用)"""
try:
# 查询工具配置
tool_config = db.query(ToolConfig).filter(
ToolConfig.tenant_id == current_user.tenant_id,
ToolConfig.id == tool_id,
ToolConfig.tool_type == ToolType.BUILTIN
).first()
if not tool_config:
raise HTTPException(status_code=404, detail="工具不存在")
# 获取内置工具配置
builtin_config = db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if not builtin_config:
raise HTTPException(status_code=404, detail="内置工具配置不存在")
# 解密参数
decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {})
return {
"tool_id": tool_id,
"tool_class": builtin_config.tool_class,
"name": tool_config.name,
"parameters": decrypted_params,
"status": tool_config.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"获取工具配置失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/custom")
async def create_custom_tool(
request: CustomToolCreateRequest = Body(...),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""创建自定义工具"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "custom"
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "custom")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
# 创建数据库记录
tool_config = ToolConfig(
success_flag = service.update_tool(
tool_id=tool_id,
tenant_id=current_user.tenant_id,
name=request.name,
description=request.description,
tool_type=ToolType.CUSTOM,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
icon=request.icon,
config=request.config,
is_enabled=request.config.get("is_enabled", None)
)
db.add(tool_config)
db.flush()
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
return success(msg="工具更新成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 创建CustomToolConfig记录
custom_config = CustomToolConfig(
id=tool_config.id,
base_url=request.base_url,
schema_url=request.schema_url,
schema_content=request.schema_content,
auth_type=request.auth_type,
auth_config=request.auth_config,
@router.delete("/{tool_id}", response_model=ApiResponse)
async def delete_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""删除工具"""
try:
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
return success(msg="工具删除成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/execution/execute", response_model=ApiResponse)
async def execute_tool(
request: ToolExecuteRequest,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""执行工具"""
try:
result = await service.execute_tool(
tool_id=request.tool_id,
parameters=request.parameters,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
timeout=request.timeout
)
db.add(custom_config)
db.commit()
return {
"success": True,
"message": f"自定义工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
return success(
data={
"success": result.success,
"data": result.data,
"error": result.error,
"execution_time": result.execution_time,
"token_usage": result.token_usage
},
msg="工具执行完成"
)
except Exception as e:
logger.error(f"创建自定义工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/mcp")
async def create_mcp_tool(
request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"),
@router.post("/parse_schema", response_model=ApiResponse)
async def parse_openapi_schema(
request: ParseSchemaRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
service: ToolService = Depends(get_tool_service)
):
"""创建MCP工具"""
"""解析OpenAPI schema"""
try:
config_data = jsonable_encoder(request.model_dump())
config_data["tool_type"] = "mcp"
result = await service.parse_openapi_schema(request.schema_content, request.schema_url)
if result["success"] is False:
raise HTTPException(status_code=400, detail=result["message"])
return success(data=result, msg="Schema解析完成")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
config_manager = ConfigManager()
is_valid, error_msg = config_manager.validate_config(config_data, "mcp")
if not is_valid:
raise HTTPException(status_code=400, detail=error_msg)
@router.post("/{tool_id}/sync_mcp_tools", response_model=ApiResponse)
async def sync_mcp_tools(
tool_id: str,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""同步MCP工具列表"""
try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if result["success"] is False:
raise HTTPException(status_code=404, detail=result["message"])
return success(data=result, msg="MCP工具列表同步完成")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# 创建数据库记录
try:
tool_config = ToolConfig(
name=request.name,
description=request.description,
tool_type=ToolType.MCP,
tenant_id=current_user.tenant_id,
status=ToolStatus.ACTIVE.value,
config_data=config_data
@router.post("/{tool_id}/test", response_model=ApiResponse)
async def test_tool_connection(
tool_id: str,
test_request: Optional[CustomToolTestRequest] = None,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""测试工具连接"""
try:
if test_request:
# 自定义工具测试
result = await service.test_custom_tool(
tool_id, current_user.tenant_id,
test_request.method, test_request.path, test_request.parameters
)
db.add(tool_config)
db.flush()
# 创建MCPToolConfig记录
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=request.server_url,
connection_config=request.connection_config
)
db.add(mcp_config)
db.commit()
except SQLAlchemyError as db_e:
db.rollback()
logger.error(f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id},工具名:{request.name}: {str(db_e)}",
exc_info=True)
raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败租户ID{current_user.tenant_id}"
f"工具名:{request.name}{str(db_e)}")
return {
"success": True,
"message": f"MCP工具 {request.name} 创建成功",
"tool_id": str(tool_config.id)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"创建MCP工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/{tool_id}")
async def delete_tool(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""删除工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许删除")
db.delete(tool)
db.commit()
return {
"success": True,
"message": f"工具 {tool.name} 删除成功"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"删除工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put("/{tool_id}")
async def update_tool(
tool_id: str,
config_data: Optional[Dict[str, Any]] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""更新工具仅限自定义和MCP工具"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
if tool.tool_type == ToolType.BUILTIN:
raise HTTPException(status_code=403, detail="内置工具不允许修改")
if config_data is not None:
tool.config_data = config_data
# 更新状态
_update_tool_status(tool)
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 更新成功",
"status": tool.status
}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新工具失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/{tool_id}/toggle")
async def toggle_tool_status(
tool_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""切换工具活跃/非活跃状态"""
try:
tool = db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == current_user.tenant_id
).first()
if not tool:
raise HTTPException(status_code=404, detail="工具不存在")
# 在active和inactive之间切换
if tool.status == ToolStatus.ACTIVE.value:
tool.status = ToolStatus.INACTIVE.value
elif tool.status == ToolStatus.INACTIVE.value:
tool.status = ToolStatus.ACTIVE.value
else:
raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换")
db.commit()
db.refresh(tool)
return {
"success": True,
"message": f"工具 {tool.name} 状态已更新为 {tool.status}",
"status": tool.status
}
except HTTPException:
raise
# 普通连接测试
result = await service.test_connection(tool_id, current_user.tenant_id)
return success(data=result, msg="连接测试完成")
except Exception as e:
logger.error(f"切换工具状态失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.get("/enums/tool_types", response_model=ApiResponse)
async def get_tool_types():
"""获取工具类型枚举"""
return success(
data=[
{"value": ToolType.BUILTIN.value, "label": "内置工具"},
{"value": ToolType.CUSTOM.value, "label": "自定义工具"},
{"value": ToolType.MCP.value, "label": "MCP工具"}
],
msg="获取工具类型成功"
)
@router.get("/enums/status", response_model=ApiResponse)
async def get_tool_status():
"""获取工具状态枚举"""
return success(data=ToolStatus.get_all_statuses_with_labels(), msg="获取工具状态成功")
@router.get("/auth/types", response_model=ApiResponse)
async def get_auth_types():
"""获取认证类型枚举"""
return success(data=AuthType.get_all_types_with_labels(), msg="获取认证类型成功")

View File

@@ -1,430 +0,0 @@
"""工具执行API控制器"""
import uuid
from typing import Dict, Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Path, Query
from sqlalchemy.orm import Session
from pydantic import BaseModel, Field
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.core.tools.registry import ToolRegistry
from app.core.tools.executor import ToolExecutor
from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode
from app.core.tools.builtin import *
from app.core.logging_config import get_business_logger
logger = get_business_logger()
router = APIRouter(prefix="/tools/execution", tags=["工具执行"])
# ==================== 请求/响应模型 ====================
class ToolExecutionRequest(BaseModel):
"""工具执行请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)")
metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据")
class BatchExecutionRequest(BaseModel):
"""批量执行请求"""
executions: List[ToolExecutionRequest] = Field(..., description="执行列表")
max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数")
class ToolExecutionResponse(BaseModel):
"""工具执行响应"""
success: bool
execution_id: str
tool_id: str
data: Any = None
error: Optional[str] = None
error_code: Optional[str] = None
execution_time: float
token_usage: Optional[Dict[str, int]] = None
metadata: Dict[str, Any] = Field(default_factory=dict)
class ChainStepRequest(BaseModel):
"""链步骤请求"""
tool_id: str = Field(..., description="工具ID")
parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数")
condition: Optional[str] = Field(None, description="执行条件")
output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射")
error_handling: str = Field("stop", description="错误处理策略")
class ChainExecutionRequest(BaseModel):
"""链执行请求"""
name: str = Field(..., description="链名称")
description: str = Field("", description="链描述")
steps: List[ChainStepRequest] = Field(..., description="执行步骤")
execution_mode: str = Field("sequential", description="执行模式")
initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量")
global_timeout: Optional[float] = Field(None, description="全局超时")
class ExecutionHistoryResponse(BaseModel):
"""执行历史响应"""
execution_id: str
tool_id: str
status: str
started_at: Optional[str]
completed_at: Optional[str]
execution_time: Optional[float]
user_id: Optional[str]
workspace_id: Optional[str]
input_data: Optional[Dict[str, Any]]
output_data: Optional[Any]
error_message: Optional[str]
token_usage: Optional[Dict[str, int]]
class ToolConnectionTestResponse(BaseModel):
"""工具连接测试响应"""
success: bool
message: str
error: Optional[str] = None
details: Optional[Dict[str, Any]] = None
# ==================== 依赖注入 ====================
def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry:
"""获取工具注册表"""
registry = ToolRegistry(db)
# 注册内置工具类
registry.register_tool_class(DateTimeTool)
registry.register_tool_class(JsonTool)
registry.register_tool_class(BaiduSearchTool)
registry.register_tool_class(MinerUTool)
registry.register_tool_class(TextInTool)
return registry
def get_tool_executor(
db: Session = Depends(get_db),
registry: ToolRegistry = Depends(get_tool_registry)
) -> ToolExecutor:
"""获取工具执行器"""
return ToolExecutor(db, registry)
def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager:
"""获取链管理器"""
return ChainManager(executor)
# ==================== API端点 ====================
@router.post("/execute", response_model=ToolExecutionResponse)
async def execute_tool(
request: ToolExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""执行单个工具"""
try:
# 生成执行ID
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 执行工具
result = await executor.execute_tool(
tool_id=request.tool_id,
parameters=request.parameters,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
execution_id=execution_id,
timeout=request.timeout,
metadata=request.metadata
)
return ToolExecutionResponse(
success=result.success,
execution_id=execution_id,
tool_id=request.tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
)
except Exception as e:
logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/batch", response_model=List[ToolExecutionResponse])
async def execute_tools_batch(
request: BatchExecutionRequest,
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""批量执行工具"""
try:
# 准备执行配置
execution_configs = []
execution_ids = []
for exec_request in request.executions:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
execution_ids.append(execution_id)
execution_configs.append({
"tool_id": exec_request.tool_id,
"parameters": exec_request.parameters,
"user_id": current_user.id,
"workspace_id": current_user.current_workspace_id,
"execution_id": execution_id,
"timeout": exec_request.timeout,
"metadata": exec_request.metadata
})
# 批量执行
results = await executor.execute_tools_batch(
execution_configs,
max_concurrency=request.max_concurrency
)
# 转换响应格式
responses = []
for i, result in enumerate(results):
responses.append(ToolExecutionResponse(
success=result.success,
execution_id=execution_ids[i],
tool_id=request.executions[i].tool_id,
data=result.data,
error=result.error,
error_code=result.error_code,
execution_time=result.execution_time,
token_usage=result.token_usage,
metadata=result.metadata
))
return responses
except Exception as e:
logger.error(f"批量执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/chain", response_model=Dict[str, Any])
async def execute_tool_chain(
request: ChainExecutionRequest,
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""执行工具链"""
try:
# 转换步骤格式
steps = []
for step_request in request.steps:
step = ChainStep(
tool_id=step_request.tool_id,
parameters=step_request.parameters,
condition=step_request.condition,
output_mapping=step_request.output_mapping,
error_handling=step_request.error_handling
)
steps.append(step)
# 创建链定义
chain_definition = ChainDefinition(
name=request.name,
description=request.description,
steps=steps,
execution_mode=ChainExecutionMode(request.execution_mode),
global_timeout=request.global_timeout
)
# 注册并执行链
chain_manager.register_chain(chain_definition)
result = await chain_manager.execute_chain(
chain_name=request.name,
initial_variables=request.initial_variables
)
return result
except Exception as e:
logger.error(f"工具链执行失败: {request.name}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/running", response_model=List[Dict[str, Any]])
async def get_running_executions(
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取正在运行的执行"""
try:
running_executions = executor.get_running_executions()
# 过滤当前工作空间的执行
workspace_executions = [
exec_info for exec_info in running_executions
if exec_info.get("workspace_id") == str(current_user.current_workspace_id)
]
return workspace_executions
except Exception as e:
logger.error(f"获取运行中执行失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any])
async def cancel_execution(
execution_id: str = Path(..., description="执行ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""取消工具执行"""
try:
success = await executor.cancel_execution(execution_id)
if success:
return {
"success": True,
"message": "执行已取消"
}
else:
raise HTTPException(status_code=404, detail="执行不存在或已完成")
except HTTPException:
raise
except Exception as e:
logger.error(f"取消执行失败: {execution_id}, 错误: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/history", response_model=List[ExecutionHistoryResponse])
async def get_execution_history(
tool_id: Optional[str] = Query(None, description="工具ID过滤"),
limit: int = Query(50, ge=1, le=200, description="返回数量限制"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行历史"""
try:
history = executor.get_execution_history(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
limit=limit
)
# 转换响应格式
responses = []
for record in history:
responses.append(ExecutionHistoryResponse(
execution_id=record["execution_id"],
tool_id=record["tool_id"],
status=record["status"],
started_at=record["started_at"],
completed_at=record["completed_at"],
execution_time=record["execution_time"],
user_id=record["user_id"],
workspace_id=record["workspace_id"],
input_data=record["input_data"],
output_data=record["output_data"],
error_message=record["error_message"],
token_usage=record["token_usage"]
))
return responses
except Exception as e:
logger.error(f"获取执行历史失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/statistics", response_model=Dict[str, Any])
async def get_execution_statistics(
days: int = Query(7, ge=1, le=90, description="统计天数"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""获取执行统计"""
try:
stats = executor.get_execution_statistics(
workspace_id=current_user.current_workspace_id,
days=days
)
return {
"success": True,
"statistics": stats
}
except Exception as e:
logger.error(f"获取执行统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains/running", response_model=List[Dict[str, Any]])
async def get_running_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""获取正在运行的工具链"""
try:
running_chains = chain_manager.get_running_chains()
return running_chains
except Exception as e:
logger.error(f"获取运行中工具链失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/chains", response_model=List[Dict[str, Any]])
async def list_tool_chains(
current_user: User = Depends(get_current_user),
chain_manager: ChainManager = Depends(get_chain_manager)
):
"""列出工具链"""
try:
chains = chain_manager.list_chains()
return chains
except Exception as e:
logger.error(f"获取工具链列表失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse)
async def test_tool_connection(
tool_id: str = Path(..., description="工具ID"),
current_user: User = Depends(get_current_user),
executor: ToolExecutor = Depends(get_tool_executor)
):
"""测试工具连接"""
try:
result = await executor.test_tool_connection(
tool_id=tool_id,
user_id=current_user.id,
workspace_id=current_user.current_workspace_id
)
return ToolConnectionTestResponse(
success=result.get("success", False),
message=result.get("message", ""),
error=result.get("error"),
details=result.get("details")
)
except Exception as e:
logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}")
return ToolConnectionTestResponse(
success=False,
message="连接测试失败",
error=str(e)
)