diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 27f65b1d..5cfbe536 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -32,6 +32,8 @@ from . import ( emotion_controller, emotion_config_controller, prompt_optimizer_controller, + tool_controller, + tool_execution_controller, ) # 创建管理端 API 路由器 @@ -66,4 +68,7 @@ manager_router.include_router(emotion_controller.router) manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(memory_reflection_controller.router) +manager_router.include_router(tool_controller.router) +manager_router.include_router(tool_execution_controller.router) + __all__ = ["manager_router"] diff --git a/api/app/controllers/tool_controller.py b/api/app/controllers/tool_controller.py new file mode 100644 index 00000000..433392d2 --- /dev/null +++ b/api/app/controllers/tool_controller.py @@ -0,0 +1,585 @@ +"""工具管理API控制器""" +import base64 +from typing import List, Optional, Dict, Any + +from fastapi import APIRouter, Depends, HTTPException, Body +from langfuse.api.core import jsonable_encoder +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session +from pydantic import BaseModel, Field, PositiveInt, field_validator +from cryptography.fernet import Fernet + +from app.db import get_db +from app.dependencies import get_current_user +from app.models import User +from app.models.tool_model import ToolConfig, BuiltinToolConfig, ToolType, ToolStatus, CustomToolConfig, MCPToolConfig +from app.core.logging_config import get_business_logger +from app.core.config import settings +from app.core.tools.config_manager import ConfigManager + +logger = get_business_logger() + +router = APIRouter(prefix="/tools", tags=["工具管理"]) + + +# ==================== 辅助函数 ==================== + + +def _encrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]: + """加密敏感参数""" + cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode()) + cipher = Fernet(cipher_key) + + encrypted_params = {} + sensitive_keys = ['api_key', 'token', 'api_secret', 'password'] + + for key, value in parameters.items(): + if any(sensitive in key.lower() for sensitive in sensitive_keys) and value: + encrypted_params[key] = cipher.encrypt(str(value).encode()).decode() + else: + encrypted_params[key] = value + + return encrypted_params + + +def _decrypt_sensitive_params(parameters: Dict[str, Any]) -> Dict[str, Any]: + """解密敏感参数""" + cipher_key = base64.urlsafe_b64encode(settings.SECRET_KEY[:32].ljust(32, '0').encode()) + cipher = Fernet(cipher_key) + + decrypted_params = {} + sensitive_keys = ['api_key', 'token', 'secret', 'password'] + + for key, value in parameters.items(): + if any(sensitive in key.lower() for sensitive in sensitive_keys) and value: + try: + decrypted_params[key] = cipher.decrypt(value.encode()).decode() + except Exception as e: + decrypted_params[key] = value + else: + decrypted_params[key] = value + + return decrypted_params + + +def _update_tool_status(tool_config: ToolConfig, builtin_config: BuiltinToolConfig = None, tool_info: Dict = None) -> str: + """更新工具状态并返回新状态""" + if tool_config.tool_type == ToolType.BUILTIN: + if not tool_info or not tool_info.get('requires_config', False): + new_status = ToolStatus.ACTIVE.value # 不需要配置的内置工具 + elif not builtin_config or not builtin_config.parameters: + new_status = ToolStatus.INACTIVE.value + else: + # 检查是否有必要的API密钥 + has_key = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token')) + new_status = ToolStatus.ACTIVE.value if has_key else ToolStatus.INACTIVE.value + else: # 自定义和MCP工具 + new_status = ToolStatus.ACTIVE.value if tool_config.config_data else ToolStatus.ERROR.value + + # 更新数据库中的状态 + if tool_config.status != new_status: + tool_config.status = new_status + + return new_status + + +# ==================== 请求/响应模型 ==================== + +class ToolListResponse(BaseModel): + """工具列表响应""" + id: str + name: str + description: str + tool_type: str + category: str + version: str = "1.0.0" + status: str # active inactive error loading + requires_config: bool = False + # is_configured: bool = False + + class Config: + from_attributes = True + +class BuiltinToolConfigRequest(BaseModel): + """内置工具配置请求""" + parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") + + +class CustomToolCreateRequest(BaseModel): + """自定义工具创建请求体模型,包含参数校验规则""" + name: str = Field(..., min_length=1, max_length=100, description="工具名称,必填") + description: str = Field(None, description="工具描述") + base_url: str = Field(None, description="工具基础URL") + schema_url: str = Field(None, description="工具Schema URL") + schema_content: Optional[Dict[str, Any]] = Field(None, description="工具Schema内容,可选") + auth_type: str = Field("none", pattern=r"^(none|api_key|bearer_token)$", description="认证类型") + auth_config: Optional[Dict[str, Any]] = Field(None, description="认证配置,默认空字典") + timeout: PositiveInt = Field(30, ge=1, le=300, description="超时时间,1-300秒,默认30") + + # 自定义校验:当auth_type为api_key时,auth_config必须包含api_key字段 + @field_validator("auth_config") + def validate_auth_config(cls, v, values): + auth_type = values.data.get("auth_type") + if auth_type == "api_key" and (not v or "api_key" not in v): + raise ValueError("认证类型为api_key时,auth_config必须包含api_key字段") + if auth_type == "bearer_token" and (not v or "bearer_token" not in v): + raise ValueError("认证类型为bearer_token时,auth_config必须包含bearer_token字段") + return v + +class MCPToolCreateRequest(BaseModel): + """MCP工具创建请求体模型,适配MCP业务特性""" + # 基础必填字段(带长度/格式校验) + name: str = Field(..., min_length=1, max_length=100,description="MCP工具名称") + description: str = Field(None, description="MCP工具描述") + # MCP核心字段:服务端URL(强制HTTP/HTTPS格式) + server_url: str = Field(..., description="MCP服务端URL,仅支持http/https协议") + # 连接配置:默认空字典,可自定义校验规则(根据实际业务调整) + connection_config: Dict[str, Any] = Field({},description="MCP连接配置(如认证信息、超时、重试等),默认空字典") + + @field_validator("connection_config") + def validate_connection_config(cls, v): + # 示例1:若包含timeout,必须是1-300的整数 + if "timeout" in v: + timeout = v["timeout"] + if not isinstance(timeout, int) or timeout < 1 or timeout > 300: + raise ValueError("connection_config.timeout必须是1-300的整数") + return v + + # @field_validator("server_url") + # def validate_server_url_protocol(cls, v): + # if v.scheme != "https": + # raise ValueError("MCP服务端URL仅支持HTTPS协议(安全要求)") + # return v + + +# ==================== API端点 ==================== +@router.get("", response_model=List[ToolListResponse]) +async def list_tools( + name: Optional[str] = None, + tool_type: Optional[str] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取工具列表(包含内置工具、自定义工具和MCP工具)""" + try: + # 初始化内置工具(如果需要) + config_manager = ConfigManager() + config_manager.ensure_builtin_tools_initialized( + current_user.tenant_id, db, ToolConfig, BuiltinToolConfig, ToolType, ToolStatus + ) + + response_tools = [] + + query = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id + ) + if tool_type: + query = query.filter(ToolConfig.tool_type == tool_type) + + if name: + query = query.filter(ToolConfig.name.ilike(f"%{name}%")) + + tools = query.all() + builtin_tools = config_manager.load_builtin_tools_config() + configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()} + + for tool_config in tools: + if tool_config.tool_type == ToolType.BUILTIN.value: + builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first() + tool_info = configured_tools.get(builtin_config.tool_class) + status = _update_tool_status(tool_config, builtin_config, tool_info) + else: + status = _update_tool_status(tool_config) + + response_tools.append(ToolListResponse( + id=str(tool_config.id), + name=tool_config.name, + description=tool_config.description, + tool_type=tool_config.tool_type, + category=tool_info['category'] if tool_config.tool_type == ToolType.BUILTIN.value else tool_config.tool_type, + version="1.0.0", + status=status, + requires_config=tool_info['requires_config'] if tool_config.tool_type == ToolType.BUILTIN.value else False, + )) + + return response_tools + except Exception as e: + logger.error(f"获取工具列表失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/builtin/{tool_id}") +async def get_builtin_tool_detail( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取内置工具详情""" + try: + config_manager = ConfigManager() + builtin_tools = config_manager.load_builtin_tools_config() + configured_tools = {tool_info["tool_class"]: tool_info for tool_key, tool_info in builtin_tools.items()} + tool_config = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id, + ToolConfig.id == tool_id + ).first() + builtin_config = db.query(BuiltinToolConfig).filter(BuiltinToolConfig.id == tool_config.id).first() + tool_info = configured_tools.get(builtin_config.tool_class) + + is_configured = False + config_parameters = {} + + if builtin_config and builtin_config.parameters: + is_configured = bool(builtin_config.parameters.get('api_key') or builtin_config.parameters.get('token')) + # 不返回敏感信息,只返回非敏感配置 + config_parameters = {k: v for k, v in builtin_config.parameters.items() + if not any(sensitive in k.lower() for sensitive in ['key', 'secret', 'token', 'password'])} + + return { + "id": tool_config.id, + "name": tool_config.name, + "description": tool_config.description, + "category": tool_info['category'], + "status": tool_config.tool_type, + "requires_config": tool_info['requires_config'], + "is_configured": is_configured, + "config_parameters": config_parameters + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"获取工具详情失败: {tool_id}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/builtin/{tool_id}/configure") +async def configure_builtin_tool( + tool_id: str, + request: BuiltinToolConfigRequest = Body(...), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """配置内置工具参数(租户级别)""" + try: + # 查询工具配置 + tool_config = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id, + ToolConfig.id == tool_id, + ToolConfig.tool_type == ToolType.BUILTIN + ).first() + + if not tool_config: + raise HTTPException(status_code=404, detail="工具不存在") + + # 获取内置工具配置 + builtin_config = db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + + if not builtin_config: + raise HTTPException(status_code=404, detail="内置工具配置不存在") + + # 获取全局工具信息 + config_manager = ConfigManager() + builtin_tools_config = config_manager.load_builtin_tools_config() + tool_info = None + for tool_key, info in builtin_tools_config.items(): + if info['tool_class'] == builtin_config.tool_class: + tool_info = info + break + + if not tool_info: + raise HTTPException(status_code=404, detail="工具信息不存在") + + # 加密敏感参数 + encrypted_params = _encrypt_sensitive_params(request.parameters) + + # 更新配置 + builtin_config.parameters = encrypted_params + + # 更新状态 + _update_tool_status(tool_config, builtin_config, tool_info) + + db.commit() + + return { + "success": True, + "message": f"工具 {tool_config.name} 配置成功" + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"配置内置工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/builtin/{tool_id}/config") +async def get_builtin_tool_config( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """获取内置工具配置(用于使用)""" + try: + # 查询工具配置 + tool_config = db.query(ToolConfig).filter( + ToolConfig.tenant_id == current_user.tenant_id, + ToolConfig.id == tool_id, + ToolConfig.tool_type == ToolType.BUILTIN + ).first() + + if not tool_config: + raise HTTPException(status_code=404, detail="工具不存在") + + # 获取内置工具配置 + builtin_config = db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + + if not builtin_config: + raise HTTPException(status_code=404, detail="内置工具配置不存在") + + # 解密参数 + decrypted_params = _decrypt_sensitive_params(builtin_config.parameters or {}) + + return { + "tool_id": tool_id, + "tool_class": builtin_config.tool_class, + "name": tool_config.name, + "parameters": decrypted_params, + "status": tool_config.status + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"获取工具配置失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/custom") +async def create_custom_tool( + request: CustomToolCreateRequest = Body(...), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """创建自定义工具""" + try: + config_data = jsonable_encoder(request.model_dump()) + config_data["tool_type"] = "custom" + + config_manager = ConfigManager() + is_valid, error_msg = config_manager.validate_config(config_data, "custom") + if not is_valid: + raise HTTPException(status_code=400, detail=error_msg) + + # 创建数据库记录 + tool_config = ToolConfig( + name=request.name, + description=request.description, + tool_type=ToolType.CUSTOM, + tenant_id=current_user.tenant_id, + status=ToolStatus.ACTIVE.value, + config_data=config_data + ) + db.add(tool_config) + db.flush() + + # 创建CustomToolConfig记录 + custom_config = CustomToolConfig( + id=tool_config.id, + base_url=request.base_url, + schema_url=request.schema_url, + schema_content=request.schema_content, + auth_type=request.auth_type, + auth_config=request.auth_config, + timeout=request.timeout + ) + db.add(custom_config) + + db.commit() + + return { + "success": True, + "message": f"自定义工具 {request.name} 创建成功", + "tool_id": str(tool_config.id) + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"创建自定义工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@router.post("/mcp") +async def create_mcp_tool( + request: MCPToolCreateRequest = Body(..., description="MCP工具创建参数"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """创建MCP工具""" + try: + config_data = jsonable_encoder(request.model_dump()) + config_data["tool_type"] = "mcp" + + config_manager = ConfigManager() + is_valid, error_msg = config_manager.validate_config(config_data, "mcp") + if not is_valid: + raise HTTPException(status_code=400, detail=error_msg) + + # 创建数据库记录 + try: + tool_config = ToolConfig( + name=request.name, + description=request.description, + tool_type=ToolType.MCP, + tenant_id=current_user.tenant_id, + status=ToolStatus.ACTIVE.value, + config_data=config_data + ) + db.add(tool_config) + db.flush() + + # 创建MCPToolConfig记录 + mcp_config = MCPToolConfig( + id=tool_config.id, + server_url=request.server_url, + connection_config=request.connection_config + ) + db.add(mcp_config) + + db.commit() + except SQLAlchemyError as db_e: + db.rollback() + logger.error(f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id},工具名:{request.name}): {str(db_e)}", + exc_info=True) + raise HTTPException(status_code=500, detail=f"创建MCP工具数据库操作失败(租户ID:{current_user.tenant_id}," + f"工具名:{request.name}):{str(db_e)}") + + return { + "success": True, + "message": f"MCP工具 {request.name} 创建成功", + "tool_id": str(tool_config.id) + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"创建MCP工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@router.delete("/{tool_id}") +async def delete_tool( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """删除工具(仅限自定义和MCP工具)""" + try: + tool = db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == current_user.tenant_id + ).first() + + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + + if tool.tool_type == ToolType.BUILTIN: + raise HTTPException(status_code=403, detail="内置工具不允许删除") + + db.delete(tool) + db.commit() + + return { + "success": True, + "message": f"工具 {tool.name} 删除成功" + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"删除工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.put("/{tool_id}") +async def update_tool( + tool_id: str, + config_data: Optional[Dict[str, Any]] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """更新工具(仅限自定义和MCP工具)""" + try: + tool = db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == current_user.tenant_id + ).first() + + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + + if tool.tool_type == ToolType.BUILTIN: + raise HTTPException(status_code=403, detail="内置工具不允许修改") + + if config_data is not None: + tool.config_data = config_data + # 更新状态 + _update_tool_status(tool) + + db.commit() + db.refresh(tool) + + return { + "success": True, + "message": f"工具 {tool.name} 更新成功", + "status": tool.status + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"更新工具失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/{tool_id}/toggle") +async def toggle_tool_status( + tool_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """切换工具活跃/非活跃状态""" + try: + tool = db.query(ToolConfig).filter( + ToolConfig.id == tool_id, + ToolConfig.tenant_id == current_user.tenant_id + ).first() + + if not tool: + raise HTTPException(status_code=404, detail="工具不存在") + + # 在active和inactive之间切换 + if tool.status == ToolStatus.ACTIVE.value: + tool.status = ToolStatus.INACTIVE.value + elif tool.status == ToolStatus.INACTIVE.value: + tool.status = ToolStatus.ACTIVE.value + else: + raise HTTPException(status_code=400, detail="只有可用或非活跃状态的工具可以切换") + + db.commit() + db.refresh(tool) + + return { + "success": True, + "message": f"工具 {tool.name} 状态已更新为 {tool.status}", + "status": tool.status + } + + except HTTPException: + raise + except Exception as e: + logger.error(f"切换工具状态失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file diff --git a/api/app/controllers/tool_execution_controller.py b/api/app/controllers/tool_execution_controller.py new file mode 100644 index 00000000..486eb7cf --- /dev/null +++ b/api/app/controllers/tool_execution_controller.py @@ -0,0 +1,430 @@ +"""工具执行API控制器""" +import uuid +from typing import Dict, Any, List, Optional +from fastapi import APIRouter, Depends, HTTPException, Path, Query +from sqlalchemy.orm import Session +from pydantic import BaseModel, Field + +from app.db import get_db +from app.dependencies import get_current_user +from app.models import User +from app.core.tools.registry import ToolRegistry +from app.core.tools.executor import ToolExecutor +from app.core.tools.chain_manager import ChainManager, ChainDefinition, ChainStep, ChainExecutionMode +from app.core.tools.builtin import * +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + +router = APIRouter(prefix="/tools/execution", tags=["工具执行"]) + + +# ==================== 请求/响应模型 ==================== + +class ToolExecutionRequest(BaseModel): + """工具执行请求""" + tool_id: str = Field(..., description="工具ID") + parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") + timeout: Optional[float] = Field(None, ge=1, le=300, description="超时时间(秒)") + metadata: Optional[Dict[str, Any]] = Field(None, description="额外元数据") + + +class BatchExecutionRequest(BaseModel): + """批量执行请求""" + executions: List[ToolExecutionRequest] = Field(..., description="执行列表") + max_concurrency: int = Field(5, ge=1, le=20, description="最大并发数") + + +class ToolExecutionResponse(BaseModel): + """工具执行响应""" + success: bool + execution_id: str + tool_id: str + data: Any = None + error: Optional[str] = None + error_code: Optional[str] = None + execution_time: float + token_usage: Optional[Dict[str, int]] = None + metadata: Dict[str, Any] = Field(default_factory=dict) + + +class ChainStepRequest(BaseModel): + """链步骤请求""" + tool_id: str = Field(..., description="工具ID") + parameters: Dict[str, Any] = Field(default_factory=dict, description="工具参数") + condition: Optional[str] = Field(None, description="执行条件") + output_mapping: Optional[Dict[str, str]] = Field(None, description="输出映射") + error_handling: str = Field("stop", description="错误处理策略") + + +class ChainExecutionRequest(BaseModel): + """链执行请求""" + name: str = Field(..., description="链名称") + description: str = Field("", description="链描述") + steps: List[ChainStepRequest] = Field(..., description="执行步骤") + execution_mode: str = Field("sequential", description="执行模式") + initial_variables: Optional[Dict[str, Any]] = Field(None, description="初始变量") + global_timeout: Optional[float] = Field(None, description="全局超时") + + +class ExecutionHistoryResponse(BaseModel): + """执行历史响应""" + execution_id: str + tool_id: str + status: str + started_at: Optional[str] + completed_at: Optional[str] + execution_time: Optional[float] + user_id: Optional[str] + workspace_id: Optional[str] + input_data: Optional[Dict[str, Any]] + output_data: Optional[Any] + error_message: Optional[str] + token_usage: Optional[Dict[str, int]] + + +class ToolConnectionTestResponse(BaseModel): + """工具连接测试响应""" + success: bool + message: str + error: Optional[str] = None + details: Optional[Dict[str, Any]] = None + + +# ==================== 依赖注入 ==================== + +def get_tool_registry(db: Session = Depends(get_db)) -> ToolRegistry: + """获取工具注册表""" + registry = ToolRegistry(db) + + # 注册内置工具类 + registry.register_tool_class(DateTimeTool) + registry.register_tool_class(JsonTool) + registry.register_tool_class(BaiduSearchTool) + registry.register_tool_class(MinerUTool) + registry.register_tool_class(TextInTool) + + return registry + + +def get_tool_executor( + db: Session = Depends(get_db), + registry: ToolRegistry = Depends(get_tool_registry) +) -> ToolExecutor: + """获取工具执行器""" + return ToolExecutor(db, registry) + + +def get_chain_manager(executor: ToolExecutor = Depends(get_tool_executor)) -> ChainManager: + """获取链管理器""" + return ChainManager(executor) + + +# ==================== API端点 ==================== + +@router.post("/execute", response_model=ToolExecutionResponse) +async def execute_tool( + request: ToolExecutionRequest, + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """执行单个工具""" + try: + # 生成执行ID + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + + # 执行工具 + result = await executor.execute_tool( + tool_id=request.tool_id, + parameters=request.parameters, + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, + execution_id=execution_id, + timeout=request.timeout, + metadata=request.metadata + ) + + return ToolExecutionResponse( + success=result.success, + execution_id=execution_id, + tool_id=request.tool_id, + data=result.data, + error=result.error, + error_code=result.error_code, + execution_time=result.execution_time, + token_usage=result.token_usage, + metadata=result.metadata + ) + + except Exception as e: + logger.error(f"工具执行失败: {request.tool_id}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/batch", response_model=List[ToolExecutionResponse]) +async def execute_tools_batch( + request: BatchExecutionRequest, + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """批量执行工具""" + try: + # 准备执行配置 + execution_configs = [] + execution_ids = [] + + for exec_request in request.executions: + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + execution_ids.append(execution_id) + + execution_configs.append({ + "tool_id": exec_request.tool_id, + "parameters": exec_request.parameters, + "user_id": current_user.id, + "workspace_id": current_user.current_workspace_id, + "execution_id": execution_id, + "timeout": exec_request.timeout, + "metadata": exec_request.metadata + }) + + # 批量执行 + results = await executor.execute_tools_batch( + execution_configs, + max_concurrency=request.max_concurrency + ) + + # 转换响应格式 + responses = [] + for i, result in enumerate(results): + responses.append(ToolExecutionResponse( + success=result.success, + execution_id=execution_ids[i], + tool_id=request.executions[i].tool_id, + data=result.data, + error=result.error, + error_code=result.error_code, + execution_time=result.execution_time, + token_usage=result.token_usage, + metadata=result.metadata + )) + + return responses + + except Exception as e: + logger.error(f"批量执行失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/chain", response_model=Dict[str, Any]) +async def execute_tool_chain( + request: ChainExecutionRequest, + current_user: User = Depends(get_current_user), + chain_manager: ChainManager = Depends(get_chain_manager) +): + """执行工具链""" + try: + # 转换步骤格式 + steps = [] + for step_request in request.steps: + step = ChainStep( + tool_id=step_request.tool_id, + parameters=step_request.parameters, + condition=step_request.condition, + output_mapping=step_request.output_mapping, + error_handling=step_request.error_handling + ) + steps.append(step) + + # 创建链定义 + chain_definition = ChainDefinition( + name=request.name, + description=request.description, + steps=steps, + execution_mode=ChainExecutionMode(request.execution_mode), + global_timeout=request.global_timeout + ) + + # 注册并执行链 + chain_manager.register_chain(chain_definition) + + result = await chain_manager.execute_chain( + chain_name=request.name, + initial_variables=request.initial_variables + ) + + return result + + except Exception as e: + logger.error(f"工具链执行失败: {request.name}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/running", response_model=List[Dict[str, Any]]) +async def get_running_executions( + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """获取正在运行的执行""" + try: + running_executions = executor.get_running_executions() + + # 过滤当前工作空间的执行 + workspace_executions = [ + exec_info for exec_info in running_executions + if exec_info.get("workspace_id") == str(current_user.current_workspace_id) + ] + + return workspace_executions + + except Exception as e: + logger.error(f"获取运行中执行失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.delete("/cancel/{execution_id}", response_model=Dict[str, Any]) +async def cancel_execution( + execution_id: str = Path(..., description="执行ID"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """取消工具执行""" + try: + success = await executor.cancel_execution(execution_id) + + if success: + return { + "success": True, + "message": "执行已取消" + } + else: + raise HTTPException(status_code=404, detail="执行不存在或已完成") + + except HTTPException: + raise + except Exception as e: + logger.error(f"取消执行失败: {execution_id}, 错误: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/history", response_model=List[ExecutionHistoryResponse]) +async def get_execution_history( + tool_id: Optional[str] = Query(None, description="工具ID过滤"), + limit: int = Query(50, ge=1, le=200, description="返回数量限制"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """获取执行历史""" + try: + history = executor.get_execution_history( + tool_id=tool_id, + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, + limit=limit + ) + + # 转换响应格式 + responses = [] + for record in history: + responses.append(ExecutionHistoryResponse( + execution_id=record["execution_id"], + tool_id=record["tool_id"], + status=record["status"], + started_at=record["started_at"], + completed_at=record["completed_at"], + execution_time=record["execution_time"], + user_id=record["user_id"], + workspace_id=record["workspace_id"], + input_data=record["input_data"], + output_data=record["output_data"], + error_message=record["error_message"], + token_usage=record["token_usage"] + )) + + return responses + + except Exception as e: + logger.error(f"获取执行历史失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/statistics", response_model=Dict[str, Any]) +async def get_execution_statistics( + days: int = Query(7, ge=1, le=90, description="统计天数"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """获取执行统计""" + try: + stats = executor.get_execution_statistics( + workspace_id=current_user.current_workspace_id, + days=days + ) + + return { + "success": True, + "statistics": stats + } + + except Exception as e: + logger.error(f"获取执行统计失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/chains/running", response_model=List[Dict[str, Any]]) +async def get_running_chains( + current_user: User = Depends(get_current_user), + chain_manager: ChainManager = Depends(get_chain_manager) +): + """获取正在运行的工具链""" + try: + running_chains = chain_manager.get_running_chains() + return running_chains + + except Exception as e: + logger.error(f"获取运行中工具链失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.get("/chains", response_model=List[Dict[str, Any]]) +async def list_tool_chains( + current_user: User = Depends(get_current_user), + chain_manager: ChainManager = Depends(get_chain_manager) +): + """列出工具链""" + try: + chains = chain_manager.list_chains() + return chains + + except Exception as e: + logger.error(f"获取工具链列表失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@router.post("/test-connection/{tool_id}", response_model=ToolConnectionTestResponse) +async def test_tool_connection( + tool_id: str = Path(..., description="工具ID"), + current_user: User = Depends(get_current_user), + executor: ToolExecutor = Depends(get_tool_executor) +): + """测试工具连接""" + try: + result = await executor.test_tool_connection( + tool_id=tool_id, + user_id=current_user.id, + workspace_id=current_user.current_workspace_id + ) + + return ToolConnectionTestResponse( + success=result.get("success", False), + message=result.get("message", ""), + error=result.get("error"), + details=result.get("details") + ) + + except Exception as e: + logger.error(f"工具连接测试失败: {tool_id}, 错误: {e}") + return ToolConnectionTestResponse( + success=False, + message="连接测试失败", + error=str(e) + ) \ No newline at end of file diff --git a/api/app/core/api_key_auth.py b/api/app/core/api_key_auth.py index d90bb00d..e1021c6f 100644 --- a/api/app/core/api_key_auth.py +++ b/api/app/core/api_key_auth.py @@ -37,9 +37,10 @@ def require_api_key( @require_api_key(scopes=["app"]) def chat_with_app( resource_id: uuid.UUID, - api_key_auth: ApiKeyAuth = Depends(), + request: Request, + api_key_auth: ApiKeyAuth = None, db: Session = Depends(get_db), - message: str + message: str = Query(..., description="聊天消息内容") ): # api_key_auth 包含验证后的API Key 信息 pass diff --git a/api/app/core/config.py b/api/app/core/config.py index 41e9f0cf..bf5ff45a 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -157,6 +157,12 @@ class Settings: MEMORY_RUNTIME_FILE: str = os.getenv("MEMORY_RUNTIME_FILE", "runtime.json") MEMORY_DBRUN_FILE: str = os.getenv("MEMORY_DBRUN_FILE", "dbrun.json") + # Tool Management Configuration + TOOL_CONFIG_DIR: str = os.getenv("TOOL_CONFIG_DIR", "app/core/tools") + TOOL_EXECUTION_TIMEOUT: int = int(os.getenv("TOOL_EXECUTION_TIMEOUT", "60")) + TOOL_MAX_CONCURRENCY: int = int(os.getenv("TOOL_MAX_CONCURRENCY", "10")) + ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true" + def get_memory_output_path(self, filename: str = "") -> str: """ Get the full path for memory module output files. diff --git a/api/app/core/tools/__init__.py b/api/app/core/tools/__init__.py new file mode 100644 index 00000000..109bac13 --- /dev/null +++ b/api/app/core/tools/__init__.py @@ -0,0 +1,37 @@ +"""工具管理核心模块""" + +from .base import BaseTool, ToolResult, ToolParameter +from .registry import ToolRegistry +from .executor import ToolExecutor +from .langchain_adapter import LangchainAdapter +from .config_manager import ConfigManager +from .chain_manager import ChainManager + +# 可选导入,避免导入错误 +try: + from .custom.base import CustomTool +except ImportError: + CustomTool = None + +try: + from .mcp.base import MCPTool +except ImportError: + MCPTool = None + +__all__ = [ + "BaseTool", + "ToolResult", + "ToolParameter", + "ToolRegistry", + "ToolExecutor", + "LangchainAdapter", + "ConfigManager", + "ChainManager" +] + +# 只有在成功导入时才添加到__all__ +if CustomTool: + __all__.append("CustomTool") + +if MCPTool: + __all__.append("MCPTool") \ No newline at end of file diff --git a/api/app/core/tools/base.py b/api/app/core/tools/base.py new file mode 100644 index 00000000..d674af76 --- /dev/null +++ b/api/app/core/tools/base.py @@ -0,0 +1,302 @@ +"""工具基础接口定义""" +import time +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel, Field +from enum import Enum + +from app.models.tool_model import ToolType, ToolStatus + + +class ParameterType(str, Enum): + """参数类型枚举""" + STRING = "string" + INTEGER = "integer" + NUMBER = "number" + BOOLEAN = "boolean" + ARRAY = "array" + OBJECT = "object" + + +class ToolParameter(BaseModel): + """工具参数定义""" + name: str = Field(..., description="参数名称") + type: ParameterType = Field(..., description="参数类型") + description: str = Field("", description="参数描述") + required: bool = Field(False, description="是否必需") + default: Any = Field(None, description="默认值") + enum: Optional[List[Any]] = Field(None, description="枚举值") + minimum: Optional[Union[int, float]] = Field(None, description="最小值") + maximum: Optional[Union[int, float]] = Field(None, description="最大值") + pattern: Optional[str] = Field(None, description="正则表达式模式") + + class Config: + use_enum_values = True + + +class ToolResult(BaseModel): + """工具执行结果""" + success: bool = Field(..., description="执行是否成功") + data: Any = Field(None, description="返回数据") + error: Optional[str] = Field(None, description="错误信息") + error_code: Optional[str] = Field(None, description="错误代码") + execution_time: float = Field(..., description="执行时间(秒)") + token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况") + metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据") + + @classmethod + def success_result( + cls, + data: Any, + execution_time: float, + token_usage: Optional[Dict[str, int]] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> "ToolResult": + """创建成功结果""" + return cls( + success=True, + data=data, + execution_time=execution_time, + token_usage=token_usage, + metadata=metadata or {} + ) + + @classmethod + def error_result( + cls, + error: str, + execution_time: float, + error_code: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> "ToolResult": + """创建错误结果""" + return cls( + success=False, + error=error, + error_code=error_code, + execution_time=execution_time, + metadata=metadata or {} + ) + + +class ToolInfo(BaseModel): + """工具信息""" + id: str = Field(..., description="工具ID") + name: str = Field(..., description="工具名称") + description: str = Field(..., description="工具描述") + tool_type: ToolType = Field(..., description="工具类型") + version: str = Field("1.0.0", description="工具版本") + parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数") + status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态") + tags: List[str] = Field(default_factory=list, description="工具标签") + tenant_id: Optional[str] = Field(None, description="租户ID") + + class Config: + use_enum_values = True + + +class BaseTool(ABC): + """所有工具的基础抽象类""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + self.tool_id = tool_id + self.config = config + self._status = ToolStatus.ACTIVE + + @property + @abstractmethod + def name(self) -> str: + """工具名称""" + pass + + @property + @abstractmethod + def description(self) -> str: + """工具描述""" + pass + + @property + @abstractmethod + def tool_type(self) -> ToolType: + """工具类型""" + pass + + @property + def version(self) -> str: + """工具版本""" + return self.config.get("version", "1.0.0") + + @property + def status(self) -> ToolStatus: + """工具状态""" + return self._status + + @status.setter + def status(self, value: ToolStatus): + """设置工具状态""" + self._status = value + + @property + @abstractmethod + def parameters(self) -> List[ToolParameter]: + """工具参数定义""" + pass + + @property + def tags(self) -> List[str]: + """工具标签""" + return self.config.get("tags", []) + + def get_info(self) -> ToolInfo: + """获取工具信息""" + return ToolInfo( + id=self.tool_id, + name=self.name, + description=self.description, + tool_type=self.tool_type, + version=self.version, + parameters=self.parameters, + status=self.status, + tags=self.tags, + tenant_id=self.config.get("tenant_id") + ) + + def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]: + """验证参数 + + Args: + parameters: 输入参数 + + Returns: + 验证错误字典,空字典表示验证通过 + """ + errors = {} + param_definitions = {p.name: p for p in self.parameters} + + # 检查必需参数 + for param_def in self.parameters: + if param_def.required and param_def.name not in parameters: + errors[param_def.name] = f"Required parameter '{param_def.name}' is missing" + + # 检查参数类型和约束 + for param_name, param_value in parameters.items(): + if param_name not in param_definitions: + continue + + param_def = param_definitions[param_name] + + # 类型检查 + if not self._validate_parameter_type(param_value, param_def): + errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}" + + # 约束检查 + constraint_error = self._validate_parameter_constraints(param_value, param_def) + if constraint_error: + errors[param_name] = constraint_error + + return errors + + def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool: + """验证参数类型""" + if value is None: + return not param_def.required + + type_mapping = { + ParameterType.STRING: str, + ParameterType.INTEGER: int, + ParameterType.NUMBER: (int, float), + ParameterType.BOOLEAN: bool, + ParameterType.ARRAY: list, + ParameterType.OBJECT: dict + } + + expected_type = type_mapping.get(param_def.type) + if expected_type: + return isinstance(value, expected_type) + + return True + + def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]: + """验证参数约束""" + if value is None: + return None + + # 枚举值检查 + if param_def.enum and value not in param_def.enum: + return f"Value must be one of {param_def.enum}" + + # 数值范围检查 + if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]: + if param_def.minimum is not None and value < param_def.minimum: + return f"Value must be >= {param_def.minimum}" + if param_def.maximum is not None and value > param_def.maximum: + return f"Value must be <= {param_def.maximum}" + + # 字符串模式检查 + if param_def.type == ParameterType.STRING and param_def.pattern: + import re + if not re.match(param_def.pattern, str(value)): + return f"Value must match pattern: {param_def.pattern}" + + return None + + @abstractmethod + async def execute(self, **kwargs) -> ToolResult: + """执行工具 + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + pass + + async def safe_execute(self, **kwargs) -> ToolResult: + """安全执行工具(包含参数验证和异常处理) + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + start_time = time.time() + + try: + # 参数验证 + validation_errors = self.validate_parameters(kwargs) + if validation_errors: + execution_time = time.time() - start_time + error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()]) + return ToolResult.error_result( + error=f"Parameter validation failed: {error_msg}", + error_code="VALIDATION_ERROR", + execution_time=execution_time + ) + + # 执行工具 + result = await self.execute(**kwargs) + return result + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="EXECUTION_ERROR", + execution_time=execution_time + ) + + def to_langchain_tool(self): + """转换为Langchain工具格式""" + from .langchain_adapter import LangchainAdapter + return LangchainAdapter.convert_tool(self) + + def __repr__(self): + return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>" \ No newline at end of file diff --git a/api/app/core/tools/builtin/__init__.py b/api/app/core/tools/builtin/__init__.py new file mode 100644 index 00000000..3813402c --- /dev/null +++ b/api/app/core/tools/builtin/__init__.py @@ -0,0 +1,17 @@ +"""内置工具模块""" + +from .base import BuiltinTool +from .datetime_tool import DateTimeTool +from .json_tool import JsonTool +from .baidu_search_tool import BaiduSearchTool +from .mineru_tool import MinerUTool +from .textin_tool import TextInTool + +__all__ = [ + "BuiltinTool", + "DateTimeTool", + "JsonTool", + "BaiduSearchTool", + "MinerUTool", + "TextInTool" +] \ No newline at end of file diff --git a/api/app/core/tools/builtin/baidu_search_tool.py b/api/app/core/tools/builtin/baidu_search_tool.py new file mode 100644 index 00000000..fddd6eb7 --- /dev/null +++ b/api/app/core/tools/builtin/baidu_search_tool.py @@ -0,0 +1,334 @@ +"""百度搜索工具 - 搜索引擎服务""" +import time +from typing import List, Dict, Any +import aiohttp + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class BaiduSearchTool(BuiltinTool): + """百度搜索工具 - 提供网页搜索、新闻搜索、图片搜索、实时结果""" + + @property + def name(self) -> str: + return "baidu_search_tool" + + @property + def description(self) -> str: + return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果" + + def get_required_config_parameters(self) -> List[str]: + return ["api_key"] + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="query", + type=ParameterType.STRING, + description="搜索关键词", + required=True + ), + ToolParameter( + name="search_type", + type=ParameterType.STRING, + description="搜索类型", + required=False, + default="web", + enum=["web", "news", "image", "video"] + ), + ToolParameter( + name="page_size", + type=ParameterType.INTEGER, + description="每页结果数", + required=False, + default=10, + minimum=1, + maximum=50 + ), + ToolParameter( + name="page_num", + type=ParameterType.INTEGER, + description="页码(从1开始)", + required=False, + default=1, + minimum=1, + maximum=10 + ), + ToolParameter( + name="safe_search", + type=ParameterType.BOOLEAN, + description="是否启用安全搜索", + required=False, + default=True + ), + ToolParameter( + name="region", + type=ParameterType.STRING, + description="搜索地区", + required=False, + default="cn", + enum=["cn", "hk", "tw", "us", "jp", "kr"] + ), + ToolParameter( + name="time_filter", + type=ParameterType.STRING, + description="时间过滤", + required=False, + enum=["all", "day", "week", "month", "year"] + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行百度搜索""" + start_time = time.time() + + try: + query = kwargs.get("query") + search_type = kwargs.get("search_type", "web") + page_size = kwargs.get("page_size", 10) + page_num = kwargs.get("page_num", 1) + safe_search = kwargs.get("safe_search", True) + region = kwargs.get("region", "cn") + time_filter = kwargs.get("time_filter") + + if not query: + raise ValueError("query 参数是必需的") + + # 根据搜索类型调用不同的API + if search_type == "web": + result = await self._web_search(query, page_size, page_num, safe_search, region, time_filter) + elif search_type == "news": + result = await self._news_search(query, page_size, page_num, region, time_filter) + elif search_type == "image": + result = await self._image_search(query, page_size, page_num, safe_search) + elif search_type == "video": + result = await self._video_search(query, page_size, page_num, safe_search) + else: + raise ValueError(f"不支持的搜索类型: {search_type}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="BAIDU_SEARCH_ERROR", + execution_time=execution_time + ) + + async def _web_search(self, query: str, page_size: int, page_num: int, + safe_search: bool, region: str, time_filter: str = None) -> Dict[str, Any]: + """网页搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "web", "top_k": min(page_size, 50)}], + "enable_full_content": True + } + + if time_filter: + time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"} + if time_filter in time_map: + payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}} + payload["search_recency_filter"] = time_filter + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "web", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _news_search(self, query: str, page_size: int, page_num: int, + region: str, time_filter: str = None) -> Dict[str, Any]: + """新闻搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "new", "top_k": min(page_size, 50)}], + "enable_full_content": True + } + + if time_filter: + time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"} + if time_filter in time_map: + payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}} + payload["search_recency_filter"] = time_filter + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "new", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _image_search(self, query: str, page_size: int, page_num: int, + safe_search: bool) -> Dict[str, Any]: + """图片搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "image", "top_k": min(page_size, 30)}], + "enable_full_content": True + } + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "image", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _video_search(self, query: str, page_size: int, page_num: int, + safe_search: bool) -> Dict[str, Any]: + """视频搜索""" + payload = { + "messages": [{"role": "user", "content": query}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "video", "top_k": min(page_size, 10)}], + "enable_full_content": True + } + + results = await self._call_baidu_ai_search_api(payload) + + search_results = [] + if "references" in results: + for item in results["references"]: + search_results.append({ + "title": item.get("title", ""), + "url": item.get("url", ""), + "snippet": item.get("content", ""), + "display_url": item.get("url", ""), + "rank": len(search_results) + 1 + }) + + return { + "search_type": "video", + "query": query, + "total_results": len(search_results), + "page_num": page_num, + "page_size": page_size, + "results": search_results, + "answer": results.get("result", ""), + "references": results.get("references", []) + } + + async def _call_baidu_ai_search_api(self, payload: Dict[str, Any]) -> Dict[str, Any]: + """调用百度AI搜索API""" + api_key = self.get_config_parameter("api_key") + + if not api_key: + raise ValueError("百度搜索API密钥未配置") + + url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" + headers = { + 'Content-Type': 'application/json', + 'Authorization': f'Bearer {api_key}' + } + + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=payload) as response: + if response.status == 200: + return await response.json() + else: + raise Exception(f"HTTP错误: {response.status}") + + async def test_connection(self) -> Dict[str, Any]: + """测试连接""" + try: + api_key = self.get_config_parameter("api_key") + + if not api_key: + return { + "success": False, + "error": "API密钥未配置" + } + + # 发送测试请求验证API key是否有效 + test_payload = { + "messages": [{"role": "user", "content": "test"}], + "edition": "standard", + "search_source": "baidu_search_v2", + "resource_type_filter": [{"type": "web", "top_k": 1}] + } + + try: + await self._call_baidu_ai_search_api(test_payload) + return { + "success": True, + "message": "连接测试成功", + "api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***" + } + except Exception as e: + return { + "success": False, + "error": f"API连接失败: {str(e)}" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/builtin/base.py b/api/app/core/tools/builtin/base.py new file mode 100644 index 00000000..532d0869 --- /dev/null +++ b/api/app/core/tools/builtin/base.py @@ -0,0 +1,118 @@ +"""内置工具基类""" +from abc import ABC, abstractmethod +from typing import Dict, Any, List + +from app.models.tool_model import ToolType +from app.core.tools.base import BaseTool, ToolResult, ToolParameter + + +class BuiltinTool(BaseTool, ABC): + """内置工具基类""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化内置工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + super().__init__(tool_id, config) + self.parameters_config = config.get("parameters", {}) + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.BUILTIN + + @property + @abstractmethod + def name(self) -> str: + """工具名称 - 子类必须实现""" + pass + + @property + @abstractmethod + def description(self) -> str: + """工具描述 - 子类必须实现""" + pass + + @property + @abstractmethod + def parameters(self) -> List[ToolParameter]: + """工具参数定义 - 子类必须实现""" + pass + + @abstractmethod + async def execute(self, **kwargs) -> ToolResult: + """执行工具 - 子类必须实现 + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + pass + + @property + def is_configured(self) -> bool: + """检查工具是否已正确配置""" + required_params = self.get_required_config_parameters() + for param in required_params: + if not self.parameters_config.get(param): + return False + return True + + def get_required_config_parameters(self) -> List[str]: + """获取必需的配置参数列表 + + Returns: + 必需配置参数名称列表 + """ + return [] + + def get_config_parameter(self, name: str, default: Any = None) -> Any: + """获取配置参数值 + + Args: + name: 参数名称 + default: 默认值 + + Returns: + 参数值 + """ + return self.parameters_config.get(name, default) + + def validate_configuration(self) -> tuple[bool, str]: + """验证工具配置 + + Returns: + (是否有效, 错误信息) + """ + if not self.is_configured: + required_params = self.get_required_config_parameters() + missing_params = [p for p in required_params if not self.parameters_config.get(p)] + return False, f"缺少必需的配置参数: {', '.join(missing_params)}" + + return True, "" + + async def safe_execute(self, **kwargs) -> ToolResult: + """安全执行工具(包含配置验证) + + Args: + **kwargs: 工具参数 + + Returns: + 执行结果 + """ + # 首先验证配置 + is_valid, error_msg = self.validate_configuration() + if not is_valid: + return ToolResult.error_result( + error=f"工具配置无效: {error_msg}", + error_code="CONFIGURATION_ERROR", + execution_time=0.0 + ) + + # 调用父类的安全执行 + return await super().safe_execute(**kwargs) \ No newline at end of file diff --git a/api/app/core/tools/builtin/datetime_tool.py b/api/app/core/tools/builtin/datetime_tool.py new file mode 100644 index 00000000..475ce7be --- /dev/null +++ b/api/app/core/tools/builtin/datetime_tool.py @@ -0,0 +1,307 @@ +"""时间工具 - 日期时间处理""" +import time +from datetime import datetime, timezone, timedelta +from typing import List +import pytz + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class DateTimeTool(BuiltinTool): + """时间工具 - 提供时间格式转换、时区转换、时间戳转换、时间计算功能""" + + @property + def name(self) -> str: + return "datetime_tool" + + @property + def description(self) -> str: + return "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算" + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="操作类型", + required=True, + enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"] + ), + ToolParameter( + name="input_value", + type=ParameterType.STRING, + description="输入值(时间字符串或时间戳)", + required=False + ), + ToolParameter( + name="input_format", + type=ParameterType.STRING, + description="输入时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出时间格式(如:%Y-%m-%d %H:%M:%S)", + required=False, + default="%Y-%m-%d %H:%M:%S" + ), + ToolParameter( + name="from_timezone", + type=ParameterType.STRING, + description="源时区(如:UTC, Asia/Shanghai)", + required=False, + default="UTC" + ), + ToolParameter( + name="to_timezone", + type=ParameterType.STRING, + description="目标时区(如:UTC, Asia/Shanghai)", + required=False, + default="UTC" + ), + ToolParameter( + name="calculation", + type=ParameterType.STRING, + description="时间计算表达式(如:+1d, -2h, +30m)", + required=False + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行时间工具操作""" + start_time = time.time() + + try: + operation = kwargs.get("operation") + + if operation == "now": + result = self._get_current_time(kwargs) + elif operation == "format": + result = self._format_datetime(kwargs) + elif operation == "convert_timezone": + result = self._convert_timezone(kwargs) + elif operation == "timestamp_to_datetime": + result = self._timestamp_to_datetime(kwargs) + elif operation == "datetime_to_timestamp": + result = self._datetime_to_timestamp(kwargs) + elif operation == "calculate": + result = self._calculate_datetime(kwargs) + else: + raise ValueError(f"不支持的操作类型: {operation}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="DATETIME_ERROR", + execution_time=execution_time + ) + + def _get_current_time(self, kwargs) -> dict: + """获取当前时间""" + timezone_str = kwargs.get("to_timezone", "UTC") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + now = datetime.now(tz) + + return { + "datetime": now.strftime(output_format), + "timestamp": int(now.timestamp()), + "timezone": timezone_str, + "iso_format": now.isoformat() + } + + def _format_datetime(self, kwargs) -> dict: + """格式化时间""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + return { + "original": input_value, + "formatted": dt.strftime(output_format), + "timestamp": int(dt.timestamp()), + "iso_format": dt.isoformat() + } + + def _convert_timezone(self, kwargs) -> dict: + """时区转换""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + from_timezone = kwargs.get("from_timezone", "UTC") + to_timezone = kwargs.get("to_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + # 设置源时区 + if from_timezone == "UTC": + from_tz = pytz.UTC + else: + from_tz = pytz.timezone(from_timezone) + + # 设置目标时区 + if to_timezone == "UTC": + to_tz = pytz.UTC + else: + to_tz = pytz.timezone(to_timezone) + + # 本地化时间并转换时区 + if dt.tzinfo is None: + dt = from_tz.localize(dt) + + converted_dt = dt.astimezone(to_tz) + + return { + "original": input_value, + "original_timezone": from_timezone, + "converted": converted_dt.strftime(output_format), + "converted_timezone": to_timezone, + "timestamp": int(converted_dt.timestamp()) + } + + def _timestamp_to_datetime(self, kwargs) -> dict: + """时间戳转日期时间""" + input_value = kwargs.get("input_value") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + timezone_str = kwargs.get("to_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 转换时间戳 + timestamp = float(input_value) + + # 设置时区 + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + dt = datetime.fromtimestamp(timestamp, tz) + + return { + "timestamp": timestamp, + "datetime": dt.strftime(output_format), + "timezone": timezone_str, + "iso_format": dt.isoformat() + } + + def _datetime_to_timestamp(self, kwargs) -> dict: + """日期时间转时间戳""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + timezone_str = kwargs.get("from_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + # 设置时区 + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + # 本地化时间 + if dt.tzinfo is None: + dt = tz.localize(dt) + + return { + "datetime": input_value, + "timezone": timezone_str, + "timestamp": int(dt.timestamp()), + "iso_format": dt.isoformat() + } + + def _calculate_datetime(self, kwargs) -> dict: + """时间计算""" + input_value = kwargs.get("input_value") + input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S") + output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S") + calculation = kwargs.get("calculation") + timezone_str = kwargs.get("from_timezone", "UTC") + + if not input_value: + raise ValueError("input_value 参数是必需的") + + if not calculation: + raise ValueError("calculation 参数是必需的") + + # 解析输入时间 + dt = datetime.strptime(input_value, input_format) + + # 设置时区 + if timezone_str == "UTC": + tz = timezone.utc + else: + tz = pytz.timezone(timezone_str) + + if dt.tzinfo is None: + dt = tz.localize(dt) + + # 解析计算表达式 + delta = self._parse_time_delta(calculation) + calculated_dt = dt + delta + + return { + "original": input_value, + "calculation": calculation, + "result": calculated_dt.strftime(output_format), + "timezone": timezone_str, + "timestamp": int(calculated_dt.timestamp()) + } + + def _parse_time_delta(self, calculation: str) -> timedelta: + """解析时间计算表达式""" + import re + + # 支持的单位:d(天), h(小时), m(分钟), s(秒) + pattern = r'([+-]?\d+)([dhms])' + matches = re.findall(pattern, calculation.lower()) + + if not matches: + raise ValueError(f"无效的时间计算表达式: {calculation}") + + total_delta = timedelta() + + for value_str, unit in matches: + value = int(value_str) + + if unit == 'd': + total_delta += timedelta(days=value) + elif unit == 'h': + total_delta += timedelta(hours=value) + elif unit == 'm': + total_delta += timedelta(minutes=value) + elif unit == 's': + total_delta += timedelta(seconds=value) + + return total_delta \ No newline at end of file diff --git a/api/app/core/tools/builtin/json_tool.py b/api/app/core/tools/builtin/json_tool.py new file mode 100644 index 00000000..135d252a --- /dev/null +++ b/api/app/core/tools/builtin/json_tool.py @@ -0,0 +1,430 @@ +"""JSON转换工具 - 数据格式转换""" +import json +import time +from typing import List, Any, Dict +import yaml +import xml.etree.ElementTree as ET +from xml.dom import minidom + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class JsonTool(BuiltinTool): + """JSON转换工具 - 提供JSON格式化、压缩、验证、格式转换功能""" + + @property + def name(self) -> str: + return "json_tool" + + @property + def description(self) -> str: + return "JSON转换工具 - 数据格式转换:JSON格式化、JSON压缩、JSON验证、格式转换" + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="操作类型", + required=True, + enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"] + ), + ToolParameter( + name="input_data", + type=ParameterType.STRING, + description="输入数据(JSON字符串、YAML字符串或XML字符串)", + required=True + ), + ToolParameter( + name="indent", + type=ParameterType.INTEGER, + description="JSON格式化缩进空格数", + required=False, + default=2, + minimum=0, + maximum=8 + ), + ToolParameter( + name="ensure_ascii", + type=ParameterType.BOOLEAN, + description="是否确保ASCII编码", + required=False, + default=False + ), + ToolParameter( + name="sort_keys", + type=ParameterType.BOOLEAN, + description="是否对键进行排序", + required=False, + default=False + ), + ToolParameter( + name="merge_data", + type=ParameterType.STRING, + description="要合并的JSON数据(用于merge操作)", + required=False + ), + ToolParameter( + name="json_path", + type=ParameterType.STRING, + description="JSON路径表达式(用于extract操作,如:$.user.name)", + required=False + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行JSON工具操作""" + start_time = time.time() + + try: + operation = kwargs.get("operation") + input_data = kwargs.get("input_data") + + if not input_data: + raise ValueError("input_data 参数是必需的") + + if operation == "format": + result = self._format_json(input_data, kwargs) + elif operation == "minify": + result = self._minify_json(input_data) + elif operation == "validate": + result = self._validate_json(input_data) + elif operation == "convert": + result = self._convert_json(input_data) + elif operation == "to_yaml": + result = self._json_to_yaml(input_data) + elif operation == "from_yaml": + result = self._yaml_to_json(input_data, kwargs) + elif operation == "to_xml": + result = self._json_to_xml(input_data) + elif operation == "from_xml": + result = self._xml_to_json(input_data, kwargs) + elif operation == "merge": + result = self._merge_json(input_data, kwargs) + elif operation == "extract": + result = self._extract_json_path(input_data, kwargs) + else: + raise ValueError(f"不支持的操作类型: {operation}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="JSON_ERROR", + execution_time=execution_time + ) + + def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """格式化JSON""" + indent = kwargs.get("indent", 2) + ensure_ascii = kwargs.get("ensure_ascii", False) + sort_keys = kwargs.get("sort_keys", False) + + # 解析JSON + data = json.loads(input_data) + + # 格式化输出 + formatted = json.dumps( + data, + indent=indent, + ensure_ascii=ensure_ascii, + sort_keys=sort_keys, + separators=(',', ': ') + ) + + return { + "original_size": len(input_data), + "formatted_size": len(formatted), + "formatted_json": formatted, + "is_valid": True, + "settings": { + "indent": indent, + "ensure_ascii": ensure_ascii, + "sort_keys": sort_keys + } + } + + def _minify_json(self, input_data: str) -> Dict[str, Any]: + """压缩JSON""" + # 解析并压缩 + data = json.loads(input_data) + minified = json.dumps(data, separators=(',', ':')) + + return { + "original_size": len(input_data), + "minified_size": len(minified), + "compression_ratio": round((1 - len(minified) / len(input_data)) * 100, 2), + "minified_json": minified, + "is_valid": True + } + + def _validate_json(self, input_data: str) -> Dict[str, Any]: + """验证JSON""" + try: + data = json.loads(input_data) + + # 统计信息 + stats = self._analyze_json_structure(data) + + return { + "is_valid": True, + "error": None, + "size": len(input_data), + "structure": stats + } + + except json.JSONDecodeError as e: + return { + "is_valid": False, + "error": str(e), + "error_line": getattr(e, 'lineno', None), + "error_column": getattr(e, 'colno', None), + "size": len(input_data) + } + + def _convert_json(self, input_data: str) -> Dict[str, Any]: + """JSON转义""" + data = json.loads(input_data) + converted = json.dumps(data, ensure_ascii=False) + + return { + "converted_json": converted, + "is_valid": True + } + + def _json_to_yaml(self, input_data: str) -> Dict[str, Any]: + """JSON转YAML""" + data = json.loads(input_data) + yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2) + + return { + "original_format": "json", + "target_format": "yaml", + "original_size": len(input_data), + "converted_size": len(yaml_output), + "converted_data": yaml_output + } + + def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """YAML转JSON""" + indent = kwargs.get("indent", 2) + ensure_ascii = kwargs.get("ensure_ascii", False) + + data = yaml.safe_load(input_data) + json_output = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii) + + return { + "original_format": "yaml", + "target_format": "json", + "original_size": len(input_data), + "converted_size": len(json_output), + "converted_data": json_output + } + + def _json_to_xml(self, input_data: str) -> Dict[str, Any]: + """JSON转XML""" + data = json.loads(input_data) + + def dict_to_xml(data, root_name="root"): + """递归转换字典为XML""" + if isinstance(data, dict): + if len(data) == 1 and not root_name == "root": + # 如果字典只有一个键,使用该键作为根元素 + key, value = next(iter(data.items())) + return dict_to_xml(value, key) + + root = ET.Element(root_name) + for key, value in data.items(): + if isinstance(value, (dict, list)): + child = dict_to_xml(value, key) + root.append(child) + else: + child = ET.SubElement(root, key) + child.text = str(value) + return root + + elif isinstance(data, list): + root = ET.Element(root_name) + for i, item in enumerate(data): + if isinstance(item, (dict, list)): + child = dict_to_xml(item, f"item_{i}") + root.append(child) + else: + child = ET.SubElement(root, f"item_{i}") + child.text = str(item) + return root + + else: + root = ET.Element(root_name) + root.text = str(data) + return root + + xml_element = dict_to_xml(data) + xml_string = ET.tostring(xml_element, encoding='unicode') + + # 格式化XML + dom = minidom.parseString(xml_string) + formatted_xml = dom.toprettyxml(indent=" ") + + # 移除空行 + formatted_xml = '\n'.join([line for line in formatted_xml.split('\n') if line.strip()]) + + return { + "original_format": "json", + "target_format": "xml", + "original_size": len(input_data), + "converted_size": len(formatted_xml), + "converted_data": formatted_xml + } + + def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """XML转JSON""" + indent = kwargs.get("indent", 2) + + def xml_to_dict(element): + """递归转换XML元素为字典""" + result = {} + + # 处理属性 + if element.attrib: + result.update(element.attrib) + + # 处理文本内容 + if element.text and element.text.strip(): + if len(element) == 0: # 叶子节点 + return element.text.strip() + else: + result['text'] = element.text.strip() + + # 处理子元素 + for child in element: + child_data = xml_to_dict(child) + if child.tag in result: + # 如果标签已存在,转换为列表 + if not isinstance(result[child.tag], list): + result[child.tag] = [result[child.tag]] + result[child.tag].append(child_data) + else: + result[child.tag] = child_data + + return result + + root = ET.fromstring(input_data) + data = {root.tag: xml_to_dict(root)} + json_output = json.dumps(data, indent=indent, ensure_ascii=False) + + return { + "original_format": "xml", + "target_format": "json", + "original_size": len(input_data), + "converted_size": len(json_output), + "converted_data": json_output + } + + def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """合并JSON""" + merge_data = kwargs.get("merge_data") + if not merge_data: + raise ValueError("merge_data 参数是必需的") + + data1 = json.loads(input_data) + data2 = json.loads(merge_data) + + def deep_merge(dict1, dict2): + """深度合并字典""" + result = dict1.copy() + for key, value in dict2.items(): + if key in result and isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge(result[key], value) + else: + result[key] = value + return result + + if isinstance(data1, dict) and isinstance(data2, dict): + merged = deep_merge(data1, data2) + elif isinstance(data1, list) and isinstance(data2, list): + merged = data1 + data2 + else: + raise ValueError("无法合并不同类型的数据") + + merged_json = json.dumps(merged, indent=2, ensure_ascii=False) + + return { + "operation": "merge", + "original_size": len(input_data), + "merge_size": len(merge_data), + "result_size": len(merged_json), + "merged_data": merged_json + } + + def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取JSON路径""" + json_path = kwargs.get("json_path") + if not json_path: + raise ValueError("json_path 参数是必需的") + + data = json.loads(input_data) + + # 简单的JSONPath实现(支持基本的点号路径) + try: + result = data + if json_path.startswith('$.'): + path_parts = json_path[2:].split('.') + else: + path_parts = json_path.split('.') + + for part in path_parts: + if part.isdigit(): + result = result[int(part)] + else: + result = result[part] + + extracted_json = json.dumps(result, indent=2, ensure_ascii=False) + + return { + "operation": "extract", + "json_path": json_path, + "found": True, + "extracted_data": extracted_json, + "data_type": type(result).__name__ + } + + except (KeyError, IndexError, TypeError) as e: + return { + "operation": "extract", + "json_path": json_path, + "found": False, + "error": str(e), + "extracted_data": None + } + + def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]: + """分析JSON结构""" + if isinstance(data, dict): + return { + "type": "object", + "keys": len(data), + "depth": depth, + "children": {k: self._analyze_json_structure(v, depth + 1) for k, v in data.items()} + } + elif isinstance(data, list): + return { + "type": "array", + "length": len(data), + "depth": depth, + "item_types": list(set(type(item).__name__ for item in data)) + } + else: + return { + "type": type(data).__name__, + "depth": depth, + "value": str(data)[:100] + "..." if len(str(data)) > 100 else str(data) + } \ No newline at end of file diff --git a/api/app/core/tools/builtin/mineru_tool.py b/api/app/core/tools/builtin/mineru_tool.py new file mode 100644 index 00000000..b2a544c0 --- /dev/null +++ b/api/app/core/tools/builtin/mineru_tool.py @@ -0,0 +1,327 @@ +"""MinerU PDF解析工具""" +import time +from typing import List, Dict, Any +import aiohttp + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class MinerUTool(BuiltinTool): + """MinerU PDF解析工具 - 提供PDF解析、表格提取、图片识别、文本提取功能""" + + @property + def name(self) -> str: + return "mineru_tool" + + @property + def description(self) -> str: + return "MinerU - PDF解析工具:PDF解析、表格提取、图片识别、文本提取" + + def get_required_config_parameters(self) -> List[str]: + return ["api_key", "api_url"] + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="operation", + type=ParameterType.STRING, + description="操作类型", + required=True, + enum=["parse_pdf", "extract_text", "extract_tables", "extract_images", "analyze_layout"] + ), + ToolParameter( + name="file_content", + type=ParameterType.STRING, + description="PDF文件内容(Base64编码)", + required=False + ), + ToolParameter( + name="file_url", + type=ParameterType.STRING, + description="PDF文件URL", + required=False + ), + ToolParameter( + name="parse_mode", + type=ParameterType.STRING, + description="解析模式", + required=False, + default="auto", + enum=["auto", "text_only", "table_priority", "image_priority", "layout_analysis"] + ), + ToolParameter( + name="extract_images", + type=ParameterType.BOOLEAN, + description="是否提取图片", + required=False, + default=True + ), + ToolParameter( + name="extract_tables", + type=ParameterType.BOOLEAN, + description="是否提取表格", + required=False, + default=True + ), + ToolParameter( + name="page_range", + type=ParameterType.STRING, + description="页面范围(如:1-5, 1,3,5)", + required=False + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出格式", + required=False, + default="json", + enum=["json", "markdown", "html", "text"] + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行MinerU PDF解析""" + start_time = time.time() + + try: + operation = kwargs.get("operation") + file_content = kwargs.get("file_content") + file_url = kwargs.get("file_url") + + if not file_content and not file_url: + raise ValueError("必须提供 file_content 或 file_url 参数") + + if operation == "parse_pdf": + result = await self._parse_pdf(kwargs) + elif operation == "extract_text": + result = await self._extract_text(kwargs) + elif operation == "extract_tables": + result = await self._extract_tables(kwargs) + elif operation == "extract_images": + result = await self._extract_images(kwargs) + elif operation == "analyze_layout": + result = await self._analyze_layout(kwargs) + else: + raise ValueError(f"不支持的操作类型: {operation}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="MINERU_ERROR", + execution_time=execution_time + ) + + async def _parse_pdf(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """完整PDF解析""" + parse_mode = kwargs.get("parse_mode", "auto") + extract_images = kwargs.get("extract_images", True) + extract_tables = kwargs.get("extract_tables", True) + page_range = kwargs.get("page_range") + output_format = kwargs.get("output_format", "json") + + # 构建请求参数 + request_data = { + "parse_mode": parse_mode, + "extract_images": extract_images, + "extract_tables": extract_tables, + "output_format": output_format + } + + if page_range: + request_data["page_range"] = page_range + + # 添加文件数据 + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + # 调用MinerU API + result = await self._call_mineru_api("parse", request_data) + + return { + "operation": "parse_pdf", + "parse_mode": parse_mode, + "total_pages": result.get("total_pages", 0), + "processed_pages": result.get("processed_pages", 0), + "text_content": result.get("text_content", ""), + "tables": result.get("tables", []), + "images": result.get("images", []), + "layout_info": result.get("layout_info", {}), + "metadata": result.get("metadata", {}), + "processing_time": result.get("processing_time", 0) + } + + async def _extract_text(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取文本""" + page_range = kwargs.get("page_range") + output_format = kwargs.get("output_format", "text") + + request_data = { + "operation": "extract_text", + "output_format": output_format + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("extract_text", request_data) + + return { + "operation": "extract_text", + "total_pages": result.get("total_pages", 0), + "text_content": result.get("text_content", ""), + "word_count": len(result.get("text_content", "").split()), + "character_count": len(result.get("text_content", "")), + "pages_text": result.get("pages_text", []) + } + + async def _extract_tables(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取表格""" + page_range = kwargs.get("page_range") + output_format = kwargs.get("output_format", "json") + + request_data = { + "operation": "extract_tables", + "output_format": output_format + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("extract_tables", request_data) + + return { + "operation": "extract_tables", + "total_tables": result.get("total_tables", 0), + "tables": result.get("tables", []), + "table_locations": result.get("table_locations", []) + } + + async def _extract_images(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """提取图片""" + page_range = kwargs.get("page_range") + + request_data = { + "operation": "extract_images" + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("extract_images", request_data) + + return { + "operation": "extract_images", + "total_images": result.get("total_images", 0), + "images": result.get("images", []), + "image_locations": result.get("image_locations", []) + } + + async def _analyze_layout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """分析布局""" + page_range = kwargs.get("page_range") + + request_data = { + "operation": "analyze_layout" + } + + if page_range: + request_data["page_range"] = page_range + + if kwargs.get("file_content"): + request_data["file_content"] = kwargs["file_content"] + elif kwargs.get("file_url"): + request_data["file_url"] = kwargs["file_url"] + + result = await self._call_mineru_api("analyze_layout", request_data) + + return { + "operation": "analyze_layout", + "layout_info": result.get("layout_info", {}), + "page_layouts": result.get("page_layouts", []), + "text_blocks": result.get("text_blocks", []), + "image_blocks": result.get("image_blocks", []), + "table_blocks": result.get("table_blocks", []) + } + + async def _call_mineru_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: + """调用MinerU API""" + api_key = self.get_config_parameter("api_key") + api_url = self.get_config_parameter("api_url") + timeout_seconds = self.get_config_parameter("timeout", 60) + + if not api_key or not api_url: + raise ValueError("MinerU API配置未完成") + + # 构建完整URL + url = f"{api_url.rstrip('/')}/{endpoint}" + + # 构建请求头 + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + # 发送请求 + timeout = aiohttp.ClientTimeout(total=timeout_seconds) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status == 200: + result = await response.json() + if result.get("success", True): + return result.get("data", result) + else: + raise Exception(f"MinerU API错误: {result.get('message', '未知错误')}") + else: + error_text = await response.text() + raise Exception(f"HTTP错误 {response.status}: {error_text}") + + def test_connection(self) -> Dict[str, Any]: + """测试连接""" + try: + api_key = self.get_config_parameter("api_key") + api_url = self.get_config_parameter("api_url") + + if not api_key or not api_url: + return { + "success": False, + "error": "API配置未完成" + } + + return { + "success": True, + "message": "连接配置有效", + "api_url": api_url, + "api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/builtin/textin_tool.py b/api/app/core/tools/builtin/textin_tool.py new file mode 100644 index 00000000..ec3e214e --- /dev/null +++ b/api/app/core/tools/builtin/textin_tool.py @@ -0,0 +1,401 @@ +"""TextIn OCR文字识别工具""" +import time +from typing import List, Dict, Any +import aiohttp + +from app.core.tools.base import ToolParameter, ToolResult, ParameterType +from .base import BuiltinTool + + +class TextInTool(BuiltinTool): + """TextIn OCR工具 - 提供通用OCR、手写识别、多语言支持、高精度识别""" + + @property + def name(self) -> str: + return "textin_tool" + + @property + def description(self) -> str: + return "TextIn - OCR文字识别:通用OCR、手写识别、多语言支持、高精度识别" + + def get_required_config_parameters(self) -> List[str]: + return ["app_id", "secret_key", "api_url"] + + @property + def parameters(self) -> List[ToolParameter]: + return [ + ToolParameter( + name="image_content", + type=ParameterType.STRING, + description="图片内容(Base64编码)", + required=False + ), + ToolParameter( + name="image_url", + type=ParameterType.STRING, + description="图片URL", + required=False + ), + ToolParameter( + name="language", + type=ParameterType.STRING, + description="识别语言", + required=False, + default="auto", + enum=["auto", "zh-cn", "zh-tw", "en", "ja", "ko", "fr", "de", "es", "ru"] + ), + ToolParameter( + name="recognition_mode", + type=ParameterType.STRING, + description="识别模式", + required=False, + default="general", + enum=["general", "accurate", "handwriting", "formula", "table", "document"] + ), + ToolParameter( + name="return_location", + type=ParameterType.BOOLEAN, + description="是否返回文字位置信息", + required=False, + default=False + ), + ToolParameter( + name="return_confidence", + type=ParameterType.BOOLEAN, + description="是否返回置信度", + required=False, + default=True + ), + ToolParameter( + name="merge_lines", + type=ParameterType.BOOLEAN, + description="是否合并行", + required=False, + default=True + ), + ToolParameter( + name="output_format", + type=ParameterType.STRING, + description="输出格式", + required=False, + default="text", + enum=["text", "json", "structured"] + ) + ] + + async def execute(self, **kwargs) -> ToolResult: + """执行TextIn OCR识别""" + start_time = time.time() + + try: + image_content = kwargs.get("image_content") + image_url = kwargs.get("image_url") + + if not image_content and not image_url: + raise ValueError("必须提供 image_content 或 image_url 参数") + + language = kwargs.get("language", "auto") + recognition_mode = kwargs.get("recognition_mode", "general") + return_location = kwargs.get("return_location", False) + return_confidence = kwargs.get("return_confidence", True) + merge_lines = kwargs.get("merge_lines", True) + output_format = kwargs.get("output_format", "text") + + # 根据识别模式调用不同的API + if recognition_mode == "general": + result = await self._general_ocr(kwargs) + elif recognition_mode == "accurate": + result = await self._accurate_ocr(kwargs) + elif recognition_mode == "handwriting": + result = await self._handwriting_ocr(kwargs) + elif recognition_mode == "formula": + result = await self._formula_ocr(kwargs) + elif recognition_mode == "table": + result = await self._table_ocr(kwargs) + elif recognition_mode == "document": + result = await self._document_ocr(kwargs) + else: + raise ValueError(f"不支持的识别模式: {recognition_mode}") + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="TEXTIN_ERROR", + execution_time=execution_time + ) + + async def _general_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """通用OCR识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "merge_lines": kwargs.get("merge_lines", True) + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("general_ocr", request_data) + + return self._format_ocr_result(result, kwargs.get("output_format", "text")) + + async def _accurate_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """高精度OCR识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "merge_lines": kwargs.get("merge_lines", True) + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("accurate_ocr", request_data) + + return self._format_ocr_result(result, kwargs.get("output_format", "text")) + + async def _handwriting_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """手写体识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True) + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("handwriting_ocr", request_data) + + return self._format_ocr_result(result, kwargs.get("output_format", "text")) + + async def _formula_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """公式识别""" + request_data = { + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "output_latex": True + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("formula_ocr", request_data) + + return self._format_formula_result(result, kwargs.get("output_format", "text")) + + async def _table_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """表格识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "output_excel": True + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("table_ocr", request_data) + + return self._format_table_result(result, kwargs.get("output_format", "text")) + + async def _document_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: + """文档识别""" + request_data = { + "language": kwargs.get("language", "auto"), + "return_location": kwargs.get("return_location", False), + "return_confidence": kwargs.get("return_confidence", True), + "layout_analysis": True + } + + if kwargs.get("image_content"): + request_data["image"] = kwargs["image_content"] + elif kwargs.get("image_url"): + request_data["image_url"] = kwargs["image_url"] + + result = await self._call_textin_api("document_ocr", request_data) + + return self._format_document_result(result, kwargs.get("output_format", "text")) + + def _format_ocr_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any] | None: + """格式化OCR结果""" + lines = result.get("lines", []) + + if output_format == "text": + text_content = "\n".join([line.get("text", "") for line in lines]) + return { + "recognition_mode": "ocr", + "text_content": text_content, + "line_count": len(lines), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + elif output_format == "json": + return { + "recognition_mode": "ocr", + "lines": lines, + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + elif output_format == "structured": + return { + "recognition_mode": "ocr", + "text_content": "\n".join([line.get("text", "") for line in lines]), + "structured_data": { + "lines": lines, + "paragraphs": self._group_lines_to_paragraphs(lines), + "statistics": { + "line_count": len(lines), + "word_count": sum(len(line.get("text", "").split()) for line in lines), + "character_count": sum(len(line.get("text", "")) for line in lines) + } + }, + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + """格式化公式识别结果""" + formulas = result.get("formulas", []) + + return { + "recognition_mode": "formula", + "formula_count": len(formulas), + "formulas": formulas, + "latex_content": "\n".join([f.get("latex", "") for f in formulas]), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + """格式化表格识别结果""" + tables = result.get("tables", []) + + return { + "recognition_mode": "table", + "table_count": len(tables), + "tables": tables, + "excel_data": result.get("excel_data"), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]: + """格式化文档识别结果""" + return { + "recognition_mode": "document", + "layout_info": result.get("layout_info", {}), + "text_blocks": result.get("text_blocks", []), + "image_blocks": result.get("image_blocks", []), + "table_blocks": result.get("table_blocks", []), + "full_text": result.get("full_text", ""), + "total_confidence": result.get("confidence", 0), + "processing_time": result.get("processing_time", 0) + } + + def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """将行分组为段落""" + paragraphs = [] + current_paragraph = [] + + for line in lines: + text = line.get("text", "").strip() + if text: + current_paragraph.append(line) + else: + if current_paragraph: + paragraphs.append({ + "text": " ".join([l.get("text", "") for l in current_paragraph]), + "lines": current_paragraph + }) + current_paragraph = [] + + if current_paragraph: + paragraphs.append({ + "text": " ".join([l.get("text", "") for l in current_paragraph]), + "lines": current_paragraph + }) + + return paragraphs + + async def _call_textin_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]: + """调用TextIn API""" + app_id = self.get_config_parameter("app_id") + secret_key = self.get_config_parameter("secret_key") + api_url = self.get_config_parameter("api_url") + + if not app_id or not secret_key or not api_url: + raise ValueError("TextIn API配置未完成") + + # 构建完整URL + url = f"{api_url.rstrip('/')}/{endpoint}" + + # 构建请求头 + headers = { + "X-App-Id": app_id, + "X-Secret-Key": secret_key, + "Content-Type": "application/json" + } + + # 发送请求 + timeout = aiohttp.ClientTimeout(total=30) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, json=data, headers=headers) as response: + if response.status == 200: + result = await response.json() + if result.get("code") == 200: + return result.get("data", result) + else: + raise Exception(f"TextIn API错误: {result.get('message', '未知错误')}") + else: + error_text = await response.text() + raise Exception(f"HTTP错误 {response.status}: {error_text}") + + def test_connection(self) -> Dict[str, Any]: + """测试连接""" + try: + app_id = self.get_config_parameter("app_id") + secret_key = self.get_config_parameter("secret_key") + api_url = self.get_config_parameter("api_url") + + if not app_id or not secret_key or not api_url: + return { + "success": False, + "error": "API配置未完成" + } + + return { + "success": True, + "message": "连接配置有效", + "api_url": api_url, + "app_id": app_id, + "secret_key_masked": secret_key[:8] + "***" if len(secret_key) > 8 else "***" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/chain_manager.py b/api/app/core/tools/chain_manager.py new file mode 100644 index 00000000..713baa39 --- /dev/null +++ b/api/app/core/tools/chain_manager.py @@ -0,0 +1,485 @@ +"""工具链管理器 - 支持langchain的工具链模式""" +from typing import List, Dict, Any, Optional +from dataclasses import dataclass +from enum import Enum + +from app.core.tools.base import ToolResult +from app.core.tools.executor import ToolExecutor +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class ChainExecutionMode(str, Enum): + """链执行模式""" + SEQUENTIAL = "sequential" # 顺序执行 + PARALLEL = "parallel" # 并行执行 + CONDITIONAL = "conditional" # 条件执行 + + +@dataclass +class ChainStep: + """链步骤定义""" + tool_id: str + parameters: Dict[str, Any] + condition: Optional[str] = None # 执行条件 + output_mapping: Optional[Dict[str, str]] = None # 输出映射 + error_handling: str = "stop" # 错误处理:stop, continue, retry + + +@dataclass +class ChainDefinition: + """工具链定义""" + name: str + description: str + steps: List[ChainStep] + execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL + global_timeout: Optional[float] = None + retry_policy: Optional[Dict[str, Any]] = None + + +class ChainExecutionContext: + """链执行上下文""" + + def __init__(self, chain_id: str): + self.chain_id = chain_id + self.variables: Dict[str, Any] = {} + self.step_results: Dict[int, ToolResult] = {} + self.current_step = 0 + self.is_completed = False + self.is_failed = False + self.error_message: Optional[str] = None + + +class ChainManager: + """工具链管理器 - 支持langchain的工具链模式""" + + def __init__(self, executor: ToolExecutor): + """初始化工具链管理器 + + Args: + executor: 工具执行器 + """ + self.executor = executor + self._chains: Dict[str, ChainDefinition] = {} + self._running_chains: Dict[str, ChainExecutionContext] = {} + + def register_chain(self, chain: ChainDefinition) -> bool: + """注册工具链 + + Args: + chain: 工具链定义 + + Returns: + 注册是否成功 + """ + try: + # 验证工具链定义 + validation_result = self._validate_chain(chain) + if not validation_result[0]: + logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}") + return False + + self._chains[chain.name] = chain + logger.info(f"工具链注册成功: {chain.name}") + return True + + except Exception as e: + logger.error(f"工具链注册失败: {chain.name}, 错误: {e}") + return False + + def unregister_chain(self, chain_name: str) -> bool: + """注销工具链 + + Args: + chain_name: 工具链名称 + + Returns: + 注销是否成功 + """ + if chain_name in self._chains: + del self._chains[chain_name] + logger.info(f"工具链注销成功: {chain_name}") + return True + + return False + + def list_chains(self) -> List[Dict[str, Any]]: + """列出所有工具链 + + Returns: + 工具链信息列表 + """ + chains = [] + for name, chain in self._chains.items(): + chains.append({ + "name": name, + "description": chain.description, + "step_count": len(chain.steps), + "execution_mode": chain.execution_mode.value, + "global_timeout": chain.global_timeout + }) + + return chains + + async def execute_chain( + self, + chain_name: str, + initial_variables: Optional[Dict[str, Any]] = None, + chain_id: Optional[str] = None + ) -> Dict[str, Any] | None: + """执行工具链 + + Args: + chain_name: 工具链名称 + initial_variables: 初始变量 + chain_id: 链执行ID(可选) + + Returns: + 执行结果 + """ + if chain_name not in self._chains: + return { + "success": False, + "error": f"工具链不存在: {chain_name}", + "chain_id": chain_id + } + + chain = self._chains[chain_name] + + # 生成链ID + if not chain_id: + import uuid + chain_id = f"chain_{uuid.uuid4().hex[:16]}" + + # 创建执行上下文 + context = ChainExecutionContext(chain_id) + context.variables = initial_variables or {} + self._running_chains[chain_id] = context + + try: + logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})") + + # 根据执行模式执行 + if chain.execution_mode == ChainExecutionMode.SEQUENTIAL: + result = await self._execute_sequential(chain, context) + elif chain.execution_mode == ChainExecutionMode.PARALLEL: + result = await self._execute_parallel(chain, context) + elif chain.execution_mode == ChainExecutionMode.CONDITIONAL: + result = await self._execute_conditional(chain, context) + else: + raise ValueError(f"不支持的执行模式: {chain.execution_mode}") + + logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})") + return result + + except Exception as e: + logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}") + return { + "success": False, + "error": str(e), + "chain_id": chain_id, + "completed_steps": context.current_step, + "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} + } + + finally: + # 清理执行上下文 + if chain_id in self._running_chains: + del self._running_chains[chain_id] + + async def _execute_sequential( + self, + chain: ChainDefinition, + context: ChainExecutionContext + ) -> Dict[str, Any]: + """顺序执行工具链""" + for i, step in enumerate(chain.steps): + context.current_step = i + + # 检查执行条件 + if step.condition and not self._evaluate_condition(step.condition, context): + logger.debug(f"跳过步骤 {i}: 条件不满足") + continue + + # 准备参数 + parameters = self._prepare_parameters(step.parameters, context) + + # 执行工具 + try: + result = await self.executor.execute_tool( + tool_id=step.tool_id, + parameters=parameters + ) + + context.step_results[i] = result + + # 处理输出映射 + if step.output_mapping and result.success: + self._apply_output_mapping(step.output_mapping, result.data, context) + + # 处理执行失败 + if not result.success: + if step.error_handling == "stop": + context.is_failed = True + context.error_message = result.error + break + elif step.error_handling == "continue": + logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}") + continue + elif step.error_handling == "retry": + # 简单重试逻辑 + retry_result = await self.executor.execute_tool( + tool_id=step.tool_id, + parameters=parameters + ) + context.step_results[i] = retry_result + if not retry_result.success and step.error_handling == "stop": + context.is_failed = True + context.error_message = retry_result.error + break + + except Exception as e: + logger.error(f"步骤 {i} 执行异常: {e}") + if step.error_handling == "stop": + context.is_failed = True + context.error_message = str(e) + break + + context.is_completed = not context.is_failed + + return { + "success": context.is_completed, + "error": context.error_message, + "chain_id": context.chain_id, + "completed_steps": context.current_step + 1, + "total_steps": len(chain.steps), + "final_variables": context.variables, + "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} + } + + async def _execute_parallel( + self, + chain: ChainDefinition, + context: ChainExecutionContext + ) -> Dict[str, Any]: + """并行执行工具链""" + # 准备所有步骤的执行配置 + execution_configs = [] + + for i, step in enumerate(chain.steps): + # 检查执行条件 + if step.condition and not self._evaluate_condition(step.condition, context): + continue + + parameters = self._prepare_parameters(step.parameters, context) + execution_configs.append({ + "step_index": i, + "tool_id": step.tool_id, + "parameters": parameters + }) + + # 并行执行所有步骤 + try: + results = await self.executor.execute_tools_batch(execution_configs) + + # 处理结果 + for i, result in enumerate(results): + step_index = execution_configs[i]["step_index"] + context.step_results[step_index] = result + + # 处理输出映射 + step = chain.steps[step_index] + if step.output_mapping and result.success: + self._apply_output_mapping(step.output_mapping, result.data, context) + + # 检查是否有失败的步骤 + failed_steps = [i for i, result in context.step_results.items() if not result.success] + + context.is_completed = len(failed_steps) == 0 + if failed_steps: + context.error_message = f"步骤 {failed_steps} 执行失败" + + except Exception as e: + context.is_failed = True + context.error_message = str(e) + + return { + "success": context.is_completed, + "error": context.error_message, + "chain_id": context.chain_id, + "completed_steps": len(context.step_results), + "total_steps": len(chain.steps), + "final_variables": context.variables, + "step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()} + } + + async def _execute_conditional( + self, + chain: ChainDefinition, + context: ChainExecutionContext + ) -> Dict[str, Any]: + """条件执行工具链""" + # 条件执行类似于顺序执行,但更严格地检查条件 + return await self._execute_sequential(chain, context) + + def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]: + """验证工具链定义 + + Args: + chain: 工具链定义 + + Returns: + (是否有效, 错误信息) + """ + if not chain.name: + return False, "工具链名称不能为空" + + if not chain.steps: + return False, "工具链必须包含至少一个步骤" + + for i, step in enumerate(chain.steps): + if not step.tool_id: + return False, f"步骤 {i} 缺少工具ID" + + if step.error_handling not in ["stop", "continue", "retry"]: + return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}" + + return True, None + + def _prepare_parameters( + self, + parameters: Dict[str, Any], + context: ChainExecutionContext + ) -> Dict[str, Any]: + """准备参数(支持变量替换) + + Args: + parameters: 原始参数 + context: 执行上下文 + + Returns: + 处理后的参数 + """ + prepared = {} + + for key, value in parameters.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + # 变量替换 + var_name = value[2:-1] + if var_name in context.variables: + prepared[key] = context.variables[var_name] + else: + prepared[key] = value # 保持原值 + else: + prepared[key] = value + + return prepared + + def _evaluate_condition( + self, + condition: str, + context: ChainExecutionContext + ) -> bool: + """评估执行条件 + + Args: + condition: 条件表达式 + context: 执行上下文 + + Returns: + 条件是否满足 + """ + try: + # 简单的条件评估(可以扩展为更复杂的表达式解析) + # 支持格式:variable == value, variable != value, variable > value 等 + + if "==" in condition: + var_name, expected_value = condition.split("==", 1) + var_name = var_name.strip() + expected_value = expected_value.strip().strip('"\'') + + return str(context.variables.get(var_name, "")) == expected_value + + elif "!=" in condition: + var_name, expected_value = condition.split("!=", 1) + var_name = var_name.strip() + expected_value = expected_value.strip().strip('"\'') + + return str(context.variables.get(var_name, "")) != expected_value + + elif condition in context.variables: + # 简单的布尔检查 + return bool(context.variables[condition]) + + else: + # 默认为真 + return True + + except Exception as e: + logger.error(f"条件评估失败: {condition}, 错误: {e}") + return False + + def _apply_output_mapping( + self, + mapping: Dict[str, str], + output_data: Any, + context: ChainExecutionContext + ): + """应用输出映射 + + Args: + mapping: 输出映射配置 + output_data: 输出数据 + context: 执行上下文 + """ + try: + if isinstance(output_data, dict): + for source_key, target_var in mapping.items(): + if source_key in output_data: + context.variables[target_var] = output_data[source_key] + else: + # 如果输出不是字典,将整个输出映射到指定变量 + if "result" in mapping: + context.variables[mapping["result"]] = output_data + + except Exception as e: + logger.error(f"输出映射失败: {e}") + + def _serialize_result(self, result: ToolResult) -> Dict[str, Any]: + """序列化工具结果 + + Args: + result: 工具结果 + + Returns: + 序列化的结果 + """ + return { + "success": result.success, + "data": result.data, + "error": result.error, + "error_code": result.error_code, + "execution_time": result.execution_time, + "token_usage": result.token_usage, + "metadata": result.metadata + } + + def get_running_chains(self) -> List[Dict[str, Any]]: + """获取正在运行的工具链 + + Returns: + 运行中的工具链列表 + """ + chains = [] + for chain_id, context in self._running_chains.items(): + chains.append({ + "chain_id": chain_id, + "current_step": context.current_step, + "is_completed": context.is_completed, + "is_failed": context.is_failed, + "variables_count": len(context.variables), + "completed_steps": len(context.step_results) + }) + + return chains \ No newline at end of file diff --git a/api/app/core/tools/config_manager.py b/api/app/core/tools/config_manager.py new file mode 100644 index 00000000..fb8d1fff --- /dev/null +++ b/api/app/core/tools/config_manager.py @@ -0,0 +1,264 @@ +"""工具配置管理器 - 管理工具配置的加载和验证""" +import json +from pathlib import Path +from typing import Dict, Any, Optional +from pydantic import BaseModel, ValidationError + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class ToolConfigSchema(BaseModel): + """工具配置基础Schema""" + name: str + description: str + tool_type: str + version: str = "1.0.0" + enabled: bool = True + parameters: Dict[str, Any] = {} + tags: list[str] = [] + + class Config: + extra = "allow" + + +class BuiltinToolConfigSchema(ToolConfigSchema): + """内置工具配置Schema""" + tool_class: str + tool_type: str = "builtin" + + +class CustomToolConfigSchema(ToolConfigSchema): + """自定义工具配置Schema""" + schema_url: Optional[str] = None + schema_content: Optional[Dict[str, Any]] = None + auth_type: str = "none" + auth_config: Dict[str, Any] = {} + base_url: Optional[str] = None + timeout: int = 30 + tool_type: str = "custom" + + +class MCPToolConfigSchema(ToolConfigSchema): + """MCP工具配置Schema""" + server_url: str + connection_config: Dict[str, Any] = {} + available_tools: list[str] = [] + tool_type: str = "mcp" + + +class ConfigManager: + """工具配置管理器""" + + def __init__(self, config_dir: Optional[str] = None): + """初始化配置管理器 + + Args: + config_dir: 配置文件目录,默认使用系统配置 + """ + self.config_dir = Path(config_dir or self._get_default_config_dir()) + self.config_dir.mkdir(parents=True, exist_ok=True) + + logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}") + + def _get_default_config_dir(self) -> str: + """获取默认配置目录""" + # 获取tools目录下的configs子目录 + tools_dir = Path(__file__).parent + return str(tools_dir / "configs") + + def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]: + """加载内置工具配置 + + Returns: + 内置工具配置字典 + """ + configs = {} + builtin_dir = self.config_dir / "builtin" + + if not builtin_dir.exists(): + logger.info("内置工具配置目录不存在,创建默认配置") + self._create_default_builtin_configs(builtin_dir) + + for config_file in builtin_dir.glob("*.json"): + try: + config_data = self._load_config_file(config_file) + config = BuiltinToolConfigSchema(**config_data) + configs[config.name] = config + logger.debug(f"加载内置工具配置: {config.name}") + except Exception as e: + logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}") + + return configs + + def load_builtin_tools_config(self) -> Dict[str, Any]: + """加载全局内置工具配置(兼容原有接口) + + Returns: + 内置工具配置字典 + """ + config_file = self.config_dir / "builtin_tools.json" + try: + with open(config_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"加载内置工具配置失败: {e}") + return {} + + def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum): + """确保内置工具已初始化到数据库 + + Args: + tenant_id: 租户ID + db_session: 数据库会话 + tool_config_model: ToolConfig模型类 + builtin_tool_config_model: BuiltinToolConfig模型类 + tool_type_enum: ToolType枚举 + tool_status_enum: ToolStatus枚举 + """ + # 检查是否已初始化 + existing_count = db_session.query(tool_config_model).filter( + tool_config_model.tenant_id == tenant_id, + tool_config_model.tool_type == tool_type_enum.BUILTIN + ).count() + + if existing_count > 0: + return # 已初始化 + + # 加载全局配置 + builtin_tools = self.load_builtin_tools_config() + + # 为租户创建内置工具记录 + for tool_key, tool_info in builtin_tools.items(): + # 设置初始状态 + initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value + + tool_config = tool_config_model( + name=tool_info['name'], + description=tool_info['description'], + tool_type=tool_type_enum.BUILTIN, + tenant_id=tenant_id, + status=initial_status + ) + db_session.add(tool_config) + db_session.flush() + + builtin_config = builtin_tool_config_model( + id=tool_config.id, + tool_class=tool_info['tool_class'], + parameters={} + ) + db_session.add(builtin_config) + + db_session.commit() + logger.info(f"租户 {tenant_id} 的内置工具初始化完成") + + def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool: + """保存工具配置 + + Args: + config: 工具配置 + tool_type: 工具类型 + + Returns: + 保存是否成功 + """ + try: + config_dir = self.config_dir / tool_type + config_dir.mkdir(parents=True, exist_ok=True) + + config_file = config_dir / f"{config.name}.json" + config_data = config.model_dump() + + with open(config_file, 'w', encoding='utf-8') as f: + json.dump(config_data, f, indent=2, ensure_ascii=False) + + logger.info(f"工具配置保存成功: {config.name} ({tool_type})") + return True + + except Exception as e: + logger.error(f"工具配置保存失败: {config.name}, 错误: {e}") + return False + + def delete_tool_config(self, tool_name: str, tool_type: str) -> bool: + """删除工具配置 + + Args: + tool_name: 工具名称 + tool_type: 工具类型 + + Returns: + 删除是否成功 + """ + try: + config_file = self.config_dir / tool_type / f"{tool_name}.json" + + if config_file.exists(): + config_file.unlink() + logger.info(f"工具配置删除成功: {tool_name} ({tool_type})") + return True + else: + logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})") + return False + + except Exception as e: + logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}") + return False + + def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]: + """验证工具配置 + + Args: + config_data: 配置数据 + tool_type: 工具类型 + + Returns: + (是否有效, 错误信息) + """ + try: + schema_map = { + "builtin": BuiltinToolConfigSchema, + "custom": CustomToolConfigSchema, + "mcp": MCPToolConfigSchema + } + + schema_class = schema_map.get(tool_type) + if not schema_class: + return False, f"不支持的工具类型: {tool_type}" + + # 验证配置 + schema_class(**config_data) + return True, None + + except ValidationError as e: + error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()]) + return False, f"配置验证失败: {error_msg}" + except Exception as e: + return False, f"配置验证异常: {str(e)}" + + def _load_config_file(self, config_file: Path) -> Dict[str, Any]: + """加载配置文件 + + Args: + config_file: 配置文件路径 + + Returns: + 配置数据字典 + """ + try: + with open(config_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + logger.error(f"加载配置文件失败: {config_file}, 错误: {e}") + raise + + def _create_default_builtin_configs(self, builtin_dir: Path): + """创建默认内置工具配置 + + Args: + builtin_dir: 内置工具配置目录 + """ + builtin_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"内置工具配置目录已创建: {builtin_dir}") + # 配置文件已经通过其他方式创建,这里只需要确保目录存在 \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/baidu_search_tool.json b/api/app/core/tools/configs/builtin/baidu_search_tool.json new file mode 100644 index 00000000..e46a34e3 --- /dev/null +++ b/api/app/core/tools/configs/builtin/baidu_search_tool.json @@ -0,0 +1,14 @@ +{ + "name": "baidu_search_tool", + "description": "百度搜索工具 - 网络搜索:提供网页搜索、新闻搜索、图片搜索功能", + "tool_type": "builtin", + "tool_class": "BaiduSearchTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": "", + "secret_key": "", + "search_type": "web" + }, + "tags": ["search", "web", "baidu", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/datetime_tool.json b/api/app/core/tools/configs/builtin/datetime_tool.json new file mode 100644 index 00000000..8652fd05 --- /dev/null +++ b/api/app/core/tools/configs/builtin/datetime_tool.json @@ -0,0 +1,12 @@ +{ + "name": "datetime_tool", + "description": "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算", + "tool_type": "builtin", + "tool_class": "DateTimeTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "timezone": "UTC" + }, + "tags": ["time", "utility", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/json_tool.json b/api/app/core/tools/configs/builtin/json_tool.json new file mode 100644 index 00000000..4c9f8c4a --- /dev/null +++ b/api/app/core/tools/configs/builtin/json_tool.json @@ -0,0 +1,12 @@ +{ + "name": "json_tool", + "description": "JSON工具 - 数据格式处理:提供JSON格式化、压缩、验证、格式转换", + "tool_type": "builtin", + "tool_class": "JsonTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "indent": 2 + }, + "tags": ["json", "data", "utility", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/mineru_tool.json b/api/app/core/tools/configs/builtin/mineru_tool.json new file mode 100644 index 00000000..e53d6a71 --- /dev/null +++ b/api/app/core/tools/configs/builtin/mineru_tool.json @@ -0,0 +1,14 @@ +{ + "name": "mineru_tool", + "description": "MinerU PDF解析工具 - 文档处理:提供PDF解析、表格提取、图片识别、文本提取功能", + "tool_type": "builtin", + "tool_class": "MinerUTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": "", + "parse_mode": "auto", + "timeout": 60 + }, + "tags": ["pdf", "document", "ocr", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin/textin_tool.json b/api/app/core/tools/configs/builtin/textin_tool.json new file mode 100644 index 00000000..d954f8f1 --- /dev/null +++ b/api/app/core/tools/configs/builtin/textin_tool.json @@ -0,0 +1,14 @@ +{ + "name": "textin_tool", + "description": "TextIn OCR工具 - 图像识别:提供通用OCR、手写识别、多语言支持功能", + "tool_type": "builtin", + "tool_class": "TextInTool", + "version": "1.0.0", + "enabled": true, + "parameters": { + "app_id": "", + "language": "auto", + "recognition_mode": "general" + }, + "tags": ["ocr", "image", "text", "builtin"] +} \ No newline at end of file diff --git a/api/app/core/tools/configs/builtin_tools.json b/api/app/core/tools/configs/builtin_tools.json new file mode 100644 index 00000000..ed0b87b1 --- /dev/null +++ b/api/app/core/tools/configs/builtin_tools.json @@ -0,0 +1,60 @@ +{ + "datetime": { + "name": "时间工具", + "description": "获取当前时间、日期计算", + "tool_class": "DateTimeTool", + "category": "utility", + "requires_config": false, + "version": "1.0.0", + "enabled": true, + "parameters": {} + }, + "json_converter": { + "name": "JSON转换工具", + "description": "JSON数据格式化和转换", + "tool_class": "JsonTool", + "category": "utility", + "requires_config": false, + "version": "1.0.0", + "enabled": true, + "parameters": {} + }, + "baidu_search": { + "name": "百度搜索", + "description": "百度网页搜索服务", + "tool_class": "BaiduSearchTool", + "category": "search", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true} + } + }, + "mineru": { + "name": "MinerU", + "description": "PDF文档解析工具", + "tool_class": "MinerUTool", + "category": "document", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": {"type": "string", "description": "MinerU API密钥", "sensitive": true, "required": true}, + "base_url": {"type": "string", "description": "API地址", "default": "https://api.mineru.com"} + } + }, + "textin": { + "name": "TextIn", + "description": "OCR文字识别服务", + "tool_class": "TextInTool", + "category": "ocr", + "requires_config": true, + "version": "1.0.0", + "enabled": true, + "parameters": { + "api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}, + "api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true} + } + } +} \ No newline at end of file diff --git a/api/app/core/tools/custom/__init__.py b/api/app/core/tools/custom/__init__.py new file mode 100644 index 00000000..87b0488a --- /dev/null +++ b/api/app/core/tools/custom/__init__.py @@ -0,0 +1,11 @@ +"""自定义工具模块""" + +from .base import CustomTool +from .schema_parser import OpenAPISchemaParser +from .auth_manager import AuthManager + +__all__ = [ + "CustomTool", + "OpenAPISchemaParser", + "AuthManager" +] \ No newline at end of file diff --git a/api/app/core/tools/custom/auth_manager.py b/api/app/core/tools/custom/auth_manager.py new file mode 100644 index 00000000..5d457f11 --- /dev/null +++ b/api/app/core/tools/custom/auth_manager.py @@ -0,0 +1,525 @@ +"""认证管理器 - 处理自定义工具的认证配置""" +import base64 +import hashlib +import hmac +import time +from typing import Dict, Any, Tuple +from urllib.parse import quote +import aiohttp + +from app.models.tool_model import AuthType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class AuthManager: + """认证管理器 - 支持多种认证方式""" + + def __init__(self): + """初始化认证管理器""" + self.supported_auth_types = [ + AuthType.NONE, + AuthType.API_KEY, + AuthType.BEARER_TOKEN + ] + + def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + """验证认证配置 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + + Returns: + (是否有效, 错误信息) + """ + try: + if auth_type not in self.supported_auth_types: + return False, f"不支持的认证类型: {auth_type}" + + if auth_type == AuthType.NONE: + return True, "" + + elif auth_type == AuthType.API_KEY: + return self._validate_api_key_config(auth_config) + + elif auth_type == AuthType.BEARER_TOKEN: + return self._validate_bearer_token_config(auth_config) + + return False, "未知的认证类型" + + except Exception as e: + return False, f"验证认证配置时出错: {e}" + + def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + """验证API Key认证配置 + + Args: + auth_config: 认证配置 + + Returns: + (是否有效, 错误信息) + """ + api_key = auth_config.get("api_key") + if not api_key: + return False, "API Key不能为空" + + if not isinstance(api_key, str): + return False, "API Key必须是字符串" + + # 验证key名称 + key_name = auth_config.get("key_name", "X-API-Key") + if not isinstance(key_name, str): + return False, "API Key名称必须是字符串" + + # 验证位置 + key_location = auth_config.get("location", "header") + if key_location not in ["header", "query", "cookie"]: + return False, "API Key位置必须是 header、query 或 cookie" + + return True, "" + + def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]: + """验证Bearer Token认证配置 + + Args: + auth_config: 认证配置 + + Returns: + (是否有效, 错误信息) + """ + token = auth_config.get("token") + if not token: + return False, "Bearer Token不能为空" + + if not isinstance(token, str): + return False, "Bearer Token必须是字符串" + + return True, "" + + def apply_authentication( + self, + auth_type: AuthType, + auth_config: Dict[str, Any], + url: str, + headers: Dict[str, str], + params: Dict[str, Any] + ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: + """应用认证到请求 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + url: 请求URL + headers: 请求头 + params: 请求参数 + + Returns: + (修改后的URL, 修改后的headers, 修改后的params) + """ + try: + if auth_type == AuthType.NONE: + return url, headers, params + + elif auth_type == AuthType.API_KEY: + return self._apply_api_key_auth(auth_config, url, headers, params) + + elif auth_type == AuthType.BEARER_TOKEN: + return self._apply_bearer_token_auth(auth_config, url, headers, params) + + else: + logger.warning(f"不支持的认证类型: {auth_type}") + return url, headers, params + + except Exception as e: + logger.error(f"应用认证时出错: {e}") + return url, headers, params + + def _apply_api_key_auth( + self, + auth_config: Dict[str, Any], + url: str, + headers: Dict[str, str], + params: Dict[str, Any] + ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: + """应用API Key认证 + + Args: + auth_config: 认证配置 + url: 请求URL + headers: 请求头 + params: 请求参数 + + Returns: + (修改后的URL, 修改后的headers, 修改后的params) + """ + api_key = auth_config.get("api_key") + key_name = auth_config.get("key_name", "X-API-Key") + location = auth_config.get("location", "header") + + if location == "header": + headers[key_name] = api_key + + elif location == "query": + # 添加到URL查询参数 + separator = "&" if "?" in url else "?" + encoded_key = quote(str(api_key)) + url += f"{separator}{key_name}={encoded_key}" + + elif location == "cookie": + # 添加到Cookie头 + cookie_value = f"{key_name}={api_key}" + if "Cookie" in headers: + headers["Cookie"] += f"; {cookie_value}" + else: + headers["Cookie"] = cookie_value + + return url, headers, params + + def _apply_bearer_token_auth( + self, + auth_config: Dict[str, Any], + url: str, + headers: Dict[str, str], + params: Dict[str, Any] + ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: + """应用Bearer Token认证 + + Args: + auth_config: 认证配置 + url: 请求URL + headers: 请求头 + params: 请求参数 + + Returns: + (修改后的URL, 修改后的headers, 修改后的params) + """ + token = auth_config.get("token") + headers["Authorization"] = f"Bearer {token}" + + return url, headers, params + + def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]: + """加密认证配置中的敏感信息 + + Args: + auth_config: 认证配置 + encryption_key: 加密密钥 + + Returns: + 加密后的认证配置 + """ + try: + encrypted_config = auth_config.copy() + + # 需要加密的字段 + sensitive_fields = ["api_key", "token", "secret", "password"] + + for field in sensitive_fields: + if field in encrypted_config: + value = encrypted_config[field] + if isinstance(value, str) and value: + encrypted_value = self._encrypt_string(value, encryption_key) + encrypted_config[field] = encrypted_value + encrypted_config[f"{field}_encrypted"] = True + + return encrypted_config + + except Exception as e: + logger.error(f"加密认证配置失败: {e}") + return auth_config + + def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]: + """解密认证配置中的敏感信息 + + Args: + encrypted_config: 加密的认证配置 + encryption_key: 解密密钥 + + Returns: + 解密后的认证配置 + """ + try: + decrypted_config = encrypted_config.copy() + + # 需要解密的字段 + sensitive_fields = ["api_key", "token", "secret", "password"] + + for field in sensitive_fields: + if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"): + encrypted_value = decrypted_config[field] + if isinstance(encrypted_value, str) and encrypted_value: + decrypted_value = self._decrypt_string(encrypted_value, encryption_key) + decrypted_config[field] = decrypted_value + # 移除加密标记 + decrypted_config.pop(f"{field}_encrypted", None) + + return decrypted_config + + except Exception as e: + logger.error(f"解密认证配置失败: {e}") + return encrypted_config + + def _encrypt_string(self, value: str, key: str) -> str: + """加密字符串 + + Args: + value: 要加密的字符串 + key: 加密密钥 + + Returns: + 加密后的字符串(Base64编码) + """ + try: + # 使用HMAC-SHA256进行简单加密 + key_bytes = key.encode('utf-8') + value_bytes = value.encode('utf-8') + + # 生成HMAC + hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256) + signature = hmac_obj.hexdigest() + + # 组合原始值和签名,然后Base64编码 + combined = f"{value}:{signature}" + encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8') + + return encrypted + + except Exception as e: + logger.error(f"加密字符串失败: {e}") + return value + + def _decrypt_string(self, encrypted_value: str, key: str) -> str: + """解密字符串 + + Args: + encrypted_value: 加密的字符串 + key: 解密密钥 + + Returns: + 解密后的字符串 + """ + try: + # Base64解码 + decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8') + + # 分离原始值和签名 + if ':' not in decoded: + return encrypted_value # 可能不是加密的值 + + value, signature = decoded.rsplit(':', 1) + + # 验证签名 + key_bytes = key.encode('utf-8') + value_bytes = value.encode('utf-8') + + hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256) + expected_signature = hmac_obj.hexdigest() + + if signature == expected_signature: + return value + else: + logger.warning("解密时签名验证失败") + return encrypted_value + + except Exception as e: + logger.error(f"解密字符串失败: {e}") + return encrypted_value + + def test_authentication( + self, + auth_type: AuthType, + auth_config: Dict[str, Any], + test_url: str = None + ) -> Dict[str, Any]: + """测试认证配置 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + test_url: 测试URL(可选) + + Returns: + 测试结果 + """ + try: + # 验证配置 + is_valid, error_msg = self.validate_auth_config(auth_type, auth_config) + if not is_valid: + return { + "success": False, + "error": error_msg, + "auth_type": auth_type.value + } + + # 如果没有测试URL,只验证配置 + if not test_url: + return { + "success": True, + "message": "认证配置有效", + "auth_type": auth_type.value + } + + # 构建测试请求 + headers = {"User-Agent": "AuthManager-Test/1.0"} + params = {} + + # 应用认证 + test_url, headers, params = self.apply_authentication( + auth_type, auth_config, test_url, headers, params + ) + + return { + "success": True, + "message": "认证配置测试成功", + "auth_type": auth_type.value, + "test_url": test_url, + "headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息 + "has_auth_header": "Authorization" in headers + } + + except Exception as e: + return { + "success": False, + "error": str(e), + "auth_type": auth_type.value if auth_type else "unknown" + } + + async def test_authentication_with_request( + self, + auth_type: AuthType, + auth_config: Dict[str, Any], + test_url: str, + timeout: int = 10 + ) -> Dict[str, Any]: + """通过实际HTTP请求测试认证 + + Args: + auth_type: 认证类型 + auth_config: 认证配置 + test_url: 测试URL + timeout: 超时时间(秒) + + Returns: + 测试结果 + """ + try: + # 验证配置 + is_valid, error_msg = self.validate_auth_config(auth_type, auth_config) + if not is_valid: + return { + "success": False, + "error": error_msg, + "auth_type": auth_type.value + } + + # 构建请求 + headers = {"User-Agent": "AuthManager-Test/1.0"} + params = {} + + # 应用认证 + test_url, headers, params = self.apply_authentication( + auth_type, auth_config, test_url, headers, params + ) + + # 发送测试请求 + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.get(test_url, headers=headers) as response: + status_code = response.status + + # 根据状态码判断认证是否成功 + if status_code == 200: + return { + "success": True, + "message": "认证测试成功", + "status_code": status_code, + "auth_type": auth_type.value + } + elif status_code == 401: + return { + "success": False, + "error": "认证失败 - 401 Unauthorized", + "status_code": status_code, + "auth_type": auth_type.value + } + elif status_code == 403: + return { + "success": False, + "error": "认证失败 - 403 Forbidden", + "status_code": status_code, + "auth_type": auth_type.value + } + else: + return { + "success": True, + "message": f"请求成功,状态码: {status_code}", + "status_code": status_code, + "auth_type": auth_type.value + } + + except aiohttp.ClientError as e: + return { + "success": False, + "error": f"网络请求失败: {e}", + "auth_type": auth_type.value + } + except Exception as e: + return { + "success": False, + "error": f"测试认证时出错: {e}", + "auth_type": auth_type.value + } + + def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]: + """获取认证配置模板 + + Args: + auth_type: 认证类型 + + Returns: + 配置模板 + """ + templates = { + AuthType.NONE: {}, + + AuthType.API_KEY: { + "api_key": "", + "key_name": "X-API-Key", + "location": "header", # header, query, cookie + "description": "API Key认证配置" + }, + + AuthType.BEARER_TOKEN: { + "token": "", + "description": "Bearer Token认证配置" + } + } + + return templates.get(auth_type, {}) + + def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]: + """遮蔽认证配置中的敏感信息 + + Args: + auth_config: 认证配置 + + Returns: + 遮蔽敏感信息后的配置 + """ + masked_config = auth_config.copy() + + # 需要遮蔽的字段 + sensitive_fields = ["api_key", "token", "secret", "password"] + + for field in sensitive_fields: + if field in masked_config: + value = masked_config[field] + if isinstance(value, str) and len(value) > 4: + # 只显示前2位和后2位 + masked_config[field] = f"{value[:2]}***{value[-2:]}" + elif isinstance(value, str) and value: + masked_config[field] = "***" + + return masked_config \ No newline at end of file diff --git a/api/app/core/tools/custom/base.py b/api/app/core/tools/custom/base.py new file mode 100644 index 00000000..eda6769b --- /dev/null +++ b/api/app/core/tools/custom/base.py @@ -0,0 +1,318 @@ +"""自定义工具基类""" +import time +from typing import Dict, Any, List, Optional +import aiohttp +from urllib.parse import urljoin + +from app.models.tool_model import ToolType, AuthType +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class CustomTool(BaseTool): + """自定义工具 - 基于OpenAPI schema的工具""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化自定义工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + super().__init__(tool_id, config) + self.schema_content = config.get("schema_content", {}) + self.schema_url = config.get("schema_url") + self.auth_type = AuthType(config.get("auth_type", "none")) + self.auth_config = config.get("auth_config", {}) + self.base_url = config.get("base_url", "") + self.timeout = config.get("timeout", 30) + + # 解析schema + self._parsed_operations = self._parse_openapi_schema() + + @property + def name(self) -> str: + """工具名称""" + if self.schema_content: + info = self.schema_content.get("info", {}) + return info.get("title", f"custom_tool_{self.tool_id[:8]}") + return f"custom_tool_{self.tool_id[:8]}" + + @property + def description(self) -> str: + """工具描述""" + if self.schema_content: + info = self.schema_content.get("info", {}) + return info.get("description", "自定义API工具") + return "自定义API工具" + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.CUSTOM + + @property + def parameters(self) -> List[ToolParameter]: + """工具参数定义""" + params = [] + + # 添加操作选择参数 + if len(self._parsed_operations) > 1: + params.append(ToolParameter( + name="operation", + type=ParameterType.STRING, + description="要执行的操作", + required=True, + enum=list(self._parsed_operations.keys()) + )) + + # 添加通用参数(基于第一个操作的参数) + if self._parsed_operations: + first_operation = next(iter(self._parsed_operations.values())) + for param_name, param_info in first_operation.get("parameters", {}).items(): + params.append(ToolParameter( + name=param_name, + type=self._convert_openapi_type(param_info.get("type", "string")), + description=param_info.get("description", ""), + required=param_info.get("required", False), + default=param_info.get("default"), + enum=param_info.get("enum"), + minimum=param_info.get("minimum"), + maximum=param_info.get("maximum"), + pattern=param_info.get("pattern") + )) + + return params + + async def execute(self, **kwargs) -> ToolResult: + """执行自定义工具""" + start_time = time.time() + + try: + # 确定要执行的操作 + operation_name = kwargs.get("operation") + if not operation_name and len(self._parsed_operations) == 1: + operation_name = next(iter(self._parsed_operations.keys())) + + if not operation_name or operation_name not in self._parsed_operations: + raise ValueError(f"无效的操作: {operation_name}") + + operation = self._parsed_operations[operation_name] + + # 构建请求 + url = self._build_request_url(operation, kwargs) + headers = self._build_request_headers(operation) + data = self._build_request_data(operation, kwargs) + + # 发送HTTP请求 + result = await self._send_http_request( + method=operation["method"], + url=url, + headers=headers, + data=data + ) + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="CUSTOM_TOOL_ERROR", + execution_time=execution_time + ) + + def _parse_openapi_schema(self) -> Dict[str, Any]: + """解析OpenAPI schema""" + operations = {} + + if not self.schema_content: + return operations + + paths = self.schema_content.get("paths", {}) + + for path, path_item in paths.items(): + for method, operation in path_item.items(): + if method.lower() in ["get", "post", "put", "delete", "patch"]: + operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}") + + # 解析参数 + parameters = {} + if "parameters" in operation: + for param in operation["parameters"]: + param_name = param.get("name") + param_schema = param.get("schema", {}) + parameters[param_name] = { + "type": param_schema.get("type", "string"), + "description": param.get("description", ""), + "required": param.get("required", False), + "in": param.get("in", "query"), + **param_schema + } + + # 解析请求体 + request_body = None + if "requestBody" in operation: + content = operation["requestBody"].get("content", {}) + if "application/json" in content: + request_body = content["application/json"].get("schema", {}) + + operations[operation_id] = { + "method": method.upper(), + "path": path, + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": parameters, + "request_body": request_body + } + + return operations + + def _convert_openapi_type(self, openapi_type: str) -> ParameterType: + """转换OpenAPI类型到内部类型""" + type_mapping = { + "string": ParameterType.STRING, + "integer": ParameterType.INTEGER, + "number": ParameterType.NUMBER, + "boolean": ParameterType.BOOLEAN, + "array": ParameterType.ARRAY, + "object": ParameterType.OBJECT + } + return type_mapping.get(openapi_type, ParameterType.STRING) + + def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str: + """构建请求URL""" + path = operation["path"] + + # 替换路径参数 + for param_name, param_info in operation.get("parameters", {}).items(): + if param_info.get("in") == "path" and param_name in params: + path = path.replace(f"{{{param_name}}}", str(params[param_name])) + + # 构建完整URL + if self.base_url: + url = urljoin(self.base_url, path.lstrip("/")) + else: + # 从schema中获取服务器URL + servers = self.schema_content.get("servers", []) + if servers: + base_url = servers[0].get("url", "") + url = urljoin(base_url, path.lstrip("/")) + else: + url = path + + # 添加查询参数 + query_params = {} + for param_name, param_info in operation.get("parameters", {}).items(): + if param_info.get("in") == "query" and param_name in params: + query_params[param_name] = params[param_name] + + if query_params: + from urllib.parse import urlencode + url += "?" + urlencode(query_params) + + return url + + def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]: + """构建请求头""" + headers = { + "Content-Type": "application/json", + "User-Agent": "CustomTool/1.0" + } + + # 添加认证头 + if self.auth_type == AuthType.API_KEY: + api_key = self.auth_config.get("api_key") + key_name = self.auth_config.get("key_name", "X-API-Key") + if api_key: + headers[key_name] = api_key + + elif self.auth_type == AuthType.BEARER_TOKEN: + token = self.auth_config.get("token") + if token: + headers["Authorization"] = f"Bearer {token}" + + return headers + + def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """构建请求数据""" + if operation["method"] in ["POST", "PUT", "PATCH"]: + request_body = operation.get("request_body") + if request_body: + # 构建请求体数据 + data = {} + properties = request_body.get("properties", {}) + + for prop_name, prop_schema in properties.items(): + if prop_name in params: + data[prop_name] = params[prop_name] + + return data if data else None + + return None + + async def _send_http_request( + self, + method: str, + url: str, + headers: Dict[str, str], + data: Optional[Dict[str, Any]] = None + ) -> Any: + """发送HTTP请求""" + timeout = aiohttp.ClientTimeout(total=self.timeout) + + async with aiohttp.ClientSession(timeout=timeout) as session: + kwargs = { + "headers": headers + } + + if data and method in ["POST", "PUT", "PATCH"]: + kwargs["json"] = data + + async with session.request(method, url, **kwargs) as response: + if response.status >= 400: + error_text = await response.text() + raise Exception(f"HTTP {response.status}: {error_text}") + + # 尝试解析JSON响应 + try: + return await response.json() + except Exception as e: + return await response.text() + + @classmethod + def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool': + """从URL导入OpenAPI schema创建工具""" + import uuid + if not tool_id: + tool_id = str(uuid.uuid4()) + + config = { + "schema_url": schema_url, + "auth_config": auth_config, + "auth_type": auth_config.get("type", "none") + } + + # 这里应该异步加载schema,为了简化暂时返回空配置 + return cls(tool_id, config) + + @classmethod + def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool': + """从schema字典创建工具""" + import uuid + if not tool_id: + tool_id = str(uuid.uuid4()) + + config = { + "schema_content": schema_dict, + "auth_config": auth_config, + "auth_type": auth_config.get("type", "none") + } + + return cls(tool_id, config) \ No newline at end of file diff --git a/api/app/core/tools/custom/schema_parser.py b/api/app/core/tools/custom/schema_parser.py new file mode 100644 index 00000000..21ac28b6 --- /dev/null +++ b/api/app/core/tools/custom/schema_parser.py @@ -0,0 +1,477 @@ +"""OpenAPI Schema解析器""" +import json +import yaml +from typing import Dict, Any, List, Optional, Tuple +from urllib.parse import urlparse +import aiohttp +import asyncio + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class OpenAPISchemaParser: + """OpenAPI Schema解析器 - 解析OpenAPI 3.0规范""" + + def __init__(self): + """初始化解析器""" + self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"] + + async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]: + """从URL解析OpenAPI schema + + Args: + schema_url: Schema URL + timeout: 超时时间(秒) + + Returns: + (是否成功, schema内容, 错误信息) + """ + try: + # 验证URL格式 + parsed_url = urlparse(schema_url) + if not parsed_url.scheme or not parsed_url.netloc: + return False, {}, "无效的URL格式" + + # 下载schema + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.get(schema_url) as response: + if response.status != 200: + return False, {}, f"HTTP错误: {response.status}" + + content_type = response.headers.get('content-type', '').lower() + content = await response.text() + + # 解析内容 + schema_dict = self._parse_content(content, content_type) + if not schema_dict: + return False, {}, "无法解析schema内容" + + # 验证schema + is_valid, error_msg = self.validate_schema(schema_dict) + if not is_valid: + return False, {}, error_msg + + return True, schema_dict, "" + + except asyncio.TimeoutError: + return False, {}, "请求超时" + except Exception as e: + logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}") + return False, {}, str(e) + + def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]: + """从内容解析OpenAPI schema + + Args: + content: Schema内容 + content_type: 内容类型 + + Returns: + (是否成功, schema内容, 错误信息) + """ + try: + # 解析内容 + schema_dict = self._parse_content(content, content_type) + if not schema_dict: + return False, {}, "无法解析schema内容" + + # 验证schema + is_valid, error_msg = self.validate_schema(schema_dict) + if not is_valid: + return False, {}, error_msg + + return True, schema_dict, "" + + except Exception as e: + logger.error(f"解析schema内容失败: {e}") + return False, {}, str(e) + + def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]: + """解析内容为字典 + + Args: + content: 内容字符串 + content_type: 内容类型 + + Returns: + 解析后的字典,失败返回None + """ + try: + # 根据内容类型解析 + if 'json' in content_type: + return json.loads(content) + elif 'yaml' in content_type or 'yml' in content_type: + return yaml.safe_load(content) + else: + # 尝试自动检测格式 + try: + return json.loads(content) + except json.JSONDecodeError: + try: + return yaml.safe_load(content) + except yaml.YAMLError: + return None + except Exception as e: + logger.error(f"解析内容失败: {e}") + return None + + def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]: + """验证OpenAPI schema + + Args: + schema_dict: Schema字典 + + Returns: + (是否有效, 错误信息) + """ + try: + # 检查基本结构 + if not isinstance(schema_dict, dict): + return False, "Schema必须是JSON对象" + + # 检查OpenAPI版本 + openapi_version = schema_dict.get("openapi") + if not openapi_version: + return False, "缺少openapi版本字段" + + if openapi_version not in self.supported_versions: + return False, f"不支持的OpenAPI版本: {openapi_version}" + + # 检查必需字段 + required_fields = ["info", "paths"] + for field in required_fields: + if field not in schema_dict: + return False, f"缺少必需字段: {field}" + + # 验证info字段 + info = schema_dict.get("info", {}) + if not isinstance(info, dict): + return False, "info字段必须是对象" + + if "title" not in info: + return False, "info.title字段是必需的" + + # 验证paths字段 + paths = schema_dict.get("paths", {}) + if not isinstance(paths, dict): + return False, "paths字段必须是对象" + + # 验证至少有一个路径 + if not paths: + return False, "至少需要定义一个API路径" + + return True, "" + + except Exception as e: + return False, f"验证schema时出错: {e}" + + def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """从schema提取工具信息 + + Args: + schema_dict: Schema字典 + + Returns: + 工具信息字典 + """ + info = schema_dict.get("info", {}) + + return { + "name": info.get("title", "Custom API Tool"), + "description": info.get("description", ""), + "version": info.get("version", "1.0.0"), + "servers": schema_dict.get("servers", []), + "operations": self._extract_operations(schema_dict) + } + + def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]: + """提取API操作信息 + + Args: + schema_dict: Schema字典 + + Returns: + 操作信息字典 + """ + operations = {} + paths = schema_dict.get("paths", {}) + + for path, path_item in paths.items(): + if not isinstance(path_item, dict): + continue + + for method, operation in path_item.items(): + if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]: + continue + + if not isinstance(operation, dict): + continue + + # 生成操作ID + operation_id = operation.get("operationId") + if not operation_id: + operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}" + + # 提取操作信息 + operations[operation_id] = { + "method": method.upper(), + "path": path, + "summary": operation.get("summary", ""), + "description": operation.get("description", ""), + "parameters": self._extract_parameters(operation), + "request_body": self._extract_request_body(operation), + "responses": self._extract_responses(operation), + "tags": operation.get("tags", []) + } + + return operations + + def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]: + """提取操作参数 + + Args: + operation: 操作定义 + + Returns: + 参数信息字典 + """ + parameters = {} + + for param in operation.get("parameters", []): + if not isinstance(param, dict): + continue + + param_name = param.get("name") + if not param_name: + continue + + param_schema = param.get("schema", {}) + + parameters[param_name] = { + "name": param_name, + "in": param.get("in", "query"), + "description": param.get("description", ""), + "required": param.get("required", False), + "type": param_schema.get("type", "string"), + "format": param_schema.get("format"), + "enum": param_schema.get("enum"), + "default": param_schema.get("default"), + "minimum": param_schema.get("minimum"), + "maximum": param_schema.get("maximum"), + "pattern": param_schema.get("pattern"), + "example": param.get("example") or param_schema.get("example") + } + + return parameters + + def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """提取请求体信息 + + Args: + operation: 操作定义 + + Returns: + 请求体信息,如果没有返回None + """ + request_body = operation.get("requestBody") + if not request_body: + return None + + content = request_body.get("content", {}) + + # 优先使用application/json + if "application/json" in content: + schema = content["application/json"].get("schema", {}) + elif content: + # 使用第一个可用的内容类型 + first_content_type = next(iter(content.keys())) + schema = content[first_content_type].get("schema", {}) + else: + return None + + return { + "description": request_body.get("description", ""), + "required": request_body.get("required", False), + "schema": schema, + "content_types": list(content.keys()) + } + + def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]: + """提取响应信息 + + Args: + operation: 操作定义 + + Returns: + 响应信息字典 + """ + responses = {} + + for status_code, response in operation.get("responses", {}).items(): + if not isinstance(response, dict): + continue + + content = response.get("content", {}) + schema = None + + # 尝试获取响应schema + if "application/json" in content: + schema = content["application/json"].get("schema") + elif content: + first_content_type = next(iter(content.keys())) + schema = content[first_content_type].get("schema") + + responses[status_code] = { + "description": response.get("description", ""), + "schema": schema, + "content_types": list(content.keys()) if content else [] + } + + return responses + + def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]: + """生成工具参数定义 + + Args: + operations: 操作信息字典 + + Returns: + 参数定义列表 + """ + parameters = [] + + # 如果有多个操作,添加操作选择参数 + if len(operations) > 1: + parameters.append({ + "name": "operation", + "type": "string", + "description": "要执行的操作", + "required": True, + "enum": list(operations.keys()) + }) + + # 收集所有参数(去重) + all_params = {} + + for operation_id, operation in operations.items(): + # 路径参数和查询参数 + for param_name, param_info in operation.get("parameters", {}).items(): + if param_name not in all_params: + all_params[param_name] = { + "name": param_name, + "type": param_info.get("type", "string"), + "description": param_info.get("description", ""), + "required": param_info.get("required", False), + "enum": param_info.get("enum"), + "default": param_info.get("default"), + "minimum": param_info.get("minimum"), + "maximum": param_info.get("maximum"), + "pattern": param_info.get("pattern") + } + + # 请求体参数 + request_body = operation.get("request_body") + if request_body: + schema = request_body.get("schema", {}) + properties = schema.get("properties", {}) + + for prop_name, prop_schema in properties.items(): + if prop_name not in all_params: + all_params[prop_name] = { + "name": prop_name, + "type": prop_schema.get("type", "string"), + "description": prop_schema.get("description", ""), + "required": prop_name in schema.get("required", []), + "enum": prop_schema.get("enum"), + "default": prop_schema.get("default"), + "minimum": prop_schema.get("minimum"), + "maximum": prop_schema.get("maximum"), + "pattern": prop_schema.get("pattern") + } + + # 转换为参数列表 + parameters.extend(all_params.values()) + + return parameters + + def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]: + """验证操作参数 + + Args: + operation: 操作定义 + params: 输入参数 + + Returns: + (是否有效, 错误信息列表) + """ + errors = [] + + # 验证路径参数和查询参数 + for param_name, param_info in operation.get("parameters", {}).items(): + if param_info.get("required", False) and param_name not in params: + errors.append(f"缺少必需参数: {param_name}") + + if param_name in params: + value = params[param_name] + param_type = param_info.get("type", "string") + + # 类型验证 + if not self._validate_parameter_type(value, param_type): + errors.append(f"参数 {param_name} 类型错误,期望: {param_type}") + + # 枚举验证 + enum_values = param_info.get("enum") + if enum_values and value not in enum_values: + errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}") + + # 验证请求体参数 + request_body = operation.get("request_body") + if request_body: + schema = request_body.get("schema", {}) + required_props = schema.get("required", []) + properties = schema.get("properties", {}) + + for prop_name in required_props: + if prop_name not in params: + errors.append(f"缺少必需的请求体参数: {prop_name}") + + for prop_name, value in params.items(): + if prop_name in properties: + prop_schema = properties[prop_name] + prop_type = prop_schema.get("type", "string") + + if not self._validate_parameter_type(value, prop_type): + errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}") + + return len(errors) == 0, errors + + def _validate_parameter_type(self, value: Any, expected_type: str) -> bool: + """验证参数类型 + + Args: + value: 参数值 + expected_type: 期望类型 + + Returns: + 是否类型匹配 + """ + if value is None: + return True + + type_mapping = { + "string": str, + "integer": int, + "number": (int, float), + "boolean": bool, + "array": list, + "object": dict + } + + expected_python_type = type_mapping.get(expected_type) + if expected_python_type: + return isinstance(value, expected_python_type) + + return True \ No newline at end of file diff --git a/api/app/core/tools/executor.py b/api/app/core/tools/executor.py new file mode 100644 index 00000000..c0ba87fb --- /dev/null +++ b/api/app/core/tools/executor.py @@ -0,0 +1,501 @@ +"""工具执行器 - 负责工具的实际调用和执行管理""" +import asyncio +import uuid +import time +from typing import Dict, Any, List, Optional +from datetime import datetime +from sqlalchemy.orm import Session + +from app.models.tool_model import ToolExecution, ExecutionStatus +from app.core.tools.base import BaseTool, ToolResult +from app.core.tools.registry import ToolRegistry +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class ExecutionContext: + """执行上下文""" + + def __init__( + self, + execution_id: str, + tool_id: str, + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + timeout: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None + ): + self.execution_id = execution_id + self.tool_id = tool_id + self.user_id = user_id + self.workspace_id = workspace_id + self.timeout = timeout or 60.0 # 默认60秒超时 + self.metadata = metadata or {} + self.started_at = datetime.now() + self.completed_at: Optional[datetime] = None + self.status = ExecutionStatus.PENDING + + +class ToolExecutor: + """工具执行器 - 使用langchain标准接口执行工具""" + + def __init__(self, db: Session, registry: ToolRegistry): + """初始化工具执行器 + + Args: + db: 数据库会话 + registry: 工具注册表 + """ + self.db = db + self.registry = registry + self._running_executions: Dict[str, ExecutionContext] = {} + self._execution_lock = asyncio.Lock() + + async def execute_tool( + self, + tool_id: str, + parameters: Dict[str, Any], + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + execution_id: Optional[str] = None, + timeout: Optional[float] = None, + metadata: Optional[Dict[str, Any]] = None + ) -> ToolResult: + """执行工具 + + Args: + tool_id: 工具ID + parameters: 工具参数 + user_id: 用户ID + workspace_id: 工作空间ID + execution_id: 执行ID(可选,自动生成) + timeout: 超时时间(秒) + metadata: 额外元数据 + + Returns: + 工具执行结果 + """ + # 生成执行ID + if not execution_id: + execution_id = f"exec_{uuid.uuid4().hex[:16]}" + + # 创建执行上下文 + context = ExecutionContext( + execution_id=execution_id, + tool_id=tool_id, + user_id=user_id, + workspace_id=workspace_id, + timeout=timeout, + metadata=metadata + ) + + try: + # 获取工具实例 + tool = self.registry.get_tool(tool_id) + if not tool: + return ToolResult.error_result( + error=f"工具不存在: {tool_id}", + error_code="TOOL_NOT_FOUND", + execution_time=0.0 + ) + + # 记录执行开始 + await self._record_execution_start(context, parameters) + + # 执行工具 + result = await self._execute_with_timeout(tool, parameters, context) + + # 记录执行完成 + await self._record_execution_complete(context, result) + + return result + + except Exception as e: + logger.error(f"工具执行异常: {execution_id}, 错误: {e}") + + # 记录执行失败 + error_result = ToolResult.error_result( + error=str(e), + error_code="EXECUTION_ERROR", + execution_time=time.time() - context.started_at.timestamp() + ) + await self._record_execution_complete(context, error_result) + + return error_result + + finally: + # 清理执行上下文 + async with self._execution_lock: + if execution_id in self._running_executions: + del self._running_executions[execution_id] + + async def execute_tools_batch( + self, + tool_executions: List[Dict[str, Any]], + max_concurrency: int = 5 + ) -> List[ToolResult]: + """批量执行工具 + + Args: + tool_executions: 工具执行配置列表,每个包含tool_id和parameters + max_concurrency: 最大并发数 + + Returns: + 执行结果列表 + """ + semaphore = asyncio.Semaphore(max_concurrency) + + async def execute_single(exec_config: Dict[str, Any]) -> ToolResult: + async with semaphore: + return await self.execute_tool( + tool_id=exec_config["tool_id"], + parameters=exec_config.get("parameters", {}), + user_id=exec_config.get("user_id"), + workspace_id=exec_config.get("workspace_id"), + timeout=exec_config.get("timeout"), + metadata=exec_config.get("metadata") + ) + + # 并发执行所有工具 + tasks = [execute_single(config) for config in tool_executions] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 处理异常结果 + processed_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + processed_results.append( + ToolResult.error_result( + error=str(result), + error_code="BATCH_EXECUTION_ERROR", + execution_time=0.0 + ) + ) + else: + processed_results.append(result) + + return processed_results + + async def cancel_execution(self, execution_id: str) -> bool: + """取消工具执行 + + Args: + execution_id: 执行ID + + Returns: + 是否成功取消 + """ + async with self._execution_lock: + if execution_id not in self._running_executions: + return False + + context = self._running_executions[execution_id] + context.status = ExecutionStatus.FAILED + + # 更新数据库记录 + execution_record = self.db.query(ToolExecution).filter( + ToolExecution.execution_id == execution_id + ).first() + + if execution_record: + execution_record.status = ExecutionStatus.FAILED.value + execution_record.error_message = "执行被取消" + execution_record.completed_at = datetime.now() + self.db.commit() + + logger.info(f"工具执行已取消: {execution_id}") + return True + + def get_running_executions(self) -> List[Dict[str, Any]]: + """获取正在运行的执行列表 + + Returns: + 执行信息列表 + """ + executions = [] + for execution_id, context in self._running_executions.items(): + executions.append({ + "execution_id": execution_id, + "tool_id": context.tool_id, + "user_id": str(context.user_id) if context.user_id else None, + "workspace_id": str(context.workspace_id) if context.workspace_id else None, + "started_at": context.started_at.isoformat(), + "status": context.status.value, + "elapsed_time": (datetime.now() - context.started_at).total_seconds() + }) + + return executions + + async def _execute_with_timeout( + self, + tool: BaseTool, + parameters: Dict[str, Any], + context: ExecutionContext + ) -> ToolResult: + """带超时的工具执行 + + Args: + tool: 工具实例 + parameters: 参数 + context: 执行上下文 + + Returns: + 执行结果 + """ + async with self._execution_lock: + self._running_executions[context.execution_id] = context + context.status = ExecutionStatus.RUNNING + + try: + # 使用asyncio.wait_for实现超时控制 + result = await asyncio.wait_for( + tool.safe_execute(**parameters), + timeout=context.timeout + ) + + context.status = ExecutionStatus.COMPLETED + return result + + except asyncio.TimeoutError: + context.status = ExecutionStatus.TIMEOUT + return ToolResult.error_result( + error=f"工具执行超时({context.timeout}秒)", + error_code="EXECUTION_TIMEOUT", + execution_time=context.timeout + ) + + except Exception as e: + context.status = ExecutionStatus.FAILED + raise + + async def _record_execution_start( + self, + context: ExecutionContext, + parameters: Dict[str, Any] + ): + """记录执行开始""" + try: + execution_record = ToolExecution( + execution_id=context.execution_id, + tool_config_id=uuid.UUID(context.tool_id), + status=ExecutionStatus.RUNNING.value, + input_data=parameters, + started_at=context.started_at, + user_id=context.user_id, + workspace_id=context.workspace_id + ) + + self.db.add(execution_record) + self.db.commit() + + logger.debug(f"执行记录已创建: {context.execution_id}") + + except Exception as e: + logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}") + + async def _record_execution_complete( + self, + context: ExecutionContext, + result: ToolResult + ): + """记录执行完成""" + try: + context.completed_at = datetime.now() + + execution_record = self.db.query(ToolExecution).filter( + ToolExecution.execution_id == context.execution_id + ).first() + + if execution_record: + execution_record.status = ( + ExecutionStatus.COMPLETED.value if result.success + else ExecutionStatus.FAILED.value + ) + execution_record.output_data = result.data if result.success else None + execution_record.error_message = result.error if not result.success else None + execution_record.completed_at = context.completed_at + execution_record.execution_time = result.execution_time + execution_record.token_usage = result.token_usage + + self.db.commit() + + logger.debug(f"执行记录已更新: {context.execution_id}") + + except Exception as e: + logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}") + + def get_execution_history( + self, + tool_id: Optional[str] = None, + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None, + limit: int = 100 + ) -> List[Dict[str, Any]]: + """获取执行历史 + + Args: + tool_id: 工具ID过滤 + user_id: 用户ID过滤 + workspace_id: 工作空间ID过滤 + limit: 返回数量限制 + + Returns: + 执行历史列表 + """ + try: + query = self.db.query(ToolExecution).order_by( + ToolExecution.started_at.desc() + ) + + if tool_id: + query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id)) + + if user_id: + query = query.filter(ToolExecution.user_id == user_id) + + if workspace_id: + query = query.filter(ToolExecution.workspace_id == workspace_id) + + executions = query.limit(limit).all() + + history = [] + for execution in executions: + history.append({ + "execution_id": execution.execution_id, + "tool_id": str(execution.tool_config_id), + "status": execution.status, + "started_at": execution.started_at.isoformat() if execution.started_at else None, + "completed_at": execution.completed_at.isoformat() if execution.completed_at else None, + "execution_time": execution.execution_time, + "user_id": str(execution.user_id) if execution.user_id else None, + "workspace_id": str(execution.workspace_id) if execution.workspace_id else None, + "input_data": execution.input_data, + "output_data": execution.output_data, + "error_message": execution.error_message, + "token_usage": execution.token_usage + }) + + return history + + except Exception as e: + logger.error(f"获取执行历史失败, 错误: {e}") + return [] + + def get_execution_statistics( + self, + workspace_id: Optional[uuid.UUID] = None, + days: int = 7 + ) -> Dict[str, Any]: + """获取执行统计信息 + + Args: + workspace_id: 工作空间ID + days: 统计天数 + + Returns: + 统计信息 + """ + try: + from datetime import timedelta + + start_date = datetime.now() - timedelta(days=days) + + query = self.db.query(ToolExecution).filter( + ToolExecution.started_at >= start_date + ) + + if workspace_id: + query = query.filter(ToolExecution.workspace_id == workspace_id) + + executions = query.all() + + # 统计数据 + total_executions = len(executions) + successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value]) + failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value]) + + # 平均执行时间 + completed_executions = [e for e in executions if e.execution_time is not None] + avg_execution_time = ( + sum(e.execution_time for e in completed_executions) / len(completed_executions) + if completed_executions else 0 + ) + + # 按工具统计 + tool_stats = {} + for execution in executions: + tool_id = str(execution.tool_config_id) + if tool_id not in tool_stats: + tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0} + + tool_stats[tool_id]["total"] += 1 + if execution.status == ExecutionStatus.COMPLETED.value: + tool_stats[tool_id]["successful"] += 1 + elif execution.status == ExecutionStatus.FAILED.value: + tool_stats[tool_id]["failed"] += 1 + + return { + "period_days": days, + "total_executions": total_executions, + "successful_executions": successful_executions, + "failed_executions": failed_executions, + "success_rate": successful_executions / total_executions if total_executions > 0 else 0, + "average_execution_time": avg_execution_time, + "tool_statistics": tool_stats + } + + except Exception as e: + logger.error(f"获取执行统计失败, 错误: {e}") + return {} + + async def test_tool_connection( + self, + tool_id: str, + user_id: Optional[uuid.UUID] = None, + workspace_id: Optional[uuid.UUID] = None + ) -> Dict[str, Any]: + """测试工具连接""" + try: + from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig + from .mcp.client import MCPClient + + tool_config = self.db.query(ToolConfig).filter( + ToolConfig.id == uuid.UUID(tool_id) + ).first() + + if not tool_config: + return {"success": False, "message": "工具不存在"} + + if tool_config.tool_type == ToolType.MCP.value: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == tool_config.id + ).first() + + if not mcp_config: + return {"success": False, "message": "MCP配置不存在"} + + client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {}) + + if await client.connect(): + try: + tools = await client.list_tools() + await client.disconnect() + return { + "success": True, + "message": "MCP连接成功", + "details": {"server_url": mcp_config.server_url, "tools": len(tools)} + } + except: + await client.disconnect() + return {"success": False, "message": "MCP功能测试失败"} + else: + return {"success": False, "message": "MCP连接失败"} + else: + tool = self.registry.get_tool(tool_id) + if tool and hasattr(tool, 'test_connection'): + result = tool.test_connection() + return {"success": result.get("success", False), "message": result.get("message", "")} + return {"success": True, "message": "工具无需连接测试"} + except Exception as e: + return {"success": False, "message": "测试失败", "error": str(e)} \ No newline at end of file diff --git a/api/app/core/tools/langchain_adapter.py b/api/app/core/tools/langchain_adapter.py new file mode 100644 index 00000000..1b6969b9 --- /dev/null +++ b/api/app/core/tools/langchain_adapter.py @@ -0,0 +1,375 @@ +"""Langchain适配器 - 将工具转换为langchain兼容格式""" +import json +from typing import Dict, Any, List, Optional, Type +from pydantic import BaseModel, Field +from langchain.tools import BaseTool as LangchainBaseTool +from langchain_core.tools import ToolException + +from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class LangchainToolWrapper(LangchainBaseTool): + """Langchain工具包装器""" + + name: str = Field(..., description="工具名称") + description: str = Field(..., description="工具描述") + args_schema: Optional[Type[BaseModel]] = Field(None, description="参数schema") + return_direct: bool = Field(False, description="是否直接返回结果") + + # 内部工具实例 + tool_instance: BaseTool = Field(..., description="内部工具实例") + + class Config: + arbitrary_types_allowed = True + + def __init__(self, tool_instance: BaseTool, **kwargs): + """初始化Langchain工具包装器 + + Args: + tool_instance: 内部工具实例 + """ + # 动态创建参数schema + args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters) + + super().__init__( + name=tool_instance.name, + description=tool_instance.description, + args_schema=args_schema, + _tool_instance=tool_instance, + **kwargs + ) + + def _run( + self, + run_manager=None, + **kwargs: Any, + ) -> str: + """同步执行工具(Langchain要求)""" + # 由于我们的工具是异步的,这里抛出异常提示使用异步版本 + raise NotImplementedError("请使用 _arun 方法进行异步调用") + + async def _arun( + self, + run_manager=None, + **kwargs: Any, + ) -> str: + """异步执行工具""" + try: + # 执行内部工具 + result = await self._tool_instance.safe_execute(**kwargs) + + # 转换结果为Langchain格式 + return LangchainAdapter._format_result_for_langchain(result) + + except Exception as e: + logger.error(f"工具执行失败: {self.name}, 错误: {e}") + raise ToolException(f"工具执行失败: {str(e)}") + + +class LangchainAdapter: + """Langchain适配器 - 负责工具格式转换和标准化""" + + @staticmethod + def convert_tool(tool: BaseTool) -> LangchainToolWrapper: + """将内部工具转换为Langchain工具 + + Args: + tool: 内部工具实例 + + Returns: + Langchain兼容的工具包装器 + """ + try: + wrapper = LangchainToolWrapper(tool_instance=tool) + logger.debug(f"工具转换成功: {tool.name} -> Langchain格式") + return wrapper + + except Exception as e: + logger.error(f"工具转换失败: {tool.name}, 错误: {e}") + raise + + @staticmethod + def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]: + """批量转换工具 + + Args: + tools: 工具列表 + + Returns: + Langchain工具列表 + """ + converted_tools = [] + + for tool in tools: + try: + converted_tool = LangchainAdapter.convert_tool(tool) + converted_tools.append(converted_tool) + except Exception as e: + logger.error(f"跳过工具转换: {tool.name}, 错误: {e}") + + logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具") + return converted_tools + + @staticmethod + def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]: + """根据工具参数创建Pydantic schema + + Args: + parameters: 工具参数列表 + + Returns: + Pydantic模型类 + """ + # 构建字段定义 + fields = {} + annotations = {} + + for param in parameters: + # 确定Python类型 + python_type = LangchainAdapter._get_python_type(param.type) + + # 处理可选参数 + if not param.required: + python_type = Optional[python_type] + + # 创建Field定义 + field_kwargs = { + "description": param.description + } + + if param.default is not None: + field_kwargs["default"] = param.default + elif not param.required: + field_kwargs["default"] = None + else: + field_kwargs["default"] = ... # 必需字段 + + # 添加验证约束 + if param.enum: + # 枚举值约束 + field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$" + + if param.minimum is not None: + field_kwargs["ge"] = param.minimum + + if param.maximum is not None: + field_kwargs["le"] = param.maximum + + if param.pattern: + field_kwargs["regex"] = param.pattern + + fields[param.name] = Field(**field_kwargs) + annotations[param.name] = python_type + + # 动态创建Pydantic模型 + schema_class = type( + "ToolArgsSchema", + (BaseModel,), + { + "__annotations__": annotations, + **fields, + "Config": type("Config", (), {"extra": "forbid"}) + } + ) + + return schema_class + + @staticmethod + def _get_python_type(param_type: ParameterType) -> type: + """获取参数类型对应的Python类型 + + Args: + param_type: 参数类型 + + Returns: + Python类型 + """ + type_mapping = { + ParameterType.STRING: str, + ParameterType.INTEGER: int, + ParameterType.NUMBER: float, + ParameterType.BOOLEAN: bool, + ParameterType.ARRAY: list, + ParameterType.OBJECT: dict + } + + return type_mapping.get(param_type, str) + + @staticmethod + def _format_result_for_langchain(result: ToolResult) -> str: + """将工具结果格式化为Langchain标准格式 + + Args: + result: 工具执行结果 + + Returns: + 格式化的字符串结果 + """ + if not result.success: + # 错误结果 + error_info = { + "success": False, + "error": result.error, + "error_code": result.error_code, + "execution_time": result.execution_time + } + return json.dumps(error_info, ensure_ascii=False, indent=2) + + # 成功结果 + if isinstance(result.data, str): + # 如果数据已经是字符串,直接返回 + return result.data + elif isinstance(result.data, (dict, list)): + # 如果是结构化数据,转换为JSON + return json.dumps(result.data, ensure_ascii=False, indent=2) + else: + # 其他类型转换为字符串 + return str(result.data) + + @staticmethod + def create_tool_description(tool: BaseTool) -> Dict[str, Any]: + """创建工具描述(用于工具发现和文档生成) + + Args: + tool: 工具实例 + + Returns: + 工具描述字典 + """ + return { + "name": tool.name, + "description": tool.description, + "tool_type": tool.tool_type.value, + "version": tool.version, + "status": tool.status.value, + "tags": tool.tags, + "parameters": [ + { + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default, + "enum": param.enum, + "minimum": param.minimum, + "maximum": param.maximum, + "pattern": param.pattern + } + for param in tool.parameters + ], + "langchain_compatible": True + } + + @staticmethod + def validate_langchain_compatibility(tool: BaseTool) -> tuple[bool, List[str]]: + """验证工具是否与Langchain兼容 + + Args: + tool: 工具实例 + + Returns: + (是否兼容, 问题列表) + """ + issues = [] + + # 检查工具名称 + if not tool.name or not isinstance(tool.name, str): + issues.append("工具名称必须是非空字符串") + + # 检查工具描述 + if not tool.description or not isinstance(tool.description, str): + issues.append("工具描述必须是非空字符串") + + # 检查参数定义 + for param in tool.parameters: + if not param.name or not isinstance(param.name, str): + issues.append(f"参数名称无效: {param.name}") + + if param.type not in ParameterType: + issues.append(f"不支持的参数类型: {param.type}") + + if param.required and param.default is not None: + issues.append(f"必需参数不应有默认值: {param.name}") + + # 检查是否有execute方法 + if not hasattr(tool, 'execute') or not callable(getattr(tool, 'execute')): + issues.append("工具必须实现execute方法") + + return len(issues) == 0, issues + + @staticmethod + def get_langchain_tool_schema(tool: BaseTool) -> Dict[str, Any]: + """获取Langchain工具的OpenAPI schema + + Args: + tool: 工具实例 + + Returns: + OpenAPI schema字典 + """ + # 构建参数schema + properties = {} + required = [] + + for param in tool.parameters: + prop_schema = { + "type": LangchainAdapter._get_openapi_type(param.type), + "description": param.description + } + + if param.enum: + prop_schema["enum"] = param.enum + + if param.minimum is not None: + prop_schema["minimum"] = param.minimum + + if param.maximum is not None: + prop_schema["maximum"] = param.maximum + + if param.pattern: + prop_schema["pattern"] = param.pattern + + if param.default is not None: + prop_schema["default"] = param.default + + properties[param.name] = prop_schema + + if param.required: + required.append(param.name) + + return { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": required + } + } + } + + @staticmethod + def _get_openapi_type(param_type: ParameterType) -> str: + """获取OpenAPI类型 + + Args: + param_type: 参数类型 + + Returns: + OpenAPI类型字符串 + """ + type_mapping = { + ParameterType.STRING: "string", + ParameterType.INTEGER: "integer", + ParameterType.NUMBER: "number", + ParameterType.BOOLEAN: "boolean", + ParameterType.ARRAY: "array", + ParameterType.OBJECT: "object" + } + + return type_mapping.get(param_type, "string") \ No newline at end of file diff --git a/api/app/core/tools/mcp/__init__.py b/api/app/core/tools/mcp/__init__.py new file mode 100644 index 00000000..faf13ceb --- /dev/null +++ b/api/app/core/tools/mcp/__init__.py @@ -0,0 +1,12 @@ +"""MCP工具模块""" + +from .base import MCPTool +from .client import MCPClient, MCPConnectionPool +from .service_manager import MCPServiceManager + +__all__ = [ + "MCPTool", + "MCPClient", + "MCPConnectionPool", + "MCPServiceManager" +] \ No newline at end of file diff --git a/api/app/core/tools/mcp/base.py b/api/app/core/tools/mcp/base.py new file mode 100644 index 00000000..241069cd --- /dev/null +++ b/api/app/core/tools/mcp/base.py @@ -0,0 +1,258 @@ +"""MCP工具基类""" +import time +from typing import Dict, Any, List +import aiohttp + +from app.models.tool_model import ToolType +from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class MCPTool(BaseTool): + """MCP工具 - Model Context Protocol工具""" + + def __init__(self, tool_id: str, config: Dict[str, Any]): + """初始化MCP工具 + + Args: + tool_id: 工具ID + config: 工具配置 + """ + super().__init__(tool_id, config) + self.server_url = config.get("server_url", "") + self.connection_config = config.get("connection_config", {}) + self.available_tools = config.get("available_tools", []) + self._client = None + self._connected = False + + @property + def name(self) -> str: + """工具名称""" + return f"mcp_tool_{self.tool_id[:8]}" + + @property + def description(self) -> str: + """工具描述""" + return f"MCP工具 - 连接到 {self.server_url}" + + @property + def tool_type(self) -> ToolType: + """工具类型""" + return ToolType.MCP + + @property + def parameters(self) -> List[ToolParameter]: + """工具参数定义""" + params = [] + + # 添加工具选择参数 + if len(self.available_tools) > 1: + params.append(ToolParameter( + name="tool_name", + type=ParameterType.STRING, + description="要调用的MCP工具名称", + required=True, + enum=self.available_tools + )) + + # 添加通用参数 + params.extend([ + ToolParameter( + name="arguments", + type=ParameterType.OBJECT, + description="工具参数(JSON对象)", + required=False, + default={} + ), + ToolParameter( + name="timeout", + type=ParameterType.INTEGER, + description="超时时间(秒)", + required=False, + default=30, + minimum=1, + maximum=300 + ) + ]) + + return params + + async def execute(self, **kwargs) -> ToolResult: + """执行MCP工具""" + start_time = time.time() + + try: + # 确保连接 + if not self._connected: + await self.connect() + + # 确定要调用的工具 + tool_name = kwargs.get("tool_name") + if not tool_name and len(self.available_tools) == 1: + tool_name = self.available_tools[0] + + if not tool_name: + raise ValueError("必须指定要调用的MCP工具名称") + + if tool_name not in self.available_tools: + raise ValueError(f"MCP工具不存在: {tool_name}") + + # 获取参数 + arguments = kwargs.get("arguments", {}) + timeout = kwargs.get("timeout", 30) + + # 调用MCP工具 + result = await self._call_mcp_tool(tool_name, arguments, timeout) + + execution_time = time.time() - start_time + return ToolResult.success_result( + data=result, + execution_time=execution_time + ) + + except Exception as e: + execution_time = time.time() - start_time + return ToolResult.error_result( + error=str(e), + error_code="MCP_ERROR", + execution_time=execution_time + ) + + async def connect(self) -> bool: + """连接到MCP服务器""" + try: + # 这里应该实现实际的MCP连接逻辑 + # 为了简化,这里只是模拟连接 + + # 测试服务器连接 + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + # 尝试获取服务器信息 + async with session.get(f"{self.server_url}/info") as response: + if response.status == 200: + server_info = await response.json() + self.available_tools = server_info.get("tools", []) + self._connected = True + logger.info(f"MCP服务器连接成功: {self.server_url}") + return True + else: + raise Exception(f"服务器响应错误: {response.status}") + + except Exception as e: + logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}") + self._connected = False + return False + + async def disconnect(self) -> bool: + """断开MCP服务器连接""" + try: + if self._client: + # 这里应该实现实际的断开逻辑 + self._client = None + + self._connected = False + logger.info(f"MCP服务器连接已断开: {self.server_url}") + return True + + except Exception as e: + logger.error(f"断开MCP服务器连接失败: {e}") + return False + + def get_health_status(self) -> Dict[str, Any]: + """获取MCP服务健康状态""" + return { + "connected": self._connected, + "server_url": self.server_url, + "available_tools": self.available_tools, + "last_check": time.time() + } + + async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any: + """调用MCP工具""" + # 构建MCP请求 + request_data = { + "jsonrpc": "2.0", + "id": f"req_{int(time.time() * 1000)}", + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + } + } + + # 发送请求 + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.post( + f"{self.server_url}/mcp", + json=request_data, + headers={"Content-Type": "application/json"} + ) as response: + + if response.status != 200: + error_text = await response.text() + raise Exception(f"MCP请求失败 {response.status}: {error_text}") + + result = await response.json() + + # 检查MCP响应 + if "error" in result: + error = result["error"] + raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}") + + return result.get("result", {}) + + async def list_available_tools(self) -> List[Dict[str, Any]]: + """列出可用的MCP工具""" + try: + if not self._connected: + await self.connect() + + # 获取工具列表 + request_data = { + "jsonrpc": "2.0", + "id": f"req_{int(time.time() * 1000)}", + "method": "tools/list" + } + + timeout = aiohttp.ClientTimeout(total=10) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + f"{self.server_url}/mcp", + json=request_data, + headers={"Content-Type": "application/json"} + ) as response: + + if response.status == 200: + result = await response.json() + if "result" in result: + tools = result["result"].get("tools", []) + self.available_tools = [tool.get("name") for tool in tools] + return tools + + return [] + + except Exception as e: + logger.error(f"获取MCP工具列表失败: {e}") + return [] + + def test_connection(self) -> Dict[str, Any]: + """测试MCP连接""" + try: + # 这里应该实现同步的连接测试 + # 为了简化,返回基本信息 + return { + "success": bool(self.server_url), + "server_url": self.server_url, + "connected": self._connected, + "available_tools_count": len(self.available_tools), + "message": "MCP配置有效" if self.server_url else "缺少服务器URL配置" + } + + except Exception as e: + return { + "success": False, + "error": str(e) + } \ No newline at end of file diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py new file mode 100644 index 00000000..3be2e9bf --- /dev/null +++ b/api/app/core/tools/mcp/client.py @@ -0,0 +1,626 @@ +"""MCP客户端 - Model Context Protocol客户端实现""" +import asyncio +import json +import time +from typing import Dict, Any, List, Optional, Callable +from urllib.parse import urlparse +import aiohttp +import websockets +from websockets.exceptions import ConnectionClosed + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class MCPConnectionError(Exception): + """MCP连接错误""" + pass + + +class MCPProtocolError(Exception): + """MCP协议错误""" + pass + + +class MCPClient: + """MCP客户端 - 支持HTTP和WebSocket连接""" + + def __init__(self, server_url: str, connection_config: Dict[str, Any] = None): + """初始化MCP客户端 + + Args: + server_url: MCP服务器URL + connection_config: 连接配置 + """ + self.server_url = server_url + self.connection_config = connection_config or {} + + # 解析URL确定连接类型 + parsed_url = urlparse(server_url) + self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http" + + # 连接状态 + self._connected = False + self._websocket = None + self._session = None + + # 请求管理 + self._request_id = 0 + self._pending_requests: Dict[str, asyncio.Future] = {} + + # 连接池配置 + self.max_connections = self.connection_config.get("max_connections", 10) + self.connection_timeout = self.connection_config.get("timeout", 30) + self.retry_attempts = self.connection_config.get("retry_attempts", 3) + self.retry_delay = self.connection_config.get("retry_delay", 1) + + # 健康检查 + self.health_check_interval = self.connection_config.get("health_check_interval", 60) + self._health_check_task = None + self._last_health_check = None + + # 事件回调 + self._on_connect_callbacks: List[Callable] = [] + self._on_disconnect_callbacks: List[Callable] = [] + self._on_error_callbacks: List[Callable] = [] + + async def connect(self) -> bool: + """连接到MCP服务器 + + Returns: + 连接是否成功 + """ + try: + if self._connected: + return True + + logger.info(f"连接MCP服务器: {self.server_url}") + + if self.connection_type == "websocket": + success = await self._connect_websocket() + else: + success = await self._connect_http() + + if success: + self._connected = True + await self._start_health_check() + await self._notify_connect_callbacks() + logger.info(f"MCP服务器连接成功: {self.server_url}") + + return success + + except Exception as e: + logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}") + await self._notify_error_callbacks(e) + return False + + async def disconnect(self) -> bool: + """断开MCP服务器连接 + + Returns: + 断开是否成功 + """ + try: + if not self._connected: + return True + + logger.info(f"断开MCP服务器连接: {self.server_url}") + + # 停止健康检查 + await self._stop_health_check() + + # 取消所有待处理的请求 + for future in self._pending_requests.values(): + if not future.done(): + future.cancel() + self._pending_requests.clear() + + # 断开连接 + if self.connection_type == "websocket" and self._websocket: + await self._websocket.close() + self._websocket = None + elif self._session: + await self._session.close() + self._session = None + + self._connected = False + await self._notify_disconnect_callbacks() + logger.info(f"MCP服务器连接已断开: {self.server_url}") + + return True + + except Exception as e: + logger.error(f"断开MCP服务器连接失败: {e}") + return False + + async def _connect_websocket(self) -> bool: + """建立WebSocket连接""" + try: + # WebSocket连接配置 + extra_headers = self.connection_config.get("headers", {}) + + self._websocket = await websockets.connect( + self.server_url, + extra_headers=extra_headers, + timeout=self.connection_timeout + ) + + # 启动消息监听 + asyncio.create_task(self._websocket_message_handler()) + + # 发送初始化消息 + init_message = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "clientInfo": { + "name": "ToolManagementSystem", + "version": "1.0.0" + } + } + } + + await self._websocket.send(json.dumps(init_message)) + + # 等待初始化响应 + response = await asyncio.wait_for( + self._websocket.recv(), + timeout=self.connection_timeout + ) + + init_response = json.loads(response) + if "error" in init_response: + raise MCPProtocolError(f"初始化失败: {init_response['error']}") + + return True + + except Exception as e: + logger.error(f"WebSocket连接失败: {e}") + return False + + async def _connect_http(self) -> bool: + """建立HTTP连接""" + try: + # HTTP会话配置 + timeout = aiohttp.ClientTimeout(total=self.connection_timeout) + headers = self.connection_config.get("headers", {}) + + self._session = aiohttp.ClientSession( + timeout=timeout, + headers=headers + ) + + # 测试连接 + test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health" + + async with self._session.get(test_url) as response: + if response.status == 200: + return True + else: + # 尝试根路径 + async with self._session.get(self.server_url) as root_response: + return root_response.status < 400 + + except Exception as e: + logger.error(f"HTTP连接失败: {e}") + if self._session: + await self._session.close() + self._session = None + return False + + async def _websocket_message_handler(self): + """WebSocket消息处理器""" + try: + while self._websocket and not self._websocket.closed: + try: + message = await self._websocket.recv() + await self._handle_message(json.loads(message)) + except ConnectionClosed: + break + except json.JSONDecodeError as e: + logger.error(f"解析WebSocket消息失败: {e}") + except Exception as e: + logger.error(f"处理WebSocket消息失败: {e}") + + except Exception as e: + logger.error(f"WebSocket消息处理器异常: {e}") + finally: + self._connected = False + await self._notify_disconnect_callbacks() + + async def _handle_message(self, message: Dict[str, Any]): + """处理收到的消息""" + try: + # 检查是否是响应消息 + if "id" in message: + request_id = str(message["id"]) + if request_id in self._pending_requests: + future = self._pending_requests.pop(request_id) + if not future.done(): + future.set_result(message) + + # 处理通知消息 + elif "method" in message: + await self._handle_notification(message) + + except Exception as e: + logger.error(f"处理消息失败: {e}") + + async def _handle_notification(self, message: Dict[str, Any]): + """处理通知消息""" + method = message.get("method") + params = message.get("params", {}) + + logger.debug(f"收到MCP通知: {method}, 参数: {params}") + + # 这里可以根据需要处理特定的通知 + # 例如:工具列表更新、服务器状态变化等 + + async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]: + """调用MCP工具 + + Args: + tool_name: 工具名称 + arguments: 工具参数 + timeout: 超时时间(秒) + + Returns: + 工具执行结果 + + Raises: + MCPConnectionError: 连接错误 + MCPProtocolError: 协议错误 + """ + if not self._connected: + raise MCPConnectionError("MCP客户端未连接") + + request_data = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + } + } + + try: + response = await self._send_request(request_data, timeout) + + if "error" in response: + error = response["error"] + raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}") + + return response.get("result", {}) + + except asyncio.TimeoutError: + raise MCPProtocolError(f"工具调用超时: {tool_name}") + + async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]: + """获取可用工具列表 + + Args: + timeout: 超时时间(秒) + + Returns: + 工具列表 + + Raises: + MCPConnectionError: 连接错误 + MCPProtocolError: 协议错误 + """ + if not self._connected: + raise MCPConnectionError("MCP客户端未连接") + + request_data = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "tools/list" + } + + try: + response = await self._send_request(request_data, timeout) + + if not response["error"] is None: + error = response["error"] + raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}") + + result = response.get("result", {}) + return result.get("tools", []) + + except asyncio.TimeoutError: + raise MCPProtocolError("获取工具列表超时") + + async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + """发送请求并等待响应 + + Args: + request_data: 请求数据 + timeout: 超时时间(秒) + + Returns: + 响应数据 + """ + request_id = str(request_data["id"]) + + if self.connection_type == "websocket": + return await self._send_websocket_request(request_data, request_id, timeout) + else: + return await self._send_http_request(request_data, timeout) + + async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]: + """发送WebSocket请求""" + if not self._websocket or self._websocket.closed: + raise MCPConnectionError("WebSocket连接已断开") + + # 创建Future等待响应 + future = asyncio.Future() + self._pending_requests[request_id] = future + + try: + # 发送请求 + await self._websocket.send(json.dumps(request_data)) + + # 等待响应 + response = await asyncio.wait_for(future, timeout=timeout) + return response + + except asyncio.TimeoutError: + self._pending_requests.pop(request_id, None) + raise + except Exception as e: + self._pending_requests.pop(request_id, None) + raise MCPConnectionError(f"发送WebSocket请求失败: {e}") + + async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + """发送HTTP请求""" + if not self._session: + raise MCPConnectionError("HTTP会话未建立") + + try: + url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp" + + async with self._session.post( + url, + json=request_data, + timeout=aiohttp.ClientTimeout(total=timeout) + ) as response: + + if response.status != 200: + error_text = await response.text() + raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}") + + return await response.json() + + except aiohttp.ClientError as e: + raise MCPConnectionError(f"HTTP请求失败: {e}") + + async def health_check(self) -> Dict[str, Any]: + """执行健康检查 + + Returns: + 健康状态信息 + """ + try: + if not self._connected: + return { + "healthy": False, + "error": "未连接", + "timestamp": time.time() + } + + # 发送ping请求 + request_data = { + "jsonrpc": "2.0", + "id": self._get_next_request_id(), + "method": "ping" + } + + start_time = time.time() + response = await self._send_request(request_data, timeout=5) + response_time = time.time() - start_time + + self._last_health_check = time.time() + + return { + "healthy": True, + "response_time": response_time, + "timestamp": self._last_health_check, + "server_info": response.get("result", {}) + } + + except Exception as e: + return { + "healthy": False, + "error": str(e), + "timestamp": time.time() + } + + async def _start_health_check(self): + """启动健康检查任务""" + if self.health_check_interval > 0: + self._health_check_task = asyncio.create_task(self._health_check_loop()) + + async def _stop_health_check(self): + """停止健康检查任务""" + if self._health_check_task: + self._health_check_task.cancel() + try: + await self._health_check_task + except asyncio.CancelledError: + pass + self._health_check_task = None + + async def _health_check_loop(self): + """健康检查循环""" + try: + while self._connected: + await asyncio.sleep(self.health_check_interval) + + if self._connected: + health_status = await self.health_check() + if not health_status["healthy"]: + logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}") + # 可以在这里实现重连逻辑 + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"健康检查循环异常: {e}") + + def _get_next_request_id(self) -> str: + """获取下一个请求ID""" + self._request_id += 1 + return f"req_{self._request_id}_{int(time.time() * 1000)}" + + # 事件回调管理 + def on_connect(self, callback: Callable): + """注册连接回调""" + self._on_connect_callbacks.append(callback) + + def on_disconnect(self, callback: Callable): + """注册断开连接回调""" + self._on_disconnect_callbacks.append(callback) + + def on_error(self, callback: Callable): + """注册错误回调""" + self._on_error_callbacks.append(callback) + + async def _notify_connect_callbacks(self): + """通知连接回调""" + for callback in self._on_connect_callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback() + else: + callback() + except Exception as e: + logger.error(f"连接回调执行失败: {e}") + + async def _notify_disconnect_callbacks(self): + """通知断开连接回调""" + for callback in self._on_disconnect_callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback() + else: + callback() + except Exception as e: + logger.error(f"断开连接回调执行失败: {e}") + + async def _notify_error_callbacks(self, error: Exception): + """通知错误回调""" + for callback in self._on_error_callbacks: + try: + if asyncio.iscoroutinefunction(callback): + await callback(error) + else: + callback(error) + except Exception as e: + logger.error(f"错误回调执行失败: {e}") + + @property + def is_connected(self) -> bool: + """检查是否已连接""" + return self._connected + + @property + def last_health_check(self) -> Optional[float]: + """获取最后一次健康检查时间""" + return self._last_health_check + + def get_connection_info(self) -> Dict[str, Any]: + """获取连接信息""" + return { + "server_url": self.server_url, + "connection_type": self.connection_type, + "connected": self._connected, + "last_health_check": self._last_health_check, + "pending_requests": len(self._pending_requests), + "config": self.connection_config + } + + async def __aenter__(self): + """异步上下文管理器入口""" + await self.connect() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器出口""" + await self.disconnect() + + +class MCPConnectionPool: + """MCP连接池 - 管理多个MCP客户端连接""" + + def __init__(self, max_connections: int = 10): + """初始化连接池 + + Args: + max_connections: 最大连接数 + """ + self.max_connections = max_connections + self._clients: Dict[str, MCPClient] = {} + self._lock = asyncio.Lock() + + async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient: + """获取或创建MCP客户端 + + Args: + server_url: 服务器URL + connection_config: 连接配置 + + Returns: + MCP客户端实例 + """ + async with self._lock: + if server_url in self._clients: + client = self._clients[server_url] + if client.is_connected: + return client + else: + # 尝试重连 + if await client.connect(): + return client + else: + # 移除失效的客户端 + del self._clients[server_url] + + # 检查连接数限制 + if len(self._clients) >= self.max_connections: + # 移除最旧的连接 + oldest_url = next(iter(self._clients)) + await self._clients[oldest_url].disconnect() + del self._clients[oldest_url] + + # 创建新客户端 + client = MCPClient(server_url, connection_config) + if await client.connect(): + self._clients[server_url] = client + return client + else: + raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}") + + async def disconnect_all(self): + """断开所有连接""" + async with self._lock: + for client in self._clients.values(): + await client.disconnect() + self._clients.clear() + + def get_pool_status(self) -> Dict[str, Any]: + """获取连接池状态""" + return { + "total_connections": len(self._clients), + "max_connections": self.max_connections, + "connections": { + url: client.get_connection_info() + for url, client in self._clients.items() + } + } \ No newline at end of file diff --git a/api/app/core/tools/mcp/service_manager.py b/api/app/core/tools/mcp/service_manager.py new file mode 100644 index 00000000..53b83ddd --- /dev/null +++ b/api/app/core/tools/mcp/service_manager.py @@ -0,0 +1,604 @@ +"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控""" +import asyncio +import time +import uuid +from typing import Dict, Any, List, Optional, Tuple +from datetime import datetime +from sqlalchemy.orm import Session + +from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType +from app.core.logging_config import get_business_logger +from .client import MCPClient, MCPConnectionPool + +logger = get_business_logger() + + +class MCPServiceManager: + """MCP服务管理器 - 管理MCP服务的生命周期""" + + def __init__(self, db: Session): + """初始化MCP服务管理器 + + Args: + db: 数据库会话 + """ + self.db = db + self.connection_pool = MCPConnectionPool(max_connections=20) + + # 服务状态管理 + self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info + self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task + + # 配置 + self.health_check_interval = 60 # 健康检查间隔(秒) + self.max_retry_attempts = 3 # 最大重试次数 + self.retry_delay = 5 # 重试延迟(秒) + + # 状态 + self._running = False + self._manager_task = None + + async def start(self): + """启动服务管理器""" + if self._running: + return + + self._running = True + logger.info("MCP服务管理器启动") + + # 加载现有服务 + await self._load_existing_services() + + # 启动管理任务 + self._manager_task = asyncio.create_task(self._management_loop()) + + async def stop(self): + """停止服务管理器""" + if not self._running: + return + + self._running = False + logger.info("MCP服务管理器停止") + + # 停止管理任务 + if self._manager_task: + self._manager_task.cancel() + try: + await self._manager_task + except asyncio.CancelledError: + pass + + # 停止所有监控任务 + for task in self._monitoring_tasks.values(): + task.cancel() + + if self._monitoring_tasks: + await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True) + + self._monitoring_tasks.clear() + + # 断开所有连接 + await self.connection_pool.disconnect_all() + + async def register_service( + self, + server_url: str, + connection_config: Dict[str, Any], + tenant_id: uuid.UUID, + service_name: str = None + ) -> Tuple[bool, str, Optional[str]]: + """注册MCP服务 + + Args: + server_url: 服务器URL + connection_config: 连接配置 + tenant_id: 租户ID + service_name: 服务名称(可选) + + Returns: + (是否成功, 服务ID或错误信息, 错误详情) + """ + try: + # 检查服务是否已存在 + existing_service = self.db.query(MCPToolConfig).filter( + MCPToolConfig.server_url == server_url + ).first() + + if existing_service: + return False, "服务已存在", f"URL {server_url} 已被注册" + + # 测试连接 + try: + client = MCPClient(server_url, connection_config) + if not await client.connect(): + return False, "连接测试失败", "无法连接到MCP服务器" + + # 获取可用工具 + available_tools = await client.list_tools() + tool_names = [tool.get("name") for tool in available_tools if tool.get("name")] + + await client.disconnect() + + except Exception as e: + return False, "连接测试失败", str(e) + + # 创建工具配置 + if not service_name: + service_name = f"mcp_service_{server_url.split('/')[-1]}" + + tool_config = ToolConfig( + name=service_name, + description=f"MCP服务 - {server_url}", + tool_type=ToolType.MCP.value, + tenant_id=tenant_id, + version="1.0.0", + config_data={ + "server_url": server_url, + "connection_config": connection_config + } + ) + + self.db.add(tool_config) + self.db.flush() + + # 创建MCP特定配置 + mcp_config = MCPToolConfig( + id=tool_config.id, + server_url=server_url, + connection_config=connection_config, + available_tools=tool_names, + health_status="healthy", + last_health_check=datetime.utcnow() + ) + + self.db.add(mcp_config) + self.db.commit() + + service_id = str(tool_config.id) + + # 添加到内存管理 + self._services[service_id] = { + "id": service_id, + "server_url": server_url, + "connection_config": connection_config, + "tenant_id": tenant_id, + "available_tools": tool_names, + "status": "healthy", + "last_health_check": time.time(), + "retry_count": 0, + "created_at": time.time() + } + + # 启动监控 + await self._start_service_monitoring(service_id) + + logger.info(f"MCP服务注册成功: {service_id} ({server_url})") + return True, service_id, None + + except Exception as e: + self.db.rollback() + logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}") + return False, "注册失败", str(e) + + async def unregister_service(self, service_id: str) -> Tuple[bool, str]: + """注销MCP服务 + + Args: + service_id: 服务ID + + Returns: + (是否成功, 错误信息) + """ + try: + # 从数据库删除 + tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) + if not tool_config: + return False, "服务不存在" + + self.db.delete(tool_config) + self.db.commit() + + # 停止监控 + await self._stop_service_monitoring(service_id) + + # 从内存移除 + if service_id in self._services: + del self._services[service_id] + + logger.info(f"MCP服务注销成功: {service_id}") + return True, "" + + except Exception as e: + self.db.rollback() + logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}") + return False, str(e) + + async def update_service( + self, + service_id: str, + connection_config: Dict[str, Any] = None, + enabled: bool = None + ) -> Tuple[bool, str]: + """更新MCP服务配置 + + Args: + service_id: 服务ID + connection_config: 新的连接配置 + enabled: 是否启用 + + Returns: + (是否成功, 错误信息) + """ + try: + # 更新数据库 + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == uuid.UUID(service_id) + ).first() + + if not mcp_config: + return False, "服务不存在" + + tool_config = mcp_config.base_config + + if connection_config is not None: + mcp_config.connection_config = connection_config + tool_config.config_data["connection_config"] = connection_config + + if enabled is not None: + tool_config.is_enabled = enabled + + self.db.commit() + + # 更新内存状态 + if service_id in self._services: + if connection_config is not None: + self._services[service_id]["connection_config"] = connection_config + + # 如果配置有变化,重启监控 + if connection_config is not None: + await self._restart_service_monitoring(service_id) + + logger.info(f"MCP服务更新成功: {service_id}") + return True, "" + + except Exception as e: + self.db.rollback() + logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}") + return False, str(e) + + async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]: + """获取服务状态 + + Args: + service_id: 服务ID + + Returns: + 服务状态信息 + """ + if service_id not in self._services: + return None + + service_info = self._services[service_id].copy() + + # 添加实时健康检查 + try: + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + health_status = await client.health_check() + service_info["real_time_health"] = health_status + + except Exception as e: + service_info["real_time_health"] = { + "healthy": False, + "error": str(e), + "timestamp": time.time() + } + + return service_info + + async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]: + """列出所有服务 + + Args: + tenant_id: 租户ID过滤 + + Returns: + 服务列表 + """ + services = [] + + for service_id, service_info in self._services.items(): + if tenant_id and service_info["tenant_id"] != tenant_id: + continue + + services.append(service_info.copy()) + + return services + + async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]: + """获取服务的可用工具 + + Args: + service_id: 服务ID + + Returns: + 工具列表 + """ + if service_id not in self._services: + return [] + + service_info = self._services[service_id] + + try: + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + tools = await client.list_tools() + + # 更新缓存的工具列表 + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + service_info["available_tools"] = tool_names + + # 更新数据库 + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == uuid.UUID(service_id) + ).first() + + if mcp_config: + mcp_config.available_tools = tool_names + self.db.commit() + + return tools + + except Exception as e: + logger.error(f"获取服务工具失败: {service_id}, 错误: {e}") + return [] + + async def call_service_tool( + self, + service_id: str, + tool_name: str, + arguments: Dict[str, Any], + timeout: int = 30 + ) -> Dict[str, Any]: + """调用服务工具 + + Args: + service_id: 服务ID + tool_name: 工具名称 + arguments: 工具参数 + timeout: 超时时间 + + Returns: + 执行结果 + """ + if service_id not in self._services: + raise ValueError(f"服务不存在: {service_id}") + + service_info = self._services[service_id] + + try: + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + result = await client.call_tool(tool_name, arguments, timeout) + + # 更新服务状态为健康 + service_info["status"] = "healthy" + service_info["last_health_check"] = time.time() + service_info["retry_count"] = 0 + + return result + + except Exception as e: + # 更新服务状态为错误 + service_info["status"] = "error" + service_info["last_error"] = str(e) + service_info["retry_count"] += 1 + + logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}") + raise + + async def _load_existing_services(self): + """加载现有服务""" + try: + mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter( + ToolConfig.is_enabled == True + ).all() + + for mcp_config in mcp_configs: + tool_config = mcp_config.base_config + service_id = str(mcp_config.id) + + self._services[service_id] = { + "id": service_id, + "server_url": mcp_config.server_url, + "connection_config": mcp_config.connection_config or {}, + "tenant_id": tool_config.tenant_id, + "available_tools": mcp_config.available_tools or [], + "status": mcp_config.health_status or "unknown", + "last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0, + "retry_count": 0, + "created_at": tool_config.created_at.timestamp() + } + + # 启动监控 + await self._start_service_monitoring(service_id) + + logger.info(f"加载了 {len(mcp_configs)} 个MCP服务") + + except Exception as e: + logger.error(f"加载现有服务失败: {e}") + + async def _start_service_monitoring(self, service_id: str): + """启动服务监控""" + if service_id in self._monitoring_tasks: + return + + task = asyncio.create_task(self._monitor_service(service_id)) + self._monitoring_tasks[service_id] = task + + async def _stop_service_monitoring(self, service_id: str): + """停止服务监控""" + if service_id in self._monitoring_tasks: + task = self._monitoring_tasks.pop(service_id) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + async def _restart_service_monitoring(self, service_id: str): + """重启服务监控""" + await self._stop_service_monitoring(service_id) + await self._start_service_monitoring(service_id) + + async def _monitor_service(self, service_id: str): + """监控单个服务""" + try: + while self._running and service_id in self._services: + service_info = self._services[service_id] + + try: + # 执行健康检查 + client = await self.connection_pool.get_client( + service_info["server_url"], + service_info["connection_config"] + ) + + health_status = await client.health_check() + + if health_status["healthy"]: + # 服务健康 + service_info["status"] = "healthy" + service_info["retry_count"] = 0 + + # 更新工具列表 + try: + tools = await client.list_tools() + tool_names = [tool.get("name") for tool in tools if tool.get("name")] + service_info["available_tools"] = tool_names + except Exception as e: + logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}") + + else: + # 服务不健康 + service_info["status"] = "unhealthy" + service_info["last_error"] = health_status.get("error", "健康检查失败") + service_info["retry_count"] += 1 + + service_info["last_health_check"] = time.time() + + # 更新数据库 + await self._update_service_health_in_db(service_id, health_status) + + except Exception as e: + # 监控异常 + service_info["status"] = "error" + service_info["last_error"] = str(e) + service_info["retry_count"] += 1 + service_info["last_health_check"] = time.time() + + logger.error(f"服务监控异常: {service_id}, 错误: {e}") + + # 如果重试次数过多,暂停监控 + if service_info["retry_count"] >= self.max_retry_attempts: + logger.warning(f"服务 {service_id} 重试次数过多,暂停监控") + await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间 + service_info["retry_count"] = 0 # 重置重试计数 + + # 等待下次检查 + await asyncio.sleep(self.health_check_interval) + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"服务监控任务异常: {service_id}, 错误: {e}") + + async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]): + """更新数据库中的服务健康状态""" + try: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == uuid.UUID(service_id) + ).first() + + if mcp_config: + mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy" + mcp_config.last_health_check = datetime.utcnow() + + if not health_status["healthy"]: + mcp_config.error_message = health_status.get("error", "") + else: + mcp_config.error_message = None + + self.db.commit() + + except Exception as e: + logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}") + self.db.rollback() + + async def _management_loop(self): + """管理循环 - 处理服务清理等任务""" + try: + while self._running: + # 清理失效的服务 + await self._cleanup_failed_services() + + # 等待下次循环 + await asyncio.sleep(300) # 5分钟 + + except asyncio.CancelledError: + pass + except Exception as e: + logger.error(f"管理循环异常: {e}") + + async def _cleanup_failed_services(self): + """清理长期失效的服务""" + try: + current_time = time.time() + cleanup_threshold = 24 * 60 * 60 # 24小时 + + services_to_cleanup = [] + + for service_id, service_info in self._services.items(): + # 检查服务是否长期失效 + if (service_info["status"] in ["error", "unhealthy"] and + current_time - service_info["last_health_check"] > cleanup_threshold): + + services_to_cleanup.append(service_id) + + for service_id in services_to_cleanup: + logger.warning(f"清理长期失效的服务: {service_id}") + + # 停止监控但不删除数据库记录 + await self._stop_service_monitoring(service_id) + + # 标记为禁用 + tool_config = self.db.get(ToolConfig, uuid.UUID(service_id)) + if tool_config: + tool_config.is_enabled = False + self.db.commit() + + # 从内存移除 + del self._services[service_id] + + except Exception as e: + logger.error(f"清理失效服务失败: {e}") + + def get_manager_status(self) -> Dict[str, Any]: + """获取管理器状态""" + return { + "running": self._running, + "total_services": len(self._services), + "healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]), + "unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]), + "monitoring_tasks": len(self._monitoring_tasks), + "connection_pool_status": self.connection_pool.get_pool_status() + } \ No newline at end of file diff --git a/api/app/core/tools/registry.py b/api/app/core/tools/registry.py new file mode 100644 index 00000000..b56c1bf7 --- /dev/null +++ b/api/app/core/tools/registry.py @@ -0,0 +1,436 @@ +"""工具注册表 - 管理所有工具的元数据和状态""" +import uuid +import asyncio +from typing import Dict, List, Optional, Type, Any + +from sqlalchemy.orm import Session +from sqlalchemy import and_, or_ + +from app.models.tool_model import ( + ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, + ToolType, ToolStatus, ToolExecution, ExecutionStatus +) +from app.core.logging_config import get_business_logger +from .base import BaseTool, ToolInfo +from .custom.base import CustomTool +from .mcp.base import MCPTool + +logger = get_business_logger() + + +class ToolRegistry: + """工具注册表 - 管理所有工具的元数据和实例""" + + def __init__(self, db: Session): + """初始化工具注册表 + + Args: + db: 数据库会话 + """ + self.db = db + self._tools: Dict[str, BaseTool] = {} # 工具实例缓存 + self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表 + self._lock = asyncio.Lock() # 异步锁 + + def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None): + """注册工具类 + + Args: + tool_class: 工具类 + class_name: 类名(可选,默认使用类的__name__) + """ + class_name = class_name or tool_class.__name__ + self._tool_classes[class_name] = tool_class + logger.info(f"工具类已注册: {class_name}") + + async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool: + """注册工具实例到系统 + + Args: + tool: 工具实例 + tenant_id: 租户ID(内置工具可以为None,表示全局工具) + + Returns: + 注册是否成功 + """ + async with self._lock: + try: + # 检查工具是否已存在 + if tenant_id: + existing_config = self.db.query(ToolConfig).filter( + and_( + ToolConfig.name == tool.name, + ToolConfig.tenant_id == tenant_id, + ToolConfig.tool_type == tool.tool_type.value + ) + ).first() + else: + # 全局工具(内置工具) + existing_config = self.db.query(ToolConfig).filter( + and_( + ToolConfig.name == tool.name, + ToolConfig.tenant_id.is_(None), + ToolConfig.tool_type == tool.tool_type.value + ) + ).first() + + if existing_config: + logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})") + return False + + # 创建工具配置 + tool_config = ToolConfig( + name=tool.name, + description=tool.description, + tool_type=tool.tool_type.value, + tenant_id=tenant_id, + version=tool.version, + tags=tool.tags, + config_data=tool.config + ) + + self.db.add(tool_config) + self.db.flush() # 获取ID + + # 根据工具类型创建特定配置 + if tool.tool_type == ToolType.BUILTIN: + builtin_config = BuiltinToolConfig( + id=tool_config.id, + tool_class=tool.__class__.__name__, + parameters=tool.config.get("parameters", {}) + ) + self.db.add(builtin_config) + + elif tool.tool_type == ToolType.CUSTOM: + custom_config = CustomToolConfig( + id=tool_config.id, + schema_url=tool.config.get("schema_url"), + schema_content=tool.config.get("schema_content"), + auth_type=tool.config.get("auth_type", "none"), + auth_config=tool.config.get("auth_config", {}), + base_url=tool.config.get("base_url"), + timeout=tool.config.get("timeout", 30) + ) + self.db.add(custom_config) + + elif tool.tool_type == ToolType.MCP: + mcp_config = MCPToolConfig( + id=tool_config.id, + server_url=tool.config.get("server_url"), + connection_config=tool.config.get("connection_config", {}), + available_tools=tool.config.get("available_tools", []) + ) + self.db.add(mcp_config) + + self.db.commit() + + # 缓存工具实例 + tool.tool_id = str(tool_config.id) + self._tools[str(tool_config.id)] = tool + + logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})") + return True + + except Exception as e: + self.db.rollback() + logger.error(f"工具注册失败: {tool.name}, 错误: {e}") + return False + + async def unregister_tool(self, tool_id: str) -> bool: + """从系统注销工具 + + Args: + tool_id: 工具ID + + Returns: + 注销是否成功 + """ + async with self._lock: + try: + # 检查工具是否存在 + tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) + if not tool_config: + logger.warning(f"工具不存在: {tool_id}") + return False + + # 检查是否有正在执行的任务 + running_executions = self.db.query(ToolExecution).filter( + and_( + ToolExecution.tool_config_id == uuid.UUID(tool_id), + ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value]) + ) + ).count() + + if running_executions > 0: + logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}") + return False + + # 删除工具配置(级联删除相关记录) + self.db.delete(tool_config) + self.db.commit() + + # 从缓存中移除 + if tool_id in self._tools: + del self._tools[tool_id] + + logger.info(f"工具注销成功: {tool_id}") + return True + + except Exception as e: + self.db.rollback() + logger.error(f"工具注销失败: {tool_id}, 错误: {e}") + return False + + def get_tool(self, tool_id: str) -> Optional[BaseTool]: + """获取工具实例 + + Args: + tool_id: 工具ID + + Returns: + 工具实例,如果不存在返回None + """ + # 先从缓存获取 + if tool_id in self._tools: + return self._tools[tool_id] + + # 从数据库加载 + try: + tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) + if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value: + return None + + # 根据工具类型加载实例 + tool_instance = self._load_tool_instance(tool_config) + if tool_instance: + self._tools[tool_id] = tool_instance + return tool_instance + + except Exception as e: + logger.error(f"加载工具失败: {tool_id}, 错误: {e}") + + return None + + def list_tools( + self, + tenant_id: Optional[uuid.UUID] = None, + tool_type: Optional[ToolType] = None, + status: Optional[ToolStatus] = None, + tags: Optional[List[str]] = None + ) -> List[ToolInfo]: + """列出工具 + + Args: + tenant_id: 租户ID过滤 + tool_type: 工具类型过滤 + status: 工具状态过滤 + tags: 标签过滤 + + Returns: + 工具信息列表 + """ + try: + query = self.db.query(ToolConfig) + + # 应用过滤条件 + if tenant_id: + # 返回全局工具(tenant_id为空)和该租户的工具 + query = query.filter( + or_( + ToolConfig.tenant_id == tenant_id, + ToolConfig.tenant_id.is_(None) + ) + ) + + if tool_type: + query = query.filter(ToolConfig.tool_type == tool_type.value) + + if status == ToolStatus.ACTIVE: + query = query.filter(ToolConfig.is_enabled == True) + elif status == ToolStatus.INACTIVE: + query = query.filter(ToolConfig.is_enabled == False) + + if tags: + for tag in tags: + query = query.filter(ToolConfig.tags.contains([tag])) + + tool_configs = query.all() + + # 转换为ToolInfo + tool_infos = [] + for config in tool_configs: + tool_info = ToolInfo( + id=str(config.id), + name=config.name, + description=config.description or "", + tool_type=ToolType(config.tool_type), + version=config.version, + status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE, + tags=config.tags or [], + tenant_id=str(config.tenant_id) if config.tenant_id else None + ) + + # 尝试获取参数信息 + tool_instance = self.get_tool(str(config.id)) + if tool_instance: + tool_info.parameters = tool_instance.parameters + + tool_infos.append(tool_info) + + return tool_infos + + except Exception as e: + logger.error(f"列出工具失败, 错误: {e}") + return [] + + async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool: + """更新工具状态 + + Args: + tool_id: 工具ID + status: 新状态 + + Returns: + 更新是否成功 + """ + try: + tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id)) + if not tool_config: + logger.warning(f"工具不存在: {tool_id}") + return False + + # 更新状态 + if status == ToolStatus.ACTIVE: + tool_config.is_enabled = True + elif status == ToolStatus.INACTIVE: + tool_config.is_enabled = False + + self.db.commit() + + # 更新缓存中的工具状态 + if tool_id in self._tools: + self._tools[tool_id].status = status + + logger.info(f"工具状态更新成功: {tool_id} -> {status}") + return True + + except Exception as e: + self.db.rollback() + logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}") + return False + + def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]: + """从配置加载工具实例 + + Args: + tool_config: 工具配置 + + Returns: + 工具实例 + """ + try: + if tool_config.tool_type == ToolType.BUILTIN.value: + # 加载内置工具 + builtin_config = self.db.query(BuiltinToolConfig).filter( + BuiltinToolConfig.id == tool_config.id + ).first() + + if builtin_config and builtin_config.tool_class in self._tool_classes: + tool_class = self._tool_classes[builtin_config.tool_class] + config = { + **tool_config.config_data, + "parameters": builtin_config.parameters, + "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, + "version": tool_config.version, + "tags": tool_config.tags + } + return tool_class(str(tool_config.id), config) + + elif tool_config.tool_type == ToolType.CUSTOM.value: + # 加载自定义工具 + try: + custom_config = self.db.query(CustomToolConfig).filter( + CustomToolConfig.id == tool_config.id + ).first() + + if custom_config: + config = { + **tool_config.config_data, + "schema_url": custom_config.schema_url, + "schema_content": custom_config.schema_content, + "auth_type": custom_config.auth_type, + "auth_config": custom_config.auth_config, + "base_url": custom_config.base_url, + "timeout": custom_config.timeout, + "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, + "version": tool_config.version, + "tags": tool_config.tags + } + return CustomTool(str(tool_config.id), config) + except ImportError as e: + logger.error(f"无法导入自定义工具模块: {e}") + + elif tool_config.tool_type == ToolType.MCP.value: + # 加载MCP工具 + try: + mcp_config = self.db.query(MCPToolConfig).filter( + MCPToolConfig.id == tool_config.id + ).first() + + if mcp_config: + config = { + **tool_config.config_data, + "server_url": mcp_config.server_url, + "connection_config": mcp_config.connection_config, + "available_tools": mcp_config.available_tools, + "tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None, + "version": tool_config.version, + "tags": tool_config.tags + } + return MCPTool(str(tool_config.id), config) + except ImportError as e: + logger.error(f"无法导入MCP工具模块: {e}") + + except Exception as e: + logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}") + + return None + + def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: + """获取工具统计信息 + + Args: + tenant_id: 租户ID + + Returns: + 统计信息字典 + """ + try: + query = self.db.query(ToolConfig) + if tenant_id: + query = query.filter(ToolConfig.tenant_id == tenant_id) + + total_tools = query.count() + active_tools = query.filter(ToolConfig.is_enabled == True).count() + + # 按类型统计 + type_stats = {} + for tool_type in ToolType: + count = query.filter(ToolConfig.tool_type == tool_type.value).count() + type_stats[tool_type.value] = count + + return { + "total_tools": total_tools, + "active_tools": active_tools, + "inactive_tools": total_tools - active_tools, + "by_type": type_stats + } + + except Exception as e: + logger.error(f"获取工具统计失败, 错误: {e}") + return {} + + def clear_cache(self): + """清空工具缓存""" + self._tools.clear() + logger.info("工具缓存已清空") \ No newline at end of file diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 3c4b8840..46f8cf08 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -4,8 +4,9 @@ 基于 LangGraph 的工作流执行引擎。 """ -import datetime import logging +# import uuid +import datetime from typing import Any from langchain_core.messages import HumanMessage @@ -15,6 +16,11 @@ from langgraph.graph.state import CompiledStateGraph from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.nodes.enums import NodeType +# from app.core.tools.registry import ToolRegistry +# from app.core.tools.executor import ToolExecutor +# from app.core.tools.langchain_adapter import LangchainAdapter +# TOOL_MANAGEMENT_AVAILABLE = True +# from app.db import get_db logger = logging.getLogger(__name__) @@ -457,3 +463,179 @@ async def execute_workflow_stream( ) async for event in executor.execute_stream(input_data): yield event + + +# ==================== 工具管理系统集成 ==================== + +# def get_workflow_tools(workspace_id: str, user_id: str) -> list: +# """获取工作流可用的工具列表 +# +# Args: +# workspace_id: 工作空间ID +# user_id: 用户ID +# +# Returns: +# 可用工具列表 +# """ +# if not TOOL_MANAGEMENT_AVAILABLE: +# logger.warning("工具管理系统不可用") +# return [] +# +# try: +# db = next(get_db()) +# +# # 创建工具注册表 +# registry = ToolRegistry(db) +# +# # 注册内置工具类 +# from app.core.tools.builtin import ( +# DateTimeTool, JsonTool, BaiduSearchTool, MinerUTool, TextInTool +# ) +# registry.register_tool_class(DateTimeTool) +# registry.register_tool_class(JsonTool) +# registry.register_tool_class(BaiduSearchTool) +# registry.register_tool_class(MinerUTool) +# registry.register_tool_class(TextInTool) +# +# # 获取活跃的工具 +# import uuid +# tools = registry.list_tools(workspace_id=uuid.UUID(workspace_id)) +# active_tools = [tool for tool in tools if tool.status.value == "active"] +# +# # 转换为Langchain工具 +# langchain_tools = [] +# for tool_info in active_tools: +# try: +# tool_instance = registry.get_tool(tool_info.id) +# if tool_instance: +# langchain_tool = LangchainAdapter.convert_tool(tool_instance) +# langchain_tools.append(langchain_tool) +# except Exception as e: +# logger.error(f"转换工具失败: {tool_info.name}, 错误: {e}") +# +# logger.info(f"为工作流获取了 {len(langchain_tools)} 个工具") +# return langchain_tools +# +# except Exception as e: +# logger.error(f"获取工作流工具失败: {e}") +# return [] +# +# +# class ToolWorkflowNode: +# """工具工作流节点 - 在工作流中执行工具""" +# +# def __init__(self, node_config: dict, workflow_config: dict): +# """初始化工具节点 +# +# Args: +# node_config: 节点配置 +# workflow_config: 工作流配置 +# """ +# self.node_config = node_config +# self.workflow_config = workflow_config +# self.tool_id = node_config.get("tool_id") +# self.tool_parameters = node_config.get("parameters", {}) +# +# async def run(self, state: WorkflowState) -> WorkflowState: +# """执行工具节点""" +# if not TOOL_MANAGEMENT_AVAILABLE: +# logger.error("工具管理系统不可用") +# state["error"] = "工具管理系统不可用" +# return state +# +# try: +# from sqlalchemy.orm import Session +# db = next(get_db()) +# +# # 创建工具执行器 +# registry = ToolRegistry(db) +# executor = ToolExecutor(db, registry) +# +# # 准备参数(支持变量替换) +# parameters = self._prepare_parameters(state) +# +# # 执行工具 +# result = await executor.execute_tool( +# tool_id=self.tool_id, +# parameters=parameters, +# user_id=uuid.UUID(state["user_id"]), +# workspace_id=uuid.UUID(state["workspace_id"]) +# ) +# +# # 更新状态 +# node_id = self.node_config.get("id") +# if result.success: +# state["node_outputs"][node_id] = { +# "type": "tool", +# "tool_id": self.tool_id, +# "output": result.data, +# "execution_time": result.execution_time, +# "token_usage": result.token_usage +# } +# +# # 更新运行时变量 +# if isinstance(result.data, dict): +# for key, value in result.data.items(): +# state["runtime_vars"][f"{node_id}.{key}"] = value +# else: +# state["runtime_vars"][f"{node_id}.result"] = result.data +# else: +# state["error"] = result.error +# state["error_node"] = node_id +# state["node_outputs"][node_id] = { +# "type": "tool", +# "tool_id": self.tool_id, +# "error": result.error, +# "execution_time": result.execution_time +# } +# +# return state +# +# except Exception as e: +# logger.error(f"工具节点执行失败: {e}") +# state["error"] = str(e) +# state["error_node"] = self.node_config.get("id") +# return state +# +# def _prepare_parameters(self, state: WorkflowState) -> dict: +# """准备工具参数(支持变量替换)""" +# parameters = {} +# +# for key, value in self.tool_parameters.items(): +# if isinstance(value, str) and value.startswith("${") and value.endswith("}"): +# # 变量替换 +# var_path = value[2:-1] +# +# # 支持多层级变量访问,如 ${sys.message} 或 ${node1.result} +# if "." in var_path: +# parts = var_path.split(".") +# current = state.get("variables", {}) +# +# for part in parts: +# if isinstance(current, dict) and part in current: +# current = current[part] +# else: +# # 尝试从运行时变量获取 +# runtime_key = ".".join(parts) +# current = state.get("runtime_vars", {}).get(runtime_key, value) +# break +# +# parameters[key] = current +# else: +# # 简单变量 +# variables = state.get("variables", {}) +# parameters[key] = variables.get(var_path, value) +# else: +# parameters[key] = value +# +# return parameters +# +# +# # 注册工具节点到NodeFactory(如果存在) +# try: +# from app.core.workflow.nodes import NodeFactory +# if hasattr(NodeFactory, 'register_node_type'): +# NodeFactory.register_node_type("tool", ToolWorkflowNode) +# logger.info("工具节点已注册到工作流系统") +# except Exception as e: +# logger.warning(f"注册工具节点失败: {e}") \ No newline at end of file diff --git a/api/app/models/__init__.py b/api/app/models/__init__.py index 198a788e..01dad24e 100644 --- a/api/app/models/__init__.py +++ b/api/app/models/__init__.py @@ -21,6 +21,10 @@ from .multi_agent_model import MultiAgentConfig, AgentInvocation from .workflow_model import WorkflowConfig, WorkflowExecution, WorkflowNodeExecution from .retrieval_info import RetrievalInfo from .prompt_optimizer_model import PromptOptimizerSession, PromptOptimizerSessionHistory +from .tool_model import ( + ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, + ToolExecution, ToolType, ToolStatus, AuthType, ExecutionStatus +) __all__ = [ "Tenants", @@ -57,5 +61,15 @@ __all__ = [ "WorkflowNodeExecution", "RetrievalInfo", "PromptOptimizerSession", - "PromptOptimizerSessionHistory" + "PromptOptimizerSessionHistory", + "RetrievalInfo", + "ToolConfig", + "BuiltinToolConfig", + "CustomToolConfig", + "MCPToolConfig", + "ToolExecution", + "ToolType", + "ToolStatus", + "AuthType", + "ExecutionStatus" ] diff --git a/api/app/models/tenant_model.py b/api/app/models/tenant_model.py index fd3d9a31..552e87b5 100644 --- a/api/app/models/tenant_model.py +++ b/api/app/models/tenant_model.py @@ -21,3 +21,6 @@ class Tenants(Base): # Relationship to workspaces owned by the tenant owned_workspaces = relationship("Workspace", back_populates="tenant") + + # Relationship to tool configs owned by the tenant + tool_configs = relationship("ToolConfig", back_populates="tenant") diff --git a/api/app/models/tool_model.py b/api/app/models/tool_model.py new file mode 100644 index 00000000..ac719317 --- /dev/null +++ b/api/app/models/tool_model.py @@ -0,0 +1,226 @@ +"""工具管理相关数据模型""" +import uuid +from datetime import datetime +from enum import StrEnum + +from sqlalchemy import Column, String, Text, DateTime, JSON, ForeignKey, Integer, Float +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import relationship + +from app.db import Base + + +class ToolType(StrEnum): + """工具类型枚举""" + BUILTIN = "builtin" + CUSTOM = "custom" + MCP = "mcp" + + +class ToolStatus(StrEnum): + """工具状态枚举""" + ACTIVE = "active" + INACTIVE = "inactive" + ERROR = "error" + LOADING = "loading" + + +class AuthType(StrEnum): + """认证类型枚举""" + NONE = "none" + API_KEY = "api_key" + BEARER_TOKEN = "bearer_token" + + +class ExecutionStatus(StrEnum): + """执行状态枚举""" + PENDING = "pending" + RUNNING = "running" + COMPLETED = "completed" + FAILED = "failed" + TIMEOUT = "timeout" + + +class ToolConfig(Base): + """工具配置基础模型""" + __tablename__ = "tool_configs" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + name = Column(String(255), nullable=False, index=True) + description = Column(Text) + tool_type = Column(String(50), nullable=False, index=True) + tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id"), nullable=False, index=True) # 必须属于租户 + status = Column(String(50), default=ToolStatus.INACTIVE.value, nullable=False, index=True) # 工具状态 + + # 工具特定配置(JSON格式存储) + config_data = Column(JSON, default=dict) + + # 元数据 + version = Column(String(50), default="1.0.0") + tags = Column(JSON, default=list) # 标签列表 + + # 时间戳 + created_at = Column(DateTime, default=datetime.now, nullable=False) + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) + + # 关联关系 + tenant = relationship("Tenants", back_populates="tool_configs") + executions = relationship("ToolExecution", back_populates="tool_config", cascade="all, delete-orphan") + + def __repr__(self): + return f"" + + +class BuiltinToolConfig(Base): + """内置工具配置模型""" + __tablename__ = "builtin_tool_configs" + + id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) + tool_class = Column(String(255), nullable=False) # 工具类名 + parameters = Column(JSON, default=dict) # 工具参数配置 + + # 关联关系 + base_config = relationship("ToolConfig", foreign_keys=[id]) + + def __repr__(self): + return f"" + + +class CustomToolConfig(Base): + """自定义工具配置模型""" + __tablename__ = "custom_tool_configs" + + id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) + schema_url = Column(String(1000)) # OpenAPI schema URL + schema_content = Column(JSON) # OpenAPI schema 内容 + + # 认证配置 + auth_type = Column(String(50), default=AuthType.NONE.value, nullable=False) + auth_config = Column(JSON, default=dict) # 认证配置(加密存储) + + # API配置 + base_url = Column(String(1000)) # API基础URL + timeout = Column(Integer, default=30) # 超时时间(秒) + + # 关联关系 + base_config = relationship("ToolConfig", foreign_keys=[id]) + + def __repr__(self): + return f"" + + +class MCPToolConfig(Base): + """MCP工具配置模型""" + __tablename__ = "mcp_tool_configs" + + id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), primary_key=True) + server_url = Column(String(1000), nullable=False) # MCP服务器URL + connection_config = Column(JSON, default=dict) # 连接配置 + + # 服务状态 + last_health_check = Column(DateTime) + health_status = Column(String(50), default="unknown") + error_message = Column(Text) + + # 可用工具列表 + available_tools = Column(JSON, default=list) + + # 关联关系 + base_config = relationship("ToolConfig", foreign_keys=[id]) + + def __repr__(self): + return f"" + + +class ToolExecution(Base): + """工具执行记录模型""" + __tablename__ = "tool_executions" + + id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + tool_config_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False, index=True) + + # 执行信息 + execution_id = Column(String(255), nullable=False, index=True) # 执行ID(可用于关联工作流等) + status = Column(String(50), default=ExecutionStatus.PENDING.value, nullable=False, index=True) + + # 输入输出 + input_data = Column(JSON) # 输入参数 + output_data = Column(JSON) # 输出结果 + error_message = Column(Text) # 错误信息 + + # 性能指标 + started_at = Column(DateTime, nullable=False, index=True) + completed_at = Column(DateTime) + execution_time = Column(Float) # 执行时间(秒) + + # Token使用情况(如果适用) + token_usage = Column(JSON) + + # 用户信息 + user_id = Column(UUID(as_uuid=True), ForeignKey("users.id"), index=True) + workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False, index=True) + + # 关联关系 + tool_config = relationship("ToolConfig", back_populates="executions") + user = relationship("User") + workspace = relationship("Workspace") + + def __repr__(self): + return f"" + + +# class ToolDependency(Base): +# """工具依赖关系模型""" +# __tablename__ = "tool_dependencies" +# +# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) +# tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False) +# depends_on_tool_id = Column(UUID(as_uuid=True), ForeignKey("tool_configs.id"), nullable=False) +# +# # 依赖类型和版本要求 +# dependency_type = Column(String(50), default="required") # required, optional +# version_constraint = Column(String(100)) # 版本约束,如 ">=1.0.0" +# +# # 时间戳 +# created_at = Column(DateTime, default=datetime.now, nullable=False) +# +# # 关联关系 +# tool = relationship("ToolConfig", foreign_keys=[tool_id]) +# depends_on_tool = relationship("ToolConfig", foreign_keys=[depends_on_tool_id]) +# +# def __repr__(self): +# return f"" + + +# class PluginConfig(Base): +# """插件配置模型""" +# __tablename__ = "plugin_configs" +# +# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) +# name = Column(String(255), nullable=False, unique=True, index=True) +# description = Column(Text) +# +# # 插件信息 +# plugin_path = Column(String(1000), nullable=False) # 插件文件路径 +# entry_point = Column(String(255), nullable=False) # 入口点 +# version = Column(String(50), default="1.0.0") +# +# # 状态 +# is_enabled = Column(Boolean, default=True, nullable=False) +# is_loaded = Column(Boolean, default=False, nullable=False) +# load_error = Column(Text) # 加载错误信息 +# +# # 配置 +# config_schema = Column(JSON) # 配置schema +# config_data = Column(JSON, default=dict) # 配置数据 +# +# # 依赖 +# dependencies = Column(JSON, default=list) # 依赖的其他插件 +# +# # 时间戳 +# created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) +# updated_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) +# last_loaded_at = Column(DateTime) +# +# def __repr__(self): +# return f"" \ No newline at end of file diff --git a/api/app/services/agent_tools.py b/api/app/services/agent_tools.py index 4c011a87..3ca7bddd 100644 --- a/api/app/services/agent_tools.py +++ b/api/app/services/agent_tools.py @@ -14,6 +14,7 @@ from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger from app.repositories import workspace_repository, knowledge_repository + logger = get_business_logger() @@ -328,4 +329,4 @@ def create_agent_invocation_tool( ) return f"调用 Agent 失败: {str(e)}" - return invoke_agent + return invoke_agent \ No newline at end of file diff --git a/api/test_tool_system.py b/api/test_tool_system.py new file mode 100644 index 00000000..30d60d23 --- /dev/null +++ b/api/test_tool_system.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python3 +""" +工具管理系统基础测试脚本 +用于验证系统的基本功能是否正常 +""" + +import asyncio +import uuid +from datetime import datetime + +# 测试导入 +def test_imports(): + """测试模块导入""" + print("测试模块导入...") + + try: + from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType + print("✓ 基础工具模块导入成功") + except ImportError as e: + print(f"✗ 基础工具模块导入失败: {e}") + return False + + try: + from app.core.tools.builtin.datetime_tool import DateTimeTool + from app.core.tools.builtin.json_tool import JsonTool + print("✓ 内置工具模块导入成功") + except ImportError as e: + print(f"✗ 内置工具模块导入失败: {e}") + return False + + try: + from app.core.tools.langchain_adapter import LangchainAdapter + print("✓ Langchain适配器导入成功") + except ImportError as e: + print(f"✗ Langchain适配器导入失败: {e}") + return False + + try: + from app.models.tool_model import ToolConfig, ToolType, ToolStatus + print("✓ 工具模型导入成功") + except ImportError as e: + print(f"✗ 工具模型导入失败: {e}") + return False + + try: + from app.core.tools.custom import CustomTool, OpenAPISchemaParser, AuthManager + print("✓ 自定义工具模块导入成功") + except ImportError as e: + print(f"✗ 自定义工具模块导入失败: {e}") + return False + + try: + from app.core.tools.mcp import MCPTool, MCPClient, MCPServiceManager + print("✓ MCP工具模块导入成功") + except ImportError as e: + print(f"✗ MCP工具模块导入失败: {e}") + return False + + return True + + +def test_tool_creation(): + """测试工具创建""" + print("\n测试工具创建...") + + try: + from app.core.tools.builtin.datetime_tool import DateTimeTool + + # 创建时间工具实例(全局工具) + tool_id = str(uuid.uuid4()) + config = { + "parameters": {"timezone": "UTC"}, + "tenant_id": None, # 全局工具 + "version": "1.0.0", + "tags": ["time", "utility", "builtin"] + } + + datetime_tool = DateTimeTool(tool_id, config) + + # 验证工具属性 + assert datetime_tool.name == "datetime_tool" + assert datetime_tool.tool_type.value == "builtin" + assert len(datetime_tool.parameters) > 0 + + print("✓ 时间工具创建成功(全局工具)") + return True + + except Exception as e: + print(f"✗ 工具创建失败: {e}") + return False + + +async def test_tool_execution(): + """测试工具执行""" + print("\n测试工具执行...") + + try: + from app.core.tools.builtin.datetime_tool import DateTimeTool + + # 创建时间工具实例 + tool_id = str(uuid.uuid4()) + config = { + "parameters": {"timezone": "UTC"}, + "tenant_id": None, # 全局工具 + "version": "1.0.0" + } + + datetime_tool = DateTimeTool(tool_id, config) + + # 测试获取当前时间 + result = await datetime_tool.safe_execute(operation="now") + + assert result.success == True + assert "datetime" in result.data + assert result.execution_time > 0 + + print("✓ 工具执行成功") + print(f" 执行时间: {result.execution_time:.3f}秒") + print(f" 返回数据: {result.data}") + + return True + + except Exception as e: + print(f"✗ 工具执行失败: {e}") + return False + + +def test_langchain_adapter(): + """测试Langchain适配器""" + print("\n测试Langchain适配器...") + + try: + from app.core.tools.builtin.json_tool import JsonTool + from app.core.tools.langchain_adapter import LangchainAdapter + + # 创建JSON工具实例 + tool_id = str(uuid.uuid4()) + config = { + "parameters": {"indent": 2}, + "tenant_id": None, # 全局工具 + "version": "1.0.0" + } + + json_tool = JsonTool(tool_id, config) + + # 验证Langchain兼容性 + is_compatible, issues = LangchainAdapter.validate_langchain_compatibility(json_tool) + + if not is_compatible: + print(f"✗ Langchain兼容性验证失败: {issues}") + return False + + # 创建工具描述 + description = LangchainAdapter.create_tool_description(json_tool) + + assert "name" in description + assert "parameters" in description + assert description["langchain_compatible"] == True + + print("✓ Langchain适配器测试成功") + return True + + except Exception as e: + print(f"✗ Langchain适配器测试失败: {e}") + return False + + +def test_config_manager(): + """测试配置管理器""" + print("\n测试配置管理器...") + + try: + from app.core.tools.config_manager import ConfigManager + + # 创建配置管理器 + config_manager = ConfigManager() + + # 获取配置摘要 + summary = config_manager.get_config_summary() + + assert "config_dir" in summary + assert "total_configs" in summary + + print("✓ 配置管理器测试成功") + print(f" 配置目录: {summary['config_dir']}") + print(f" 总配置数: {summary['total_configs']}") + + return True + + except Exception as e: + print(f"✗ 配置管理器测试失败: {e}") + return False + + +def test_schema_parser(): + """测试OpenAPI Schema解析器""" + print("\n测试OpenAPI Schema解析器...") + + try: + from app.core.tools.custom.schema_parser import OpenAPISchemaParser + + # 创建解析器 + parser = OpenAPISchemaParser() + + # 测试简单的OpenAPI schema + test_schema = { + "openapi": "3.0.0", + "info": { + "title": "Test API", + "version": "1.0.0", + "description": "测试API" + }, + "paths": { + "/test": { + "get": { + "summary": "测试接口", + "operationId": "test_operation", + "responses": { + "200": { + "description": "成功" + } + } + } + } + } + } + + # 验证schema + is_valid, error_msg = parser.validate_schema(test_schema) + assert is_valid, f"Schema验证失败: {error_msg}" + + # 提取工具信息 + tool_info = parser.extract_tool_info(test_schema) + assert tool_info["name"] == "Test API" + assert "test_operation" in tool_info["operations"] + + print("✓ OpenAPI Schema解析器测试成功") + return True + + except Exception as e: + print(f"✗ OpenAPI Schema解析器测试失败: {e}") + return False + + +def test_auth_manager(): + """测试认证管理器""" + print("\n测试认证管理器...") + + try: + from app.core.tools.custom.auth_manager import AuthManager + from app.models.tool_model import AuthType + + # 创建认证管理器 + auth_manager = AuthManager() + + # 测试API Key认证配置 + api_key_config = { + "api_key": "test-key-123", + "key_name": "X-API-Key", + "location": "header" + } + + is_valid, error_msg = auth_manager.validate_auth_config(AuthType.API_KEY, api_key_config) + assert is_valid, f"API Key配置验证失败: {error_msg}" + + # 测试Bearer Token认证配置 + bearer_config = { + "token": "bearer-token-123" + } + + is_valid, error_msg = auth_manager.validate_auth_config(AuthType.BEARER_TOKEN, bearer_config) + assert is_valid, f"Bearer Token配置验证失败: {error_msg}" + + # 测试认证应用 + url = "https://api.example.com/test" + headers = {} + params = {} + + new_url, new_headers, new_params = auth_manager.apply_authentication( + AuthType.API_KEY, api_key_config, url, headers, params + ) + + assert "X-API-Key" in new_headers + assert new_headers["X-API-Key"] == "test-key-123" + + print("✓ 认证管理器测试成功") + return True + + except Exception as e: + print(f"✗ 认证管理器测试失败: {e}") + return False + + +def test_builtin_initializer(): + """测试内置工具初始化器""" + print("\n测试内置工具初始化器...") + + try: + from app.core.tools.builtin_initializer import BuiltinToolInitializer + + # 注意:这里不能真正初始化,因为需要数据库连接 + # 只测试类的创建和基本方法 + + # 模拟数据库会话(实际使用中需要真实的数据库连接) + class MockDB: + def query(self, *args): + return self + def filter(self, *args): + return self + def first(self): + return None + def all(self): + return [] + + mock_db = MockDB() + initializer = BuiltinToolInitializer(mock_db) + + # 测试获取内置工具状态(会返回空列表,因为没有真实数据) + status = initializer.get_builtin_tools_status() + assert isinstance(status, list) + + print("✓ 内置工具初始化器测试成功") + return True + + except Exception as e: + print(f"✗ 内置工具初始化器测试失败: {e}") + return False + + +async def main(): + """主测试函数""" + print("=" * 50) + print("工具管理系统基础测试") + print("=" * 50) + + tests = [ + ("模块导入", test_imports), + ("工具创建", test_tool_creation), + ("工具执行", test_tool_execution), + ("Langchain适配", test_langchain_adapter), + ("配置管理", test_config_manager), + ("Schema解析器", test_schema_parser), + ("认证管理器", test_auth_manager), + ("内置工具初始化器", test_builtin_initializer) + ] + + passed = 0 + total = len(tests) + + for test_name, test_func in tests: + try: + if asyncio.iscoroutinefunction(test_func): + result = await test_func() + else: + result = test_func() + + if result: + passed += 1 + except Exception as e: + print(f"✗ {test_name}测试异常: {e}") + + print("\n" + "=" * 50) + print(f"测试结果: {passed}/{total} 通过") + + if passed == total: + print("🎉 所有基础测试通过!工具管理系统基本功能正常。") + return True + else: + print("⚠️ 部分测试失败,请检查相关模块。") + return False + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file