fix(workflow): fix loop node termination and iteration node startup issues (#181)
This commit is contained in:
@@ -133,7 +133,7 @@ class WorkflowExecutor:
|
||||
for node in self.workflow_config.get("nodes")
|
||||
if node.get("type") in [NodeType.LOOP, NodeType.ITERATION]
|
||||
], # loop, iteration node id
|
||||
"looping": False, # loop runing flag, only use in loop node,not use in main loop
|
||||
"looping": 0, # loop runing flag, only use in loop node,not use in main loop
|
||||
"activate": {
|
||||
self.start_node_id: True
|
||||
}
|
||||
@@ -358,6 +358,7 @@ class WorkflowExecutor:
|
||||
|
||||
elif mode == "updates":
|
||||
# Handle state updates - store final state
|
||||
# TODO:流式输出点
|
||||
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
|
||||
f"- execution_id: {self.execution_id}")
|
||||
|
||||
|
||||
@@ -19,13 +19,17 @@ from app.core.workflow.variable_pool import VariablePool
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def merget_activate_state(x, y):
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
return y if y > x else x
|
||||
|
||||
|
||||
class WorkflowState(TypedDict):
|
||||
"""Workflow state
|
||||
|
||||
@@ -36,7 +40,7 @@ class WorkflowState(TypedDict):
|
||||
|
||||
# Set of loop node IDs, used for assigning values in loop nodes
|
||||
cycle_nodes: list
|
||||
looping: Annotated[bool, lambda x, y: x and y]
|
||||
looping: Annotated[int, merge_looping_state]
|
||||
|
||||
# Input variables (passed from configured variables)
|
||||
# Uses a deep merge function, supporting nested dict updates (e.g., conv.xxx)
|
||||
@@ -68,7 +72,7 @@ class WorkflowState(TypedDict):
|
||||
streaming_buffer: Annotated[dict[str, Any], lambda x, y: {**x, **y}]
|
||||
|
||||
# node activate status
|
||||
activate: Annotated[dict[str, bool], merget_activate_state]
|
||||
activate: Annotated[dict[str, bool], merge_activate_state]
|
||||
|
||||
|
||||
class BaseNode(ABC):
|
||||
|
||||
@@ -28,6 +28,6 @@ class BreakNode(BaseNode):
|
||||
Returns:
|
||||
Optional dictionary indicating the loop has been stopped.
|
||||
"""
|
||||
state["looping"] = False
|
||||
state["looping"] = 2
|
||||
logger.info(f"Setting cycle node exit flag, cycle={self.cycle}, looping={state['looping']}")
|
||||
|
||||
|
||||
@@ -58,10 +58,10 @@ class IterationRuntime:
|
||||
idx: Index of the element in the input array.
|
||||
|
||||
Returns:
|
||||
A deep copy of the workflow state with iteration-specific variables set.
|
||||
A copy of the workflow state with iteration-specific variables set.
|
||||
"""
|
||||
loopstate = WorkflowState(
|
||||
**copy.deepcopy(self.state)
|
||||
**self.state
|
||||
)
|
||||
loopstate["runtime_vars"][self.node_id] = {
|
||||
"item": item,
|
||||
@@ -71,7 +71,7 @@ class IterationRuntime:
|
||||
"item": item,
|
||||
"index": idx,
|
||||
}
|
||||
loopstate["looping"] = True
|
||||
loopstate["looping"] = 1
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
|
||||
@@ -89,7 +89,7 @@ class IterationRuntime:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if not result["looping"]:
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
return result
|
||||
|
||||
@@ -150,10 +150,9 @@ class IterationRuntime:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if not result["looping"]:
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
idx += 1
|
||||
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return {
|
||||
"output": self.result,
|
||||
|
||||
@@ -46,6 +46,7 @@ class LoopRuntime:
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
self.typed_config = LoopNodeConfig(**config)
|
||||
self.looping = True
|
||||
|
||||
def _init_loop_state(self):
|
||||
"""
|
||||
@@ -88,7 +89,7 @@ class LoopRuntime:
|
||||
loopstate = WorkflowState(
|
||||
**self.state
|
||||
)
|
||||
loopstate["looping"] = True
|
||||
loopstate["looping"] = 1
|
||||
loopstate["activate"][self.start_id] = True
|
||||
return loopstate
|
||||
|
||||
@@ -179,9 +180,12 @@ class LoopRuntime:
|
||||
loopstate = self._init_loop_state()
|
||||
loop_time = self.typed_config.max_loop
|
||||
child_state = []
|
||||
while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
|
||||
while self.evaluate_conditional(loopstate) and self.looping and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
child_state.append(await self.graph.ainvoke(loopstate))
|
||||
result = await self.graph.ainvoke(loopstate)
|
||||
child_state.append(result)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
loop_time -= 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
|
||||
Reference in New Issue
Block a user