[ADD] Merge code
This commit is contained in:
@@ -27,9 +27,9 @@ class ApiKeyRepository:
|
||||
return db.get(ApiKey, api_key_id)
|
||||
|
||||
@staticmethod
|
||||
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
|
||||
"""根据哈希值获取 API Key"""
|
||||
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
|
||||
def get_by_api_key(db: Session, api_key: str) -> Optional[ApiKey]:
|
||||
"""根据 API Key 获取 API Key"""
|
||||
stmt = select(ApiKey).where(ApiKey.api_key == api_key)
|
||||
return db.scalars(stmt).first()
|
||||
|
||||
@staticmethod
|
||||
@@ -63,11 +63,15 @@ class ApiKeyRepository:
|
||||
@staticmethod
|
||||
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey | None:
|
||||
"""更新 API Key"""
|
||||
allow_none_fields = {"description", "quota_limit", "expires_at"}
|
||||
api_key = db.get(ApiKey, api_key_id)
|
||||
if api_key:
|
||||
for key, value in update_data.items():
|
||||
if value is not None:
|
||||
if key in allow_none_fields:
|
||||
setattr(api_key, key, value)
|
||||
else:
|
||||
if value is not None:
|
||||
setattr(api_key, key, value)
|
||||
db.flush()
|
||||
return api_key
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class AppRepository:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> List[App]:
|
||||
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
try:
|
||||
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all()
|
||||
@@ -24,7 +24,19 @@ class AppRepository:
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下应用时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_apps_by_id(self, app_id: uuid.UUID) -> App:
|
||||
try:
|
||||
app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first()
|
||||
return app
|
||||
except Exception as e:
|
||||
raise
|
||||
|
||||
def get_apps_by_workspace_id(db: Session, workspace_id: uuid.UUID) -> List[App]:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_workspace_id(workspace_id)
|
||||
|
||||
def get_apps_by_id(db: Session, app_id: uuid.UUID) -> App:
|
||||
"""根据工作空间ID查询应用"""
|
||||
repo = AppRepository(db)
|
||||
return repo.get_apps_by_id(app_id)
|
||||
|
||||
247
api/app/repositories/workflow_repository.py
Normal file
247
api/app/repositories/workflow_repository.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
工作流数据访问层
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Annotated
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from fastapi import Depends
|
||||
|
||||
from app.models.workflow_model import (
|
||||
WorkflowConfig,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution
|
||||
)
|
||||
from app.db import get_db
|
||||
|
||||
|
||||
class WorkflowConfigRepository:
|
||||
"""工作流配置仓储"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_app_id(self, app_id: uuid.UUID) -> WorkflowConfig | None:
|
||||
"""根据应用 ID 获取工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
工作流配置或 None
|
||||
"""
|
||||
return self.db.query(WorkflowConfig).filter(
|
||||
WorkflowConfig.app_id == app_id,
|
||||
WorkflowConfig.is_active == True
|
||||
).first()
|
||||
|
||||
def create_or_update(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None
|
||||
) -> WorkflowConfig:
|
||||
"""创建或更新工作流配置
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
nodes: 节点列表
|
||||
edges: 边列表
|
||||
variables: 变量列表
|
||||
execution_config: 执行配置
|
||||
triggers: 触发器列表
|
||||
|
||||
Returns:
|
||||
工作流配置
|
||||
"""
|
||||
# 查找现有配置
|
||||
existing = self.get_by_app_id(app_id)
|
||||
|
||||
if existing:
|
||||
# 更新现有配置
|
||||
existing.nodes = nodes
|
||||
existing.edges = edges
|
||||
if variables is not None:
|
||||
existing.variables = variables
|
||||
if execution_config is not None:
|
||||
existing.execution_config = execution_config
|
||||
if triggers is not None:
|
||||
existing.triggers = triggers
|
||||
self.db.commit()
|
||||
self.db.refresh(existing)
|
||||
return existing
|
||||
else:
|
||||
# 创建新配置
|
||||
config = WorkflowConfig(
|
||||
app_id=app_id,
|
||||
nodes=nodes,
|
||||
edges=edges,
|
||||
variables=variables or [],
|
||||
execution_config=execution_config or {},
|
||||
triggers=triggers or []
|
||||
)
|
||||
self.db.add(config)
|
||||
self.db.commit()
|
||||
self.db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
class WorkflowExecutionRepository:
|
||||
"""工作流执行记录仓储"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_execution_id(self, execution_id: str) -> WorkflowExecution | None:
|
||||
"""根据执行 ID 获取执行记录
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
|
||||
Returns:
|
||||
执行记录或 None
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.execution_id == execution_id
|
||||
).first()
|
||||
|
||||
def get_by_app_id(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> list[WorkflowExecution]:
|
||||
"""根据应用 ID 获取执行记录列表
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
limit: 返回数量限制
|
||||
offset: 偏移量
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).limit(limit).offset(offset).all()
|
||||
|
||||
def get_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
) -> list[WorkflowExecution]:
|
||||
"""根据会话 ID 获取执行记录列表
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.conversation_id == conversation_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).all()
|
||||
|
||||
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||
"""统计应用的执行次数
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
|
||||
Returns:
|
||||
执行次数
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id
|
||||
).count()
|
||||
|
||||
def count_by_status(self, app_id: uuid.UUID, status: str) -> int:
|
||||
"""统计指定状态的执行次数
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
status: 状态
|
||||
|
||||
Returns:
|
||||
执行次数
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id,
|
||||
WorkflowExecution.status == status
|
||||
).count()
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository:
|
||||
"""工作流节点执行记录仓储"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def get_by_execution_id(
|
||||
self,
|
||||
execution_id: uuid.UUID
|
||||
) -> list[WorkflowNodeExecution]:
|
||||
"""根据执行 ID 获取节点执行记录列表
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
|
||||
Returns:
|
||||
节点执行记录列表(按执行顺序排序)
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.execution_order
|
||||
).all()
|
||||
|
||||
def get_by_node_id(
|
||||
self,
|
||||
execution_id: uuid.UUID,
|
||||
node_id: str
|
||||
) -> list[WorkflowNodeExecution]:
|
||||
"""根据节点 ID 获取节点执行记录(可能有多次重试)
|
||||
|
||||
Args:
|
||||
execution_id: 执行 ID
|
||||
node_id: 节点 ID
|
||||
|
||||
Returns:
|
||||
节点执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id,
|
||||
WorkflowNodeExecution.node_id == node_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.retry_count
|
||||
).all()
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_workflow_config_repository(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowConfigRepository:
|
||||
"""获取工作流配置仓储(依赖注入)"""
|
||||
return WorkflowConfigRepository(db)
|
||||
|
||||
|
||||
def get_workflow_execution_repository(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowExecutionRepository:
|
||||
"""获取工作流执行记录仓储(依赖注入)"""
|
||||
return WorkflowExecutionRepository(db)
|
||||
|
||||
|
||||
def get_workflow_node_execution_repository(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""获取工作流节点执行记录仓储(依赖注入)"""
|
||||
return WorkflowNodeExecutionRepository(db)
|
||||
Reference in New Issue
Block a user