Files
MemoryBear/api/app/core/tools/base.py
2025-12-20 15:24:28 +08:00

302 lines
9.7 KiB
Python

"""工具基础接口定义"""
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from enum import Enum
from app.models.tool_model import ToolType, ToolStatus
class ParameterType(str, Enum):
"""参数类型枚举"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class ToolParameter(BaseModel):
"""工具参数定义"""
name: str = Field(..., description="参数名称")
type: ParameterType = Field(..., description="参数类型")
description: str = Field("", description="参数描述")
required: bool = Field(False, description="是否必需")
default: Any = Field(None, description="默认值")
enum: Optional[List[Any]] = Field(None, description="枚举值")
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
pattern: Optional[str] = Field(None, description="正则表达式模式")
class Config:
use_enum_values = True
class ToolResult(BaseModel):
"""工具执行结果"""
success: bool = Field(..., description="执行是否成功")
data: Any = Field(None, description="返回数据")
error: Optional[str] = Field(None, description="错误信息")
error_code: Optional[str] = Field(None, description="错误代码")
execution_time: float = Field(..., description="执行时间(秒)")
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
@classmethod
def success_result(
cls,
data: Any,
execution_time: float,
token_usage: Optional[Dict[str, int]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建成功结果"""
return cls(
success=True,
data=data,
execution_time=execution_time,
token_usage=token_usage,
metadata=metadata or {}
)
@classmethod
def error_result(
cls,
error: str,
execution_time: float,
error_code: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建错误结果"""
return cls(
success=False,
error=error,
error_code=error_code,
execution_time=execution_time,
metadata=metadata or {}
)
class ToolInfo(BaseModel):
"""工具信息"""
id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
tool_type: ToolType = Field(..., description="工具类型")
version: str = Field("1.0.0", description="工具版本")
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID")
class Config:
use_enum_values = True
class BaseTool(ABC):
"""所有工具的基础抽象类"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化工具
Args:
tool_id: 工具ID
config: 工具配置
"""
self.tool_id = tool_id
self.config = config
self._status = ToolStatus.ACTIVE
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述"""
pass
@property
@abstractmethod
def tool_type(self) -> ToolType:
"""工具类型"""
pass
@property
def version(self) -> str:
"""工具版本"""
return self.config.get("version", "1.0.0")
@property
def status(self) -> ToolStatus:
"""工具状态"""
return self._status
@status.setter
def status(self, value: ToolStatus):
"""设置工具状态"""
self._status = value
@property
@abstractmethod
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
pass
@property
def tags(self) -> List[str]:
"""工具标签"""
return self.config.get("tags", [])
def get_info(self) -> ToolInfo:
"""获取工具信息"""
return ToolInfo(
id=self.tool_id,
name=self.name,
description=self.description,
tool_type=self.tool_type,
version=self.version,
parameters=self.parameters,
status=self.status,
tags=self.tags,
tenant_id=self.config.get("tenant_id")
)
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
"""验证参数
Args:
parameters: 输入参数
Returns:
验证错误字典,空字典表示验证通过
"""
errors = {}
param_definitions = {p.name: p for p in self.parameters}
# 检查必需参数
for param_def in self.parameters:
if param_def.required and param_def.name not in parameters:
errors[param_def.name] = f"Required parameter '{param_def.name}' is missing"
# 检查参数类型和约束
for param_name, param_value in parameters.items():
if param_name not in param_definitions:
continue
param_def = param_definitions[param_name]
# 类型检查
if not self._validate_parameter_type(param_value, param_def):
errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}"
# 约束检查
constraint_error = self._validate_parameter_constraints(param_value, param_def)
if constraint_error:
errors[param_name] = constraint_error
return errors
def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool:
"""验证参数类型"""
if value is None:
return not param_def.required
type_mapping = {
ParameterType.STRING: str,
ParameterType.INTEGER: int,
ParameterType.NUMBER: (int, float),
ParameterType.BOOLEAN: bool,
ParameterType.ARRAY: list,
ParameterType.OBJECT: dict
}
expected_type = type_mapping.get(param_def.type)
if expected_type:
return isinstance(value, expected_type)
return True
def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]:
"""验证参数约束"""
if value is None:
return None
# 枚举值检查
if param_def.enum and value not in param_def.enum:
return f"Value must be one of {param_def.enum}"
# 数值范围检查
if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]:
if param_def.minimum is not None and value < param_def.minimum:
return f"Value must be >= {param_def.minimum}"
if param_def.maximum is not None and value > param_def.maximum:
return f"Value must be <= {param_def.maximum}"
# 字符串模式检查
if param_def.type == ParameterType.STRING and param_def.pattern:
import re
if not re.match(param_def.pattern, str(value)):
return f"Value must match pattern: {param_def.pattern}"
return None
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""执行工具
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
async def safe_execute(self, **kwargs) -> ToolResult:
"""安全执行工具(包含参数验证和异常处理)
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
start_time = time.time()
try:
# 参数验证
validation_errors = self.validate_parameters(kwargs)
if validation_errors:
execution_time = time.time() - start_time
error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()])
return ToolResult.error_result(
error=f"Parameter validation failed: {error_msg}",
error_code="VALIDATION_ERROR",
execution_time=execution_time
)
# 执行工具
result = await self.execute(**kwargs)
return result
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=execution_time
)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"