feat(workflow): include loop information in loop node outputs

This commit is contained in:
Eternity
2026-02-24 15:26:32 +08:00
parent b272a52b57
commit 44083aec79
7 changed files with 167 additions and 16 deletions

View File

@@ -271,3 +271,11 @@ class EventStreamHandler:
} }
} }
@staticmethod
async def handle_cycle_item_event(data: dict):
yield {
"event": "cycle_item",
"data": data.get("data")
}

View File

@@ -279,6 +279,10 @@ class WorkflowExecutor:
async for error_event in self.event_handler.handle_node_error_event(data): async for error_event in self.event_handler.handle_node_error_event(data):
yield error_event yield error_event
elif event_type == "cycle_item":
async for cycle_event in self.event_handler.handle_cycle_item_event(data):
yield cycle_event
elif mode == "debug": elif mode == "debug":
async for debug_event in self.event_handler.handle_debug_event(data, input_data): async for debug_event in self.event_handler.handle_debug_event(data, input_data):
yield debug_event yield debug_event

View File

@@ -1,13 +1,17 @@
import asyncio import asyncio
import logging import logging
import re import re
import uuid
from typing import Any from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from langgraph.config import get_stream_writer
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -25,6 +29,7 @@ class IterationRuntime:
def __init__( def __init__(
self, self,
start_id: str, start_id: str,
stream: bool,
graph: CompiledStateGraph, graph: CompiledStateGraph,
node_id: str, node_id: str,
config: dict[str, Any], config: dict[str, Any],
@@ -42,6 +47,7 @@ class IterationRuntime:
state: Current workflow state at the point of iteration. state: Current workflow state at the point of iteration.
""" """
self.start_id = start_id self.start_id = start_id
self.stream = stream
self.graph = graph self.graph = graph
self.state = state self.state = state
self.node_id = node_id self.node_id = node_id
@@ -49,6 +55,12 @@ class IterationRuntime:
self.looping = True self.looping = True
self.variable_pool = variable_pool self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool self.child_variable_pool = child_variable_pool
self.event_write = get_stream_writer()
self.checkpoint = RunnableConfig(
configurable={
"thread_id": uuid.uuid4()
}
)
self.output_value = None self.output_value = None
self.result: list = [] self.result: list = []
@@ -91,7 +103,46 @@ class IterationRuntime:
item: The input element for this iteration. item: The input element for this iteration.
idx: The index of this iteration. idx: The index of this iteration.
""" """
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) if self.stream:
async for event in self.graph.astream(
await self._init_iteration_state(item, idx),
stream_mode=["debug"],
config=self.checkpoint
):
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
continue
if mode == "debug":
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task_result":
result = payload.get("result", {})
if not result.get("activate", {}).get(node_name):
continue
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
self.event_write({
"type": "cycle_item",
"data": {
"cycle_id": self.node_id,
"cycle_idx": idx,
"node_id": node_name,
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
if not cycle_variable else cycle_variable,
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
}
})
result = self.graph.get_state(config=self.checkpoint).values
else:
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
output = self.child_variable_pool.get_value(self.output_value) output = self.child_variable_pool.get_value(self.output_value)
if isinstance(output, list) and self.typed_config.flatten: if isinstance(output, list) and self.typed_config.flatten:
self.result.extend(output) self.result.extend(output)
@@ -152,16 +203,9 @@ class IterationRuntime:
while idx < len(array_obj) and self.looping: while idx < len(array_obj) and self.looping:
logger.info(f"Iteration node {self.node_id}: running") logger.info(f"Iteration node {self.node_id}: running")
item = array_obj[idx] item = array_obj[idx]
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx)) result = await self.run_task(item, idx)
child_state.append(result)
output = self.child_variable_pool.get_value(self.output_value)
self.merge_conv_vars() self.merge_conv_vars()
if isinstance(output, list) and self.typed_config.flatten: child_state.append(result)
self.result.extend(output)
else:
self.result.append(output)
if result["looping"] == 2:
self.looping = False
idx += 1 idx += 1
logger.info(f"Iteration node {self.node_id}: execution completed") logger.info(f"Iteration node {self.node_id}: execution completed")
return { return {

View File

@@ -1,12 +1,15 @@
import logging import logging
import uuid
from typing import Any from typing import Any
from langchain_core.runnables import RunnableConfig
from langgraph.config import get_stream_writer
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression from app.core.workflow.utils.expression_evaluator import evaluate_expression
@@ -27,6 +30,7 @@ class LoopRuntime:
def __init__( def __init__(
self, self,
start_id: str, start_id: str,
stream: bool,
graph: CompiledStateGraph, graph: CompiledStateGraph,
node_id: str, node_id: str,
config: dict[str, Any], config: dict[str, Any],
@@ -46,6 +50,7 @@ class LoopRuntime:
child_variable_pool: A VariablePool instance for managing child node outputs. child_variable_pool: A VariablePool instance for managing child node outputs.
""" """
self.start_id = start_id self.start_id = start_id
self.stream = stream
self.graph = graph self.graph = graph
self.state = state self.state = state
self.node_id = node_id self.node_id = node_id
@@ -53,6 +58,13 @@ class LoopRuntime:
self.looping = True self.looping = True
self.variable_pool = variable_pool self.variable_pool = variable_pool
self.child_variable_pool = child_variable_pool self.child_variable_pool = child_variable_pool
self.event_write = get_stream_writer()
self.checkpoint = RunnableConfig(
configurable={
"thread_id": uuid.uuid4()
}
)
async def _init_loop_state(self): async def _init_loop_state(self):
""" """
@@ -142,10 +154,12 @@ class LoopRuntime:
case _: case _:
raise ValueError(f"Invalid condition: {operator}") raise ValueError(f"Invalid condition: {operator}")
def merge_conv_vars(self): def merge_conv_vars(self, loopstate):
self.variable_pool.variables["conv"].update( self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables.get("conv", {}) self.child_variable_pool.variables.get("conv", {})
) )
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars
def evaluate_conditional(self) -> bool: def evaluate_conditional(self) -> bool:
""" """
@@ -175,6 +189,50 @@ class LoopRuntime:
else: else:
return any(conditions) return any(conditions)
async def _run(self, loopstate, idx):
if self.stream:
async for event in self.graph.astream(
loopstate,
stream_mode=["debug"],
config=self.checkpoint
):
if isinstance(event, tuple) and len(event) == 2:
mode, data = event
else:
continue
if mode == "debug":
event_type = data.get("type")
payload = data.get("payload", {})
node_name = payload.get("name")
if node_name and node_name.startswith("nop"):
continue
if event_type == "task_result":
result = payload.get("result", {})
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
if not result.get("activate", {}).get(node_name):
continue
cycle_variable = None
if node_type == NodeType.CYCLE_START:
cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {})
self.event_write({
"type": "cycle_item",
"data": {
"cycle_id": self.node_id,
"cycle_idx": idx,
"node_id": node_name,
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
if not cycle_variable else cycle_variable,
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
if not cycle_variable else cycle_variable,
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
}
})
return self.graph.get_state(config=self.checkpoint).values
else:
return await self.graph.ainvoke(loopstate)
async def run(self): async def run(self):
""" """
Execute the loop node until termination conditions are met. Execute the loop node until termination conditions are met.
@@ -190,15 +248,17 @@ class LoopRuntime:
loopstate = await self._init_loop_state() loopstate = await self._init_loop_state()
loop_time = self.typed_config.max_loop loop_time = self.typed_config.max_loop
child_state = [] child_state = []
idx = 0
while not self.evaluate_conditional() and self.looping and loop_time > 0: while not self.evaluate_conditional() and self.looping and loop_time > 0:
logger.info(f"loop node {self.node_id}: running") logger.info(f"loop node {self.node_id}: running")
result = await self.graph.ainvoke(loopstate) result = await self._run(loopstate, idx)
child_state.append(result) child_state.append(result)
self.merge_conv_vars() self.merge_conv_vars(loopstate)
if result["looping"] == 2: if result["looping"] == 2:
self.looping = False self.looping = False
loop_time -= 1 loop_time -= 1
idx += 1
logger.info(f"loop node {self.node_id}: execution completed") logger.info(f"loop node {self.node_id}: execution completed")
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state} return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}

