302 lines
9.7 KiB
Python
302 lines
9.7 KiB
Python
"""工具基础接口定义"""
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
from typing import Any, Dict, List, Optional, Union
|
|
from pydantic import BaseModel, Field
|
|
from enum import Enum
|
|
|
|
from app.models.tool_model import ToolType, ToolStatus
|
|
|
|
|
|
class ParameterType(str, Enum):
|
|
"""参数类型枚举"""
|
|
STRING = "string"
|
|
INTEGER = "integer"
|
|
NUMBER = "number"
|
|
BOOLEAN = "boolean"
|
|
ARRAY = "array"
|
|
OBJECT = "object"
|
|
|
|
|
|
class ToolParameter(BaseModel):
|
|
"""工具参数定义"""
|
|
name: str = Field(..., description="参数名称")
|
|
type: ParameterType = Field(..., description="参数类型")
|
|
description: str = Field("", description="参数描述")
|
|
required: bool = Field(False, description="是否必需")
|
|
default: Any = Field(None, description="默认值")
|
|
enum: Optional[List[Any]] = Field(None, description="枚举值")
|
|
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
|
|
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
|
|
pattern: Optional[str] = Field(None, description="正则表达式模式")
|
|
|
|
class Config:
|
|
use_enum_values = True
|
|
|
|
|
|
class ToolResult(BaseModel):
|
|
"""工具执行结果"""
|
|
success: bool = Field(..., description="执行是否成功")
|
|
data: Any = Field(None, description="返回数据")
|
|
error: Optional[str] = Field(None, description="错误信息")
|
|
error_code: Optional[str] = Field(None, description="错误代码")
|
|
execution_time: float = Field(..., description="执行时间(秒)")
|
|
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
|
|
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
|
|
|
@classmethod
|
|
def success_result(
|
|
cls,
|
|
data: Any,
|
|
execution_time: float,
|
|
token_usage: Optional[Dict[str, int]] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> "ToolResult":
|
|
"""创建成功结果"""
|
|
return cls(
|
|
success=True,
|
|
data=data,
|
|
execution_time=execution_time,
|
|
token_usage=token_usage,
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
@classmethod
|
|
def error_result(
|
|
cls,
|
|
error: str,
|
|
execution_time: float,
|
|
error_code: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> "ToolResult":
|
|
"""创建错误结果"""
|
|
return cls(
|
|
success=False,
|
|
error=error,
|
|
error_code=error_code,
|
|
execution_time=execution_time,
|
|
metadata=metadata or {}
|
|
)
|
|
|
|
|
|
class ToolInfo(BaseModel):
|
|
"""工具信息"""
|
|
id: str = Field(..., description="工具ID")
|
|
name: str = Field(..., description="工具名称")
|
|
description: str = Field(..., description="工具描述")
|
|
tool_type: ToolType = Field(..., description="工具类型")
|
|
version: str = Field("1.0.0", description="工具版本")
|
|
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
|
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
|
|
tags: List[str] = Field(default_factory=list, description="工具标签")
|
|
tenant_id: Optional[str] = Field(None, description="租户ID")
|
|
|
|
class Config:
|
|
use_enum_values = True
|
|
|
|
|
|
class BaseTool(ABC):
|
|
"""所有工具的基础抽象类"""
|
|
|
|
def __init__(self, tool_id: str, config: Dict[str, Any]):
|
|
"""初始化工具
|
|
|
|
Args:
|
|
tool_id: 工具ID
|
|
config: 工具配置
|
|
"""
|
|
self.tool_id = tool_id
|
|
self.config = config
|
|
self._status = ToolStatus.ACTIVE
|
|
|
|
@property
|
|
@abstractmethod
|
|
def name(self) -> str:
|
|
"""工具名称"""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def description(self) -> str:
|
|
"""工具描述"""
|
|
pass
|
|
|
|
@property
|
|
@abstractmethod
|
|
def tool_type(self) -> ToolType:
|
|
"""工具类型"""
|
|
pass
|
|
|
|
@property
|
|
def version(self) -> str:
|
|
"""工具版本"""
|
|
return self.config.get("version", "1.0.0")
|
|
|
|
@property
|
|
def status(self) -> ToolStatus:
|
|
"""工具状态"""
|
|
return self._status
|
|
|
|
@status.setter
|
|
def status(self, value: ToolStatus):
|
|
"""设置工具状态"""
|
|
self._status = value
|
|
|
|
@property
|
|
@abstractmethod
|
|
def parameters(self) -> List[ToolParameter]:
|
|
"""工具参数定义"""
|
|
pass
|
|
|
|
@property
|
|
def tags(self) -> List[str]:
|
|
"""工具标签"""
|
|
return self.config.get("tags", [])
|
|
|
|
def get_info(self) -> ToolInfo:
|
|
"""获取工具信息"""
|
|
return ToolInfo(
|
|
id=self.tool_id,
|
|
name=self.name,
|
|
description=self.description,
|
|
tool_type=self.tool_type,
|
|
version=self.version,
|
|
parameters=self.parameters,
|
|
status=self.status,
|
|
tags=self.tags,
|
|
tenant_id=self.config.get("tenant_id")
|
|
)
|
|
|
|
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
|
|
"""验证参数
|
|
|
|
Args:
|
|
parameters: 输入参数
|
|
|
|
Returns:
|
|
验证错误字典,空字典表示验证通过
|
|
"""
|
|
errors = {}
|
|
param_definitions = {p.name: p for p in self.parameters}
|
|
|
|
# 检查必需参数
|
|
for param_def in self.parameters:
|
|
if param_def.required and param_def.name not in parameters:
|
|
errors[param_def.name] = f"Required parameter '{param_def.name}' is missing"
|
|
|
|
# 检查参数类型和约束
|
|
for param_name, param_value in parameters.items():
|
|
if param_name not in param_definitions:
|
|
continue
|
|
|
|
param_def = param_definitions[param_name]
|
|
|
|
# 类型检查
|
|
if not self._validate_parameter_type(param_value, param_def):
|
|
errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}"
|
|
|
|
# 约束检查
|
|
constraint_error = self._validate_parameter_constraints(param_value, param_def)
|
|
if constraint_error:
|
|
errors[param_name] = constraint_error
|
|
|
|
return errors
|
|
|
|
def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool:
|
|
"""验证参数类型"""
|
|
if value is None:
|
|
return not param_def.required
|
|
|
|
type_mapping = {
|
|
ParameterType.STRING: str,
|
|
ParameterType.INTEGER: int,
|
|
ParameterType.NUMBER: (int, float),
|
|
ParameterType.BOOLEAN: bool,
|
|
ParameterType.ARRAY: list,
|
|
ParameterType.OBJECT: dict
|
|
}
|
|
|
|
expected_type = type_mapping.get(param_def.type)
|
|
if expected_type:
|
|
return isinstance(value, expected_type)
|
|
|
|
return True
|
|
|
|
def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]:
|
|
"""验证参数约束"""
|
|
if value is None:
|
|
return None
|
|
|
|
# 枚举值检查
|
|
if param_def.enum and value not in param_def.enum:
|
|
return f"Value must be one of {param_def.enum}"
|
|
|
|
# 数值范围检查
|
|
if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]:
|
|
if param_def.minimum is not None and value < param_def.minimum:
|
|
return f"Value must be >= {param_def.minimum}"
|
|
if param_def.maximum is not None and value > param_def.maximum:
|
|
return f"Value must be <= {param_def.maximum}"
|
|
|
|
# 字符串模式检查
|
|
if param_def.type == ParameterType.STRING and param_def.pattern:
|
|
import re
|
|
if not re.match(param_def.pattern, str(value)):
|
|
return f"Value must match pattern: {param_def.pattern}"
|
|
|
|
return None
|
|
|
|
@abstractmethod
|
|
async def execute(self, **kwargs) -> ToolResult:
|
|
"""执行工具
|
|
|
|
Args:
|
|
**kwargs: 工具参数
|
|
|
|
Returns:
|
|
执行结果
|
|
"""
|
|
pass
|
|
|
|
async def safe_execute(self, **kwargs) -> ToolResult:
|
|
"""安全执行工具(包含参数验证和异常处理)
|
|
|
|
Args:
|
|
**kwargs: 工具参数
|
|
|
|
Returns:
|
|
执行结果
|
|
"""
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# 参数验证
|
|
validation_errors = self.validate_parameters(kwargs)
|
|
if validation_errors:
|
|
execution_time = time.time() - start_time
|
|
error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()])
|
|
return ToolResult.error_result(
|
|
error=f"Parameter validation failed: {error_msg}",
|
|
error_code="VALIDATION_ERROR",
|
|
execution_time=execution_time
|
|
)
|
|
|
|
# 执行工具
|
|
result = await self.execute(**kwargs)
|
|
return result
|
|
|
|
except Exception as e:
|
|
execution_time = time.time() - start_time
|
|
return ToolResult.error_result(
|
|
error=str(e),
|
|
error_code="EXECUTION_ERROR",
|
|
execution_time=execution_time
|
|
)
|
|
|
|
def to_langchain_tool(self):
|
|
"""转换为Langchain工具格式"""
|
|
from .langchain_adapter import LangchainAdapter
|
|
return LangchainAdapter.convert_tool(self)
|
|
|
|
def __repr__(self):
|
|
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>" |