diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py index 4e24faa0..a783fe3f 100644 --- a/api/app/repositories/workflow_repository.py +++ b/api/app/repositories/workflow_repository.py @@ -3,9 +3,9 @@ """ import uuid -from typing import Any, Annotated +from typing import Any, Annotated, Literal from sqlalchemy.orm import Session -from sqlalchemy import desc +from sqlalchemy import desc, select from fastapi import Depends from app.models.workflow_model import ( @@ -128,29 +128,36 @@ class WorkflowExecutionRepository: Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.app_id == app_id ).order_by( desc(WorkflowExecution.started_at) - ).limit(limit).offset(offset).all() + ).limit(limit).offset(offset) + return list(self.db.execute(stmt).scalars()) def get_by_conversation_id( self, - conversation_id: uuid.UUID + conversation_id: uuid.UUID, + status: Literal["running", "completed", "failed"] = None, + limit_count: int = 50 ) -> list[WorkflowExecution]: """根据会话 ID 获取执行记录列表 Args: + limit_count: conversation_id: 会话 ID + status: 状态(可选) Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.conversation_id == conversation_id - ).order_by( - desc(WorkflowExecution.started_at) - ).all() + ) + if status: + stmt = stmt.filter(WorkflowExecution.status == status) + stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count) + return list(self.db.execute(stmt).scalars()) def count_by_app_id(self, app_id: uuid.UUID) -> int: """统计应用的执行次数 @@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表(按执行顺序排序) """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id ).order_by( WorkflowNodeExecution.execution_order - ).all() + ) + return list(self.db.execute(stmt).scalars()) def get_by_node_id( self, @@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表 """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id, WorkflowNodeExecution.node_id == node_id ).order_by( WorkflowNodeExecution.retry_count - ).all() + ) + return list(self.db.execute(stmt).scalars()) # ==================== 依赖注入函数 ==================== diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index c7d7f2b1..13267078 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -561,6 +561,24 @@ class WorkflowService: storage_type = 'neo4j' return storage_type, user_rag_memory_id + def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None: + executions = self.execution_repo.get_by_conversation_id( + conversation_id=conversation_id, + status="completed", + limit_count=1 + ) + + if executions: + last_state = executions[0].output_data + if isinstance(last_state, dict): + variables = last_state.get("variables", {}) + conv_vars = variables.get("conv", {}) + # input_data["conv"] = conv_vars + # input_data["conv_messages"] = last_state.get("messages") or [] + conv_messages = last_state.get("messages") or [] + return conv_vars, conv_messages + return None + # ==================== 工作流执行 ==================== async def run( @@ -634,18 +652,11 @@ class WorkflowService: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break - + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) result = await execute_workflow( @@ -807,17 +818,11 @@ class WorkflowService: storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) message_id = uuid.uuid4() async for event in execute_workflow_stream(