Files
MemoryBear/api/app/repositories/tool_repository.py
miao 55b2e05ba8 feat(tools): refactor migrate OpenClaw from custom tool to builtin tool
Create OpenClawTool class inheriting BuiltinTool with dedicated config
Remove all x-openclaw special handling from CustomTool (~270 lines)
Add multi-operation support (print_task, device_query, image_understand, general)
Change ensure_builtin_tools_initialized to incremental mode for auto-provisioning
Fix OperationTool and LangchainAdapter to support OpenClaw operation routing
2026-04-09 18:14:31 +08:00

229 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""工具数据访问层"""
import uuid
from typing import List, Optional
from sqlalchemy.orm import Session
from sqlalchemy import func
from app.models.tool_model import (
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
ToolExecution, ToolType, ToolStatus
)
class ToolRepository:
"""工具仓储类"""
@staticmethod
def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]:
"""根据工作流ID获取tenant_id
Args:
db: 数据库会话
workflow_id: 工作流配置ID
Returns:
tenant_id或None
"""
from app.models.app_model import App
from app.models.workflow_model import WorkflowConfig
from app.models.workspace_model import Workspace
result = db.query(Workspace.tenant_id).join(
App, App.workspace_id == Workspace.id
).join(
WorkflowConfig, WorkflowConfig.app_id == App.id
).filter(
WorkflowConfig.id == workflow_id
).first()
return result[0] if result else None
@staticmethod
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
"""
根据空间ID获取tenant_id
Args:
db: 数据库会话
workspace_id: 空间ID
Returns:
tenant_id或None
"""
from app.models.workspace_model import Workspace
tenant_id = db.query(Workspace.tenant_id).filter(
Workspace.id == workspace_id
).scalar()
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
try:
tenant_id = uuid.UUID(tenant_id)
except (ValueError, TypeError):
return None
return tenant_id
@staticmethod
def find_by_tenant(
db: Session,
tenant_id: uuid.UUID,
name: Optional[str] = None,
tool_type: Optional[ToolType] = None,
status: Optional[ToolStatus] = None,
is_enabled: Optional[bool] = None
) -> List[ToolConfig]:
"""根据租户查找工具(只返回未删除的)"""
query = db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
)
if name:
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
if tool_type:
query = query.filter(ToolConfig.tool_type == tool_type.value)
if status:
query = query.filter(ToolConfig.status == status.value)
if is_enabled is not None:
query = query.filter(ToolConfig.is_enabled == is_enabled)
query = query.order_by(ToolConfig.created_at.desc())
return query.all()
@staticmethod
def find_by_id_and_tenant(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""根据ID和租户查找工具只返回未删除的"""
return db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
).first()
@staticmethod
def find_by_id_and_tenant_all(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
"""根据ID和租户查找工具返回所有工具包括删除的"""
return db.query(ToolConfig).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id
).first()
@staticmethod
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
"""统计租户工具数量(只统计未删除的)"""
return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
).count()
@staticmethod
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
"""获取状态统计"""
return db.query(ToolConfig.status, func.count(ToolConfig.id).label('count')).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
).group_by(ToolConfig.status).all()
@staticmethod
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
"""获取类型统计"""
return db.query(ToolConfig.tool_type, func.count(ToolConfig.id).label('count')).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True)
).group_by(ToolConfig.tool_type).all()
@staticmethod
def count_enabled_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
"""统计租户启用的工具数量"""
return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.is_active.is_(True),
ToolConfig.is_enabled == True
).count()
@staticmethod
def exists_builtin_for_tenant(db: Session, tenant_id: uuid.UUID) -> bool:
"""检查租户是否已有内置工具"""
return db.query(ToolConfig).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == ToolType.BUILTIN.value,
ToolConfig.is_active.is_(True)
).count() > 0
class BuiltinToolRepository:
"""内置工具仓储类"""
@staticmethod
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[BuiltinToolConfig]:
"""根据工具ID查找内置工具配置"""
return db.query(BuiltinToolConfig).filter(
BuiltinToolConfig.id == tool_id
).first()
@staticmethod
def get_existing_tool_classes(db: Session, tenant_id: uuid.UUID) -> set:
"""获取该租户已有的内置工具 tool_class 集合"""
rows = db.query(BuiltinToolConfig.tool_class).join(
ToolConfig, BuiltinToolConfig.id == ToolConfig.id
).filter(
ToolConfig.tenant_id == tenant_id,
ToolConfig.tool_type == ToolType.BUILTIN.value
).all()
return {row[0] for row in rows}
class CustomToolRepository:
"""自定义工具仓储类"""
@staticmethod
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[CustomToolConfig]:
"""根据工具ID查找自定义工具配置"""
return db.query(CustomToolConfig).filter(
CustomToolConfig.id == tool_id
).first()
class MCPToolRepository:
"""MCP工具仓储类"""
@staticmethod
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[MCPToolConfig]:
"""根据工具ID查找MCP工具配置"""
return db.query(MCPToolConfig).filter(
MCPToolConfig.id == tool_id
).first()
@staticmethod
def find_error_connections(db: Session) -> List[MCPToolConfig]:
"""查找连接错误的MCP工具"""
return db.query(MCPToolConfig).filter(
MCPToolConfig.connection_status == "error"
).all()
class ToolExecutionRepository:
"""工具执行仓储类"""
@staticmethod
def find_by_execution_id(db: Session, execution_id: str) -> Optional[ToolExecution]:
"""根据执行ID查找执行记录"""
return db.query(ToolExecution).filter(
ToolExecution.execution_id == execution_id
).first()
@staticmethod
def find_by_tool_and_tenant(
db: Session,
tool_id: uuid.UUID,
tenant_id: uuid.UUID,
limit: int = 100
) -> List[ToolExecution]:
"""根据工具和租户查找执行记录"""
return db.query(ToolExecution).join(
ToolConfig, ToolExecution.tool_config_id == ToolConfig.id
).filter(
ToolConfig.id == tool_id,
ToolConfig.tenant_id == tenant_id
).order_by(ToolExecution.started_at.desc()).limit(limit).all()