"""工具基础接口定义""" import time from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, Field from enum import Enum from app.models.tool_model import ToolType, ToolStatus class ParameterType(str, Enum): """参数类型枚举""" STRING = "string" INTEGER = "integer" NUMBER = "number" BOOLEAN = "boolean" ARRAY = "array" OBJECT = "object" class ToolParameter(BaseModel): """工具参数定义""" name: str = Field(..., description="参数名称") type: ParameterType = Field(..., description="参数类型") description: str = Field("", description="参数描述") required: bool = Field(False, description="是否必需") default: Any = Field(None, description="默认值") enum: Optional[List[Any]] = Field(None, description="枚举值") minimum: Optional[Union[int, float]] = Field(None, description="最小值") maximum: Optional[Union[int, float]] = Field(None, description="最大值") pattern: Optional[str] = Field(None, description="正则表达式模式") class Config: use_enum_values = True class ToolResult(BaseModel): """工具执行结果""" success: bool = Field(..., description="执行是否成功") data: Any = Field(None, description="返回数据") error: Optional[str] = Field(None, description="错误信息") error_code: Optional[str] = Field(None, description="错误代码") execution_time: float = Field(..., description="执行时间(秒)") token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况") metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据") @classmethod def success_result( cls, data: Any, execution_time: float, token_usage: Optional[Dict[str, int]] = None, metadata: Optional[Dict[str, Any]] = None ) -> "ToolResult": """创建成功结果""" return cls( success=True, data=data, execution_time=execution_time, token_usage=token_usage, metadata=metadata or {} ) @classmethod def error_result( cls, error: str, execution_time: float, error_code: Optional[str] = None, metadata: Optional[Dict[str, Any]] = None ) -> "ToolResult": """创建错误结果""" return cls( success=False, error=error, error_code=error_code, execution_time=execution_time, metadata=metadata or {} ) class ToolInfo(BaseModel): """工具信息""" id: str = Field(..., description="工具ID") name: str = Field(..., description="工具名称") description: str = Field(..., description="工具描述") tool_type: ToolType = Field(..., description="工具类型") version: str = Field("1.0.0", description="工具版本") parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数") status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态") tags: List[str] = Field(default_factory=list, description="工具标签") tenant_id: Optional[str] = Field(None, description="租户ID") class Config: use_enum_values = True class BaseTool(ABC): """所有工具的基础抽象类""" def __init__(self, tool_id: str, config: Dict[str, Any]): """初始化工具 Args: tool_id: 工具ID config: 工具配置 """ self.tool_id = tool_id self.config = config self._status = ToolStatus.ACTIVE @property @abstractmethod def name(self) -> str: """工具名称""" pass @property @abstractmethod def description(self) -> str: """工具描述""" pass @property @abstractmethod def tool_type(self) -> ToolType: """工具类型""" pass @property def version(self) -> str: """工具版本""" return self.config.get("version", "1.0.0") @property def status(self) -> ToolStatus: """工具状态""" return self._status @status.setter def status(self, value: ToolStatus): """设置工具状态""" self._status = value @property @abstractmethod def parameters(self) -> List[ToolParameter]: """工具参数定义""" pass @property def tags(self) -> List[str]: """工具标签""" return self.config.get("tags", []) def get_info(self) -> ToolInfo: """获取工具信息""" return ToolInfo( id=self.tool_id, name=self.name, description=self.description, tool_type=self.tool_type, version=self.version, parameters=self.parameters, status=self.status, tags=self.tags, tenant_id=self.config.get("tenant_id") ) def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]: """验证参数 Args: parameters: 输入参数 Returns: 验证错误字典,空字典表示验证通过 """ errors = {} param_definitions = {p.name: p for p in self.parameters} # 检查必需参数 for param_def in self.parameters: if param_def.required and param_def.name not in parameters: errors[param_def.name] = f"Required parameter '{param_def.name}' is missing" # 检查参数类型和约束 for param_name, param_value in parameters.items(): if param_name not in param_definitions: continue param_def = param_definitions[param_name] # 类型检查 if not self._validate_parameter_type(param_value, param_def): errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}" # 约束检查 constraint_error = self._validate_parameter_constraints(param_value, param_def) if constraint_error: errors[param_name] = constraint_error return errors def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool: """验证参数类型""" if value is None: return not param_def.required type_mapping = { ParameterType.STRING: str, ParameterType.INTEGER: int, ParameterType.NUMBER: (int, float), ParameterType.BOOLEAN: bool, ParameterType.ARRAY: list, ParameterType.OBJECT: dict } expected_type = type_mapping.get(param_def.type) if expected_type: return isinstance(value, expected_type) return True def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]: """验证参数约束""" if value is None: return None # 枚举值检查 if param_def.enum and value not in param_def.enum: return f"Value must be one of {param_def.enum}" # 数值范围检查 if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]: if param_def.minimum is not None and value < param_def.minimum: return f"Value must be >= {param_def.minimum}" if param_def.maximum is not None and value > param_def.maximum: return f"Value must be <= {param_def.maximum}" # 字符串模式检查 if param_def.type == ParameterType.STRING and param_def.pattern: import re if not re.match(param_def.pattern, str(value)): return f"Value must match pattern: {param_def.pattern}" return None @abstractmethod async def execute(self, **kwargs) -> ToolResult: """执行工具 Args: **kwargs: 工具参数 Returns: 执行结果 """ pass async def safe_execute(self, **kwargs) -> ToolResult: """安全执行工具(包含参数验证和异常处理) Args: **kwargs: 工具参数 Returns: 执行结果 """ start_time = time.time() try: # 参数验证 validation_errors = self.validate_parameters(kwargs) if validation_errors: execution_time = time.time() - start_time error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()]) return ToolResult.error_result( error=f"Parameter validation failed: {error_msg}", error_code="VALIDATION_ERROR", execution_time=execution_time ) # 执行工具 result = await self.execute(**kwargs) return result except Exception as e: execution_time = time.time() - start_time return ToolResult.error_result( error=str(e), error_code="EXECUTION_ERROR", execution_time=execution_time ) def to_langchain_tool(self): """转换为Langchain工具格式""" from .langchain_adapter import LangchainAdapter return LangchainAdapter.convert_tool(self) def __repr__(self): return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"