feat(tool system): Tool system reengineering

This commit is contained in:
谢俊男
2025-12-25 17:30:20 +08:00
parent 3bcaead413
commit 04be3088a2
25 changed files with 1887 additions and 3325 deletions

View File

@@ -1,11 +1,7 @@
"""工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter
from .registry import ToolRegistry
from .executor import ToolExecutor
from .langchain_adapter import LangchainAdapter
from .config_manager import ConfigManager
from .chain_manager import ChainManager
# 可选导入,避免导入错误
try:
@@ -22,11 +18,7 @@ __all__ = [
"BaseTool",
"ToolResult",
"ToolParameter",
"ToolRegistry",
"ToolExecutor",
"LangchainAdapter",
"ConfigManager",
"ChainManager"
"LangchainAdapter"
]
# 只有在成功导入时才添加到__all__

View File

@@ -1,98 +1,10 @@
"""工具基础接口定义"""
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from enum import Enum
from typing import Any, Dict, List, Optional
from app.models.tool_model import ToolType, ToolStatus
class ParameterType(str, Enum):
"""参数类型枚举"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class ToolParameter(BaseModel):
"""工具参数定义"""
name: str = Field(..., description="参数名称")
type: ParameterType = Field(..., description="参数类型")
description: str = Field("", description="参数描述")
required: bool = Field(False, description="是否必需")
default: Any = Field(None, description="默认值")
enum: Optional[List[Any]] = Field(None, description="枚举值")
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
pattern: Optional[str] = Field(None, description="正则表达式模式")
class Config:
use_enum_values = True
class ToolResult(BaseModel):
"""工具执行结果"""
success: bool = Field(..., description="执行是否成功")
data: Any = Field(None, description="返回数据")
error: Optional[str] = Field(None, description="错误信息")
error_code: Optional[str] = Field(None, description="错误代码")
execution_time: float = Field(..., description="执行时间(秒)")
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
@classmethod
def success_result(
cls,
data: Any,
execution_time: float,
token_usage: Optional[Dict[str, int]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建成功结果"""
return cls(
success=True,
data=data,
execution_time=execution_time,
token_usage=token_usage,
metadata=metadata or {}
)
@classmethod
def error_result(
cls,
error: str,
execution_time: float,
error_code: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建错误结果"""
return cls(
success=False,
error=error,
error_code=error_code,
execution_time=execution_time,
metadata=metadata or {}
)
class ToolInfo(BaseModel):
"""工具信息"""
id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
tool_type: ToolType = Field(..., description="工具类型")
version: str = Field("1.0.0", description="工具版本")
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID")
class Config:
use_enum_values = True
from app.schemas.tool_schema import ToolParameter, ParameterType, ToolResult
class BaseTool(ABC):
@@ -107,7 +19,7 @@ class BaseTool(ABC):
"""
self.tool_id = tool_id
self.config = config
self._status = ToolStatus.ACTIVE
self._status = ToolStatus.AVAILABLE
@property
@abstractmethod
@@ -153,20 +65,6 @@ class BaseTool(ABC):
"""工具标签"""
return self.config.get("tags", [])
def get_info(self) -> ToolInfo:
"""获取工具信息"""
return ToolInfo(
id=self.tool_id,
name=self.name,
description=self.description,
tool_type=self.tool_type,
version=self.version,
parameters=self.parameters,
status=self.status,
tags=self.tags,
tenant_id=self.config.get("tenant_id")
)
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
"""验证参数

View File

@@ -3,7 +3,8 @@ from abc import ABC, abstractmethod
from typing import Dict, Any, List
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolResult, ToolParameter
class BuiltinTool(BaseTool, ABC):

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timezone, timedelta
from typing import List
import pytz
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
@@ -54,14 +54,14 @@ class DateTimeTool(BuiltinTool):
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="UTC"
default="Asia/Shanghai"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="UTC"
default="Asia/Shanghai"
),
ToolParameter(
name="calculation",
@@ -106,10 +106,11 @@ class DateTimeTool(BuiltinTool):
error_code="DATETIME_ERROR",
execution_time=execution_time
)
def _get_current_time(self, kwargs) -> dict:
@staticmethod
def _get_current_time(kwargs) -> dict:
"""获取当前时间"""
timezone_str = kwargs.get("to_timezone", "UTC")
timezone_str = kwargs.get("to_timezone", "Asia/Shanghai")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if timezone_str == "UTC":
@@ -118,15 +119,20 @@ class DateTimeTool(BuiltinTool):
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
utc_now = datetime.now(timezone.utc)
return {
"datetime": now.strftime(output_format),
"timestamp": int(now.timestamp()),
"timezone": timezone_str,
"iso_format": now.isoformat()
"iso_format": now.isoformat(),
"timestamp_ms": int(now.timestamp() * 1000),
"utc_datetime": utc_now.strftime(output_format)
}
def _format_datetime(self, kwargs) -> dict:
@staticmethod
def _format_datetime(kwargs) -> dict:
"""格式化时间"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
@@ -144,8 +150,9 @@ class DateTimeTool(BuiltinTool):
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _convert_timezone(self, kwargs) -> dict:
@staticmethod
def _convert_timezone(kwargs) -> dict:
"""时区转换"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
@@ -184,8 +191,9 @@ class DateTimeTool(BuiltinTool):
"converted_timezone": to_timezone,
"timestamp": int(converted_dt.timestamp())
}
def _timestamp_to_datetime(self, kwargs) -> dict:
@staticmethod
def _timestamp_to_datetime(kwargs) -> dict:
"""时间戳转日期时间"""
input_value = kwargs.get("input_value")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
@@ -196,6 +204,8 @@ class DateTimeTool(BuiltinTool):
# 转换时间戳
timestamp = float(input_value)
if timestamp > 1e12:
timestamp = timestamp / 1000
# 设置时区
if timezone_str == "UTC":
@@ -211,8 +221,9 @@ class DateTimeTool(BuiltinTool):
"timezone": timezone_str,
"iso_format": dt.isoformat()
}
def _datetime_to_timestamp(self, kwargs) -> dict:
@staticmethod
def _datetime_to_timestamp(kwargs) -> dict:
"""日期时间转时间戳"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
@@ -240,7 +251,7 @@ class DateTimeTool(BuiltinTool):
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _calculate_datetime(self, kwargs) -> dict:
"""时间计算"""
input_value = kwargs.get("input_value")
@@ -278,8 +289,9 @@ class DateTimeTool(BuiltinTool):
"timezone": timezone_str,
"timestamp": int(calculated_dt.timestamp())
}
def _parse_time_delta(self, calculation: str) -> timedelta:
@staticmethod
def _parse_time_delta(calculation: str) -> timedelta:
"""解析时间计算表达式"""
import re

View File

@@ -121,8 +121,9 @@ class JsonTool(BuiltinTool):
error_code="JSON_ERROR",
execution_time=execution_time
)
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _format_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""格式化JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
@@ -151,12 +152,13 @@ class JsonTool(BuiltinTool):
"sort_keys": sort_keys
}
}
def _minify_json(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _minify_json(input_data: str) -> Dict[str, Any]:
"""压缩JSON"""
# 解析并压缩
data = json.loads(input_data)
minified = json.dumps(data, separators=(',', ':'))
minified = json.dumps(data, ensure_ascii=False, separators=(',', ':'))
return {
"original_size": len(input_data),
@@ -165,7 +167,7 @@ class JsonTool(BuiltinTool):
"minified_json": minified,
"is_valid": True
}
def _validate_json(self, input_data: str) -> Dict[str, Any]:
"""验证JSON"""
try:
@@ -190,17 +192,19 @@ class JsonTool(BuiltinTool):
"size": len(input_data)
}
def _convert_json(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _convert_json(input_data: str) -> Dict[str, Any]:
"""JSON转义"""
data = json.loads(input_data)
converted = json.dumps(data, ensure_ascii=False)
converted = json.dumps(data, ensure_ascii=True, separators=(',', ':'))
return {
"converted_json": converted,
"is_valid": True
}
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _json_to_yaml(input_data: str) -> Dict[str, Any]:
"""JSON转YAML"""
data = json.loads(input_data)
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
@@ -212,8 +216,9 @@ class JsonTool(BuiltinTool):
"converted_size": len(yaml_output),
"converted_data": yaml_output
}
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _yaml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""YAML转JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
@@ -228,10 +233,11 @@ class JsonTool(BuiltinTool):
"converted_size": len(json_output),
"converted_data": json_output
}
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
@staticmethod
def _json_to_xml(input_data: str) -> Dict[str, Any]:
"""JSON转XML"""
data = json.loads(input_data)
json_data = json.loads(input_data)
def dict_to_xml(data, root_name="root"):
"""递归转换字典为XML"""
@@ -267,7 +273,7 @@ class JsonTool(BuiltinTool):
root.text = str(data)
return root
xml_element = dict_to_xml(data)
xml_element = dict_to_xml(json_data)
xml_string = ET.tostring(xml_element, encoding='unicode')
# 格式化XML
@@ -284,8 +290,9 @@ class JsonTool(BuiltinTool):
"converted_size": len(formatted_xml),
"converted_data": formatted_xml
}
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _xml_to_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""XML转JSON"""
indent = kwargs.get("indent", 2)
@@ -328,8 +335,9 @@ class JsonTool(BuiltinTool):
"converted_size": len(json_output),
"converted_data": json_output
}
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _merge_json(input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""合并JSON"""
merge_data = kwargs.get("merge_data")
if not merge_data:
@@ -364,8 +372,9 @@ class JsonTool(BuiltinTool):
"result_size": len(merged_json),
"merged_data": merged_json
}
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _extract_json_path( input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取JSON路径"""
json_path = kwargs.get("json_path")
if not json_path:

View File

@@ -275,8 +275,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
@staticmethod
def _format_formula_result( result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化公式识别结果"""
formulas = result.get("formulas", [])
@@ -288,8 +289,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
@staticmethod
def _format_table_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化表格识别结果"""
tables = result.get("tables", [])
@@ -301,8 +303,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
@staticmethod
def _format_document_result(result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化文档识别结果"""
return {
"recognition_mode": "document",
@@ -314,8 +317,9 @@ class TextInTool(BuiltinTool):
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
@staticmethod
def _group_lines_to_paragraphs(lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将行分组为段落"""
paragraphs = []
current_paragraph = []

View File

@@ -1,485 +0,0 @@
"""工具链管理器 - 支持langchain的工具链模式"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
from app.core.tools.base import ToolResult
from app.core.tools.executor import ToolExecutor
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ChainExecutionMode(str, Enum):
"""链执行模式"""
SEQUENTIAL = "sequential" # 顺序执行
PARALLEL = "parallel" # 并行执行
CONDITIONAL = "conditional" # 条件执行
@dataclass
class ChainStep:
"""链步骤定义"""
tool_id: str
parameters: Dict[str, Any]
condition: Optional[str] = None # 执行条件
output_mapping: Optional[Dict[str, str]] = None # 输出映射
error_handling: str = "stop" # 错误处理stop, continue, retry
@dataclass
class ChainDefinition:
"""工具链定义"""
name: str
description: str
steps: List[ChainStep]
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
global_timeout: Optional[float] = None
retry_policy: Optional[Dict[str, Any]] = None
class ChainExecutionContext:
"""链执行上下文"""
def __init__(self, chain_id: str):
self.chain_id = chain_id
self.variables: Dict[str, Any] = {}
self.step_results: Dict[int, ToolResult] = {}
self.current_step = 0
self.is_completed = False
self.is_failed = False
self.error_message: Optional[str] = None
class ChainManager:
"""工具链管理器 - 支持langchain的工具链模式"""
def __init__(self, executor: ToolExecutor):
"""初始化工具链管理器
Args:
executor: 工具执行器
"""
self.executor = executor
self._chains: Dict[str, ChainDefinition] = {}
self._running_chains: Dict[str, ChainExecutionContext] = {}
def register_chain(self, chain: ChainDefinition) -> bool:
"""注册工具链
Args:
chain: 工具链定义
Returns:
注册是否成功
"""
try:
# 验证工具链定义
validation_result = self._validate_chain(chain)
if not validation_result[0]:
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
return False
self._chains[chain.name] = chain
logger.info(f"工具链注册成功: {chain.name}")
return True
except Exception as e:
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
return False
def unregister_chain(self, chain_name: str) -> bool:
"""注销工具链
Args:
chain_name: 工具链名称
Returns:
注销是否成功
"""
if chain_name in self._chains:
del self._chains[chain_name]
logger.info(f"工具链注销成功: {chain_name}")
return True
return False
def list_chains(self) -> List[Dict[str, Any]]:
"""列出所有工具链
Returns:
工具链信息列表
"""
chains = []
for name, chain in self._chains.items():
chains.append({
"name": name,
"description": chain.description,
"step_count": len(chain.steps),
"execution_mode": chain.execution_mode.value,
"global_timeout": chain.global_timeout
})
return chains
async def execute_chain(
self,
chain_name: str,
initial_variables: Optional[Dict[str, Any]] = None,
chain_id: Optional[str] = None
) -> Dict[str, Any] | None:
"""执行工具链
Args:
chain_name: 工具链名称
initial_variables: 初始变量
chain_id: 链执行ID可选
Returns:
执行结果
"""
if chain_name not in self._chains:
return {
"success": False,
"error": f"工具链不存在: {chain_name}",
"chain_id": chain_id
}
chain = self._chains[chain_name]
# 生成链ID
if not chain_id:
import uuid
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ChainExecutionContext(chain_id)
context.variables = initial_variables or {}
self._running_chains[chain_id] = context
try:
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
# 根据执行模式执行
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
result = await self._execute_sequential(chain, context)
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
result = await self._execute_parallel(chain, context)
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
result = await self._execute_conditional(chain, context)
else:
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
return result
except Exception as e:
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
return {
"success": False,
"error": str(e),
"chain_id": chain_id,
"completed_steps": context.current_step,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
finally:
# 清理执行上下文
if chain_id in self._running_chains:
del self._running_chains[chain_id]
async def _execute_sequential(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""顺序执行工具链"""
for i, step in enumerate(chain.steps):
context.current_step = i
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
logger.debug(f"跳过步骤 {i}: 条件不满足")
continue
# 准备参数
parameters = self._prepare_parameters(step.parameters, context)
# 执行工具
try:
result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = result
# 处理输出映射
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 处理执行失败
if not result.success:
if step.error_handling == "stop":
context.is_failed = True
context.error_message = result.error
break
elif step.error_handling == "continue":
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
continue
elif step.error_handling == "retry":
# 简单重试逻辑
retry_result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = retry_result
if not retry_result.success and step.error_handling == "stop":
context.is_failed = True
context.error_message = retry_result.error
break
except Exception as e:
logger.error(f"步骤 {i} 执行异常: {e}")
if step.error_handling == "stop":
context.is_failed = True
context.error_message = str(e)
break
context.is_completed = not context.is_failed
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": context.current_step + 1,
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_parallel(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""并行执行工具链"""
# 准备所有步骤的执行配置
execution_configs = []
for i, step in enumerate(chain.steps):
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
continue
parameters = self._prepare_parameters(step.parameters, context)
execution_configs.append({
"step_index": i,
"tool_id": step.tool_id,
"parameters": parameters
})
# 并行执行所有步骤
try:
results = await self.executor.execute_tools_batch(execution_configs)
# 处理结果
for i, result in enumerate(results):
step_index = execution_configs[i]["step_index"]
context.step_results[step_index] = result
# 处理输出映射
step = chain.steps[step_index]
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 检查是否有失败的步骤
failed_steps = [i for i, result in context.step_results.items() if not result.success]
context.is_completed = len(failed_steps) == 0
if failed_steps:
context.error_message = f"步骤 {failed_steps} 执行失败"
except Exception as e:
context.is_failed = True
context.error_message = str(e)
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": len(context.step_results),
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_conditional(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""条件执行工具链"""
# 条件执行类似于顺序执行,但更严格地检查条件
return await self._execute_sequential(chain, context)
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
"""验证工具链定义
Args:
chain: 工具链定义
Returns:
(是否有效, 错误信息)
"""
if not chain.name:
return False, "工具链名称不能为空"
if not chain.steps:
return False, "工具链必须包含至少一个步骤"
for i, step in enumerate(chain.steps):
if not step.tool_id:
return False, f"步骤 {i} 缺少工具ID"
if step.error_handling not in ["stop", "continue", "retry"]:
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
return True, None
def _prepare_parameters(
self,
parameters: Dict[str, Any],
context: ChainExecutionContext
) -> Dict[str, Any]:
"""准备参数(支持变量替换)
Args:
parameters: 原始参数
context: 执行上下文
Returns:
处理后的参数
"""
prepared = {}
for key, value in parameters.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 变量替换
var_name = value[2:-1]
if var_name in context.variables:
prepared[key] = context.variables[var_name]
else:
prepared[key] = value # 保持原值
else:
prepared[key] = value
return prepared
def _evaluate_condition(
self,
condition: str,
context: ChainExecutionContext
) -> bool:
"""评估执行条件
Args:
condition: 条件表达式
context: 执行上下文
Returns:
条件是否满足
"""
try:
# 简单的条件评估(可以扩展为更复杂的表达式解析)
# 支持格式variable == value, variable != value, variable > value 等
if "==" in condition:
var_name, expected_value = condition.split("==", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) == expected_value
elif "!=" in condition:
var_name, expected_value = condition.split("!=", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) != expected_value
elif condition in context.variables:
# 简单的布尔检查
return bool(context.variables[condition])
else:
# 默认为真
return True
except Exception as e:
logger.error(f"条件评估失败: {condition}, 错误: {e}")
return False
def _apply_output_mapping(
self,
mapping: Dict[str, str],
output_data: Any,
context: ChainExecutionContext
):
"""应用输出映射
Args:
mapping: 输出映射配置
output_data: 输出数据
context: 执行上下文
"""
try:
if isinstance(output_data, dict):
for source_key, target_var in mapping.items():
if source_key in output_data:
context.variables[target_var] = output_data[source_key]
else:
# 如果输出不是字典,将整个输出映射到指定变量
if "result" in mapping:
context.variables[mapping["result"]] = output_data
except Exception as e:
logger.error(f"输出映射失败: {e}")
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
"""序列化工具结果
Args:
result: 工具结果
Returns:
序列化的结果
"""
return {
"success": result.success,
"data": result.data,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time,
"token_usage": result.token_usage,
"metadata": result.metadata
}
def get_running_chains(self) -> List[Dict[str, Any]]:
"""获取正在运行的工具链
Returns:
运行中的工具链列表
"""
chains = []
for chain_id, context in self._running_chains.items():
chains.append({
"chain_id": chain_id,
"current_step": context.current_step,
"is_completed": context.is_completed,
"is_failed": context.is_failed,
"variables_count": len(context.variables),
"completed_steps": len(context.step_results)
})
return chains

View File

@@ -1,264 +0,0 @@
"""工具配置管理器 - 管理工具配置的加载和验证"""
import json
from pathlib import Path
from typing import Dict, Any, Optional
from pydantic import BaseModel, ValidationError
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ToolConfigSchema(BaseModel):
"""工具配置基础Schema"""
name: str
description: str
tool_type: str
version: str = "1.0.0"
enabled: bool = True
parameters: Dict[str, Any] = {}
tags: list[str] = []
class Config:
extra = "allow"
class BuiltinToolConfigSchema(ToolConfigSchema):
"""内置工具配置Schema"""
tool_class: str
tool_type: str = "builtin"
class CustomToolConfigSchema(ToolConfigSchema):
"""自定义工具配置Schema"""
schema_url: Optional[str] = None
schema_content: Optional[Dict[str, Any]] = None
auth_type: str = "none"
auth_config: Dict[str, Any] = {}
base_url: Optional[str] = None
timeout: int = 30
tool_type: str = "custom"
class MCPToolConfigSchema(ToolConfigSchema):
"""MCP工具配置Schema"""
server_url: str
connection_config: Dict[str, Any] = {}
available_tools: list[str] = []
tool_type: str = "mcp"
class ConfigManager:
"""工具配置管理器"""
def __init__(self, config_dir: Optional[str] = None):
"""初始化配置管理器
Args:
config_dir: 配置文件目录,默认使用系统配置
"""
self.config_dir = Path(config_dir or self._get_default_config_dir())
self.config_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
def _get_default_config_dir(self) -> str:
"""获取默认配置目录"""
# 获取tools目录下的configs子目录
tools_dir = Path(__file__).parent
return str(tools_dir / "configs")
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
"""加载内置工具配置
Returns:
内置工具配置字典
"""
configs = {}
builtin_dir = self.config_dir / "builtin"
if not builtin_dir.exists():
logger.info("内置工具配置目录不存在,创建默认配置")
self._create_default_builtin_configs(builtin_dir)
for config_file in builtin_dir.glob("*.json"):
try:
config_data = self._load_config_file(config_file)
config = BuiltinToolConfigSchema(**config_data)
configs[config.name] = config
logger.debug(f"加载内置工具配置: {config.name}")
except Exception as e:
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
return configs
def load_builtin_tools_config(self) -> Dict[str, Any]:
"""加载全局内置工具配置(兼容原有接口)
Returns:
内置工具配置字典
"""
config_file = self.config_dir / "builtin_tools.json"
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载内置工具配置失败: {e}")
return {}
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
"""确保内置工具已初始化到数据库
Args:
tenant_id: 租户ID
db_session: 数据库会话
tool_config_model: ToolConfig模型类
builtin_tool_config_model: BuiltinToolConfig模型类
tool_type_enum: ToolType枚举
tool_status_enum: ToolStatus枚举
"""
# 检查是否已初始化
existing_count = db_session.query(tool_config_model).filter(
tool_config_model.tenant_id == tenant_id,
tool_config_model.tool_type == tool_type_enum.BUILTIN
).count()
if existing_count > 0:
return # 已初始化
# 加载全局配置
builtin_tools = self.load_builtin_tools_config()
# 为租户创建内置工具记录
for tool_key, tool_info in builtin_tools.items():
# 设置初始状态
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
tool_config = tool_config_model(
name=tool_info['name'],
description=tool_info['description'],
tool_type=tool_type_enum.BUILTIN,
tenant_id=tenant_id,
status=initial_status
)
db_session.add(tool_config)
db_session.flush()
builtin_config = builtin_tool_config_model(
id=tool_config.id,
tool_class=tool_info['tool_class'],
parameters={}
)
db_session.add(builtin_config)
db_session.commit()
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
"""保存工具配置
Args:
config: 工具配置
tool_type: 工具类型
Returns:
保存是否成功
"""
try:
config_dir = self.config_dir / tool_type
config_dir.mkdir(parents=True, exist_ok=True)
config_file = config_dir / f"{config.name}.json"
config_data = config.model_dump()
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=2, ensure_ascii=False)
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
return True
except Exception as e:
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
return False
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
"""删除工具配置
Args:
tool_name: 工具名称
tool_type: 工具类型
Returns:
删除是否成功
"""
try:
config_file = self.config_dir / tool_type / f"{tool_name}.json"
if config_file.exists():
config_file.unlink()
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
return True
else:
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
return False
except Exception as e:
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
return False
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
"""验证工具配置
Args:
config_data: 配置数据
tool_type: 工具类型
Returns:
(是否有效, 错误信息)
"""
try:
schema_map = {
"builtin": BuiltinToolConfigSchema,
"custom": CustomToolConfigSchema,
"mcp": MCPToolConfigSchema
}
schema_class = schema_map.get(tool_type)
if not schema_class:
return False, f"不支持的工具类型: {tool_type}"
# 验证配置
schema_class(**config_data)
return True, None
except ValidationError as e:
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
return False, f"配置验证失败: {error_msg}"
except Exception as e:
return False, f"配置验证异常: {str(e)}"
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
"""加载配置文件
Args:
config_file: 配置文件路径
Returns:
配置数据字典
"""
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
raise
def _create_default_builtin_configs(self, builtin_dir: Path):
"""创建默认内置工具配置
Args:
builtin_dir: 内置工具配置目录
"""
builtin_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
# 配置文件已经通过其他方式创建,这里只需要确保目录存在

View File

@@ -54,7 +54,8 @@
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
"base_url": {"type": "string", "description": "API地址", "default": "https://api.textin.com/v1"}
}
}
}

