Create OpenClawTool class inheriting BuiltinTool with dedicated config Remove all x-openclaw special handling from CustomTool (~270 lines) Add multi-operation support (print_task, device_query, image_understand, general) Change ensure_builtin_tools_initialized to incremental mode for auto-provisioning Fix OperationTool and LangchainAdapter to support OpenClaw operation routing
229 lines
7.8 KiB
Python
229 lines
7.8 KiB
Python
"""工具数据访问层"""
|
||
import uuid
|
||
from typing import List, Optional
|
||
from sqlalchemy.orm import Session
|
||
from sqlalchemy import func
|
||
|
||
from app.models.tool_model import (
|
||
ToolConfig, BuiltinToolConfig, CustomToolConfig, MCPToolConfig,
|
||
ToolExecution, ToolType, ToolStatus
|
||
)
|
||
|
||
|
||
class ToolRepository:
|
||
"""工具仓储类"""
|
||
|
||
@staticmethod
|
||
def get_tenant_id_by_workflow_id(db: Session, workflow_id: uuid.UUID) -> Optional[uuid.UUID]:
|
||
"""根据工作流ID获取tenant_id
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
workflow_id: 工作流配置ID
|
||
|
||
Returns:
|
||
tenant_id或None
|
||
"""
|
||
from app.models.app_model import App
|
||
from app.models.workflow_model import WorkflowConfig
|
||
from app.models.workspace_model import Workspace
|
||
|
||
result = db.query(Workspace.tenant_id).join(
|
||
App, App.workspace_id == Workspace.id
|
||
).join(
|
||
WorkflowConfig, WorkflowConfig.app_id == App.id
|
||
).filter(
|
||
WorkflowConfig.id == workflow_id
|
||
).first()
|
||
|
||
return result[0] if result else None
|
||
|
||
@staticmethod
|
||
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
|
||
"""
|
||
根据空间ID获取tenant_id
|
||
|
||
Args:
|
||
db: 数据库会话
|
||
workspace_id: 空间ID
|
||
|
||
Returns:
|
||
tenant_id或None
|
||
"""
|
||
from app.models.workspace_model import Workspace
|
||
|
||
tenant_id = db.query(Workspace.tenant_id).filter(
|
||
Workspace.id == workspace_id
|
||
).scalar()
|
||
|
||
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
|
||
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
|
||
try:
|
||
tenant_id = uuid.UUID(tenant_id)
|
||
except (ValueError, TypeError):
|
||
return None
|
||
|
||
return tenant_id
|
||
|
||
@staticmethod
|
||
def find_by_tenant(
|
||
db: Session,
|
||
tenant_id: uuid.UUID,
|
||
name: Optional[str] = None,
|
||
tool_type: Optional[ToolType] = None,
|
||
status: Optional[ToolStatus] = None,
|
||
is_enabled: Optional[bool] = None
|
||
) -> List[ToolConfig]:
|
||
"""根据租户查找工具(只返回未删除的)"""
|
||
query = db.query(ToolConfig).filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.is_active.is_(True)
|
||
)
|
||
|
||
if name:
|
||
query = query.filter(ToolConfig.name.ilike(f"%{name}%"))
|
||
if tool_type:
|
||
query = query.filter(ToolConfig.tool_type == tool_type.value)
|
||
if status:
|
||
query = query.filter(ToolConfig.status == status.value)
|
||
if is_enabled is not None:
|
||
query = query.filter(ToolConfig.is_enabled == is_enabled)
|
||
query = query.order_by(ToolConfig.created_at.desc())
|
||
return query.all()
|
||
|
||
@staticmethod
|
||
def find_by_id_and_tenant(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||
"""根据ID和租户查找工具(只返回未删除的)"""
|
||
return db.query(ToolConfig).filter(
|
||
ToolConfig.id == tool_id,
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.is_active.is_(True)
|
||
).first()
|
||
|
||
@staticmethod
|
||
def find_by_id_and_tenant_all(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||
"""根据ID和租户查找工具(返回所有工具包括删除的)"""
|
||
return db.query(ToolConfig).filter(
|
||
ToolConfig.id == tool_id,
|
||
ToolConfig.tenant_id == tenant_id
|
||
).first()
|
||
|
||
@staticmethod
|
||
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
|
||
"""统计租户工具数量(只统计未删除的)"""
|
||
return db.query(ToolConfig).filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.is_active.is_(True)
|
||
).count()
|
||
|
||
@staticmethod
|
||
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||
"""获取状态统计"""
|
||
return db.query(ToolConfig.status, func.count(ToolConfig.id).label('count')).filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.is_active.is_(True)
|
||
).group_by(ToolConfig.status).all()
|
||
|
||
@staticmethod
|
||
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||
"""获取类型统计"""
|
||
return db.query(ToolConfig.tool_type, func.count(ToolConfig.id).label('count')).filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.is_active.is_(True)
|
||
).group_by(ToolConfig.tool_type).all()
|
||
|
||
@staticmethod
|
||
def count_enabled_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
|
||
"""统计租户启用的工具数量"""
|
||
return db.query(ToolConfig).filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.is_active.is_(True),
|
||
ToolConfig.is_enabled == True
|
||
).count()
|
||
|
||
@staticmethod
|
||
def exists_builtin_for_tenant(db: Session, tenant_id: uuid.UUID) -> bool:
|
||
"""检查租户是否已有内置工具"""
|
||
return db.query(ToolConfig).filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.tool_type == ToolType.BUILTIN.value,
|
||
ToolConfig.is_active.is_(True)
|
||
).count() > 0
|
||
|
||
|
||
class BuiltinToolRepository:
|
||
"""内置工具仓储类"""
|
||
|
||
@staticmethod
|
||
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[BuiltinToolConfig]:
|
||
"""根据工具ID查找内置工具配置"""
|
||
return db.query(BuiltinToolConfig).filter(
|
||
BuiltinToolConfig.id == tool_id
|
||
).first()
|
||
|
||
@staticmethod
|
||
def get_existing_tool_classes(db: Session, tenant_id: uuid.UUID) -> set:
|
||
"""获取该租户已有的内置工具 tool_class 集合"""
|
||
rows = db.query(BuiltinToolConfig.tool_class).join(
|
||
ToolConfig, BuiltinToolConfig.id == ToolConfig.id
|
||
).filter(
|
||
ToolConfig.tenant_id == tenant_id,
|
||
ToolConfig.tool_type == ToolType.BUILTIN.value
|
||
).all()
|
||
return {row[0] for row in rows}
|
||
|
||
|
||
class CustomToolRepository:
|
||
"""自定义工具仓储类"""
|
||
|
||
@staticmethod
|
||
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[CustomToolConfig]:
|
||
"""根据工具ID查找自定义工具配置"""
|
||
return db.query(CustomToolConfig).filter(
|
||
CustomToolConfig.id == tool_id
|
||
).first()
|
||
|
||
|
||
class MCPToolRepository:
|
||
"""MCP工具仓储类"""
|
||
|
||
@staticmethod
|
||
def find_by_tool_id(db: Session, tool_id: uuid.UUID) -> Optional[MCPToolConfig]:
|
||
"""根据工具ID查找MCP工具配置"""
|
||
return db.query(MCPToolConfig).filter(
|
||
MCPToolConfig.id == tool_id
|
||
).first()
|
||
|
||
@staticmethod
|
||
def find_error_connections(db: Session) -> List[MCPToolConfig]:
|
||
"""查找连接错误的MCP工具"""
|
||
return db.query(MCPToolConfig).filter(
|
||
MCPToolConfig.connection_status == "error"
|
||
).all()
|
||
|
||
|
||
class ToolExecutionRepository:
|
||
"""工具执行仓储类"""
|
||
|
||
@staticmethod
|
||
def find_by_execution_id(db: Session, execution_id: str) -> Optional[ToolExecution]:
|
||
"""根据执行ID查找执行记录"""
|
||
return db.query(ToolExecution).filter(
|
||
ToolExecution.execution_id == execution_id
|
||
).first()
|
||
|
||
@staticmethod
|
||
def find_by_tool_and_tenant(
|
||
db: Session,
|
||
tool_id: uuid.UUID,
|
||
tenant_id: uuid.UUID,
|
||
limit: int = 100
|
||
) -> List[ToolExecution]:
|
||
"""根据工具和租户查找执行记录"""
|
||
return db.query(ToolExecution).join(
|
||
ToolConfig, ToolExecution.tool_config_id == ToolConfig.id
|
||
).filter(
|
||
ToolConfig.id == tool_id,
|
||
ToolConfig.tenant_id == tenant_id
|
||
).order_by(ToolExecution.started_at.desc()).limit(limit).all()
|