From 04be3088a270bccc74745a4d83a1e851b1305a3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Thu, 25 Dec 2025 17:30:20 +0800 Subject: [PATCH] feat(tool system): Tool system reengineering --- api/app/controllers/__init__.py | 2 - api/app/controllers/tool_controller.py | 747 ++++--------- .../controllers/tool_execution_controller.py | 430 -------- api/app/core/tools/__init__.py | 10 +- api/app/core/tools/base.py | 108 +- api/app/core/tools/builtin/base.py | 3 +- api/app/core/tools/builtin/datetime_tool.py | 48 +- api/app/core/tools/builtin/json_tool.py | 53 +- api/app/core/tools/builtin/textin_tool.py | 20 +- api/app/core/tools/chain_manager.py | 485 --------- api/app/core/tools/config_manager.py | 264 ----- api/app/core/tools/configs/builtin_tools.json | 3 +- api/app/core/tools/custom/auth_manager.py | 39 +- api/app/core/tools/custom/base.py | 14 +- api/app/core/tools/custom/schema_parser.py | 42 +- api/app/core/tools/executor.py | 501 --------- api/app/core/tools/mcp/base.py | 111 +- api/app/core/tools/mcp/client.py | 46 +- api/app/core/tools/mcp/service_manager.py | 9 +- api/app/core/tools/registry.py | 436 -------- api/app/models/tool_model.py | 74 +- api/app/repositories/tool_repository.py | 157 +++ api/app/schemas/tool_schema.py | 259 +++++ api/app/services/tool_service.py | 977 ++++++++++++++++++ api/test_tool_system.py | 374 ------- 25 files changed, 1887 insertions(+), 3325 deletions(-) delete mode 100644 api/app/controllers/tool_execution_controller.py delete mode 100644 api/app/core/tools/chain_manager.py delete mode 100644 api/app/core/tools/config_manager.py delete mode 100644 api/app/core/tools/executor.py delete mode 100644 api/app/core/tools/registry.py create mode 100644 api/app/repositories/tool_repository.py create mode 100644 api/app/schemas/tool_schema.py create mode 100644 api/app/services/tool_service.py delete mode 100644 api/test_tool_system.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index c72072eb..13e66ea7 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -33,7 +33,6 @@ from . import ( emotion_config_controller, prompt_optimizer_controller, tool_controller, - tool_execution_controller, ) from . import user_memory_controllers @@ -71,6 +70,5 @@ manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(tool_controller.router) -manager_router.include_router(tool_execution_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py index 433392d2..dc304c50 100644 --- a/api/app/controllers/tool_controller.py +++ b/api/app/controllers/tool_controller.py @@ -1,585 +1,250 @@ -"""工具管理API控制器""" -import base64 -from typing import List, Optional, Dict, Any +"""工具控制器 - 简化统一的工具管理接口""" +from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Body -from langfuse.api.core import jsonable_encoder -from sqlalchemy.exc import SQLAlchemyError +from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session -from pydantic import BaseModel, Field, PositiveInt, field_validator -from cryptography.fernet import Fernet +from app.schemas.tool_schema import ( + ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest +) +from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user from app.models import User -from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig -from app.core.logging_config import get_business_logger -from app.core.config import settings -from app.core.tools.config_manager import ConfigManager +from app.models.tool_model import ToolType, ToolStatus, AuthType +from app.services.tool_service import ToolService +from app.schemas.response_schema import ApiResponse -logger = get_business_logger() - -router = APIRouter(prefix="/tools", tags=["工具管理"]) +router = APIRouter(prefix="/tools", tags=["Tool System"]) -# ==================== 辅助函数 ==================== +def get_tool_service(db: Session = Depends(get_db)) -> ToolService: + return ToolService(db) -def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]: - """加密敏感参数""" - cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode()) - cipher = Fernet(cipher_key) - - encrypted_params = {} - sensitive_keys = ['api_key', 'token', 'api_secret', 'password'] - - for key, value in parameters.items(): - if any(sensitive in key.lower() for sensitive in sensitive_keys) and value: - encrypted_params[key] = cipher.encrypt(str(value).encode()).decode() - else: - encrypted_params[key] = value - - return encrypted_params +@router.get("/statistics", response_model=ApiResponse) +async def get_tool_statistics( + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """获取工具统计信息""" + try: + stats = service.get_tool_statistics(current_user.tenant_id) + return success(data=stats, msg="获取统计信息成功") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) -def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]: - """解密敏感参数""" - cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode()) - cipher = Fernet(cipher_key) - - decrypted_params = {} - sensitive_keys = ['api_key', 'token', 'secret', 'password'] - - for key, value in parameters.items(): - if any(sensitive in key.lower() for sensitive in sensitive_keys) and value: - try: - decrypted_params[key] = cipher.decrypt(value.encode()).decode() - except Exception as e: - decrypted_params[key] = value - else: - decrypted_params[key] = value - - return decrypted_params - - -def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str: - """更新工具状态并返回新状态""" - if tool_config.tool_type == ToolType.BUILTIN: - if not tool_info or not tool_info.get('requires_config', False): - new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具 - elif not builtin_config or not builtin_config.parameters: - new_status = ToolStatus.INACTIVE.value - else: - # 检查是否有必要的API密钥 - has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token')) - new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value - else: # 自定义和MCP工具 - new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value - - # 更新数据库中的状态 - if tool_config.status != new_status: - tool_config.status = new_status - - return new_status - - -# ==================== 请求/响应模型 ==================== - -class ToolListResponse(BaseModel): - """工具列表响应""" - id: str - name: str - description: str - tool_type: str - category: str - version: str = "1.0.0" - status: str # active inactive error loading - requires_config: bool = False - # is_configured: bool = False - - class Config: - from_attributes = True - -class BuiltinToolConfigRequest(BaseModel): - """内置工具配置请求""" - parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") - - -class CustomToolCreateRequest(BaseModel): - """自定义工具创建请求体模型,包含参数校验规则""" - name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填") - description: str = Field(None, description="工具描述") - base_url: str = Field(None, description="工具基础URL") - schema_url: str = Field(None, description="工具Schema URL") - schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容,可选") - auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型") - auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典") - timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间,1-300秒,默认30") - - # 自定义校验:当auth_type为api_key时,auth_config必须包含api_key字段 - @field_validator("auth_config") - def validate_auth_config(cls, v, values): - auth_type = values.data.get("auth_type") - if auth_type == "api_key" and (not v or "api_key" not in v): - raise ValueError("认证类型为api_key时,auth_config必须包含api_key字段") - if auth_type == "bearer_token" and (not v or "bearer_token" not in v): - raise ValueError("认证类型为bearer_token时,auth_config必须包含bearer_token字段") - return v - -class MCPToolCreateRequest(BaseModel): - """MCP工具创建请求体模型,适配MCP业务特性""" - # 基础必填字段(带长度/格式校验) - name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称") - description: str = Field(None, description="MCP工具描述") - # MCP核心字段:服务端URL(强制HTTP/HTTPS格式) - server_url: str = Field(..., description="MCP服务端URL,仅支持http/https协议") - # 连接配置:默认空字典,可自定义校验规则(根据实际业务调整) - connection_config: Dict[str, Any] = Field({},description="MCP连接配置(如认证信息、超时、重试等),默认空字典") - - @field_validator("connection_config") - def validate_connection_config(cls, v): - # 示例1:若包含timeout,必须是1-300的整数 - if "timeout" in v: - timeout = v["timeout"] - if not isinstance(timeout, int) or timeout < 1 or timeout > 300: - raise ValueError("connection_config.timeout必须是1-300的整数") - return v - - # @field_validator("server_url") - # def validate_server_url_protocol(cls, v): - # if v.scheme != "https": - # raise ValueError("MCP服务端URL仅支持HTTPS协议(安全要求)") - # return v - - -# ==================== API端点 ==================== -@router.get("", response_model=List[ToolListResponse]) +@router.get("", response_model=ApiResponse) async def list_tools( - name: Optional[str] = None, - tool_type: Optional[str] = None, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + name: Optional[str] = Query(None), + tool_type: Optional[str] = Query(None), + status: Optional[str] = Query(None), + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) ): - """获取工具列表(包含内置工具、自定义工具和MCP工具)""" + """获取工具列表""" try: - # 初始化内置工具(如果需要) - config_manager = ConfigManager() - config_manager.ensure_builtin_tools_initialized( - current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus + # 确保内置工具已初始化 + service.ensure_builtin_tools_initialized(current_user.tenant_id) + + # 获取工具列表 + tools = service.list_tools( + tenant_id=current_user.tenant_id, + name=name, + tool_type=ToolType(tool_type) if tool_type else None, + status=ToolStatus(status) if status else None ) + return success(data=tools, msg="获取工具列表成功") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) - response_tools = [] - query = db.query(ToolConfig).filter( - ToolConfig.tenant_id == current_user.tenant_id +@router.get("/{tool_id}", response_model=ApiResponse) +async def get_tool( + tool_id: str, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """获取工具详情""" + tool = service.get_tool_info(tool_id, current_user.tenant_id) + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + return success(data=tool, msg="获取工具详情成功") + + +@router.post("", response_model=ApiResponse) +async def create_tool( + request: ToolCreateRequest, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """创建工具""" + try: + tool_id = service.create_tool( + name=request.name, + tool_type=request.tool_type, + tenant_id=current_user.tenant_id, + icon=request.icon, + description=request.description, + config=request.config ) - if tool_type: - query = query.filter(ToolConfig.tool_type == tool_type) - - if name: - query = query.filter(ToolConfig.name.ilike(f"%{name}%")) - - tools = query.all() - builtin_tools = config_manager.load_builtin_tools_config() - configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()} - - for tool_config in tools: - if tool_config.tool_type == ToolType.BUILTIN.value: - builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first() - tool_info = configured_tools.get(builtin_config.tool_class) - status = _update_tool_status(tool_config, builtin_config, tool_info) - else: - status = _update_tool_status(tool_config) - - response_tools.append(ToolListResponse( - id=str(tool_config.id), - name=tool_config.name, - description=tool_config.description, - tool_type=tool_config.tool_type, - category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type, - version="1.0.0", - status=status, - requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False, - )) - - return response_tools + return success(data={"tool_id": tool_id}, msg="工具创建成功") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) except Exception as e: - logger.error(f"获取工具列表失败: {e}") raise HTTPException(status_code=500, detail=str(e)) -@router.get("/builtin/{tool_id}") -async def get_builtin_tool_detail( - tool_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) +@router.put("/{tool_id}", response_model=ApiResponse) +async def update_tool( + tool_id: str, + request: ToolUpdateRequest, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) ): - """获取内置工具详情""" + """更新工具""" try: - config_manager = ConfigManager() - builtin_tools = config_manager.load_builtin_tools_config() - configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()} - tool_config = db.query(ToolConfig).filter( - ToolConfig.tenant_id == current_user.tenant_id, - ToolConfig.id == tool_id - ).first() - builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first() - tool_info = configured_tools.get(builtin_config.tool_class) - - is_configured = False - config_parameters = {} - - if builtin_config and builtin_config.parameters: - is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token')) - # 不返回敏感信息,只返回非敏感配置 - config_parameters = {k: v for k, v in builtin_config.parameters.items() - if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])} - - return { - "id": tool_config.id, - "name": tool_config.name, - "description": tool_config.description, - "category": tool_info['category'], - "status": tool_config.tool_type, - "requires_config": tool_info['requires_config'], - "is_configured": is_configured, - "config_parameters": config_parameters - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/builtin/{tool_id}/configure") -async def configure_builtin_tool( - tool_id: str, - request: BuiltinToolConfigRequest = Body(...), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -): - """配置内置工具参数(租户级别)""" - try: - # 查询工具配置 - tool_config = db.query(ToolConfig).filter( - ToolConfig.tenant_id == current_user.tenant_id, - ToolConfig.id == tool_id, - ToolConfig.tool_type == ToolType.BUILTIN - ).first() - - if not tool_config: - raise HTTPException(status_code=404, detail="工具不存在") - - # 获取内置工具配置 - builtin_config = db.query(BuiltinToolConfig).filter( - BuiltinToolConfig.id == tool_config.id - ).first() - - if not builtin_config: - raise HTTPException(status_code=404, detail="内置工具配置不存在") - - # 获取全局工具信息 - config_manager = ConfigManager() - builtin_tools_config = config_manager.load_builtin_tools_config() - tool_info = None - for tool_key, info in builtin_tools_config.items(): - if info['tool_class'] == builtin_config.tool_class: - tool_info = info - break - - if not tool_info: - raise HTTPException(status_code=404, detail="工具信息不存在") - - # 加密敏感参数 - encrypted_params = _encrypt_sensitive_params(request.parameters) - - # 更新配置 - builtin_config.parameters = encrypted_params - - # 更新状态 - _update_tool_status(tool_config, builtin_config, tool_info) - - db.commit() - - return { - "success": True, - "message": f"工具 {tool_config.name} 配置成功" - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"配置内置工具失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/builtin/{tool_id}/config") -async def get_builtin_tool_config( - tool_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -): - """获取内置工具配置(用于使用)""" - try: - # 查询工具配置 - tool_config = db.query(ToolConfig).filter( - ToolConfig.tenant_id == current_user.tenant_id, - ToolConfig.id == tool_id, - ToolConfig.tool_type == ToolType.BUILTIN - ).first() - - if not tool_config: - raise HTTPException(status_code=404, detail="工具不存在") - - # 获取内置工具配置 - builtin_config = db.query(BuiltinToolConfig).filter( - BuiltinToolConfig.id == tool_config.id - ).first() - - if not builtin_config: - raise HTTPException(status_code=404, detail="内置工具配置不存在") - - # 解密参数 - decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {}) - - return { - "tool_id": tool_id, - "tool_class": builtin_config.tool_class, - "name": tool_config.name, - "parameters": decrypted_params, - "status": tool_config.status - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"获取工具配置失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/custom") -async def create_custom_tool( - request: CustomToolCreateRequest = Body(...), - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -): - """创建自定义工具""" - try: - config_data = jsonable_encoder(request.model_dump()) - config_data["tool_type"] = "custom" - - config_manager = ConfigManager() - is_valid, error_msg = config_manager.validate_config(config_data, "custom") - if not is_valid: - raise HTTPException(status_code=400, detail=error_msg) - - # 创建数据库记录 - tool_config = ToolConfig( + success_flag = service.update_tool( + tool_id=tool_id, + tenant_id=current_user.tenant_id, name=request.name, description=request.description, - tool_type=ToolType.CUSTOM, - tenant_id=current_user.tenant_id, - status=ToolStatus.ACTIVE.value, - config_data=config_data + icon=request.icon, + config=request.config, + is_enabled=request.config.get("is_enabled", None) ) - db.add(tool_config) - db.flush() + if not success_flag: + raise HTTPException(status_code=404, detail="工具不存在") + return success(msg="工具更新成功") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) - # 创建CustomToolConfig记录 - custom_config = CustomToolConfig( - id=tool_config.id, - base_url=request.base_url, - schema_url=request.schema_url, - schema_content=request.schema_content, - auth_type=request.auth_type, - auth_config=request.auth_config, + +@router.delete("/{tool_id}", response_model=ApiResponse) +async def delete_tool( + tool_id: str, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """删除工具""" + try: + success_flag = service.delete_tool(tool_id, current_user.tenant_id) + if not success_flag: + raise HTTPException(status_code=404, detail="工具不存在") + return success(msg="工具删除成功") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/execution/execute", response_model=ApiResponse) +async def execute_tool( + request: ToolExecuteRequest, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """执行工具""" + try: + result = await service.execute_tool( + tool_id=request.tool_id, + parameters=request.parameters, + tenant_id=current_user.tenant_id, + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, timeout=request.timeout ) - db.add(custom_config) - db.commit() - - return { - "success": True, - "message": f"自定义工具 {request.name} 创建成功", - "tool_id": str(tool_config.id) - } - - except HTTPException: - raise + return success( + data={ + "success": result.success, + "data": result.data, + "error": result.error, + "execution_time": result.execution_time, + "token_usage": result.token_usage + }, + msg="工具执行完成" + ) except Exception as e: - logger.error(f"创建自定义工具失败: {e}") raise HTTPException(status_code=500, detail=str(e)) -@router.post("/mcp") -async def create_mcp_tool( - request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"), + +@router.post("/parse_schema", response_model=ApiResponse) +async def parse_openapi_schema( + request: ParseSchemaRequest, current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) + service: ToolService = Depends(get_tool_service) ): - """创建MCP工具""" + """解析OpenAPI schema""" try: - config_data = jsonable_encoder(request.model_dump()) - config_data["tool_type"] = "mcp" + result = await service.parse_openapi_schema(request.schema_content, request.schema_url) + if result["success"] is False: + raise HTTPException(status_code=400, detail=result["message"]) + return success(data=result, msg="Schema解析完成") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) - config_manager = ConfigManager() - is_valid, error_msg = config_manager.validate_config(config_data, "mcp") - if not is_valid: - raise HTTPException(status_code=400, detail=error_msg) +@router.post("/{tool_id}/sync_mcp_tools", response_model=ApiResponse) +async def sync_mcp_tools( + tool_id: str, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """同步MCP工具列表""" + try: + result = await service.sync_mcp_tools(tool_id, current_user.tenant_id) + if result["success"] is False: + raise HTTPException(status_code=404, detail=result["message"]) + return success(data=result, msg="MCP工具列表同步完成") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) - # 创建数据库记录 - try: - tool_config = ToolConfig( - name=request.name, - description=request.description, - tool_type=ToolType.MCP, - tenant_id=current_user.tenant_id, - status=ToolStatus.ACTIVE.value, - config_data=config_data +@router.post("/{tool_id}/test", response_model=ApiResponse) +async def test_tool_connection( + tool_id: str, + test_request: Optional[CustomToolTestRequest] = None, + current_user: User = Depends(get_current_user), + service: ToolService = Depends(get_tool_service) +): + """测试工具连接""" + try: + if test_request: + # 自定义工具测试 + result = await service.test_custom_tool( + tool_id, current_user.tenant_id, + test_request.method, test_request.path, test_request.parameters ) - db.add(tool_config) - db.flush() - - # 创建MCPToolConfig记录 - mcp_config = MCPToolConfig( - id=tool_config.id, - server_url=request.server_url, - connection_config=request.connection_config - ) - db.add(mcp_config) - - db.commit() - except SQLAlchemyError as db_e: - db.rollback() - logger.error(f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},工具名:{request.name}): {str(db_e)}", - exc_info=True) - raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id}," - f"工具名:{request.name}):{str(db_e)}") - - return { - "success": True, - "message": f"MCP工具 {request.name} 创建成功", - "tool_id": str(tool_config.id) - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"创建MCP工具失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - -@router.delete("/{tool_id}") -async def delete_tool( - tool_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -): - """删除工具(仅限自定义和MCP工具)""" - try: - tool = db.query(ToolConfig).filter( - ToolConfig.id == tool_id, - ToolConfig.tenant_id == current_user.tenant_id - ).first() - - if not tool: - raise HTTPException(status_code=404, detail="工具不存在") - - if tool.tool_type == ToolType.BUILTIN: - raise HTTPException(status_code=403, detail="内置工具不允许删除") - - db.delete(tool) - db.commit() - - return { - "success": True, - "message": f"工具 {tool.name} 删除成功" - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"删除工具失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.put("/{tool_id}") -async def update_tool( - tool_id: str, - config_data: Optional[Dict[str, Any]] = None, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -): - """更新工具(仅限自定义和MCP工具)""" - try: - tool = db.query(ToolConfig).filter( - ToolConfig.id == tool_id, - ToolConfig.tenant_id == current_user.tenant_id - ).first() - - if not tool: - raise HTTPException(status_code=404, detail="工具不存在") - - if tool.tool_type == ToolType.BUILTIN: - raise HTTPException(status_code=403, detail="内置工具不允许修改") - - if config_data is not None: - tool.config_data = config_data - # 更新状态 - _update_tool_status(tool) - - db.commit() - db.refresh(tool) - - return { - "success": True, - "message": f"工具 {tool.name} 更新成功", - "status": tool.status - } - - except HTTPException: - raise - except Exception as e: - logger.error(f"更新工具失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/{tool_id}/toggle") -async def toggle_tool_status( - tool_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db) -): - """切换工具活跃/非活跃状态""" - try: - tool = db.query(ToolConfig).filter( - ToolConfig.id == tool_id, - ToolConfig.tenant_id == current_user.tenant_id - ).first() - - if not tool: - raise HTTPException(status_code=404, detail="工具不存在") - - # 在active和inactive之间切换 - if tool.status == ToolStatus.ACTIVE.value: - tool.status = ToolStatus.INACTIVE.value - elif tool.status == ToolStatus.INACTIVE.value: - tool.status = ToolStatus.ACTIVE.value else: - raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换") - - db.commit() - db.refresh(tool) - - return { - "success": True, - "message": f"工具 {tool.name} 状态已更新为 {tool.status}", - "status": tool.status - } - - except HTTPException: - raise + # 普通连接测试 + result = await service.test_connection(tool_id, current_user.tenant_id) + return success(data=result, msg="连接测试完成") except Exception as e: - logger.error(f"切换工具状态失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/enums/tool_types", response_model=ApiResponse) +async def get_tool_types(): + """获取工具类型枚举""" + return success( + data=[ + {"value": ToolType.BUILTIN.value, "label": "内置工具"}, + {"value": ToolType.CUSTOM.value, "label": "自定义工具"}, + {"value": ToolType.MCP.value, "label": "MCP工具"} + ], + msg="获取工具类型成功" + ) + + +@router.get("/enums/status", response_model=ApiResponse) +async def get_tool_status(): + """获取工具状态枚举""" + return success(data=ToolStatus.get_all_statuses_with_labels(), msg="获取工具状态成功") + + +@router.get("/auth/types", response_model=ApiResponse) +async def get_auth_types(): + """获取认证类型枚举""" + return success(data=AuthType.get_all_types_with_labels(), msg="获取认证类型成功") diff --git a/api/app/controllers/tool_execution_controller.py b/api/app/controllers/tool_execution_controller.py deleted file mode 100644 index 486eb7cf..00000000 --- a/api/app/controllers/tool_execution_controller.py +++ /dev/null @@ -1,430 +0,0 @@ -"""工具执行API控制器""" -import uuid -from typing import Dict, Any, List, Optional -from fastapi import APIRouter, Depends, HTTPException, Path, Query -from sqlalchemy.orm import Session -from pydantic import BaseModel, Field - -from app.db import get_db -from app.dependencies import get_current_user -from app.models import User -from app.core.tools.registry import ToolRegistry -from app.core.tools.executor import ToolExecutor -from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode -from app.core.tools.builtin import * -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - -router = APIRouter(prefix="/tools/execution", tags=["工具执行"]) - - -# ==================== 请求/响应模型 ==================== - -class ToolExecutionRequest(BaseModel): - """工具执行请求""" - tool_id: str = Field(..., description="工具ID") - parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") - timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)") - metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据") - - -class BatchExecutionRequest(BaseModel): - """批量执行请求""" - executions: List[ToolExecutionRequest] = Field(..., description="执行列表") - max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数") - - -class ToolExecutionResponse(BaseModel): - """工具执行响应""" - success: bool - execution_id: str - tool_id: str - data: Any = None - error: Optional[str] = None - error_code: Optional[str] = None - execution_time: float - token_usage: Optional[Dict[str, int]] = None - metadata: Dict[str, Any] = Field(default_factory=dict) - - -class ChainStepRequest(BaseModel): - """链步骤请求""" - tool_id: str = Field(..., description="工具ID") - parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") - condition: Optional[str] = Field(None, description="执行条件") - output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射") - error_handling: str = Field("stop", description="错误处理策略") - - -class ChainExecutionRequest(BaseModel): - """链执行请求""" - name: str = Field(..., description="链名称") - description: str = Field("", description="链描述") - steps: List[ChainStepRequest] = Field(..., description="执行步骤") - execution_mode: str = Field("sequential", description="执行模式") - initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量") - global_timeout: Optional[float] = Field(None, description="全局超时") - - -class ExecutionHistoryResponse(BaseModel): - """执行历史响应""" - execution_id: str - tool_id: str - status: str - started_at: Optional[str] - completed_at: Optional[str] - execution_time: Optional[float] - user_id: Optional[str] - workspace_id: Optional[str] - input_data: Optional[Dict[str, Any]] - output_data: Optional[Any] - error_message: Optional[str] - token_usage: Optional[Dict[str, int]] - - -class ToolConnectionTestResponse(BaseModel): - """工具连接测试响应""" - success: bool - message: str - error: Optional[str] = None - details: Optional[Dict[str, Any]] = None - - -# ==================== 依赖注入 ==================== - -def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry: - """获取工具注册表""" - registry = ToolRegistry(db) - - # 注册内置工具类 - registry.register_tool_class(DateTimeTool) - registry.register_tool_class(JsonTool) - registry.register_tool_class(BaiduSearchTool) - registry.register_tool_class(MinerUTool) - registry.register_tool_class(TextInTool) - - return registry - - -def get_tool_executor( - db: Session = Depends(get_db), - registry: ToolRegistry = Depends(get_tool_registry) -) -> ToolExecutor: - """获取工具执行器""" - return ToolExecutor(db, registry) - - -def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager: - """获取链管理器""" - return ChainManager(executor) - - -# ==================== API端点 ==================== - -@router.post("/execute", response_model=ToolExecutionResponse) -async def execute_tool( - request: ToolExecutionRequest, - current_user: User = Depends(get_current_user), - executor: ToolExecutor = Depends(get_tool_executor) -): - """执行单个工具""" - try: - # 生成执行ID - execution_id = f"exec_{uuid.uuid4().hex[:16]}" - - # 执行工具 - result = await executor.execute_tool( - tool_id=request.tool_id, - parameters=request.parameters, - user_id=current_user.id, - workspace_id=current_user.current_workspace_id, - execution_id=execution_id, - timeout=request.timeout, - metadata=request.metadata - ) - - return ToolExecutionResponse( - success=result.success, - execution_id=execution_id, - tool_id=request.tool_id, - data=result.data, - error=result.error, - error_code=result.error_code, - execution_time=result.execution_time, - token_usage=result.token_usage, - metadata=result.metadata - ) - - except Exception as e: - logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/batch", response_model=List[ToolExecutionResponse]) -async def execute_tools_batch( - request: BatchExecutionRequest, - current_user: User = Depends(get_current_user), - executor: ToolExecutor = Depends(get_tool_executor) -): - """批量执行工具""" - try: - # 准备执行配置 - execution_configs = [] - execution_ids = [] - - for exec_request in request.executions: - execution_id = f"exec_{uuid.uuid4().hex[:16]}" - execution_ids.append(execution_id) - - execution_configs.append({ - "tool_id": exec_request.tool_id, - "parameters": exec_request.parameters, - "user_id": current_user.id, - "workspace_id": current_user.current_workspace_id, - "execution_id": execution_id, - "timeout": exec_request.timeout, - "metadata": exec_request.metadata - }) - - # 批量执行 - results = await executor.execute_tools_batch( - execution_configs, - max_concurrency=request.max_concurrency - ) - - # 转换响应格式 - responses = [] - for i, result in enumerate(results): - responses.append(ToolExecutionResponse( - success=result.success, - execution_id=execution_ids[i], - tool_id=request.executions[i].tool_id, - data=result.data, - error=result.error, - error_code=result.error_code, - execution_time=result.execution_time, - token_usage=result.token_usage, - metadata=result.metadata - )) - - return responses - - except Exception as e: - logger.error(f"批量执行失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/chain", response_model=Dict[str, Any]) -async def execute_tool_chain( - request: ChainExecutionRequest, - current_user: User = Depends(get_current_user), - chain_manager: ChainManager = Depends(get_chain_manager) -): - """执行工具链""" - try: - # 转换步骤格式 - steps = [] - for step_request in request.steps: - step = ChainStep( - tool_id=step_request.tool_id, - parameters=step_request.parameters, - condition=step_request.condition, - output_mapping=step_request.output_mapping, - error_handling=step_request.error_handling - ) - steps.append(step) - - # 创建链定义 - chain_definition = ChainDefinition( - name=request.name, - description=request.description, - steps=steps, - execution_mode=ChainExecutionMode(request.execution_mode), - global_timeout=request.global_timeout - ) - - # 注册并执行链 - chain_manager.register_chain(chain_definition) - - result = await chain_manager.execute_chain( - chain_name=request.name, - initial_variables=request.initial_variables - ) - - return result - - except Exception as e: - logger.error(f"工具链执行失败: {request.name}, 错误: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/running", response_model=List[Dict[str, Any]]) -async def get_running_executions( - current_user: User = Depends(get_current_user), - executor: ToolExecutor = Depends(get_tool_executor) -): - """获取正在运行的执行""" - try: - running_executions = executor.get_running_executions() - - # 过滤当前工作空间的执行 - workspace_executions = [ - exec_info for exec_info in running_executions - if exec_info.get("workspace_id") == str(current_user.current_workspace_id) - ] - - return workspace_executions - - except Exception as e: - logger.error(f"获取运行中执行失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any]) -async def cancel_execution( - execution_id: str = Path(..., description="执行ID"), - current_user: User = Depends(get_current_user), - executor: ToolExecutor = Depends(get_tool_executor) -): - """取消工具执行""" - try: - success = await executor.cancel_execution(execution_id) - - if success: - return { - "success": True, - "message": "执行已取消" - } - else: - raise HTTPException(status_code=404, detail="执行不存在或已完成") - - except HTTPException: - raise - except Exception as e: - logger.error(f"取消执行失败: {execution_id}, 错误: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/history", response_model=List[ExecutionHistoryResponse]) -async def get_execution_history( - tool_id: Optional[str] = Query(None, description="工具ID过滤"), - limit: int = Query(50, ge=1, le=200, description="返回数量限制"), - current_user: User = Depends(get_current_user), - executor: ToolExecutor = Depends(get_tool_executor) -): - """获取执行历史""" - try: - history = executor.get_execution_history( - tool_id=tool_id, - user_id=current_user.id, - workspace_id=current_user.current_workspace_id, - limit=limit - ) - - # 转换响应格式 - responses = [] - for record in history: - responses.append(ExecutionHistoryResponse( - execution_id=record["execution_id"], - tool_id=record["tool_id"], - status=record["status"], - started_at=record["started_at"], - completed_at=record["completed_at"], - execution_time=record["execution_time"], - user_id=record["user_id"], - workspace_id=record["workspace_id"], - input_data=record["input_data"], - output_data=record["output_data"], - error_message=record["error_message"], - token_usage=record["token_usage"] - )) - - return responses - - except Exception as e: - logger.error(f"获取执行历史失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/statistics", response_model=Dict[str, Any]) -async def get_execution_statistics( - days: int = Query(7, ge=1, le=90, description="统计天数"), - current_user: User = Depends(get_current_user), - executor: ToolExecutor = Depends(get_tool_executor) -): - """获取执行统计""" - try: - stats = executor.get_execution_statistics( - workspace_id=current_user.current_workspace_id, - days=days - ) - - return { - "success": True, - "statistics": stats - } - - except Exception as e: - logger.error(f"获取执行统计失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/chains/running", response_model=List[Dict[str, Any]]) -async def get_running_chains( - current_user: User = Depends(get_current_user), - chain_manager: ChainManager = Depends(get_chain_manager) -): - """获取正在运行的工具链""" - try: - running_chains = chain_manager.get_running_chains() - return running_chains - - except Exception as e: - logger.error(f"获取运行中工具链失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.get("/chains", response_model=List[Dict[str, Any]]) -async def list_tool_chains( - current_user: User = Depends(get_current_user), - chain_manager: ChainManager = Depends(get_chain_manager) -): - """列出工具链""" - try: - chains = chain_manager.list_chains() - return chains - - except Exception as e: - logger.error(f"获取工具链列表失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse) -async def test_tool_connection( - tool_id: str = Path(..., description="工具ID"), - current_user: User = Depends(get_current_user), - executor: ToolExecutor = Depends(get_tool_executor) -): - """测试工具连接""" - try: - result = await executor.test_tool_connection( - tool_id=tool_id, - user_id=current_user.id, - workspace_id=current_user.current_workspace_id - ) - - return ToolConnectionTestResponse( - success=result.get("success", False), - message=result.get("message", ""), - error=result.get("error"), - details=result.get("details") - ) - - except Exception as e: - logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}") - return ToolConnectionTestResponse( - success=False, - message="连接测试失败", - error=str(e) - ) \ No newline at end of file diff --git a/api/app/core/tools/__init__.py b/api/app/core/tools/__init__.py index 109bac13..714dc851 100644 --- a/api/app/core/tools/__init__.py +++ b/api/app/core/tools/__init__.py @@ -1,11 +1,7 @@ """工具管理核心模块""" from .base import BaseTool, ToolResult, ToolParameter -from .registry import ToolRegistry -from .executor import ToolExecutor from .langchain_adapter import LangchainAdapter -from .config_manager import ConfigManager -from .chain_manager import ChainManager # 可选导入,避免导入错误 try: @@ -22,11 +18,7 @@ __all__ = [ "BaseTool", "ToolResult", "ToolParameter", - "ToolRegistry", - "ToolExecutor", - "LangchainAdapter", - "ConfigManager", - "ChainManager" + "LangchainAdapter" ] # 只有在成功导入时才添加到__all__ diff --git a/api/app/core/tools/base.py b/api/app/core/tools/base.py index d674af76..c9771ef0 100644 --- a/api/app/core/tools/base.py +++ b/api/app/core/tools/base.py @@ -1,98 +1,10 @@ """工具基础接口定义""" import time from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union -from pydantic import BaseModel, Field -from enum import Enum +from typing import Any, Dict, List, Optional from app.models.tool_model import ToolType, ToolStatus - - -class ParameterType(str, Enum): - """参数类型枚举""" - STRING = "string" - INTEGER = "integer" - NUMBER = "number" - BOOLEAN = "boolean" - ARRAY = "array" - OBJECT = "object" - - -class ToolParameter(BaseModel): - """工具参数定义""" - name: str = Field(..., description="参数名称") - type: ParameterType = Field(..., description="参数类型") - description: str = Field("", description="参数描述") - required: bool = Field(False, description="是否必需") - default: Any = Field(None, description="默认值") - enum: Optional[List[Any]] = Field(None, description="枚举值") - minimum: Optional[Union[int, float]] = Field(None, description="最小值") - maximum: Optional[Union[int, float]] = Field(None, description="最大值") - pattern: Optional[str] = Field(None, description="正则表达式模式") - - class Config: - use_enum_values = True - - -class ToolResult(BaseModel): - """工具执行结果""" - success: bool = Field(..., description="执行是否成功") - data: Any = Field(None, description="返回数据") - error: Optional[str] = Field(None, description="错误信息") - error_code: Optional[str] = Field(None, description="错误代码") - execution_time: float = Field(..., description="执行时间(秒)") - token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况") - metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据") - - @classmethod - def success_result( - cls, - data: Any, - execution_time: float, - token_usage: Optional[Dict[str, int]] = None, - metadata: Optional[Dict[str, Any]] = None - ) -> "ToolResult": - """创建成功结果""" - return cls( - success=True, - data=data, - execution_time=execution_time, - token_usage=token_usage, - metadata=metadata or {} - ) - - @classmethod - def error_result( - cls, - error: str, - execution_time: float, - error_code: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None - ) -> "ToolResult": - """创建错误结果""" - return cls( - success=False, - error=error, - error_code=error_code, - execution_time=execution_time, - metadata=metadata or {} - ) - - -class ToolInfo(BaseModel): - """工具信息""" - id: str = Field(..., description="工具ID") - name: str = Field(..., description="工具名称") - description: str = Field(..., description="工具描述") - tool_type: ToolType = Field(..., description="工具类型") - version: str = Field("1.0.0", description="工具版本") - parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数") - status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态") - tags: List[str] = Field(default_factory=list, description="工具标签") - tenant_id: Optional[str] = Field(None, description="租户ID") - - class Config: - use_enum_values = True +from app.schemas.tool_schema import ToolParameter, ParameterType, ToolResult class BaseTool(ABC): @@ -107,7 +19,7 @@ class BaseTool(ABC): """ self.tool_id = tool_id self.config = config - self._status = ToolStatus.ACTIVE + self._status = ToolStatus.AVAILABLE @property @abstractmethod @@ -153,20 +65,6 @@ class BaseTool(ABC): """工具标签""" return self.config.get("tags", []) - def get_info(self) -> ToolInfo: - """获取工具信息""" - return ToolInfo( - id=self.tool_id, - name=self.name, - description=self.description, - tool_type=self.tool_type, - version=self.version, - parameters=self.parameters, - status=self.status, - tags=self.tags, - tenant_id=self.config.get("tenant_id") - ) - def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]: """验证参数 diff --git a/api/app/core/tools/builtin/base.py b/api/app/core/tools/builtin/base.py index 532d0869..781b5ffc 100644 --- a/api/app/core/tools/builtin/base.py +++ b/api/app/core/tools/builtin/base.py @@ -3,7 +3,8 @@ from abc import ABC, abstractmethod from typing import Dict, Any, List from app.models.tool_model import ToolType -from app.core.tools.base import BaseTool, ToolResult, ToolParameter +from app.core.tools.base import BaseTool +from app.schemas.tool_schema import ToolResult, ToolParameter class BuiltinTool(BaseTool, ABC): diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py index 475ce7be..9e5ab9f6 100644 --- a/api/app/core/tools/builtin/datetime_tool.py +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone, timedelta from typing import List import pytz -from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType from .base import BuiltinTool @@ -54,14 +54,14 @@ class DateTimeTool(BuiltinTool): type=ParameterType.STRING, description="源时区(如:UTC, Asia/Shanghai)", required=False, - default="UTC" + default="Asia/Shanghai" ), ToolParameter( name="to_timezone", type=ParameterType.STRING, description="目标时区(如:UTC, Asia/Shanghai)", required=False, - default="UTC" + default="Asia/Shanghai" ), ToolParameter( name="calculation", @@ -106,10 +106,11 @@ class DateTimeTool(BuiltinTool): error_code="DATETIME_ERROR", execution_time=execution_time ) - - def _get_current_time(self, kwargs) -> dict: + + @staticmethod + def _get_current_time(kwargs) -> dict: """获取当前时间""" - timezone_str = kwargs.get("to_timezone", "UTC") + timezone_str = kwargs.get("to_timezone", "Asia/Shanghai") output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") if timezone_str == "UTC": @@ -118,15 +119,20 @@ class DateTimeTool(BuiltinTool): tz = pytz.timezone(timezone_str) now = datetime.now(tz) + + utc_now = datetime.now(timezone.utc) return { "datetime": now.strftime(output_format), "timestamp": int(now.timestamp()), "timezone": timezone_str, - "iso_format": now.isoformat() + "iso_format": now.isoformat(), + "timestamp_ms": int(now.timestamp() * 1000), + "utc_datetime": utc_now.strftime(output_format) } - - def _format_datetime(self, kwargs) -> dict: + + @staticmethod + def _format_datetime(kwargs) -> dict: """格式化时间""" input_value = kwargs.get("input_value") input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") @@ -144,8 +150,9 @@ class DateTimeTool(BuiltinTool): "timestamp": int(dt.timestamp()), "iso_format": dt.isoformat() } - - def _convert_timezone(self, kwargs) -> dict: + + @staticmethod + def _convert_timezone(kwargs) -> dict: """时区转换""" input_value = kwargs.get("input_value") input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") @@ -184,8 +191,9 @@ class DateTimeTool(BuiltinTool): "converted_timezone": to_timezone, "timestamp": int(converted_dt.timestamp()) } - - def _timestamp_to_datetime(self, kwargs) -> dict: + + @staticmethod + def _timestamp_to_datetime(kwargs) -> dict: """时间戳转日期时间""" input_value = kwargs.get("input_value") output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") @@ -196,6 +204,8 @@ class DateTimeTool(BuiltinTool): # 转换时间戳 timestamp = float(input_value) + if timestamp > 1e12: + timestamp = timestamp / 1000 # 设置时区 if timezone_str == "UTC": @@ -211,8 +221,9 @@ class DateTimeTool(BuiltinTool): "timezone": timezone_str, "iso_format": dt.isoformat() } - - def _datetime_to_timestamp(self, kwargs) -> dict: + + @staticmethod + def _datetime_to_timestamp(kwargs) -> dict: """日期时间转时间戳""" input_value = kwargs.get("input_value") input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") @@ -240,7 +251,7 @@ class DateTimeTool(BuiltinTool): "timestamp": int(dt.timestamp()), "iso_format": dt.isoformat() } - + def _calculate_datetime(self, kwargs) -> dict: """时间计算""" input_value = kwargs.get("input_value") @@ -278,8 +289,9 @@ class DateTimeTool(BuiltinTool): "timezone": timezone_str, "timestamp": int(calculated_dt.timestamp()) } - - def _parse_time_delta(self, calculation: str) -> timedelta: + + @staticmethod + def _parse_time_delta(calculation: str) -> timedelta: """解析时间计算表达式""" import re diff --git a/api/app/core/tools/builtin/json_tool.py b/api/app/core/tools/builtin/json_tool.py index 135d252a..d2b73bba 100644 --- a/api/app/core/tools/builtin/json_tool.py +++ b/api/app/core/tools/builtin/json_tool.py @@ -121,8 +121,9 @@ class JsonTool(BuiltinTool): error_code="JSON_ERROR", execution_time=execution_time ) - - def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def _format_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: """格式化JSON""" indent = kwargs.get("indent", 2) ensure_ascii = kwargs.get("ensure_ascii", False) @@ -151,12 +152,13 @@ class JsonTool(BuiltinTool): "sort_keys": sort_keys } } - - def _minify_json(self, input_data: str) -> Dict[str, Any]: + + @staticmethod + def _minify_json(input_data: str) -> Dict[str, Any]: """压缩JSON""" # 解析并压缩 data = json.loads(input_data) - minified = json.dumps(data, separators=(',', ':')) + minified = json.dumps(data, ensure_ascii=False, separators=(',', ':')) return { "original_size": len(input_data), @@ -165,7 +167,7 @@ class JsonTool(BuiltinTool): "minified_json": minified, "is_valid": True } - + def _validate_json(self, input_data: str) -> Dict[str, Any]: """验证JSON""" try: @@ -190,17 +192,19 @@ class JsonTool(BuiltinTool): "size": len(input_data) } - def _convert_json(self, input_data: str) -> Dict[str, Any]: + @staticmethod + def _convert_json(input_data: str) -> Dict[str, Any]: """JSON转义""" data = json.loads(input_data) - converted = json.dumps(data, ensure_ascii=False) + converted = json.dumps(data, ensure_ascii=True, separators=(',', ':')) return { "converted_json": converted, "is_valid": True } - - def _json_to_yaml(self, input_data: str) -> Dict[str, Any]: + + @staticmethod + def _json_to_yaml(input_data: str) -> Dict[str, Any]: """JSON转YAML""" data = json.loads(input_data) yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2) @@ -212,8 +216,9 @@ class JsonTool(BuiltinTool): "converted_size": len(yaml_output), "converted_data": yaml_output } - - def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def _yaml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: """YAML转JSON""" indent = kwargs.get("indent", 2) ensure_ascii = kwargs.get("ensure_ascii", False) @@ -228,10 +233,11 @@ class JsonTool(BuiltinTool): "converted_size": len(json_output), "converted_data": json_output } - - def _json_to_xml(self, input_data: str) -> Dict[str, Any]: + + @staticmethod + def _json_to_xml(input_data: str) -> Dict[str, Any]: """JSON转XML""" - data = json.loads(input_data) + json_data = json.loads(input_data) def dict_to_xml(data, root_name="root"): """递归转换字典为XML""" @@ -267,7 +273,7 @@ class JsonTool(BuiltinTool): root.text = str(data) return root - xml_element = dict_to_xml(data) + xml_element = dict_to_xml(json_data) xml_string = ET.tostring(xml_element, encoding='unicode') # 格式化XML @@ -284,8 +290,9 @@ class JsonTool(BuiltinTool): "converted_size": len(formatted_xml), "converted_data": formatted_xml } - - def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def _xml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: """XML转JSON""" indent = kwargs.get("indent", 2) @@ -328,8 +335,9 @@ class JsonTool(BuiltinTool): "converted_size": len(json_output), "converted_data": json_output } - - def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def _merge_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: """合并JSON""" merge_data = kwargs.get("merge_data") if not merge_data: @@ -364,8 +372,9 @@ class JsonTool(BuiltinTool): "result_size": len(merged_json), "merged_data": merged_json } - - def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def _extract_json_path( input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: """提取JSON路径""" json_path = kwargs.get("json_path") if not json_path: diff --git a/api/app/core/tools/builtin/textin_tool.py b/api/app/core/tools/builtin/textin_tool.py index ec3e214e..e5218416 100644 --- a/api/app/core/tools/builtin/textin_tool.py +++ b/api/app/core/tools/builtin/textin_tool.py @@ -275,8 +275,9 @@ class TextInTool(BuiltinTool): "total_confidence": result.get("confidence", 0), "processing_time": result.get("processing_time", 0) } - - def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + + @staticmethod + def _format_formula_result( result: Dict[str, Any], output_format: str) -> Dict[str, Any]: """格式化公式识别结果""" formulas = result.get("formulas", []) @@ -288,8 +289,9 @@ class TextInTool(BuiltinTool): "total_confidence": result.get("confidence", 0), "processing_time": result.get("processing_time", 0) } - - def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + + @staticmethod + def _format_table_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]: """格式化表格识别结果""" tables = result.get("tables", []) @@ -301,8 +303,9 @@ class TextInTool(BuiltinTool): "total_confidence": result.get("confidence", 0), "processing_time": result.get("processing_time", 0) } - - def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + + @staticmethod + def _format_document_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]: """格式化文档识别结果""" return { "recognition_mode": "document", @@ -314,8 +317,9 @@ class TextInTool(BuiltinTool): "total_confidence": result.get("confidence", 0), "processing_time": result.get("processing_time", 0) } - - def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + + @staticmethod + def _group_lines_to_paragraphs(lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """将行分组为段落""" paragraphs = [] current_paragraph = [] diff --git a/api/app/core/tools/chain_manager.py b/api/app/core/tools/chain_manager.py deleted file mode 100644 index 713baa39..00000000 --- a/api/app/core/tools/chain_manager.py +++ /dev/null @@ -1,485 +0,0 @@ -"""工具链管理器 - 支持langchain的工具链模式""" -from typing import List, Dict, Any, Optional -from dataclasses import dataclass -from enum import Enum - -from app.core.tools.base import ToolResult -from app.core.tools.executor import ToolExecutor -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class ChainExecutionMode(str, Enum): - """链执行模式""" - SEQUENTIAL = "sequential" # 顺序执行 - PARALLEL = "parallel" # 并行执行 - CONDITIONAL = "conditional" # 条件执行 - - -@dataclass -class ChainStep: - """链步骤定义""" - tool_id: str - parameters: Dict[str, Any] - condition: Optional[str] = None # 执行条件 - output_mapping: Optional[Dict[str, str]] = None # 输出映射 - error_handling: str = "stop" # 错误处理:stop, continue, retry - - -@dataclass -class ChainDefinition: - """工具链定义""" - name: str - description: str - steps: List[ChainStep] - execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL - global_timeout: Optional[float] = None - retry_policy: Optional[Dict[str, Any]] = None - - -class ChainExecutionContext: - """链执行上下文""" - - def __init__(self, chain_id: str): - self.chain_id = chain_id - self.variables: Dict[str, Any] = {} - self.step_results: Dict[int, ToolResult] = {} - self.current_step = 0 - self.is_completed = False - self.is_failed = False - self.error_message: Optional[str] = None - - -class ChainManager: - """工具链管理器 - 支持langchain的工具链模式""" - - def __init__(self, executor: ToolExecutor): - """初始化工具链管理器 - - Args: - executor: 工具执行器 - """ - self.executor = executor - self._chains: Dict[str, ChainDefinition] = {} - self._running_chains: Dict[str, ChainExecutionContext] = {} - - def register_chain(self, chain: ChainDefinition) -> bool: - """注册工具链 - - Args: - chain: 工具链定义 - - Returns: - 注册是否成功 - """ - try: - # 验证工具链定义 - validation_result = self._validate_chain(chain) - if not validation_result[0]: - logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}") - return False - - self._chains[chain.name] = chain - logger.info(f"工具链注册成功: {chain.name}") - return True - - except Exception as e: - logger.error(f"工具链注册失败: {chain.name}, 错误: {e}") - return False - - def unregister_chain(self, chain_name: str) -> bool: - """注销工具链 - - Args: - chain_name: 工具链名称 - - Returns: - 注销是否成功 - """ - if chain_name in self._chains: - del self._chains[chain_name] - logger.info(f"工具链注销成功: {chain_name}") - return True - - return False - - def list_chains(self) -> List[Dict[str, Any]]: - """列出所有工具链 - - Returns: - 工具链信息列表 - """ - chains = [] - for name, chain in self._chains.items(): - chains.append({ - "name": name, - "description": chain.description, - "step_count": len(chain.steps), - "execution_mode": chain.execution_mode.value, - "global_timeout": chain.global_timeout - }) - - return chains - - async def execute_chain( - self, - chain_name: str, - initial_variables: Optional[Dict[str, Any]] = None, - chain_id: Optional[str] = None - ) -> Dict[str, Any] | None: - """执行工具链 - - Args: - chain_name: 工具链名称 - initial_variables: 初始变量 - chain_id: 链执行ID(可选) - - Returns: - 执行结果 - """ - if chain_name not in self._chains: - return { - "success": False, - "error": f"工具链不存在: {chain_name}", - "chain_id": chain_id - } - - chain = self._chains[chain_name] - - # 生成链ID - if not chain_id: - import uuid - chain_id = f"chain_{uuid.uuid4().hex[:16]}" - - # 创建执行上下文 - context = ChainExecutionContext(chain_id) - context.variables = initial_variables or {} - self._running_chains[chain_id] = context - - try: - logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})") - - # 根据执行模式执行 - if chain.execution_mode == ChainExecutionMode.SEQUENTIAL: - result = await self._execute_sequential(chain, context) - elif chain.execution_mode == ChainExecutionMode.PARALLEL: - result = await self._execute_parallel(chain, context) - elif chain.execution_mode == ChainExecutionMode.CONDITIONAL: - result = await self._execute_conditional(chain, context) - else: - raise ValueError(f"不支持的执行模式: {chain.execution_mode}") - - logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})") - return result - - except Exception as e: - logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}") - return { - "success": False, - "error": str(e), - "chain_id": chain_id, - "completed_steps": context.current_step, - "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} - } - - finally: - # 清理执行上下文 - if chain_id in self._running_chains: - del self._running_chains[chain_id] - - async def _execute_sequential( - self, - chain: ChainDefinition, - context: ChainExecutionContext - ) -> Dict[str, Any]: - """顺序执行工具链""" - for i, step in enumerate(chain.steps): - context.current_step = i - - # 检查执行条件 - if step.condition and not self._evaluate_condition(step.condition, context): - logger.debug(f"跳过步骤 {i}: 条件不满足") - continue - - # 准备参数 - parameters = self._prepare_parameters(step.parameters, context) - - # 执行工具 - try: - result = await self.executor.execute_tool( - tool_id=step.tool_id, - parameters=parameters - ) - - context.step_results[i] = result - - # 处理输出映射 - if step.output_mapping and result.success: - self._apply_output_mapping(step.output_mapping, result.data, context) - - # 处理执行失败 - if not result.success: - if step.error_handling == "stop": - context.is_failed = True - context.error_message = result.error - break - elif step.error_handling == "continue": - logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}") - continue - elif step.error_handling == "retry": - # 简单重试逻辑 - retry_result = await self.executor.execute_tool( - tool_id=step.tool_id, - parameters=parameters - ) - context.step_results[i] = retry_result - if not retry_result.success and step.error_handling == "stop": - context.is_failed = True - context.error_message = retry_result.error - break - - except Exception as e: - logger.error(f"步骤 {i} 执行异常: {e}") - if step.error_handling == "stop": - context.is_failed = True - context.error_message = str(e) - break - - context.is_completed = not context.is_failed - - return { - "success": context.is_completed, - "error": context.error_message, - "chain_id": context.chain_id, - "completed_steps": context.current_step + 1, - "total_steps": len(chain.steps), - "final_variables": context.variables, - "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} - } - - async def _execute_parallel( - self, - chain: ChainDefinition, - context: ChainExecutionContext - ) -> Dict[str, Any]: - """并行执行工具链""" - # 准备所有步骤的执行配置 - execution_configs = [] - - for i, step in enumerate(chain.steps): - # 检查执行条件 - if step.condition and not self._evaluate_condition(step.condition, context): - continue - - parameters = self._prepare_parameters(step.parameters, context) - execution_configs.append({ - "step_index": i, - "tool_id": step.tool_id, - "parameters": parameters - }) - - # 并行执行所有步骤 - try: - results = await self.executor.execute_tools_batch(execution_configs) - - # 处理结果 - for i, result in enumerate(results): - step_index = execution_configs[i]["step_index"] - context.step_results[step_index] = result - - # 处理输出映射 - step = chain.steps[step_index] - if step.output_mapping and result.success: - self._apply_output_mapping(step.output_mapping, result.data, context) - - # 检查是否有失败的步骤 - failed_steps = [i for i, result in context.step_results.items() if not result.success] - - context.is_completed = len(failed_steps) == 0 - if failed_steps: - context.error_message = f"步骤 {failed_steps} 执行失败" - - except Exception as e: - context.is_failed = True - context.error_message = str(e) - - return { - "success": context.is_completed, - "error": context.error_message, - "chain_id": context.chain_id, - "completed_steps": len(context.step_results), - "total_steps": len(chain.steps), - "final_variables": context.variables, - "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} - } - - async def _execute_conditional( - self, - chain: ChainDefinition, - context: ChainExecutionContext - ) -> Dict[str, Any]: - """条件执行工具链""" - # 条件执行类似于顺序执行,但更严格地检查条件 - return await self._execute_sequential(chain, context) - - def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]: - """验证工具链定义 - - Args: - chain: 工具链定义 - - Returns: - (是否有效, 错误信息) - """ - if not chain.name: - return False, "工具链名称不能为空" - - if not chain.steps: - return False, "工具链必须包含至少一个步骤" - - for i, step in enumerate(chain.steps): - if not step.tool_id: - return False, f"步骤 {i} 缺少工具ID" - - if step.error_handling not in ["stop", "continue", "retry"]: - return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}" - - return True, None - - def _prepare_parameters( - self, - parameters: Dict[str, Any], - context: ChainExecutionContext - ) -> Dict[str, Any]: - """准备参数(支持变量替换) - - Args: - parameters: 原始参数 - context: 执行上下文 - - Returns: - 处理后的参数 - """ - prepared = {} - - for key, value in parameters.items(): - if isinstance(value, str) and value.startswith("${") and value.endswith("}"): - # 变量替换 - var_name = value[2:-1] - if var_name in context.variables: - prepared[key] = context.variables[var_name] - else: - prepared[key] = value # 保持原值 - else: - prepared[key] = value - - return prepared - - def _evaluate_condition( - self, - condition: str, - context: ChainExecutionContext - ) -> bool: - """评估执行条件 - - Args: - condition: 条件表达式 - context: 执行上下文 - - Returns: - 条件是否满足 - """ - try: - # 简单的条件评估(可以扩展为更复杂的表达式解析) - # 支持格式:variable == value, variable != value, variable > value 等 - - if "==" in condition: - var_name, expected_value = condition.split("==", 1) - var_name = var_name.strip() - expected_value = expected_value.strip().strip('"\'') - - return str(context.variables.get(var_name, "")) == expected_value - - elif "!=" in condition: - var_name, expected_value = condition.split("!=", 1) - var_name = var_name.strip() - expected_value = expected_value.strip().strip('"\'') - - return str(context.variables.get(var_name, "")) != expected_value - - elif condition in context.variables: - # 简单的布尔检查 - return bool(context.variables[condition]) - - else: - # 默认为真 - return True - - except Exception as e: - logger.error(f"条件评估失败: {condition}, 错误: {e}") - return False - - def _apply_output_mapping( - self, - mapping: Dict[str, str], - output_data: Any, - context: ChainExecutionContext - ): - """应用输出映射 - - Args: - mapping: 输出映射配置 - output_data: 输出数据 - context: 执行上下文 - """ - try: - if isinstance(output_data, dict): - for source_key, target_var in mapping.items(): - if source_key in output_data: - context.variables[target_var] = output_data[source_key] - else: - # 如果输出不是字典,将整个输出映射到指定变量 - if "result" in mapping: - context.variables[mapping["result"]] = output_data - - except Exception as e: - logger.error(f"输出映射失败: {e}") - - def _serialize_result(self, result: ToolResult) -> Dict[str, Any]: - """序列化工具结果 - - Args: - result: 工具结果 - - Returns: - 序列化的结果 - """ - return { - "success": result.success, - "data": result.data, - "error": result.error, - "error_code": result.error_code, - "execution_time": result.execution_time, - "token_usage": result.token_usage, - "metadata": result.metadata - } - - def get_running_chains(self) -> List[Dict[str, Any]]: - """获取正在运行的工具链 - - Returns: - 运行中的工具链列表 - """ - chains = [] - for chain_id, context in self._running_chains.items(): - chains.append({ - "chain_id": chain_id, - "current_step": context.current_step, - "is_completed": context.is_completed, - "is_failed": context.is_failed, - "variables_count": len(context.variables), - "completed_steps": len(context.step_results) - }) - - return chains \ No newline at end of file diff --git a/api/app/core/tools/config_manager.py b/api/app/core/tools/config_manager.py deleted file mode 100644 index fb8d1fff..00000000 --- a/api/app/core/tools/config_manager.py +++ /dev/null @@ -1,264 +0,0 @@ -"""工具配置管理器 - 管理工具配置的加载和验证""" -import json -from pathlib import Path -from typing import Dict, Any, Optional -from pydantic import BaseModel, ValidationError - -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class ToolConfigSchema(BaseModel): - """工具配置基础Schema""" - name: str - description: str - tool_type: str - version: str = "1.0.0" - enabled: bool = True - parameters: Dict[str, Any] = {} - tags: list[str] = [] - - class Config: - extra = "allow" - - -class BuiltinToolConfigSchema(ToolConfigSchema): - """内置工具配置Schema""" - tool_class: str - tool_type: str = "builtin" - - -class CustomToolConfigSchema(ToolConfigSchema): - """自定义工具配置Schema""" - schema_url: Optional[str] = None - schema_content: Optional[Dict[str, Any]] = None - auth_type: str = "none" - auth_config: Dict[str, Any] = {} - base_url: Optional[str] = None - timeout: int = 30 - tool_type: str = "custom" - - -class MCPToolConfigSchema(ToolConfigSchema): - """MCP工具配置Schema""" - server_url: str - connection_config: Dict[str, Any] = {} - available_tools: list[str] = [] - tool_type: str = "mcp" - - -class ConfigManager: - """工具配置管理器""" - - def __init__(self, config_dir: Optional[str] = None): - """初始化配置管理器 - - Args: - config_dir: 配置文件目录,默认使用系统配置 - """ - self.config_dir = Path(config_dir or self._get_default_config_dir()) - self.config_dir.mkdir(parents=True, exist_ok=True) - - logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}") - - def _get_default_config_dir(self) -> str: - """获取默认配置目录""" - # 获取tools目录下的configs子目录 - tools_dir = Path(__file__).parent - return str(tools_dir / "configs") - - def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]: - """加载内置工具配置 - - Returns: - 内置工具配置字典 - """ - configs = {} - builtin_dir = self.config_dir / "builtin" - - if not builtin_dir.exists(): - logger.info("内置工具配置目录不存在,创建默认配置") - self._create_default_builtin_configs(builtin_dir) - - for config_file in builtin_dir.glob("*.json"): - try: - config_data = self._load_config_file(config_file) - config = BuiltinToolConfigSchema(**config_data) - configs[config.name] = config - logger.debug(f"加载内置工具配置: {config.name}") - except Exception as e: - logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}") - - return configs - - def load_builtin_tools_config(self) -> Dict[str, Any]: - """加载全局内置工具配置(兼容原有接口) - - Returns: - 内置工具配置字典 - """ - config_file = self.config_dir / "builtin_tools.json" - try: - with open(config_file, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"加载内置工具配置失败: {e}") - return {} - - def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum): - """确保内置工具已初始化到数据库 - - Args: - tenant_id: 租户ID - db_session: 数据库会话 - tool_config_model: ToolConfig模型类 - builtin_tool_config_model: BuiltinToolConfig模型类 - tool_type_enum: ToolType枚举 - tool_status_enum: ToolStatus枚举 - """ - # 检查是否已初始化 - existing_count = db_session.query(tool_config_model).filter( - tool_config_model.tenant_id == tenant_id, - tool_config_model.tool_type == tool_type_enum.BUILTIN - ).count() - - if existing_count > 0: - return # 已初始化 - - # 加载全局配置 - builtin_tools = self.load_builtin_tools_config() - - # 为租户创建内置工具记录 - for tool_key, tool_info in builtin_tools.items(): - # 设置初始状态 - initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value - - tool_config = tool_config_model( - name=tool_info['name'], - description=tool_info['description'], - tool_type=tool_type_enum.BUILTIN, - tenant_id=tenant_id, - status=initial_status - ) - db_session.add(tool_config) - db_session.flush() - - builtin_config = builtin_tool_config_model( - id=tool_config.id, - tool_class=tool_info['tool_class'], - parameters={} - ) - db_session.add(builtin_config) - - db_session.commit() - logger.info(f"租户 {tenant_id} 的内置工具初始化完成") - - def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool: - """保存工具配置 - - Args: - config: 工具配置 - tool_type: 工具类型 - - Returns: - 保存是否成功 - """ - try: - config_dir = self.config_dir / tool_type - config_dir.mkdir(parents=True, exist_ok=True) - - config_file = config_dir / f"{config.name}.json" - config_data = config.model_dump() - - with open(config_file, 'w', encoding='utf-8') as f: - json.dump(config_data, f, indent=2, ensure_ascii=False) - - logger.info(f"工具配置保存成功: {config.name} ({tool_type})") - return True - - except Exception as e: - logger.error(f"工具配置保存失败: {config.name}, 错误: {e}") - return False - - def delete_tool_config(self, tool_name: str, tool_type: str) -> bool: - """删除工具配置 - - Args: - tool_name: 工具名称 - tool_type: 工具类型 - - Returns: - 删除是否成功 - """ - try: - config_file = self.config_dir / tool_type / f"{tool_name}.json" - - if config_file.exists(): - config_file.unlink() - logger.info(f"工具配置删除成功: {tool_name} ({tool_type})") - return True - else: - logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})") - return False - - except Exception as e: - logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}") - return False - - def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]: - """验证工具配置 - - Args: - config_data: 配置数据 - tool_type: 工具类型 - - Returns: - (是否有效, 错误信息) - """ - try: - schema_map = { - "builtin": BuiltinToolConfigSchema, - "custom": CustomToolConfigSchema, - "mcp": MCPToolConfigSchema - } - - schema_class = schema_map.get(tool_type) - if not schema_class: - return False, f"不支持的工具类型: {tool_type}" - - # 验证配置 - schema_class(**config_data) - return True, None - - except ValidationError as e: - error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()]) - return False, f"配置验证失败: {error_msg}" - except Exception as e: - return False, f"配置验证异常: {str(e)}" - - def _load_config_file(self, config_file: Path) -> Dict[str, Any]: - """加载配置文件 - - Args: - config_file: 配置文件路径 - - Returns: - 配置数据字典 - """ - try: - with open(config_file, 'r', encoding='utf-8') as f: - return json.load(f) - except Exception as e: - logger.error(f"加载配置文件失败: {config_file}, 错误: {e}") - raise - - def _create_default_builtin_configs(self, builtin_dir: Path): - """创建默认内置工具配置 - - Args: - builtin_dir: 内置工具配置目录 - """ - builtin_dir.mkdir(parents=True, exist_ok=True) - logger.info(f"内置工具配置目录已创建: {builtin_dir}") - # 配置文件已经通过其他方式创建,这里只需要确保目录存在 \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin_tools.json b/api/app/core/tools/configs/builtin_tools.json index ed0b87b1..c758a54a 100644 --- a/api/app/core/tools/configs/builtin_tools.json +++ b/api/app/core/tools/configs/builtin_tools.json @@ -54,7 +54,8 @@ "enabled": true, "parameters": { "api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}, - "api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true} + "api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}, + "base_url": {"type": "string", "description": "API地址", "default": "https://api.textin.com/v1"} } } } \ No newline at end of file diff --git a/api/app/core/tools/custom/auth_manager.py b/api/app/core/tools/custom/auth_manager.py index 5d457f11..9eb416a2 100644 --- a/api/app/core/tools/custom/auth_manager.py +++ b/api/app/core/tools/custom/auth_manager.py @@ -2,7 +2,6 @@ import base64 import hashlib import hmac -import time from typing import Dict, Any, Tuple from urllib.parse import quote import aiohttp @@ -51,8 +50,9 @@ class AuthManager: except Exception as e: return False, f"验证认证配置时出错: {e}" - - def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + + @staticmethod + def _validate_api_key_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]: """验证API Key认证配置 Args: @@ -79,8 +79,9 @@ class AuthManager: return False, "API Key位置必须是 header、query 或 cookie" return True, "" - - def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + + @staticmethod + def _validate_bearer_token_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]: """验证Bearer Token认证配置 Args: @@ -135,9 +136,9 @@ class AuthManager: except Exception as e: logger.error(f"应用认证时出错: {e}") return url, headers, params - + + @staticmethod def _apply_api_key_auth( - self, auth_config: Dict[str, Any], url: str, headers: Dict[str, str], @@ -176,9 +177,9 @@ class AuthManager: headers["Cookie"] = cookie_value return url, headers, params - + + @staticmethod def _apply_bearer_token_auth( - self, auth_config: Dict[str, Any], url: str, headers: Dict[str, str], @@ -260,8 +261,9 @@ class AuthManager: except Exception as e: logger.error(f"解密认证配置失败: {e}") return encrypted_config - - def _encrypt_string(self, value: str, key: str) -> str: + + @staticmethod + def _encrypt_string(value: str, key: str) -> str: """加密字符串 Args: @@ -289,8 +291,9 @@ class AuthManager: except Exception as e: logger.error(f"加密字符串失败: {e}") return value - - def _decrypt_string(self, encrypted_value: str, key: str) -> str: + + @staticmethod + def _decrypt_string(encrypted_value: str, key: str) -> str: """解密字符串 Args: @@ -471,8 +474,9 @@ class AuthManager: "error": f"测试认证时出错: {e}", "auth_type": auth_type.value } - - def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]: + + @staticmethod + def get_auth_config_template(auth_type: AuthType) -> Dict[str, Any]: """获取认证配置模板 Args: @@ -498,8 +502,9 @@ class AuthManager: } return templates.get(auth_type, {}) - - def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def mask_sensitive_config(auth_config: Dict[str, Any]) -> Dict[str, Any]: """遮蔽认证配置中的敏感信息 Args: diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py index eda6769b..0d656a8e 100644 --- a/api/app/core/tools/custom/base.py +++ b/api/app/core/tools/custom/base.py @@ -5,7 +5,8 @@ import aiohttp from urllib.parse import urljoin from app.models.tool_model import ToolType, AuthType -from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.core.tools.base import BaseTool +from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType from app.core.logging_config import get_business_logger logger = get_business_logger() @@ -173,8 +174,9 @@ class CustomTool(BaseTool): } return operations - - def _convert_openapi_type(self, openapi_type: str) -> ParameterType: + + @staticmethod + def _convert_openapi_type(openapi_type: str) -> ParameterType: """转换OpenAPI类型到内部类型""" type_mapping = { "string": ParameterType.STRING, @@ -239,8 +241,9 @@ class CustomTool(BaseTool): headers["Authorization"] = f"Bearer {token}" return headers - - def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]: + + @staticmethod + def _build_request_data(operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]: """构建请求数据""" if operation["method"] in ["POST", "PUT", "PATCH"]: request_body = operation.get("request_body") @@ -284,6 +287,7 @@ class CustomTool(BaseTool): try: return await response.json() except Exception as e: + logger.error(f"解析HTTP响应JSON失败: {str(e)}") return await response.text() @classmethod diff --git a/api/app/core/tools/custom/schema_parser.py b/api/app/core/tools/custom/schema_parser.py index 21ac28b6..a22e2cfa 100644 --- a/api/app/core/tools/custom/schema_parser.py +++ b/api/app/core/tools/custom/schema_parser.py @@ -10,6 +10,9 @@ from app.core.logging_config import get_business_logger logger = get_business_logger() +# 为了兼容性,创建别名 +# SchemaParser = OpenAPISchemaParser = None + class OpenAPISchemaParser: """OpenAPI Schema解析器 - 解析OpenAPI 3.0规范""" @@ -88,8 +91,9 @@ class OpenAPISchemaParser: except Exception as e: logger.error(f"解析schema内容失败: {e}") return False, {}, str(e) - - def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]: + + @staticmethod + def _parse_content(content: str, content_type: str) -> Optional[Dict[str, Any]]: """解析内容为字典 Args: @@ -101,7 +105,7 @@ class OpenAPISchemaParser: """ try: # 根据内容类型解析 - if 'json' in content_type: + if 'application/json' in content_type: return json.loads(content) elif 'yaml' in content_type or 'yml' in content_type: return yaml.safe_load(content) @@ -228,8 +232,9 @@ class OpenAPISchemaParser: } return operations - - def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def _extract_parameters(operation: Dict[str, Any]) -> Dict[str, Any]: """提取操作参数 Args: @@ -266,8 +271,9 @@ class OpenAPISchemaParser: } return parameters - - def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]: + + @staticmethod + def _extract_request_body(operation: Dict[str, Any]) -> Optional[Dict[str, Any]]: """提取请求体信息 Args: @@ -298,8 +304,9 @@ class OpenAPISchemaParser: "schema": schema, "content_types": list(content.keys()) } - - def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]: + + @staticmethod + def _extract_responses(operation: Dict[str, Any]) -> Dict[str, Any]: """提取响应信息 Args: @@ -331,8 +338,9 @@ class OpenAPISchemaParser: } return responses - - def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]: + + @staticmethod + def generate_tool_parameters(operations: Dict[str, Any]) -> List[Dict[str, Any]]: """生成工具参数定义 Args: @@ -396,7 +404,7 @@ class OpenAPISchemaParser: parameters.extend(all_params.values()) return parameters - + def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]: """验证操作参数 @@ -447,8 +455,9 @@ class OpenAPISchemaParser: errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}") return len(errors) == 0, errors - - def _validate_parameter_type(self, value: Any, expected_type: str) -> bool: + + @staticmethod + def _validate_parameter_type(value: Any, expected_type: str) -> bool: """验证参数类型 Args: @@ -474,4 +483,7 @@ class OpenAPISchemaParser: if expected_python_type: return isinstance(value, expected_python_type) - return True \ No newline at end of file + return True + +# 为了兼容性,创建别名 +SchemaParser = OpenAPISchemaParser \ No newline at end of file diff --git a/api/app/core/tools/executor.py b/api/app/core/tools/executor.py deleted file mode 100644 index c0ba87fb..00000000 --- a/api/app/core/tools/executor.py +++ /dev/null @@ -1,501 +0,0 @@ -"""工具执行器 - 负责工具的实际调用和执行管理""" -import asyncio -import uuid -import time -from typing import Dict, Any, List, Optional -from datetime import datetime -from sqlalchemy.orm import Session - -from app.models.tool_model import ToolExecution, ExecutionStatus -from app.core.tools.base import BaseTool, ToolResult -from app.core.tools.registry import ToolRegistry -from app.core.logging_config import get_business_logger - -logger = get_business_logger() - - -class ExecutionContext: - """执行上下文""" - - def __init__( - self, - execution_id: str, - tool_id: str, - user_id: Optional[uuid.UUID] = None, - workspace_id: Optional[uuid.UUID] = None, - timeout: Optional[float] = None, - metadata: Optional[Dict[str, Any]] = None - ): - self.execution_id = execution_id - self.tool_id = tool_id - self.user_id = user_id - self.workspace_id = workspace_id - self.timeout = timeout or 60.0 # 默认60秒超时 - self.metadata = metadata or {} - self.started_at = datetime.now() - self.completed_at: Optional[datetime] = None - self.status = ExecutionStatus.PENDING - - -class ToolExecutor: - """工具执行器 - 使用langchain标准接口执行工具""" - - def __init__(self, db: Session, registry: ToolRegistry): - """初始化工具执行器 - - Args: - db: 数据库会话 - registry: 工具注册表 - """ - self.db = db - self.registry = registry - self._running_executions: Dict[str, ExecutionContext] = {} - self._execution_lock = asyncio.Lock() - - async def execute_tool( - self, - tool_id: str, - parameters: Dict[str, Any], - user_id: Optional[uuid.UUID] = None, - workspace_id: Optional[uuid.UUID] = None, - execution_id: Optional[str] = None, - timeout: Optional[float] = None, - metadata: Optional[Dict[str, Any]] = None - ) -> ToolResult: - """执行工具 - - Args: - tool_id: 工具ID - parameters: 工具参数 - user_id: 用户ID - workspace_id: 工作空间ID - execution_id: 执行ID(可选,自动生成) - timeout: 超时时间(秒) - metadata: 额外元数据 - - Returns: - 工具执行结果 - """ - # 生成执行ID - if not execution_id: - execution_id = f"exec_{uuid.uuid4().hex[:16]}" - - # 创建执行上下文 - context = ExecutionContext( - execution_id=execution_id, - tool_id=tool_id, - user_id=user_id, - workspace_id=workspace_id, - timeout=timeout, - metadata=metadata - ) - - try: - # 获取工具实例 - tool = self.registry.get_tool(tool_id) - if not tool: - return ToolResult.error_result( - error=f"工具不存在: {tool_id}", - error_code="TOOL_NOT_FOUND", - execution_time=0.0 - ) - - # 记录执行开始 - await self._record_execution_start(context, parameters) - - # 执行工具 - result = await self._execute_with_timeout(tool, parameters, context) - - # 记录执行完成 - await self._record_execution_complete(context, result) - - return result - - except Exception as e: - logger.error(f"工具执行异常: {execution_id}, 错误: {e}") - - # 记录执行失败 - error_result = ToolResult.error_result( - error=str(e), - error_code="EXECUTION_ERROR", - execution_time=time.time() - context.started_at.timestamp() - ) - await self._record_execution_complete(context, error_result) - - return error_result - - finally: - # 清理执行上下文 - async with self._execution_lock: - if execution_id in self._running_executions: - del self._running_executions[execution_id] - - async def execute_tools_batch( - self, - tool_executions: List[Dict[str, Any]], - max_concurrency: int = 5 - ) -> List[ToolResult]: - """批量执行工具 - - Args: - tool_executions: 工具执行配置列表,每个包含tool_id和parameters - max_concurrency: 最大并发数 - - Returns: - 执行结果列表 - """ - semaphore = asyncio.Semaphore(max_concurrency) - - async def execute_single(exec_config: Dict[str, Any]) -> ToolResult: - async with semaphore: - return await self.execute_tool( - tool_id=exec_config["tool_id"], - parameters=exec_config.get("parameters", {}), - user_id=exec_config.get("user_id"), - workspace_id=exec_config.get("workspace_id"), - timeout=exec_config.get("timeout"), - metadata=exec_config.get("metadata") - ) - - # 并发执行所有工具 - tasks = [execute_single(config) for config in tool_executions] - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 处理异常结果 - processed_results = [] - for i, result in enumerate(results): - if isinstance(result, Exception): - processed_results.append( - ToolResult.error_result( - error=str(result), - error_code="BATCH_EXECUTION_ERROR", - execution_time=0.0 - ) - ) - else: - processed_results.append(result) - - return processed_results - - async def cancel_execution(self, execution_id: str) -> bool: - """取消工具执行 - - Args: - execution_id: 执行ID - - Returns: - 是否成功取消 - """ - async with self._execution_lock: - if execution_id not in self._running_executions: - return False - - context = self._running_executions[execution_id] - context.status = ExecutionStatus.FAILED - - # 更新数据库记录 - execution_record = self.db.query(ToolExecution).filter( - ToolExecution.execution_id == execution_id - ).first() - - if execution_record: - execution_record.status = ExecutionStatus.FAILED.value - execution_record.error_message = "执行被取消" - execution_record.completed_at = datetime.now() - self.db.commit() - - logger.info(f"工具执行已取消: {execution_id}") - return True - - def get_running_executions(self) -> List[Dict[str, Any]]: - """获取正在运行的执行列表 - - Returns: - 执行信息列表 - """ - executions = [] - for execution_id, context in self._running_executions.items(): - executions.append({ - "execution_id": execution_id, - "tool_id": context.tool_id, - "user_id": str(context.user_id) if context.user_id else None, - "workspace_id": str(context.workspace_id) if context.workspace_id else None, - "started_at": context.started_at.isoformat(), - "status": context.status.value, - "elapsed_time": (datetime.now() - context.started_at).total_seconds() - }) - - return executions - - async def _execute_with_timeout( - self, - tool: BaseTool, - parameters: Dict[str, Any], - context: ExecutionContext - ) -> ToolResult: - """带超时的工具执行 - - Args: - tool: 工具实例 - parameters: 参数 - context: 执行上下文 - - Returns: - 执行结果 - """ - async with self._execution_lock: - self._running_executions[context.execution_id] = context - context.status = ExecutionStatus.RUNNING - - try: - # 使用asyncio.wait_for实现超时控制 - result = await asyncio.wait_for( - tool.safe_execute(**parameters), - timeout=context.timeout - ) - - context.status = ExecutionStatus.COMPLETED - return result - - except asyncio.TimeoutError: - context.status = ExecutionStatus.TIMEOUT - return ToolResult.error_result( - error=f"工具执行超时({context.timeout}秒)", - error_code="EXECUTION_TIMEOUT", - execution_time=context.timeout - ) - - except Exception as e: - context.status = ExecutionStatus.FAILED - raise - - async def _record_execution_start( - self, - context: ExecutionContext, - parameters: Dict[str, Any] - ): - """记录执行开始""" - try: - execution_record = ToolExecution( - execution_id=context.execution_id, - tool_config_id=uuid.UUID(context.tool_id), - status=ExecutionStatus.RUNNING.value, - input_data=parameters, - started_at=context.started_at, - user_id=context.user_id, - workspace_id=context.workspace_id - ) - - self.db.add(execution_record) - self.db.commit() - - logger.debug(f"执行记录已创建: {context.execution_id}") - - except Exception as e: - logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}") - - async def _record_execution_complete( - self, - context: ExecutionContext, - result: ToolResult - ): - """记录执行完成""" - try: - context.completed_at = datetime.now() - - execution_record = self.db.query(ToolExecution).filter( - ToolExecution.execution_id == context.execution_id - ).first() - - if execution_record: - execution_record.status = ( - ExecutionStatus.COMPLETED.value if result.success - else ExecutionStatus.FAILED.value - ) - execution_record.output_data = result.data if result.success else None - execution_record.error_message = result.error if not result.success else None - execution_record.completed_at = context.completed_at - execution_record.execution_time = result.execution_time - execution_record.token_usage = result.token_usage - - self.db.commit() - - logger.debug(f"执行记录已更新: {context.execution_id}") - - except Exception as e: - logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}") - - def get_execution_history( - self, - tool_id: Optional[str] = None, - user_id: Optional[uuid.UUID] = None, - workspace_id: Optional[uuid.UUID] = None, - limit: int = 100 - ) -> List[Dict[str, Any]]: - """获取执行历史 - - Args: - tool_id: 工具ID过滤 - user_id: 用户ID过滤 - workspace_id: 工作空间ID过滤 - limit: 返回数量限制 - - Returns: - 执行历史列表 - """ - try: - query = self.db.query(ToolExecution).order_by( - ToolExecution.started_at.desc() - ) - - if tool_id: - query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id)) - - if user_id: - query = query.filter(ToolExecution.user_id == user_id) - - if workspace_id: - query = query.filter(ToolExecution.workspace_id == workspace_id) - - executions = query.limit(limit).all() - - history = [] - for execution in executions: - history.append({ - "execution_id": execution.execution_id, - "tool_id": str(execution.tool_config_id), - "status": execution.status, - "started_at": execution.started_at.isoformat() if execution.started_at else None, - "completed_at": execution.completed_at.isoformat() if execution.completed_at else None, - "execution_time": execution.execution_time, - "user_id": str(execution.user_id) if execution.user_id else None, - "workspace_id": str(execution.workspace_id) if execution.workspace_id else None, - "input_data": execution.input_data, - "output_data": execution.output_data, - "error_message": execution.error_message, - "token_usage": execution.token_usage - }) - - return history - - except Exception as e: - logger.error(f"获取执行历史失败, 错误: {e}") - return [] - - def get_execution_statistics( - self, - workspace_id: Optional[uuid.UUID] = None, - days: int = 7 - ) -> Dict[str, Any]: - """获取执行统计信息 - - Args: - workspace_id: 工作空间ID - days: 统计天数 - - Returns: - 统计信息 - """ - try: - from datetime import timedelta - - start_date = datetime.now() - timedelta(days=days) - - query = self.db.query(ToolExecution).filter( - ToolExecution.started_at >= start_date - ) - - if workspace_id: - query = query.filter(ToolExecution.workspace_id == workspace_id) - - executions = query.all() - - # 统计数据 - total_executions = len(executions) - successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value]) - failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value]) - - # 平均执行时间 - completed_executions = [e for e in executions if e.execution_time is not None] - avg_execution_time = ( - sum(e.execution_time for e in completed_executions) / len(completed_executions) - if completed_executions else 0 - ) - - # 按工具统计 - tool_stats = {} - for execution in executions: - tool_id = str(execution.tool_config_id) - if tool_id not in tool_stats: - tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0} - - tool_stats[tool_id]["total"] += 1 - if execution.status == ExecutionStatus.COMPLETED.value: - tool_stats[tool_id]["successful"] += 1 - elif execution.status == ExecutionStatus.FAILED.value: - tool_stats[tool_id]["failed"] += 1 - - return { - "period_days": days, - "total_executions": total_executions, - "successful_executions": successful_executions, - "failed_executions": failed_executions, - "success_rate": successful_executions / total_executions if total_executions > 0 else 0, - "average_execution_time": avg_execution_time, - "tool_statistics": tool_stats - } - - except Exception as e: - logger.error(f"获取执行统计失败, 错误: {e}") - return {} - - async def test_tool_connection( - self, - tool_id: str, - user_id: Optional[uuid.UUID] = None, - workspace_id: Optional[uuid.UUID] = None - ) -> Dict[str, Any]: - """测试工具连接""" - try: - from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig - from .mcp.client import MCPClient - - tool_config = self.db.query(ToolConfig).filter( - ToolConfig.id == uuid.UUID(tool_id) - ).first() - - if not tool_config: - return {"success": False, "message": "工具不存在"} - - if tool_config.tool_type == ToolType.MCP.value: - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == tool_config.id - ).first() - - if not mcp_config: - return {"success": False, "message": "MCP配置不存在"} - - client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {}) - - if await client.connect(): - try: - tools = await client.list_tools() - await client.disconnect() - return { - "success": True, - "message": "MCP连接成功", - "details": {"server_url": mcp_config.server_url, "tools": len(tools)} - } - except: - await client.disconnect() - return {"success": False, "message": "MCP功能测试失败"} - else: - return {"success": False, "message": "MCP连接失败"} - else: - tool = self.registry.get_tool(tool_id) - if tool and hasattr(tool, 'test_connection'): - result = tool.test_connection() - return {"success": result.get("success", False), "message": result.get("message", "")} - return {"success": True, "message": "工具无需连接测试"} - except Exception as e: - return {"success": False, "message": "测试失败", "error": str(e)} \ No newline at end of file diff --git a/api/app/core/tools/mcp/base.py b/api/app/core/tools/mcp/base.py index 241069cd..ca77f528 100644 --- a/api/app/core/tools/mcp/base.py +++ b/api/app/core/tools/mcp/base.py @@ -4,7 +4,8 @@ from typing import Dict, Any, List import aiohttp from app.models.tool_model import ToolType -from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.core.tools.base import BaseTool +from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType from app.core.logging_config import get_business_logger logger = get_business_logger() @@ -123,33 +124,43 @@ class MCPTool(BaseTool): async def connect(self) -> bool: """连接到MCP服务器""" try: - # 这里应该实现实际的MCP连接逻辑 - # 为了简化,这里只是模拟连接 + from .client import MCPClient - # 测试服务器连接 - timeout = aiohttp.ClientTimeout(total=10) - async with aiohttp.ClientSession(timeout=timeout) as session: - # 尝试获取服务器信息 - async with session.get(f"{self.server_url}/info") as response: - if response.status == 200: - server_info = await response.json() - self.available_tools = server_info.get("tools", []) - self._connected = True - logger.info(f"MCP服务器连接成功: {self.server_url}") - return True - else: - raise Exception(f"服务器响应错误: {response.status}") + if self._connected: + return True + + self._client = MCPClient(self.server_url, self.connection_config) + + if await self._client.connect(): + self._connected = True + # 更新可用工具列表 + await self._update_available_tools() + logger.info(f"MCP服务器连接成功: {self.server_url}") + return True + else: + logger.error(f"MCP服务器连接失败: {self.server_url}") + return False except Exception as e: - logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}") + logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}") self._connected = False return False + async def _update_available_tools(self): + """更新可用工具列表""" + try: + if self._client and self._connected: + tools = await self._client.list_tools() + self.available_tools = [tool.get("name") for tool in tools if tool.get("name")] + logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具") + except Exception as e: + logger.error(f"更新MCP工具列表失败: {e}") + async def disconnect(self) -> bool: """断开MCP服务器连接""" try: if self._client: - # 这里应该实现实际的断开逻辑 + await self._client.disconnect() self._client = None self._connected = False @@ -171,38 +182,15 @@ class MCPTool(BaseTool): async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any: """调用MCP工具""" - # 构建MCP请求 - request_data = { - "jsonrpc": "2.0", - "id": f"req_{int(time.time() * 1000)}", - "method": "tools/call", - "params": { - "name": tool_name, - "arguments": arguments - } - } + if not self._client or not self._connected: + raise Exception("MCP客户端未连接") - # 发送请求 - client_timeout = aiohttp.ClientTimeout(total=timeout) - async with aiohttp.ClientSession(timeout=client_timeout) as session: - async with session.post( - f"{self.server_url}/mcp", - json=request_data, - headers={"Content-Type": "application/json"} - ) as response: - - if response.status != 200: - error_text = await response.text() - raise Exception(f"MCP请求失败 {response.status}: {error_text}") - - result = await response.json() - - # 检查MCP响应 - if "error" in result: - error = result["error"] - raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}") - - return result.get("result", {}) + try: + result = await self._client.call_tool(tool_name, arguments, timeout) + return result + except Exception as e: + logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}") + raise async def list_available_tools(self) -> List[Dict[str, Any]]: """列出可用的MCP工具""" @@ -210,27 +198,10 @@ class MCPTool(BaseTool): if not self._connected: await self.connect() - # 获取工具列表 - request_data = { - "jsonrpc": "2.0", - "id": f"req_{int(time.time() * 1000)}", - "method": "tools/list" - } - - timeout = aiohttp.ClientTimeout(total=10) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post( - f"{self.server_url}/mcp", - json=request_data, - headers={"Content-Type": "application/json"} - ) as response: - - if response.status == 200: - result = await response.json() - if "result" in result: - tools = result["result"].get("tools", []) - self.available_tools = [tool.get("name") for tool in tools] - return tools + if self._client: + tools = await self._client.list_tools() + self.available_tools = [tool.get("name") for tool in tools if tool.get("name")] + return tools return [] diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 3be2e9bf..997e6e84 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -134,11 +134,40 @@ class MCPClient: logger.error(f"断开MCP服务器连接失败: {e}") return False + def _build_auth_headers(self) -> Dict[str, str]: + """构建认证头""" + headers = {} + auth_type = self.connection_config.get("auth_type", "none") + auth_config = self.connection_config.get("auth_config", {}) + + if auth_type == "api_key": + api_key = auth_config.get("api_key") + key_name = auth_config.get("key_name", "X-API-Key") + if api_key: + headers[key_name] = api_key + + elif auth_type == "bearer_token": + token = auth_config.get("token") + if token: + headers["Authorization"] = f"Bearer {token}" + + elif auth_type == "basic_auth": + username = auth_config.get("username") + password = auth_config.get("password") + if username and password: + import base64 + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + headers["Authorization"] = f"Basic {credentials}" + + return headers + async def _connect_websocket(self) -> bool: """建立WebSocket连接""" try: # WebSocket连接配置 extra_headers = self.connection_config.get("headers", {}) + auth_headers = self._build_auth_headers() + extra_headers.update(auth_headers) self._websocket = await websockets.connect( self.server_url, @@ -190,6 +219,8 @@ class MCPClient: # HTTP会话配置 timeout = aiohttp.ClientTimeout(total=self.connection_timeout) headers = self.connection_config.get("headers", {}) + auth_headers = self._build_auth_headers() + headers.update(auth_headers) self._session = aiohttp.ClientSession( timeout=timeout, @@ -251,8 +282,9 @@ class MCPClient: except Exception as e: logger.error(f"处理消息失败: {e}") - - async def _handle_notification(self, message: Dict[str, Any]): + + @staticmethod + async def _handle_notification(message: Dict[str, Any]): """处理通知消息""" method = message.get("method") params = message.get("params", {}) @@ -327,7 +359,7 @@ class MCPClient: try: response = await self._send_request(request_data, timeout) - if not response["error"] is None: + if response.get("error", None) is not None: error = response["error"] raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}") @@ -372,10 +404,10 @@ class MCPClient: return response except asyncio.TimeoutError: - self._pending_requests.pop(request_id, None) + await self._pending_requests.pop(request_id, None) raise except Exception as e: - self._pending_requests.pop(request_id, None) + await self._pending_requests.pop(request_id, None) raise MCPConnectionError(f"发送WebSocket请求失败: {e}") async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: @@ -424,9 +456,9 @@ class MCPClient: start_time = time.time() response = await self._send_request(request_data, timeout=5) - response_time = time.time() - start_time + response_time = round((time.time() - start_time) * 1000) - self._last_health_check = time.time() + self._last_health_check = round(time.time() * 1000) return { "healthy": True, diff --git a/api/app/core/tools/mcp/service_manager.py b/api/app/core/tools/mcp/service_manager.py index 53b83ddd..51d01535 100644 --- a/api/app/core/tools/mcp/service_manager.py +++ b/api/app/core/tools/mcp/service_manager.py @@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional, Tuple from datetime import datetime from sqlalchemy.orm import Session -from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType +from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus from app.core.logging_config import get_business_logger from .client import MCPClient, MCPConnectionPool @@ -148,7 +148,7 @@ class MCPServiceManager: connection_config=connection_config, available_tools=tool_names, health_status="healthy", - last_health_check=datetime.utcnow() + last_health_check=datetime.now() ) self.db.add(mcp_config) @@ -410,7 +410,8 @@ class MCPServiceManager: """加载现有服务""" try: mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter( - ToolConfig.is_enabled == True + ToolConfig.status == ToolStatus.AVAILABLE.value, + ToolConfig.tool_type == ToolType.MCP.value ).all() for mcp_config in mcp_configs: @@ -531,7 +532,7 @@ class MCPServiceManager: if mcp_config: mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy" - mcp_config.last_health_check = datetime.utcnow() + mcp_config.last_health_check = datetime.now() if not health_status["healthy"]: mcp_config.error_message = health_status.get("error", "") diff --git a/api/app/core/tools/registry.py b/api/app/core/tools/registry.py deleted file mode 100644 index b56c1bf7..00000000 --- a/api/app/core/tools/registry.py +++ /dev/null @@ -1,436 +0,0 @@ -"""工具注册表 - 管理所有工具的元数据和状态""" -import uuid -import asyncio -from typing import Dict, List, Optional, Type, Any - -from sqlalchemy.orm import Session -from sqlalchemy import and_, or_ - -from app.models.tool_model import ( - ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, - ToolType, ToolStatus, ToolExecution, ExecutionStatus -) -from app.core.logging_config import get_business_logger -from .base import BaseTool, ToolInfo -from .custom.base import CustomTool -from .mcp.base import MCPTool - -logger = get_business_logger() - - -class ToolRegistry: - """工具注册表 - 管理所有工具的元数据和实例""" - - def __init__(self, db: Session): - """初始化工具注册表 - - Args: - db: 数据库会话 - """ - self.db = db - self._tools: Dict[str, BaseTool] = {} # 工具实例缓存 - self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表 - self._lock = asyncio.Lock() # 异步锁 - - def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None): - """注册工具类 - - Args: - tool_class: 工具类 - class_name: 类名(可选,默认使用类的__name__) - """ - class_name = class_name or tool_class.__name__ - self._tool_classes[class_name] = tool_class - logger.info(f"工具类已注册: {class_name}") - - async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool: - """注册工具实例到系统 - - Args: - tool: 工具实例 - tenant_id: 租户ID(内置工具可以为None,表示全局工具) - - Returns: - 注册是否成功 - """ - async with self._lock: - try: - # 检查工具是否已存在 - if tenant_id: - existing_config = self.db.query(ToolConfig).filter( - and_( - ToolConfig.name == tool.name, - ToolConfig.tenant_id == tenant_id, - ToolConfig.tool_type == tool.tool_type.value - ) - ).first() - else: - # 全局工具(内置工具) - existing_config = self.db.query(ToolConfig).filter( - and_( - ToolConfig.name == tool.name, - ToolConfig.tenant_id.is_(None), - ToolConfig.tool_type == tool.tool_type.value - ) - ).first() - - if existing_config: - logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})") - return False - - # 创建工具配置 - tool_config = ToolConfig( - name=tool.name, - description=tool.description, - tool_type=tool.tool_type.value, - tenant_id=tenant_id, - version=tool.version, - tags=tool.tags, - config_data=tool.config - ) - - self.db.add(tool_config) - self.db.flush() # 获取ID - - # 根据工具类型创建特定配置 - if tool.tool_type == ToolType.BUILTIN: - builtin_config = BuiltinToolConfig( - id=tool_config.id, - tool_class=tool.__class__.__name__, - parameters=tool.config.get("parameters", {}) - ) - self.db.add(builtin_config) - - elif tool.tool_type == ToolType.CUSTOM: - custom_config = CustomToolConfig( - id=tool_config.id, - schema_url=tool.config.get("schema_url"), - schema_content=tool.config.get("schema_content"), - auth_type=tool.config.get("auth_type", "none"), - auth_config=tool.config.get("auth_config", {}), - base_url=tool.config.get("base_url"), - timeout=tool.config.get("timeout", 30) - ) - self.db.add(custom_config) - - elif tool.tool_type == ToolType.MCP: - mcp_config = MCPToolConfig( - id=tool_config.id, - server_url=tool.config.get("server_url"), - connection_config=tool.config.get("connection_config", {}), - available_tools=tool.config.get("available_tools", []) - ) - self.db.add(mcp_config) - - self.db.commit() - - # 缓存工具实例 - tool.tool_id = str(tool_config.id) - self._tools[str(tool_config.id)] = tool - - logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})") - return True - - except Exception as e: - self.db.rollback() - logger.error(f"工具注册失败: {tool.name}, 错误: {e}") - return False - - async def unregister_tool(self, tool_id: str) -> bool: - """从系统注销工具 - - Args: - tool_id: 工具ID - - Returns: - 注销是否成功 - """ - async with self._lock: - try: - # 检查工具是否存在 - tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) - if not tool_config: - logger.warning(f"工具不存在: {tool_id}") - return False - - # 检查是否有正在执行的任务 - running_executions = self.db.query(ToolExecution).filter( - and_( - ToolExecution.tool_config_id == uuid.UUID(tool_id), - ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value]) - ) - ).count() - - if running_executions > 0: - logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}") - return False - - # 删除工具配置(级联删除相关记录) - self.db.delete(tool_config) - self.db.commit() - - # 从缓存中移除 - if tool_id in self._tools: - del self._tools[tool_id] - - logger.info(f"工具注销成功: {tool_id}") - return True - - except Exception as e: - self.db.rollback() - logger.error(f"工具注销失败: {tool_id}, 错误: {e}") - return False - - def get_tool(self, tool_id: str) -> Optional[BaseTool]: - """获取工具实例 - - Args: - tool_id: 工具ID - - Returns: - 工具实例,如果不存在返回None - """ - # 先从缓存获取 - if tool_id in self._tools: - return self._tools[tool_id] - - # 从数据库加载 - try: - tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) - if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value: - return None - - # 根据工具类型加载实例 - tool_instance = self._load_tool_instance(tool_config) - if tool_instance: - self._tools[tool_id] = tool_instance - return tool_instance - - except Exception as e: - logger.error(f"加载工具失败: {tool_id}, 错误: {e}") - - return None - - def list_tools( - self, - tenant_id: Optional[uuid.UUID] = None, - tool_type: Optional[ToolType] = None, - status: Optional[ToolStatus] = None, - tags: Optional[List[str]] = None - ) -> List[ToolInfo]: - """列出工具 - - Args: - tenant_id: 租户ID过滤 - tool_type: 工具类型过滤 - status: 工具状态过滤 - tags: 标签过滤 - - Returns: - 工具信息列表 - """ - try: - query = self.db.query(ToolConfig) - - # 应用过滤条件 - if tenant_id: - # 返回全局工具(tenant_id为空)和该租户的工具 - query = query.filter( - or_( - ToolConfig.tenant_id == tenant_id, - ToolConfig.tenant_id.is_(None) - ) - ) - - if tool_type: - query = query.filter(ToolConfig.tool_type == tool_type.value) - - if status == ToolStatus.ACTIVE: - query = query.filter(ToolConfig.is_enabled == True) - elif status == ToolStatus.INACTIVE: - query = query.filter(ToolConfig.is_enabled == False) - - if tags: - for tag in tags: - query = query.filter(ToolConfig.tags.contains([tag])) - - tool_configs = query.all() - - # 转换为ToolInfo - tool_infos = [] - for config in tool_configs: - tool_info = ToolInfo( - id=str(config.id), - name=config.name, - description=config.description or "", - tool_type=ToolType(config.tool_type), - version=config.version, - status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE, - tags=config.tags or [], - tenant_id=str(config.tenant_id) if config.tenant_id else None - ) - - # 尝试获取参数信息 - tool_instance = self.get_tool(str(config.id)) - if tool_instance: - tool_info.parameters = tool_instance.parameters - - tool_infos.append(tool_info) - - return tool_infos - - except Exception as e: - logger.error(f"列出工具失败, 错误: {e}") - return [] - - async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool: - """更新工具状态 - - Args: - tool_id: 工具ID - status: 新状态 - - Returns: - 更新是否成功 - """ - try: - tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) - if not tool_config: - logger.warning(f"工具不存在: {tool_id}") - return False - - # 更新状态 - if status == ToolStatus.ACTIVE: - tool_config.is_enabled = True - elif status == ToolStatus.INACTIVE: - tool_config.is_enabled = False - - self.db.commit() - - # 更新缓存中的工具状态 - if tool_id in self._tools: - self._tools[tool_id].status = status - - logger.info(f"工具状态更新成功: {tool_id} -> {status}") - return True - - except Exception as e: - self.db.rollback() - logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}") - return False - - def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]: - """从配置加载工具实例 - - Args: - tool_config: 工具配置 - - Returns: - 工具实例 - """ - try: - if tool_config.tool_type == ToolType.BUILTIN.value: - # 加载内置工具 - builtin_config = self.db.query(BuiltinToolConfig).filter( - BuiltinToolConfig.id == tool_config.id - ).first() - - if builtin_config and builtin_config.tool_class in self._tool_classes: - tool_class = self._tool_classes[builtin_config.tool_class] - config = { - **tool_config.config_data, - "parameters": builtin_config.parameters, - "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, - "version": tool_config.version, - "tags": tool_config.tags - } - return tool_class(str(tool_config.id), config) - - elif tool_config.tool_type == ToolType.CUSTOM.value: - # 加载自定义工具 - try: - custom_config = self.db.query(CustomToolConfig).filter( - CustomToolConfig.id == tool_config.id - ).first() - - if custom_config: - config = { - **tool_config.config_data, - "schema_url": custom_config.schema_url, - "schema_content": custom_config.schema_content, - "auth_type": custom_config.auth_type, - "auth_config": custom_config.auth_config, - "base_url": custom_config.base_url, - "timeout": custom_config.timeout, - "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, - "version": tool_config.version, - "tags": tool_config.tags - } - return CustomTool(str(tool_config.id), config) - except ImportError as e: - logger.error(f"无法导入自定义工具模块: {e}") - - elif tool_config.tool_type == ToolType.MCP.value: - # 加载MCP工具 - try: - mcp_config = self.db.query(MCPToolConfig).filter( - MCPToolConfig.id == tool_config.id - ).first() - - if mcp_config: - config = { - **tool_config.config_data, - "server_url": mcp_config.server_url, - "connection_config": mcp_config.connection_config, - "available_tools": mcp_config.available_tools, - "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, - "version": tool_config.version, - "tags": tool_config.tags - } - return MCPTool(str(tool_config.id), config) - except ImportError as e: - logger.error(f"无法导入MCP工具模块: {e}") - - except Exception as e: - logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}") - - return None - - def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: - """获取工具统计信息 - - Args: - tenant_id: 租户ID - - Returns: - 统计信息字典 - """ - try: - query = self.db.query(ToolConfig) - if tenant_id: - query = query.filter(ToolConfig.tenant_id == tenant_id) - - total_tools = query.count() - active_tools = query.filter(ToolConfig.is_enabled == True).count() - - # 按类型统计 - type_stats = {} - for tool_type in ToolType: - count = query.filter(ToolConfig.tool_type == tool_type.value).count() - type_stats[tool_type.value] = count - - return { - "total_tools": total_tools, - "active_tools": active_tools, - "inactive_tools": total_tools - active_tools, - "by_type": type_stats - } - - except Exception as e: - logger.error(f"获取工具统计失败, 错误: {e}") - return {} - - def clear_cache(self): - """清空工具缓存""" - self._tools.clear() - logger.info("工具缓存已清空") \ No newline at end of file diff --git a/api/app/models/tool_model.py b/api/app/models/tool_model.py index ac719317..f170148f 100644 --- a/api/app/models/tool_model.py +++ b/api/app/models/tool_model.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime from enum import StrEnum -from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float +from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float, Boolean from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -19,10 +19,40 @@ class ToolType(StrEnum): class ToolStatus(StrEnum): """工具状态枚举""" - ACTIVE = "active" - INACTIVE = "inactive" - ERROR = "error" - LOADING = "loading" + AVAILABLE = "available" # 可用(已配置且已启用) + UNCONFIGURED = "unconfigured" # 未配置 + CONFIGURED_DISABLED = "configured_disabled" # 已配置未启用 + ERROR = "error" # 错误状态 + + @classmethod + def get_all_statuses(cls): + """获取所有工具状态""" + return [status.value for status in cls] + + @classmethod + def get_all_statuses_with_labels(cls): + """获取所有工具状态及其文本描述""" + return [ + {"value": cls.AVAILABLE.value, "label": "可用"}, + {"value": cls.UNCONFIGURED.value, "label": "未配置"}, + {"value": cls.CONFIGURED_DISABLED.value, "label": "已配置未启用"}, + {"value": cls.ERROR.value, "label": "错误状态"} + ] + + @classmethod + def is_valid_status(cls, status): + """检查状态是否有效""" + return status in cls._value2member_map_ + + @classmethod + def get_active_statuses(cls): + """获取所有活跃状态""" + return [cls.AVAILABLE.value] + + @classmethod + def get_inactive_statuses(cls): + """获取所有非活跃状态""" + return [cls.UNCONFIGURED.value, cls.CONFIGURED_DISABLED.value, cls.ERROR.value] class AuthType(StrEnum): @@ -30,6 +60,27 @@ class AuthType(StrEnum): NONE = "none" API_KEY = "api_key" BEARER_TOKEN = "bearer_token" + BASIC_AUTH = "basic_auth" + + @classmethod + def get_all_types(cls): + """获取所有认证类型""" + return [auth_type.value for auth_type in cls] + + @classmethod + def get_all_types_with_labels(cls): + """获取所有认证类型及其文本描述""" + return [ + {"value": cls.NONE.value, "label": "无需认证"}, + {"value": cls.API_KEY.value, "label": "API Key"}, + {"value": cls.BEARER_TOKEN.value, "label": "Bearer Token"}, + {"value": cls.BASIC_AUTH.value, "label": "Basic Auth"} + ] + + @classmethod + def is_valid_types(cls, auth_type): + """检查认证类型是否有效""" + return auth_type in cls._value2member_map_ class ExecutionStatus(StrEnum): @@ -48,13 +99,14 @@ class ToolConfig(Base): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) name = Column(String(255), nullable=False, index=True) description = Column(Text) + icon = Column(String(255)) # 工具图标 tool_type = Column(String(50), nullable=False, index=True) tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户 - status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态 + status = Column(String(50), default=ToolStatus.UNCONFIGURED.value, nullable=False, index=True) # 工具状态 # 工具特定配置(JSON格式存储) config_data = Column(JSON, default=dict) - + # 元数据 version = Column(String(50), default="1.0.0") tags = Column(JSON, default=list) # 标签列表 @@ -78,12 +130,14 @@ class BuiltinToolConfig(Base): id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) tool_class = Column(String(255), nullable=False) # 工具类名 parameters = Column(JSON, default=dict) # 工具参数配置 - + is_enabled = Column(Boolean, default=False, nullable=False) # 启用开关 + requires_config = Column(Boolean, default=False, nullable=False) # 是否需要配置 + # 关联关系 base_config = relationship("ToolConfig", foreign_keys=[id]) def __repr__(self): - return f"" + return f"" class CustomToolConfig(Base): @@ -115,7 +169,7 @@ class MCPToolConfig(Base): id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) server_url = Column(String(1000), nullable=False) # MCP服务器URL - connection_config = Column(JSON, default=dict) # 连接配置 + connection_config = Column(JSON, default=dict) # 连接配置(包含认证信息) # 服务状态 last_health_check = Column(DateTime) diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py new file mode 100644 index 00000000..bc8db683 --- /dev/null +++ b/api/app/repositories/tool_repository.py @@ -0,0 +1,157 @@ +"""工具数据访问层""" +import uuid +from typing import List, Optional, Dict, Any +from sqlalchemy.orm import Session +from sqlalchemy import func, or_ + +from app.repositories.base_repository import BaseRepository +from app.models.tool_model import ( + ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, + ToolExecution, ToolType, ToolStatus +) + + +class ToolRepository: + """工具仓储类""" + + @staticmethod + def find_by_tenant( + db: Session, + tenant_id: uuid.UUID, + name: Optional[str] = None, + tool_type: Optional[ToolType] = None, + status: Optional[ToolStatus] = None, + is_enabled: Optional[bool] = None + ) -> List[ToolConfig]: + """根据租户查找工具""" + query = db.query(ToolConfig).filter( + ToolConfig.tenant_id == tenant_id + ) + + if name: + query = query.filter(ToolConfig.name.ilike(f"%{name}%")) + if tool_type: + query = query.filter(ToolConfig.tool_type == tool_type.value) + if status: + query = query.filter(ToolConfig.status == status.value) + if is_enabled is not None: + query = query.filter(ToolConfig.is_enabled == is_enabled) + + return query.all() + + @staticmethod + def find_by_id_and_tenant(db:Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]: + """根据ID和租户查找工具""" + return db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == tenant_id + ).first() + + @staticmethod + def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int: + """统计租户工具数量""" + return db.query(ToolConfig).filter( + ToolConfig.tenant_id == tenant_id + ).count() + + @staticmethod + def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: + """获取状态统计""" + return db.query( + ToolConfig.status, + func.count(ToolConfig.id).label('count') + ).filter( + ToolConfig.tenant_id == tenant_id + ).group_by(ToolConfig.status).all() + + @staticmethod + def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: + """获取类型统计""" + return db.query( + ToolConfig.tool_type, + func.count(ToolConfig.id).label('count') + ).filter( + ToolConfig.tenant_id == tenant_id + ).group_by(ToolConfig.tool_type).all() + + @staticmethod + def count_enabled_by_tenant(db: Session, tenant_id: uuid.UUID) -> int: + """统计租户启用的工具数量""" + return db.query(ToolConfig).filter( + ToolConfig.tenant_id == tenant_id, + ToolConfig.is_enabled == True + ).count() + + @staticmethod + def exists_builtin_for_tenant(db: Session, tenant_id: uuid.UUID) -> bool: + """检查租户是否已有内置工具""" + return db.query(ToolConfig).filter( + ToolConfig.tenant_id == tenant_id, + ToolConfig.tool_type == ToolType.BUILTIN.value + ).count() > 0 + + +class BuiltinToolRepository: + """内置工具仓储类""" + + @staticmethod + def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[BuiltinToolConfig]: + """根据工具ID查找内置工具配置""" + return db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_id + ).first() + + +class CustomToolRepository: + """自定义工具仓储类""" + + @staticmethod + def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[CustomToolConfig]: + """根据工具ID查找自定义工具配置""" + return db.query(CustomToolConfig).filter( + CustomToolConfig.id == tool_id + ).first() + + +class MCPToolRepository: + """MCP工具仓储类""" + + @staticmethod + def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[MCPToolConfig]: + """根据工具ID查找MCP工具配置""" + return db.query(MCPToolConfig).filter( + MCPToolConfig.id == tool_id + ).first() + + @staticmethod + def find_error_connections(db: Session) -> List[MCPToolConfig]: + """查找连接错误的MCP工具""" + return db.query(MCPToolConfig).filter( + MCPToolConfig.connection_status == "error" + ).all() + + +class ToolExecutionRepository: + """工具执行仓储类""" + + @staticmethod + def find_by_execution_id(db: Session, execution_id: str) -> Optional[ToolExecution]: + """根据执行ID查找执行记录""" + return db.query(ToolExecution).filter( + ToolExecution.execution_id == execution_id + ).first() + + @staticmethod + def find_by_tool_and_tenant( + db: Session, + tool_id: uuid.UUID, + tenant_id: uuid.UUID, + limit: int = 100 + ) -> List[ToolExecution]: + """根据工具和租户查找执行记录""" + return db.query(ToolExecution).join( + ToolConfig, ToolExecution.tool_config_id == ToolConfig.id + ).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == tenant_id + ).order_by(ToolExecution.started_at.desc()).limit(limit).all() \ No newline at end of file diff --git a/api/app/schemas/tool_schema.py b/api/app/schemas/tool_schema.py new file mode 100644 index 00000000..15167375 --- /dev/null +++ b/api/app/schemas/tool_schema.py @@ -0,0 +1,259 @@ +"""工具相关的数据模式定义""" +from typing import Dict, Any, List, Optional +from pydantic import BaseModel, Field, field_serializer +from datetime import datetime +from enum import Enum + +from app.core.api_key_utils import datetime_to_timestamp +from app.models.tool_model import ToolType, ToolStatus, AuthType + + +class ParameterType(str, Enum): + """参数类型枚举""" + STRING = "string" + INTEGER = "integer" + NUMBER = "number" + BOOLEAN = "boolean" + ARRAY = "array" + OBJECT = "object" + + +class ToolParameter(BaseModel): + """工具参数定义""" + name: str = Field(..., description="参数名称") + type: ParameterType = Field(..., description="参数类型") + description: str = Field("", description="参数描述") + required: bool = Field(False, description="是否必需") + default: Any = Field(None, description="默认值") + enum: Optional[List[Any]] = Field(None, description="枚举值") + minimum: Optional[float] = Field(None, description="最小值") + maximum: Optional[float] = Field(None, description="最大值") + pattern: Optional[str] = Field(None, description="正则表达式模式") + + class Config: + use_enum_values = True + + +class ToolResult(BaseModel): + """工具执行结果""" + success: bool = Field(..., description="执行是否成功") + data: Any = Field(None, description="返回数据") + error: Optional[str] = Field(None, description="错误信息") + error_code: Optional[str] = Field(None, description="错误代码") + execution_time: float = Field(..., description="执行时间(秒)") + token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况") + metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据") + + @classmethod + def success_result( + cls, + data: Any, + execution_time: float, + token_usage: Optional[Dict[str, int]] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> "ToolResult": + """创建成功结果""" + return cls( + success=True, + data=data, + execution_time=execution_time, + token_usage=token_usage, + metadata=metadata or {} + ) + + @classmethod + def error_result( + cls, + error: str, + execution_time: float, + error_code: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> "ToolResult": + """创建错误结果""" + return cls( + success=False, + error=error, + error_code=error_code, + execution_time=execution_time, + metadata=metadata or {} + ) + + +class ToolInfo(BaseModel): + """工具信息""" + id: str = Field(..., description="工具ID") + name: str = Field(..., description="工具名称") + description: str = Field(..., description="工具描述") + icon: Optional[str] = Field(None, description="工具图标") + tool_type: ToolType = Field(..., description="工具类型") + version: str = Field("1.0.0", description="工具版本") + parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数") + config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置") + status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态") + tags: List[str] = Field(default_factory=list, description="工具标签") + tenant_id: Optional[str] = Field(None, description="租户ID") + created_at: datetime = Field(..., description="创建时间") + + class Config: + use_enum_values = True + + @field_serializer('created_at') + @classmethod + def serialize_datetime(cls, v): + """将datetime转换为时间戳""" + return datetime_to_timestamp(v) + + +class ToolConfigSchema(BaseModel): + """工具配置基础模式""" + id: str + name: str + description: Optional[str] = None + icon: Optional[str] = None + tool_type: ToolType + status: ToolStatus + config_data: Dict[str, Any] = Field(default_factory=dict) + version: str = "1.0.0" + tags: List[str] = Field(default_factory=list) + tenant_id: str + created_at: datetime + updated_at: datetime + + class Config: + from_attributes = True + + +class BuiltinToolConfigSchema(BaseModel): + """内置工具配置模式""" + tool_class: str + parameters: Dict[str, Any] = Field(default_factory=dict) + is_enabled: bool + requires_config: bool = False + + class Config: + from_attributes = True + + +class CustomToolConfigSchema(BaseModel): + """自定义工具配置模式""" + base_url: Optional[str] = None + auth_type: AuthType = AuthType.NONE + auth_config: Dict[str, Any] = Field(default_factory=dict) + timeout: int = 30 + schema_content: Optional[Dict[str, Any]] = None + schema_url: Optional[str] = None + + class Config: + from_attributes = True + + +class MCPToolConfigSchema(BaseModel): + """MCP工具配置模式""" + server_url: str + connection_config: Dict[str, Any] = Field(default_factory=dict) + last_health_check: Optional[datetime] = None + health_status: str = "unknown" + error_message: Optional[str] = None + available_tools: List[str] = Field(default_factory=list) + + class Config: + from_attributes = True + + +class ToolDetailSchema(ToolConfigSchema): + """工具详情模式(包含类型特定配置)""" + builtin_config: Optional[BuiltinToolConfigSchema] = None + custom_config: Optional[CustomToolConfigSchema] = None + mcp_config: Optional[MCPToolConfigSchema] = None + + +class ToolExecutionSchema(BaseModel): + """工具执行记录模式""" + id: str + execution_id: str + status: str + input_data: Optional[Dict[str, Any]] = None + output_data: Optional[Dict[str, Any]] = None + error_message: Optional[str] = None + started_at: datetime + completed_at: Optional[datetime] = None + execution_time: Optional[float] = None + token_usage: Optional[Dict[str, int]] = None + + class Config: + from_attributes = True + + +class ToolCreateRequest(BaseModel): + """创建工具请求""" + name: str = Field(..., min_length=1, max_length=255) + description: Optional[str] = Field(None, max_length=1000) + icon: Optional[str] = Field(None, max_length=255) + tool_type: ToolType + config: Dict[str, Any] = Field(default_factory=dict) + + +class ToolUpdateRequest(BaseModel): + """更新工具请求""" + name: Optional[str] = Field(None, min_length=1, max_length=255) + description: Optional[str] = Field(None, max_length=1000) + icon: Optional[str] = Field(None, max_length=255) + config: Optional[Dict[str, Any]] = None + is_enabled: Optional[bool] = None + + +class ToolExecuteRequest(BaseModel): + """执行工具请求""" + tool_id: str + parameters: Dict[str, Any] = Field(default_factory=dict) + timeout: Optional[float] = Field(60.0, gt=0, le=300) + + +class CustomToolCreateRequest(BaseModel): + """创建自定义工具请求""" + name: str = Field(..., min_length=1, max_length=255) + description: Optional[str] = Field(None, max_length=1000) + icon: Optional[str] = Field(None, max_length=255) + auth_type: AuthType = Field(AuthType.NONE, description="认证类型") + auth_config: Dict[str, Any] = Field(default_factory=dict, description="认证配置") + timeout: int = Field(30, ge=1, le=300, description="超时时间") + schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容") + schema_url: Optional[str] = Field(None, description="OpenAPI schema URL") + + +class ParseSchemaRequest(BaseModel): + """解析Schema请求""" + schema_content: Optional[Dict[str, Any]] = Field(None, description="OpenAPI schema内容") + schema_url: Optional[str] = Field(None, description="OpenAPI schema URL") + + +class ToolListQuery(BaseModel): + """工具列表查询参数""" + name: Optional[str] = None + tool_type: Optional[ToolType] = None + status: Optional[ToolStatus] = None + is_enabled: Optional[bool] = None + page: int = Field(1, ge=1) + page_size: int = Field(20, ge=1, le=100) + + +class ToolStatusCount(BaseModel): + """工具状态统计""" + status: ToolStatus + count: int + + +class ToolStatistics(BaseModel): + """工具统计信息""" + total_tools: int + status_counts: List[ToolStatusCount] + type_counts: Dict[str, int] + enabled_count: int + disabled_count: int + + +class CustomToolTestRequest(BaseModel): + """自定义工具测试请求""" + method: str = Field(..., description="HTTP方法") + path: str = Field(..., description="API路径") + parameters: Dict[str, Any] = Field(default_factory=dict, description="请求参数") \ No newline at end of file diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py new file mode 100644 index 00000000..9142a9ba --- /dev/null +++ b/api/app/services/tool_service.py @@ -0,0 +1,977 @@ +"""工具服务 - 统一的工具管理和执行服务""" +import json +import uuid +import time +import importlib +from typing import Dict, Any, List, Optional +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.core.tools.mcp import MCPClient +from app.repositories.tool_repository import ( + ToolRepository, BuiltinToolRepository, CustomToolRepository, + MCPToolRepository, ToolExecutionRepository +) + +from app.models.tool_model import ( + ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, + ToolExecution, ToolType, ToolStatus, ExecutionStatus, AuthType +) +from app.schemas.tool_schema import ToolInfo, ToolResult +from app.core.logging_config import get_business_logger +from app.core.tools.base import BaseTool +from app.core.tools.custom.base import CustomTool +from app.core.tools.mcp.base import MCPTool + +logger = get_business_logger() + +# 内置工具映射 +BUILTIN_TOOLS = { + "DateTimeTool": "app.core.tools.builtin.datetime_tool", + "JsonTool": "app.core.tools.builtin.json_tool", + "BaiduSearchTool": "app.core.tools.builtin.baidu_search_tool", + "MinerUTool": "app.core.tools.builtin.mineru_tool", + "TextInTool": "app.core.tools.builtin.textin_tool" +} + + +class ToolService: + """统一工具服务 - 管理工具的完整生命周期""" + + def __init__(self, db: Session): + self.db = db + self._tool_cache: Dict[str, BaseTool] = {} + + # 初始化仓储 + self.tool_repo = ToolRepository() + self.builtin_repo = BuiltinToolRepository() + self.custom_repo = CustomToolRepository() + self.mcp_repo = MCPToolRepository() + self.execution_repo = ToolExecutionRepository() + + def list_tools( + self, + tenant_id: uuid.UUID, + name: Optional[str] = None, + tool_type: Optional[ToolType] = None, + status: Optional[ToolStatus] = None + ) -> List[ToolInfo]: + """获取工具列表""" + try: + configs = self.tool_repo.find_by_tenant( + db=self.db, + tenant_id=tenant_id, + name=name, + tool_type=tool_type, + status=status + ) + return [self._config_to_info(config) for config in configs] + except Exception as e: + logger.error(f"获取工具列表失败: {e}") + return [] + + def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]: + """获取工具详情""" + config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) + return self._config_to_info(config) if config else None + + def create_tool( + self, + name: str, + tool_type: ToolType, + tenant_id: uuid.UUID, + icon: Optional[str] = None, + description: Optional[str] = None, + config: Optional[Dict[str, Any]] = None + ) -> str: + """创建工具""" + if tool_type == ToolType.BUILTIN: + raise ValueError("内置工具不允许创建") + + try: + # 创建基础配置 + tool_config = ToolConfig( + name=name, + description=description, + icon=icon, + tool_type=tool_type.value, + tenant_id=tenant_id, + status=ToolStatus.AVAILABLE.value, + config_data=config or {} + ) + self.db.add(tool_config) + self.db.flush() + + # 创建类型特定配置 + self._create_type_config(tool_config, config or {}) + + self.db.commit() + logger.info(f"工具创建成功: {tool_config.id}") + return str(tool_config.id) + + except Exception as e: + self.db.rollback() + logger.error(f"创建工具失败: {e}") + raise + + def update_tool( + self, + tool_id: str, + tenant_id: uuid.UUID, + name: Optional[str] = None, + description: Optional[str] = None, + icon: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, + is_enabled: Optional[bool] = None + ) -> bool: + """更新工具""" + config_obj = self._get_tool_config(tool_id, tenant_id) + if not config_obj: + return False + + if config_obj.tool_type == ToolType.BUILTIN.value: + if name or description or icon: + raise ValueError("内置工具不允许修改名称、描述和图标") + try: + if name: + config_obj.name = name + if description: + config_obj.description = description + if icon: + config_obj.icon = icon + if config: + config_obj.config_data = config.copy() + + # 同步到类型表 + self._sync_type_config(config_obj, config, is_enabled) + + # 更新状态逻辑 + self._update_tool_status(config_obj) + + # 清除缓存 + self._clear_tool_cache(tool_id) + + self.db.commit() + return True + + except Exception as e: + self.db.rollback() + logger.error(f"更新工具失败: {tool_id}, {e}") + return False + + def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool: + """删除工具""" + config = self._get_tool_config(tool_id, tenant_id) + if not config: + return False + + if config.tool_type == ToolType.BUILTIN.value: + raise ValueError("内置工具不允许删除") + + try: + # 删除关联表记录 + if config.tool_type == ToolType.CUSTOM.value: + self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete() + elif config.tool_type == ToolType.MCP.value: + self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete() + + # 删除主表记录(ToolExecution会通过cascade自动删除) + self.db.delete(config) + self._clear_tool_cache(tool_id) + self.db.commit() + return True + except Exception as e: + self.db.rollback() + logger.error(f"删除工具失败: {tool_id}, {e}") + return False + + async def execute_tool( + self, + tool_id: str, + parameters: Dict[str, Any], + tenant_id: uuid.UUID, + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + timeout: float = 60.0 + ) -> ToolResult: + """执行工具""" + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + start_time = time.time() + + try: + # 获取工具实例 + tool = self._get_tool_instance(tool_id, tenant_id) + if not tool: + return ToolResult.error_result( + error=f"工具不存在: {tool_id}", + execution_time=time.time() - start_time + ) + + # 记录执行开始 + self._record_execution_start( + execution_id, tool_id, parameters, user_id, workspace_id + ) + + # 执行工具 + result = await tool.safe_execute(**parameters) + + # 记录执行完成 + self._record_execution_complete(execution_id, result) + + return result + + except Exception as e: + execution_time = time.time() - start_time + error_result = ToolResult.error_result( + error=str(e), + execution_time=execution_time + ) + self._record_execution_complete(execution_id, error_result) + return error_result + + async def test_connection(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]: + """测试工具连接""" + try: + config = self._get_tool_config(tool_id, tenant_id) + if not config: + return {"success": False, "message": "工具不存在"} + + if config.tool_type == ToolType.MCP.value: + return await self._test_mcp_connection(config) + elif config.tool_type == ToolType.CUSTOM.value: + return await self._test_custom_connection(config) + elif config.tool_type == ToolType.BUILTIN.value: + return await self._test_builtin_connection(config) + else: + return {"success": True, "message": "未知工具类型"} + + except Exception as e: + return {"success": False, "message": f"测试失败: {str(e)}"} + + def ensure_builtin_tools_initialized(self, tenant_id: uuid.UUID): + """确保内置工具已初始化""" + existing = self.tool_repo.exists_builtin_for_tenant(self.db, tenant_id) + + if existing: + return + + # 从配置文件加载内置工具定义 + builtin_config = self._load_builtin_config() + + for tool_key, tool_info in builtin_config.items(): + try: + # 创建工具配置 + initial_status = self._determine_initial_status(tool_info) + tool_config = ToolConfig( + name=tool_info['name'], + description=tool_info['description'], + tool_type=ToolType.BUILTIN.value, + tenant_id=tenant_id, + status=initial_status, + config_data={"tool_class": tool_info['tool_class'], + "requires_config": tool_info.get('requires_config', False), + "is_enabled": False}, + version=tool_info["version"] + ) + self.db.add(tool_config) + self.db.flush() + + # 创建内置工具配置 + builtin_config_obj = BuiltinToolConfig( + id=tool_config.id, + tool_class=tool_info['tool_class'], + parameters={}, + requires_config=tool_info.get('requires_config', False) + ) + self.db.add(builtin_config_obj) + + except Exception as e: + logger.error(f"初始化内置工具失败: {tool_key}, {e}") + + self.db.commit() + logger.info(f"租户 {tenant_id} 内置工具初始化完成") + + def get_tool_statistics(self, tenant_id: uuid.UUID) -> Dict[str, Any]: + """获取工具统计信息""" + try: + # 总数统计 + total_tools = self.tool_repo.count_by_tenant(self.db, tenant_id) + + # 状态统计 + status_counts = self.tool_repo.get_status_statistics(self.db, tenant_id) + + # 类型统计 + type_counts = self.tool_repo.get_type_statistics(self.db, tenant_id) + + # 启用/禁用统计 + enabled_count = self.tool_repo.count_enabled_by_tenant(self.db, tenant_id) + disabled_count = total_tools - enabled_count + + return { + "total_tools": total_tools, + "status_counts": [ + {"status": status, "count": count} + for status, count in status_counts + ], + "type_counts": { + tool_type: count for tool_type, count in type_counts + }, + "enabled_count": enabled_count, + "disabled_count": disabled_count + } + except Exception as e: + logger.error(f"获取工具统计失败: {e}") + return { + "total_tools": 0, + "status_counts": [], + "type_counts": {}, + "enabled_count": 0, + "disabled_count": 0 + } + + def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]: + """获取工具配置""" + return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) + + def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: + """获取工具实例""" + if tool_id in self._tool_cache: + return self._tool_cache[tool_id] + + config = self._get_tool_config(tool_id, tenant_id) + if not config: + return None + + try: + tool = self._create_tool_instance(config) + if tool: + self._tool_cache[tool_id] = tool + return tool + except Exception as e: + logger.error(f"创建工具实例失败: {tool_id}, {e}") + return None + + def _create_tool_instance(self, config: ToolConfig) -> Optional[BaseTool]: + """创建工具实例""" + if config.tool_type == ToolType.BUILTIN.value: + return self._create_builtin_instance(config) + elif config.tool_type == ToolType.CUSTOM.value: + return self._create_custom_instance(config) + elif config.tool_type == ToolType.MCP.value: + return self._create_mcp_instance(config) + return None + + def _create_builtin_instance(self, config: ToolConfig) -> Optional[BaseTool]: + """创建内置工具实例""" + builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id) + + if not builtin_config or builtin_config.tool_class not in BUILTIN_TOOLS: + return None + + try: + module_path = BUILTIN_TOOLS[builtin_config.tool_class] + module = importlib.import_module(module_path) + tool_class = getattr(module, builtin_config.tool_class) + + tool_config = { + **config.config_data, + "parameters": builtin_config.parameters, + } + + return tool_class(str(config.id), tool_config) + except Exception as e: + logger.error(f"创建内置工具实例失败: {builtin_config.tool_class}, {e}") + return None + + def _create_custom_instance(self, config: ToolConfig) -> Optional[CustomTool]: + """创建自定义工具实例""" + custom_config = self.custom_repo.find_by_tool_id(self.db, config.id) + + if not custom_config: + return None + + tool_config = { + "base_url": custom_config.base_url, + "auth_type": custom_config.auth_type, + "auth_config": custom_config.auth_config or {}, + "timeout": custom_config.timeout or 30, + "schema_content": custom_config.schema_content, + "schema_url": custom_config.schema_url + } + + return CustomTool(str(config.id), tool_config) + + def _create_mcp_instance(self, config: ToolConfig) -> Optional[MCPTool]: + """创建MCP工具实例""" + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) + + if not mcp_config: + return None + + tool_config = { + "server_url": mcp_config.server_url, + "connection_config": mcp_config.connection_config or {}, + "available_tools": mcp_config.available_tools or [] + } + + return MCPTool(str(config.id), tool_config) + + def _config_to_info(self, config: ToolConfig) -> ToolInfo: + """配置转换为信息对象""" + config_data = config.config_data or {} + + # 对于MCP工具,从MCPToolConfig获取额外信息 + if config.tool_type == ToolType.MCP.value: + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) + if mcp_config: + config_data.update({ + "last_health_check": int(mcp_config.last_health_check.timestamp() * 1000) if mcp_config.last_health_check else None, + "health_status": mcp_config.health_status, + "available_tools": mcp_config.available_tools or [] + }) + + return ToolInfo( + id=str(config.id), + name=config.name, + description=config.description or "", + icon=config.icon, + tool_type=ToolType(config.tool_type), + version=config.version or "1.0.0", + status=ToolStatus(config.status), + tags=config.tags or [], + tenant_id=str(config.tenant_id) if config.tenant_id else None, + config_data=config_data, + created_at=config.created_at + ) + + def _create_type_config(self, tool_config: ToolConfig, config: Dict[str, Any]): + """创建类型特定配置""" + if tool_config.tool_type == ToolType.CUSTOM.value: + # 从 schema 中解析 base_url + base_url = config.get("base_url") + if not base_url and (config.get("schema_content") or config.get("schema_url")): + try: + from app.core.tools.custom.schema_parser import OpenAPISchemaParser + parser = OpenAPISchemaParser() + + if config.get("schema_content"): + success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]), "application/json") + else: + success, schema, _ = parser.parse_from_url(config["schema_url"]) + + if success: + tool_info = parser.extract_tool_info(schema) + servers = tool_info.get("servers", []) + base_url = servers[0].get("url") if servers else "" + except Exception as e: + logger.error(f"解析schema获取base_url失败: {e}") + + custom_config = CustomToolConfig( + id=tool_config.id, + base_url=base_url, + auth_type=config.get("auth_type", "none"), + auth_config=config.get("auth_config", {}), + timeout=config.get("timeout", 30), + schema_content=config.get("schema_content"), + schema_url=config.get("schema_url") + ) + self.db.add(custom_config) + + elif tool_config.tool_type == ToolType.MCP.value: + mcp_config = MCPToolConfig( + id=tool_config.id, + server_url=config.get("server_url"), + connection_config=config.get("connection_config", {}), + available_tools=config.get("available_tools", []) + ) + self.db.add(mcp_config) + + def _sync_type_config(self, tool_config: ToolConfig, config: Dict[str, Any], is_enabled: bool): + """同步到类型特定表""" + if tool_config.tool_type == ToolType.BUILTIN.value: + builtin_config = self.db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + if builtin_config: + builtin_config.parameters = config.get("parameters", {}) + if is_enabled is not None: + builtin_config.is_enabled = is_enabled + + elif tool_config.tool_type == ToolType.CUSTOM.value: + custom_config = self.db.query(CustomToolConfig).filter( + CustomToolConfig.id == tool_config.id + ).first() + if custom_config: + base_url = config.get("base_url") + if not base_url and (config.get("schema_content") or config.get("schema_url")): + try: + from app.core.tools.custom.schema_parser import OpenAPISchemaParser + parser = OpenAPISchemaParser() + + if config.get("schema_content"): + success, schema, _ = parser.parse_from_content(json.dumps(config["schema_content"]), + "application/json") + else: + success, schema, _ = parser.parse_from_url(config["schema_url"]) + + if success: + tool_info = parser.extract_tool_info(schema) + servers = tool_info.get("servers", []) + base_url = servers[0].get("url") if servers else "" + except Exception as e: + logger.error(f"解析schema获取base_url失败: {e}") + custom_config.base_url = base_url + custom_config.auth_type = config.get("auth_type", "none") + custom_config.auth_config = config.get("auth_config", {}) + custom_config.timeout = config.get("timeout", 30) + custom_config.schema_content = config.get("schema_content") + custom_config.schema_url = config.get("schema_url") + + elif tool_config.tool_type == ToolType.MCP.value: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == tool_config.id + ).first() + if mcp_config: + mcp_config.server_url = config.get("server_url") + mcp_config.connection_config = config.get("connection_config", {}) + mcp_config.available_tools = config.get("available_tools", []) + + @staticmethod + def _determine_initial_status(tool_info: Dict[str, Any]) -> str: + """确定工具初始状态""" + if tool_info.get('requires_config', False): + return ToolStatus.UNCONFIGURED + else: + return ToolStatus.AVAILABLE + + def _update_tool_status(self, tool_config: ToolConfig): + """更新工具状态逻辑""" + if tool_config.tool_type == ToolType.BUILTIN.value: + builtin_config = self.db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + + if builtin_config: + if builtin_config.requires_config: + # 需要配置的工具 + if self._is_tool_configured(builtin_config): + if tool_config.config_data.get("is_enabled", None): + tool_config.status = ToolStatus.AVAILABLE.value + else: + tool_config.status = ToolStatus.CONFIGURED_DISABLED.value + else: + tool_config.status = ToolStatus.UNCONFIGURED.value + else: + # 不需要配置的工具 + tool_config.status = ToolStatus.AVAILABLE.value + + elif tool_config.tool_type == ToolType.CUSTOM.value: + custom_config = self.db.query(CustomToolConfig).filter( + CustomToolConfig.id == tool_config.id + ).first() + + if custom_config and tool_config.name and (custom_config.schema_content or custom_config.schema_url): + tool_config.status = ToolStatus.AVAILABLE.value + else: + tool_config.status = ToolStatus.UNCONFIGURED.value + + elif tool_config.tool_type == ToolType.MCP.value: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == tool_config.id + ).first() + + if mcp_config: + if mcp_config.health_status == "healthy": + tool_config.status = ToolStatus.AVAILABLE.value + elif mcp_config.health_status == "error": + tool_config.status = ToolStatus.ERROR.value + else: + tool_config.status = ToolStatus.UNCONFIGURED.value + + def _is_tool_configured(self, builtin_config: BuiltinToolConfig) -> bool: + """检查工具是否已配置""" + # 从配置文件获取必需参数 + builtin_config_data = self._load_builtin_config() + required_params = {} + for key, value in builtin_config_data.items(): + if builtin_config.tool_class == value["tool_class"]: + required_params = value.get('parameters', {}) + break + + # 检查所有必需参数是否已配置 + for param_name, param_info in required_params.items(): + if param_info.get('required', False): + if not builtin_config.parameters.get(param_name): + return False + return True + + def _clear_tool_cache(self, tool_id: str): + """清除工具缓存""" + if tool_id in self._tool_cache: + del self._tool_cache[tool_id] + + def _record_execution_start( + self, + execution_id: str, + tool_id: str, + parameters: Dict[str, Any], + user_id: Optional[uuid.UUID], + workspace_id: Optional[uuid.UUID] + ): + """记录执行开始""" + try: + execution = ToolExecution( + execution_id=execution_id, + tool_config_id=uuid.UUID(tool_id), + status=ExecutionStatus.RUNNING.value, + input_data=parameters, + started_at=datetime.now(), + user_id=user_id, + workspace_id=workspace_id + ) + self.db.add(execution) + self.db.commit() + except Exception as e: + logger.error(f"记录执行开始失败: {execution_id}, {e}") + + def _record_execution_complete(self, execution_id: str, result: ToolResult): + """记录执行完成""" + try: + execution = self.db.query(ToolExecution).filter( + ToolExecution.execution_id == execution_id + ).first() + + if execution: + execution.status = ExecutionStatus.COMPLETED.value if result.success else ExecutionStatus.FAILED.value + execution.output_data = result.data if result.success else None + execution.error_message = result.error if not result.success else None + execution.completed_at = datetime.now() + execution.execution_time = result.execution_time + execution.token_usage = result.token_usage + self.db.commit() + except Exception as e: + logger.error(f"记录执行完成失败: {execution_id}, {e}") + + @staticmethod + def _load_builtin_config() -> Dict[str, Any]: + """加载内置工具配置""" + import json + from pathlib import Path + + config_file = Path(__file__).parent.parent / "core" / "tools" / "configs" / "builtin_tools.json" + try: + with open(config_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"加载内置工具配置失败: {e}") + return {} + + async def _test_mcp_connection(self, config: ToolConfig) -> Dict[str, Any]: + """测试MCP连接""" + try: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == config.id + ).first() + + if not mcp_config: + return {"success": False, "message": "MCP配置不存在"} + + client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {}) + + if await client.connect(): + try: + tools = await client.list_tools() + await client.disconnect() + + # 更新连接状态 + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "healthy" + mcp_config.error_message = None + + # 更新工具状态 + self._update_tool_status(config) + self.db.commit() + + return { + "success": True, + "message": "MCP连接成功", + "details": {"server_url": mcp_config.server_url, "tools_count": len(tools)} + } + except Exception as e: + await client.disconnect() + + # 更新错误状态 + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "error" + mcp_config.error_message = str(e) + self._update_tool_status(config) + self.db.commit() + + return {"success": False, "message": f"MCP功能测试失败: {str(e)}"} + else: + # 更新连接失败状态 + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "error" + mcp_config.error_message = "连接失败" + self._update_tool_status(config) + self.db.commit() + + return {"success": False, "message": "MCP连接失败"} + + except Exception as e: + # 更新异常状态 + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == config.id + ).first() + if mcp_config: + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "error" + mcp_config.error_message = str(e) + self._update_tool_status(config) + self.db.commit() + + return {"success": False, "message": f"MCP测试异常: {str(e)}"} + + @staticmethod + async def parse_openapi_schema(schema_data: Dict[str, Any] = None, schema_url: str = None) -> Dict[str, Any]: + """解析OpenAPI schema获取接口信息""" + try: + from app.core.tools.custom.schema_parser import OpenAPISchemaParser + + parser = OpenAPISchemaParser() + + # 使用现有的解析器 + if schema_data: + success, schema, error = parser.parse_from_content(json.dumps(schema_data), "application/json") + elif schema_url: + success, schema, error = await parser.parse_from_url(schema_url) + else: + return {"success": False, "message": "schema_data或schema_url必须提供一个"} + + if not success: + return {"success": False, "message": error} + + # 提取工具信息 + tool_info = parser.extract_tool_info(schema) + + # 获取base_url + servers = tool_info.get("servers", []) + base_url = servers[0].get("url") if servers else "" + + return { + "success": True, + "data": { + "title": tool_info["name"], + "description": tool_info["description"], + "version": tool_info["version"], + "base_url": base_url, + "operations": list(tool_info["operations"].values()) + } + } + + except Exception as e: + logger.error(f"解析OpenAPI schema失败: {e}") + return {"success": False, "message": f"解析失败: {str(e)}"} + + async def sync_mcp_tools(self, tool_id: str, tenant_id: uuid.UUID) -> Dict[str, Any]: + """同步MCP工具列表到数据库""" + try: + config = self._get_tool_config(tool_id, tenant_id) + if not config or config.tool_type != ToolType.MCP.value: + return {"success": False, "message": "工具不存在或不是MCP工具"} + + mcp_config = self.mcp_repo.find_by_tool_id(self.db, config.id) + if not mcp_config: + return {"success": False, "message": "MCP配置不存在"} + + # 创建MCP客户端 + connection_config = mcp_config.connection_config or {} + + client = MCPClient(mcp_config.server_url, connection_config) + + if await client.connect(): + try: + # 获取工具列表 + tools = await client.list_tools() + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + + # 更新数据库 + mcp_config.available_tools = tool_names + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "healthy" + mcp_config.error_message = None + + # 更新工具状态 + config.status = ToolStatus.AVAILABLE.value + + self.db.commit() + + await client.disconnect() + + return { + "success": True, + "message": "工具列表同步成功", + "tools_count": len(tool_names), + "tools": tool_names + } + + except Exception as e: + await client.disconnect() + + # 更新错误状态 + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "error" + mcp_config.error_message = str(e) + config.status = ToolStatus.ERROR.value + self.db.commit() + + return {"success": False, "message": f"获取工具列表失败: {str(e)}"} + else: + # 连接失败 + mcp_config.last_health_check = datetime.now() + mcp_config.health_status = "error" + mcp_config.error_message = "连接失败" + config.status = ToolStatus.ERROR.value + self.db.commit() + + return {"success": False, "message": "MCP连接失败"} + + except Exception as e: + logger.error(f"同步MCP工具列表失败: {tool_id}, 错误: {e}") + return {"success": False, "message": f"同步失败: {str(e)}"} + + async def _test_custom_connection(self, config: ToolConfig) -> Dict[str, Any]: + """测试自定义工具连接(基础连接测试)""" + try: + custom_config = self.db.query(CustomToolConfig).filter( + CustomToolConfig.id == config.id + ).first() + + if not custom_config or not custom_config.base_url: + return {"success": False, "message": "自定义工具配置不完整"} + + import aiohttp + async with aiohttp.ClientSession() as session: + async with session.get( + custom_config.base_url, + timeout=aiohttp.ClientTimeout(total=10) + ) as response: + if response.status == 200: + return {"success": True, "message": "自定义工具连接成功"} + else: + return {"success": False, "message": f"连接失败,状态码: {response.status}"} + + except Exception as e: + return {"success": False, "message": f"自定义工具测试失败: {str(e)}"} + + async def test_custom_tool( + self, + tool_id: str, + tenant_id: uuid.UUID, + method: str, + path: str, + parameters: Dict[str, Any] + ) -> Dict[str, Any]: + """测试自定义工具API调用""" + try: + config = self._get_tool_config(tool_id, tenant_id) + if not config or config.tool_type != ToolType.CUSTOM.value: + return {"success": False, "message": "工具不存在或不是自定义工具"} + + custom_config = self.db.query(CustomToolConfig).filter( + CustomToolConfig.id == config.id + ).first() + + if not custom_config or not custom_config.base_url: + return {"success": False, "message": "自定义工具配置不完整"} + + # 构建完整URL + url = custom_config.base_url.rstrip('/') + '/' + path.lstrip('/') + + # 构建请求头 + headers = {"Content-Type": "application/json"} + + # 添加认证头 + if custom_config.auth_type != AuthType.NONE.value: + auth_config = custom_config.auth_config or {} + if custom_config.auth_type == AuthType.API_KEY.value: + key_name = auth_config.get("key_name", "X-API-Key") + api_key = auth_config.get("api_key") + if api_key: + headers[key_name] = api_key + elif custom_config.auth_type == AuthType.BEARER_TOKEN.value: + token = auth_config.get("token") + if token: + headers["Authorization"] = f"Bearer {token}" + elif custom_config.auth_type == AuthType.BASIC_AUTH.value: + import base64 + username = auth_config.get("username", "") + password = auth_config.get("password", "") + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + headers["Authorization"] = f"Basic {credentials}" + + import aiohttp + async with aiohttp.ClientSession() as session: + # 根据方法发送请求 + if method.upper() == "GET": + async with session.get( + url, + params=parameters, + headers=headers, + timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30) + ) as response: + result_data = await response.text() + return { + "success": True, + "message": "测试成功", + "status_code": response.status, + "response_data": result_data[:1000] # 限制返回数据长度 + } + else: + async with session.request( + method.upper(), + url, + json=parameters, + headers=headers, + timeout=aiohttp.ClientTimeout(total=custom_config.timeout or 30) + ) as response: + result_data = await response.text() + return { + "success": True, + "message": "测试成功", + "status_code": response.status, + "response_data": result_data[:1000] # 限制返回数据长度 + } + + except Exception as e: + logger.error(f"测试自定义工具API失败: {tool_id}, 错误: {e}") + return {"success": False, "message": f"测试失败: {str(e)}"} + + async def _test_builtin_connection(self, config: ToolConfig) -> Dict[str, Any]: + """测试内置工具连接""" + try: + # 获取工具实例 + tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + if not tool_instance: + return {"success": False, "message": "无法创建工具实例"} + + # 检查工具是否有test_connection方法 + if hasattr(tool_instance, 'test_connection'): + result = await tool_instance.test_connection() + return result + else: + # 检查是否需要配置 + builtin_config = self.builtin_repo.find_by_tool_id(self.db, config.id) + if builtin_config and builtin_config.requires_config: + # 检查必需参数是否已配置 + if self._is_tool_configured(builtin_config): + return {"success": True, "message": "内置工具已正确配置"} + else: + return {"success": False, "message": "工具缺少必需配置参数"} + else: + return {"success": True, "message": "内置工具无需连接测试"} + + except Exception as e: + logger.error(f"测试内置工具失败: {config.id}, 错误: {e}") + return {"success": False, "message": f"测试失败: {str(e)}"} diff --git a/api/test_tool_system.py b/api/test_tool_system.py deleted file mode 100644 index 30d60d23..00000000 --- a/api/test_tool_system.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python3 -""" -工具管理系统基础测试脚本 -用于验证系统的基本功能是否正常 -""" - -import asyncio -import uuid -from datetime import datetime - -# 测试导入 -def test_imports(): - """测试模块导入""" - print("测试模块导入...") - - try: - from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType - print("✓ 基础工具模块导入成功") - except ImportError as e: - print(f"✗ 基础工具模块导入失败: {e}") - return False - - try: - from app.core.tools.builtin.datetime_tool import DateTimeTool - from app.core.tools.builtin.json_tool import JsonTool - print("✓ 内置工具模块导入成功") - except ImportError as e: - print(f"✗ 内置工具模块导入失败: {e}") - return False - - try: - from app.core.tools.langchain_adapter import LangchainAdapter - print("✓ Langchain适配器导入成功") - except ImportError as e: - print(f"✗ Langchain适配器导入失败: {e}") - return False - - try: - from app.models.tool_model import ToolConfig, ToolType, ToolStatus - print("✓ 工具模型导入成功") - except ImportError as e: - print(f"✗ 工具模型导入失败: {e}") - return False - - try: - from app.core.tools.custom import CustomTool, OpenAPISchemaParser, AuthManager - print("✓ 自定义工具模块导入成功") - except ImportError as e: - print(f"✗ 自定义工具模块导入失败: {e}") - return False - - try: - from app.core.tools.mcp import MCPTool, MCPClient, MCPServiceManager - print("✓ MCP工具模块导入成功") - except ImportError as e: - print(f"✗ MCP工具模块导入失败: {e}") - return False - - return True - - -def test_tool_creation(): - """测试工具创建""" - print("\n测试工具创建...") - - try: - from app.core.tools.builtin.datetime_tool import DateTimeTool - - # 创建时间工具实例(全局工具) - tool_id = str(uuid.uuid4()) - config = { - "parameters": {"timezone": "UTC"}, - "tenant_id": None, # 全局工具 - "version": "1.0.0", - "tags": ["time", "utility", "builtin"] - } - - datetime_tool = DateTimeTool(tool_id, config) - - # 验证工具属性 - assert datetime_tool.name == "datetime_tool" - assert datetime_tool.tool_type.value == "builtin" - assert len(datetime_tool.parameters) > 0 - - print("✓ 时间工具创建成功(全局工具)") - return True - - except Exception as e: - print(f"✗ 工具创建失败: {e}") - return False - - -async def test_tool_execution(): - """测试工具执行""" - print("\n测试工具执行...") - - try: - from app.core.tools.builtin.datetime_tool import DateTimeTool - - # 创建时间工具实例 - tool_id = str(uuid.uuid4()) - config = { - "parameters": {"timezone": "UTC"}, - "tenant_id": None, # 全局工具 - "version": "1.0.0" - } - - datetime_tool = DateTimeTool(tool_id, config) - - # 测试获取当前时间 - result = await datetime_tool.safe_execute(operation="now") - - assert result.success == True - assert "datetime" in result.data - assert result.execution_time > 0 - - print("✓ 工具执行成功") - print(f" 执行时间: {result.execution_time:.3f}秒") - print(f" 返回数据: {result.data}") - - return True - - except Exception as e: - print(f"✗ 工具执行失败: {e}") - return False - - -def test_langchain_adapter(): - """测试Langchain适配器""" - print("\n测试Langchain适配器...") - - try: - from app.core.tools.builtin.json_tool import JsonTool - from app.core.tools.langchain_adapter import LangchainAdapter - - # 创建JSON工具实例 - tool_id = str(uuid.uuid4()) - config = { - "parameters": {"indent": 2}, - "tenant_id": None, # 全局工具 - "version": "1.0.0" - } - - json_tool = JsonTool(tool_id, config) - - # 验证Langchain兼容性 - is_compatible, issues = LangchainAdapter.validate_langchain_compatibility(json_tool) - - if not is_compatible: - print(f"✗ Langchain兼容性验证失败: {issues}") - return False - - # 创建工具描述 - description = LangchainAdapter.create_tool_description(json_tool) - - assert "name" in description - assert "parameters" in description - assert description["langchain_compatible"] == True - - print("✓ Langchain适配器测试成功") - return True - - except Exception as e: - print(f"✗ Langchain适配器测试失败: {e}") - return False - - -def test_config_manager(): - """测试配置管理器""" - print("\n测试配置管理器...") - - try: - from app.core.tools.config_manager import ConfigManager - - # 创建配置管理器 - config_manager = ConfigManager() - - # 获取配置摘要 - summary = config_manager.get_config_summary() - - assert "config_dir" in summary - assert "total_configs" in summary - - print("✓ 配置管理器测试成功") - print(f" 配置目录: {summary['config_dir']}") - print(f" 总配置数: {summary['total_configs']}") - - return True - - except Exception as e: - print(f"✗ 配置管理器测试失败: {e}") - return False - - -def test_schema_parser(): - """测试OpenAPI Schema解析器""" - print("\n测试OpenAPI Schema解析器...") - - try: - from app.core.tools.custom.schema_parser import OpenAPISchemaParser - - # 创建解析器 - parser = OpenAPISchemaParser() - - # 测试简单的OpenAPI schema - test_schema = { - "openapi": "3.0.0", - "info": { - "title": "Test API", - "version": "1.0.0", - "description": "测试API" - }, - "paths": { - "/test": { - "get": { - "summary": "测试接口", - "operationId": "test_operation", - "responses": { - "200": { - "description": "成功" - } - } - } - } - } - } - - # 验证schema - is_valid, error_msg = parser.validate_schema(test_schema) - assert is_valid, f"Schema验证失败: {error_msg}" - - # 提取工具信息 - tool_info = parser.extract_tool_info(test_schema) - assert tool_info["name"] == "Test API" - assert "test_operation" in tool_info["operations"] - - print("✓ OpenAPI Schema解析器测试成功") - return True - - except Exception as e: - print(f"✗ OpenAPI Schema解析器测试失败: {e}") - return False - - -def test_auth_manager(): - """测试认证管理器""" - print("\n测试认证管理器...") - - try: - from app.core.tools.custom.auth_manager import AuthManager - from app.models.tool_model import AuthType - - # 创建认证管理器 - auth_manager = AuthManager() - - # 测试API Key认证配置 - api_key_config = { - "api_key": "test-key-123", - "key_name": "X-API-Key", - "location": "header" - } - - is_valid, error_msg = auth_manager.validate_auth_config(AuthType.API_KEY, api_key_config) - assert is_valid, f"API Key配置验证失败: {error_msg}" - - # 测试Bearer Token认证配置 - bearer_config = { - "token": "bearer-token-123" - } - - is_valid, error_msg = auth_manager.validate_auth_config(AuthType.BEARER_TOKEN, bearer_config) - assert is_valid, f"Bearer Token配置验证失败: {error_msg}" - - # 测试认证应用 - url = "https://api.example.com/test" - headers = {} - params = {} - - new_url, new_headers, new_params = auth_manager.apply_authentication( - AuthType.API_KEY, api_key_config, url, headers, params - ) - - assert "X-API-Key" in new_headers - assert new_headers["X-API-Key"] == "test-key-123" - - print("✓ 认证管理器测试成功") - return True - - except Exception as e: - print(f"✗ 认证管理器测试失败: {e}") - return False - - -def test_builtin_initializer(): - """测试内置工具初始化器""" - print("\n测试内置工具初始化器...") - - try: - from app.core.tools.builtin_initializer import BuiltinToolInitializer - - # 注意:这里不能真正初始化,因为需要数据库连接 - # 只测试类的创建和基本方法 - - # 模拟数据库会话(实际使用中需要真实的数据库连接) - class MockDB: - def query(self, *args): - return self - def filter(self, *args): - return self - def first(self): - return None - def all(self): - return [] - - mock_db = MockDB() - initializer = BuiltinToolInitializer(mock_db) - - # 测试获取内置工具状态(会返回空列表,因为没有真实数据) - status = initializer.get_builtin_tools_status() - assert isinstance(status, list) - - print("✓ 内置工具初始化器测试成功") - return True - - except Exception as e: - print(f"✗ 内置工具初始化器测试失败: {e}") - return False - - -async def main(): - """主测试函数""" - print("=" * 50) - print("工具管理系统基础测试") - print("=" * 50) - - tests = [ - ("模块导入", test_imports), - ("工具创建", test_tool_creation), - ("工具执行", test_tool_execution), - ("Langchain适配", test_langchain_adapter), - ("配置管理", test_config_manager), - ("Schema解析器", test_schema_parser), - ("认证管理器", test_auth_manager), - ("内置工具初始化器", test_builtin_initializer) - ] - - passed = 0 - total = len(tests) - - for test_name, test_func in tests: - try: - if asyncio.iscoroutinefunction(test_func): - result = await test_func() - else: - result = test_func() - - if result: - passed += 1 - except Exception as e: - print(f"✗ {test_name}测试异常: {e}") - - print("\n" + "=" * 50) - print(f"测试结果: {passed}/{total} 通过") - - if passed == total: - print("🎉 所有基础测试通过!工具管理系统基本功能正常。") - return True - else: - print("⚠️ 部分测试失败,请检查相关模块。") - return False - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file