Merge pull request #399 from SuanmoSuanyangTechnology/feature/workflow-cycle-state
feat(workflow): include loop information in loop node outputs
This commit is contained in:
@@ -271,3 +271,11 @@ class EventStreamHandler:
|
||||
}
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def handle_cycle_item_event(data: dict):
|
||||
yield {
|
||||
"event": "cycle_item",
|
||||
"data": data.get("data")
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -279,6 +279,10 @@ class WorkflowExecutor:
|
||||
async for error_event in self.event_handler.handle_node_error_event(data):
|
||||
yield error_event
|
||||
|
||||
elif event_type == "cycle_item":
|
||||
async for cycle_event in self.event_handler.handle_cycle_item_event(data):
|
||||
yield cycle_event
|
||||
|
||||
elif mode == "debug":
|
||||
async for debug_event in self.event_handler.handle_debug_event(data, input_data):
|
||||
yield debug_event
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
from langgraph.config import get_stream_writer
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,6 +29,7 @@ class IterationRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -42,6 +47,7 @@ class IterationRuntime:
|
||||
state: Current workflow state at the point of iteration.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -49,6 +55,12 @@ class IterationRuntime:
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.event_write = get_stream_writer()
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
self.output_value = None
|
||||
self.result: list = []
|
||||
@@ -91,7 +103,46 @@ class IterationRuntime:
|
||||
item: The input element for this iteration.
|
||||
idx: The index of this iteration.
|
||||
"""
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
await self._init_iteration_state(item, idx),
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
continue
|
||||
if mode == "debug":
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
result = payload.get("result", {})
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
continue
|
||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
||||
self.event_write({
|
||||
"type": "cycle_item",
|
||||
"data": {
|
||||
"cycle_id": self.node_id,
|
||||
"cycle_idx": idx,
|
||||
"node_id": node_name,
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
result = self.graph.get_state(config=self.checkpoint).values
|
||||
else:
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
@@ -152,16 +203,9 @@ class IterationRuntime:
|
||||
while idx < len(array_obj) and self.looping:
|
||||
logger.info(f"Iteration node {self.node_id}: running")
|
||||
item = array_obj[idx]
|
||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||
child_state.append(result)
|
||||
output = self.child_variable_pool.get_value(self.output_value)
|
||||
result = await self.run_task(item, idx)
|
||||
self.merge_conv_vars()
|
||||
if isinstance(output, list) and self.typed_config.flatten:
|
||||
self.result.extend(output)
|
||||
else:
|
||||
self.result.append(output)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
child_state.append(result)
|
||||
idx += 1
|
||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||
return {
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.config import get_stream_writer
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||
|
||||
@@ -27,6 +30,7 @@ class LoopRuntime:
|
||||
def __init__(
|
||||
self,
|
||||
start_id: str,
|
||||
stream: bool,
|
||||
graph: CompiledStateGraph,
|
||||
node_id: str,
|
||||
config: dict[str, Any],
|
||||
@@ -46,6 +50,7 @@ class LoopRuntime:
|
||||
child_variable_pool: A VariablePool instance for managing child node outputs.
|
||||
"""
|
||||
self.start_id = start_id
|
||||
self.stream = stream
|
||||
self.graph = graph
|
||||
self.state = state
|
||||
self.node_id = node_id
|
||||
@@ -53,6 +58,13 @@ class LoopRuntime:
|
||||
self.looping = True
|
||||
self.variable_pool = variable_pool
|
||||
self.child_variable_pool = child_variable_pool
|
||||
self.event_write = get_stream_writer()
|
||||
|
||||
self.checkpoint = RunnableConfig(
|
||||
configurable={
|
||||
"thread_id": uuid.uuid4()
|
||||
}
|
||||
)
|
||||
|
||||
async def _init_loop_state(self):
|
||||
"""
|
||||
@@ -142,10 +154,12 @@ class LoopRuntime:
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {operator}")
|
||||
|
||||
def merge_conv_vars(self):
|
||||
def merge_conv_vars(self, loopstate):
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables.get("conv", {})
|
||||
)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||
|
||||
def evaluate_conditional(self) -> bool:
|
||||
"""
|
||||
@@ -175,6 +189,50 @@ class LoopRuntime:
|
||||
else:
|
||||
return any(conditions)
|
||||
|
||||
async def _run(self, loopstate, idx):
|
||||
if self.stream:
|
||||
async for event in self.graph.astream(
|
||||
loopstate,
|
||||
stream_mode=["debug"],
|
||||
config=self.checkpoint
|
||||
):
|
||||
if isinstance(event, tuple) and len(event) == 2:
|
||||
mode, data = event
|
||||
else:
|
||||
continue
|
||||
if mode == "debug":
|
||||
event_type = data.get("type")
|
||||
payload = data.get("payload", {})
|
||||
node_name = payload.get("name")
|
||||
|
||||
if node_name and node_name.startswith("nop"):
|
||||
continue
|
||||
if event_type == "task_result":
|
||||
result = payload.get("result", {})
|
||||
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||
if not result.get("activate", {}).get(node_name):
|
||||
continue
|
||||
cycle_variable = None
|
||||
if node_type == NodeType.CYCLE_START:
|
||||
cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {})
|
||||
self.event_write({
|
||||
"type": "cycle_item",
|
||||
"data": {
|
||||
"cycle_id": self.node_id,
|
||||
"cycle_idx": idx,
|
||||
"node_id": node_name,
|
||||
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||
if not cycle_variable else cycle_variable,
|
||||
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||
}
|
||||
})
|
||||
return self.graph.get_state(config=self.checkpoint).values
|
||||
else:
|
||||
return await self.graph.ainvoke(loopstate)
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
Execute the loop node until termination conditions are met.
|
||||
@@ -190,15 +248,17 @@ class LoopRuntime:
|
||||
loopstate = await self._init_loop_state()
|
||||
loop_time = self.typed_config.max_loop
|
||||
child_state = []
|
||||
idx = 0
|
||||
while not self.evaluate_conditional() and self.looping and loop_time > 0:
|
||||
logger.info(f"loop node {self.node_id}: running")
|
||||
result = await self.graph.ainvoke(loopstate)
|
||||
result = await self._run(loopstate, idx)
|
||||
child_state.append(result)
|
||||
|
||||
self.merge_conv_vars()
|
||||
self.merge_conv_vars(loopstate)
|
||||
if result["looping"] == 2:
|
||||
self.looping = False
|
||||
loop_time -= 1
|
||||
idx += 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
||||
|
||||
@@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode):
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
@@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode):
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
return await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=False,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
@@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode):
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
if self.node_type == NodeType.LOOP:
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool,
|
||||
).run()
|
||||
}
|
||||
return
|
||||
if self.node_type == NodeType.ITERATION:
|
||||
yield {
|
||||
"__final__": True,
|
||||
"result": await IterationRuntime(
|
||||
start_id=self.start_node_id,
|
||||
stream=True,
|
||||
graph=self.graph,
|
||||
node_id=self.node_id,
|
||||
config=self.config,
|
||||
state=state,
|
||||
variable_pool=variable_pool,
|
||||
child_variable_pool=self.child_variable_pool
|
||||
).run()
|
||||
}
|
||||
return
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
@@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel):
|
||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||
stream: bool = Field(default=False, description="是否流式返回")
|
||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
||||
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||
|
||||
|
||||
class DraftRunResponse(BaseModel):
|
||||
|
||||
@@ -588,7 +588,7 @@ class WorkflowService:
|
||||
"message_length": len(payload.get("output", ""))
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error":
|
||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
Reference in New Issue
Block a user