View File

@@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode):
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
return await LoopRuntime( return await LoopRuntime(
start_id=self.start_node_id, start_id=self.start_node_id,
stream=False,
graph=self.graph, graph=self.graph,
node_id=self.node_id, node_id=self.node_id,
config=self.config, config=self.config,
@@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode):
if self.node_type == NodeType.ITERATION: if self.node_type == NodeType.ITERATION:
return await IterationRuntime( return await IterationRuntime(
start_id=self.start_node_id, start_id=self.start_node_id,
stream=False,
graph=self.graph, graph=self.graph,
node_id=self.node_id, node_id=self.node_id,
config=self.config, config=self.config,
@@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode):
child_variable_pool=self.child_variable_pool child_variable_pool=self.child_variable_pool
).run() ).run()
raise RuntimeError("Unknown cycle node type") raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
if self.node_type == NodeType.LOOP:
yield {
"__final__": True,
"result": await LoopRuntime(
start_id=self.start_node_id,
stream=True,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool,
).run()
}
return
if self.node_type == NodeType.ITERATION:
yield {
"__final__": True,
"result": await IterationRuntime(
start_id=self.start_node_id,
stream=True,
graph=self.graph,
node_id=self.node_id,
config=self.config,
state=state,
variable_pool=variable_pool,
child_variable_pool=self.child_variable_pool
).run()
}
return
raise RuntimeError("Unknown cycle node type")

View File

@@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel):
user_id: Optional[str] = Field(default=None, description="用户ID用于会话管理") user_id: Optional[str] = Field(default=None, description="用户ID用于会话管理")
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
stream: bool = Field(default=False, description="是否流式返回") stream: bool = Field(default=False, description="是否流式返回")
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)") files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
class DraftRunResponse(BaseModel): class DraftRunResponse(BaseModel):

View File

@@ -588,7 +588,7 @@ class WorkflowService:
"message_length": len(payload.get("output", "")) "message_length": len(payload.get("output", ""))
} }
} }
case "node_start" | "node_end" | "node_error": case "node_start" | "node_end" | "node_error" | "cycle_item":
return None return None
case _: case _:
return event return event