Merge pull request #399 from SuanmoSuanyangTechnology/feature/workflow-cycle-state

feat(workflow): include loop information in loop node outputs
This commit is contained in:
Mark
2026-02-24 18:02:11 +08:00
committed by GitHub
7 changed files with 167 additions and 16 deletions

View File

@@ -271,3 +271,11 @@ class EventStreamHandler:
}
}
@staticmethod
async def handle_cycle_item_event(data: dict):
yield {
"event": "cycle_item",
"data": data.get("data")
}

View File

@@ -279,6 +279,10 @@ class WorkflowExecutor:
async for error_event in self.event_handler.handle_node_error_event(data):
yield error_event
elif event_type == "cycle_item":
async for cycle_event in self.event_handler.handle_cycle_item_event(data):
yield cycle_event
elif mode == "debug":
async for debug_event in self.event_handler.handle_debug_event(data, input_data):
yield debug_event

View File

@@ -1,13 +1,17 @@
import asyncio
import logging
import re
import uuid
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph
from langgraph.config import get_stream_writer
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__)
@@ -25,6 +29,7 @@ class IterationRuntime:
def __init__(
self,
start_id: str,
stream: bool,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
@@ -42,6 +47,7 @@ class IterationRuntime:
state: Current workflow state at the point of iteration.
"""
self.start_id = start_id
self.stream = stream
self.graph = graph
self.state = state
self.node_id = node_id
@@ -49,6 +55,12 @@ class IterationRuntime:
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
self.event_write = get_stream_writer()
self.checkpoint = RunnableConfig(
configurable={
"thread_id": uuid.uuid4()
}
)
self.output_value = None
self.result: list = []
@@ -91,7 +103,46 @@ class IterationRuntime:
item: The input element for this iteration.
idx: The index of this iteration.
"""
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
if self.stream:
async for event in self.graph.astream(
await self._init_iteration_state(item, idx),
stream_mode=["debug"],
config=self.checkpoint
):
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
continue
if mode == "debug":
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task_result":
result = payload.get("result", {})
if not result.get("activate", {}).get(node_name):
continue
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
self.event_write({
"type": "cycle_item",
"data": {
"cycle_id": self.node_id,
"cycle_idx": idx,
"node_id": node_name,
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
if not cycle_variable else cycle_variable,
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
}
})
result = self.graph.get_state(config=self.checkpoint).values
else:
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
@@ -152,16 +203,9 @@ class IterationRuntime:
while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx]
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
child_state.append(result)
output = self.child_variable_pool.get_value(self.output_value)
result = await self.run_task(item, idx)
self.merge_conv_vars()
if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output)
else:
self.result.append(output)
if result["looping"] == 2:
self.looping = False
child_state.append(result)
idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed")
return {

View File

@@ -1,12 +1,15 @@
import logging
import uuid
from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_stream_writer
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression
@@ -27,6 +30,7 @@ class LoopRuntime:
def __init__(
self,
start_id: str,
stream: bool,
graph: CompiledStateGraph,
node_id: str,
config: dict[str, Any],
@@ -46,6 +50,7 @@ class LoopRuntime:
child_variable_pool: A VariablePool instance for managing child node outputs.
"""
self.start_id = start_id
self.stream = stream
self.graph = graph
self.state = state
self.node_id = node_id
@@ -53,6 +58,13 @@ class LoopRuntime:
self.looping = True
self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool
self.event_write = get_stream_writer()
self.checkpoint = RunnableConfig(
configurable={
"thread_id": uuid.uuid4()
}
)
async def _init_loop_state(self):
"""
@@ -142,10 +154,12 @@ class LoopRuntime:
case _:
raise ValueError(f"Invalid condition: {operator}")
def merge_conv_vars(self):
def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {})
)
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars
def evaluate_conditional(self) -> bool:
"""
@@ -175,6 +189,50 @@ class LoopRuntime:
else:
return any(conditions)
async def _run(self, loopstate, idx):
if self.stream:
async for event in self.graph.astream(
loopstate,
stream_mode=["debug"],
config=self.checkpoint
):
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
continue
if mode == "debug":
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task_result":
result = payload.get("result", {})
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
if not result.get("activate", {}).get(node_name):
continue
cycle_variable = None
if node_type == NodeType.CYCLE_START:
cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {})
self.event_write({
"type": "cycle_item",
"data": {
"cycle_id": self.node_id,
"cycle_idx": idx,
"node_id": node_name,
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
if not cycle_variable else cycle_variable,
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
}
})
return self.graph.get_state(config=self.checkpoint).values
else:
return await self.graph.ainvoke(loopstate)
async def run(self):
"""
Execute the loop node until termination conditions are met.
@@ -190,15 +248,17 @@ class LoopRuntime:
loopstate = await self._init_loop_state()
loop_time = self.typed_config.max_loop
child_state = []
idx = 0
while not self.evaluate_conditional() and self.looping and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
result = await self.graph.ainvoke(loopstate)
result = await self._run(loopstate, idx)
child_state.append(result)
self.merge_conv_vars()
self.merge_conv_vars(loopstate)
if result["looping"] == 2:
self.looping = False
loop_time -= 1
idx += 1
logger.info(f"loop node {self.node_id}: execution completed")
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}

View File

@@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode):
if self.node_type == NodeType.LOOP:
return await LoopRuntime(
start_id=self.start_node_id,
stream=False,
graph=self.graph,
node_id=self.node_id,
config=self.config,
@@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode):
if self.node_type == NodeType.ITERATION:
return await IterationRuntime(
start_id=self.start_node_id,
stream=False,
graph=self.graph,
node_id=self.node_id,
config=self.config,
@@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode):
child_variable_pool=self.child_variable_pool
).run()
raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
if self.node_type == NodeType.LOOP:
yield {
"__final__": True,
"result": await LoopRuntime(
start_id=self.start_node_id,
stream=True,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool,
).run()
}
return
if self.node_type == NodeType.ITERATION:
yield {
"__final__": True,
"result": await IterationRuntime(
start_id=self.start_node_id,
stream=True,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
).run()
}
return
raise RuntimeError("Unknown cycle node type")

View File

@@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel):
user_id: Optional[str] = Field(default=None, description="用户ID用于会话管理")
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回")
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
class DraftRunResponse(BaseModel):

View File

@@ -588,7 +588,7 @@ class WorkflowService:
"message_length": len(payload.get("output", ""))
}
}
case "node_start" | "node_end" | "node_error":
case "node_start" | "node_end" | "node_error" | "cycle_item":
return None
case _:
return event