From 44083aec7999d1565fcfedef746877d53d47bc7c Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 24 Feb 2026 15:26:32 +0800 Subject: [PATCH] feat(workflow): include loop information in loop node outputs --- .../workflow/engine/event_stream_handler.py | 8 +++ api/app/core/workflow/executor.py | 4 ++ .../workflow/nodes/cycle_graph/iteration.py | 64 ++++++++++++++--- .../core/workflow/nodes/cycle_graph/loop.py | 68 +++++++++++++++++-- .../core/workflow/nodes/cycle_graph/node.py | 35 ++++++++++ api/app/schemas/app_schema.py | 2 +- api/app/services/workflow_service.py | 2 +- 7 files changed, 167 insertions(+), 16 deletions(-) diff --git a/api/app/core/workflow/engine/event_stream_handler.py b/api/app/core/workflow/engine/event_stream_handler.py index e49a2e8b..5b7d8de2 100644 --- a/api/app/core/workflow/engine/event_stream_handler.py +++ b/api/app/core/workflow/engine/event_stream_handler.py @@ -271,3 +271,11 @@ class EventStreamHandler: } } + @staticmethod + async def handle_cycle_item_event(data: dict): + yield { + "event": "cycle_item", + "data": data.get("data") + } + + diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index ff48fb07..2b554a60 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -279,6 +279,10 @@ class WorkflowExecutor: async for error_event in self.event_handler.handle_node_error_event(data): yield error_event + elif event_type == "cycle_item": + async for cycle_event in self.event_handler.handle_cycle_item_event(data): + yield cycle_event + elif mode == "debug": async for debug_event in self.event_handler.handle_debug_event(data, input_data): yield debug_event diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index f1138840..e4026f2d 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -1,13 +1,17 @@ import asyncio import logging import re +import uuid from typing import Any +from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph +from langgraph.config import get_stream_writer from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.cycle_graph import IterationNodeConfig +from app.core.workflow.nodes.enums import NodeType from app.core.workflow.variable.base_variable import VariableType logger = logging.getLogger(__name__) @@ -25,6 +29,7 @@ class IterationRuntime: def __init__( self, start_id: str, + stream: bool, graph: CompiledStateGraph, node_id: str, config: dict[str, Any], @@ -42,6 +47,7 @@ class IterationRuntime: state: Current workflow state at the point of iteration. """ self.start_id = start_id + self.stream = stream self.graph = graph self.state = state self.node_id = node_id @@ -49,6 +55,12 @@ class IterationRuntime: self.looping = True self.variable_pool = variable_pool self.child_variable_pool = child_variable_pool + self.event_write = get_stream_writer() + self.checkpoint = RunnableConfig( + configurable={ + "thread_id": uuid.uuid4() + } + ) self.output_value = None self.result: list = [] @@ -91,7 +103,46 @@ class IterationRuntime: item: The input element for this iteration. idx: The index of this iteration. """ - result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) + if self.stream: + async for event in self.graph.astream( + await self._init_iteration_state(item, idx), + stream_mode=["debug"], + config=self.checkpoint + ): + if isinstance(event, tuple) and len(event) == 2: + mode, data = event + else: + continue + if mode == "debug": + event_type = data.get("type") + payload = data.get("payload", {}) + node_name = payload.get("name") + + if node_name and node_name.startswith("nop"): + continue + if event_type == "task_result": + result = payload.get("result", {}) + if not result.get("activate", {}).get(node_name): + continue + node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type") + cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None + self.event_write({ + "type": "cycle_item", + "data": { + "cycle_id": self.node_id, + "cycle_idx": idx, + "node_id": node_name, + "input": result.get("node_outputs", {}).get(node_name, {}).get("input") + if not cycle_variable else cycle_variable, + "output": result.get("node_outputs", {}).get(node_name, {}).get("output") + if not cycle_variable else cycle_variable, + "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), + "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage") + } + }) + result = self.graph.get_state(config=self.checkpoint).values + else: + result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) output = self.child_variable_pool.get_value(self.output_value) if isinstance(output, list) and self.typed_config.flatten: self.result.extend(output) @@ -152,16 +203,9 @@ class IterationRuntime: while idx < len(array_obj) and self.looping: logger.info(f"Iteration node {self.node_id}: running") item = array_obj[idx] - result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) - child_state.append(result) - output = self.child_variable_pool.get_value(self.output_value) + result = await self.run_task(item, idx) self.merge_conv_vars() - if isinstance(output, list) and self.typed_config.flatten: - self.result.extend(output) - else: - self.result.append(output) - if result["looping"] == 2: - self.looping = False + child_state.append(result) idx += 1 logger.info(f"Iteration node {self.node_id}: execution completed") return { diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index b4406f74..cebadfdc 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -1,12 +1,15 @@ import logging +import uuid from typing import Any +from langchain_core.runnables import RunnableConfig +from langgraph.config import get_stream_writer from langgraph.graph.state import CompiledStateGraph from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.cycle_graph import LoopNodeConfig -from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator +from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.utils.expression_evaluator import evaluate_expression @@ -27,6 +30,7 @@ class LoopRuntime: def __init__( self, start_id: str, + stream: bool, graph: CompiledStateGraph, node_id: str, config: dict[str, Any], @@ -46,6 +50,7 @@ class LoopRuntime: child_variable_pool: A VariablePool instance for managing child node outputs. """ self.start_id = start_id + self.stream = stream self.graph = graph self.state = state self.node_id = node_id @@ -53,6 +58,13 @@ class LoopRuntime: self.looping = True self.variable_pool = variable_pool self.child_variable_pool = child_variable_pool + self.event_write = get_stream_writer() + + self.checkpoint = RunnableConfig( + configurable={ + "thread_id": uuid.uuid4() + } + ) async def _init_loop_state(self): """ @@ -142,10 +154,12 @@ class LoopRuntime: case _: raise ValueError(f"Invalid condition: {operator}") - def merge_conv_vars(self): + def merge_conv_vars(self, loopstate): self.variable_pool.variables["conv"].update( self.child_variable_pool.variables.get("conv", {}) ) + loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) + loopstate["node_outputs"][self.node_id] = loop_vars def evaluate_conditional(self) -> bool: """ @@ -175,6 +189,50 @@ class LoopRuntime: else: return any(conditions) + async def _run(self, loopstate, idx): + if self.stream: + async for event in self.graph.astream( + loopstate, + stream_mode=["debug"], + config=self.checkpoint + ): + if isinstance(event, tuple) and len(event) == 2: + mode, data = event + else: + continue + if mode == "debug": + event_type = data.get("type") + payload = data.get("payload", {}) + node_name = payload.get("name") + + if node_name and node_name.startswith("nop"): + continue + if event_type == "task_result": + result = payload.get("result", {}) + node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type") + if not result.get("activate", {}).get(node_name): + continue + cycle_variable = None + if node_type == NodeType.CYCLE_START: + cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {}) + self.event_write({ + "type": "cycle_item", + "data": { + "cycle_id": self.node_id, + "cycle_idx": idx, + "node_id": node_name, + "input": result.get("node_outputs", {}).get(node_name, {}).get("input") + if not cycle_variable else cycle_variable, + "output": result.get("node_outputs", {}).get(node_name, {}).get("output") + if not cycle_variable else cycle_variable, + "elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"), + "token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage") + } + }) + return self.graph.get_state(config=self.checkpoint).values + else: + return await self.graph.ainvoke(loopstate) + async def run(self): """ Execute the loop node until termination conditions are met. @@ -190,15 +248,17 @@ class LoopRuntime: loopstate = await self._init_loop_state() loop_time = self.typed_config.max_loop child_state = [] + idx = 0 while not self.evaluate_conditional() and self.looping and loop_time > 0: logger.info(f"loop node {self.node_id}: running") - result = await self.graph.ainvoke(loopstate) + result = await self._run(loopstate, idx) child_state.append(result) - self.merge_conv_vars() + self.merge_conv_vars(loopstate) if result["looping"] == 2: self.looping = False loop_time -= 1 + idx += 1 logger.info(f"loop node {self.node_id}: execution completed") return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state} diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 72768b77..f2912e2c 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode): if self.node_type == NodeType.LOOP: return await LoopRuntime( start_id=self.start_node_id, + stream=False, graph=self.graph, node_id=self.node_id, config=self.config, @@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode): if self.node_type == NodeType.ITERATION: return await IterationRuntime( start_id=self.start_node_id, + stream=False, graph=self.graph, node_id=self.node_id, config=self.config, @@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode): child_variable_pool=self.child_variable_pool ).run() raise RuntimeError("Unknown cycle node type") + + async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): + if self.node_type == NodeType.LOOP: + yield { + "__final__": True, + "result": await LoopRuntime( + start_id=self.start_node_id, + stream=True, + graph=self.graph, + node_id=self.node_id, + config=self.config, + state=state, + variable_pool=variable_pool, + child_variable_pool=self.child_variable_pool, + ).run() + } + return + if self.node_type == NodeType.ITERATION: + yield { + "__final__": True, + "result": await IterationRuntime( + start_id=self.start_node_id, + stream=True, + graph=self.graph, + node_id=self.node_id, + config=self.config, + state=state, + variable_pool=variable_pool, + child_variable_pool=self.child_variable_pool + ).run() + } + return + raise RuntimeError("Unknown cycle node type") diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 792a32ac..8cf81b92 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel): user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") - files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)") + files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)") class DraftRunResponse(BaseModel): diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index fb88f804..d06a05d7 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -588,7 +588,7 @@ class WorkflowService: "message_length": len(payload.get("output", "")) } } - case "node_start" | "node_end" | "node_error": + case "node_start" | "node_end" | "node_error" | "cycle_item": return None case _: return event