"""工具数据访问层""" import uuid from typing import List, Optional from sqlalchemy.orm import Session from sqlalchemy import func from app.models.tool_model import ( ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig, ToolExecution, ToolType, ToolStatus ) class ToolRepository: """工具仓储类""" @staticmethod def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]: """根据工作流ID获取tenant_id Args: db: 数据库会话 workflow_id: 工作流配置ID Returns: tenant_id或None """ from app.models.app_model import App from app.models.workflow_model import WorkflowConfig from app.models.workspace_model import Workspace result = db.query(Workspace.tenant_id).join( App, App.workspace_id == Workspace.id ).join( WorkflowConfig, WorkflowConfig.app_id == App.id ).filter( WorkflowConfig.id == workflow_id ).first() return result[0] if result else None @staticmethod def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]: """ 根据空间ID获取tenant_id Args: db: 数据库会话 workspace_id: 空间ID Returns: tenant_id或None """ from app.models.workspace_model import Workspace tenant_id = db.query(Workspace.tenant_id).filter( Workspace.id == workspace_id ).scalar() if tenant_id is not None and not isinstance(tenant_id, uuid.UUID): # 兼容数据库中字段类型不匹配的情况(比如存储为字符串) try: tenant_id = uuid.UUID(tenant_id) except (ValueError, TypeError): return None return tenant_id @staticmethod def find_by_tenant( db: Session, tenant_id: uuid.UUID, name: Optional[str] = None, tool_type: Optional[ToolType] = None, status: Optional[ToolStatus] = None, is_enabled: Optional[bool] = None ) -> List[ToolConfig]: """根据租户查找工具(只返回未删除的)""" query = db.query(ToolConfig).filter( ToolConfig.tenant_id == tenant_id, ToolConfig.is_active.is_(True) ) if name: query = query.filter(ToolConfig.name.ilike(f"%{name}%")) if tool_type: query = query.filter(ToolConfig.tool_type == tool_type.value) if status: query = query.filter(ToolConfig.status == status.value) if is_enabled is not None: query = query.filter(ToolConfig.is_enabled == is_enabled) query = query.order_by(ToolConfig.created_at.desc()) return query.all() @staticmethod def find_by_id_and_tenant(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]: """根据ID和租户查找工具(只返回未删除的)""" return db.query(ToolConfig).filter( ToolConfig.id == tool_id, ToolConfig.tenant_id == tenant_id, ToolConfig.is_active.is_(True) ).first() @staticmethod def find_by_id_and_tenant_all(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]: """根据ID和租户查找工具(返回所有工具包括删除的)""" return db.query(ToolConfig).filter( ToolConfig.id == tool_id, ToolConfig.tenant_id == tenant_id ).first() @staticmethod def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int: """统计租户工具数量(只统计未删除的)""" return db.query(ToolConfig).filter( ToolConfig.tenant_id == tenant_id, ToolConfig.is_active.is_(True) ).count() @staticmethod def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: """获取状态统计""" return db.query(ToolConfig.status, func.count(ToolConfig.id).label('count')).filter( ToolConfig.tenant_id == tenant_id, ToolConfig.is_active.is_(True) ).group_by(ToolConfig.status).all() @staticmethod def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]: """获取类型统计""" return db.query(ToolConfig.tool_type, func.count(ToolConfig.id).label('count')).filter( ToolConfig.tenant_id == tenant_id, ToolConfig.is_active.is_(True) ).group_by(ToolConfig.tool_type).all() @staticmethod def count_enabled_by_tenant(db: Session, tenant_id: uuid.UUID) -> int: """统计租户启用的工具数量""" return db.query(ToolConfig).filter( ToolConfig.tenant_id == tenant_id, ToolConfig.is_active.is_(True), ToolConfig.is_enabled == True ).count() @staticmethod def exists_builtin_for_tenant(db: Session, tenant_id: uuid.UUID) -> bool: """检查租户是否已有内置工具""" return db.query(ToolConfig).filter( ToolConfig.tenant_id == tenant_id, ToolConfig.tool_type == ToolType.BUILTIN.value, ToolConfig.is_active.is_(True) ).count() > 0 class BuiltinToolRepository: """内置工具仓储类""" @staticmethod def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[BuiltinToolConfig]: """根据工具ID查找内置工具配置""" return db.query(BuiltinToolConfig).filter( BuiltinToolConfig.id == tool_id ).first() class CustomToolRepository: """自定义工具仓储类""" @staticmethod def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[CustomToolConfig]: """根据工具ID查找自定义工具配置""" return db.query(CustomToolConfig).filter( CustomToolConfig.id == tool_id ).first() class MCPToolRepository: """MCP工具仓储类""" @staticmethod def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[MCPToolConfig]: """根据工具ID查找MCP工具配置""" return db.query(MCPToolConfig).filter( MCPToolConfig.id == tool_id ).first() @staticmethod def find_error_connections(db: Session) -> List[MCPToolConfig]: """查找连接错误的MCP工具""" return db.query(MCPToolConfig).filter( MCPToolConfig.connection_status == "error" ).all() class ToolExecutionRepository: """工具执行仓储类""" @staticmethod def find_by_execution_id(db: Session, execution_id: str) -> Optional[ToolExecution]: """根据执行ID查找执行记录""" return db.query(ToolExecution).filter( ToolExecution.execution_id == execution_id ).first() @staticmethod def find_by_tool_and_tenant( db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID, limit: int = 100 ) -> List[ToolExecution]: """根据工具和租户查找执行记录""" return db.query(ToolExecution).join( ToolConfig, ToolExecution.tool_config_id == ToolConfig.id ).filter( ToolConfig.id == tool_id, ToolConfig.tenant_id == tenant_id ).order_by(ToolExecution.started_at.desc()).limit(limit).all()