feat(apikey system): tool system development

This commit is contained in:
谢俊男
2025-12-20 15:24:28 +08:00
parent 3fbd4f206e
commit c26af11f76
39 changed files with 9338 additions and 4 deletions

View File

@@ -0,0 +1,436 @@
"""工具注册表 - 管理所有工具的元数据和状态"""
import uuid
import asyncio
from typing import Dict, List, Optional, Type, Any
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolType, ToolStatus, ToolExecution, ExecutionStatus
)
from app.core.logging_config import get_business_logger
from .base import BaseTool, ToolInfo
from .custom.base import CustomTool
from .mcp.base import MCPTool
logger = get_business_logger()
class ToolRegistry:
"""工具注册表 - 管理所有工具的元数据和实例"""
def __init__(self, db: Session):
"""初始化工具注册表
Args:
db: 数据库会话
"""
self.db = db
self._tools: Dict[str, BaseTool] = {} # 工具实例缓存
self._tool_classes: Dict[str, Type[BaseTool]] = {} # 工具类注册表
self._lock = asyncio.Lock() # 异步锁
def register_tool_class(self, tool_class: Type[BaseTool], class_name: str = None):
"""注册工具类
Args:
tool_class: 工具类
class_name: 类名可选默认使用类的__name__
"""
class_name = class_name or tool_class.__name__
self._tool_classes[class_name] = tool_class
logger.info(f"工具类已注册: {class_name}")
async def register_tool(self, tool: BaseTool, tenant_id: Optional[uuid.UUID] = None) -> bool:
"""注册工具实例到系统
Args:
tool: 工具实例
tenant_id: 租户ID内置工具可以为None表示全局工具
Returns:
注册是否成功
"""
async with self._lock:
try:
# 检查工具是否已存在
if tenant_id:
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == tool.tool_type.value
)
).first()
else:
# 全局工具(内置工具)
existing_config = self.db.query(ToolConfig).filter(
and_(
ToolConfig.name == tool.name,
ToolConfig.tenant_id.is_(None),
ToolConfig.tool_type == tool.tool_type.value
)
).first()
if existing_config:
logger.warning(f"工具已存在: {tool.name} (tenant: {tenant_id or 'global'})")
return False
# 创建工具配置
tool_config = ToolConfig(
name=tool.name,
description=tool.description,
tool_type=tool.tool_type.value,
tenant_id=tenant_id,
version=tool.version,
tags=tool.tags,
config_data=tool.config
)
self.db.add(tool_config)
self.db.flush() # 获取ID
# 根据工具类型创建特定配置
if tool.tool_type == ToolType.BUILTIN:
builtin_config = BuiltinToolConfig(
id=tool_config.id,
tool_class=tool.__class__.__name__,
parameters=tool.config.get("parameters", {})
)
self.db.add(builtin_config)
elif tool.tool_type == ToolType.CUSTOM:
custom_config = CustomToolConfig(
id=tool_config.id,
schema_url=tool.config.get("schema_url"),
schema_content=tool.config.get("schema_content"),
auth_type=tool.config.get("auth_type", "none"),
auth_config=tool.config.get("auth_config", {}),
base_url=tool.config.get("base_url"),
timeout=tool.config.get("timeout", 30)
)
self.db.add(custom_config)
elif tool.tool_type == ToolType.MCP:
mcp_config = MCPToolConfig(
id=tool_config.id,
server_url=tool.config.get("server_url"),
connection_config=tool.config.get("connection_config", {}),
available_tools=tool.config.get("available_tools", [])
)
self.db.add(mcp_config)
self.db.commit()
# 缓存工具实例
tool.tool_id = str(tool_config.id)
self._tools[str(tool_config.id)] = tool
logger.info(f"工具注册成功: {tool.name} (ID: {tool_config.id})")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注册失败: {tool.name}, 错误: {e}")
return False
async def unregister_tool(self, tool_id: str) -> bool:
"""从系统注销工具
Args:
tool_id: 工具ID
Returns:
注销是否成功
"""
async with self._lock:
try:
# 检查工具是否存在
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 检查是否有正在执行的任务
running_executions = self.db.query(ToolExecution).filter(
and_(
ToolExecution.tool_config_id == uuid.UUID(tool_id),
ToolExecution.status.in_([ExecutionStatus.PENDING.value, ExecutionStatus.RUNNING.value])
)
).count()
if running_executions > 0:
logger.warning(f"工具有正在执行的任务,无法注销: {tool_id}")
return False
# 删除工具配置(级联删除相关记录)
self.db.delete(tool_config)
self.db.commit()
# 从缓存中移除
if tool_id in self._tools:
del self._tools[tool_id]
logger.info(f"工具注销成功: {tool_id}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具注销失败: {tool_id}, 错误: {e}")
return False
def get_tool(self, tool_id: str) -> Optional[BaseTool]:
"""获取工具实例
Args:
tool_id: 工具ID
Returns:
工具实例如果不存在返回None
"""
# 先从缓存获取
if tool_id in self._tools:
return self._tools[tool_id]
# 从数据库加载
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config or not tool_config.status == ToolStatus.ACTIVE.value:
return None
# 根据工具类型加载实例
tool_instance = self._load_tool_instance(tool_config)
if tool_instance:
self._tools[tool_id] = tool_instance
return tool_instance
except Exception as e:
logger.error(f"加载工具失败: {tool_id}, 错误: {e}")
return None
def list_tools(
self,
tenant_id: Optional[uuid.UUID] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
tags: Optional[List[str]] = None
) -> List[ToolInfo]:
"""列出工具
Args:
tenant_id: 租户ID过滤
tool_type: 工具类型过滤
status: 工具状态过滤
tags: 标签过滤
Returns:
工具信息列表
"""
try:
query = self.db.query(ToolConfig)
# 应用过滤条件
if tenant_id:
# 返回全局工具tenant_id为空和该租户的工具
query = query.filter(
or_(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tenant_id.is_(None)
)
)
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status == ToolStatus.ACTIVE:
query = query.filter(ToolConfig.is_enabled == True)
elif status == ToolStatus.INACTIVE:
query = query.filter(ToolConfig.is_enabled == False)
if tags:
for tag in tags:
query = query.filter(ToolConfig.tags.contains([tag]))
tool_configs = query.all()
# 转换为ToolInfo
tool_infos = []
for config in tool_configs:
tool_info = ToolInfo(
id=str(config.id),
name=config.name,
description=config.description or "",
tool_type=ToolType(config.tool_type),
version=config.version,
status=ToolStatus.ACTIVE if config.is_enabled else ToolStatus.INACTIVE,
tags=config.tags or [],
tenant_id=str(config.tenant_id) if config.tenant_id else None
)
# 尝试获取参数信息
tool_instance = self.get_tool(str(config.id))
if tool_instance:
tool_info.parameters = tool_instance.parameters
tool_infos.append(tool_info)
return tool_infos
except Exception as e:
logger.error(f"列出工具失败, 错误: {e}")
return []
async def update_tool_status(self, tool_id: str, status: ToolStatus) -> bool:
"""更新工具状态
Args:
tool_id: 工具ID
status: 新状态
Returns:
更新是否成功
"""
try:
tool_config = self.db.get(ToolConfig, uuid.UUID(tool_id))
if not tool_config:
logger.warning(f"工具不存在: {tool_id}")
return False
# 更新状态
if status == ToolStatus.ACTIVE:
tool_config.is_enabled = True
elif status == ToolStatus.INACTIVE:
tool_config.is_enabled = False
self.db.commit()
# 更新缓存中的工具状态
if tool_id in self._tools:
self._tools[tool_id].status = status
logger.info(f"工具状态更新成功: {tool_id} -> {status}")
return True
except Exception as e:
self.db.rollback()
logger.error(f"工具状态更新失败: {tool_id}, 错误: {e}")
return False
def _load_tool_instance(self, tool_config: type[ToolConfig] | None) -> Optional[BaseTool]:
"""从配置加载工具实例
Args:
tool_config: 工具配置
Returns:
工具实例
"""
try:
if tool_config.tool_type == ToolType.BUILTIN.value:
# 加载内置工具
builtin_config = self.db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_config.id
).first()
if builtin_config and builtin_config.tool_class in self._tool_classes:
tool_class = self._tool_classes[builtin_config.tool_class]
config = {
**tool_config.config_data,
"parameters": builtin_config.parameters,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return tool_class(str(tool_config.id), config)
elif tool_config.tool_type == ToolType.CUSTOM.value:
# 加载自定义工具
try:
custom_config = self.db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_config.id
).first()
if custom_config:
config = {
**tool_config.config_data,
"schema_url": custom_config.schema_url,
"schema_content": custom_config.schema_content,
"auth_type": custom_config.auth_type,
"auth_config": custom_config.auth_config,
"base_url": custom_config.base_url,
"timeout": custom_config.timeout,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return CustomTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入自定义工具模块: {e}")
elif tool_config.tool_type == ToolType.MCP.value:
# 加载MCP工具
try:
mcp_config = self.db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_config.id
).first()
if mcp_config:
config = {
**tool_config.config_data,
"server_url": mcp_config.server_url,
"connection_config": mcp_config.connection_config,
"available_tools": mcp_config.available_tools,
"tenant_id": str(tool_config.tenant_id) if tool_config.tenant_id else None,
"version": tool_config.version,
"tags": tool_config.tags
}
return MCPTool(str(tool_config.id), config)
except ImportError as e:
logger.error(f"无法导入MCP工具模块: {e}")
except Exception as e:
logger.error(f"加载工具实例失败: {tool_config.id}, 错误: {e}")
return None
def get_tool_statistics(self, tenant_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""获取工具统计信息
Args:
tenant_id: 租户ID
Returns:
统计信息字典
"""
try:
query = self.db.query(ToolConfig)
if tenant_id:
query = query.filter(ToolConfig.tenant_id == tenant_id)
total_tools = query.count()
active_tools = query.filter(ToolConfig.is_enabled == True).count()
# 按类型统计
type_stats = {}
for tool_type in ToolType:
count = query.filter(ToolConfig.tool_type == tool_type.value).count()
type_stats[tool_type.value] = count
return {
"total_tools": total_tools,
"active_tools": active_tools,
"inactive_tools": total_tools - active_tools,
"by_type": type_stats
}
except Exception as e:
logger.error(f"获取工具统计失败, 错误: {e}")
return {}
def clear_cache(self):
"""清空工具缓存"""
self._tools.clear()
logger.info("工具缓存已清空")