Files
MemoryBear/api/app/core/tools/registry.py
2025-12-20 15:24:28 +08:00

436 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""工具注册表 - 管理所有工具的元数据和状态"""
import uuid
import asyncio
from typing import Dict, List, Optional, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolType, ToolStatus, ToolExecution, ExecutionStatus
)
from app.core.logging_config import get_business_logger
from .base import BaseTool, ToolInfo
from .custom.base import CustomTool
from .mcp.base import MCPTool
logger = get_business_logger()
class ToolRegistry:
"""工具注册表 - 管理所有工具的元数据和实例"""
def __init__(self, db: Session):
"""初始化工具注册表
Args:
db: 数据库会话
"""
self.db = db
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
self._lock = asyncio.Lock() # 异步锁
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
"""注册工具类
Args:
tool_class: 工具类
class_name: 类名可选默认使用类的__name__
"""
class_name = class_name or tool_class.__name__
self._tool_classes[class_name] = tool_class
logger.info(f"工具类已注册: {class_name}")
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
"""注册工具实例到系统
Args:
tool: 工具实例
tenant_id: 租户ID内置工具可以为None表示全局工具
Returns:
注册是否成功
"""
async with self._lock:
try:
# 检查工具是否已存在
if tenant_id:
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool.tool_type.value
)
).first()
else:
# 全局工具(内置工具)
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id.is_(None),
ToolConfig.tool_type == tool.tool_type.value
)
).first()
if existing_config:
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
return False
# 创建工具配置
tool_config = ToolConfig(
name=tool.name,
description=tool.description,
tool_type=tool.tool_type.value,
tenant_id=tenant_id,
version=tool.version,
tags=tool.tags,
config_data=tool.config
)
self.db.add(tool_config)
self.db.flush() # 获取ID
# 根据工具类型创建特定配置
if tool.tool_type == ToolType.BUILTIN:
builtin_config = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool.__class__.__name__,
parameters=tool.config.get("parameters", {})
)
self.db.add(builtin_config)
elif tool.tool_type == ToolType.CUSTOM:
custom_config = CustomToolConfig(
id=tool_config.id,
schema_url=tool.config.get("schema_url"),
schema_content=tool.config.get("schema_content"),
auth_type=tool.config.get("auth_type", "none"),
auth_config=tool.config.get("auth_config", {}),
base_url=tool.config.get("base_url"),
timeout=tool.config.get("timeout", 30)
)
self.db.add(custom_config)
elif tool.tool_type == ToolType.MCP:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=tool.config.get("server_url"),
connection_config=tool.config.get("connection_config", {}),
available_tools=tool.config.get("available_tools", [])
)
self.db.add(mcp_config)
self.db.commit()
# 缓存工具实例
tool.tool_id = str(tool_config.id)
self._tools[str(tool_config.id)] = tool
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
return False
async def unregister_tool(self, tool_id: str) -> bool:
"""从系统注销工具
Args:
tool_id: 工具ID
Returns:
注销是否成功
"""
async with self._lock:
try:
# 检查工具是否存在
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 检查是否有正在执行的任务
running_executions = self.db.query(ToolExecution).filter(
and_(
ToolExecution.tool_config_id == uuid.UUID(tool_id),
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
)
).count()
if running_executions > 0:
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
return False
# 删除工具配置(级联删除相关记录)
self.db.delete(tool_config)
self.db.commit()
# 从缓存中移除
if tool_id in self._tools:
del self._tools[tool_id]
logger.info(f"工具注销成功: {tool_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
return False
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
"""获取工具实例
Args:
tool_id: 工具ID
Returns:
工具实例如果不存在返回None
"""
# 先从缓存获取
if tool_id in self._tools:
return self._tools[tool_id]
# 从数据库加载
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
return None
# 根据工具类型加载实例
tool_instance = self._load_tool_instance(tool_config)
if tool_instance:
self._tools[tool_id] = tool_instance
return tool_instance
except Exception as e:
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
return None
def list_tools(
self,
tenant_id: Optional[uuid.UUID] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
tags: Optional[List[str]] = None
) -> List[ToolInfo]:
"""列出工具
Args:
tenant_id: 租户ID过滤
tool_type: 工具类型过滤
status: 工具状态过滤
tags: 标签过滤
Returns:
工具信息列表
"""
try:
query = self.db.query(ToolConfig)
# 应用过滤条件
if tenant_id:
# 返回全局工具tenant_id为空和该租户的工具
query = query.filter(
or_(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tenant_id.is_(None)
)
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status == ToolStatus.ACTIVE:
query = query.filter(ToolConfig.is_enabled == True)
elif status == ToolStatus.INACTIVE:
query = query.filter(ToolConfig.is_enabled == False)
if tags:
for tag in tags:
query = query.filter(ToolConfig.tags.contains([tag]))
tool_configs = query.all()
# 转换为ToolInfo
tool_infos = []
for config in tool_configs:
tool_info = ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
tool_type=ToolType(config.tool_type),
version=config.version,
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None
)
# 尝试获取参数信息
tool_instance = self.get_tool(str(config.id))
if tool_instance:
tool_info.parameters = tool_instance.parameters
tool_infos.append(tool_info)
return tool_infos
except Exception as e:
logger.error(f"列出工具失败, 错误: {e}")
return []
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
"""更新工具状态
Args:
tool_id: 工具ID
status: 新状态
Returns:
更新是否成功
"""
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 更新状态
if status == ToolStatus.ACTIVE:
tool_config.is_enabled = True
elif status == ToolStatus.INACTIVE:
tool_config.is_enabled = False
self.db.commit()
# 更新缓存中的工具状态
if tool_id in self._tools:
self._tools[tool_id].status = status
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
return False
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
"""从配置加载工具实例
Args:
tool_config: 工具配置
Returns:
工具实例
"""
try:
if tool_config.tool_type == ToolType.BUILTIN.value:
# 加载内置工具
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config and builtin_config.tool_class in self._tool_classes:
tool_class = self._tool_classes[builtin_config.tool_class]
config = {
**tool_config.config_data,
"parameters": builtin_config.parameters,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return tool_class(str(tool_config.id), config)
elif tool_config.tool_type == ToolType.CUSTOM.value:
# 加载自定义工具
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
config = {
**tool_config.config_data,
"schema_url": custom_config.schema_url,
"schema_content": custom_config.schema_content,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config,
"base_url": custom_config.base_url,
"timeout": custom_config.timeout,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return CustomTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入自定义工具模块: {e}")
elif tool_config.tool_type == ToolType.MCP.value:
# 加载MCP工具
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
config = {
**tool_config.config_data,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config,
"available_tools": mcp_config.available_tools,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return MCPTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入MCP工具模块: {e}")
except Exception as e:
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
return None
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""获取工具统计信息
Args:
tenant_id: 租户ID
Returns:
统计信息字典
"""
try:
query = self.db.query(ToolConfig)
if tenant_id:
query = query.filter(ToolConfig.tenant_id == tenant_id)
total_tools = query.count()
active_tools = query.filter(ToolConfig.is_enabled == True).count()
# 按类型统计
type_stats = {}
for tool_type in ToolType:
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
type_stats[tool_type.value] = count
return {
"total_tools": total_tools,
"active_tools": active_tools,
"inactive_tools": total_tools - active_tools,
"by_type": type_stats
}
except Exception as e:
logger.error(f"获取工具统计失败, 错误: {e}")
return {}
def clear_cache(self):
"""清空工具缓存"""
self._tools.clear()
logger.info("工具缓存已清空")