feat(apikey system): tool system development
This commit is contained in:
436
api/app/core/tools/registry.py
Normal file
436
api/app/core/tools/registry.py
Normal file
@@ -0,0 +1,436 @@
|
||||
"""工具注册表 - 管理所有工具的元数据和状态"""
|
||||
import uuid
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Type, Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
|
||||
from app.models.tool_model import (
|
||||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||||
ToolType, ToolStatus, ToolExecution, ExecutionStatus
|
||||
)
|
||||
from app.core.logging_config import get_business_logger
|
||||
from .base import BaseTool, ToolInfo
|
||||
from .custom.base import CustomTool
|
||||
from .mcp.base import MCPTool
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
"""工具注册表 - 管理所有工具的元数据和实例"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""初始化工具注册表
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
"""
|
||||
self.db = db
|
||||
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
|
||||
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
|
||||
self._lock = asyncio.Lock() # 异步锁
|
||||
|
||||
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
|
||||
"""注册工具类
|
||||
|
||||
Args:
|
||||
tool_class: 工具类
|
||||
class_name: 类名(可选,默认使用类的__name__)
|
||||
"""
|
||||
class_name = class_name or tool_class.__name__
|
||||
self._tool_classes[class_name] = tool_class
|
||||
logger.info(f"工具类已注册: {class_name}")
|
||||
|
||||
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
|
||||
"""注册工具实例到系统
|
||||
|
||||
Args:
|
||||
tool: 工具实例
|
||||
tenant_id: 租户ID(内置工具可以为None,表示全局工具)
|
||||
|
||||
Returns:
|
||||
注册是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否已存在
|
||||
if tenant_id:
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
else:
|
||||
# 全局工具(内置工具)
|
||||
existing_config = self.db.query(ToolConfig).filter(
|
||||
and_(
|
||||
ToolConfig.name == tool.name,
|
||||
ToolConfig.tenant_id.is_(None),
|
||||
ToolConfig.tool_type == tool.tool_type.value
|
||||
)
|
||||
).first()
|
||||
|
||||
if existing_config:
|
||||
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
|
||||
return False
|
||||
|
||||
# 创建工具配置
|
||||
tool_config = ToolConfig(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
tool_type=tool.tool_type.value,
|
||||
tenant_id=tenant_id,
|
||||
version=tool.version,
|
||||
tags=tool.tags,
|
||||
config_data=tool.config
|
||||
)
|
||||
|
||||
self.db.add(tool_config)
|
||||
self.db.flush() # 获取ID
|
||||
|
||||
# 根据工具类型创建特定配置
|
||||
if tool.tool_type == ToolType.BUILTIN:
|
||||
builtin_config = BuiltinToolConfig(
|
||||
id=tool_config.id,
|
||||
tool_class=tool.__class__.__name__,
|
||||
parameters=tool.config.get("parameters", {})
|
||||
)
|
||||
self.db.add(builtin_config)
|
||||
|
||||
elif tool.tool_type == ToolType.CUSTOM:
|
||||
custom_config = CustomToolConfig(
|
||||
id=tool_config.id,
|
||||
schema_url=tool.config.get("schema_url"),
|
||||
schema_content=tool.config.get("schema_content"),
|
||||
auth_type=tool.config.get("auth_type", "none"),
|
||||
auth_config=tool.config.get("auth_config", {}),
|
||||
base_url=tool.config.get("base_url"),
|
||||
timeout=tool.config.get("timeout", 30)
|
||||
)
|
||||
self.db.add(custom_config)
|
||||
|
||||
elif tool.tool_type == ToolType.MCP:
|
||||
mcp_config = MCPToolConfig(
|
||||
id=tool_config.id,
|
||||
server_url=tool.config.get("server_url"),
|
||||
connection_config=tool.config.get("connection_config", {}),
|
||||
available_tools=tool.config.get("available_tools", [])
|
||||
)
|
||||
self.db.add(mcp_config)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 缓存工具实例
|
||||
tool.tool_id = str(tool_config.id)
|
||||
self._tools[str(tool_config.id)] = tool
|
||||
|
||||
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
|
||||
return False
|
||||
|
||||
async def unregister_tool(self, tool_id: str) -> bool:
|
||||
"""从系统注销工具
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
注销是否成功
|
||||
"""
|
||||
async with self._lock:
|
||||
try:
|
||||
# 检查工具是否存在
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 检查是否有正在执行的任务
|
||||
running_executions = self.db.query(ToolExecution).filter(
|
||||
and_(
|
||||
ToolExecution.tool_config_id == uuid.UUID(tool_id),
|
||||
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
|
||||
)
|
||||
).count()
|
||||
|
||||
if running_executions > 0:
|
||||
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
|
||||
return False
|
||||
|
||||
# 删除工具配置(级联删除相关记录)
|
||||
self.db.delete(tool_config)
|
||||
self.db.commit()
|
||||
|
||||
# 从缓存中移除
|
||||
if tool_id in self._tools:
|
||||
del self._tools[tool_id]
|
||||
|
||||
logger.info(f"工具注销成功: {tool_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
|
||||
"""获取工具实例
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
|
||||
Returns:
|
||||
工具实例,如果不存在返回None
|
||||
"""
|
||||
# 先从缓存获取
|
||||
if tool_id in self._tools:
|
||||
return self._tools[tool_id]
|
||||
|
||||
# 从数据库加载
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
|
||||
return None
|
||||
|
||||
# 根据工具类型加载实例
|
||||
tool_instance = self._load_tool_instance(tool_config)
|
||||
if tool_instance:
|
||||
self._tools[tool_id] = tool_instance
|
||||
return tool_instance
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tenant_id: Optional[uuid.UUID] = None,
|
||||
tool_type: Optional[ToolType] = None,
|
||||
status: Optional[ToolStatus] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[ToolInfo]:
|
||||
"""列出工具
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID过滤
|
||||
tool_type: 工具类型过滤
|
||||
status: 工具状态过滤
|
||||
tags: 标签过滤
|
||||
|
||||
Returns:
|
||||
工具信息列表
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
|
||||
# 应用过滤条件
|
||||
if tenant_id:
|
||||
# 返回全局工具(tenant_id为空)和该租户的工具
|
||||
query = query.filter(
|
||||
or_(
|
||||
ToolConfig.tenant_id == tenant_id,
|
||||
ToolConfig.tenant_id.is_(None)
|
||||
)
|
||||
)
|
||||
|
||||
if tool_type:
|
||||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||||
|
||||
if status == ToolStatus.ACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == True)
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
query = query.filter(ToolConfig.is_enabled == False)
|
||||
|
||||
if tags:
|
||||
for tag in tags:
|
||||
query = query.filter(ToolConfig.tags.contains([tag]))
|
||||
|
||||
tool_configs = query.all()
|
||||
|
||||
# 转换为ToolInfo
|
||||
tool_infos = []
|
||||
for config in tool_configs:
|
||||
tool_info = ToolInfo(
|
||||
id=str(config.id),
|
||||
name=config.name,
|
||||
description=config.description or "",
|
||||
tool_type=ToolType(config.tool_type),
|
||||
version=config.version,
|
||||
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
|
||||
tags=config.tags or [],
|
||||
tenant_id=str(config.tenant_id) if config.tenant_id else None
|
||||
)
|
||||
|
||||
# 尝试获取参数信息
|
||||
tool_instance = self.get_tool(str(config.id))
|
||||
if tool_instance:
|
||||
tool_info.parameters = tool_instance.parameters
|
||||
|
||||
tool_infos.append(tool_info)
|
||||
|
||||
return tool_infos
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出工具失败, 错误: {e}")
|
||||
return []
|
||||
|
||||
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
|
||||
"""更新工具状态
|
||||
|
||||
Args:
|
||||
tool_id: 工具ID
|
||||
status: 新状态
|
||||
|
||||
Returns:
|
||||
更新是否成功
|
||||
"""
|
||||
try:
|
||||
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
|
||||
if not tool_config:
|
||||
logger.warning(f"工具不存在: {tool_id}")
|
||||
return False
|
||||
|
||||
# 更新状态
|
||||
if status == ToolStatus.ACTIVE:
|
||||
tool_config.is_enabled = True
|
||||
elif status == ToolStatus.INACTIVE:
|
||||
tool_config.is_enabled = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# 更新缓存中的工具状态
|
||||
if tool_id in self._tools:
|
||||
self._tools[tool_id].status = status
|
||||
|
||||
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
|
||||
return False
|
||||
|
||||
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
|
||||
"""从配置加载工具实例
|
||||
|
||||
Args:
|
||||
tool_config: 工具配置
|
||||
|
||||
Returns:
|
||||
工具实例
|
||||
"""
|
||||
try:
|
||||
if tool_config.tool_type == ToolType.BUILTIN.value:
|
||||
# 加载内置工具
|
||||
builtin_config = self.db.query(BuiltinToolConfig).filter(
|
||||
BuiltinToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if builtin_config and builtin_config.tool_class in self._tool_classes:
|
||||
tool_class = self._tool_classes[builtin_config.tool_class]
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"parameters": builtin_config.parameters,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return tool_class(str(tool_config.id), config)
|
||||
|
||||
elif tool_config.tool_type == ToolType.CUSTOM.value:
|
||||
# 加载自定义工具
|
||||
try:
|
||||
custom_config = self.db.query(CustomToolConfig).filter(
|
||||
CustomToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if custom_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"schema_url": custom_config.schema_url,
|
||||
"schema_content": custom_config.schema_content,
|
||||
"auth_type": custom_config.auth_type,
|
||||
"auth_config": custom_config.auth_config,
|
||||
"base_url": custom_config.base_url,
|
||||
"timeout": custom_config.timeout,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return CustomTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入自定义工具模块: {e}")
|
||||
|
||||
elif tool_config.tool_type == ToolType.MCP.value:
|
||||
# 加载MCP工具
|
||||
try:
|
||||
mcp_config = self.db.query(MCPToolConfig).filter(
|
||||
MCPToolConfig.id == tool_config.id
|
||||
).first()
|
||||
|
||||
if mcp_config:
|
||||
config = {
|
||||
**tool_config.config_data,
|
||||
"server_url": mcp_config.server_url,
|
||||
"connection_config": mcp_config.connection_config,
|
||||
"available_tools": mcp_config.available_tools,
|
||||
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
|
||||
"version": tool_config.version,
|
||||
"tags": tool_config.tags
|
||||
}
|
||||
return MCPTool(str(tool_config.id), config)
|
||||
except ImportError as e:
|
||||
logger.error(f"无法导入MCP工具模块: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
|
||||
"""获取工具统计信息
|
||||
|
||||
Args:
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
try:
|
||||
query = self.db.query(ToolConfig)
|
||||
if tenant_id:
|
||||
query = query.filter(ToolConfig.tenant_id == tenant_id)
|
||||
|
||||
total_tools = query.count()
|
||||
active_tools = query.filter(ToolConfig.is_enabled == True).count()
|
||||
|
||||
# 按类型统计
|
||||
type_stats = {}
|
||||
for tool_type in ToolType:
|
||||
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
|
||||
type_stats[tool_type.value] = count
|
||||
|
||||
return {
|
||||
"total_tools": total_tools,
|
||||
"active_tools": active_tools,
|
||||
"inactive_tools": total_tools - active_tools,
|
||||
"by_type": type_stats
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工具统计失败, 错误: {e}")
|
||||
return {}
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空工具缓存"""
|
||||
self._tools.clear()
|
||||
logger.info("工具缓存已清空")
|
||||
Reference in New Issue
Block a user