View File

@@ -2,7 +2,6 @@
import base64
import hashlib
import hmac
import time
from typing import Dict, Any, Tuple
from urllib.parse import quote
import aiohttp
@@ -51,8 +50,9 @@ class AuthManager:
except Exception as e:
return False, f"验证认证配置时出错: {e}"
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
@staticmethod
def _validate_api_key_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证API Key认证配置
Args:
@@ -79,8 +79,9 @@ class AuthManager:
return False, "API Key位置必须是 header、query 或 cookie"
return True, ""
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
@staticmethod
def _validate_bearer_token_config(auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证Bearer Token认证配置
Args:
@@ -135,9 +136,9 @@ class AuthManager:
except Exception as e:
logger.error(f"应用认证时出错: {e}")
return url, headers, params
@staticmethod
def _apply_api_key_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
@@ -176,9 +177,9 @@ class AuthManager:
headers["Cookie"] = cookie_value
return url, headers, params
@staticmethod
def _apply_bearer_token_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
@@ -260,8 +261,9 @@ class AuthManager:
except Exception as e:
logger.error(f"解密认证配置失败: {e}")
return encrypted_config
def _encrypt_string(self, value: str, key: str) -> str:
@staticmethod
def _encrypt_string(value: str, key: str) -> str:
"""加密字符串
Args:
@@ -289,8 +291,9 @@ class AuthManager:
except Exception as e:
logger.error(f"加密字符串失败: {e}")
return value
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
@staticmethod
def _decrypt_string(encrypted_value: str, key: str) -> str:
"""解密字符串
Args:
@@ -471,8 +474,9 @@ class AuthManager:
"error": f"测试认证时出错: {e}",
"auth_type": auth_type.value
}
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
@staticmethod
def get_auth_config_template(auth_type: AuthType) -> Dict[str, Any]:
"""获取认证配置模板
Args:
@@ -498,8 +502,9 @@ class AuthManager:
}
return templates.get(auth_type, {})
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def mask_sensitive_config(auth_config: Dict[str, Any]) -> Dict[str, Any]:
"""遮蔽认证配置中的敏感信息
Args:

