feat(apikey system): tool system development

This commit is contained in:
谢俊男
2025-12-20 15:24:28 +08:00
parent 3fbd4f206e
commit c26af11f76
39 changed files with 9338 additions and 4 deletions

View File

@@ -0,0 +1,37 @@
"""工具管理核心模块"""
from .base import BaseTool, ToolResult, ToolParameter
from .registry import ToolRegistry
from .executor import ToolExecutor
from .langchain_adapter import LangchainAdapter
from .config_manager import ConfigManager
from .chain_manager import ChainManager
# 可选导入,避免导入错误
try:
from .custom.base import CustomTool
except ImportError:
CustomTool = None
try:
from .mcp.base import MCPTool
except ImportError:
MCPTool = None
__all__ = [
"BaseTool",
"ToolResult",
"ToolParameter",
"ToolRegistry",
"ToolExecutor",
"LangchainAdapter",
"ConfigManager",
"ChainManager"
]
# 只有在成功导入时才添加到__all__
if CustomTool:
__all__.append("CustomTool")
if MCPTool:
__all__.append("MCPTool")

302
api/app/core/tools/base.py Normal file
View File

@@ -0,0 +1,302 @@
"""工具基础接口定义"""
import time
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from enum import Enum
from app.models.tool_model import ToolType, ToolStatus
class ParameterType(str, Enum):
"""参数类型枚举"""
STRING = "string"
INTEGER = "integer"
NUMBER = "number"
BOOLEAN = "boolean"
ARRAY = "array"
OBJECT = "object"
class ToolParameter(BaseModel):
"""工具参数定义"""
name: str = Field(..., description="参数名称")
type: ParameterType = Field(..., description="参数类型")
description: str = Field("", description="参数描述")
required: bool = Field(False, description="是否必需")
default: Any = Field(None, description="默认值")
enum: Optional[List[Any]] = Field(None, description="枚举值")
minimum: Optional[Union[int, float]] = Field(None, description="最小值")
maximum: Optional[Union[int, float]] = Field(None, description="最大值")
pattern: Optional[str] = Field(None, description="正则表达式模式")
class Config:
use_enum_values = True
class ToolResult(BaseModel):
"""工具执行结果"""
success: bool = Field(..., description="执行是否成功")
data: Any = Field(None, description="返回数据")
error: Optional[str] = Field(None, description="错误信息")
error_code: Optional[str] = Field(None, description="错误代码")
execution_time: float = Field(..., description="执行时间(秒)")
token_usage: Optional[Dict[str, int]] = Field(None, description="Token使用情况")
metadata: Dict[str, Any] = Field(default_factory=dict, description="额外元数据")
@classmethod
def success_result(
cls,
data: Any,
execution_time: float,
token_usage: Optional[Dict[str, int]] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建成功结果"""
return cls(
success=True,
data=data,
execution_time=execution_time,
token_usage=token_usage,
metadata=metadata or {}
)
@classmethod
def error_result(
cls,
error: str,
execution_time: float,
error_code: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None
) -> "ToolResult":
"""创建错误结果"""
return cls(
success=False,
error=error,
error_code=error_code,
execution_time=execution_time,
metadata=metadata or {}
)
class ToolInfo(BaseModel):
"""工具信息"""
id: str = Field(..., description="工具ID")
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
tool_type: ToolType = Field(..., description="工具类型")
version: str = Field("1.0.0", description="工具版本")
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
status: ToolStatus = Field(ToolStatus.ACTIVE, description="工具状态")
tags: List[str] = Field(default_factory=list, description="工具标签")
tenant_id: Optional[str] = Field(None, description="租户ID")
class Config:
use_enum_values = True
class BaseTool(ABC):
"""所有工具的基础抽象类"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化工具
Args:
tool_id: 工具ID
config: 工具配置
"""
self.tool_id = tool_id
self.config = config
self._status = ToolStatus.ACTIVE
@property
@abstractmethod
def name(self) -> str:
"""工具名称"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述"""
pass
@property
@abstractmethod
def tool_type(self) -> ToolType:
"""工具类型"""
pass
@property
def version(self) -> str:
"""工具版本"""
return self.config.get("version", "1.0.0")
@property
def status(self) -> ToolStatus:
"""工具状态"""
return self._status
@status.setter
def status(self, value: ToolStatus):
"""设置工具状态"""
self._status = value
@property
@abstractmethod
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
pass
@property
def tags(self) -> List[str]:
"""工具标签"""
return self.config.get("tags", [])
def get_info(self) -> ToolInfo:
"""获取工具信息"""
return ToolInfo(
id=self.tool_id,
name=self.name,
description=self.description,
tool_type=self.tool_type,
version=self.version,
parameters=self.parameters,
status=self.status,
tags=self.tags,
tenant_id=self.config.get("tenant_id")
)
def validate_parameters(self, parameters: Dict[str, Any]) -> Dict[str, str]:
"""验证参数
Args:
parameters: 输入参数
Returns:
验证错误字典,空字典表示验证通过
"""
errors = {}
param_definitions = {p.name: p for p in self.parameters}
# 检查必需参数
for param_def in self.parameters:
if param_def.required and param_def.name not in parameters:
errors[param_def.name] = f"Required parameter '{param_def.name}' is missing"
# 检查参数类型和约束
for param_name, param_value in parameters.items():
if param_name not in param_definitions:
continue
param_def = param_definitions[param_name]
# 类型检查
if not self._validate_parameter_type(param_value, param_def):
errors[param_name] = f"Parameter '{param_name}' has invalid type, expected {param_def.type}"
# 约束检查
constraint_error = self._validate_parameter_constraints(param_value, param_def)
if constraint_error:
errors[param_name] = constraint_error
return errors
def _validate_parameter_type(self, value: Any, param_def: ToolParameter) -> bool:
"""验证参数类型"""
if value is None:
return not param_def.required
type_mapping = {
ParameterType.STRING: str,
ParameterType.INTEGER: int,
ParameterType.NUMBER: (int, float),
ParameterType.BOOLEAN: bool,
ParameterType.ARRAY: list,
ParameterType.OBJECT: dict
}
expected_type = type_mapping.get(param_def.type)
if expected_type:
return isinstance(value, expected_type)
return True
def _validate_parameter_constraints(self, value: Any, param_def: ToolParameter) -> Optional[str]:
"""验证参数约束"""
if value is None:
return None
# 枚举值检查
if param_def.enum and value not in param_def.enum:
return f"Value must be one of {param_def.enum}"
# 数值范围检查
if param_def.type in [ParameterType.INTEGER, ParameterType.NUMBER]:
if param_def.minimum is not None and value < param_def.minimum:
return f"Value must be >= {param_def.minimum}"
if param_def.maximum is not None and value > param_def.maximum:
return f"Value must be <= {param_def.maximum}"
# 字符串模式检查
if param_def.type == ParameterType.STRING and param_def.pattern:
import re
if not re.match(param_def.pattern, str(value)):
return f"Value must match pattern: {param_def.pattern}"
return None
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""执行工具
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
async def safe_execute(self, **kwargs) -> ToolResult:
"""安全执行工具(包含参数验证和异常处理)
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
start_time = time.time()
try:
# 参数验证
validation_errors = self.validate_parameters(kwargs)
if validation_errors:
execution_time = time.time() - start_time
error_msg = "; ".join([f"{k}: {v}" for k, v in validation_errors.items()])
return ToolResult.error_result(
error=f"Parameter validation failed: {error_msg}",
error_code="VALIDATION_ERROR",
execution_time=execution_time
)
# 执行工具
result = await self.execute(**kwargs)
return result
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=execution_time
)
def to_langchain_tool(self):
"""转换为Langchain工具格式"""
from .langchain_adapter import LangchainAdapter
return LangchainAdapter.convert_tool(self)
def __repr__(self):
return f"<{self.__class__.__name__}(id={self.tool_id}, name={self.name})>"

View File

@@ -0,0 +1,17 @@
"""内置工具模块"""
from .base import BuiltinTool
from .datetime_tool import DateTimeTool
from .json_tool import JsonTool
from .baidu_search_tool import BaiduSearchTool
from .mineru_tool import MinerUTool
from .textin_tool import TextInTool
__all__ = [
"BuiltinTool",
"DateTimeTool",
"JsonTool",
"BaiduSearchTool",
"MinerUTool",
"TextInTool"
]

View File

@@ -0,0 +1,334 @@
"""百度搜索工具 - 搜索引擎服务"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class BaiduSearchTool(BuiltinTool):
"""百度搜索工具 - 提供网页搜索、新闻搜索、图片搜索、实时结果"""
@property
def name(self) -> str:
return "baidu_search_tool"
@property
def description(self) -> str:
return "百度搜索 - 搜索引擎服务:网页搜索、新闻搜索、图片搜索、实时结果"
def get_required_config_parameters(self) -> List[str]:
return ["api_key"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="query",
type=ParameterType.STRING,
description="搜索关键词",
required=True
),
ToolParameter(
name="search_type",
type=ParameterType.STRING,
description="搜索类型",
required=False,
default="web",
enum=["web", "news", "image", "video"]
),
ToolParameter(
name="page_size",
type=ParameterType.INTEGER,
description="每页结果数",
required=False,
default=10,
minimum=1,
maximum=50
),
ToolParameter(
name="page_num",
type=ParameterType.INTEGER,
description="页码从1开始",
required=False,
default=1,
minimum=1,
maximum=10
),
ToolParameter(
name="safe_search",
type=ParameterType.BOOLEAN,
description="是否启用安全搜索",
required=False,
default=True
),
ToolParameter(
name="region",
type=ParameterType.STRING,
description="搜索地区",
required=False,
default="cn",
enum=["cn", "hk", "tw", "us", "jp", "kr"]
),
ToolParameter(
name="time_filter",
type=ParameterType.STRING,
description="时间过滤",
required=False,
enum=["all", "day", "week", "month", "year"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行百度搜索"""
start_time = time.time()
try:
query = kwargs.get("query")
search_type = kwargs.get("search_type", "web")
page_size = kwargs.get("page_size", 10)
page_num = kwargs.get("page_num", 1)
safe_search = kwargs.get("safe_search", True)
region = kwargs.get("region", "cn")
time_filter = kwargs.get("time_filter")
if not query:
raise ValueError("query 参数是必需的")
# 根据搜索类型调用不同的API
if search_type == "web":
result = await self._web_search(query, page_size, page_num, safe_search, region, time_filter)
elif search_type == "news":
result = await self._news_search(query, page_size, page_num, region, time_filter)
elif search_type == "image":
result = await self._image_search(query, page_size, page_num, safe_search)
elif search_type == "video":
result = await self._video_search(query, page_size, page_num, safe_search)
else:
raise ValueError(f"不支持的搜索类型: {search_type}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="BAIDU_SEARCH_ERROR",
execution_time=execution_time
)
async def _web_search(self, query: str, page_size: int, page_num: int,
safe_search: bool, region: str, time_filter: str = None) -> Dict[str, Any]:
"""网页搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": min(page_size, 50)}],
"enable_full_content": True
}
if time_filter:
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
if time_filter in time_map:
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
payload["search_recency_filter"] = time_filter
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "web",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _news_search(self, query: str, page_size: int, page_num: int,
region: str, time_filter: str = None) -> Dict[str, Any]:
"""新闻搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "new", "top_k": min(page_size, 50)}],
"enable_full_content": True
}
if time_filter:
time_map = {"day": "now-1d/d", "week": "now-1w/d", "month": "now-1M/d", "year": "now-1y/d"}
if time_filter in time_map:
payload["search_filter"] = {"range": {"page_time": {"gte": time_map[time_filter], "lt": "now/d"}}}
payload["search_recency_filter"] = time_filter
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "new",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _image_search(self, query: str, page_size: int, page_num: int,
safe_search: bool) -> Dict[str, Any]:
"""图片搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "image", "top_k": min(page_size, 30)}],
"enable_full_content": True
}
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "image",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _video_search(self, query: str, page_size: int, page_num: int,
safe_search: bool) -> Dict[str, Any]:
"""视频搜索"""
payload = {
"messages": [{"role": "user", "content": query}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "video", "top_k": min(page_size, 10)}],
"enable_full_content": True
}
results = await self._call_baidu_ai_search_api(payload)
search_results = []
if "references" in results:
for item in results["references"]:
search_results.append({
"title": item.get("title", ""),
"url": item.get("url", ""),
"snippet": item.get("content", ""),
"display_url": item.get("url", ""),
"rank": len(search_results) + 1
})
return {
"search_type": "video",
"query": query,
"total_results": len(search_results),
"page_num": page_num,
"page_size": page_size,
"results": search_results,
"answer": results.get("result", ""),
"references": results.get("references", [])
}
async def _call_baidu_ai_search_api(self, payload: Dict[str, Any]) -> Dict[str, Any]:
"""调用百度AI搜索API"""
api_key = self.get_config_parameter("api_key")
if not api_key:
raise ValueError("百度搜索API密钥未配置")
url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions"
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {api_key}'
}
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, headers=headers, json=payload) as response:
if response.status == 200:
return await response.json()
else:
raise Exception(f"HTTP错误: {response.status}")
async def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
api_key = self.get_config_parameter("api_key")
if not api_key:
return {
"success": False,
"error": "API密钥未配置"
}
# 发送测试请求验证API key是否有效
test_payload = {
"messages": [{"role": "user", "content": "test"}],
"edition": "standard",
"search_source": "baidu_search_v2",
"resource_type_filter": [{"type": "web", "top_k": 1}]
}
try:
await self._call_baidu_ai_search_api(test_payload)
return {
"success": True,
"message": "连接测试成功",
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": f"API连接失败: {str(e)}"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,118 @@
"""内置工具基类"""
from abc import ABC, abstractmethod
from typing import Dict, Any, List
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolResult, ToolParameter
class BuiltinTool(BaseTool, ABC):
"""内置工具基类"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化内置工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.parameters_config = config.get("parameters", {})
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.BUILTIN
@property
@abstractmethod
def name(self) -> str:
"""工具名称 - 子类必须实现"""
pass
@property
@abstractmethod
def description(self) -> str:
"""工具描述 - 子类必须实现"""
pass
@property
@abstractmethod
def parameters(self) -> List[ToolParameter]:
"""工具参数定义 - 子类必须实现"""
pass
@abstractmethod
async def execute(self, **kwargs) -> ToolResult:
"""执行工具 - 子类必须实现
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
pass
@property
def is_configured(self) -> bool:
"""检查工具是否已正确配置"""
required_params = self.get_required_config_parameters()
for param in required_params:
if not self.parameters_config.get(param):
return False
return True
def get_required_config_parameters(self) -> List[str]:
"""获取必需的配置参数列表
Returns:
必需配置参数名称列表
"""
return []
def get_config_parameter(self, name: str, default: Any = None) -> Any:
"""获取配置参数值
Args:
name: 参数名称
default: 默认值
Returns:
参数值
"""
return self.parameters_config.get(name, default)
def validate_configuration(self) -> tuple[bool, str]:
"""验证工具配置
Returns:
(是否有效, 错误信息)
"""
if not self.is_configured:
required_params = self.get_required_config_parameters()
missing_params = [p for p in required_params if not self.parameters_config.get(p)]
return False, f"缺少必需的配置参数: {', '.join(missing_params)}"
return True, ""
async def safe_execute(self, **kwargs) -> ToolResult:
"""安全执行工具(包含配置验证)
Args:
**kwargs: 工具参数
Returns:
执行结果
"""
# 首先验证配置
is_valid, error_msg = self.validate_configuration()
if not is_valid:
return ToolResult.error_result(
error=f"工具配置无效: {error_msg}",
error_code="CONFIGURATION_ERROR",
execution_time=0.0
)
# 调用父类的安全执行
return await super().safe_execute(**kwargs)

View File

@@ -0,0 +1,307 @@
"""时间工具 - 日期时间处理"""
import time
from datetime import datetime, timezone, timedelta
from typing import List
import pytz
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class DateTimeTool(BuiltinTool):
"""时间工具 - 提供时间格式转换、时区转换、时间戳转换、时间计算功能"""
@property
def name(self) -> str:
return "datetime_tool"
@property
def description(self) -> str:
return "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算"
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "convert_timezone", "timestamp_to_datetime", "datetime_to_timestamp", "calculate", "now"]
),
ToolParameter(
name="input_value",
type=ParameterType.STRING,
description="输入值(时间字符串或时间戳)",
required=False
),
ToolParameter(
name="input_format",
type=ParameterType.STRING,
description="输入时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出时间格式(如:%Y-%m-%d %H:%M:%S",
required=False,
default="%Y-%m-%d %H:%M:%S"
),
ToolParameter(
name="from_timezone",
type=ParameterType.STRING,
description="源时区UTC, Asia/Shanghai",
required=False,
default="UTC"
),
ToolParameter(
name="to_timezone",
type=ParameterType.STRING,
description="目标时区UTC, Asia/Shanghai",
required=False,
default="UTC"
),
ToolParameter(
name="calculation",
type=ParameterType.STRING,
description="时间计算表达式(如:+1d, -2h, +30m",
required=False
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行时间工具操作"""
start_time = time.time()
try:
operation = kwargs.get("operation")
if operation == "now":
result = self._get_current_time(kwargs)
elif operation == "format":
result = self._format_datetime(kwargs)
elif operation == "convert_timezone":
result = self._convert_timezone(kwargs)
elif operation == "timestamp_to_datetime":
result = self._timestamp_to_datetime(kwargs)
elif operation == "datetime_to_timestamp":
result = self._datetime_to_timestamp(kwargs)
elif operation == "calculate":
result = self._calculate_datetime(kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="DATETIME_ERROR",
execution_time=execution_time
)
def _get_current_time(self, kwargs) -> dict:
"""获取当前时间"""
timezone_str = kwargs.get("to_timezone", "UTC")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
now = datetime.now(tz)
return {
"datetime": now.strftime(output_format),
"timestamp": int(now.timestamp()),
"timezone": timezone_str,
"iso_format": now.isoformat()
}
def _format_datetime(self, kwargs) -> dict:
"""格式化时间"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
return {
"original": input_value,
"formatted": dt.strftime(output_format),
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _convert_timezone(self, kwargs) -> dict:
"""时区转换"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
from_timezone = kwargs.get("from_timezone", "UTC")
to_timezone = kwargs.get("to_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置源时区
if from_timezone == "UTC":
from_tz = pytz.UTC
else:
from_tz = pytz.timezone(from_timezone)
# 设置目标时区
if to_timezone == "UTC":
to_tz = pytz.UTC
else:
to_tz = pytz.timezone(to_timezone)
# 本地化时间并转换时区
if dt.tzinfo is None:
dt = from_tz.localize(dt)
converted_dt = dt.astimezone(to_tz)
return {
"original": input_value,
"original_timezone": from_timezone,
"converted": converted_dt.strftime(output_format),
"converted_timezone": to_timezone,
"timestamp": int(converted_dt.timestamp())
}
def _timestamp_to_datetime(self, kwargs) -> dict:
"""时间戳转日期时间"""
input_value = kwargs.get("input_value")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("to_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 转换时间戳
timestamp = float(input_value)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
dt = datetime.fromtimestamp(timestamp, tz)
return {
"timestamp": timestamp,
"datetime": dt.strftime(output_format),
"timezone": timezone_str,
"iso_format": dt.isoformat()
}
def _datetime_to_timestamp(self, kwargs) -> dict:
"""日期时间转时间戳"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
timezone_str = kwargs.get("from_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
# 本地化时间
if dt.tzinfo is None:
dt = tz.localize(dt)
return {
"datetime": input_value,
"timezone": timezone_str,
"timestamp": int(dt.timestamp()),
"iso_format": dt.isoformat()
}
def _calculate_datetime(self, kwargs) -> dict:
"""时间计算"""
input_value = kwargs.get("input_value")
input_format = kwargs.get("input_format", "%Y-%m-%d %H:%M:%S")
output_format = kwargs.get("output_format", "%Y-%m-%d %H:%M:%S")
calculation = kwargs.get("calculation")
timezone_str = kwargs.get("from_timezone", "UTC")
if not input_value:
raise ValueError("input_value 参数是必需的")
if not calculation:
raise ValueError("calculation 参数是必需的")
# 解析输入时间
dt = datetime.strptime(input_value, input_format)
# 设置时区
if timezone_str == "UTC":
tz = timezone.utc
else:
tz = pytz.timezone(timezone_str)
if dt.tzinfo is None:
dt = tz.localize(dt)
# 解析计算表达式
delta = self._parse_time_delta(calculation)
calculated_dt = dt + delta
return {
"original": input_value,
"calculation": calculation,
"result": calculated_dt.strftime(output_format),
"timezone": timezone_str,
"timestamp": int(calculated_dt.timestamp())
}
def _parse_time_delta(self, calculation: str) -> timedelta:
"""解析时间计算表达式"""
import re
# 支持的单位d(天), h(小时), m(分钟), s(秒)
pattern = r'([+-]?\d+)([dhms])'
matches = re.findall(pattern, calculation.lower())
if not matches:
raise ValueError(f"无效的时间计算表达式: {calculation}")
total_delta = timedelta()
for value_str, unit in matches:
value = int(value_str)
if unit == 'd':
total_delta += timedelta(days=value)
elif unit == 'h':
total_delta += timedelta(hours=value)
elif unit == 'm':
total_delta += timedelta(minutes=value)
elif unit == 's':
total_delta += timedelta(seconds=value)
return total_delta

View File

@@ -0,0 +1,430 @@
"""JSON转换工具 - 数据格式转换"""
import json
import time
from typing import List, Any, Dict
import yaml
import xml.etree.ElementTree as ET
from xml.dom import minidom
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class JsonTool(BuiltinTool):
"""JSON转换工具 - 提供JSON格式化、压缩、验证、格式转换功能"""
@property
def name(self) -> str:
return "json_tool"
@property
def description(self) -> str:
return "JSON转换工具 - 数据格式转换JSON格式化、JSON压缩、JSON验证、格式转换"
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["format", "minify", "validate", "convert", "to_yaml", "from_yaml", "to_xml", "from_xml", "merge", "extract"]
),
ToolParameter(
name="input_data",
type=ParameterType.STRING,
description="输入数据JSON字符串、YAML字符串或XML字符串",
required=True
),
ToolParameter(
name="indent",
type=ParameterType.INTEGER,
description="JSON格式化缩进空格数",
required=False,
default=2,
minimum=0,
maximum=8
),
ToolParameter(
name="ensure_ascii",
type=ParameterType.BOOLEAN,
description="是否确保ASCII编码",
required=False,
default=False
),
ToolParameter(
name="sort_keys",
type=ParameterType.BOOLEAN,
description="是否对键进行排序",
required=False,
default=False
),
ToolParameter(
name="merge_data",
type=ParameterType.STRING,
description="要合并的JSON数据用于merge操作",
required=False
),
ToolParameter(
name="json_path",
type=ParameterType.STRING,
description="JSON路径表达式用于extract操作$.user.name",
required=False
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行JSON工具操作"""
start_time = time.time()
try:
operation = kwargs.get("operation")
input_data = kwargs.get("input_data")
if not input_data:
raise ValueError("input_data 参数是必需的")
if operation == "format":
result = self._format_json(input_data, kwargs)
elif operation == "minify":
result = self._minify_json(input_data)
elif operation == "validate":
result = self._validate_json(input_data)
elif operation == "convert":
result = self._convert_json(input_data)
elif operation == "to_yaml":
result = self._json_to_yaml(input_data)
elif operation == "from_yaml":
result = self._yaml_to_json(input_data, kwargs)
elif operation == "to_xml":
result = self._json_to_xml(input_data)
elif operation == "from_xml":
result = self._xml_to_json(input_data, kwargs)
elif operation == "merge":
result = self._merge_json(input_data, kwargs)
elif operation == "extract":
result = self._extract_json_path(input_data, kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="JSON_ERROR",
execution_time=execution_time
)
def _format_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""格式化JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
sort_keys = kwargs.get("sort_keys", False)
# 解析JSON
data = json.loads(input_data)
# 格式化输出
formatted = json.dumps(
data,
indent=indent,
ensure_ascii=ensure_ascii,
sort_keys=sort_keys,
separators=(',', ': ')
)
return {
"original_size": len(input_data),
"formatted_size": len(formatted),
"formatted_json": formatted,
"is_valid": True,
"settings": {
"indent": indent,
"ensure_ascii": ensure_ascii,
"sort_keys": sort_keys
}
}
def _minify_json(self, input_data: str) -> Dict[str, Any]:
"""压缩JSON"""
# 解析并压缩
data = json.loads(input_data)
minified = json.dumps(data, separators=(',', ':'))
return {
"original_size": len(input_data),
"minified_size": len(minified),
"compression_ratio": round((1 - len(minified) / len(input_data)) * 100, 2),
"minified_json": minified,
"is_valid": True
}
def _validate_json(self, input_data: str) -> Dict[str, Any]:
"""验证JSON"""
try:
data = json.loads(input_data)
# 统计信息
stats = self._analyze_json_structure(data)
return {
"is_valid": True,
"error": None,
"size": len(input_data),
"structure": stats
}
except json.JSONDecodeError as e:
return {
"is_valid": False,
"error": str(e),
"error_line": getattr(e, 'lineno', None),
"error_column": getattr(e, 'colno', None),
"size": len(input_data)
}
def _convert_json(self, input_data: str) -> Dict[str, Any]:
"""JSON转义"""
data = json.loads(input_data)
converted = json.dumps(data, ensure_ascii=False)
return {
"converted_json": converted,
"is_valid": True
}
def _json_to_yaml(self, input_data: str) -> Dict[str, Any]:
"""JSON转YAML"""
data = json.loads(input_data)
yaml_output = yaml.dump(data, default_flow_style=False, allow_unicode=True, indent=2)
return {
"original_format": "json",
"target_format": "yaml",
"original_size": len(input_data),
"converted_size": len(yaml_output),
"converted_data": yaml_output
}
def _yaml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""YAML转JSON"""
indent = kwargs.get("indent", 2)
ensure_ascii = kwargs.get("ensure_ascii", False)
data = yaml.safe_load(input_data)
json_output = json.dumps(data, indent=indent, ensure_ascii=ensure_ascii)
return {
"original_format": "yaml",
"target_format": "json",
"original_size": len(input_data),
"converted_size": len(json_output),
"converted_data": json_output
}
def _json_to_xml(self, input_data: str) -> Dict[str, Any]:
"""JSON转XML"""
data = json.loads(input_data)
def dict_to_xml(data, root_name="root"):
"""递归转换字典为XML"""
if isinstance(data, dict):
if len(data) == 1 and not root_name == "root":
# 如果字典只有一个键,使用该键作为根元素
key, value = next(iter(data.items()))
return dict_to_xml(value, key)
root = ET.Element(root_name)
for key, value in data.items():
if isinstance(value, (dict, list)):
child = dict_to_xml(value, key)
root.append(child)
else:
child = ET.SubElement(root, key)
child.text = str(value)
return root
elif isinstance(data, list):
root = ET.Element(root_name)
for i, item in enumerate(data):
if isinstance(item, (dict, list)):
child = dict_to_xml(item, f"item_{i}")
root.append(child)
else:
child = ET.SubElement(root, f"item_{i}")
child.text = str(item)
return root
else:
root = ET.Element(root_name)
root.text = str(data)
return root
xml_element = dict_to_xml(data)
xml_string = ET.tostring(xml_element, encoding='unicode')
# 格式化XML
dom = minidom.parseString(xml_string)
formatted_xml = dom.toprettyxml(indent=" ")
# 移除空行
formatted_xml = '\n'.join([line for line in formatted_xml.split('\n') if line.strip()])
return {
"original_format": "json",
"target_format": "xml",
"original_size": len(input_data),
"converted_size": len(formatted_xml),
"converted_data": formatted_xml
}
def _xml_to_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""XML转JSON"""
indent = kwargs.get("indent", 2)
def xml_to_dict(element):
"""递归转换XML元素为字典"""
result = {}
# 处理属性
if element.attrib:
result.update(element.attrib)
# 处理文本内容
if element.text and element.text.strip():
if len(element) == 0: # 叶子节点
return element.text.strip()
else:
result['text'] = element.text.strip()
# 处理子元素
for child in element:
child_data = xml_to_dict(child)
if child.tag in result:
# 如果标签已存在,转换为列表
if not isinstance(result[child.tag], list):
result[child.tag] = [result[child.tag]]
result[child.tag].append(child_data)
else:
result[child.tag] = child_data
return result
root = ET.fromstring(input_data)
data = {root.tag: xml_to_dict(root)}
json_output = json.dumps(data, indent=indent, ensure_ascii=False)
return {
"original_format": "xml",
"target_format": "json",
"original_size": len(input_data),
"converted_size": len(json_output),
"converted_data": json_output
}
def _merge_json(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""合并JSON"""
merge_data = kwargs.get("merge_data")
if not merge_data:
raise ValueError("merge_data 参数是必需的")
data1 = json.loads(input_data)
data2 = json.loads(merge_data)
def deep_merge(dict1, dict2):
"""深度合并字典"""
result = dict1.copy()
for key, value in dict2.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = deep_merge(result[key], value)
else:
result[key] = value
return result
if isinstance(data1, dict) and isinstance(data2, dict):
merged = deep_merge(data1, data2)
elif isinstance(data1, list) and isinstance(data2, list):
merged = data1 + data2
else:
raise ValueError("无法合并不同类型的数据")
merged_json = json.dumps(merged, indent=2, ensure_ascii=False)
return {
"operation": "merge",
"original_size": len(input_data),
"merge_size": len(merge_data),
"result_size": len(merged_json),
"merged_data": merged_json
}
def _extract_json_path(self, input_data: str, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取JSON路径"""
json_path = kwargs.get("json_path")
if not json_path:
raise ValueError("json_path 参数是必需的")
data = json.loads(input_data)
# 简单的JSONPath实现支持基本的点号路径
try:
result = data
if json_path.startswith('$.'):
path_parts = json_path[2:].split('.')
else:
path_parts = json_path.split('.')
for part in path_parts:
if part.isdigit():
result = result[int(part)]
else:
result = result[part]
extracted_json = json.dumps(result, indent=2, ensure_ascii=False)
return {
"operation": "extract",
"json_path": json_path,
"found": True,
"extracted_data": extracted_json,
"data_type": type(result).__name__
}
except (KeyError, IndexError, TypeError) as e:
return {
"operation": "extract",
"json_path": json_path,
"found": False,
"error": str(e),
"extracted_data": None
}
def _analyze_json_structure(self, data: Any, depth: int = 0) -> Dict[str, Any]:
"""分析JSON结构"""
if isinstance(data, dict):
return {
"type": "object",
"keys": len(data),
"depth": depth,
"children": {k: self._analyze_json_structure(v, depth + 1) for k, v in data.items()}
}
elif isinstance(data, list):
return {
"type": "array",
"length": len(data),
"depth": depth,
"item_types": list(set(type(item).__name__ for item in data))
}
else:
return {
"type": type(data).__name__,
"depth": depth,
"value": str(data)[:100] + "..." if len(str(data)) > 100 else str(data)
}

View File

@@ -0,0 +1,327 @@
"""MinerU PDF解析工具"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class MinerUTool(BuiltinTool):
"""MinerU PDF解析工具 - 提供PDF解析、表格提取、图片识别、文本提取功能"""
@property
def name(self) -> str:
return "mineru_tool"
@property
def description(self) -> str:
return "MinerU - PDF解析工具PDF解析、表格提取、图片识别、文本提取"
def get_required_config_parameters(self) -> List[str]:
return ["api_key", "api_url"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="operation",
type=ParameterType.STRING,
description="操作类型",
required=True,
enum=["parse_pdf", "extract_text", "extract_tables", "extract_images", "analyze_layout"]
),
ToolParameter(
name="file_content",
type=ParameterType.STRING,
description="PDF文件内容Base64编码",
required=False
),
ToolParameter(
name="file_url",
type=ParameterType.STRING,
description="PDF文件URL",
required=False
),
ToolParameter(
name="parse_mode",
type=ParameterType.STRING,
description="解析模式",
required=False,
default="auto",
enum=["auto", "text_only", "table_priority", "image_priority", "layout_analysis"]
),
ToolParameter(
name="extract_images",
type=ParameterType.BOOLEAN,
description="是否提取图片",
required=False,
default=True
),
ToolParameter(
name="extract_tables",
type=ParameterType.BOOLEAN,
description="是否提取表格",
required=False,
default=True
),
ToolParameter(
name="page_range",
type=ParameterType.STRING,
description="页面范围1-5, 1,3,5",
required=False
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出格式",
required=False,
default="json",
enum=["json", "markdown", "html", "text"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行MinerU PDF解析"""
start_time = time.time()
try:
operation = kwargs.get("operation")
file_content = kwargs.get("file_content")
file_url = kwargs.get("file_url")
if not file_content and not file_url:
raise ValueError("必须提供 file_content 或 file_url 参数")
if operation == "parse_pdf":
result = await self._parse_pdf(kwargs)
elif operation == "extract_text":
result = await self._extract_text(kwargs)
elif operation == "extract_tables":
result = await self._extract_tables(kwargs)
elif operation == "extract_images":
result = await self._extract_images(kwargs)
elif operation == "analyze_layout":
result = await self._analyze_layout(kwargs)
else:
raise ValueError(f"不支持的操作类型: {operation}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MINERU_ERROR",
execution_time=execution_time
)
async def _parse_pdf(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""完整PDF解析"""
parse_mode = kwargs.get("parse_mode", "auto")
extract_images = kwargs.get("extract_images", True)
extract_tables = kwargs.get("extract_tables", True)
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "json")
# 构建请求参数
request_data = {
"parse_mode": parse_mode,
"extract_images": extract_images,
"extract_tables": extract_tables,
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
# 添加文件数据
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
# 调用MinerU API
result = await self._call_mineru_api("parse", request_data)
return {
"operation": "parse_pdf",
"parse_mode": parse_mode,
"total_pages": result.get("total_pages", 0),
"processed_pages": result.get("processed_pages", 0),
"text_content": result.get("text_content", ""),
"tables": result.get("tables", []),
"images": result.get("images", []),
"layout_info": result.get("layout_info", {}),
"metadata": result.get("metadata", {}),
"processing_time": result.get("processing_time", 0)
}
async def _extract_text(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取文本"""
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "text")
request_data = {
"operation": "extract_text",
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_text", request_data)
return {
"operation": "extract_text",
"total_pages": result.get("total_pages", 0),
"text_content": result.get("text_content", ""),
"word_count": len(result.get("text_content", "").split()),
"character_count": len(result.get("text_content", "")),
"pages_text": result.get("pages_text", [])
}
async def _extract_tables(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取表格"""
page_range = kwargs.get("page_range")
output_format = kwargs.get("output_format", "json")
request_data = {
"operation": "extract_tables",
"output_format": output_format
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_tables", request_data)
return {
"operation": "extract_tables",
"total_tables": result.get("total_tables", 0),
"tables": result.get("tables", []),
"table_locations": result.get("table_locations", [])
}
async def _extract_images(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""提取图片"""
page_range = kwargs.get("page_range")
request_data = {
"operation": "extract_images"
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("extract_images", request_data)
return {
"operation": "extract_images",
"total_images": result.get("total_images", 0),
"images": result.get("images", []),
"image_locations": result.get("image_locations", [])
}
async def _analyze_layout(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""分析布局"""
page_range = kwargs.get("page_range")
request_data = {
"operation": "analyze_layout"
}
if page_range:
request_data["page_range"] = page_range
if kwargs.get("file_content"):
request_data["file_content"] = kwargs["file_content"]
elif kwargs.get("file_url"):
request_data["file_url"] = kwargs["file_url"]
result = await self._call_mineru_api("analyze_layout", request_data)
return {
"operation": "analyze_layout",
"layout_info": result.get("layout_info", {}),
"page_layouts": result.get("page_layouts", []),
"text_blocks": result.get("text_blocks", []),
"image_blocks": result.get("image_blocks", []),
"table_blocks": result.get("table_blocks", [])
}
async def _call_mineru_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""调用MinerU API"""
api_key = self.get_config_parameter("api_key")
api_url = self.get_config_parameter("api_url")
timeout_seconds = self.get_config_parameter("timeout", 60)
if not api_key or not api_url:
raise ValueError("MinerU API配置未完成")
# 构建完整URL
url = f"{api_url.rstrip('/')}/{endpoint}"
# 构建请求头
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
# 发送请求
timeout = aiohttp.ClientTimeout(total=timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status == 200:
result = await response.json()
if result.get("success", True):
return result.get("data", result)
else:
raise Exception(f"MinerU API错误: {result.get('message', '未知错误')}")
else:
error_text = await response.text()
raise Exception(f"HTTP错误 {response.status}: {error_text}")
def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
api_key = self.get_config_parameter("api_key")
api_url = self.get_config_parameter("api_url")
if not api_key or not api_url:
return {
"success": False,
"error": "API配置未完成"
}
return {
"success": True,
"message": "连接配置有效",
"api_url": api_url,
"api_key_masked": api_key[:8] + "***" if len(api_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,401 @@
"""TextIn OCR文字识别工具"""
import time
from typing import List, Dict, Any
import aiohttp
from app.core.tools.base import ToolParameter, ToolResult, ParameterType
from .base import BuiltinTool
class TextInTool(BuiltinTool):
"""TextIn OCR工具 - 提供通用OCR、手写识别、多语言支持、高精度识别"""
@property
def name(self) -> str:
return "textin_tool"
@property
def description(self) -> str:
return "TextIn - OCR文字识别通用OCR、手写识别、多语言支持、高精度识别"
def get_required_config_parameters(self) -> List[str]:
return ["app_id", "secret_key", "api_url"]
@property
def parameters(self) -> List[ToolParameter]:
return [
ToolParameter(
name="image_content",
type=ParameterType.STRING,
description="图片内容Base64编码",
required=False
),
ToolParameter(
name="image_url",
type=ParameterType.STRING,
description="图片URL",
required=False
),
ToolParameter(
name="language",
type=ParameterType.STRING,
description="识别语言",
required=False,
default="auto",
enum=["auto", "zh-cn", "zh-tw", "en", "ja", "ko", "fr", "de", "es", "ru"]
),
ToolParameter(
name="recognition_mode",
type=ParameterType.STRING,
description="识别模式",
required=False,
default="general",
enum=["general", "accurate", "handwriting", "formula", "table", "document"]
),
ToolParameter(
name="return_location",
type=ParameterType.BOOLEAN,
description="是否返回文字位置信息",
required=False,
default=False
),
ToolParameter(
name="return_confidence",
type=ParameterType.BOOLEAN,
description="是否返回置信度",
required=False,
default=True
),
ToolParameter(
name="merge_lines",
type=ParameterType.BOOLEAN,
description="是否合并行",
required=False,
default=True
),
ToolParameter(
name="output_format",
type=ParameterType.STRING,
description="输出格式",
required=False,
default="text",
enum=["text", "json", "structured"]
)
]
async def execute(self, **kwargs) -> ToolResult:
"""执行TextIn OCR识别"""
start_time = time.time()
try:
image_content = kwargs.get("image_content")
image_url = kwargs.get("image_url")
if not image_content and not image_url:
raise ValueError("必须提供 image_content 或 image_url 参数")
language = kwargs.get("language", "auto")
recognition_mode = kwargs.get("recognition_mode", "general")
return_location = kwargs.get("return_location", False)
return_confidence = kwargs.get("return_confidence", True)
merge_lines = kwargs.get("merge_lines", True)
output_format = kwargs.get("output_format", "text")
# 根据识别模式调用不同的API
if recognition_mode == "general":
result = await self._general_ocr(kwargs)
elif recognition_mode == "accurate":
result = await self._accurate_ocr(kwargs)
elif recognition_mode == "handwriting":
result = await self._handwriting_ocr(kwargs)
elif recognition_mode == "formula":
result = await self._formula_ocr(kwargs)
elif recognition_mode == "table":
result = await self._table_ocr(kwargs)
elif recognition_mode == "document":
result = await self._document_ocr(kwargs)
else:
raise ValueError(f"不支持的识别模式: {recognition_mode}")
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="TEXTIN_ERROR",
execution_time=execution_time
)
async def _general_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""通用OCR识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"merge_lines": kwargs.get("merge_lines", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("general_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _accurate_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""高精度OCR识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"merge_lines": kwargs.get("merge_lines", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("accurate_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _handwriting_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""手写体识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True)
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("handwriting_ocr", request_data)
return self._format_ocr_result(result, kwargs.get("output_format", "text"))
async def _formula_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""公式识别"""
request_data = {
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"output_latex": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("formula_ocr", request_data)
return self._format_formula_result(result, kwargs.get("output_format", "text"))
async def _table_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""表格识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"output_excel": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("table_ocr", request_data)
return self._format_table_result(result, kwargs.get("output_format", "text"))
async def _document_ocr(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""文档识别"""
request_data = {
"language": kwargs.get("language", "auto"),
"return_location": kwargs.get("return_location", False),
"return_confidence": kwargs.get("return_confidence", True),
"layout_analysis": True
}
if kwargs.get("image_content"):
request_data["image"] = kwargs["image_content"]
elif kwargs.get("image_url"):
request_data["image_url"] = kwargs["image_url"]
result = await self._call_textin_api("document_ocr", request_data)
return self._format_document_result(result, kwargs.get("output_format", "text"))
def _format_ocr_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any] | None:
"""格式化OCR结果"""
lines = result.get("lines", [])
if output_format == "text":
text_content = "\n".join([line.get("text", "") for line in lines])
return {
"recognition_mode": "ocr",
"text_content": text_content,
"line_count": len(lines),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
elif output_format == "json":
return {
"recognition_mode": "ocr",
"lines": lines,
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
elif output_format == "structured":
return {
"recognition_mode": "ocr",
"text_content": "\n".join([line.get("text", "") for line in lines]),
"structured_data": {
"lines": lines,
"paragraphs": self._group_lines_to_paragraphs(lines),
"statistics": {
"line_count": len(lines),
"word_count": sum(len(line.get("text", "").split()) for line in lines),
"character_count": sum(len(line.get("text", "")) for line in lines)
}
},
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_formula_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化公式识别结果"""
formulas = result.get("formulas", [])
return {
"recognition_mode": "formula",
"formula_count": len(formulas),
"formulas": formulas,
"latex_content": "\n".join([f.get("latex", "") for f in formulas]),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_table_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化表格识别结果"""
tables = result.get("tables", [])
return {
"recognition_mode": "table",
"table_count": len(tables),
"tables": tables,
"excel_data": result.get("excel_data"),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _format_document_result(self, result: Dict[str, Any], output_format: str) -> Dict[str, Any]:
"""格式化文档识别结果"""
return {
"recognition_mode": "document",
"layout_info": result.get("layout_info", {}),
"text_blocks": result.get("text_blocks", []),
"image_blocks": result.get("image_blocks", []),
"table_blocks": result.get("table_blocks", []),
"full_text": result.get("full_text", ""),
"total_confidence": result.get("confidence", 0),
"processing_time": result.get("processing_time", 0)
}
def _group_lines_to_paragraphs(self, lines: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将行分组为段落"""
paragraphs = []
current_paragraph = []
for line in lines:
text = line.get("text", "").strip()
if text:
current_paragraph.append(line)
else:
if current_paragraph:
paragraphs.append({
"text": " ".join([l.get("text", "") for l in current_paragraph]),
"lines": current_paragraph
})
current_paragraph = []
if current_paragraph:
paragraphs.append({
"text": " ".join([l.get("text", "") for l in current_paragraph]),
"lines": current_paragraph
})
return paragraphs
async def _call_textin_api(self, endpoint: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""调用TextIn API"""
app_id = self.get_config_parameter("app_id")
secret_key = self.get_config_parameter("secret_key")
api_url = self.get_config_parameter("api_url")
if not app_id or not secret_key or not api_url:
raise ValueError("TextIn API配置未完成")
# 构建完整URL
url = f"{api_url.rstrip('/')}/{endpoint}"
# 构建请求头
headers = {
"X-App-Id": app_id,
"X-Secret-Key": secret_key,
"Content-Type": "application/json"
}
# 发送请求
timeout = aiohttp.ClientTimeout(total=30)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, json=data, headers=headers) as response:
if response.status == 200:
result = await response.json()
if result.get("code") == 200:
return result.get("data", result)
else:
raise Exception(f"TextIn API错误: {result.get('message', '未知错误')}")
else:
error_text = await response.text()
raise Exception(f"HTTP错误 {response.status}: {error_text}")
def test_connection(self) -> Dict[str, Any]:
"""测试连接"""
try:
app_id = self.get_config_parameter("app_id")
secret_key = self.get_config_parameter("secret_key")
api_url = self.get_config_parameter("api_url")
if not app_id or not secret_key or not api_url:
return {
"success": False,
"error": "API配置未完成"
}
return {
"success": True,
"message": "连接配置有效",
"api_url": api_url,
"app_id": app_id,
"secret_key_masked": secret_key[:8] + "***" if len(secret_key) > 8 else "***"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,485 @@
"""工具链管理器 - 支持langchain的工具链模式"""
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
from app.core.tools.base import ToolResult
from app.core.tools.executor import ToolExecutor
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ChainExecutionMode(str, Enum):
"""链执行模式"""
SEQUENTIAL = "sequential" # 顺序执行
PARALLEL = "parallel" # 并行执行
CONDITIONAL = "conditional" # 条件执行
@dataclass
class ChainStep:
"""链步骤定义"""
tool_id: str
parameters: Dict[str, Any]
condition: Optional[str] = None # 执行条件
output_mapping: Optional[Dict[str, str]] = None # 输出映射
error_handling: str = "stop" # 错误处理stop, continue, retry
@dataclass
class ChainDefinition:
"""工具链定义"""
name: str
description: str
steps: List[ChainStep]
execution_mode: ChainExecutionMode = ChainExecutionMode.SEQUENTIAL
global_timeout: Optional[float] = None
retry_policy: Optional[Dict[str, Any]] = None
class ChainExecutionContext:
"""链执行上下文"""
def __init__(self, chain_id: str):
self.chain_id = chain_id
self.variables: Dict[str, Any] = {}
self.step_results: Dict[int, ToolResult] = {}
self.current_step = 0
self.is_completed = False
self.is_failed = False
self.error_message: Optional[str] = None
class ChainManager:
"""工具链管理器 - 支持langchain的工具链模式"""
def __init__(self, executor: ToolExecutor):
"""初始化工具链管理器
Args:
executor: 工具执行器
"""
self.executor = executor
self._chains: Dict[str, ChainDefinition] = {}
self._running_chains: Dict[str, ChainExecutionContext] = {}
def register_chain(self, chain: ChainDefinition) -> bool:
"""注册工具链
Args:
chain: 工具链定义
Returns:
注册是否成功
"""
try:
# 验证工具链定义
validation_result = self._validate_chain(chain)
if not validation_result[0]:
logger.error(f"工具链验证失败: {chain.name}, 错误: {validation_result[1]}")
return False
self._chains[chain.name] = chain
logger.info(f"工具链注册成功: {chain.name}")
return True
except Exception as e:
logger.error(f"工具链注册失败: {chain.name}, 错误: {e}")
return False
def unregister_chain(self, chain_name: str) -> bool:
"""注销工具链
Args:
chain_name: 工具链名称
Returns:
注销是否成功
"""
if chain_name in self._chains:
del self._chains[chain_name]
logger.info(f"工具链注销成功: {chain_name}")
return True
return False
def list_chains(self) -> List[Dict[str, Any]]:
"""列出所有工具链
Returns:
工具链信息列表
"""
chains = []
for name, chain in self._chains.items():
chains.append({
"name": name,
"description": chain.description,
"step_count": len(chain.steps),
"execution_mode": chain.execution_mode.value,
"global_timeout": chain.global_timeout
})
return chains
async def execute_chain(
self,
chain_name: str,
initial_variables: Optional[Dict[str, Any]] = None,
chain_id: Optional[str] = None
) -> Dict[str, Any] | None:
"""执行工具链
Args:
chain_name: 工具链名称
initial_variables: 初始变量
chain_id: 链执行ID可选
Returns:
执行结果
"""
if chain_name not in self._chains:
return {
"success": False,
"error": f"工具链不存在: {chain_name}",
"chain_id": chain_id
}
chain = self._chains[chain_name]
# 生成链ID
if not chain_id:
import uuid
chain_id = f"chain_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ChainExecutionContext(chain_id)
context.variables = initial_variables or {}
self._running_chains[chain_id] = context
try:
logger.info(f"开始执行工具链: {chain_name} (ID: {chain_id})")
# 根据执行模式执行
if chain.execution_mode == ChainExecutionMode.SEQUENTIAL:
result = await self._execute_sequential(chain, context)
elif chain.execution_mode == ChainExecutionMode.PARALLEL:
result = await self._execute_parallel(chain, context)
elif chain.execution_mode == ChainExecutionMode.CONDITIONAL:
result = await self._execute_conditional(chain, context)
else:
raise ValueError(f"不支持的执行模式: {chain.execution_mode}")
logger.info(f"工具链执行完成: {chain_name} (ID: {chain_id})")
return result
except Exception as e:
logger.error(f"工具链执行失败: {chain_name} (ID: {chain_id}), 错误: {e}")
return {
"success": False,
"error": str(e),
"chain_id": chain_id,
"completed_steps": context.current_step,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
finally:
# 清理执行上下文
if chain_id in self._running_chains:
del self._running_chains[chain_id]
async def _execute_sequential(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""顺序执行工具链"""
for i, step in enumerate(chain.steps):
context.current_step = i
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
logger.debug(f"跳过步骤 {i}: 条件不满足")
continue
# 准备参数
parameters = self._prepare_parameters(step.parameters, context)
# 执行工具
try:
result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = result
# 处理输出映射
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 处理执行失败
if not result.success:
if step.error_handling == "stop":
context.is_failed = True
context.error_message = result.error
break
elif step.error_handling == "continue":
logger.warning(f"步骤 {i} 执行失败,继续执行: {result.error}")
continue
elif step.error_handling == "retry":
# 简单重试逻辑
retry_result = await self.executor.execute_tool(
tool_id=step.tool_id,
parameters=parameters
)
context.step_results[i] = retry_result
if not retry_result.success and step.error_handling == "stop":
context.is_failed = True
context.error_message = retry_result.error
break
except Exception as e:
logger.error(f"步骤 {i} 执行异常: {e}")
if step.error_handling == "stop":
context.is_failed = True
context.error_message = str(e)
break
context.is_completed = not context.is_failed
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": context.current_step + 1,
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_parallel(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""并行执行工具链"""
# 准备所有步骤的执行配置
execution_configs = []
for i, step in enumerate(chain.steps):
# 检查执行条件
if step.condition and not self._evaluate_condition(step.condition, context):
continue
parameters = self._prepare_parameters(step.parameters, context)
execution_configs.append({
"step_index": i,
"tool_id": step.tool_id,
"parameters": parameters
})
# 并行执行所有步骤
try:
results = await self.executor.execute_tools_batch(execution_configs)
# 处理结果
for i, result in enumerate(results):
step_index = execution_configs[i]["step_index"]
context.step_results[step_index] = result
# 处理输出映射
step = chain.steps[step_index]
if step.output_mapping and result.success:
self._apply_output_mapping(step.output_mapping, result.data, context)
# 检查是否有失败的步骤
failed_steps = [i for i, result in context.step_results.items() if not result.success]
context.is_completed = len(failed_steps) == 0
if failed_steps:
context.error_message = f"步骤 {failed_steps} 执行失败"
except Exception as e:
context.is_failed = True
context.error_message = str(e)
return {
"success": context.is_completed,
"error": context.error_message,
"chain_id": context.chain_id,
"completed_steps": len(context.step_results),
"total_steps": len(chain.steps),
"final_variables": context.variables,
"step_results": {k: self._serialize_result(v) for k, v in context.step_results.items()}
}
async def _execute_conditional(
self,
chain: ChainDefinition,
context: ChainExecutionContext
) -> Dict[str, Any]:
"""条件执行工具链"""
# 条件执行类似于顺序执行,但更严格地检查条件
return await self._execute_sequential(chain, context)
def _validate_chain(self, chain: ChainDefinition) -> tuple[bool, Optional[str]]:
"""验证工具链定义
Args:
chain: 工具链定义
Returns:
(是否有效, 错误信息)
"""
if not chain.name:
return False, "工具链名称不能为空"
if not chain.steps:
return False, "工具链必须包含至少一个步骤"
for i, step in enumerate(chain.steps):
if not step.tool_id:
return False, f"步骤 {i} 缺少工具ID"
if step.error_handling not in ["stop", "continue", "retry"]:
return False, f"步骤 {i} 错误处理策略无效: {step.error_handling}"
return True, None
def _prepare_parameters(
self,
parameters: Dict[str, Any],
context: ChainExecutionContext
) -> Dict[str, Any]:
"""准备参数(支持变量替换)
Args:
parameters: 原始参数
context: 执行上下文
Returns:
处理后的参数
"""
prepared = {}
for key, value in parameters.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# 变量替换
var_name = value[2:-1]
if var_name in context.variables:
prepared[key] = context.variables[var_name]
else:
prepared[key] = value # 保持原值
else:
prepared[key] = value
return prepared
def _evaluate_condition(
self,
condition: str,
context: ChainExecutionContext
) -> bool:
"""评估执行条件
Args:
condition: 条件表达式
context: 执行上下文
Returns:
条件是否满足
"""
try:
# 简单的条件评估(可以扩展为更复杂的表达式解析)
# 支持格式variable == value, variable != value, variable > value 等
if "==" in condition:
var_name, expected_value = condition.split("==", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) == expected_value
elif "!=" in condition:
var_name, expected_value = condition.split("!=", 1)
var_name = var_name.strip()
expected_value = expected_value.strip().strip('"\'')
return str(context.variables.get(var_name, "")) != expected_value
elif condition in context.variables:
# 简单的布尔检查
return bool(context.variables[condition])
else:
# 默认为真
return True
except Exception as e:
logger.error(f"条件评估失败: {condition}, 错误: {e}")
return False
def _apply_output_mapping(
self,
mapping: Dict[str, str],
output_data: Any,
context: ChainExecutionContext
):
"""应用输出映射
Args:
mapping: 输出映射配置
output_data: 输出数据
context: 执行上下文
"""
try:
if isinstance(output_data, dict):
for source_key, target_var in mapping.items():
if source_key in output_data:
context.variables[target_var] = output_data[source_key]
else:
# 如果输出不是字典,将整个输出映射到指定变量
if "result" in mapping:
context.variables[mapping["result"]] = output_data
except Exception as e:
logger.error(f"输出映射失败: {e}")
def _serialize_result(self, result: ToolResult) -> Dict[str, Any]:
"""序列化工具结果
Args:
result: 工具结果
Returns:
序列化的结果
"""
return {
"success": result.success,
"data": result.data,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time,
"token_usage": result.token_usage,
"metadata": result.metadata
}
def get_running_chains(self) -> List[Dict[str, Any]]:
"""获取正在运行的工具链
Returns:
运行中的工具链列表
"""
chains = []
for chain_id, context in self._running_chains.items():
chains.append({
"chain_id": chain_id,
"current_step": context.current_step,
"is_completed": context.is_completed,
"is_failed": context.is_failed,
"variables_count": len(context.variables),
"completed_steps": len(context.step_results)
})
return chains

View File

@@ -0,0 +1,264 @@
"""工具配置管理器 - 管理工具配置的加载和验证"""
import json
from pathlib import Path
from typing import Dict, Any, Optional
from pydantic import BaseModel, ValidationError
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ToolConfigSchema(BaseModel):
"""工具配置基础Schema"""
name: str
description: str
tool_type: str
version: str = "1.0.0"
enabled: bool = True
parameters: Dict[str, Any] = {}
tags: list[str] = []
class Config:
extra = "allow"
class BuiltinToolConfigSchema(ToolConfigSchema):
"""内置工具配置Schema"""
tool_class: str
tool_type: str = "builtin"
class CustomToolConfigSchema(ToolConfigSchema):
"""自定义工具配置Schema"""
schema_url: Optional[str] = None
schema_content: Optional[Dict[str, Any]] = None
auth_type: str = "none"
auth_config: Dict[str, Any] = {}
base_url: Optional[str] = None
timeout: int = 30
tool_type: str = "custom"
class MCPToolConfigSchema(ToolConfigSchema):
"""MCP工具配置Schema"""
server_url: str
connection_config: Dict[str, Any] = {}
available_tools: list[str] = []
tool_type: str = "mcp"
class ConfigManager:
"""工具配置管理器"""
def __init__(self, config_dir: Optional[str] = None):
"""初始化配置管理器
Args:
config_dir: 配置文件目录,默认使用系统配置
"""
self.config_dir = Path(config_dir or self._get_default_config_dir())
self.config_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"配置管理器初始化完成,配置目录: {self.config_dir}")
def _get_default_config_dir(self) -> str:
"""获取默认配置目录"""
# 获取tools目录下的configs子目录
tools_dir = Path(__file__).parent
return str(tools_dir / "configs")
def load_builtin_tool_configs(self) -> Dict[str, BuiltinToolConfigSchema]:
"""加载内置工具配置
Returns:
内置工具配置字典
"""
configs = {}
builtin_dir = self.config_dir / "builtin"
if not builtin_dir.exists():
logger.info("内置工具配置目录不存在,创建默认配置")
self._create_default_builtin_configs(builtin_dir)
for config_file in builtin_dir.glob("*.json"):
try:
config_data = self._load_config_file(config_file)
config = BuiltinToolConfigSchema(**config_data)
configs[config.name] = config
logger.debug(f"加载内置工具配置: {config.name}")
except Exception as e:
logger.error(f"加载内置工具配置失败: {config_file}, 错误: {e}")
return configs
def load_builtin_tools_config(self) -> Dict[str, Any]:
"""加载全局内置工具配置(兼容原有接口)
Returns:
内置工具配置字典
"""
config_file = self.config_dir / "builtin_tools.json"
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载内置工具配置失败: {e}")
return {}
def ensure_builtin_tools_initialized(self, tenant_id, db_session, tool_config_model, builtin_tool_config_model, tool_type_enum, tool_status_enum):
"""确保内置工具已初始化到数据库
Args:
tenant_id: 租户ID
db_session: 数据库会话
tool_config_model: ToolConfig模型类
builtin_tool_config_model: BuiltinToolConfig模型类
tool_type_enum: ToolType枚举
tool_status_enum: ToolStatus枚举
"""
# 检查是否已初始化
existing_count = db_session.query(tool_config_model).filter(
tool_config_model.tenant_id == tenant_id,
tool_config_model.tool_type == tool_type_enum.BUILTIN
).count()
if existing_count > 0:
return # 已初始化
# 加载全局配置
builtin_tools = self.load_builtin_tools_config()
# 为租户创建内置工具记录
for tool_key, tool_info in builtin_tools.items():
# 设置初始状态
initial_status = tool_status_enum.ACTIVE.value if not tool_info['requires_config'] else tool_status_enum.INACTIVE.value
tool_config = tool_config_model(
name=tool_info['name'],
description=tool_info['description'],
tool_type=tool_type_enum.BUILTIN,
tenant_id=tenant_id,
status=initial_status
)
db_session.add(tool_config)
db_session.flush()
builtin_config = builtin_tool_config_model(
id=tool_config.id,
tool_class=tool_info['tool_class'],
parameters={}
)
db_session.add(builtin_config)
db_session.commit()
logger.info(f"租户 {tenant_id} 的内置工具初始化完成")
def save_tool_config(self, config: ToolConfigSchema, tool_type: str) -> bool:
"""保存工具配置
Args:
config: 工具配置
tool_type: 工具类型
Returns:
保存是否成功
"""
try:
config_dir = self.config_dir / tool_type
config_dir.mkdir(parents=True, exist_ok=True)
config_file = config_dir / f"{config.name}.json"
config_data = config.model_dump()
with open(config_file, 'w', encoding='utf-8') as f:
json.dump(config_data, f, indent=2, ensure_ascii=False)
logger.info(f"工具配置保存成功: {config.name} ({tool_type})")
return True
except Exception as e:
logger.error(f"工具配置保存失败: {config.name}, 错误: {e}")
return False
def delete_tool_config(self, tool_name: str, tool_type: str) -> bool:
"""删除工具配置
Args:
tool_name: 工具名称
tool_type: 工具类型
Returns:
删除是否成功
"""
try:
config_file = self.config_dir / tool_type / f"{tool_name}.json"
if config_file.exists():
config_file.unlink()
logger.info(f"工具配置删除成功: {tool_name} ({tool_type})")
return True
else:
logger.warning(f"工具配置文件不存在: {tool_name} ({tool_type})")
return False
except Exception as e:
logger.error(f"工具配置删除失败: {tool_name}, 错误: {e}")
return False
def validate_config(self, config_data: Dict[str, Any], tool_type: str) -> tuple[bool, Optional[str]]:
"""验证工具配置
Args:
config_data: 配置数据
tool_type: 工具类型
Returns:
(是否有效, 错误信息)
"""
try:
schema_map = {
"builtin": BuiltinToolConfigSchema,
"custom": CustomToolConfigSchema,
"mcp": MCPToolConfigSchema
}
schema_class = schema_map.get(tool_type)
if not schema_class:
return False, f"不支持的工具类型: {tool_type}"
# 验证配置
schema_class(**config_data)
return True, None
except ValidationError as e:
error_msg = "; ".join([f"{err['loc'][0]}: {err['msg']}" for err in e.errors()])
return False, f"配置验证失败: {error_msg}"
except Exception as e:
return False, f"配置验证异常: {str(e)}"
def _load_config_file(self, config_file: Path) -> Dict[str, Any]:
"""加载配置文件
Args:
config_file: 配置文件路径
Returns:
配置数据字典
"""
try:
with open(config_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"加载配置文件失败: {config_file}, 错误: {e}")
raise
def _create_default_builtin_configs(self, builtin_dir: Path):
"""创建默认内置工具配置
Args:
builtin_dir: 内置工具配置目录
"""
builtin_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"内置工具配置目录已创建: {builtin_dir}")
# 配置文件已经通过其他方式创建,这里只需要确保目录存在

View File

@@ -0,0 +1,14 @@
{
"name": "baidu_search_tool",
"description": "百度搜索工具 - 网络搜索:提供网页搜索、新闻搜索、图片搜索功能",
"tool_type": "builtin",
"tool_class": "BaiduSearchTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": "",
"secret_key": "",
"search_type": "web"
},
"tags": ["search", "web", "baidu", "builtin"]
}

View File

@@ -0,0 +1,12 @@
{
"name": "datetime_tool",
"description": "时间工具 - 日期时间处理:提供时间格式转化、时区转换、时间戳转换、时间计算",
"tool_type": "builtin",
"tool_class": "DateTimeTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"timezone": "UTC"
},
"tags": ["time", "utility", "builtin"]
}

View File

@@ -0,0 +1,12 @@
{
"name": "json_tool",
"description": "JSON工具 - 数据格式处理提供JSON格式化、压缩、验证、格式转换",
"tool_type": "builtin",
"tool_class": "JsonTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"indent": 2
},
"tags": ["json", "data", "utility", "builtin"]
}

View File

@@ -0,0 +1,14 @@
{
"name": "mineru_tool",
"description": "MinerU PDF解析工具 - 文档处理提供PDF解析、表格提取、图片识别、文本提取功能",
"tool_type": "builtin",
"tool_class": "MinerUTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": "",
"parse_mode": "auto",
"timeout": 60
},
"tags": ["pdf", "document", "ocr", "builtin"]
}

View File

@@ -0,0 +1,14 @@
{
"name": "textin_tool",
"description": "TextIn OCR工具 - 图像识别提供通用OCR、手写识别、多语言支持功能",
"tool_type": "builtin",
"tool_class": "TextInTool",
"version": "1.0.0",
"enabled": true,
"parameters": {
"app_id": "",
"language": "auto",
"recognition_mode": "general"
},
"tags": ["ocr", "image", "text", "builtin"]
}

View File

@@ -0,0 +1,60 @@
{
"datetime": {
"name": "时间工具",
"description": "获取当前时间、日期计算",
"tool_class": "DateTimeTool",
"category": "utility",
"requires_config": false,
"version": "1.0.0",
"enabled": true,
"parameters": {}
},
"json_converter": {
"name": "JSON转换工具",
"description": "JSON数据格式化和转换",
"tool_class": "JsonTool",
"category": "utility",
"requires_config": false,
"version": "1.0.0",
"enabled": true,
"parameters": {}
},
"baidu_search": {
"name": "百度搜索",
"description": "百度网页搜索服务",
"tool_class": "BaiduSearchTool",
"category": "search",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "百度搜索API密钥", "sensitive": true, "required": true}
}
},
"mineru": {
"name": "MinerU",
"description": "PDF文档解析工具",
"tool_class": "MinerUTool",
"category": "document",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "MinerU API密钥", "sensitive": true, "required": true},
"base_url": {"type": "string", "description": "API地址", "default": "https://api.mineru.com"}
}
},
"textin": {
"name": "TextIn",
"description": "OCR文字识别服务",
"tool_class": "TextInTool",
"category": "ocr",
"requires_config": true,
"version": "1.0.0",
"enabled": true,
"parameters": {
"api_key": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true},
"api_secret": {"type": "string", "description": "TextIn API密钥", "sensitive": true, "required": true}
}
}
}

View File

@@ -0,0 +1,11 @@
"""自定义工具模块"""
from .base import CustomTool
from .schema_parser import OpenAPISchemaParser
from .auth_manager import AuthManager
__all__ = [
"CustomTool",
"OpenAPISchemaParser",
"AuthManager"
]

View File

@@ -0,0 +1,525 @@
"""认证管理器 - 处理自定义工具的认证配置"""
import base64
import hashlib
import hmac
import time
from typing import Dict, Any, Tuple
from urllib.parse import quote
import aiohttp
from app.models.tool_model import AuthType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class AuthManager:
"""认证管理器 - 支持多种认证方式"""
def __init__(self):
"""初始化认证管理器"""
self.supported_auth_types = [
AuthType.NONE,
AuthType.API_KEY,
AuthType.BEARER_TOKEN
]
def validate_auth_config(self, auth_type: AuthType, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
try:
if auth_type not in self.supported_auth_types:
return False, f"不支持的认证类型: {auth_type}"
if auth_type == AuthType.NONE:
return True, ""
elif auth_type == AuthType.API_KEY:
return self._validate_api_key_config(auth_config)
elif auth_type == AuthType.BEARER_TOKEN:
return self._validate_bearer_token_config(auth_config)
return False, "未知的认证类型"
except Exception as e:
return False, f"验证认证配置时出错: {e}"
def _validate_api_key_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证API Key认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
api_key = auth_config.get("api_key")
if not api_key:
return False, "API Key不能为空"
if not isinstance(api_key, str):
return False, "API Key必须是字符串"
# 验证key名称
key_name = auth_config.get("key_name", "X-API-Key")
if not isinstance(key_name, str):
return False, "API Key名称必须是字符串"
# 验证位置
key_location = auth_config.get("location", "header")
if key_location not in ["header", "query", "cookie"]:
return False, "API Key位置必须是 header、query 或 cookie"
return True, ""
def _validate_bearer_token_config(self, auth_config: Dict[str, Any]) -> Tuple[bool, str]:
"""验证Bearer Token认证配置
Args:
auth_config: 认证配置
Returns:
(是否有效, 错误信息)
"""
token = auth_config.get("token")
if not token:
return False, "Bearer Token不能为空"
if not isinstance(token, str):
return False, "Bearer Token必须是字符串"
return True, ""
def apply_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用认证到请求
Args:
auth_type: 认证类型
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
try:
if auth_type == AuthType.NONE:
return url, headers, params
elif auth_type == AuthType.API_KEY:
return self._apply_api_key_auth(auth_config, url, headers, params)
elif auth_type == AuthType.BEARER_TOKEN:
return self._apply_bearer_token_auth(auth_config, url, headers, params)
else:
logger.warning(f"不支持的认证类型: {auth_type}")
return url, headers, params
except Exception as e:
logger.error(f"应用认证时出错: {e}")
return url, headers, params
def _apply_api_key_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用API Key认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
api_key = auth_config.get("api_key")
key_name = auth_config.get("key_name", "X-API-Key")
location = auth_config.get("location", "header")
if location == "header":
headers[key_name] = api_key
elif location == "query":
# 添加到URL查询参数
separator = "&" if "?" in url else "?"
encoded_key = quote(str(api_key))
url += f"{separator}{key_name}={encoded_key}"
elif location == "cookie":
# 添加到Cookie头
cookie_value = f"{key_name}={api_key}"
if "Cookie" in headers:
headers["Cookie"] += f"; {cookie_value}"
else:
headers["Cookie"] = cookie_value
return url, headers, params
def _apply_bearer_token_auth(
self,
auth_config: Dict[str, Any],
url: str,
headers: Dict[str, str],
params: Dict[str, Any]
) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
"""应用Bearer Token认证
Args:
auth_config: 认证配置
url: 请求URL
headers: 请求头
params: 请求参数
Returns:
(修改后的URL, 修改后的headers, 修改后的params)
"""
token = auth_config.get("token")
headers["Authorization"] = f"Bearer {token}"
return url, headers, params
def encrypt_auth_config(self, auth_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""加密认证配置中的敏感信息
Args:
auth_config: 认证配置
encryption_key: 加密密钥
Returns:
加密后的认证配置
"""
try:
encrypted_config = auth_config.copy()
# 需要加密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in encrypted_config:
value = encrypted_config[field]
if isinstance(value, str) and value:
encrypted_value = self._encrypt_string(value, encryption_key)
encrypted_config[field] = encrypted_value
encrypted_config[f"{field}_encrypted"] = True
return encrypted_config
except Exception as e:
logger.error(f"加密认证配置失败: {e}")
return auth_config
def decrypt_auth_config(self, encrypted_config: Dict[str, Any], encryption_key: str) -> Dict[str, Any]:
"""解密认证配置中的敏感信息
Args:
encrypted_config: 加密的认证配置
encryption_key: 解密密钥
Returns:
解密后的认证配置
"""
try:
decrypted_config = encrypted_config.copy()
# 需要解密的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in decrypted_config and decrypted_config.get(f"{field}_encrypted"):
encrypted_value = decrypted_config[field]
if isinstance(encrypted_value, str) and encrypted_value:
decrypted_value = self._decrypt_string(encrypted_value, encryption_key)
decrypted_config[field] = decrypted_value
# 移除加密标记
decrypted_config.pop(f"{field}_encrypted", None)
return decrypted_config
except Exception as e:
logger.error(f"解密认证配置失败: {e}")
return encrypted_config
def _encrypt_string(self, value: str, key: str) -> str:
"""加密字符串
Args:
value: 要加密的字符串
key: 加密密钥
Returns:
加密后的字符串Base64编码
"""
try:
# 使用HMAC-SHA256进行简单加密
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
# 生成HMAC
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
signature = hmac_obj.hexdigest()
# 组合原始值和签名然后Base64编码
combined = f"{value}:{signature}"
encrypted = base64.b64encode(combined.encode('utf-8')).decode('utf-8')
return encrypted
except Exception as e:
logger.error(f"加密字符串失败: {e}")
return value
def _decrypt_string(self, encrypted_value: str, key: str) -> str:
"""解密字符串
Args:
encrypted_value: 加密的字符串
key: 解密密钥
Returns:
解密后的字符串
"""
try:
# Base64解码
decoded = base64.b64decode(encrypted_value.encode('utf-8')).decode('utf-8')
# 分离原始值和签名
if ':' not in decoded:
return encrypted_value # 可能不是加密的值
value, signature = decoded.rsplit(':', 1)
# 验证签名
key_bytes = key.encode('utf-8')
value_bytes = value.encode('utf-8')
hmac_obj = hmac.new(key_bytes, value_bytes, hashlib.sha256)
expected_signature = hmac_obj.hexdigest()
if signature == expected_signature:
return value
else:
logger.warning("解密时签名验证失败")
return encrypted_value
except Exception as e:
logger.error(f"解密字符串失败: {e}")
return encrypted_value
def test_authentication(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str = None
) -> Dict[str, Any]:
"""测试认证配置
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL可选
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 如果没有测试URL只验证配置
if not test_url:
return {
"success": True,
"message": "认证配置有效",
"auth_type": auth_type.value
}
# 构建测试请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
return {
"success": True,
"message": "认证配置测试成功",
"auth_type": auth_type.value,
"test_url": test_url,
"headers": {k: v for k, v in headers.items() if k != "Authorization"}, # 不返回敏感信息
"has_auth_header": "Authorization" in headers
}
except Exception as e:
return {
"success": False,
"error": str(e),
"auth_type": auth_type.value if auth_type else "unknown"
}
async def test_authentication_with_request(
self,
auth_type: AuthType,
auth_config: Dict[str, Any],
test_url: str,
timeout: int = 10
) -> Dict[str, Any]:
"""通过实际HTTP请求测试认证
Args:
auth_type: 认证类型
auth_config: 认证配置
test_url: 测试URL
timeout: 超时时间(秒)
Returns:
测试结果
"""
try:
# 验证配置
is_valid, error_msg = self.validate_auth_config(auth_type, auth_config)
if not is_valid:
return {
"success": False,
"error": error_msg,
"auth_type": auth_type.value
}
# 构建请求
headers = {"User-Agent": "AuthManager-Test/1.0"}
params = {}
# 应用认证
test_url, headers, params = self.apply_authentication(
auth_type, auth_config, test_url, headers, params
)
# 发送测试请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(test_url, headers=headers) as response:
status_code = response.status
# 根据状态码判断认证是否成功
if status_code == 200:
return {
"success": True,
"message": "认证测试成功",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 401:
return {
"success": False,
"error": "认证失败 - 401 Unauthorized",
"status_code": status_code,
"auth_type": auth_type.value
}
elif status_code == 403:
return {
"success": False,
"error": "认证失败 - 403 Forbidden",
"status_code": status_code,
"auth_type": auth_type.value
}
else:
return {
"success": True,
"message": f"请求成功,状态码: {status_code}",
"status_code": status_code,
"auth_type": auth_type.value
}
except aiohttp.ClientError as e:
return {
"success": False,
"error": f"网络请求失败: {e}",
"auth_type": auth_type.value
}
except Exception as e:
return {
"success": False,
"error": f"测试认证时出错: {e}",
"auth_type": auth_type.value
}
def get_auth_config_template(self, auth_type: AuthType) -> Dict[str, Any]:
"""获取认证配置模板
Args:
auth_type: 认证类型
Returns:
配置模板
"""
templates = {
AuthType.NONE: {},
AuthType.API_KEY: {
"api_key": "",
"key_name": "X-API-Key",
"location": "header", # header, query, cookie
"description": "API Key认证配置"
},
AuthType.BEARER_TOKEN: {
"token": "",
"description": "Bearer Token认证配置"
}
}
return templates.get(auth_type, {})
def mask_sensitive_config(self, auth_config: Dict[str, Any]) -> Dict[str, Any]:
"""遮蔽认证配置中的敏感信息
Args:
auth_config: 认证配置
Returns:
遮蔽敏感信息后的配置
"""
masked_config = auth_config.copy()
# 需要遮蔽的字段
sensitive_fields = ["api_key", "token", "secret", "password"]
for field in sensitive_fields:
if field in masked_config:
value = masked_config[field]
if isinstance(value, str) and len(value) > 4:
# 只显示前2位和后2位
masked_config[field] = f"{value[:2]}***{value[-2:]}"
elif isinstance(value, str) and value:
masked_config[field] = "***"
return masked_config

View File

@@ -0,0 +1,318 @@
"""自定义工具基类"""
import time
from typing import Dict, Any, List, Optional
import aiohttp
from urllib.parse import urljoin
from app.models.tool_model import ToolType, AuthType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class CustomTool(BaseTool):
"""自定义工具 - 基于OpenAPI schema的工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化自定义工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.schema_content = config.get("schema_content", {})
self.schema_url = config.get("schema_url")
self.auth_type = AuthType(config.get("auth_type", "none"))
self.auth_config = config.get("auth_config", {})
self.base_url = config.get("base_url", "")
self.timeout = config.get("timeout", 30)
# 解析schema
self._parsed_operations = self._parse_openapi_schema()
@property
def name(self) -> str:
"""工具名称"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("title", f"custom_tool_{self.tool_id[:8]}")
return f"custom_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
if self.schema_content:
info = self.schema_content.get("info", {})
return info.get("description", "自定义API工具")
return "自定义API工具"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.CUSTOM
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加操作选择参数
if len(self._parsed_operations) > 1:
params.append(ToolParameter(
name="operation",
type=ParameterType.STRING,
description="要执行的操作",
required=True,
enum=list(self._parsed_operations.keys())
))
# 添加通用参数(基于第一个操作的参数)
if self._parsed_operations:
first_operation = next(iter(self._parsed_operations.values()))
for param_name, param_info in first_operation.get("parameters", {}).items():
params.append(ToolParameter(
name=param_name,
type=self._convert_openapi_type(param_info.get("type", "string")),
description=param_info.get("description", ""),
required=param_info.get("required", False),
default=param_info.get("default"),
enum=param_info.get("enum"),
minimum=param_info.get("minimum"),
maximum=param_info.get("maximum"),
pattern=param_info.get("pattern")
))
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行自定义工具"""
start_time = time.time()
try:
# 确定要执行的操作
operation_name = kwargs.get("operation")
if not operation_name and len(self._parsed_operations) == 1:
operation_name = next(iter(self._parsed_operations.keys()))
if not operation_name or operation_name not in self._parsed_operations:
raise ValueError(f"无效的操作: {operation_name}")
operation = self._parsed_operations[operation_name]
# 构建请求
url = self._build_request_url(operation, kwargs)
headers = self._build_request_headers(operation)
data = self._build_request_data(operation, kwargs)
# 发送HTTP请求
result = await self._send_http_request(
method=operation["method"],
url=url,
headers=headers,
data=data
)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="CUSTOM_TOOL_ERROR",
execution_time=execution_time
)
def _parse_openapi_schema(self) -> Dict[str, Any]:
"""解析OpenAPI schema"""
operations = {}
if not self.schema_content:
return operations
paths = self.schema_content.get("paths", {})
for path, path_item in paths.items():
for method, operation in path_item.items():
if method.lower() in ["get", "post", "put", "delete", "patch"]:
operation_id = operation.get("operationId", f"{method}_{path.replace('/', '_')}")
# 解析参数
parameters = {}
if "parameters" in operation:
for param in operation["parameters"]:
param_name = param.get("name")
param_schema = param.get("schema", {})
parameters[param_name] = {
"type": param_schema.get("type", "string"),
"description": param.get("description", ""),
"required": param.get("required", False),
"in": param.get("in", "query"),
**param_schema
}
# 解析请求体
request_body = None
if "requestBody" in operation:
content = operation["requestBody"].get("content", {})
if "application/json" in content:
request_body = content["application/json"].get("schema", {})
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": parameters,
"request_body": request_body
}
return operations
def _convert_openapi_type(self, openapi_type: str) -> ParameterType:
"""转换OpenAPI类型到内部类型"""
type_mapping = {
"string": ParameterType.STRING,
"integer": ParameterType.INTEGER,
"number": ParameterType.NUMBER,
"boolean": ParameterType.BOOLEAN,
"array": ParameterType.ARRAY,
"object": ParameterType.OBJECT
}
return type_mapping.get(openapi_type, ParameterType.STRING)
def _build_request_url(self, operation: Dict[str, Any], params: Dict[str, Any]) -> str:
"""构建请求URL"""
path = operation["path"]
# 替换路径参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "path" and param_name in params:
path = path.replace(f"{{{param_name}}}", str(params[param_name]))
# 构建完整URL
if self.base_url:
url = urljoin(self.base_url, path.lstrip("/"))
else:
# 从schema中获取服务器URL
servers = self.schema_content.get("servers", [])
if servers:
base_url = servers[0].get("url", "")
url = urljoin(base_url, path.lstrip("/"))
else:
url = path
# 添加查询参数
query_params = {}
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("in") == "query" and param_name in params:
query_params[param_name] = params[param_name]
if query_params:
from urllib.parse import urlencode
url += "?" + urlencode(query_params)
return url
def _build_request_headers(self, operation: Dict[str, Any]) -> Dict[str, str]:
"""构建请求头"""
headers = {
"Content-Type": "application/json",
"User-Agent": "CustomTool/1.0"
}
# 添加认证头
if self.auth_type == AuthType.API_KEY:
api_key = self.auth_config.get("api_key")
key_name = self.auth_config.get("key_name", "X-API-Key")
if api_key:
headers[key_name] = api_key
elif self.auth_type == AuthType.BEARER_TOKEN:
token = self.auth_config.get("token")
if token:
headers["Authorization"] = f"Bearer {token}"
return headers
def _build_request_data(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""构建请求数据"""
if operation["method"] in ["POST", "PUT", "PATCH"]:
request_body = operation.get("request_body")
if request_body:
# 构建请求体数据
data = {}
properties = request_body.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name in params:
data[prop_name] = params[prop_name]
return data if data else None
return None
async def _send_http_request(
self,
method: str,
url: str,
headers: Dict[str, str],
data: Optional[Dict[str, Any]] = None
) -> Any:
"""发送HTTP请求"""
timeout = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(timeout=timeout) as session:
kwargs = {
"headers": headers
}
if data and method in ["POST", "PUT", "PATCH"]:
kwargs["json"] = data
async with session.request(method, url, **kwargs) as response:
if response.status >= 400:
error_text = await response.text()
raise Exception(f"HTTP {response.status}: {error_text}")
# 尝试解析JSON响应
try:
return await response.json()
except Exception as e:
return await response.text()
@classmethod
def from_url(cls, schema_url: str, auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从URL导入OpenAPI schema创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_url": schema_url,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
# 这里应该异步加载schema为了简化暂时返回空配置
return cls(tool_id, config)
@classmethod
def from_schema(cls, schema_dict: Dict[str, Any], auth_config: Dict[str, Any], tool_id: str = None) -> 'CustomTool':
"""从schema字典创建工具"""
import uuid
if not tool_id:
tool_id = str(uuid.uuid4())
config = {
"schema_content": schema_dict,
"auth_config": auth_config,
"auth_type": auth_config.get("type", "none")
}
return cls(tool_id, config)

View File

@@ -0,0 +1,477 @@
"""OpenAPI Schema解析器"""
import json
import yaml
from typing import Dict, Any, List, Optional, Tuple
from urllib.parse import urlparse
import aiohttp
import asyncio
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class OpenAPISchemaParser:
"""OpenAPI Schema解析器 - 解析OpenAPI 3.0规范"""
def __init__(self):
"""初始化解析器"""
self.supported_versions = ["3.0.0", "3.0.1", "3.0.2", "3.0.3", "3.1.0"]
async def parse_from_url(self, schema_url: str, timeout: int = 30) -> Tuple[bool, Dict[str, Any], str]:
"""从URL解析OpenAPI schema
Args:
schema_url: Schema URL
timeout: 超时时间(秒)
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 验证URL格式
parsed_url = urlparse(schema_url)
if not parsed_url.scheme or not parsed_url.netloc:
return False, {}, "无效的URL格式"
# 下载schema
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.get(schema_url) as response:
if response.status != 200:
return False, {}, f"HTTP错误: {response.status}"
content_type = response.headers.get('content-type', '').lower()
content = await response.text()
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except asyncio.TimeoutError:
return False, {}, "请求超时"
except Exception as e:
logger.error(f"从URL解析schema失败: {schema_url}, 错误: {e}")
return False, {}, str(e)
def parse_from_content(self, content: str, content_type: str = "application/json") -> Tuple[bool, Dict[str, Any], str]:
"""从内容解析OpenAPI schema
Args:
content: Schema内容
content_type: 内容类型
Returns:
(是否成功, schema内容, 错误信息)
"""
try:
# 解析内容
schema_dict = self._parse_content(content, content_type)
if not schema_dict:
return False, {}, "无法解析schema内容"
# 验证schema
is_valid, error_msg = self.validate_schema(schema_dict)
if not is_valid:
return False, {}, error_msg
return True, schema_dict, ""
except Exception as e:
logger.error(f"解析schema内容失败: {e}")
return False, {}, str(e)
def _parse_content(self, content: str, content_type: str) -> Optional[Dict[str, Any]]:
"""解析内容为字典
Args:
content: 内容字符串
content_type: 内容类型
Returns:
解析后的字典失败返回None
"""
try:
# 根据内容类型解析
if 'json' in content_type:
return json.loads(content)
elif 'yaml' in content_type or 'yml' in content_type:
return yaml.safe_load(content)
else:
# 尝试自动检测格式
try:
return json.loads(content)
except json.JSONDecodeError:
try:
return yaml.safe_load(content)
except yaml.YAMLError:
return None
except Exception as e:
logger.error(f"解析内容失败: {e}")
return None
def validate_schema(self, schema_dict: Dict[str, Any]) -> Tuple[bool, str]:
"""验证OpenAPI schema
Args:
schema_dict: Schema字典
Returns:
(是否有效, 错误信息)
"""
try:
# 检查基本结构
if not isinstance(schema_dict, dict):
return False, "Schema必须是JSON对象"
# 检查OpenAPI版本
openapi_version = schema_dict.get("openapi")
if not openapi_version:
return False, "缺少openapi版本字段"
if openapi_version not in self.supported_versions:
return False, f"不支持的OpenAPI版本: {openapi_version}"
# 检查必需字段
required_fields = ["info", "paths"]
for field in required_fields:
if field not in schema_dict:
return False, f"缺少必需字段: {field}"
# 验证info字段
info = schema_dict.get("info", {})
if not isinstance(info, dict):
return False, "info字段必须是对象"
if "title" not in info:
return False, "info.title字段是必需的"
# 验证paths字段
paths = schema_dict.get("paths", {})
if not isinstance(paths, dict):
return False, "paths字段必须是对象"
# 验证至少有一个路径
if not paths:
return False, "至少需要定义一个API路径"
return True, ""
except Exception as e:
return False, f"验证schema时出错: {e}"
def extract_tool_info(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""从schema提取工具信息
Args:
schema_dict: Schema字典
Returns:
工具信息字典
"""
info = schema_dict.get("info", {})
return {
"name": info.get("title", "Custom API Tool"),
"description": info.get("description", ""),
"version": info.get("version", "1.0.0"),
"servers": schema_dict.get("servers", []),
"operations": self._extract_operations(schema_dict)
}
def _extract_operations(self, schema_dict: Dict[str, Any]) -> Dict[str, Any]:
"""提取API操作信息
Args:
schema_dict: Schema字典
Returns:
操作信息字典
"""
operations = {}
paths = schema_dict.get("paths", {})
for path, path_item in paths.items():
if not isinstance(path_item, dict):
continue
for method, operation in path_item.items():
if method.lower() not in ["get", "post", "put", "delete", "patch", "head", "options"]:
continue
if not isinstance(operation, dict):
continue
# 生成操作ID
operation_id = operation.get("operationId")
if not operation_id:
operation_id = f"{method.lower()}_{path.replace('/', '_').replace('{', '').replace('}', '')}"
# 提取操作信息
operations[operation_id] = {
"method": method.upper(),
"path": path,
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": self._extract_parameters(operation),
"request_body": self._extract_request_body(operation),
"responses": self._extract_responses(operation),
"tags": operation.get("tags", [])
}
return operations
def _extract_parameters(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取操作参数
Args:
operation: 操作定义
Returns:
参数信息字典
"""
parameters = {}
for param in operation.get("parameters", []):
if not isinstance(param, dict):
continue
param_name = param.get("name")
if not param_name:
continue
param_schema = param.get("schema", {})
parameters[param_name] = {
"name": param_name,
"in": param.get("in", "query"),
"description": param.get("description", ""),
"required": param.get("required", False),
"type": param_schema.get("type", "string"),
"format": param_schema.get("format"),
"enum": param_schema.get("enum"),
"default": param_schema.get("default"),
"minimum": param_schema.get("minimum"),
"maximum": param_schema.get("maximum"),
"pattern": param_schema.get("pattern"),
"example": param.get("example") or param_schema.get("example")
}
return parameters
def _extract_request_body(self, operation: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""提取请求体信息
Args:
operation: 操作定义
Returns:
请求体信息如果没有返回None
"""
request_body = operation.get("requestBody")
if not request_body:
return None
content = request_body.get("content", {})
# 优先使用application/json
if "application/json" in content:
schema = content["application/json"].get("schema", {})
elif content:
# 使用第一个可用的内容类型
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema", {})
else:
return None
return {
"description": request_body.get("description", ""),
"required": request_body.get("required", False),
"schema": schema,
"content_types": list(content.keys())
}
def _extract_responses(self, operation: Dict[str, Any]) -> Dict[str, Any]:
"""提取响应信息
Args:
operation: 操作定义
Returns:
响应信息字典
"""
responses = {}
for status_code, response in operation.get("responses", {}).items():
if not isinstance(response, dict):
continue
content = response.get("content", {})
schema = None
# 尝试获取响应schema
if "application/json" in content:
schema = content["application/json"].get("schema")
elif content:
first_content_type = next(iter(content.keys()))
schema = content[first_content_type].get("schema")
responses[status_code] = {
"description": response.get("description", ""),
"schema": schema,
"content_types": list(content.keys()) if content else []
}
return responses
def generate_tool_parameters(self, operations: Dict[str, Any]) -> List[Dict[str, Any]]:
"""生成工具参数定义
Args:
operations: 操作信息字典
Returns:
参数定义列表
"""
parameters = []
# 如果有多个操作,添加操作选择参数
if len(operations) > 1:
parameters.append({
"name": "operation",
"type": "string",
"description": "要执行的操作",
"required": True,
"enum": list(operations.keys())
})
# 收集所有参数(去重)
all_params = {}
for operation_id, operation in operations.items():
# 路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_name not in all_params:
all_params[param_name] = {
"name": param_name,
"type": param_info.get("type", "string"),
"description": param_info.get("description", ""),
"required": param_info.get("required", False),
"enum": param_info.get("enum"),
"default": param_info.get("default"),
"minimum": param_info.get("minimum"),
"maximum": param_info.get("maximum"),
"pattern": param_info.get("pattern")
}
# 请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
properties = schema.get("properties", {})
for prop_name, prop_schema in properties.items():
if prop_name not in all_params:
all_params[prop_name] = {
"name": prop_name,
"type": prop_schema.get("type", "string"),
"description": prop_schema.get("description", ""),
"required": prop_name in schema.get("required", []),
"enum": prop_schema.get("enum"),
"default": prop_schema.get("default"),
"minimum": prop_schema.get("minimum"),
"maximum": prop_schema.get("maximum"),
"pattern": prop_schema.get("pattern")
}
# 转换为参数列表
parameters.extend(all_params.values())
return parameters
def validate_operation_parameters(self, operation: Dict[str, Any], params: Dict[str, Any]) -> Tuple[bool, List[str]]:
"""验证操作参数
Args:
operation: 操作定义
params: 输入参数
Returns:
(是否有效, 错误信息列表)
"""
errors = []
# 验证路径参数和查询参数
for param_name, param_info in operation.get("parameters", {}).items():
if param_info.get("required", False) and param_name not in params:
errors.append(f"缺少必需参数: {param_name}")
if param_name in params:
value = params[param_name]
param_type = param_info.get("type", "string")
# 类型验证
if not self._validate_parameter_type(value, param_type):
errors.append(f"参数 {param_name} 类型错误,期望: {param_type}")
# 枚举验证
enum_values = param_info.get("enum")
if enum_values and value not in enum_values:
errors.append(f"参数 {param_name} 值无效,必须是: {enum_values}")
# 验证请求体参数
request_body = operation.get("request_body")
if request_body:
schema = request_body.get("schema", {})
required_props = schema.get("required", [])
properties = schema.get("properties", {})
for prop_name in required_props:
if prop_name not in params:
errors.append(f"缺少必需的请求体参数: {prop_name}")
for prop_name, value in params.items():
if prop_name in properties:
prop_schema = properties[prop_name]
prop_type = prop_schema.get("type", "string")
if not self._validate_parameter_type(value, prop_type):
errors.append(f"请求体参数 {prop_name} 类型错误,期望: {prop_type}")
return len(errors) == 0, errors
def _validate_parameter_type(self, value: Any, expected_type: str) -> bool:
"""验证参数类型
Args:
value: 参数值
expected_type: 期望类型
Returns:
是否类型匹配
"""
if value is None:
return True
type_mapping = {
"string": str,
"integer": int,
"number": (int, float),
"boolean": bool,
"array": list,
"object": dict
}
expected_python_type = type_mapping.get(expected_type)
if expected_python_type:
return isinstance(value, expected_python_type)
return True

View File

@@ -0,0 +1,501 @@
"""工具执行器 - 负责工具的实际调用和执行管理"""
import asyncio
import uuid
import time
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import ToolExecution, ExecutionStatus
from app.core.tools.base import BaseTool, ToolResult
from app.core.tools.registry import ToolRegistry
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ExecutionContext:
"""执行上下文"""
def __init__(
self,
execution_id: str,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
):
self.execution_id = execution_id
self.tool_id = tool_id
self.user_id = user_id
self.workspace_id = workspace_id
self.timeout = timeout or 60.0 # 默认60秒超时
self.metadata = metadata or {}
self.started_at = datetime.now()
self.completed_at: Optional[datetime] = None
self.status = ExecutionStatus.PENDING
class ToolExecutor:
"""工具执行器 - 使用langchain标准接口执行工具"""
def __init__(self, db: Session, registry: ToolRegistry):
"""初始化工具执行器
Args:
db: 数据库会话
registry: 工具注册表
"""
self.db = db
self.registry = registry
self._running_executions: Dict[str, ExecutionContext] = {}
self._execution_lock = asyncio.Lock()
async def execute_tool(
self,
tool_id: str,
parameters: Dict[str, Any],
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
execution_id: Optional[str] = None,
timeout: Optional[float] = None,
metadata: Optional[Dict[str, Any]] = None
) -> ToolResult:
"""执行工具
Args:
tool_id: 工具ID
parameters: 工具参数
user_id: 用户ID
workspace_id: 工作空间ID
execution_id: 执行ID可选自动生成
timeout: 超时时间(秒)
metadata: 额外元数据
Returns:
工具执行结果
"""
# 生成执行ID
if not execution_id:
execution_id = f"exec_{uuid.uuid4().hex[:16]}"
# 创建执行上下文
context = ExecutionContext(
execution_id=execution_id,
tool_id=tool_id,
user_id=user_id,
workspace_id=workspace_id,
timeout=timeout,
metadata=metadata
)
try:
# 获取工具实例
tool = self.registry.get_tool(tool_id)
if not tool:
return ToolResult.error_result(
error=f"工具不存在: {tool_id}",
error_code="TOOL_NOT_FOUND",
execution_time=0.0
)
# 记录执行开始
await self._record_execution_start(context, parameters)
# 执行工具
result = await self._execute_with_timeout(tool, parameters, context)
# 记录执行完成
await self._record_execution_complete(context, result)
return result
except Exception as e:
logger.error(f"工具执行异常: {execution_id}, 错误: {e}")
# 记录执行失败
error_result = ToolResult.error_result(
error=str(e),
error_code="EXECUTION_ERROR",
execution_time=time.time() - context.started_at.timestamp()
)
await self._record_execution_complete(context, error_result)
return error_result
finally:
# 清理执行上下文
async with self._execution_lock:
if execution_id in self._running_executions:
del self._running_executions[execution_id]
async def execute_tools_batch(
self,
tool_executions: List[Dict[str, Any]],
max_concurrency: int = 5
) -> List[ToolResult]:
"""批量执行工具
Args:
tool_executions: 工具执行配置列表每个包含tool_id和parameters
max_concurrency: 最大并发数
Returns:
执行结果列表
"""
semaphore = asyncio.Semaphore(max_concurrency)
async def execute_single(exec_config: Dict[str, Any]) -> ToolResult:
async with semaphore:
return await self.execute_tool(
tool_id=exec_config["tool_id"],
parameters=exec_config.get("parameters", {}),
user_id=exec_config.get("user_id"),
workspace_id=exec_config.get("workspace_id"),
timeout=exec_config.get("timeout"),
metadata=exec_config.get("metadata")
)
# 并发执行所有工具
tasks = [execute_single(config) for config in tool_executions]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常结果
processed_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
processed_results.append(
ToolResult.error_result(
error=str(result),
error_code="BATCH_EXECUTION_ERROR",
execution_time=0.0
)
)
else:
processed_results.append(result)
return processed_results
async def cancel_execution(self, execution_id: str) -> bool:
"""取消工具执行
Args:
execution_id: 执行ID
Returns:
是否成功取消
"""
async with self._execution_lock:
if execution_id not in self._running_executions:
return False
context = self._running_executions[execution_id]
context.status = ExecutionStatus.FAILED
# 更新数据库记录
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
if execution_record:
execution_record.status = ExecutionStatus.FAILED.value
execution_record.error_message = "执行被取消"
execution_record.completed_at = datetime.now()
self.db.commit()
logger.info(f"工具执行已取消: {execution_id}")
return True
def get_running_executions(self) -> List[Dict[str, Any]]:
"""获取正在运行的执行列表
Returns:
执行信息列表
"""
executions = []
for execution_id, context in self._running_executions.items():
executions.append({
"execution_id": execution_id,
"tool_id": context.tool_id,
"user_id": str(context.user_id) if context.user_id else None,
"workspace_id": str(context.workspace_id) if context.workspace_id else None,
"started_at": context.started_at.isoformat(),
"status": context.status.value,
"elapsed_time": (datetime.now() - context.started_at).total_seconds()
})
return executions
async def _execute_with_timeout(
self,
tool: BaseTool,
parameters: Dict[str, Any],
context: ExecutionContext
) -> ToolResult:
"""带超时的工具执行
Args:
tool: 工具实例
parameters: 参数
context: 执行上下文
Returns:
执行结果
"""
async with self._execution_lock:
self._running_executions[context.execution_id] = context
context.status = ExecutionStatus.RUNNING
try:
# 使用asyncio.wait_for实现超时控制
result = await asyncio.wait_for(
tool.safe_execute(**parameters),
timeout=context.timeout
)
context.status = ExecutionStatus.COMPLETED
return result
except asyncio.TimeoutError:
context.status = ExecutionStatus.TIMEOUT
return ToolResult.error_result(
error=f"工具执行超时({context.timeout}秒)",
error_code="EXECUTION_TIMEOUT",
execution_time=context.timeout
)
except Exception as e:
context.status = ExecutionStatus.FAILED
raise
async def _record_execution_start(
self,
context: ExecutionContext,
parameters: Dict[str, Any]
):
"""记录执行开始"""
try:
execution_record = ToolExecution(
execution_id=context.execution_id,
tool_config_id=uuid.UUID(context.tool_id),
status=ExecutionStatus.RUNNING.value,
input_data=parameters,
started_at=context.started_at,
user_id=context.user_id,
workspace_id=context.workspace_id
)
self.db.add(execution_record)
self.db.commit()
logger.debug(f"执行记录已创建: {context.execution_id}")
except Exception as e:
logger.error(f"创建执行记录失败: {context.execution_id}, 错误: {e}")
async def _record_execution_complete(
self,
context: ExecutionContext,
result: ToolResult
):
"""记录执行完成"""
try:
context.completed_at = datetime.now()
execution_record = self.db.query(ToolExecution).filter(
ToolExecution.execution_id == context.execution_id
).first()
if execution_record:
execution_record.status = (
ExecutionStatus.COMPLETED.value if result.success
else ExecutionStatus.FAILED.value
)
execution_record.output_data = result.data if result.success else None
execution_record.error_message = result.error if not result.success else None
execution_record.completed_at = context.completed_at
execution_record.execution_time = result.execution_time
execution_record.token_usage = result.token_usage
self.db.commit()
logger.debug(f"执行记录已更新: {context.execution_id}")
except Exception as e:
logger.error(f"更新执行记录失败: {context.execution_id}, 错误: {e}")
def get_execution_history(
self,
tool_id: Optional[str] = None,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""获取执行历史
Args:
tool_id: 工具ID过滤
user_id: 用户ID过滤
workspace_id: 工作空间ID过滤
limit: 返回数量限制
Returns:
执行历史列表
"""
try:
query = self.db.query(ToolExecution).order_by(
ToolExecution.started_at.desc()
)
if tool_id:
query = query.filter(ToolExecution.tool_config_id == uuid.UUID(tool_id))
if user_id:
query = query.filter(ToolExecution.user_id == user_id)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.limit(limit).all()
history = []
for execution in executions:
history.append({
"execution_id": execution.execution_id,
"tool_id": str(execution.tool_config_id),
"status": execution.status,
"started_at": execution.started_at.isoformat() if execution.started_at else None,
"completed_at": execution.completed_at.isoformat() if execution.completed_at else None,
"execution_time": execution.execution_time,
"user_id": str(execution.user_id) if execution.user_id else None,
"workspace_id": str(execution.workspace_id) if execution.workspace_id else None,
"input_data": execution.input_data,
"output_data": execution.output_data,
"error_message": execution.error_message,
"token_usage": execution.token_usage
})
return history
except Exception as e:
logger.error(f"获取执行历史失败, 错误: {e}")
return []
def get_execution_statistics(
self,
workspace_id: Optional[uuid.UUID] = None,
days: int = 7
) -> Dict[str, Any]:
"""获取执行统计信息
Args:
workspace_id: 工作空间ID
days: 统计天数
Returns:
统计信息
"""
try:
from datetime import timedelta
start_date = datetime.now() - timedelta(days=days)
query = self.db.query(ToolExecution).filter(
ToolExecution.started_at >= start_date
)
if workspace_id:
query = query.filter(ToolExecution.workspace_id == workspace_id)
executions = query.all()
# 统计数据
total_executions = len(executions)
successful_executions = len([e for e in executions if e.status == ExecutionStatus.COMPLETED.value])
failed_executions = len([e for e in executions if e.status == ExecutionStatus.FAILED.value])
# 平均执行时间
completed_executions = [e for e in executions if e.execution_time is not None]
avg_execution_time = (
sum(e.execution_time for e in completed_executions) / len(completed_executions)
if completed_executions else 0
)
# 按工具统计
tool_stats = {}
for execution in executions:
tool_id = str(execution.tool_config_id)
if tool_id not in tool_stats:
tool_stats[tool_id] = {"total": 0, "successful": 0, "failed": 0}
tool_stats[tool_id]["total"] += 1
if execution.status == ExecutionStatus.COMPLETED.value:
tool_stats[tool_id]["successful"] += 1
elif execution.status == ExecutionStatus.FAILED.value:
tool_stats[tool_id]["failed"] += 1
return {
"period_days": days,
"total_executions": total_executions,
"successful_executions": successful_executions,
"failed_executions": failed_executions,
"success_rate": successful_executions / total_executions if total_executions > 0 else 0,
"average_execution_time": avg_execution_time,
"tool_statistics": tool_stats
}
except Exception as e:
logger.error(f"获取执行统计失败, 错误: {e}")
return {}
async def test_tool_connection(
self,
tool_id: str,
user_id: Optional[uuid.UUID] = None,
workspace_id: Optional[uuid.UUID] = None
) -> Dict[str, Any]:
"""测试工具连接"""
try:
from app.models.tool_model import ToolConfig, ToolType, MCPToolConfig
from .mcp.client import MCPClient
tool_config = self.db.query(ToolConfig).filter(
ToolConfig.id == uuid.UUID(tool_id)
).first()
if not tool_config:
return {"success": False, "message": "工具不存在"}
if tool_config.tool_type == ToolType.MCP.value:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if not mcp_config:
return {"success": False, "message": "MCP配置不存在"}
client = MCPClient(mcp_config.server_url, mcp_config.connection_config or {})
if await client.connect():
try:
tools = await client.list_tools()
await client.disconnect()
return {
"success": True,
"message": "MCP连接成功",
"details": {"server_url": mcp_config.server_url, "tools": len(tools)}
}
except:
await client.disconnect()
return {"success": False, "message": "MCP功能测试失败"}
else:
return {"success": False, "message": "MCP连接失败"}
else:
tool = self.registry.get_tool(tool_id)
if tool and hasattr(tool, 'test_connection'):
result = tool.test_connection()
return {"success": result.get("success", False), "message": result.get("message", "")}
return {"success": True, "message": "工具无需连接测试"}
except Exception as e:
return {"success": False, "message": "测试失败", "error": str(e)}

View File

@@ -0,0 +1,375 @@
"""Langchain适配器 - 将工具转换为langchain兼容格式"""
import json
from typing import Dict, Any, List, Optional, Type
from pydantic import BaseModel, Field
from langchain.tools import BaseTool as LangchainBaseTool
from langchain_core.tools import ToolException
from app.core.tools.base import BaseTool, ToolResult, ToolParameter, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class LangchainToolWrapper(LangchainBaseTool):
"""Langchain工具包装器"""
name: str = Field(..., description="工具名称")
description: str = Field(..., description="工具描述")
args_schema: Optional[Type[BaseModel]] = Field(None, description="参数schema")
return_direct: bool = Field(False, description="是否直接返回结果")
# 内部工具实例
tool_instance: BaseTool = Field(..., description="内部工具实例")
class Config:
arbitrary_types_allowed = True
def __init__(self, tool_instance: BaseTool, **kwargs):
"""初始化Langchain工具包装器
Args:
tool_instance: 内部工具实例
"""
# 动态创建参数schema
args_schema = LangchainAdapter._create_pydantic_schema(tool_instance.parameters)
super().__init__(
name=tool_instance.name,
description=tool_instance.description,
args_schema=args_schema,
_tool_instance=tool_instance,
**kwargs
)
def _run(
self,
run_manager=None,
**kwargs: Any,
) -> str:
"""同步执行工具Langchain要求"""
# 由于我们的工具是异步的,这里抛出异常提示使用异步版本
raise NotImplementedError("请使用 _arun 方法进行异步调用")
async def _arun(
self,
run_manager=None,
**kwargs: Any,
) -> str:
"""异步执行工具"""
try:
# 执行内部工具
result = await self._tool_instance.safe_execute(**kwargs)
# 转换结果为Langchain格式
return LangchainAdapter._format_result_for_langchain(result)
except Exception as e:
logger.error(f"工具执行失败: {self.name}, 错误: {e}")
raise ToolException(f"工具执行失败: {str(e)}")
class LangchainAdapter:
"""Langchain适配器 - 负责工具格式转换和标准化"""
@staticmethod
def convert_tool(tool: BaseTool) -> LangchainToolWrapper:
"""将内部工具转换为Langchain工具
Args:
tool: 内部工具实例
Returns:
Langchain兼容的工具包装器
"""
try:
wrapper = LangchainToolWrapper(tool_instance=tool)
logger.debug(f"工具转换成功: {tool.name} -> Langchain格式")
return wrapper
except Exception as e:
logger.error(f"工具转换失败: {tool.name}, 错误: {e}")
raise
@staticmethod
def convert_tools(tools: List[BaseTool]) -> List[LangchainToolWrapper]:
"""批量转换工具
Args:
tools: 工具列表
Returns:
Langchain工具列表
"""
converted_tools = []
for tool in tools:
try:
converted_tool = LangchainAdapter.convert_tool(tool)
converted_tools.append(converted_tool)
except Exception as e:
logger.error(f"跳过工具转换: {tool.name}, 错误: {e}")
logger.info(f"批量转换完成: {len(converted_tools)}/{len(tools)} 个工具")
return converted_tools
@staticmethod
def _create_pydantic_schema(parameters: List[ToolParameter]) -> Type[BaseModel]:
"""根据工具参数创建Pydantic schema
Args:
parameters: 工具参数列表
Returns:
Pydantic模型类
"""
# 构建字段定义
fields = {}
annotations = {}
for param in parameters:
# 确定Python类型
python_type = LangchainAdapter._get_python_type(param.type)
# 处理可选参数
if not param.required:
python_type = Optional[python_type]
# 创建Field定义
field_kwargs = {
"description": param.description
}
if param.default is not None:
field_kwargs["default"] = param.default
elif not param.required:
field_kwargs["default"] = None
else:
field_kwargs["default"] = ... # 必需字段
# 添加验证约束
if param.enum:
# 枚举值约束
field_kwargs["regex"] = f"^({'|'.join(map(str, param.enum))})$"
if param.minimum is not None:
field_kwargs["ge"] = param.minimum
if param.maximum is not None:
field_kwargs["le"] = param.maximum
if param.pattern:
field_kwargs["regex"] = param.pattern
fields[param.name] = Field(**field_kwargs)
annotations[param.name] = python_type
# 动态创建Pydantic模型
schema_class = type(
"ToolArgsSchema",
(BaseModel,),
{
"__annotations__": annotations,
**fields,
"Config": type("Config", (), {"extra": "forbid"})
}
)
return schema_class
@staticmethod
def _get_python_type(param_type: ParameterType) -> type:
"""获取参数类型对应的Python类型
Args:
param_type: 参数类型
Returns:
Python类型
"""
type_mapping = {
ParameterType.STRING: str,
ParameterType.INTEGER: int,
ParameterType.NUMBER: float,
ParameterType.BOOLEAN: bool,
ParameterType.ARRAY: list,
ParameterType.OBJECT: dict
}
return type_mapping.get(param_type, str)
@staticmethod
def _format_result_for_langchain(result: ToolResult) -> str:
"""将工具结果格式化为Langchain标准格式
Args:
result: 工具执行结果
Returns:
格式化的字符串结果
"""
if not result.success:
# 错误结果
error_info = {
"success": False,
"error": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}
return json.dumps(error_info, ensure_ascii=False, indent=2)
# 成功结果
if isinstance(result.data, str):
# 如果数据已经是字符串,直接返回
return result.data
elif isinstance(result.data, (dict, list)):
# 如果是结构化数据转换为JSON
return json.dumps(result.data, ensure_ascii=False, indent=2)
else:
# 其他类型转换为字符串
return str(result.data)
@staticmethod
def create_tool_description(tool: BaseTool) -> Dict[str, Any]:
"""创建工具描述(用于工具发现和文档生成)
Args:
tool: 工具实例
Returns:
工具描述字典
"""
return {
"name": tool.name,
"description": tool.description,
"tool_type": tool.tool_type.value,
"version": tool.version,
"status": tool.status.value,
"tags": tool.tags,
"parameters": [
{
"name": param.name,
"type": param.type.value,
"description": param.description,
"required": param.required,
"default": param.default,
"enum": param.enum,
"minimum": param.minimum,
"maximum": param.maximum,
"pattern": param.pattern
}
for param in tool.parameters
],
"langchain_compatible": True
}
@staticmethod
def validate_langchain_compatibility(tool: BaseTool) -> tuple[bool, List[str]]:
"""验证工具是否与Langchain兼容
Args:
tool: 工具实例
Returns:
(是否兼容, 问题列表)
"""
issues = []
# 检查工具名称
if not tool.name or not isinstance(tool.name, str):
issues.append("工具名称必须是非空字符串")
# 检查工具描述
if not tool.description or not isinstance(tool.description, str):
issues.append("工具描述必须是非空字符串")
# 检查参数定义
for param in tool.parameters:
if not param.name or not isinstance(param.name, str):
issues.append(f"参数名称无效: {param.name}")
if param.type not in ParameterType:
issues.append(f"不支持的参数类型: {param.type}")
if param.required and param.default is not None:
issues.append(f"必需参数不应有默认值: {param.name}")
# 检查是否有execute方法
if not hasattr(tool, 'execute') or not callable(getattr(tool, 'execute')):
issues.append("工具必须实现execute方法")
return len(issues) == 0, issues
@staticmethod
def get_langchain_tool_schema(tool: BaseTool) -> Dict[str, Any]:
"""获取Langchain工具的OpenAPI schema
Args:
tool: 工具实例
Returns:
OpenAPI schema字典
"""
# 构建参数schema
properties = {}
required = []
for param in tool.parameters:
prop_schema = {
"type": LangchainAdapter._get_openapi_type(param.type),
"description": param.description
}
if param.enum:
prop_schema["enum"] = param.enum
if param.minimum is not None:
prop_schema["minimum"] = param.minimum
if param.maximum is not None:
prop_schema["maximum"] = param.maximum
if param.pattern:
prop_schema["pattern"] = param.pattern
if param.default is not None:
prop_schema["default"] = param.default
properties[param.name] = prop_schema
if param.required:
required.append(param.name)
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": required
}
}
}
@staticmethod
def _get_openapi_type(param_type: ParameterType) -> str:
"""获取OpenAPI类型
Args:
param_type: 参数类型
Returns:
OpenAPI类型字符串
"""
type_mapping = {
ParameterType.STRING: "string",
ParameterType.INTEGER: "integer",
ParameterType.NUMBER: "number",
ParameterType.BOOLEAN: "boolean",
ParameterType.ARRAY: "array",
ParameterType.OBJECT: "object"
}
return type_mapping.get(param_type, "string")

View File

@@ -0,0 +1,12 @@
"""MCP工具模块"""
from .base import MCPTool
from .client import MCPClient, MCPConnectionPool
from .service_manager import MCPServiceManager
__all__ = [
"MCPTool",
"MCPClient",
"MCPConnectionPool",
"MCPServiceManager"
]

View File

@@ -0,0 +1,258 @@
"""MCP工具基类"""
import time
from typing import Dict, Any, List
import aiohttp
from app.models.tool_model import ToolType
from app.core.tools.base import BaseTool, ToolParameter, ToolResult, ParameterType
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPTool(BaseTool):
"""MCP工具 - Model Context Protocol工具"""
def __init__(self, tool_id: str, config: Dict[str, Any]):
"""初始化MCP工具
Args:
tool_id: 工具ID
config: 工具配置
"""
super().__init__(tool_id, config)
self.server_url = config.get("server_url", "")
self.connection_config = config.get("connection_config", {})
self.available_tools = config.get("available_tools", [])
self._client = None
self._connected = False
@property
def name(self) -> str:
"""工具名称"""
return f"mcp_tool_{self.tool_id[:8]}"
@property
def description(self) -> str:
"""工具描述"""
return f"MCP工具 - 连接到 {self.server_url}"
@property
def tool_type(self) -> ToolType:
"""工具类型"""
return ToolType.MCP
@property
def parameters(self) -> List[ToolParameter]:
"""工具参数定义"""
params = []
# 添加工具选择参数
if len(self.available_tools) > 1:
params.append(ToolParameter(
name="tool_name",
type=ParameterType.STRING,
description="要调用的MCP工具名称",
required=True,
enum=self.available_tools
))
# 添加通用参数
params.extend([
ToolParameter(
name="arguments",
type=ParameterType.OBJECT,
description="工具参数JSON对象",
required=False,
default={}
),
ToolParameter(
name="timeout",
type=ParameterType.INTEGER,
description="超时时间(秒)",
required=False,
default=30,
minimum=1,
maximum=300
)
])
return params
async def execute(self, **kwargs) -> ToolResult:
"""执行MCP工具"""
start_time = time.time()
try:
# 确保连接
if not self._connected:
await self.connect()
# 确定要调用的工具
tool_name = kwargs.get("tool_name")
if not tool_name and len(self.available_tools) == 1:
tool_name = self.available_tools[0]
if not tool_name:
raise ValueError("必须指定要调用的MCP工具名称")
if tool_name not in self.available_tools:
raise ValueError(f"MCP工具不存在: {tool_name}")
# 获取参数
arguments = kwargs.get("arguments", {})
timeout = kwargs.get("timeout", 30)
# 调用MCP工具
result = await self._call_mcp_tool(tool_name, arguments, timeout)
execution_time = time.time() - start_time
return ToolResult.success_result(
data=result,
execution_time=execution_time
)
except Exception as e:
execution_time = time.time() - start_time
return ToolResult.error_result(
error=str(e),
error_code="MCP_ERROR",
execution_time=execution_time
)
async def connect(self) -> bool:
"""连接到MCP服务器"""
try:
# 这里应该实现实际的MCP连接逻辑
# 为了简化,这里只是模拟连接
# 测试服务器连接
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
# 尝试获取服务器信息
async with session.get(f"{self.server_url}/info") as response:
if response.status == 200:
server_info = await response.json()
self.available_tools = server_info.get("tools", [])
self._connected = True
logger.info(f"MCP服务器连接成功: {self.server_url}")
return True
else:
raise Exception(f"服务器响应错误: {response.status}")
except Exception as e:
logger.error(f"MCP服务器连接失败: {self.server_url}, 错误: {e}")
self._connected = False
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接"""
try:
if self._client:
# 这里应该实现实际的断开逻辑
self._client = None
self._connected = False
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
def get_health_status(self) -> Dict[str, Any]:
"""获取MCP服务健康状态"""
return {
"connected": self._connected,
"server_url": self.server_url,
"available_tools": self.available_tools,
"last_check": time.time()
}
async def _call_mcp_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int) -> Any:
"""调用MCP工具"""
# 构建MCP请求
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
# 发送请求
client_timeout = aiohttp.ClientTimeout(total=timeout)
async with aiohttp.ClientSession(timeout=client_timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status != 200:
error_text = await response.text()
raise Exception(f"MCP请求失败 {response.status}: {error_text}")
result = await response.json()
# 检查MCP响应
if "error" in result:
error = result["error"]
raise Exception(f"MCP工具错误: {error.get('message', '未知错误')}")
return result.get("result", {})
async def list_available_tools(self) -> List[Dict[str, Any]]:
"""列出可用的MCP工具"""
try:
if not self._connected:
await self.connect()
# 获取工具列表
request_data = {
"jsonrpc": "2.0",
"id": f"req_{int(time.time() * 1000)}",
"method": "tools/list"
}
timeout = aiohttp.ClientTimeout(total=10)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
f"{self.server_url}/mcp",
json=request_data,
headers={"Content-Type": "application/json"}
) as response:
if response.status == 200:
result = await response.json()
if "result" in result:
tools = result["result"].get("tools", [])
self.available_tools = [tool.get("name") for tool in tools]
return tools
return []
except Exception as e:
logger.error(f"获取MCP工具列表失败: {e}")
return []
def test_connection(self) -> Dict[str, Any]:
"""测试MCP连接"""
try:
# 这里应该实现同步的连接测试
# 为了简化,返回基本信息
return {
"success": bool(self.server_url),
"server_url": self.server_url,
"connected": self._connected,
"available_tools_count": len(self.available_tools),
"message": "MCP配置有效" if self.server_url else "缺少服务器URL配置"
}
except Exception as e:
return {
"success": False,
"error": str(e)
}

View File

@@ -0,0 +1,626 @@
"""MCP客户端 - Model Context Protocol客户端实现"""
import asyncio
import json
import time
from typing import Dict, Any, List, Optional, Callable
from urllib.parse import urlparse
import aiohttp
import websockets
from websockets.exceptions import ConnectionClosed
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MCPConnectionError(Exception):
"""MCP连接错误"""
pass
class MCPProtocolError(Exception):
"""MCP协议错误"""
pass
class MCPClient:
"""MCP客户端 - 支持HTTP和WebSocket连接"""
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
"""初始化MCP客户端
Args:
server_url: MCP服务器URL
connection_config: 连接配置
"""
self.server_url = server_url
self.connection_config = connection_config or {}
# 解析URL确定连接类型
parsed_url = urlparse(server_url)
self.connection_type = "websocket" if parsed_url.scheme in ["ws", "wss"] else "http"
# 连接状态
self._connected = False
self._websocket = None
self._session = None
# 请求管理
self._request_id = 0
self._pending_requests: Dict[str, asyncio.Future] = {}
# 连接池配置
self.max_connections = self.connection_config.get("max_connections", 10)
self.connection_timeout = self.connection_config.get("timeout", 30)
self.retry_attempts = self.connection_config.get("retry_attempts", 3)
self.retry_delay = self.connection_config.get("retry_delay", 1)
# 健康检查
self.health_check_interval = self.connection_config.get("health_check_interval", 60)
self._health_check_task = None
self._last_health_check = None
# 事件回调
self._on_connect_callbacks: List[Callable] = []
self._on_disconnect_callbacks: List[Callable] = []
self._on_error_callbacks: List[Callable] = []
async def connect(self) -> bool:
"""连接到MCP服务器
Returns:
连接是否成功
"""
try:
if self._connected:
return True
logger.info(f"连接MCP服务器: {self.server_url}")
if self.connection_type == "websocket":
success = await self._connect_websocket()
else:
success = await self._connect_http()
if success:
self._connected = True
await self._start_health_check()
await self._notify_connect_callbacks()
logger.info(f"MCP服务器连接成功: {self.server_url}")
return success
except Exception as e:
logger.error(f"连接MCP服务器失败: {self.server_url}, 错误: {e}")
await self._notify_error_callbacks(e)
return False
async def disconnect(self) -> bool:
"""断开MCP服务器连接
Returns:
断开是否成功
"""
try:
if not self._connected:
return True
logger.info(f"断开MCP服务器连接: {self.server_url}")
# 停止健康检查
await self._stop_health_check()
# 取消所有待处理的请求
for future in self._pending_requests.values():
if not future.done():
future.cancel()
self._pending_requests.clear()
# 断开连接
if self.connection_type == "websocket" and self._websocket:
await self._websocket.close()
self._websocket = None
elif self._session:
await self._session.close()
self._session = None
self._connected = False
await self._notify_disconnect_callbacks()
logger.info(f"MCP服务器连接已断开: {self.server_url}")
return True
except Exception as e:
logger.error(f"断开MCP服务器连接失败: {e}")
return False
async def _connect_websocket(self) -> bool:
"""建立WebSocket连接"""
try:
# WebSocket连接配置
extra_headers = self.connection_config.get("headers", {})
self._websocket = await websockets.connect(
self.server_url,
extra_headers=extra_headers,
timeout=self.connection_timeout
)
# 启动消息监听
asyncio.create_task(self._websocket_message_handler())
# 发送初始化消息
init_message = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "ToolManagementSystem",
"version": "1.0.0"
}
}
}
await self._websocket.send(json.dumps(init_message))
# 等待初始化响应
response = await asyncio.wait_for(
self._websocket.recv(),
timeout=self.connection_timeout
)
init_response = json.loads(response)
if "error" in init_response:
raise MCPProtocolError(f"初始化失败: {init_response['error']}")
return True
except Exception as e:
logger.error(f"WebSocket连接失败: {e}")
return False
async def _connect_http(self) -> bool:
"""建立HTTP连接"""
try:
# HTTP会话配置
timeout = aiohttp.ClientTimeout(total=self.connection_timeout)
headers = self.connection_config.get("headers", {})
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=headers
)
# 测试连接
test_url = f"{self.server_url}/health" if not self.server_url.endswith('/') else f"{self.server_url}health"
async with self._session.get(test_url) as response:
if response.status == 200:
return True
else:
# 尝试根路径
async with self._session.get(self.server_url) as root_response:
return root_response.status < 400
except Exception as e:
logger.error(f"HTTP连接失败: {e}")
if self._session:
await self._session.close()
self._session = None
return False
async def _websocket_message_handler(self):
"""WebSocket消息处理器"""
try:
while self._websocket and not self._websocket.closed:
try:
message = await self._websocket.recv()
await self._handle_message(json.loads(message))
except ConnectionClosed:
break
except json.JSONDecodeError as e:
logger.error(f"解析WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"处理WebSocket消息失败: {e}")
except Exception as e:
logger.error(f"WebSocket消息处理器异常: {e}")
finally:
self._connected = False
await self._notify_disconnect_callbacks()
async def _handle_message(self, message: Dict[str, Any]):
"""处理收到的消息"""
try:
# 检查是否是响应消息
if "id" in message:
request_id = str(message["id"])
if request_id in self._pending_requests:
future = self._pending_requests.pop(request_id)
if not future.done():
future.set_result(message)
# 处理通知消息
elif "method" in message:
await self._handle_notification(message)
except Exception as e:
logger.error(f"处理消息失败: {e}")
async def _handle_notification(self, message: Dict[str, Any]):
"""处理通知消息"""
method = message.get("method")
params = message.get("params", {})
logger.debug(f"收到MCP通知: {method}, 参数: {params}")
# 这里可以根据需要处理特定的通知
# 例如:工具列表更新、服务器状态变化等
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]:
"""调用MCP工具
Args:
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间(秒)
Returns:
工具执行结果
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
}
}
try:
response = await self._send_request(request_data, timeout)
if "error" in response:
error = response["error"]
raise MCPProtocolError(f"工具调用失败: {error.get('message', '未知错误')}")
return response.get("result", {})
except asyncio.TimeoutError:
raise MCPProtocolError(f"工具调用超时: {tool_name}")
async def list_tools(self, timeout: int = 10) -> List[Dict[str, Any]]:
"""获取可用工具列表
Args:
timeout: 超时时间(秒)
Returns:
工具列表
Raises:
MCPConnectionError: 连接错误
MCPProtocolError: 协议错误
"""
if not self._connected:
raise MCPConnectionError("MCP客户端未连接")
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "tools/list"
}
try:
response = await self._send_request(request_data, timeout)
if not response["error"] is None:
error = response["error"]
raise MCPProtocolError(f"获取工具列表失败: {error.get('message', '未知错误')}")
result = response.get("result", {})
return result.get("tools", [])
except asyncio.TimeoutError:
raise MCPProtocolError("获取工具列表超时")
async def _send_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送请求并等待响应
Args:
request_data: 请求数据
timeout: 超时时间(秒)
Returns:
响应数据
"""
request_id = str(request_data["id"])
if self.connection_type == "websocket":
return await self._send_websocket_request(request_data, request_id, timeout)
else:
return await self._send_http_request(request_data, timeout)
async def _send_websocket_request(self, request_data: Dict[str, Any], request_id: str, timeout: int) -> Dict[str, Any]:
"""发送WebSocket请求"""
if not self._websocket or self._websocket.closed:
raise MCPConnectionError("WebSocket连接已断开")
# 创建Future等待响应
future = asyncio.Future()
self._pending_requests[request_id] = future
try:
# 发送请求
await self._websocket.send(json.dumps(request_data))
# 等待响应
response = await asyncio.wait_for(future, timeout=timeout)
return response
except asyncio.TimeoutError:
self._pending_requests.pop(request_id, None)
raise
except Exception as e:
self._pending_requests.pop(request_id, None)
raise MCPConnectionError(f"发送WebSocket请求失败: {e}")
async def _send_http_request(self, request_data: Dict[str, Any], timeout: int) -> Dict[str, Any]:
"""发送HTTP请求"""
if not self._session:
raise MCPConnectionError("HTTP会话未建立")
try:
url = f"{self.server_url}/mcp" if not self.server_url.endswith('/') else f"{self.server_url}mcp"
async with self._session.post(
url,
json=request_data,
timeout=aiohttp.ClientTimeout(total=timeout)
) as response:
if response.status != 200:
error_text = await response.text()
raise MCPConnectionError(f"HTTP请求失败 {response.status}: {error_text}")
return await response.json()
except aiohttp.ClientError as e:
raise MCPConnectionError(f"HTTP请求失败: {e}")
async def health_check(self) -> Dict[str, Any]:
"""执行健康检查
Returns:
健康状态信息
"""
try:
if not self._connected:
return {
"healthy": False,
"error": "未连接",
"timestamp": time.time()
}
# 发送ping请求
request_data = {
"jsonrpc": "2.0",
"id": self._get_next_request_id(),
"method": "ping"
}
start_time = time.time()
response = await self._send_request(request_data, timeout=5)
response_time = time.time() - start_time
self._last_health_check = time.time()
return {
"healthy": True,
"response_time": response_time,
"timestamp": self._last_health_check,
"server_info": response.get("result", {})
}
except Exception as e:
return {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
async def _start_health_check(self):
"""启动健康检查任务"""
if self.health_check_interval > 0:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def _stop_health_check(self):
"""停止健康检查任务"""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
async def _health_check_loop(self):
"""健康检查循环"""
try:
while self._connected:
await asyncio.sleep(self.health_check_interval)
if self._connected:
health_status = await self.health_check()
if not health_status["healthy"]:
logger.warning(f"MCP服务器健康检查失败: {health_status.get('error')}")
# 可以在这里实现重连逻辑
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"健康检查循环异常: {e}")
def _get_next_request_id(self) -> str:
"""获取下一个请求ID"""
self._request_id += 1
return f"req_{self._request_id}_{int(time.time() * 1000)}"
# 事件回调管理
def on_connect(self, callback: Callable):
"""注册连接回调"""
self._on_connect_callbacks.append(callback)
def on_disconnect(self, callback: Callable):
"""注册断开连接回调"""
self._on_disconnect_callbacks.append(callback)
def on_error(self, callback: Callable):
"""注册错误回调"""
self._on_error_callbacks.append(callback)
async def _notify_connect_callbacks(self):
"""通知连接回调"""
for callback in self._on_connect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"连接回调执行失败: {e}")
async def _notify_disconnect_callbacks(self):
"""通知断开连接回调"""
for callback in self._on_disconnect_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback()
else:
callback()
except Exception as e:
logger.error(f"断开连接回调执行失败: {e}")
async def _notify_error_callbacks(self, error: Exception):
"""通知错误回调"""
for callback in self._on_error_callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(error)
else:
callback(error)
except Exception as e:
logger.error(f"错误回调执行失败: {e}")
@property
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
@property
def last_health_check(self) -> Optional[float]:
"""获取最后一次健康检查时间"""
return self._last_health_check
def get_connection_info(self) -> Dict[str, Any]:
"""获取连接信息"""
return {
"server_url": self.server_url,
"connection_type": self.connection_type,
"connected": self._connected,
"last_health_check": self._last_health_check,
"pending_requests": len(self._pending_requests),
"config": self.connection_config
}
async def __aenter__(self):
"""异步上下文管理器入口"""
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
await self.disconnect()
class MCPConnectionPool:
"""MCP连接池 - 管理多个MCP客户端连接"""
def __init__(self, max_connections: int = 10):
"""初始化连接池
Args:
max_connections: 最大连接数
"""
self.max_connections = max_connections
self._clients: Dict[str, MCPClient] = {}
self._lock = asyncio.Lock()
async def get_client(self, server_url: str, connection_config: Dict[str, Any] = None) -> MCPClient:
"""获取或创建MCP客户端
Args:
server_url: 服务器URL
connection_config: 连接配置
Returns:
MCP客户端实例
"""
async with self._lock:
if server_url in self._clients:
client = self._clients[server_url]
if client.is_connected:
return client
else:
# 尝试重连
if await client.connect():
return client
else:
# 移除失效的客户端
del self._clients[server_url]
# 检查连接数限制
if len(self._clients) >= self.max_connections:
# 移除最旧的连接
oldest_url = next(iter(self._clients))
await self._clients[oldest_url].disconnect()
del self._clients[oldest_url]
# 创建新客户端
client = MCPClient(server_url, connection_config)
if await client.connect():
self._clients[server_url] = client
return client
else:
raise MCPConnectionError(f"无法连接到MCP服务器: {server_url}")
async def disconnect_all(self):
"""断开所有连接"""
async with self._lock:
for client in self._clients.values():
await client.disconnect()
self._clients.clear()
def get_pool_status(self) -> Dict[str, Any]:
"""获取连接池状态"""
return {
"total_connections": len(self._clients),
"max_connections": self.max_connections,
"connections": {
url: client.get_connection_info()
for url, client in self._clients.items()
}
}

View File

@@ -0,0 +1,604 @@
"""MCP服务管理器 - 管理MCP服务的注册、更新、删除和状态监控"""
import asyncio
import time
import uuid
from typing import Dict, Any, List, Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.models.tool_model import MCPToolConfig, ToolConfig, ToolType
from app.core.logging_config import get_business_logger
from .client import MCPClient, MCPConnectionPool
logger = get_business_logger()
class MCPServiceManager:
"""MCP服务管理器 - 管理MCP服务的生命周期"""
def __init__(self, db: Session):
"""初始化MCP服务管理器
Args:
db: 数据库会话
"""
self.db = db
self.connection_pool = MCPConnectionPool(max_connections=20)
# 服务状态管理
self._services: Dict[str, Dict[str, Any]] = {} # service_id -> service_info
self._monitoring_tasks: Dict[str, asyncio.Task] = {} # service_id -> monitoring_task
# 配置
self.health_check_interval = 60 # 健康检查间隔(秒)
self.max_retry_attempts = 3 # 最大重试次数
self.retry_delay = 5 # 重试延迟(秒)
# 状态
self._running = False
self._manager_task = None
async def start(self):
"""启动服务管理器"""
if self._running:
return
self._running = True
logger.info("MCP服务管理器启动")
# 加载现有服务
await self._load_existing_services()
# 启动管理任务
self._manager_task = asyncio.create_task(self._management_loop())
async def stop(self):
"""停止服务管理器"""
if not self._running:
return
self._running = False
logger.info("MCP服务管理器停止")
# 停止管理任务
if self._manager_task:
self._manager_task.cancel()
try:
await self._manager_task
except asyncio.CancelledError:
pass
# 停止所有监控任务
for task in self._monitoring_tasks.values():
task.cancel()
if self._monitoring_tasks:
await asyncio.gather(*self._monitoring_tasks.values(), return_exceptions=True)
self._monitoring_tasks.clear()
# 断开所有连接
await self.connection_pool.disconnect_all()
async def register_service(
self,
server_url: str,
connection_config: Dict[str, Any],
tenant_id: uuid.UUID,
service_name: str = None
) -> Tuple[bool, str, Optional[str]]:
"""注册MCP服务
Args:
server_url: 服务器URL
connection_config: 连接配置
tenant_id: 租户ID
service_name: 服务名称(可选)
Returns:
(是否成功, 服务ID或错误信息, 错误详情)
"""
try:
# 检查服务是否已存在
existing_service = self.db.query(MCPToolConfig).filter(
MCPToolConfig.server_url == server_url
).first()
if existing_service:
return False, "服务已存在", f"URL {server_url} 已被注册"
# 测试连接
try:
client = MCPClient(server_url, connection_config)
if not await client.connect():
return False, "连接测试失败", "无法连接到MCP服务器"
# 获取可用工具
available_tools = await client.list_tools()
tool_names = [tool.get("name") for tool in available_tools if tool.get("name")]
await client.disconnect()
except Exception as e:
return False, "连接测试失败", str(e)
# 创建工具配置
if not service_name:
service_name = f"mcp_service_{server_url.split('/')[-1]}"
tool_config = ToolConfig(
name=service_name,
description=f"MCP服务 - {server_url}",
tool_type=ToolType.MCP.value,
tenant_id=tenant_id,
version="1.0.0",
config_data={
"server_url": server_url,
"connection_config": connection_config
}
)
self.db.add(tool_config)
self.db.flush()
# 创建MCP特定配置
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=server_url,
connection_config=connection_config,
available_tools=tool_names,
health_status="healthy",
last_health_check=datetime.utcnow()
)
self.db.add(mcp_config)
self.db.commit()
service_id = str(tool_config.id)
# 添加到内存管理
self._services[service_id] = {
"id": service_id,
"server_url": server_url,
"connection_config": connection_config,
"tenant_id": tenant_id,
"available_tools": tool_names,
"status": "healthy",
"last_health_check": time.time(),
"retry_count": 0,
"created_at": time.time()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"MCP服务注册成功: {service_id} ({server_url})")
return True, service_id, None
except Exception as e:
self.db.rollback()
logger.error(f"注册MCP服务失败: {server_url}, 错误: {e}")
return False, "注册失败", str(e)
async def unregister_service(self, service_id: str) -> Tuple[bool, str]:
"""注销MCP服务
Args:
service_id: 服务ID
Returns:
(是否成功, 错误信息)
"""
try:
# 从数据库删除
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if not tool_config:
return False, "服务不存在"
self.db.delete(tool_config)
self.db.commit()
# 停止监控
await self._stop_service_monitoring(service_id)
# 从内存移除
if service_id in self._services:
del self._services[service_id]
logger.info(f"MCP服务注销成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"注销MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def update_service(
self,
service_id: str,
connection_config: Dict[str, Any] = None,
enabled: bool = None
) -> Tuple[bool, str]:
"""更新MCP服务配置
Args:
service_id: 服务ID
connection_config: 新的连接配置
enabled: 是否启用
Returns:
(是否成功, 错误信息)
"""
try:
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if not mcp_config:
return False, "服务不存在"
tool_config = mcp_config.base_config
if connection_config is not None:
mcp_config.connection_config = connection_config
tool_config.config_data["connection_config"] = connection_config
if enabled is not None:
tool_config.is_enabled = enabled
self.db.commit()
# 更新内存状态
if service_id in self._services:
if connection_config is not None:
self._services[service_id]["connection_config"] = connection_config
# 如果配置有变化,重启监控
if connection_config is not None:
await self._restart_service_monitoring(service_id)
logger.info(f"MCP服务更新成功: {service_id}")
return True, ""
except Exception as e:
self.db.rollback()
logger.error(f"更新MCP服务失败: {service_id}, 错误: {e}")
return False, str(e)
async def get_service_status(self, service_id: str) -> Optional[Dict[str, Any]]:
"""获取服务状态
Args:
service_id: 服务ID
Returns:
服务状态信息
"""
if service_id not in self._services:
return None
service_info = self._services[service_id].copy()
# 添加实时健康检查
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
service_info["real_time_health"] = health_status
except Exception as e:
service_info["real_time_health"] = {
"healthy": False,
"error": str(e),
"timestamp": time.time()
}
return service_info
async def list_services(self, tenant_id: uuid.UUID = None) -> List[Dict[str, Any]]:
"""列出所有服务
Args:
tenant_id: 租户ID过滤
Returns:
服务列表
"""
services = []
for service_id, service_info in self._services.items():
if tenant_id and service_info["tenant_id"] != tenant_id:
continue
services.append(service_info.copy())
return services
async def get_service_tools(self, service_id: str) -> List[Dict[str, Any]]:
"""获取服务的可用工具
Args:
service_id: 服务ID
Returns:
工具列表
"""
if service_id not in self._services:
return []
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
tools = await client.list_tools()
# 更新缓存的工具列表
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
# 更新数据库
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.available_tools = tool_names
self.db.commit()
return tools
except Exception as e:
logger.error(f"获取服务工具失败: {service_id}, 错误: {e}")
return []
async def call_service_tool(
self,
service_id: str,
tool_name: str,
arguments: Dict[str, Any],
timeout: int = 30
) -> Dict[str, Any]:
"""调用服务工具
Args:
service_id: 服务ID
tool_name: 工具名称
arguments: 工具参数
timeout: 超时时间
Returns:
执行结果
"""
if service_id not in self._services:
raise ValueError(f"服务不存在: {service_id}")
service_info = self._services[service_id]
try:
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
result = await client.call_tool(tool_name, arguments, timeout)
# 更新服务状态为健康
service_info["status"] = "healthy"
service_info["last_health_check"] = time.time()
service_info["retry_count"] = 0
return result
except Exception as e:
# 更新服务状态为错误
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
logger.error(f"调用服务工具失败: {service_id}/{tool_name}, 错误: {e}")
raise
async def _load_existing_services(self):
"""加载现有服务"""
try:
mcp_configs = self.db.query(MCPToolConfig).join(ToolConfig).filter(
ToolConfig.is_enabled == True
).all()
for mcp_config in mcp_configs:
tool_config = mcp_config.base_config
service_id = str(mcp_config.id)
self._services[service_id] = {
"id": service_id,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config or {},
"tenant_id": tool_config.tenant_id,
"available_tools": mcp_config.available_tools or [],
"status": mcp_config.health_status or "unknown",
"last_health_check": mcp_config.last_health_check.timestamp() if mcp_config.last_health_check else 0,
"retry_count": 0,
"created_at": tool_config.created_at.timestamp()
}
# 启动监控
await self._start_service_monitoring(service_id)
logger.info(f"加载了 {len(mcp_configs)} 个MCP服务")
except Exception as e:
logger.error(f"加载现有服务失败: {e}")
async def _start_service_monitoring(self, service_id: str):
"""启动服务监控"""
if service_id in self._monitoring_tasks:
return
task = asyncio.create_task(self._monitor_service(service_id))
self._monitoring_tasks[service_id] = task
async def _stop_service_monitoring(self, service_id: str):
"""停止服务监控"""
if service_id in self._monitoring_tasks:
task = self._monitoring_tasks.pop(service_id)
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
async def _restart_service_monitoring(self, service_id: str):
"""重启服务监控"""
await self._stop_service_monitoring(service_id)
await self._start_service_monitoring(service_id)
async def _monitor_service(self, service_id: str):
"""监控单个服务"""
try:
while self._running and service_id in self._services:
service_info = self._services[service_id]
try:
# 执行健康检查
client = await self.connection_pool.get_client(
service_info["server_url"],
service_info["connection_config"]
)
health_status = await client.health_check()
if health_status["healthy"]:
# 服务健康
service_info["status"] = "healthy"
service_info["retry_count"] = 0
# 更新工具列表
try:
tools = await client.list_tools()
tool_names = [tool.get("name") for tool in tools if tool.get("name")]
service_info["available_tools"] = tool_names
except Exception as e:
logger.warning(f"更新工具列表失败: {service_id}, 错误: {e}")
else:
# 服务不健康
service_info["status"] = "unhealthy"
service_info["last_error"] = health_status.get("error", "健康检查失败")
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
# 更新数据库
await self._update_service_health_in_db(service_id, health_status)
except Exception as e:
# 监控异常
service_info["status"] = "error"
service_info["last_error"] = str(e)
service_info["retry_count"] += 1
service_info["last_health_check"] = time.time()
logger.error(f"服务监控异常: {service_id}, 错误: {e}")
# 如果重试次数过多,暂停监控
if service_info["retry_count"] >= self.max_retry_attempts:
logger.warning(f"服务 {service_id} 重试次数过多,暂停监控")
await asyncio.sleep(self.health_check_interval * 5) # 延长等待时间
service_info["retry_count"] = 0 # 重置重试计数
# 等待下次检查
await asyncio.sleep(self.health_check_interval)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"服务监控任务异常: {service_id}, 错误: {e}")
async def _update_service_health_in_db(self, service_id: str, health_status: Dict[str, Any]):
"""更新数据库中的服务健康状态"""
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == uuid.UUID(service_id)
).first()
if mcp_config:
mcp_config.health_status = "healthy" if health_status["healthy"] else "unhealthy"
mcp_config.last_health_check = datetime.utcnow()
if not health_status["healthy"]:
mcp_config.error_message = health_status.get("error", "")
else:
mcp_config.error_message = None
self.db.commit()
except Exception as e:
logger.error(f"更新数据库健康状态失败: {service_id}, 错误: {e}")
self.db.rollback()
async def _management_loop(self):
"""管理循环 - 处理服务清理等任务"""
try:
while self._running:
# 清理失效的服务
await self._cleanup_failed_services()
# 等待下次循环
await asyncio.sleep(300) # 5分钟
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"管理循环异常: {e}")
async def _cleanup_failed_services(self):
"""清理长期失效的服务"""
try:
current_time = time.time()
cleanup_threshold = 24 * 60 * 60 # 24小时
services_to_cleanup = []
for service_id, service_info in self._services.items():
# 检查服务是否长期失效
if (service_info["status"] in ["error", "unhealthy"] and
current_time - service_info["last_health_check"] > cleanup_threshold):
services_to_cleanup.append(service_id)
for service_id in services_to_cleanup:
logger.warning(f"清理长期失效的服务: {service_id}")
# 停止监控但不删除数据库记录
await self._stop_service_monitoring(service_id)
# 标记为禁用
tool_config = self.db.get(ToolConfig, uuid.UUID(service_id))
if tool_config:
tool_config.is_enabled = False
self.db.commit()
# 从内存移除
del self._services[service_id]
except Exception as e:
logger.error(f"清理失效服务失败: {e}")
def get_manager_status(self) -> Dict[str, Any]:
"""获取管理器状态"""
return {
"running": self._running,
"total_services": len(self._services),
"healthy_services": len([s for s in self._services.values() if s["status"] == "healthy"]),
"unhealthy_services": len([s for s in self._services.values() if s["status"] in ["unhealthy", "error"]]),
"monitoring_tasks": len(self._monitoring_tasks),
"connection_pool_status": self.connection_pool.get_pool_status()
}

View File

@@ -0,0 +1,436 @@
"""工具注册表 - 管理所有工具的元数据和状态"""
import uuid
import asyncio
from typing import Dict, List, Optional, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolType, ToolStatus, ToolExecution, ExecutionStatus
)
from app.core.logging_config import get_business_logger
from .base import BaseTool, ToolInfo
from .custom.base import CustomTool
from .mcp.base import MCPTool
logger = get_business_logger()
class ToolRegistry:
"""工具注册表 - 管理所有工具的元数据和实例"""
def __init__(self, db: Session):
"""初始化工具注册表
Args:
db: 数据库会话
"""
self.db = db
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
self._lock = asyncio.Lock() # 异步锁
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
"""注册工具类
Args:
tool_class: 工具类
class_name: 类名可选默认使用类的__name__
"""
class_name = class_name or tool_class.__name__
self._tool_classes[class_name] = tool_class
logger.info(f"工具类已注册: {class_name}")
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
"""注册工具实例到系统
Args:
tool: 工具实例
tenant_id: 租户ID内置工具可以为None表示全局工具
Returns:
注册是否成功
"""
async with self._lock:
try:
# 检查工具是否已存在
if tenant_id:
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool.tool_type.value
)
).first()
else:
# 全局工具(内置工具)
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id.is_(None),
ToolConfig.tool_type == tool.tool_type.value
)
).first()
if existing_config:
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
return False
# 创建工具配置
tool_config = ToolConfig(
name=tool.name,
description=tool.description,
tool_type=tool.tool_type.value,
tenant_id=tenant_id,
version=tool.version,
tags=tool.tags,
config_data=tool.config
)
self.db.add(tool_config)
self.db.flush() # 获取ID
# 根据工具类型创建特定配置
if tool.tool_type == ToolType.BUILTIN:
builtin_config = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool.__class__.__name__,
parameters=tool.config.get("parameters", {})
)
self.db.add(builtin_config)
elif tool.tool_type == ToolType.CUSTOM:
custom_config = CustomToolConfig(
id=tool_config.id,
schema_url=tool.config.get("schema_url"),
schema_content=tool.config.get("schema_content"),
auth_type=tool.config.get("auth_type", "none"),
auth_config=tool.config.get("auth_config", {}),
base_url=tool.config.get("base_url"),
timeout=tool.config.get("timeout", 30)
)
self.db.add(custom_config)
elif tool.tool_type == ToolType.MCP:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=tool.config.get("server_url"),
connection_config=tool.config.get("connection_config", {}),
available_tools=tool.config.get("available_tools", [])
)
self.db.add(mcp_config)
self.db.commit()
# 缓存工具实例
tool.tool_id = str(tool_config.id)
self._tools[str(tool_config.id)] = tool
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
return False
async def unregister_tool(self, tool_id: str) -> bool:
"""从系统注销工具
Args:
tool_id: 工具ID
Returns:
注销是否成功
"""
async with self._lock:
try:
# 检查工具是否存在
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 检查是否有正在执行的任务
running_executions = self.db.query(ToolExecution).filter(
and_(
ToolExecution.tool_config_id == uuid.UUID(tool_id),
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
)
).count()
if running_executions > 0:
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
return False
# 删除工具配置(级联删除相关记录)
self.db.delete(tool_config)
self.db.commit()
# 从缓存中移除
if tool_id in self._tools:
del self._tools[tool_id]
logger.info(f"工具注销成功: {tool_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
return False
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
"""获取工具实例
Args:
tool_id: 工具ID
Returns:
工具实例如果不存在返回None
"""
# 先从缓存获取
if tool_id in self._tools:
return self._tools[tool_id]
# 从数据库加载
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
return None
# 根据工具类型加载实例
tool_instance = self._load_tool_instance(tool_config)
if tool_instance:
self._tools[tool_id] = tool_instance
return tool_instance
except Exception as e:
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
return None
def list_tools(
self,
tenant_id: Optional[uuid.UUID] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
tags: Optional[List[str]] = None
) -> List[ToolInfo]:
"""列出工具
Args:
tenant_id: 租户ID过滤
tool_type: 工具类型过滤
status: 工具状态过滤
tags: 标签过滤
Returns:
工具信息列表
"""
try:
query = self.db.query(ToolConfig)
# 应用过滤条件
if tenant_id:
# 返回全局工具tenant_id为空和该租户的工具
query = query.filter(
or_(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tenant_id.is_(None)
)
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status == ToolStatus.ACTIVE:
query = query.filter(ToolConfig.is_enabled == True)
elif status == ToolStatus.INACTIVE:
query = query.filter(ToolConfig.is_enabled == False)
if tags:
for tag in tags:
query = query.filter(ToolConfig.tags.contains([tag]))
tool_configs = query.all()
# 转换为ToolInfo
tool_infos = []
for config in tool_configs:
tool_info = ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
tool_type=ToolType(config.tool_type),
version=config.version,
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None
)
# 尝试获取参数信息
tool_instance = self.get_tool(str(config.id))
if tool_instance:
tool_info.parameters = tool_instance.parameters
tool_infos.append(tool_info)
return tool_infos
except Exception as e:
logger.error(f"列出工具失败, 错误: {e}")
return []
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
"""更新工具状态
Args:
tool_id: 工具ID
status: 新状态
Returns:
更新是否成功
"""
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 更新状态
if status == ToolStatus.ACTIVE:
tool_config.is_enabled = True
elif status == ToolStatus.INACTIVE:
tool_config.is_enabled = False
self.db.commit()
# 更新缓存中的工具状态
if tool_id in self._tools:
self._tools[tool_id].status = status
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
return False
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
"""从配置加载工具实例
Args:
tool_config: 工具配置
Returns:
工具实例
"""
try:
if tool_config.tool_type == ToolType.BUILTIN.value:
# 加载内置工具
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config and builtin_config.tool_class in self._tool_classes:
tool_class = self._tool_classes[builtin_config.tool_class]
config = {
**tool_config.config_data,
"parameters": builtin_config.parameters,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return tool_class(str(tool_config.id), config)
elif tool_config.tool_type == ToolType.CUSTOM.value:
# 加载自定义工具
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
config = {
**tool_config.config_data,
"schema_url": custom_config.schema_url,
"schema_content": custom_config.schema_content,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config,
"base_url": custom_config.base_url,
"timeout": custom_config.timeout,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return CustomTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入自定义工具模块: {e}")
elif tool_config.tool_type == ToolType.MCP.value:
# 加载MCP工具
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
config = {
**tool_config.config_data,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config,
"available_tools": mcp_config.available_tools,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return MCPTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入MCP工具模块: {e}")
except Exception as e:
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
return None
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""获取工具统计信息
Args:
tenant_id: 租户ID
Returns:
统计信息字典
"""
try:
query = self.db.query(ToolConfig)
if tenant_id:
query = query.filter(ToolConfig.tenant_id == tenant_id)
total_tools = query.count()
active_tools = query.filter(ToolConfig.is_enabled == True).count()
# 按类型统计
type_stats = {}
for tool_type in ToolType:
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
type_stats[tool_type.value] = count
return {
"total_tools": total_tools,
"active_tools": active_tools,
"inactive_tools": total_tools - active_tools,
"by_type": type_stats
}
except Exception as e:
logger.error(f"获取工具统计失败, 错误: {e}")
return {}
def clear_cache(self):
"""清空工具缓存"""
self._tools.clear()
logger.info("工具缓存已清空")