refactor(workflow): optimize workflow history queries and migrate ORM to SQLAlchemy 2.0

- Migrate historical workflow queries from legacy ORM Query API to SQLAlchemy 2.0 select() + execute()
- Limit query fields and use pagination to reduce returned data, improving performance
- Preserve original ordering and filtering logic
This commit is contained in:
Eternity
2026-03-27 11:56:22 +08:00
parent a5bce221bd
commit 4534b65d6a
2 changed files with 50 additions and 36 deletions

View File

@@ -3,9 +3,9 @@
""" """
import uuid import uuid
from typing import Any, Annotated from typing import Any, Annotated, Literal
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import desc from sqlalchemy import desc, select
from fastapi import Depends from fastapi import Depends
from app.models.workflow_model import ( from app.models.workflow_model import (
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
Returns: Returns:
执行记录列表 执行记录列表
""" """
return self.db.query(WorkflowExecution).filter( stmt = select(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id WorkflowExecution.app_id == app_id
).order_by( ).order_by(
desc(WorkflowExecution.started_at) desc(WorkflowExecution.started_at)
).limit(limit).offset(offset).all() ).limit(limit).offset(offset)
return list(self.db.execute(stmt).scalars())
def get_by_conversation_id( def get_by_conversation_id(
self, self,
conversation_id: uuid.UUID conversation_id: uuid.UUID,
status: Literal["running", "completed", "failed"] = None,
limit_count: int = 50
) -> list[WorkflowExecution]: ) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表 """根据会话 ID 获取执行记录列表
Args: Args:
limit_count:
conversation_id: 会话 ID conversation_id: 会话 ID
status: 状态(可选)
Returns: Returns:
执行记录列表 执行记录列表
""" """
return self.db.query(WorkflowExecution).filter( stmt = select(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id WorkflowExecution.conversation_id == conversation_id
).order_by( )
desc(WorkflowExecution.started_at) if status:
).all() stmt = stmt.filter(WorkflowExecution.status == status)
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
return list(self.db.execute(stmt).scalars())
def count_by_app_id(self, app_id: uuid.UUID) -> int: def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数 """统计应用的执行次数
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
Returns: Returns:
节点执行记录列表(按执行顺序排序) 节点执行记录列表(按执行顺序排序)
""" """
return self.db.query(WorkflowNodeExecution).filter( stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id WorkflowNodeExecution.execution_id == execution_id
).order_by( ).order_by(
WorkflowNodeExecution.execution_order WorkflowNodeExecution.execution_order
).all() )
return list(self.db.execute(stmt).scalars())
def get_by_node_id( def get_by_node_id(
self, self,
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
Returns: Returns:
节点执行记录列表 节点执行记录列表
""" """
return self.db.query(WorkflowNodeExecution).filter( stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id, WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id WorkflowNodeExecution.node_id == node_id
).order_by( ).order_by(
WorkflowNodeExecution.retry_count WorkflowNodeExecution.retry_count
).all() )
return list(self.db.execute(stmt).scalars())
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================

View File

@@ -561,6 +561,24 @@ class WorkflowService:
storage_type = 'neo4j' storage_type = 'neo4j'
return storage_type, user_rag_memory_id return storage_type, user_rag_memory_id
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
executions = self.execution_repo.get_by_conversation_id(
conversation_id=conversation_id,
status="completed",
limit_count=1
)
if executions:
last_state = executions[0].output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
# input_data["conv"] = conv_vars
# input_data["conv_messages"] = last_state.get("messages") or []
conv_messages = last_state.get("messages") or []
return conv_vars, conv_messages
return None
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
async def run( async def run(
@@ -634,18 +652,11 @@ class WorkflowService:
# 更新状态为运行中 # 更新状态为运行中
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) history = self._get_history_info(conversation_id_uuid)
if history:
for exec_res in executions: conv_vars, conv_messages = history
if exec_res.status == "completed": input_data["conv"] = conv_vars
last_state = exec_res.output_data input_data["conv_messages"] = conv_messages
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow( result = await execute_workflow(
@@ -807,17 +818,11 @@ class WorkflowService:
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files input_data["files"] = files
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) history = self._get_history_info(conversation_id_uuid)
if history:
for exec_res in executions: conv_vars, conv_messages = history
if exec_res.status == "completed": input_data["conv"] = conv_vars
last_state = exec_res.output_data input_data["conv_messages"] = conv_messages
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4() message_id = uuid.uuid4()
async for event in execute_workflow_stream( async for event in execute_workflow_stream(