436 lines
16 KiB
Python
436 lines
16 KiB
Python
"""工具注册表 - 管理所有工具的元数据和状态"""
|
||
import uuid
|
||
import asyncio
|
||
from typing import Dict, List, Optional, Type, Any
|
||
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import and_, or_
|
||
|
||
from app.models.tool_model import (
|
||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||
ToolType, ToolStatus, ToolExecution, ExecutionStatus
|
||
)
|
||
from app.core.logging_config import get_business_logger
|
||
from .base import BaseTool, ToolInfo
|
||
from .custom.base import CustomTool
|
||
from .mcp.base import MCPTool
|
||
|
||
logger = get_business_logger()
|
||
|
||
|
||
class ToolRegistry:
|
||
"""工具注册表 - 管理所有工具的元数据和实例"""
|
||
|
||
def __init__(self, db: Session):
|
||
"""初始化工具注册表
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
"""
|
||
self.db = db
|
||
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
|
||
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
|
||
self._lock = asyncio.Lock() # 异步锁
|
||
|
||
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
|
||
"""注册工具类
|
||
|
||
Args:
|
||
tool_class: 工具类
|
||
class_name: 类名(可选,默认使用类的__name__)
|
||
"""
|
||
class_name = class_name or tool_class.__name__
|
||
self._tool_classes[class_name] = tool_class
|
||
logger.info(f"工具类已注册: {class_name}")
|
||
|
||
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
|
||
"""注册工具实例到系统
|
||
|
||
Args:
|
||
tool: 工具实例
|
||
tenant_id: 租户ID(内置工具可以为None,表示全局工具)
|
||
|
||
Returns:
|
||
注册是否成功
|
||
"""
|
||
async with self._lock:
|
||
try:
|
||
# 检查工具是否已存在
|
||
if tenant_id:
|
||
existing_config = self.db.query(ToolConfig).filter(
|
||
and_(
|
||
ToolConfig.name == tool.name,
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.tool_type == tool.tool_type.value
|
||
)
|
||
).first()
|
||
else:
|
||
# 全局工具(内置工具)
|
||
existing_config = self.db.query(ToolConfig).filter(
|
||
and_(
|
||
ToolConfig.name == tool.name,
|
||
ToolConfig.tenant_id.is_(None),
|
||
ToolConfig.tool_type == tool.tool_type.value
|
||
)
|
||
).first()
|
||
|
||
if existing_config:
|
||
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
|
||
return False
|
||
|
||
# 创建工具配置
|
||
tool_config = ToolConfig(
|
||
name=tool.name,
|
||
description=tool.description,
|
||
tool_type=tool.tool_type.value,
|
||
tenant_id=tenant_id,
|
||
version=tool.version,
|
||
tags=tool.tags,
|
||
config_data=tool.config
|
||
)
|
||
|
||
self.db.add(tool_config)
|
||
self.db.flush() # 获取ID
|
||
|
||
# 根据工具类型创建特定配置
|
||
if tool.tool_type == ToolType.BUILTIN:
|
||
builtin_config = BuiltinToolConfig(
|
||
id=tool_config.id,
|
||
tool_class=tool.__class__.__name__,
|
||
parameters=tool.config.get("parameters", {})
|
||
)
|
||
self.db.add(builtin_config)
|
||
|
||
elif tool.tool_type == ToolType.CUSTOM:
|
||
custom_config = CustomToolConfig(
|
||
id=tool_config.id,
|
||
schema_url=tool.config.get("schema_url"),
|
||
schema_content=tool.config.get("schema_content"),
|
||
auth_type=tool.config.get("auth_type", "none"),
|
||
auth_config=tool.config.get("auth_config", {}),
|
||
base_url=tool.config.get("base_url"),
|
||
timeout=tool.config.get("timeout", 30)
|
||
)
|
||
self.db.add(custom_config)
|
||
|
||
elif tool.tool_type == ToolType.MCP:
|
||
mcp_config = MCPToolConfig(
|
||
id=tool_config.id,
|
||
server_url=tool.config.get("server_url"),
|
||
connection_config=tool.config.get("connection_config", {}),
|
||
available_tools=tool.config.get("available_tools", [])
|
||
)
|
||
self.db.add(mcp_config)
|
||
|
||
self.db.commit()
|
||
|
||
# 缓存工具实例
|
||
tool.tool_id = str(tool_config.id)
|
||
self._tools[str(tool_config.id)] = tool
|
||
|
||
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
|
||
return False
|
||
|
||
async def unregister_tool(self, tool_id: str) -> bool:
|
||
"""从系统注销工具
|
||
|
||
Args:
|
||
tool_id: 工具ID
|
||
|
||
Returns:
|
||
注销是否成功
|
||
"""
|
||
async with self._lock:
|
||
try:
|
||
# 检查工具是否存在
|
||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||
if not tool_config:
|
||
logger.warning(f"工具不存在: {tool_id}")
|
||
return False
|
||
|
||
# 检查是否有正在执行的任务
|
||
running_executions = self.db.query(ToolExecution).filter(
|
||
and_(
|
||
ToolExecution.tool_config_id == uuid.UUID(tool_id),
|
||
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
|
||
)
|
||
).count()
|
||
|
||
if running_executions > 0:
|
||
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
|
||
return False
|
||
|
||
# 删除工具配置(级联删除相关记录)
|
||
self.db.delete(tool_config)
|
||
self.db.commit()
|
||
|
||
# 从缓存中移除
|
||
if tool_id in self._tools:
|
||
del self._tools[tool_id]
|
||
|
||
logger.info(f"工具注销成功: {tool_id}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
|
||
return False
|
||
|
||
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
|
||
"""获取工具实例
|
||
|
||
Args:
|
||
tool_id: 工具ID
|
||
|
||
Returns:
|
||
工具实例,如果不存在返回None
|
||
"""
|
||
# 先从缓存获取
|
||
if tool_id in self._tools:
|
||
return self._tools[tool_id]
|
||
|
||
# 从数据库加载
|
||
try:
|
||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
|
||
return None
|
||
|
||
# 根据工具类型加载实例
|
||
tool_instance = self._load_tool_instance(tool_config)
|
||
if tool_instance:
|
||
self._tools[tool_id] = tool_instance
|
||
return tool_instance
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
|
||
|
||
return None
|
||
|
||
def list_tools(
|
||
self,
|
||
tenant_id: Optional[uuid.UUID] = None,
|
||
tool_type: Optional[ToolType] = None,
|
||
status: Optional[ToolStatus] = None,
|
||
tags: Optional[List[str]] = None
|
||
) -> List[ToolInfo]:
|
||
"""列出工具
|
||
|
||
Args:
|
||
tenant_id: 租户ID过滤
|
||
tool_type: 工具类型过滤
|
||
status: 工具状态过滤
|
||
tags: 标签过滤
|
||
|
||
Returns:
|
||
工具信息列表
|
||
"""
|
||
try:
|
||
query = self.db.query(ToolConfig)
|
||
|
||
# 应用过滤条件
|
||
if tenant_id:
|
||
# 返回全局工具(tenant_id为空)和该租户的工具
|
||
query = query.filter(
|
||
or_(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.tenant_id.is_(None)
|
||
)
|
||
)
|
||
|
||
if tool_type:
|
||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||
|
||
if status == ToolStatus.ACTIVE:
|
||
query = query.filter(ToolConfig.is_enabled == True)
|
||
elif status == ToolStatus.INACTIVE:
|
||
query = query.filter(ToolConfig.is_enabled == False)
|
||
|
||
if tags:
|
||
for tag in tags:
|
||
query = query.filter(ToolConfig.tags.contains([tag]))
|
||
|
||
tool_configs = query.all()
|
||
|
||
# 转换为ToolInfo
|
||
tool_infos = []
|
||
for config in tool_configs:
|
||
tool_info = ToolInfo(
|
||
id=str(config.id),
|
||
name=config.name,
|
||
description=config.description or "",
|
||
tool_type=ToolType(config.tool_type),
|
||
version=config.version,
|
||
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
|
||
tags=config.tags or [],
|
||
tenant_id=str(config.tenant_id) if config.tenant_id else None
|
||
)
|
||
|
||
# 尝试获取参数信息
|
||
tool_instance = self.get_tool(str(config.id))
|
||
if tool_instance:
|
||
tool_info.parameters = tool_instance.parameters
|
||
|
||
tool_infos.append(tool_info)
|
||
|
||
return tool_infos
|
||
|
||
except Exception as e:
|
||
logger.error(f"列出工具失败, 错误: {e}")
|
||
return []
|
||
|
||
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
|
||
"""更新工具状态
|
||
|
||
Args:
|
||
tool_id: 工具ID
|
||
status: 新状态
|
||
|
||
Returns:
|
||
更新是否成功
|
||
"""
|
||
try:
|
||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||
if not tool_config:
|
||
logger.warning(f"工具不存在: {tool_id}")
|
||
return False
|
||
|
||
# 更新状态
|
||
if status == ToolStatus.ACTIVE:
|
||
tool_config.is_enabled = True
|
||
elif status == ToolStatus.INACTIVE:
|
||
tool_config.is_enabled = False
|
||
|
||
self.db.commit()
|
||
|
||
# 更新缓存中的工具状态
|
||
if tool_id in self._tools:
|
||
self._tools[tool_id].status = status
|
||
|
||
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
|
||
return True
|
||
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
|
||
return False
|
||
|
||
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
|
||
"""从配置加载工具实例
|
||
|
||
Args:
|
||
tool_config: 工具配置
|
||
|
||
Returns:
|
||
工具实例
|
||
"""
|
||
try:
|
||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||
# 加载内置工具
|
||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||
BuiltinToolConfig.id == tool_config.id
|
||
).first()
|
||
|
||
if builtin_config and builtin_config.tool_class in self._tool_classes:
|
||
tool_class = self._tool_classes[builtin_config.tool_class]
|
||
config = {
|
||
**tool_config.config_data,
|
||
"parameters": builtin_config.parameters,
|
||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||
"version": tool_config.version,
|
||
"tags": tool_config.tags
|
||
}
|
||
return tool_class(str(tool_config.id), config)
|
||
|
||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||
# 加载自定义工具
|
||
try:
|
||
custom_config = self.db.query(CustomToolConfig).filter(
|
||
CustomToolConfig.id == tool_config.id
|
||
).first()
|
||
|
||
if custom_config:
|
||
config = {
|
||
**tool_config.config_data,
|
||
"schema_url": custom_config.schema_url,
|
||
"schema_content": custom_config.schema_content,
|
||
"auth_type": custom_config.auth_type,
|
||
"auth_config": custom_config.auth_config,
|
||
"base_url": custom_config.base_url,
|
||
"timeout": custom_config.timeout,
|
||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||
"version": tool_config.version,
|
||
"tags": tool_config.tags
|
||
}
|
||
return CustomTool(str(tool_config.id), config)
|
||
except ImportError as e:
|
||
logger.error(f"无法导入自定义工具模块: {e}")
|
||
|
||
elif tool_config.tool_type == ToolType.MCP.value:
|
||
# 加载MCP工具
|
||
try:
|
||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||
MCPToolConfig.id == tool_config.id
|
||
).first()
|
||
|
||
if mcp_config:
|
||
config = {
|
||
**tool_config.config_data,
|
||
"server_url": mcp_config.server_url,
|
||
"connection_config": mcp_config.connection_config,
|
||
"available_tools": mcp_config.available_tools,
|
||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||
"version": tool_config.version,
|
||
"tags": tool_config.tags
|
||
}
|
||
return MCPTool(str(tool_config.id), config)
|
||
except ImportError as e:
|
||
logger.error(f"无法导入MCP工具模块: {e}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
|
||
|
||
return None
|
||
|
||
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||
"""获取工具统计信息
|
||
|
||
Args:
|
||
tenant_id: 租户ID
|
||
|
||
Returns:
|
||
统计信息字典
|
||
"""
|
||
try:
|
||
query = self.db.query(ToolConfig)
|
||
if tenant_id:
|
||
query = query.filter(ToolConfig.tenant_id == tenant_id)
|
||
|
||
total_tools = query.count()
|
||
active_tools = query.filter(ToolConfig.is_enabled == True).count()
|
||
|
||
# 按类型统计
|
||
type_stats = {}
|
||
for tool_type in ToolType:
|
||
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
|
||
type_stats[tool_type.value] = count
|
||
|
||
return {
|
||
"total_tools": total_tools,
|
||
"active_tools": active_tools,
|
||
"inactive_tools": total_tools - active_tools,
|
||
"by_type": type_stats
|
||
}
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取工具统计失败, 错误: {e}")
|
||
return {}
|
||
|
||
def clear_cache(self):
|
||
"""清空工具缓存"""
|
||
self._tools.clear()
|
||
logger.info("工具缓存已清空") |