feat(tool system): Tool system reengineering
This commit is contained in:
@@ -1,11 +1,7 @@
|
||||
"""工具管理核心模块"""
|
||||
|
||||
from .base import BaseTool, ToolResult, ToolParameter
|
||||
from .registry import ToolRegistry
|
||||
from .executor import ToolExecutor
|
||||
from .langchain_adapter import LangchainAdapter
|
||||
from .config_manager import ConfigManager
|
||||
from .chain_manager import ChainManager
|
||||
|
||||
# 可选导入,避免导入错误
|
||||
try:
|
||||
@@ -22,11 +18,7 @@ __all__ = [
|
||||
"BaseTool",
|
||||
"ToolResult",
|
||||
"ToolParameter",
|
||||
"ToolRegistry",
|
||||
"ToolExecutor",
|
||||
"LangchainAdapter",
|
||||
"ConfigManager",
|
||||
"ChainManager"
|
||||
"LangchainAdapter"
|
||||
]
|
||||
|
||||
# 只有在成功导入时才添加到__all__
|
||||
|
||||
@@ -1,98 +1,10 @@
|
||||
"""工具基础接口定义"""
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.models.tool_model import ToolType, ToolStatus
|
||||
|
||||
|
||||
class ParameterType(str, Enum):
|
||||
"""参数类型枚举"""
|
||||
STRING = "string"
|
||||
INTEGER = "integer"
|
||||
NUMBER = "number"
|
||||
BOOLEAN = "boolean"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
|
||||
|
||||
class ToolParameter(BaseModel):
|
||||
"""工具参数定义"""
|
||||
name: str = Field(..., description="参数名称")
|
||||
type: ParameterType = Field(..., description="参数类型")
|
||||
description: str = Field("", description="参数描述")
|
||||
required: bool = Field(False, description="是否必需")
|
||||
default: Any = Field(None, description="默认值")
|
||||
enum: Optional[List[Any]] = Field(None, description="枚举值")
|
||||
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
|
||||
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
|
||||
pattern: Optional[str] = Field(None, description="正则表达式模式")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""工具执行结果"""
|
||||
success: bool = Field(..., description="执行是否成功")
|
||||
data: Any = Field(None, description="返回数据")
|
||||
error: Optional[str] = Field(None, description="错误信息")
|
||||
error_code: Optional[str] = Field(None, description="错误代码")
|
||||
execution_time: float = Field(..., description="执行时间(秒)")
|
||||
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
|
||||
|
||||
@classmethod
|
||||
def success_result(
|
||||
cls,
|
||||
data: Any,
|
||||
execution_time: float,
|
||||
token_usage: Optional[Dict[str, int]] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建成功结果"""
|
||||
return cls(
|
||||
success=True,
|
||||
data=data,
|
||||
execution_time=execution_time,
|
||||
token_usage=token_usage,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def error_result(
|
||||
cls,
|
||||
error: str,
|
||||
execution_time: float,
|
||||
error_code: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> "ToolResult":
|
||||
"""创建错误结果"""
|
||||
return cls(
|
||||
success=False,
|
||||
error=error,
|
||||
error_code=error_code,
|
||||
execution_time=execution_time,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
class ToolInfo(BaseModel):
|
||||
"""工具信息"""
|
||||
id: str = Field(..., description="工具ID")
|
||||
name: str = Field(..., description="工具名称")
|
||||
description: str = Field(..., description="工具描述")
|
||||
tool_type: ToolType = Field(..., description="工具类型")
|
||||
version: str = Field("1.0.0", description="工具版本")
|
||||
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
||||
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
|
||||
tags: List[str] = Field(default_factory=list, description="工具标签")
|
||||
tenant_id: Optional[str] = Field(None, description="租户ID")
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
from app.schemas.tool_schema import ToolParameter, ParameterType, ToolResult
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
@@ -107,7 +19,7 @@ class BaseTool(ABC):
|
||||
"""
|
||||
self.tool_id = tool_id
|
||||
self.config = config
|
||||
self._status = ToolStatus.ACTIVE
|
||||
self._status = ToolStatus.AVAILABLE
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -153,20 +65,6 @@ class BaseTool(ABC):
|
||||
"""工具标签"""
|
||||
return self.config.get("tags", [])
|
||||
|
||||
def get_info(self) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
return ToolInfo(
|
||||
id=self.tool_id,
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
tool_type=self.tool_type,
|
||||
version=self.version,
|
||||
parameters=self.parameters,
|
||||
status=self.status,
|
||||
tags=self.tags,
|
||||
tenant_id=self.config.get("tenant_id")
|
||||
)
|
||||
|
||||
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""验证参数
|
||||
|
||||
|
||||
@@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, List
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolResult, ToolParameter
|
||||
|
||||
|
||||
class BuiltinTool(BaseTool, ABC):
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime, timezone, timedelta
|
||||
from typing import List
|
||||
import pytz
|
||||
|
||||
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from .base import BuiltinTool
|
||||
|
||||
|
||||
@@ -54,14 +54,14 @@ class DateTimeTool(BuiltinTool):
|
||||
type=ParameterType.STRING,
|
||||
description="源时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="to_timezone",
|
||||
type=ParameterType.STRING,
|
||||
description="目标时区(如:UTC, Asia/Shanghai)",
|
||||
required=False,
|
||||
default="UTC"
|
||||
default="Asia/Shanghai"
|
||||
),
|
||||
ToolParameter(
|
||||
name="calculation",
|
||||
@@ -106,10 +106,11 @@ class DateTimeTool(BuiltinTool):
|
||||
error_code="DATETIME_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _get_current_time(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _get_current_time(kwargs) -> dict:
|
||||
"""获取当前时间"""
|
||||
timezone_str = kwargs.get("to_timezone", "UTC")
|
||||
timezone_str = kwargs.get("to_timezone", "Asia/Shanghai")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
|
||||
if timezone_str == "UTC":
|
||||
@@ -118,15 +119,20 @@ class DateTimeTool(BuiltinTool):
|
||||
tz = pytz.timezone(timezone_str)
|
||||
|
||||
now = datetime.now(tz)
|
||||
|
||||
utc_now = datetime.now(timezone.utc)
|
||||
|
||||
return {
|
||||
"datetime": now.strftime(output_format),
|
||||
"timestamp": int(now.timestamp()),
|
||||
"timezone": timezone_str,
|
||||
"iso_format": now.isoformat()
|
||||
"iso_format": now.isoformat(),
|
||||
"timestamp_ms": int(now.timestamp() * 1000),
|
||||
"utc_datetime": utc_now.strftime(output_format)
|
||||
}
|
||||
|
||||
def _format_datetime(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _format_datetime(kwargs) -> dict:
|
||||
"""格式化时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -144,8 +150,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _convert_timezone(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _convert_timezone(kwargs) -> dict:
|
||||
"""时区转换"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -184,8 +191,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"converted_timezone": to_timezone,
|
||||
"timestamp": int(converted_dt.timestamp())
|
||||
}
|
||||
|
||||
def _timestamp_to_datetime(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _timestamp_to_datetime(kwargs) -> dict:
|
||||
"""时间戳转日期时间"""
|
||||
input_value = kwargs.get("input_value")
|
||||
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -196,6 +204,8 @@ class DateTimeTool(BuiltinTool):
|
||||
|
||||
# 转换时间戳
|
||||
timestamp = float(input_value)
|
||||
if timestamp > 1e12:
|
||||
timestamp = timestamp / 1000
|
||||
|
||||
# 设置时区
|
||||
if timezone_str == "UTC":
|
||||
@@ -211,8 +221,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"timezone": timezone_str,
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
def _datetime_to_timestamp(self, kwargs) -> dict:
|
||||
|
||||
@staticmethod
|
||||
def _datetime_to_timestamp(kwargs) -> dict:
|
||||
"""日期时间转时间戳"""
|
||||
input_value = kwargs.get("input_value")
|
||||
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
|
||||
@@ -240,7 +251,7 @@ class DateTimeTool(BuiltinTool):
|
||||
"timestamp": int(dt.timestamp()),
|
||||
"iso_format": dt.isoformat()
|
||||
}
|
||||
|
||||
|
||||
def _calculate_datetime(self, kwargs) -> dict:
|
||||
"""时间计算"""
|
||||
input_value = kwargs.get("input_value")
|
||||
@@ -278,8 +289,9 @@ class DateTimeTool(BuiltinTool):
|
||||
"timezone": timezone_str,
|
||||
"timestamp": int(calculated_dt.timestamp())
|
||||
}
|
||||
|
||||
def _parse_time_delta(self, calculation: str) -> timedelta:
|
||||
|
||||
@staticmethod
|
||||
def _parse_time_delta(calculation: str) -> timedelta:
|
||||
"""解析时间计算表达式"""
|
||||
import re
|
||||
|
||||
|
||||
@@ -121,8 +121,9 @@ class JsonTool(BuiltinTool):
|
||||
error_code="JSON_ERROR",
|
||||
execution_time=execution_time
|
||||
)
|
||||
|
||||
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""格式化JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
@@ -151,12 +152,13 @@ class JsonTool(BuiltinTool):
|
||||
"sort_keys": sort_keys
|
||||
}
|
||||
}
|
||||
|
||||
def _minify_json(self, input_data: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _minify_json(input_data: str) -> Dict[str, Any]:
|
||||
"""压缩JSON"""
|
||||
# 解析并压缩
|
||||
data = json.loads(input_data)
|
||||
minified = json.dumps(data, separators=(',', ':'))
|
||||
minified = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
|
||||
|
||||
return {
|
||||
"original_size": len(input_data),
|
||||
@@ -165,7 +167,7 @@ class JsonTool(BuiltinTool):
|
||||
"minified_json": minified,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
|
||||
def _validate_json(self, input_data: str) -> Dict[str, Any]:
|
||||
"""验证JSON"""
|
||||
try:
|
||||
@@ -190,17 +192,19 @@ class JsonTool(BuiltinTool):
|
||||
"size": len(input_data)
|
||||
}
|
||||
|
||||
def _convert_json(self, input_data: str) -> Dict[str, Any]:
|
||||
@staticmethod
|
||||
def _convert_json(input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转义"""
|
||||
data = json.loads(input_data)
|
||||
converted = json.dumps(data, ensure_ascii=False)
|
||||
converted = json.dumps(data, ensure_ascii=True, separators=(',', ':'))
|
||||
|
||||
return {
|
||||
"converted_json": converted,
|
||||
"is_valid": True
|
||||
}
|
||||
|
||||
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _json_to_yaml(input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转YAML"""
|
||||
data = json.loads(input_data)
|
||||
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
|
||||
@@ -212,8 +216,9 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(yaml_output),
|
||||
"converted_data": yaml_output
|
||||
}
|
||||
|
||||
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _yaml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""YAML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
ensure_ascii = kwargs.get("ensure_ascii", False)
|
||||
@@ -228,10 +233,11 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _json_to_xml(input_data: str) -> Dict[str, Any]:
|
||||
"""JSON转XML"""
|
||||
data = json.loads(input_data)
|
||||
json_data = json.loads(input_data)
|
||||
|
||||
def dict_to_xml(data, root_name="root"):
|
||||
"""递归转换字典为XML"""
|
||||
@@ -267,7 +273,7 @@ class JsonTool(BuiltinTool):
|
||||
root.text = str(data)
|
||||
return root
|
||||
|
||||
xml_element = dict_to_xml(data)
|
||||
xml_element = dict_to_xml(json_data)
|
||||
xml_string = ET.tostring(xml_element, encoding='unicode')
|
||||
|
||||
# 格式化XML
|
||||
@@ -284,8 +290,9 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(formatted_xml),
|
||||
"converted_data": formatted_xml
|
||||
}
|
||||
|
||||
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _xml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""XML转JSON"""
|
||||
indent = kwargs.get("indent", 2)
|
||||
|
||||
@@ -328,8 +335,9 @@ class JsonTool(BuiltinTool):
|
||||
"converted_size": len(json_output),
|
||||
"converted_data": json_output
|
||||
}
|
||||
|
||||
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _merge_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""合并JSON"""
|
||||
merge_data = kwargs.get("merge_data")
|
||||
if not merge_data:
|
||||
@@ -364,8 +372,9 @@ class JsonTool(BuiltinTool):
|
||||
"result_size": len(merged_json),
|
||||
"merged_data": merged_json
|
||||
}
|
||||
|
||||
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_path( input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取JSON路径"""
|
||||
json_path = kwargs.get("json_path")
|
||||
if not json_path:
|
||||
|
||||
@@ -275,8 +275,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_formula_result( result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化公式识别结果"""
|
||||
formulas = result.get("formulas", [])
|
||||
|
||||
@@ -288,8 +289,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_table_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化表格识别结果"""
|
||||
tables = result.get("tables", [])
|
||||
|
||||
@@ -301,8 +303,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _format_document_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
|
||||
"""格式化文档识别结果"""
|
||||
return {
|
||||
"recognition_mode": "document",
|
||||
@@ -314,8 +317,9 @@ class TextInTool(BuiltinTool):
|
||||
"total_confidence": result.get("confidence", 0),
|
||||
"processing_time": result.get("processing_time", 0)
|
||||
}
|
||||
|
||||
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _group_lines_to_paragraphs(lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将行分组为段落"""
|
||||
paragraphs = []
|
||||
current_paragraph = []
|
||||
|
||||
@@ -1,485 +0,0 @@
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from app.core.tools.base import ToolResult
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ChainExecutionMode(str, Enum):
|
||||
"""链执行模式"""
|
||||
SEQUENTIAL = "sequential" # 顺序执行
|
||||
PARALLEL = "parallel" # 并行执行
|
||||
CONDITIONAL = "conditional" # 条件执行
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainStep:
|
||||
"""链步骤定义"""
|
||||
tool_id: str
|
||||
parameters: Dict[str, Any]
|
||||
condition: Optional[str] = None # 执行条件
|
||||
output_mapping: Optional[Dict[str, str]] = None # 输出映射
|
||||
error_handling: str = "stop" # 错误处理:stop, continue, retry
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChainDefinition:
|
||||
"""工具链定义"""
|
||||
name: str
|
||||
description: str
|
||||
steps: List[ChainStep]
|
||||
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
|
||||
global_timeout: Optional[float] = None
|
||||
retry_policy: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ChainExecutionContext:
|
||||
"""链执行上下文"""
|
||||
|
||||
def __init__(self, chain_id: str):
|
||||
self.chain_id = chain_id
|
||||
self.variables: Dict[str, Any] = {}
|
||||
self.step_results: Dict[int, ToolResult] = {}
|
||||
self.current_step = 0
|
||||
self.is_completed = False
|
||||
self.is_failed = False
|
||||
self.error_message: Optional[str] = None
|
||||
|
||||
|
||||
class ChainManager:
|
||||
"""工具链管理器 - 支持langchain的工具链模式"""
|
||||
|
||||
def __init__(self, executor: ToolExecutor):
|
||||
"""初始化工具链管理器
|
||||
|
||||
Args:
|
||||
executor: 工具执行器
|
||||
"""
|
||||
self.executor = executor
|
||||
self._chains: Dict[str, ChainDefinition] = {}
|
||||
self._running_chains: Dict[str, ChainExecutionContext] = {}
|
||||
|
||||
def register_chain(self, chain: ChainDefinition) -> bool:
|
||||
"""注册工具链
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
try:
|
||||
# 验证工具链定义
|
||||
validation_result = self._validate_chain(chain)
|
||||
if not validation_result[0]:
|
||||
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
|
||||
return False
|
||||
|
||||
self._chains[chain.name] = chain
|
||||
logger.info(f"工具链注册成功: {chain.name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def unregister_chain(self, chain_name: str) -> bool:
|
||||
"""注销工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
if chain_name in self._chains:
|
||||
del self._chains[chain_name]
|
||||
logger.info(f"工具链注销成功: {chain_name}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def list_chains(self) -> List[Dict[str, Any]]:
|
||||
"""列出所有工具链
|
||||
|
||||
Returns:
|
||||
工具链信息列表
|
||||
"""
|
||||
chains = []
|
||||
for name, chain in self._chains.items():
|
||||
chains.append({
|
||||
"name": name,
|
||||
"description": chain.description,
|
||||
"step_count": len(chain.steps),
|
||||
"execution_mode": chain.execution_mode.value,
|
||||
"global_timeout": chain.global_timeout
|
||||
})
|
||||
|
||||
return chains
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
chain_name: str,
|
||||
initial_variables: Optional[Dict[str, Any]] = None,
|
||||
chain_id: Optional[str] = None
|
||||
) -> Dict[str, Any] | None:
|
||||
"""执行工具链
|
||||
|
||||
Args:
|
||||
chain_name: 工具链名称
|
||||
initial_variables: 初始变量
|
||||
chain_id: 链执行ID(可选)
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
if chain_name not in self._chains:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"工具链不存在: {chain_name}",
|
||||
"chain_id": chain_id
|
||||
}
|
||||
|
||||
chain = self._chains[chain_name]
|
||||
|
||||
# 生成链ID
|
||||
if not chain_id:
|
||||
import uuid
|
||||
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ChainExecutionContext(chain_id)
|
||||
context.variables = initial_variables or {}
|
||||
self._running_chains[chain_id] = context
|
||||
|
||||
try:
|
||||
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
|
||||
|
||||
# 根据执行模式执行
|
||||
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
|
||||
result = await self._execute_sequential(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
|
||||
result = await self._execute_parallel(chain, context)
|
||||
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
|
||||
result = await self._execute_conditional(chain, context)
|
||||
else:
|
||||
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
|
||||
|
||||
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"chain_id": chain_id,
|
||||
"completed_steps": context.current_step,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
if chain_id in self._running_chains:
|
||||
del self._running_chains[chain_id]
|
||||
|
||||
async def _execute_sequential(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""顺序执行工具链"""
|
||||
for i, step in enumerate(chain.steps):
|
||||
context.current_step = i
|
||||
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
logger.debug(f"跳过步骤 {i}: 条件不满足")
|
||||
continue
|
||||
|
||||
# 准备参数
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
|
||||
# 执行工具
|
||||
try:
|
||||
result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
|
||||
context.step_results[i] = result
|
||||
|
||||
# 处理输出映射
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 处理执行失败
|
||||
if not result.success:
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = result.error
|
||||
break
|
||||
elif step.error_handling == "continue":
|
||||
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
|
||||
continue
|
||||
elif step.error_handling == "retry":
|
||||
# 简单重试逻辑
|
||||
retry_result = await self.executor.execute_tool(
|
||||
tool_id=step.tool_id,
|
||||
parameters=parameters
|
||||
)
|
||||
context.step_results[i] = retry_result
|
||||
if not retry_result.success and step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = retry_result.error
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"步骤 {i} 执行异常: {e}")
|
||||
if step.error_handling == "stop":
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
break
|
||||
|
||||
context.is_completed = not context.is_failed
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": context.current_step + 1,
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_parallel(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""并行执行工具链"""
|
||||
# 准备所有步骤的执行配置
|
||||
execution_configs = []
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
# 检查执行条件
|
||||
if step.condition and not self._evaluate_condition(step.condition, context):
|
||||
continue
|
||||
|
||||
parameters = self._prepare_parameters(step.parameters, context)
|
||||
execution_configs.append({
|
||||
"step_index": i,
|
||||
"tool_id": step.tool_id,
|
||||
"parameters": parameters
|
||||
})
|
||||
|
||||
# 并行执行所有步骤
|
||||
try:
|
||||
results = await self.executor.execute_tools_batch(execution_configs)
|
||||
|
||||
# 处理结果
|
||||
for i, result in enumerate(results):
|
||||
step_index = execution_configs[i]["step_index"]
|
||||
context.step_results[step_index] = result
|
||||
|
||||
# 处理输出映射
|
||||
step = chain.steps[step_index]
|
||||
if step.output_mapping and result.success:
|
||||
self._apply_output_mapping(step.output_mapping, result.data, context)
|
||||
|
||||
# 检查是否有失败的步骤
|
||||
failed_steps = [i for i, result in context.step_results.items() if not result.success]
|
||||
|
||||
context.is_completed = len(failed_steps) == 0
|
||||
if failed_steps:
|
||||
context.error_message = f"步骤 {failed_steps} 执行失败"
|
||||
|
||||
except Exception as e:
|
||||
context.is_failed = True
|
||||
context.error_message = str(e)
|
||||
|
||||
return {
|
||||
"success": context.is_completed,
|
||||
"error": context.error_message,
|
||||
"chain_id": context.chain_id,
|
||||
"completed_steps": len(context.step_results),
|
||||
"total_steps": len(chain.steps),
|
||||
"final_variables": context.variables,
|
||||
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
|
||||
}
|
||||
|
||||
async def _execute_conditional(
|
||||
self,
|
||||
chain: ChainDefinition,
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""条件执行工具链"""
|
||||
# 条件执行类似于顺序执行,但更严格地检查条件
|
||||
return await self._execute_sequential(chain, context)
|
||||
|
||||
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具链定义
|
||||
|
||||
Args:
|
||||
chain: 工具链定义
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
if not chain.name:
|
||||
return False, "工具链名称不能为空"
|
||||
|
||||
if not chain.steps:
|
||||
return False, "工具链必须包含至少一个步骤"
|
||||
|
||||
for i, step in enumerate(chain.steps):
|
||||
if not step.tool_id:
|
||||
return False, f"步骤 {i} 缺少工具ID"
|
||||
|
||||
if step.error_handling not in ["stop", "continue", "retry"]:
|
||||
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
|
||||
|
||||
return True, None
|
||||
|
||||
def _prepare_parameters(
|
||||
self,
|
||||
parameters: Dict[str, Any],
|
||||
context: ChainExecutionContext
|
||||
) -> Dict[str, Any]:
|
||||
"""准备参数(支持变量替换)
|
||||
|
||||
Args:
|
||||
parameters: 原始参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
处理后的参数
|
||||
"""
|
||||
prepared = {}
|
||||
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||
# 变量替换
|
||||
var_name = value[2:-1]
|
||||
if var_name in context.variables:
|
||||
prepared[key] = context.variables[var_name]
|
||||
else:
|
||||
prepared[key] = value # 保持原值
|
||||
else:
|
||||
prepared[key] = value
|
||||
|
||||
return prepared
|
||||
|
||||
def _evaluate_condition(
|
||||
self,
|
||||
condition: str,
|
||||
context: ChainExecutionContext
|
||||
) -> bool:
|
||||
"""评估执行条件
|
||||
|
||||
Args:
|
||||
condition: 条件表达式
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
条件是否满足
|
||||
"""
|
||||
try:
|
||||
# 简单的条件评估(可以扩展为更复杂的表达式解析)
|
||||
# 支持格式:variable == value, variable != value, variable > value 等
|
||||
|
||||
if "==" in condition:
|
||||
var_name, expected_value = condition.split("==", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) == expected_value
|
||||
|
||||
elif "!=" in condition:
|
||||
var_name, expected_value = condition.split("!=", 1)
|
||||
var_name = var_name.strip()
|
||||
expected_value = expected_value.strip().strip('"\'')
|
||||
|
||||
return str(context.variables.get(var_name, "")) != expected_value
|
||||
|
||||
elif condition in context.variables:
|
||||
# 简单的布尔检查
|
||||
return bool(context.variables[condition])
|
||||
|
||||
else:
|
||||
# 默认为真
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"条件评估失败: {condition}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _apply_output_mapping(
|
||||
self,
|
||||
mapping: Dict[str, str],
|
||||
output_data: Any,
|
||||
context: ChainExecutionContext
|
||||
):
|
||||
"""应用输出映射
|
||||
|
||||
Args:
|
||||
mapping: 输出映射配置
|
||||
output_data: 输出数据
|
||||
context: 执行上下文
|
||||
"""
|
||||
try:
|
||||
if isinstance(output_data, dict):
|
||||
for source_key, target_var in mapping.items():
|
||||
if source_key in output_data:
|
||||
context.variables[target_var] = output_data[source_key]
|
||||
else:
|
||||
# 如果输出不是字典,将整个输出映射到指定变量
|
||||
if "result" in mapping:
|
||||
context.variables[mapping["result"]] = output_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"输出映射失败: {e}")
|
||||
|
||||
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
|
||||
"""序列化工具结果
|
||||
|
||||
Args:
|
||||
result: 工具结果
|
||||
|
||||
Returns:
|
||||
序列化的结果
|
||||
"""
|
||||
return {
|
||||
"success": result.success,
|
||||
"data": result.data,
|
||||
"error": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time,
|
||||
"token_usage": result.token_usage,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
def get_running_chains(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的工具链
|
||||
|
||||
Returns:
|
||||
运行中的工具链列表
|
||||
"""
|
||||
chains = []
|
||||
for chain_id, context in self._running_chains.items():
|
||||
chains.append({
|
||||
"chain_id": chain_id,
|
||||
"current_step": context.current_step,
|
||||
"is_completed": context.is_completed,
|
||||
"is_failed": context.is_failed,
|
||||
"variables_count": len(context.variables),
|
||||
"completed_steps": len(context.step_results)
|
||||
})
|
||||
|
||||
return chains
|
||||
@@ -1,264 +0,0 @@
|
||||
"""工具配置管理器 - 管理工具配置的加载和验证"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolConfigSchema(BaseModel):
|
||||
"""工具配置基础Schema"""
|
||||
name: str
|
||||
description: str
|
||||
tool_type: str
|
||||
version: str = "1.0.0"
|
||||
enabled: bool = True
|
||||
parameters: Dict[str, Any] = {}
|
||||
tags: list[str] = []
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
||||
|
||||
class BuiltinToolConfigSchema(ToolConfigSchema):
|
||||
"""内置工具配置Schema"""
|
||||
tool_class: str
|
||||
tool_type: str = "builtin"
|
||||
|
||||
|
||||
class CustomToolConfigSchema(ToolConfigSchema):
|
||||
"""自定义工具配置Schema"""
|
||||
schema_url: Optional[str] = None
|
||||
schema_content: Optional[Dict[str, Any]] = None
|
||||
auth_type: str = "none"
|
||||
auth_config: Dict[str, Any] = {}
|
||||
base_url: Optional[str] = None
|
||||
timeout: int = 30
|
||||
tool_type: str = "custom"
|
||||
|
||||
|
||||
class MCPToolConfigSchema(ToolConfigSchema):
|
||||
"""MCP工具配置Schema"""
|
||||
server_url: str
|
||||
connection_config: Dict[str, Any] = {}
|
||||
available_tools: list[str] = []
|
||||
tool_type: str = "mcp"
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""工具配置管理器"""
|
||||
|
||||
def __init__(self, config_dir: Optional[str] = None):
|
||||
"""初始化配置管理器
|
||||
|
||||
Args:
|
||||
config_dir: 配置文件目录,默认使用系统配置
|
||||
"""
|
||||
self.config_dir = Path(config_dir or self._get_default_config_dir())
|
||||
self.config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
|
||||
|
||||
def _get_default_config_dir(self) -> str:
|
||||
"""获取默认配置目录"""
|
||||
# 获取tools目录下的configs子目录
|
||||
tools_dir = Path(__file__).parent
|
||||
return str(tools_dir / "configs")
|
||||
|
||||
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
|
||||
"""加载内置工具配置
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
configs = {}
|
||||
builtin_dir = self.config_dir / "builtin"
|
||||
|
||||
if not builtin_dir.exists():
|
||||
logger.info("内置工具配置目录不存在,创建默认配置")
|
||||
self._create_default_builtin_configs(builtin_dir)
|
||||
|
||||
for config_file in builtin_dir.glob("*.json"):
|
||||
try:
|
||||
config_data = self._load_config_file(config_file)
|
||||
config = BuiltinToolConfigSchema(**config_data)
|
||||
configs[config.name] = config
|
||||
logger.debug(f"加载内置工具配置: {config.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
|
||||
|
||||
return configs
|
||||
|
||||
def load_builtin_tools_config(self) -> Dict[str, Any]:
|
||||
"""加载全局内置工具配置(兼容原有接口)
|
||||
|
||||
Returns:
|
||||
内置工具配置字典
|
||||
"""
|
||||
config_file = self.config_dir / "builtin_tools.json"
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载内置工具配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
|
||||
"""确保内置工具已初始化到数据库
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
db_session: 数据库会话
|
||||
tool_config_model: ToolConfig模型类
|
||||
builtin_tool_config_model: BuiltinToolConfig模型类
|
||||
tool_type_enum: ToolType枚举
|
||||
tool_status_enum: ToolStatus枚举
|
||||
"""
|
||||
# 检查是否已初始化
|
||||
existing_count = db_session.query(tool_config_model).filter(
|
||||
tool_config_model.tenant_id == tenant_id,
|
||||
tool_config_model.tool_type == tool_type_enum.BUILTIN
|
||||
).count()
|
||||
|
||||
if existing_count > 0:
|
||||
return # 已初始化
|
||||
|
||||
# 加载全局配置
|
||||
builtin_tools = self.load_builtin_tools_config()
|
||||
|
||||
# 为租户创建内置工具记录
|
||||
for tool_key, tool_info in builtin_tools.items():
|
||||
# 设置初始状态
|
||||
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
|
||||
|
||||
tool_config = tool_config_model(
|
||||
name=tool_info['name'],
|
||||
description=tool_info['description'],
|
||||
tool_type=tool_type_enum.BUILTIN,
|
||||
tenant_id=tenant_id,
|
||||
status=initial_status
|
||||
)
|
||||
db_session.add(tool_config)
|
||||
db_session.flush()
|
||||
|
||||
builtin_config = builtin_tool_config_model(
|
||||
id=tool_config.id,
|
||||
tool_class=tool_info['tool_class'],
|
||||
parameters={}
|
||||
)
|
||||
db_session.add(builtin_config)
|
||||
|
||||
db_session.commit()
|
||||
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
|
||||
|
||||
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
|
||||
"""保存工具配置
|
||||
|
||||
Args:
|
||||
config: 工具配置
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
保存是否成功
|
||||
"""
|
||||
try:
|
||||
config_dir = self.config_dir / tool_type
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config_file = config_dir / f"{config.name}.json"
|
||||
config_data = config.model_dump()
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
|
||||
"""删除工具配置
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
删除是否成功
|
||||
"""
|
||||
try:
|
||||
config_file = self.config_dir / tool_type / f"{tool_name}.json"
|
||||
|
||||
if config_file.exists():
|
||||
config_file.unlink()
|
||||
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
|
||||
"""验证工具配置
|
||||
|
||||
Args:
|
||||
config_data: 配置数据
|
||||
tool_type: 工具类型
|
||||
|
||||
Returns:
|
||||
(是否有效, 错误信息)
|
||||
"""
|
||||
try:
|
||||
schema_map = {
|
||||
"builtin": BuiltinToolConfigSchema,
|
||||
"custom": CustomToolConfigSchema,
|
||||
"mcp": MCPToolConfigSchema
|
||||
}
|
||||
|
||||
schema_class = schema_map.get(tool_type)
|
||||
if not schema_class:
|
||||
return False, f"不支持的工具类型: {tool_type}"
|
||||
|
||||
# 验证配置
|
||||
schema_class(**config_data)
|
||||
return True, None
|
||||
|
||||
except ValidationError as e:
|
||||
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
|
||||
return False, f"配置验证失败: {error_msg}"
|
||||
except Exception as e:
|
||||
return False, f"配置验证异常: {str(e)}"
|
||||
|
||||
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
|
||||
"""加载配置文件
|
||||
|
||||
Args:
|
||||
config_file: 配置文件路径
|
||||
|
||||
Returns:
|
||||
配置数据字典
|
||||
"""
|
||||
try:
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def _create_default_builtin_configs(self, builtin_dir: Path):
|
||||
"""创建默认内置工具配置
|
||||
|
||||
Args:
|
||||
builtin_dir: 内置工具配置目录
|
||||
"""
|
||||
builtin_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
|
||||
# 配置文件已经通过其他方式创建,这里只需要确保目录存在
|
||||
@@ -54,7 +54,8 @@
|
||||
"enabled": true,
|
||||
"parameters": {
|
||||
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
|
||||
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
|
||||
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
|
||||
"base_url": {"type": "string", "description": "API地址", "default": "https://api.textin.com/v1"}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import time
|
||||
from typing import Dict, Any, Tuple
|
||||
from urllib.parse import quote
|
||||
import aiohttp
|
||||
@@ -51,8 +50,9 @@ class AuthManager:
|
||||
|
||||
except Exception as e:
|
||||
return False, f"验证认证配置时出错: {e}"
|
||||
|
||||
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
|
||||
@staticmethod
|
||||
def _validate_api_key_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证API Key认证配置
|
||||
|
||||
Args:
|
||||
@@ -79,8 +79,9 @@ class AuthManager:
|
||||
return False, "API Key位置必须是 header、query 或 cookie"
|
||||
|
||||
return True, ""
|
||||
|
||||
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
|
||||
@staticmethod
|
||||
def _validate_bearer_token_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""验证Bearer Token认证配置
|
||||
|
||||
Args:
|
||||
@@ -135,9 +136,9 @@ class AuthManager:
|
||||
except Exception as e:
|
||||
logger.error(f"应用认证时出错: {e}")
|
||||
return url, headers, params
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _apply_api_key_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
@@ -176,9 +177,9 @@ class AuthManager:
|
||||
headers["Cookie"] = cookie_value
|
||||
|
||||
return url, headers, params
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _apply_bearer_token_auth(
|
||||
self,
|
||||
auth_config: Dict[str, Any],
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
@@ -260,8 +261,9 @@ class AuthManager:
|
||||
except Exception as e:
|
||||
logger.error(f"解密认证配置失败: {e}")
|
||||
return encrypted_config
|
||||
|
||||
def _encrypt_string(self, value: str, key: str) -> str:
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_string(value: str, key: str) -> str:
|
||||
"""加密字符串
|
||||
|
||||
Args:
|
||||
@@ -289,8 +291,9 @@ class AuthManager:
|
||||
except Exception as e:
|
||||
logger.error(f"加密字符串失败: {e}")
|
||||
return value
|
||||
|
||||
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
|
||||
|
||||
@staticmethod
|
||||
def _decrypt_string(encrypted_value: str, key: str) -> str:
|
||||
"""解密字符串
|
||||
|
||||
Args:
|
||||
@@ -471,8 +474,9 @@ class AuthManager:
|
||||
"error": f"测试认证时出错: {e}",
|
||||
"auth_type": auth_type.value
|
||||
}
|
||||
|
||||
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def get_auth_config_template(auth_type: AuthType) -> Dict[str, Any]:
|
||||
"""获取认证配置模板
|
||||
|
||||
Args:
|
||||
@@ -498,8 +502,9 @@ class AuthManager:
|
||||
}
|
||||
|
||||
return templates.get(auth_type, {})
|
||||
|
||||
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def mask_sensitive_config(auth_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""遮蔽认证配置中的敏感信息
|
||||
|
||||
Args:
|
||||
|
||||
@@ -5,7 +5,8 @@ import aiohttp
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from app.models.tool_model import ToolType, AuthType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -173,8 +174,9 @@ class CustomTool(BaseTool):
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
|
||||
|
||||
@staticmethod
|
||||
def _convert_openapi_type(openapi_type: str) -> ParameterType:
|
||||
"""转换OpenAPI类型到内部类型"""
|
||||
type_mapping = {
|
||||
"string": ParameterType.STRING,
|
||||
@@ -239,8 +241,9 @@ class CustomTool(BaseTool):
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
return headers
|
||||
|
||||
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _build_request_data(operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""构建请求数据"""
|
||||
if operation["method"] in ["POST", "PUT", "PATCH"]:
|
||||
request_body = operation.get("request_body")
|
||||
@@ -284,6 +287,7 @@ class CustomTool(BaseTool):
|
||||
try:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
logger.error(f"解析HTTP响应JSON失败: {str(e)}")
|
||||
return await response.text()
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -10,6 +10,9 @@ from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
# SchemaParser = OpenAPISchemaParser = None
|
||||
|
||||
|
||||
class OpenAPISchemaParser:
|
||||
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
|
||||
@@ -88,8 +91,9 @@ class OpenAPISchemaParser:
|
||||
except Exception as e:
|
||||
logger.error(f"解析schema内容失败: {e}")
|
||||
return False, {}, str(e)
|
||||
|
||||
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _parse_content(content: str, content_type: str) -> Optional[Dict[str, Any]]:
|
||||
"""解析内容为字典
|
||||
|
||||
Args:
|
||||
@@ -101,7 +105,7 @@ class OpenAPISchemaParser:
|
||||
"""
|
||||
try:
|
||||
# 根据内容类型解析
|
||||
if 'json' in content_type:
|
||||
if 'application/json' in content_type:
|
||||
return json.loads(content)
|
||||
elif 'yaml' in content_type or 'yml' in content_type:
|
||||
return yaml.safe_load(content)
|
||||
@@ -228,8 +232,9 @@ class OpenAPISchemaParser:
|
||||
}
|
||||
|
||||
return operations
|
||||
|
||||
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_parameters(operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取操作参数
|
||||
|
||||
Args:
|
||||
@@ -266,8 +271,9 @@ class OpenAPISchemaParser:
|
||||
}
|
||||
|
||||
return parameters
|
||||
|
||||
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_request_body(operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""提取请求体信息
|
||||
|
||||
Args:
|
||||
@@ -298,8 +304,9 @@ class OpenAPISchemaParser:
|
||||
"schema": schema,
|
||||
"content_types": list(content.keys())
|
||||
}
|
||||
|
||||
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
@staticmethod
|
||||
def _extract_responses(operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""提取响应信息
|
||||
|
||||
Args:
|
||||
@@ -331,8 +338,9 @@ class OpenAPISchemaParser:
|
||||
}
|
||||
|
||||
return responses
|
||||
|
||||
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
|
||||
@staticmethod
|
||||
def generate_tool_parameters(operations: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""生成工具参数定义
|
||||
|
||||
Args:
|
||||
@@ -396,7 +404,7 @@ class OpenAPISchemaParser:
|
||||
parameters.extend(all_params.values())
|
||||
|
||||
return parameters
|
||||
|
||||
|
||||
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
|
||||
"""验证操作参数
|
||||
|
||||
@@ -447,8 +455,9 @@ class OpenAPISchemaParser:
|
||||
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
|
||||
|
||||
return len(errors) == 0, errors
|
||||
|
||||
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
|
||||
|
||||
@staticmethod
|
||||
def _validate_parameter_type(value: Any, expected_type: str) -> bool:
|
||||
"""验证参数类型
|
||||
|
||||
Args:
|
||||
@@ -474,4 +483,7 @@ class OpenAPISchemaParser:
|
||||
if expected_python_type:
|
||||
return isinstance(value, expected_python_type)
|
||||
|
||||
return True
|
||||
return True
|
||||
|
||||
# 为了兼容性,创建别名
|
||||
SchemaParser = OpenAPISchemaParser
|
||||
@@ -1,501 +0,0 @@
|
||||
"""工具执行器 - 负责工具的实际调用和执行管理"""
|
||||
import asyncio
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import ToolExecution, ExecutionStatus
|
||||
from app.core.tools.base import BaseTool, ToolResult
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ExecutionContext:
|
||||
"""执行上下文"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
execution_id: str,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
self.execution_id = execution_id
|
||||
self.tool_id = tool_id
|
||||
self.user_id = user_id
|
||||
self.workspace_id = workspace_id
|
||||
self.timeout = timeout or 60.0 # 默认60秒超时
|
||||
self.metadata = metadata or {}
|
||||
self.started_at = datetime.now()
|
||||
self.completed_at: Optional[datetime] = None
|
||||
self.status = ExecutionStatus.PENDING
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""工具执行器 - 使用langchain标准接口执行工具"""
|
||||
|
||||
def __init__(self, db: Session, registry: ToolRegistry):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
registry: 工具注册表
|
||||
"""
|
||||
self.db = db
|
||||
self.registry = registry
|
||||
self._running_executions: Dict[str, ExecutionContext] = {}
|
||||
self._execution_lock = asyncio.Lock()
|
||||
|
||||
async def execute_tool(
|
||||
self,
|
||||
tool_id: str,
|
||||
parameters: Dict[str, Any],
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
execution_id: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
) -> ToolResult:
|
||||
"""执行工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
parameters: 工具参数
|
||||
user_id: 用户ID
|
||||
workspace_id: 工作空间ID
|
||||
execution_id: 执行ID(可选,自动生成)
|
||||
timeout: 超时时间(秒)
|
||||
metadata: 额外元数据
|
||||
|
||||
Returns:
|
||||
工具执行结果
|
||||
"""
|
||||
# 生成执行ID
|
||||
if not execution_id:
|
||||
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 创建执行上下文
|
||||
context = ExecutionContext(
|
||||
execution_id=execution_id,
|
||||
tool_id=tool_id,
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
timeout=timeout,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取工具实例
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if not tool:
|
||||
return ToolResult.error_result(
|
||||
error=f"工具不存在: {tool_id}",
|
||||
error_code="TOOL_NOT_FOUND",
|
||||
execution_time=0.0
|
||||
)
|
||||
|
||||
# 记录执行开始
|
||||
await self._record_execution_start(context, parameters)
|
||||
|
||||
# 执行工具
|
||||
result = await self._execute_with_timeout(tool, parameters, context)
|
||||
|
||||
# 记录执行完成
|
||||
await self._record_execution_complete(context, result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
|
||||
|
||||
# 记录执行失败
|
||||
error_result = ToolResult.error_result(
|
||||
error=str(e),
|
||||
error_code="EXECUTION_ERROR",
|
||||
execution_time=time.time() - context.started_at.timestamp()
|
||||
)
|
||||
await self._record_execution_complete(context, error_result)
|
||||
|
||||
return error_result
|
||||
|
||||
finally:
|
||||
# 清理执行上下文
|
||||
async with self._execution_lock:
|
||||
if execution_id in self._running_executions:
|
||||
del self._running_executions[execution_id]
|
||||
|
||||
async def execute_tools_batch(
|
||||
self,
|
||||
tool_executions: List[Dict[str, Any]],
|
||||
max_concurrency: int = 5
|
||||
) -> List[ToolResult]:
|
||||
"""批量执行工具
|
||||
|
||||
Args:
|
||||
tool_executions: 工具执行配置列表,每个包含tool_id和parameters
|
||||
max_concurrency: 最大并发数
|
||||
|
||||
Returns:
|
||||
执行结果列表
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
|
||||
async with semaphore:
|
||||
return await self.execute_tool(
|
||||
tool_id=exec_config["tool_id"],
|
||||
parameters=exec_config.get("parameters", {}),
|
||||
user_id=exec_config.get("user_id"),
|
||||
workspace_id=exec_config.get("workspace_id"),
|
||||
timeout=exec_config.get("timeout"),
|
||||
metadata=exec_config.get("metadata")
|
||||
)
|
||||
|
||||
# 并发执行所有工具
|
||||
tasks = [execute_single(config) for config in tool_executions]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 处理异常结果
|
||||
processed_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
processed_results.append(
|
||||
ToolResult.error_result(
|
||||
error=str(result),
|
||||
error_code="BATCH_EXECUTION_ERROR",
|
||||
execution_time=0.0
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_results.append(result)
|
||||
|
||||
return processed_results
|
||||
|
||||
async def cancel_execution(self, execution_id: str) -> bool:
|
||||
"""取消工具执行
|
||||
|
||||
Args:
|
||||
execution_id: 执行ID
|
||||
|
||||
Returns:
|
||||
是否成功取消
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
if execution_id not in self._running_executions:
|
||||
return False
|
||||
|
||||
context = self._running_executions[execution_id]
|
||||
context.status = ExecutionStatus.FAILED
|
||||
|
||||
# 更新数据库记录
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = ExecutionStatus.FAILED.value
|
||||
execution_record.error_message = "执行被取消"
|
||||
execution_record.completed_at = datetime.now()
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"工具执行已取消: {execution_id}")
|
||||
return True
|
||||
|
||||
def get_running_executions(self) -> List[Dict[str, Any]]:
|
||||
"""获取正在运行的执行列表
|
||||
|
||||
Returns:
|
||||
执行信息列表
|
||||
"""
|
||||
executions = []
|
||||
for execution_id, context in self._running_executions.items():
|
||||
executions.append({
|
||||
"execution_id": execution_id,
|
||||
"tool_id": context.tool_id,
|
||||
"user_id": str(context.user_id) if context.user_id else None,
|
||||
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
|
||||
"started_at": context.started_at.isoformat(),
|
||||
"status": context.status.value,
|
||||
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
|
||||
})
|
||||
|
||||
return executions
|
||||
|
||||
async def _execute_with_timeout(
|
||||
self,
|
||||
tool: BaseTool,
|
||||
parameters: Dict[str, Any],
|
||||
context: ExecutionContext
|
||||
) -> ToolResult:
|
||||
"""带超时的工具执行
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
parameters: 参数
|
||||
context: 执行上下文
|
||||
|
||||
Returns:
|
||||
执行结果
|
||||
"""
|
||||
async with self._execution_lock:
|
||||
self._running_executions[context.execution_id] = context
|
||||
context.status = ExecutionStatus.RUNNING
|
||||
|
||||
try:
|
||||
# 使用asyncio.wait_for实现超时控制
|
||||
result = await asyncio.wait_for(
|
||||
tool.safe_execute(**parameters),
|
||||
timeout=context.timeout
|
||||
)
|
||||
|
||||
context.status = ExecutionStatus.COMPLETED
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
context.status = ExecutionStatus.TIMEOUT
|
||||
return ToolResult.error_result(
|
||||
error=f"工具执行超时({context.timeout}秒)",
|
||||
error_code="EXECUTION_TIMEOUT",
|
||||
execution_time=context.timeout
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
context.status = ExecutionStatus.FAILED
|
||||
raise
|
||||
|
||||
async def _record_execution_start(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
parameters: Dict[str, Any]
|
||||
):
|
||||
"""记录执行开始"""
|
||||
try:
|
||||
execution_record = ToolExecution(
|
||||
execution_id=context.execution_id,
|
||||
tool_config_id=uuid.UUID(context.tool_id),
|
||||
status=ExecutionStatus.RUNNING.value,
|
||||
input_data=parameters,
|
||||
started_at=context.started_at,
|
||||
user_id=context.user_id,
|
||||
workspace_id=context.workspace_id
|
||||
)
|
||||
|
||||
self.db.add(execution_record)
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已创建: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
async def _record_execution_complete(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
result: ToolResult
|
||||
):
|
||||
"""记录执行完成"""
|
||||
try:
|
||||
context.completed_at = datetime.now()
|
||||
|
||||
execution_record = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.execution_id == context.execution_id
|
||||
).first()
|
||||
|
||||
if execution_record:
|
||||
execution_record.status = (
|
||||
ExecutionStatus.COMPLETED.value if result.success
|
||||
else ExecutionStatus.FAILED.value
|
||||
)
|
||||
execution_record.output_data = result.data if result.success else None
|
||||
execution_record.error_message = result.error if not result.success else None
|
||||
execution_record.completed_at = context.completed_at
|
||||
execution_record.execution_time = result.execution_time
|
||||
execution_record.token_usage = result.token_usage
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.debug(f"执行记录已更新: {context.execution_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
|
||||
|
||||
def get_execution_history(
|
||||
self,
|
||||
tool_id: Optional[str] = None,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取执行历史
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID过滤
|
||||
user_id: 用户ID过滤
|
||||
workspace_id: 工作空间ID过滤
|
||||
limit: 返回数量限制
|
||||
|
||||
Returns:
|
||||
执行历史列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolExecution).order_by(
|
||||
ToolExecution.started_at.desc()
|
||||
)
|
||||
|
||||
if tool_id:
|
||||
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
|
||||
|
||||
if user_id:
|
||||
query = query.filter(ToolExecution.user_id == user_id)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.limit(limit).all()
|
||||
|
||||
history = []
|
||||
for execution in executions:
|
||||
history.append({
|
||||
"execution_id": execution.execution_id,
|
||||
"tool_id": str(execution.tool_config_id),
|
||||
"status": execution.status,
|
||||
"started_at": execution.started_at.isoformat() if execution.started_at else None,
|
||||
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
|
||||
"execution_time": execution.execution_time,
|
||||
"user_id": str(execution.user_id) if execution.user_id else None,
|
||||
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
|
||||
"input_data": execution.input_data,
|
||||
"output_data": execution.output_data,
|
||||
"error_message": execution.error_message,
|
||||
"token_usage": execution.token_usage
|
||||
})
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行历史失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
def get_execution_statistics(
|
||||
self,
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
days: int = 7
|
||||
) -> Dict[str, Any]:
|
||||
"""获取执行统计信息
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
days: 统计天数
|
||||
|
||||
Returns:
|
||||
统计信息
|
||||
"""
|
||||
try:
|
||||
from datetime import timedelta
|
||||
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
query = self.db.query(ToolExecution).filter(
|
||||
ToolExecution.started_at >= start_date
|
||||
)
|
||||
|
||||
if workspace_id:
|
||||
query = query.filter(ToolExecution.workspace_id == workspace_id)
|
||||
|
||||
executions = query.all()
|
||||
|
||||
# 统计数据
|
||||
total_executions = len(executions)
|
||||
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
|
||||
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
|
||||
|
||||
# 平均执行时间
|
||||
completed_executions = [e for e in executions if e.execution_time is not None]
|
||||
avg_execution_time = (
|
||||
sum(e.execution_time for e in completed_executions) / len(completed_executions)
|
||||
if completed_executions else 0
|
||||
)
|
||||
|
||||
# 按工具统计
|
||||
tool_stats = {}
|
||||
for execution in executions:
|
||||
tool_id = str(execution.tool_config_id)
|
||||
if tool_id not in tool_stats:
|
||||
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
|
||||
|
||||
tool_stats[tool_id]["total"] += 1
|
||||
if execution.status == ExecutionStatus.COMPLETED.value:
|
||||
tool_stats[tool_id]["successful"] += 1
|
||||
elif execution.status == ExecutionStatus.FAILED.value:
|
||||
tool_stats[tool_id]["failed"] += 1
|
||||
|
||||
return {
|
||||
"period_days": days,
|
||||
"total_executions": total_executions,
|
||||
"successful_executions": successful_executions,
|
||||
"failed_executions": failed_executions,
|
||||
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
|
||||
"average_execution_time": avg_execution_time,
|
||||
"tool_statistics": tool_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取执行统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
async def test_tool_connection(
|
||||
self,
|
||||
tool_id: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""测试工具连接"""
|
||||
try:
|
||||
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
|
||||
from .mcp.client import MCPClient
|
||||
|
||||
tool_config = self.db.query(ToolConfig).filter(
|
||||
ToolConfig.id == uuid.UUID(tool_id)
|
||||
).first()
|
||||
|
||||
if not tool_config:
|
||||
return {"success": False, "message": "工具不存在"}
|
||||
|
||||
if tool_config.tool_type == ToolType.MCP.value:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if not mcp_config:
|
||||
return {"success": False, "message": "MCP配置不存在"}
|
||||
|
||||
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
|
||||
|
||||
if await client.connect():
|
||||
try:
|
||||
tools = await client.list_tools()
|
||||
await client.disconnect()
|
||||
return {
|
||||
"success": True,
|
||||
"message": "MCP连接成功",
|
||||
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
|
||||
}
|
||||
except:
|
||||
await client.disconnect()
|
||||
return {"success": False, "message": "MCP功能测试失败"}
|
||||
else:
|
||||
return {"success": False, "message": "MCP连接失败"}
|
||||
else:
|
||||
tool = self.registry.get_tool(tool_id)
|
||||
if tool and hasattr(tool, 'test_connection'):
|
||||
result = tool.test_connection()
|
||||
return {"success": result.get("success", False), "message": result.get("message", "")}
|
||||
return {"success": True, "message": "工具无需连接测试"}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": "测试失败", "error": str(e)}
|
||||
@@ -4,7 +4,8 @@ from typing import Dict, Any, List
|
||||
import aiohttp
|
||||
|
||||
from app.models.tool_model import ToolType
|
||||
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
|
||||
from app.core.tools.base import BaseTool
|
||||
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
@@ -123,33 +124,43 @@ class MCPTool(BaseTool):
|
||||
async def connect(self) -> bool:
|
||||
"""连接到MCP服务器"""
|
||||
try:
|
||||
# 这里应该实现实际的MCP连接逻辑
|
||||
# 为了简化,这里只是模拟连接
|
||||
from .client import MCPClient
|
||||
|
||||
# 测试服务器连接
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
# 尝试获取服务器信息
|
||||
async with session.get(f"{self.server_url}/info") as response:
|
||||
if response.status == 200:
|
||||
server_info = await response.json()
|
||||
self.available_tools = server_info.get("tools", [])
|
||||
self._connected = True
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
raise Exception(f"服务器响应错误: {response.status}")
|
||||
if self._connected:
|
||||
return True
|
||||
|
||||
self._client = MCPClient(self.server_url, self.connection_config)
|
||||
|
||||
if await self._client.connect():
|
||||
self._connected = True
|
||||
# 更新可用工具列表
|
||||
await self._update_available_tools()
|
||||
logger.info(f"MCP服务器连接成功: {self.server_url}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
|
||||
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
|
||||
self._connected = False
|
||||
return False
|
||||
|
||||
async def _update_available_tools(self):
|
||||
"""更新可用工具列表"""
|
||||
try:
|
||||
if self._client and self._connected:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
|
||||
except Exception as e:
|
||||
logger.error(f"更新MCP工具列表失败: {e}")
|
||||
|
||||
async def disconnect(self) -> bool:
|
||||
"""断开MCP服务器连接"""
|
||||
try:
|
||||
if self._client:
|
||||
# 这里应该实现实际的断开逻辑
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
|
||||
self._connected = False
|
||||
@@ -171,38 +182,15 @@ class MCPTool(BaseTool):
|
||||
|
||||
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
|
||||
"""调用MCP工具"""
|
||||
# 构建MCP请求
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
}
|
||||
}
|
||||
if not self._client or not self._connected:
|
||||
raise Exception("MCP客户端未连接")
|
||||
|
||||
# 发送请求
|
||||
client_timeout = aiohttp.ClientTimeout(total=timeout)
|
||||
async with aiohttp.ClientSession(timeout=client_timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
|
||||
|
||||
result = await response.json()
|
||||
|
||||
# 检查MCP响应
|
||||
if "error" in result:
|
||||
error = result["error"]
|
||||
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
|
||||
|
||||
return result.get("result", {})
|
||||
try:
|
||||
result = await self._client.call_tool(tool_name, arguments, timeout)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
async def list_available_tools(self) -> List[Dict[str, Any]]:
|
||||
"""列出可用的MCP工具"""
|
||||
@@ -210,27 +198,10 @@ class MCPTool(BaseTool):
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
# 获取工具列表
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"req_{int(time.time() * 1000)}",
|
||||
"method": "tools/list"
|
||||
}
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(
|
||||
f"{self.server_url}/mcp",
|
||||
json=request_data,
|
||||
headers={"Content-Type": "application/json"}
|
||||
) as response:
|
||||
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
if "result" in result:
|
||||
tools = result["result"].get("tools", [])
|
||||
self.available_tools = [tool.get("name") for tool in tools]
|
||||
return tools
|
||||
if self._client:
|
||||
tools = await self._client.list_tools()
|
||||
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
|
||||
return tools
|
||||
|
||||
return []
|
||||
|
||||
|
||||
@@ -134,11 +134,40 @@ class MCPClient:
|
||||
logger.error(f"断开MCP服务器连接失败: {e}")
|
||||
return False
|
||||
|
||||
def _build_auth_headers(self) -> Dict[str, str]:
|
||||
"""构建认证头"""
|
||||
headers = {}
|
||||
auth_type = self.connection_config.get("auth_type", "none")
|
||||
auth_config = self.connection_config.get("auth_config", {})
|
||||
|
||||
if auth_type == "api_key":
|
||||
api_key = auth_config.get("api_key")
|
||||
key_name = auth_config.get("key_name", "X-API-Key")
|
||||
if api_key:
|
||||
headers[key_name] = api_key
|
||||
|
||||
elif auth_type == "bearer_token":
|
||||
token = auth_config.get("token")
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
elif auth_type == "basic_auth":
|
||||
username = auth_config.get("username")
|
||||
password = auth_config.get("password")
|
||||
if username and password:
|
||||
import base64
|
||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {credentials}"
|
||||
|
||||
return headers
|
||||
|
||||
async def _connect_websocket(self) -> bool:
|
||||
"""建立WebSocket连接"""
|
||||
try:
|
||||
# WebSocket连接配置
|
||||
extra_headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
extra_headers.update(auth_headers)
|
||||
|
||||
self._websocket = await websockets.connect(
|
||||
self.server_url,
|
||||
@@ -190,6 +219,8 @@ class MCPClient:
|
||||
# HTTP会话配置
|
||||
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
|
||||
headers = self.connection_config.get("headers", {})
|
||||
auth_headers = self._build_auth_headers()
|
||||
headers.update(auth_headers)
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
@@ -251,8 +282,9 @@ class MCPClient:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息失败: {e}")
|
||||
|
||||
async def _handle_notification(self, message: Dict[str, Any]):
|
||||
|
||||
@staticmethod
|
||||
async def _handle_notification(message: Dict[str, Any]):
|
||||
"""处理通知消息"""
|
||||
method = message.get("method")
|
||||
params = message.get("params", {})
|
||||
@@ -327,7 +359,7 @@ class MCPClient:
|
||||
try:
|
||||
response = await self._send_request(request_data, timeout)
|
||||
|
||||
if not response["error"] is None:
|
||||
if response.get("error", None) is not None:
|
||||
error = response["error"]
|
||||
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
|
||||
|
||||
@@ -372,10 +404,10 @@ class MCPClient:
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise
|
||||
except Exception as e:
|
||||
self._pending_requests.pop(request_id, None)
|
||||
await self._pending_requests.pop(request_id, None)
|
||||
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
|
||||
|
||||
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
@@ -424,9 +456,9 @@ class MCPClient:
|
||||
|
||||
start_time = time.time()
|
||||
response = await self._send_request(request_data, timeout=5)
|
||||
response_time = time.time() - start_time
|
||||
response_time = round((time.time() - start_time) * 1000)
|
||||
|
||||
self._last_health_check = time.time()
|
||||
self._last_health_check = round(time.time() * 1000)
|
||||
|
||||
return {
|
||||
"healthy": True,
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
|
||||
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .client import MCPClient, MCPConnectionPool
|
||||
|
||||
@@ -148,7 +148,7 @@ class MCPServiceManager:
|
||||
connection_config=connection_config,
|
||||
available_tools=tool_names,
|
||||
health_status="healthy",
|
||||
last_health_check=datetime.utcnow()
|
||||
last_health_check=datetime.now()
|
||||
)
|
||||
|
||||
self.db.add(mcp_config)
|
||||
@@ -410,7 +410,8 @@ class MCPServiceManager:
|
||||
"""加载现有服务"""
|
||||
try:
|
||||
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
|
||||
ToolConfig.is_enabled == True
|
||||
ToolConfig.status == ToolStatus.AVAILABLE.value,
|
||||
ToolConfig.tool_type == ToolType.MCP.value
|
||||
).all()
|
||||
|
||||
for mcp_config in mcp_configs:
|
||||
@@ -531,7 +532,7 @@ class MCPServiceManager:
|
||||
|
||||
if mcp_config:
|
||||
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
|
||||
mcp_config.last_health_check = datetime.utcnow()
|
||||
mcp_config.last_health_check = datetime.now()
|
||||
|
||||
if not health_status["healthy"]:
|
||||
mcp_config.error_message = health_status.get("error", "")
|
||||
|
||||
@@ -1,436 +0,0 @@
|
||||
"""工具注册表 - 管理所有工具的元数据和状态"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from app.models.tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolType, ToolStatus, ToolExecution, ExecutionStatus
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .base import BaseTool, ToolInfo
|
||||
from .custom.base import CustomTool
|
||||
from .mcp.base import MCPTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表 - 管理所有工具的元数据和实例"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化工具注册表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
|
||||
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
|
||||
self._lock = asyncio.Lock() # 异步锁
|
||||
|
||||
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
|
||||
"""注册工具类
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
class_name: 类名(可选,默认使用类的__name__)
|
||||
"""
|
||||
class_name = class_name or tool_class.__name__
|
||||
self._tool_classes[class_name] = tool_class
|
||||
logger.info(f"工具类已注册: {class_name}")
|
||||
|
||||
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
|
||||
"""注册工具实例到系统
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
tenant_id: 租户ID(内置工具可以为None,表示全局工具)
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否已存在
|
||||
if tenant_id:
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
# 全局工具(内置工具)
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id.is_(None),
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
|
||||
return False
|
||||
|
||||
# 创建工具配置
|
||||
tool_config = ToolConfig(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
tool_type=tool.tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
version=tool.version,
|
||||
tags=tool.tags,
|
||||
config_data=tool.config
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush() # 获取ID
|
||||
|
||||
# 根据工具类型创建特定配置
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
builtin_config = BuiltinToolConfig(
|
||||
id=tool_config.id,
|
||||
tool_class=tool.__class__.__name__,
|
||||
parameters=tool.config.get("parameters", {})
|
||||
)
|
||||
self.db.add(builtin_config)
|
||||
|
||||
elif tool.tool_type == ToolType.CUSTOM:
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
schema_url=tool.config.get("schema_url"),
|
||||
schema_content=tool.config.get("schema_content"),
|
||||
auth_type=tool.config.get("auth_type", "none"),
|
||||
auth_config=tool.config.get("auth_config", {}),
|
||||
base_url=tool.config.get("base_url"),
|
||||
timeout=tool.config.get("timeout", 30)
|
||||
)
|
||||
self.db.add(custom_config)
|
||||
|
||||
elif tool.tool_type == ToolType.MCP:
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=tool.config.get("server_url"),
|
||||
connection_config=tool.config.get("connection_config", {}),
|
||||
available_tools=tool.config.get("available_tools", [])
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 缓存工具实例
|
||||
tool.tool_id = str(tool_config.id)
|
||||
self._tools[str(tool_config.id)] = tool
|
||||
|
||||
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> bool:
|
||||
"""从系统注销工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否存在
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 检查是否有正在执行的任务
|
||||
running_executions = self.db.query(ToolExecution).filter(
|
||||
and_(
|
||||
ToolExecution.tool_config_id == uuid.UUID(tool_id),
|
||||
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
|
||||
)
|
||||
).count()
|
||||
|
||||
if running_executions > 0:
|
||||
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
|
||||
return False
|
||||
|
||||
# 删除工具配置(级联删除相关记录)
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 从缓存中移除
|
||||
if tool_id in self._tools:
|
||||
del self._tools[tool_id]
|
||||
|
||||
logger.info(f"工具注销成功: {tool_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
|
||||
"""获取工具实例
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
工具实例,如果不存在返回None
|
||||
"""
|
||||
# 先从缓存获取
|
||||
if tool_id in self._tools:
|
||||
return self._tools[tool_id]
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
|
||||
return None
|
||||
|
||||
# 根据工具类型加载实例
|
||||
tool_instance = self._load_tool_instance(tool_config)
|
||||
if tool_instance:
|
||||
self._tools[tool_id] = tool_instance
|
||||
return tool_instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[ToolInfo]:
|
||||
"""列出工具
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
tool_type: 工具类型过滤
|
||||
status: 工具状态过滤
|
||||
tags: 标签过滤
|
||||
|
||||
Returns:
|
||||
工具信息列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
|
||||
# 应用过滤条件
|
||||
if tenant_id:
|
||||
# 返回全局工具(tenant_id为空)和该租户的工具
|
||||
query = query.filter(
|
||||
or_(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tenant_id.is_(None)
|
||||
)
|
||||
)
|
||||
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||||
|
||||
if status == ToolStatus.ACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == True)
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == False)
|
||||
|
||||
if tags:
|
||||
for tag in tags:
|
||||
query = query.filter(ToolConfig.tags.contains([tag]))
|
||||
|
||||
tool_configs = query.all()
|
||||
|
||||
# 转换为ToolInfo
|
||||
tool_infos = []
|
||||
for config in tool_configs:
|
||||
tool_info = ToolInfo(
|
||||
id=str(config.id),
|
||||
name=config.name,
|
||||
description=config.description or "",
|
||||
tool_type=ToolType(config.tool_type),
|
||||
version=config.version,
|
||||
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
|
||||
tags=config.tags or [],
|
||||
tenant_id=str(config.tenant_id) if config.tenant_id else None
|
||||
)
|
||||
|
||||
# 尝试获取参数信息
|
||||
tool_instance = self.get_tool(str(config.id))
|
||||
if tool_instance:
|
||||
tool_info.parameters = tool_instance.parameters
|
||||
|
||||
tool_infos.append(tool_info)
|
||||
|
||||
return tool_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出工具失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
|
||||
"""更新工具状态
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
status: 新状态
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 更新状态
|
||||
if status == ToolStatus.ACTIVE:
|
||||
tool_config.is_enabled = True
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
tool_config.is_enabled = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新缓存中的工具状态
|
||||
if tool_id in self._tools:
|
||||
self._tools[tool_id].status = status
|
||||
|
||||
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
|
||||
"""从配置加载工具实例
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置
|
||||
|
||||
Returns:
|
||||
工具实例
|
||||
"""
|
||||
try:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
# 加载内置工具
|
||||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if builtin_config and builtin_config.tool_class in self._tool_classes:
|
||||
tool_class = self._tool_classes[builtin_config.tool_class]
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"parameters": builtin_config.parameters,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return tool_class(str(tool_config.id), config)
|
||||
|
||||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
# 加载自定义工具
|
||||
try:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if custom_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"schema_url": custom_config.schema_url,
|
||||
"schema_content": custom_config.schema_content,
|
||||
"auth_type": custom_config.auth_type,
|
||||
"auth_config": custom_config.auth_config,
|
||||
"base_url": custom_config.base_url,
|
||||
"timeout": custom_config.timeout,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return CustomTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入自定义工具模块: {e}")
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
# 加载MCP工具
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config,
|
||||
"available_tools": mcp_config.available_tools,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return MCPTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入MCP工具模块: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||
"""获取工具统计信息
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
if tenant_id:
|
||||
query = query.filter(ToolConfig.tenant_id == tenant_id)
|
||||
|
||||
total_tools = query.count()
|
||||
active_tools = query.filter(ToolConfig.is_enabled == True).count()
|
||||
|
||||
# 按类型统计
|
||||
type_stats = {}
|
||||
for tool_type in ToolType:
|
||||
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
|
||||
type_stats[tool_type.value] = count
|
||||
|
||||
return {
|
||||
"total_tools": total_tools,
|
||||
"active_tools": active_tools,
|
||||
"inactive_tools": total_tools - active_tools,
|
||||
"by_type": type_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空工具缓存"""
|
||||
self._tools.clear()
|
||||
logger.info("工具缓存已清空")
|
||||
Reference in New Issue
Block a user