View File

@@ -5,7 +5,8 @@ import aiohttp
from urllib.parse import urljoin
from app.models.tool_model import ToolType, AuthType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -173,8 +174,9 @@ class CustomTool(BaseTool):
}
return operations
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
@staticmethod
def _convert_openapi_type(openapi_type: str) -> ParameterType:
"""转换OpenAPI类型到内部类型"""
type_mapping = {
"string": ParameterType.STRING,
@@ -239,8 +241,9 @@ class CustomTool(BaseTool):
headers["Authorization"] = f"Bearer {token}"
return headers
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
@staticmethod
def _build_request_data(operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建请求数据"""
if operation["method"] in ["POST", "PUT", "PATCH"]:
request_body = operation.get("request_body")
@@ -284,6 +287,7 @@ class CustomTool(BaseTool):
try:
return await response.json()
except Exception as e:
logger.error(f"解析HTTP响应JSON失败: {str(e)}")
return await response.text()
@classmethod

View File

@@ -10,6 +10,9 @@ from app.core.logging_config import get_business_logger
logger = get_business_logger()
# 为了兼容性,创建别名
# SchemaParser = OpenAPISchemaParser = None
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
@@ -88,8 +91,9 @@ class OpenAPISchemaParser:
except Exception as e:
logger.error(f"解析schema内容失败: {e}")
return False, {}, str(e)
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
@staticmethod
def _parse_content(content: str, content_type: str) -> Optional[Dict[str, Any]]:
"""解析内容为字典
Args:
@@ -101,7 +105,7 @@ class OpenAPISchemaParser:
"""
try:
# 根据内容类型解析
if 'json' in content_type:
if 'application/json' in content_type:
return json.loads(content)
elif 'yaml' in content_type or 'yml' in content_type:
return yaml.safe_load(content)
@@ -228,8 +232,9 @@ class OpenAPISchemaParser:
}
return operations
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _extract_parameters(operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取操作参数
Args:
@@ -266,8 +271,9 @@ class OpenAPISchemaParser:
}
return parameters
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
@staticmethod
def _extract_request_body(operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""提取请求体信息
Args:
@@ -298,8 +304,9 @@ class OpenAPISchemaParser:
"schema": schema,
"content_types": list(content.keys())
}
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
@staticmethod
def _extract_responses(operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取响应信息
Args:
@@ -331,8 +338,9 @@ class OpenAPISchemaParser:
}
return responses
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
@staticmethod
def generate_tool_parameters(operations: Dict[str, Any]) -> List[Dict[str, Any]]:
"""生成工具参数定义
Args:
@@ -396,7 +404,7 @@ class OpenAPISchemaParser:
parameters.extend(all_params.values())
return parameters
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""验证操作参数
@@ -447,8 +455,9 @@ class OpenAPISchemaParser:
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
return len(errors) == 0, errors
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
@staticmethod
def _validate_parameter_type(value: Any, expected_type: str) -> bool:
"""验证参数类型
Args:
@@ -474,4 +483,7 @@ class OpenAPISchemaParser:
if expected_python_type:
return isinstance(value, expected_python_type)
return True
return True
# 为了兼容性,创建别名
SchemaParser = OpenAPISchemaParser

View File

@@ -1,501 +0,0 @@
"""工具执行器 - 负责工具的实际调用和执行管理"""
import asyncio
import uuid
import time
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import ToolExecution, ExecutionStatus
from app.core.tools.base import BaseTool, ToolResult
from app.core.tools.registry import ToolRegistry
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ExecutionContext:
"""执行上下文"""
def __init__(
self,
execution_id: str,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
):
self.execution_id = execution_id
self.tool_id = tool_id
self.user_id = user_id
self.workspace_id = workspace_id
self.timeout = timeout or 60.0 # 默认60秒超时
self.metadata = metadata or {}
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self.status = ExecutionStatus.PENDING
class ToolExecutor:
"""工具执行器 - 使用langchain标准接口执行工具"""
def __init__(self, db: Session, registry: ToolRegistry):
"""初始化工具执行器
Args:
db: 数据库会话
registry: 工具注册表
"""
self.db = db
self.registry = registry
self._running_executions: Dict[str, ExecutionContext] = {}
self._execution_lock = asyncio.Lock()
async def execute_tool(
self,
tool_id: str,
parameters: Dict[str, Any],
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
execution_id: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
) -> ToolResult:
"""执行工具
Args:
tool_id: 工具ID
parameters: 工具参数
user_id: 用户ID
workspace_id: 工作空间ID
execution_id: 执行ID可选自动生成
timeout: 超时时间(秒)
metadata: 额外元数据
Returns:
工具执行结果
"""
# 生成执行ID
if not execution_id:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ExecutionContext(
execution_id=execution_id,
tool_id=tool_id,
user_id=user_id,
workspace_id=workspace_id,
timeout=timeout,
metadata=metadata
)
try:
# 获取工具实例
tool = self.registry.get_tool(tool_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
error_code="TOOL_NOT_FOUND",
execution_time=0.0
)
# 记录执行开始
await self._record_execution_start(context, parameters)
# 执行工具
result = await self._execute_with_timeout(tool, parameters, context)
# 记录执行完成
await self._record_execution_complete(context, result)
return result
except Exception as e:
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
# 记录执行失败
error_result = ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=time.time() - context.started_at.timestamp()
)
await self._record_execution_complete(context, error_result)
return error_result
finally:
# 清理执行上下文
async with self._execution_lock:
if execution_id in self._running_executions:
del self._running_executions[execution_id]
async def execute_tools_batch(
self,
tool_executions: List[Dict[str, Any]],
max_concurrency: int = 5
) -> List[ToolResult]:
"""批量执行工具
Args:
tool_executions: 工具执行配置列表每个包含tool_id和parameters
max_concurrency: 最大并发数
Returns:
执行结果列表
"""
semaphore = asyncio.Semaphore(max_concurrency)
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
async with semaphore:
return await self.execute_tool(
tool_id=exec_config["tool_id"],
parameters=exec_config.get("parameters", {}),
user_id=exec_config.get("user_id"),
workspace_id=exec_config.get("workspace_id"),
timeout=exec_config.get("timeout"),
metadata=exec_config.get("metadata")
)
# 并发执行所有工具
tasks = [execute_single(config) for config in tool_executions]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append(
ToolResult.error_result(
error=str(result),
error_code="BATCH_EXECUTION_ERROR",
execution_time=0.0
)
)
else:
processed_results.append(result)
return processed_results
async def cancel_execution(self, execution_id: str) -> bool:
"""取消工具执行
Args:
execution_id: 执行ID
Returns:
是否成功取消
"""
async with self._execution_lock:
if execution_id not in self._running_executions:
return False
context = self._running_executions[execution_id]
context.status = ExecutionStatus.FAILED
# 更新数据库记录
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
if execution_record:
execution_record.status = ExecutionStatus.FAILED.value
execution_record.error_message = "执行被取消"
execution_record.completed_at = datetime.now()
self.db.commit()
logger.info(f"工具执行已取消: {execution_id}")
return True
def get_running_executions(self) -> List[Dict[str, Any]]:
"""获取正在运行的执行列表
Returns:
执行信息列表
"""
executions = []
for execution_id, context in self._running_executions.items():
executions.append({
"execution_id": execution_id,
"tool_id": context.tool_id,
"user_id": str(context.user_id) if context.user_id else None,
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
"started_at": context.started_at.isoformat(),
"status": context.status.value,
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
})
return executions
async def _execute_with_timeout(
self,
tool: BaseTool,
parameters: Dict[str, Any],
context: ExecutionContext
) -> ToolResult:
"""带超时的工具执行
Args:
tool: 工具实例
parameters: 参数
context: 执行上下文
Returns:
执行结果
"""
async with self._execution_lock:
self._running_executions[context.execution_id] = context
context.status = ExecutionStatus.RUNNING
try:
# 使用asyncio.wait_for实现超时控制
result = await asyncio.wait_for(
tool.safe_execute(**parameters),
timeout=context.timeout
)
context.status = ExecutionStatus.COMPLETED
return result
except asyncio.TimeoutError:
context.status = ExecutionStatus.TIMEOUT
return ToolResult.error_result(
error=f"工具执行超时({context.timeout}秒)",
error_code="EXECUTION_TIMEOUT",
execution_time=context.timeout
)
except Exception as e:
context.status = ExecutionStatus.FAILED
raise
async def _record_execution_start(
self,
context: ExecutionContext,
parameters: Dict[str, Any]
):
"""记录执行开始"""
try:
execution_record = ToolExecution(
execution_id=context.execution_id,
tool_config_id=uuid.UUID(context.tool_id),
status=ExecutionStatus.RUNNING.value,
input_data=parameters,
started_at=context.started_at,
user_id=context.user_id,
workspace_id=context.workspace_id
)
self.db.add(execution_record)
self.db.commit()
logger.debug(f"执行记录已创建: {context.execution_id}")
except Exception as e:
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
async def _record_execution_complete(
self,
context: ExecutionContext,
result: ToolResult
):
"""记录执行完成"""
try:
context.completed_at = datetime.now()
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == context.execution_id
).first()
if execution_record:
execution_record.status = (
ExecutionStatus.COMPLETED.value if result.success
else ExecutionStatus.FAILED.value
)
execution_record.output_data = result.data if result.success else None
execution_record.error_message = result.error if not result.success else None
execution_record.completed_at = context.completed_at
execution_record.execution_time = result.execution_time
execution_record.token_usage = result.token_usage
self.db.commit()
logger.debug(f"执行记录已更新: {context.execution_id}")
except Exception as e:
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
def get_execution_history(
self,
tool_id: Optional[str] = None,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""获取执行历史
Args:
tool_id: 工具ID过滤
user_id: 用户ID过滤
workspace_id: 工作空间ID过滤
limit: 返回数量限制
Returns:
执行历史列表
"""
try:
query = self.db.query(ToolExecution).order_by(
ToolExecution.started_at.desc()
)
if tool_id:
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
if user_id:
query = query.filter(ToolExecution.user_id == user_id)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.limit(limit).all()
history = []
for execution in executions:
history.append({
"execution_id": execution.execution_id,
"tool_id": str(execution.tool_config_id),
"status": execution.status,
"started_at": execution.started_at.isoformat() if execution.started_at else None,
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
"execution_time": execution.execution_time,
"user_id": str(execution.user_id) if execution.user_id else None,
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
"input_data": execution.input_data,
"output_data": execution.output_data,
"error_message": execution.error_message,
"token_usage": execution.token_usage
})
return history
except Exception as e:
logger.error(f"获取执行历史失败, 错误: {e}")
return []
def get_execution_statistics(
self,
workspace_id: Optional[uuid.UUID] = None,
days: int = 7
) -> Dict[str, Any]:
"""获取执行统计信息
Args:
workspace_id: 工作空间ID
days: 统计天数
Returns:
统计信息
"""
try:
from datetime import timedelta
start_date = datetime.now() - timedelta(days=days)
query = self.db.query(ToolExecution).filter(
ToolExecution.started_at >= start_date
)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.all()
# 统计数据
total_executions = len(executions)
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
# 平均执行时间
completed_executions = [e for e in executions if e.execution_time is not None]
avg_execution_time = (
sum(e.execution_time for e in completed_executions) / len(completed_executions)
if completed_executions else 0
)
# 按工具统计
tool_stats = {}
for execution in executions:
tool_id = str(execution.tool_config_id)
if tool_id not in tool_stats:
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
tool_stats[tool_id]["total"] += 1
if execution.status == ExecutionStatus.COMPLETED.value:
tool_stats[tool_id]["successful"] += 1
elif execution.status == ExecutionStatus.FAILED.value:
tool_stats[tool_id]["failed"] += 1
return {
"period_days": days,
"total_executions": total_executions,
"successful_executions": successful_executions,
"failed_executions": failed_executions,
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
"average_execution_time": avg_execution_time,
"tool_statistics": tool_stats
}
except Exception as e:
logger.error(f"获取执行统计失败, 错误: {e}")
return {}
async def test_tool_connection(
self,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try:
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
from .mcp.client import MCPClient
tool_config = self.db.query(ToolConfig).filter(
ToolConfig.id == uuid.UUID(tool_id)
).first()
if not tool_config:
return {"success": False, "message": "工具不存在"}
if tool_config.tool_type == ToolType.MCP.value:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
tools = await client.list_tools()
await client.disconnect()
return {
"success": True,
"message": "MCP连接成功",
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
}
except:
await client.disconnect()
return {"success": False, "message": "MCP功能测试失败"}
else:
return {"success": False, "message": "MCP连接失败"}
else:
tool = self.registry.get_tool(tool_id)
if tool and hasattr(tool, 'test_connection'):
result = tool.test_connection()
return {"success": result.get("success", False), "message": result.get("message", "")}
return {"success": True, "message": "工具无需连接测试"}
except Exception as e:
return {"success": False, "message": "测试失败", "error": str(e)}

View File

@@ -4,7 +4,8 @@ from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.tools.base import BaseTool
from app.schemas.tool_schema import ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
@@ -123,33 +124,43 @@ class MCPTool(BaseTool):
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
# 这里应该实现实际的MCP连接逻辑
# 为了简化,这里只是模拟连接
from .client import MCPClient
# 测试服务器连接
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
# 尝试获取服务器信息
async with session.get(f"{self.server_url}/info") as response:
if response.status == 200:
server_info = await response.json()
self.available_tools = server_info.get("tools", [])
self._connected = True
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
raise Exception(f"服务器响应错误: {response.status}")
if self._connected:
return True
self._client = MCPClient(self.server_url, self.connection_config)
if await self._client.connect():
self._connected = True
# 更新可用工具列表
await self._update_available_tools()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
logger.error(f"MCP服务器连接失败: {self.server_url}")
return False
except Exception as e:
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
logger.error(f"MCP服务器连接异常: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def _update_available_tools(self):
"""更新可用工具列表"""
try:
if self._client and self._connected:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
logger.info(f"MCP工具列表已更新: {len(self.available_tools)} 个工具")
except Exception as e:
logger.error(f"更新MCP工具列表失败: {e}")
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
# 这里应该实现实际的断开逻辑
await self._client.disconnect()
self._client = None
self._connected = False
@@ -171,38 +182,15 @@ class MCPTool(BaseTool):
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
# 构建MCP请求
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
if not self._client or not self._connected:
raise Exception("MCP客户端未连接")
# 发送请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
result = await response.json()
# 检查MCP响应
if "error" in result:
error = result["error"]
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
return result.get("result", {})
try:
result = await self._client.call_tool(tool_name, arguments, timeout)
return result
except Exception as e:
logger.error(f"MCP工具调用失败: {tool_name}, 错误: {e}")
raise
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
@@ -210,27 +198,10 @@ class MCPTool(BaseTool):
if not self._connected:
await self.connect()
# 获取工具列表
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/list"
}
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status == 200:
result = await response.json()
if "result" in result:
tools = result["result"].get("tools", [])
self.available_tools = [tool.get("name") for tool in tools]
return tools
if self._client:
tools = await self._client.list_tools()
self.available_tools = [tool.get("name") for tool in tools if tool.get("name")]
return tools
return []

View File

@@ -134,11 +134,40 @@ class MCPClient:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def _build_auth_headers(self) -> Dict[str, str]:
"""构建认证头"""
headers = {}
auth_type = self.connection_config.get("auth_type", "none")
auth_config = self.connection_config.get("auth_config", {})
if auth_type == "api_key":
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif auth_type == "bearer_token":
token = auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
elif auth_type == "basic_auth":
username = auth_config.get("username")
password = auth_config.get("password")
if username and password:
import base64
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
headers["Authorization"] = f"Basic {credentials}"
return headers
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
extra_headers.update(auth_headers)
self._websocket = await websockets.connect(
self.server_url,
@@ -190,6 +219,8 @@ class MCPClient:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
auth_headers = self._build_auth_headers()
headers.update(auth_headers)
self._session = aiohttp.ClientSession(
timeout=timeout,
@@ -251,8 +282,9 @@ class MCPClient:
except Exception as e:
logger.error(f"处理消息失败: {e}")
async def _handle_notification(self, message: Dict[str, Any]):
@staticmethod
async def _handle_notification(message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
@@ -327,7 +359,7 @@ class MCPClient:
try:
response = await self._send_request(request_data, timeout)
if not response["error"] is None:
if response.get("error", None) is not None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
@@ -372,10 +404,10 @@ class MCPClient:
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
await self._pending_requests.pop(request_id, None)
raise
except Exception as e:
self._pending_requests.pop(request_id, None)
await self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
@@ -424,9 +456,9 @@ class MCPClient:
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = time.time() - start_time
response_time = round((time.time() - start_time) * 1000)
self._last_health_check = time.time()
self._last_health_check = round(time.time() * 1000)
return {
"healthy": True,

View File

@@ -6,7 +6,7 @@ from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType, ToolStatus
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
@@ -148,7 +148,7 @@ class MCPServiceManager:
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
last_health_check=datetime.utcnow()
last_health_check=datetime.now()
)
self.db.add(mcp_config)
@@ -410,7 +410,8 @@ class MCPServiceManager:
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.is_enabled == True
ToolConfig.status == ToolStatus.AVAILABLE.value,
ToolConfig.tool_type == ToolType.MCP.value
).all()
for mcp_config in mcp_configs:
@@ -531,7 +532,7 @@ class MCPServiceManager:
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.utcnow()
mcp_config.last_health_check = datetime.now()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")

View File

@@ -1,436 +0,0 @@
"""工具注册表 - 管理所有工具的元数据和状态"""
import uuid
import asyncio
from typing import Dict, List, Optional, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolType, ToolStatus, ToolExecution, ExecutionStatus
)
from app.core.logging_config import get_business_logger
from .base import BaseTool, ToolInfo
from .custom.base import CustomTool
from .mcp.base import MCPTool
logger = get_business_logger()
class ToolRegistry:
"""工具注册表 - 管理所有工具的元数据和实例"""
def __init__(self, db: Session):
"""初始化工具注册表
Args:
db: 数据库会话
"""
self.db = db
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
self._lock = asyncio.Lock() # 异步锁
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
"""注册工具类
Args:
tool_class: 工具类
class_name: 类名可选默认使用类的__name__
"""
class_name = class_name or tool_class.__name__
self._tool_classes[class_name] = tool_class
logger.info(f"工具类已注册: {class_name}")
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
"""注册工具实例到系统
Args:
tool: 工具实例
tenant_id: 租户ID内置工具可以为None表示全局工具
Returns:
注册是否成功
"""
async with self._lock:
try:
# 检查工具是否已存在
if tenant_id:
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool.tool_type.value
)
).first()
else:
# 全局工具(内置工具)
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id.is_(None),
ToolConfig.tool_type == tool.tool_type.value
)
).first()
if existing_config:
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
return False
# 创建工具配置
tool_config = ToolConfig(
name=tool.name,
description=tool.description,
tool_type=tool.tool_type.value,
tenant_id=tenant_id,
version=tool.version,
tags=tool.tags,
config_data=tool.config
)
self.db.add(tool_config)
self.db.flush() # 获取ID
# 根据工具类型创建特定配置
if tool.tool_type == ToolType.BUILTIN:
builtin_config = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool.__class__.__name__,
parameters=tool.config.get("parameters", {})
)
self.db.add(builtin_config)
elif tool.tool_type == ToolType.CUSTOM:
custom_config = CustomToolConfig(
id=tool_config.id,
schema_url=tool.config.get("schema_url"),
schema_content=tool.config.get("schema_content"),
auth_type=tool.config.get("auth_type", "none"),
auth_config=tool.config.get("auth_config", {}),
base_url=tool.config.get("base_url"),
timeout=tool.config.get("timeout", 30)
)
self.db.add(custom_config)
elif tool.tool_type == ToolType.MCP:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=tool.config.get("server_url"),
connection_config=tool.config.get("connection_config", {}),
available_tools=tool.config.get("available_tools", [])
)
self.db.add(mcp_config)
self.db.commit()
# 缓存工具实例
tool.tool_id = str(tool_config.id)
self._tools[str(tool_config.id)] = tool
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
return False
async def unregister_tool(self, tool_id: str) -> bool:
"""从系统注销工具
Args:
tool_id: 工具ID
Returns:
注销是否成功
"""
async with self._lock:
try:
# 检查工具是否存在
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 检查是否有正在执行的任务
running_executions = self.db.query(ToolExecution).filter(
and_(
ToolExecution.tool_config_id == uuid.UUID(tool_id),
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
)
).count()
if running_executions > 0:
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
return False
# 删除工具配置(级联删除相关记录)
self.db.delete(tool_config)
self.db.commit()
# 从缓存中移除
if tool_id in self._tools:
del self._tools[tool_id]
logger.info(f"工具注销成功: {tool_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
return False
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
"""获取工具实例
Args:
tool_id: 工具ID
Returns:
工具实例如果不存在返回None
"""
# 先从缓存获取
if tool_id in self._tools:
return self._tools[tool_id]
# 从数据库加载
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
return None
# 根据工具类型加载实例
tool_instance = self._load_tool_instance(tool_config)
if tool_instance:
self._tools[tool_id] = tool_instance
return tool_instance
except Exception as e:
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
return None
def list_tools(
self,
tenant_id: Optional[uuid.UUID] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
tags: Optional[List[str]] = None
) -> List[ToolInfo]:
"""列出工具
Args:
tenant_id: 租户ID过滤
tool_type: 工具类型过滤
status: 工具状态过滤
tags: 标签过滤
Returns:
工具信息列表
"""
try:
query = self.db.query(ToolConfig)
# 应用过滤条件
if tenant_id:
# 返回全局工具tenant_id为空和该租户的工具
query = query.filter(
or_(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tenant_id.is_(None)
)
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status == ToolStatus.ACTIVE:
query = query.filter(ToolConfig.is_enabled == True)
elif status == ToolStatus.INACTIVE:
query = query.filter(ToolConfig.is_enabled == False)
if tags:
for tag in tags:
query = query.filter(ToolConfig.tags.contains([tag]))
tool_configs = query.all()
# 转换为ToolInfo
tool_infos = []
for config in tool_configs:
tool_info = ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
tool_type=ToolType(config.tool_type),
version=config.version,
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None
)
# 尝试获取参数信息
tool_instance = self.get_tool(str(config.id))
if tool_instance:
tool_info.parameters = tool_instance.parameters
tool_infos.append(tool_info)
return tool_infos
except Exception as e:
logger.error(f"列出工具失败, 错误: {e}")
return []
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
"""更新工具状态
Args:
tool_id: 工具ID
status: 新状态
Returns:
更新是否成功
"""
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 更新状态
if status == ToolStatus.ACTIVE:
tool_config.is_enabled = True
elif status == ToolStatus.INACTIVE:
tool_config.is_enabled = False
self.db.commit()
# 更新缓存中的工具状态
if tool_id in self._tools:
self._tools[tool_id].status = status
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
return False
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
"""从配置加载工具实例
Args:
tool_config: 工具配置
Returns:
工具实例
"""
try:
if tool_config.tool_type == ToolType.BUILTIN.value:
# 加载内置工具
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config and builtin_config.tool_class in self._tool_classes:
tool_class = self._tool_classes[builtin_config.tool_class]
config = {
**tool_config.config_data,
"parameters": builtin_config.parameters,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return tool_class(str(tool_config.id), config)
elif tool_config.tool_type == ToolType.CUSTOM.value:
# 加载自定义工具
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
config = {
**tool_config.config_data,
"schema_url": custom_config.schema_url,
"schema_content": custom_config.schema_content,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config,
"base_url": custom_config.base_url,
"timeout": custom_config.timeout,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return CustomTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入自定义工具模块: {e}")
elif tool_config.tool_type == ToolType.MCP.value:
# 加载MCP工具
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
config = {
**tool_config.config_data,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config,
"available_tools": mcp_config.available_tools,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return MCPTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入MCP工具模块: {e}")
except Exception as e:
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
return None
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""获取工具统计信息
Args:
tenant_id: 租户ID
Returns:
统计信息字典
"""
try:
query = self.db.query(ToolConfig)
if tenant_id:
query = query.filter(ToolConfig.tenant_id == tenant_id)
total_tools = query.count()
active_tools = query.filter(ToolConfig.is_enabled == True).count()
# 按类型统计
type_stats = {}
for tool_type in ToolType:
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
type_stats[tool_type.value] = count
return {
"total_tools": total_tools,
"active_tools": active_tools,
"inactive_tools": total_tools - active_tools,
"by_type": type_stats
}
except Exception as e:
logger.error(f"获取工具统计失败, 错误: {e}")
return {}
def clear_cache(self):
"""清空工具缓存"""
self._tools.clear()
logger.info("工具缓存已清空")