refactor(workflow): optimize workflow history queries and migrate ORM to SQLAlchemy 2.0
- Migrate historical workflow queries from legacy ORM Query API to SQLAlchemy 2.0 select() + execute() - Limit query fields and use pagination to reduce returned data, improving performance - Preserve original ordering and filtering logic
This commit is contained in:
@@ -3,9 +3,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Annotated
|
from typing import Any, Annotated, Literal
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc, select
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
from app.models.workflow_model import (
|
from app.models.workflow_model import (
|
||||||
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
执行记录列表
|
执行记录列表
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowExecution).filter(
|
stmt = select(WorkflowExecution).filter(
|
||||||
WorkflowExecution.app_id == app_id
|
WorkflowExecution.app_id == app_id
|
||||||
).order_by(
|
).order_by(
|
||||||
desc(WorkflowExecution.started_at)
|
desc(WorkflowExecution.started_at)
|
||||||
).limit(limit).offset(offset).all()
|
).limit(limit).offset(offset)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
def get_by_conversation_id(
|
def get_by_conversation_id(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID
|
conversation_id: uuid.UUID,
|
||||||
|
status: Literal["running", "completed", "failed"] = None,
|
||||||
|
limit_count: int = 50
|
||||||
) -> list[WorkflowExecution]:
|
) -> list[WorkflowExecution]:
|
||||||
"""根据会话 ID 获取执行记录列表
|
"""根据会话 ID 获取执行记录列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
limit_count:
|
||||||
conversation_id: 会话 ID
|
conversation_id: 会话 ID
|
||||||
|
status: 状态(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行记录列表
|
执行记录列表
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowExecution).filter(
|
stmt = select(WorkflowExecution).filter(
|
||||||
WorkflowExecution.conversation_id == conversation_id
|
WorkflowExecution.conversation_id == conversation_id
|
||||||
).order_by(
|
)
|
||||||
desc(WorkflowExecution.started_at)
|
if status:
|
||||||
).all()
|
stmt = stmt.filter(WorkflowExecution.status == status)
|
||||||
|
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||||
"""统计应用的执行次数
|
"""统计应用的执行次数
|
||||||
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
节点执行记录列表(按执行顺序排序)
|
节点执行记录列表(按执行顺序排序)
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowNodeExecution).filter(
|
stmt = select(WorkflowNodeExecution).filter(
|
||||||
WorkflowNodeExecution.execution_id == execution_id
|
WorkflowNodeExecution.execution_id == execution_id
|
||||||
).order_by(
|
).order_by(
|
||||||
WorkflowNodeExecution.execution_order
|
WorkflowNodeExecution.execution_order
|
||||||
).all()
|
)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
def get_by_node_id(
|
def get_by_node_id(
|
||||||
self,
|
self,
|
||||||
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
节点执行记录列表
|
节点执行记录列表
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowNodeExecution).filter(
|
stmt = select(WorkflowNodeExecution).filter(
|
||||||
WorkflowNodeExecution.execution_id == execution_id,
|
WorkflowNodeExecution.execution_id == execution_id,
|
||||||
WorkflowNodeExecution.node_id == node_id
|
WorkflowNodeExecution.node_id == node_id
|
||||||
).order_by(
|
).order_by(
|
||||||
WorkflowNodeExecution.retry_count
|
WorkflowNodeExecution.retry_count
|
||||||
).all()
|
)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|||||||
@@ -561,6 +561,24 @@ class WorkflowService:
|
|||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
return storage_type, user_rag_memory_id
|
return storage_type, user_rag_memory_id
|
||||||
|
|
||||||
|
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
|
||||||
|
executions = self.execution_repo.get_by_conversation_id(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
status="completed",
|
||||||
|
limit_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if executions:
|
||||||
|
last_state = executions[0].output_data
|
||||||
|
if isinstance(last_state, dict):
|
||||||
|
variables = last_state.get("variables", {})
|
||||||
|
conv_vars = variables.get("conv", {})
|
||||||
|
# input_data["conv"] = conv_vars
|
||||||
|
# input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
|
conv_messages = last_state.get("messages") or []
|
||||||
|
return conv_vars, conv_messages
|
||||||
|
return None
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
@@ -634,18 +652,11 @@ class WorkflowService:
|
|||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
|
|
||||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
history = self._get_history_info(conversation_id_uuid)
|
||||||
|
if history:
|
||||||
for exec_res in executions:
|
conv_vars, conv_messages = history
|
||||||
if exec_res.status == "completed":
|
input_data["conv"] = conv_vars
|
||||||
last_state = exec_res.output_data
|
input_data["conv_messages"] = conv_messages
|
||||||
if isinstance(last_state, dict):
|
|
||||||
variables = last_state.get("variables", {})
|
|
||||||
conv_vars = variables.get("conv", {})
|
|
||||||
input_data["conv"] = conv_vars
|
|
||||||
input_data["conv_messages"] = last_state.get("messages") or []
|
|
||||||
break
|
|
||||||
|
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
|
||||||
result = await execute_workflow(
|
result = await execute_workflow(
|
||||||
@@ -807,17 +818,11 @@ class WorkflowService:
|
|||||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
history = self._get_history_info(conversation_id_uuid)
|
||||||
|
if history:
|
||||||
for exec_res in executions:
|
conv_vars, conv_messages = history
|
||||||
if exec_res.status == "completed":
|
input_data["conv"] = conv_vars
|
||||||
last_state = exec_res.output_data
|
input_data["conv_messages"] = conv_messages
|
||||||
if isinstance(last_state, dict):
|
|
||||||
variables = last_state.get("variables", {})
|
|
||||||
conv_vars = variables.get("conv", {})
|
|
||||||
input_data["conv"] = conv_vars
|
|
||||||
input_data["conv_messages"] = last_state.get("messages") or []
|
|
||||||
break
|
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
async for event in execute_workflow_stream(
|
async for event in execute_workflow_stream(
|
||||||
|
|||||||
Reference in New Issue
Block a user