feat(workflow): include loop information in loop node outputs
This commit is contained in:
@@ -271,3 +271,11 @@ class EventStreamHandler:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_cycle_item_event(data: dict):
|
||||||
|
yield {
|
||||||
|
"event": "cycle_item",
|
||||||
|
"data": data.get("data")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -279,6 +279,10 @@ class WorkflowExecutor:
|
|||||||
async for error_event in self.event_handler.handle_node_error_event(data):
|
async for error_event in self.event_handler.handle_node_error_event(data):
|
||||||
yield error_event
|
yield error_event
|
||||||
|
|
||||||
|
elif event_type == "cycle_item":
|
||||||
|
async for cycle_event in self.event_handler.handle_cycle_item_event(data):
|
||||||
|
yield cycle_event
|
||||||
|
|
||||||
elif mode == "debug":
|
elif mode == "debug":
|
||||||
async for debug_event in self.event_handler.handle_debug_event(data, input_data):
|
async for debug_event in self.event_handler.handle_debug_event(data, input_data):
|
||||||
yield debug_event
|
yield debug_event
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
from langgraph.config import get_stream_writer
|
||||||
|
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
from app.core.workflow.nodes.cycle_graph import IterationNodeConfig
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -25,6 +29,7 @@ class IterationRuntime:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
start_id: str,
|
start_id: str,
|
||||||
|
stream: bool,
|
||||||
graph: CompiledStateGraph,
|
graph: CompiledStateGraph,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
config: dict[str, Any],
|
config: dict[str, Any],
|
||||||
@@ -42,6 +47,7 @@ class IterationRuntime:
|
|||||||
state: Current workflow state at the point of iteration.
|
state: Current workflow state at the point of iteration.
|
||||||
"""
|
"""
|
||||||
self.start_id = start_id
|
self.start_id = start_id
|
||||||
|
self.stream = stream
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.state = state
|
self.state = state
|
||||||
self.node_id = node_id
|
self.node_id = node_id
|
||||||
@@ -49,6 +55,12 @@ class IterationRuntime:
|
|||||||
self.looping = True
|
self.looping = True
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
self.child_variable_pool = child_variable_pool
|
self.child_variable_pool = child_variable_pool
|
||||||
|
self.event_write = get_stream_writer()
|
||||||
|
self.checkpoint = RunnableConfig(
|
||||||
|
configurable={
|
||||||
|
"thread_id": uuid.uuid4()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
self.output_value = None
|
self.output_value = None
|
||||||
self.result: list = []
|
self.result: list = []
|
||||||
@@ -91,7 +103,46 @@ class IterationRuntime:
|
|||||||
item: The input element for this iteration.
|
item: The input element for this iteration.
|
||||||
idx: The index of this iteration.
|
idx: The index of this iteration.
|
||||||
"""
|
"""
|
||||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
if self.stream:
|
||||||
|
async for event in self.graph.astream(
|
||||||
|
await self._init_iteration_state(item, idx),
|
||||||
|
stream_mode=["debug"],
|
||||||
|
config=self.checkpoint
|
||||||
|
):
|
||||||
|
if isinstance(event, tuple) and len(event) == 2:
|
||||||
|
mode, data = event
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if mode == "debug":
|
||||||
|
event_type = data.get("type")
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
node_name = payload.get("name")
|
||||||
|
|
||||||
|
if node_name and node_name.startswith("nop"):
|
||||||
|
continue
|
||||||
|
if event_type == "task_result":
|
||||||
|
result = payload.get("result", {})
|
||||||
|
if not result.get("activate", {}).get(node_name):
|
||||||
|
continue
|
||||||
|
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||||
|
cycle_variable = {"item": item} if node_type == NodeType.CYCLE_START else None
|
||||||
|
self.event_write({
|
||||||
|
"type": "cycle_item",
|
||||||
|
"data": {
|
||||||
|
"cycle_id": self.node_id,
|
||||||
|
"cycle_idx": idx,
|
||||||
|
"node_id": node_name,
|
||||||
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||||
|
if not cycle_variable else cycle_variable,
|
||||||
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||||
|
if not cycle_variable else cycle_variable,
|
||||||
|
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||||
|
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
result = self.graph.get_state(config=self.checkpoint).values
|
||||||
|
else:
|
||||||
|
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
||||||
output = self.child_variable_pool.get_value(self.output_value)
|
output = self.child_variable_pool.get_value(self.output_value)
|
||||||
if isinstance(output, list) and self.typed_config.flatten:
|
if isinstance(output, list) and self.typed_config.flatten:
|
||||||
self.result.extend(output)
|
self.result.extend(output)
|
||||||
@@ -152,16 +203,9 @@ class IterationRuntime:
|
|||||||
while idx < len(array_obj) and self.looping:
|
while idx < len(array_obj) and self.looping:
|
||||||
logger.info(f"Iteration node {self.node_id}: running")
|
logger.info(f"Iteration node {self.node_id}: running")
|
||||||
item = array_obj[idx]
|
item = array_obj[idx]
|
||||||
result = await self.graph.ainvoke(await self._init_iteration_state(item, idx))
|
result = await self.run_task(item, idx)
|
||||||
child_state.append(result)
|
|
||||||
output = self.child_variable_pool.get_value(self.output_value)
|
|
||||||
self.merge_conv_vars()
|
self.merge_conv_vars()
|
||||||
if isinstance(output, list) and self.typed_config.flatten:
|
child_state.append(result)
|
||||||
self.result.extend(output)
|
|
||||||
else:
|
|
||||||
self.result.append(output)
|
|
||||||
if result["looping"] == 2:
|
|
||||||
self.looping = False
|
|
||||||
idx += 1
|
idx += 1
|
||||||
logger.info(f"Iteration node {self.node_id}: execution completed")
|
logger.info(f"Iteration node {self.node_id}: execution completed")
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.runnables import RunnableConfig
|
||||||
|
from langgraph.config import get_stream_writer
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
|
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||||
|
|
||||||
@@ -27,6 +30,7 @@ class LoopRuntime:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
start_id: str,
|
start_id: str,
|
||||||
|
stream: bool,
|
||||||
graph: CompiledStateGraph,
|
graph: CompiledStateGraph,
|
||||||
node_id: str,
|
node_id: str,
|
||||||
config: dict[str, Any],
|
config: dict[str, Any],
|
||||||
@@ -46,6 +50,7 @@ class LoopRuntime:
|
|||||||
child_variable_pool: A VariablePool instance for managing child node outputs.
|
child_variable_pool: A VariablePool instance for managing child node outputs.
|
||||||
"""
|
"""
|
||||||
self.start_id = start_id
|
self.start_id = start_id
|
||||||
|
self.stream = stream
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.state = state
|
self.state = state
|
||||||
self.node_id = node_id
|
self.node_id = node_id
|
||||||
@@ -53,6 +58,13 @@ class LoopRuntime:
|
|||||||
self.looping = True
|
self.looping = True
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
self.child_variable_pool = child_variable_pool
|
self.child_variable_pool = child_variable_pool
|
||||||
|
self.event_write = get_stream_writer()
|
||||||
|
|
||||||
|
self.checkpoint = RunnableConfig(
|
||||||
|
configurable={
|
||||||
|
"thread_id": uuid.uuid4()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
async def _init_loop_state(self):
|
async def _init_loop_state(self):
|
||||||
"""
|
"""
|
||||||
@@ -142,10 +154,12 @@ class LoopRuntime:
|
|||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid condition: {operator}")
|
raise ValueError(f"Invalid condition: {operator}")
|
||||||
|
|
||||||
def merge_conv_vars(self):
|
def merge_conv_vars(self, loopstate):
|
||||||
self.variable_pool.variables["conv"].update(
|
self.variable_pool.variables["conv"].update(
|
||||||
self.child_variable_pool.variables.get("conv", {})
|
self.child_variable_pool.variables.get("conv", {})
|
||||||
)
|
)
|
||||||
|
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||||
|
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||||
|
|
||||||
def evaluate_conditional(self) -> bool:
|
def evaluate_conditional(self) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -175,6 +189,50 @@ class LoopRuntime:
|
|||||||
else:
|
else:
|
||||||
return any(conditions)
|
return any(conditions)
|
||||||
|
|
||||||
|
async def _run(self, loopstate, idx):
|
||||||
|
if self.stream:
|
||||||
|
async for event in self.graph.astream(
|
||||||
|
loopstate,
|
||||||
|
stream_mode=["debug"],
|
||||||
|
config=self.checkpoint
|
||||||
|
):
|
||||||
|
if isinstance(event, tuple) and len(event) == 2:
|
||||||
|
mode, data = event
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if mode == "debug":
|
||||||
|
event_type = data.get("type")
|
||||||
|
payload = data.get("payload", {})
|
||||||
|
node_name = payload.get("name")
|
||||||
|
|
||||||
|
if node_name and node_name.startswith("nop"):
|
||||||
|
continue
|
||||||
|
if event_type == "task_result":
|
||||||
|
result = payload.get("result", {})
|
||||||
|
node_type = result.get("node_outputs", {}).get(node_name, {}).get("node_type")
|
||||||
|
if not result.get("activate", {}).get(node_name):
|
||||||
|
continue
|
||||||
|
cycle_variable = None
|
||||||
|
if node_type == NodeType.CYCLE_START:
|
||||||
|
cycle_variable = loopstate.get("node_outputs", {}).get(self.node_id, {})
|
||||||
|
self.event_write({
|
||||||
|
"type": "cycle_item",
|
||||||
|
"data": {
|
||||||
|
"cycle_id": self.node_id,
|
||||||
|
"cycle_idx": idx,
|
||||||
|
"node_id": node_name,
|
||||||
|
"input": result.get("node_outputs", {}).get(node_name, {}).get("input")
|
||||||
|
if not cycle_variable else cycle_variable,
|
||||||
|
"output": result.get("node_outputs", {}).get(node_name, {}).get("output")
|
||||||
|
if not cycle_variable else cycle_variable,
|
||||||
|
"elapsed_time": result.get("node_outputs", {}).get(node_name, {}).get("elapsed_time"),
|
||||||
|
"token_usage": result.get("node_outputs", {}).get(node_name, {}).get("token_usage")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return self.graph.get_state(config=self.checkpoint).values
|
||||||
|
else:
|
||||||
|
return await self.graph.ainvoke(loopstate)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
Execute the loop node until termination conditions are met.
|
Execute the loop node until termination conditions are met.
|
||||||
@@ -190,15 +248,17 @@ class LoopRuntime:
|
|||||||
loopstate = await self._init_loop_state()
|
loopstate = await self._init_loop_state()
|
||||||
loop_time = self.typed_config.max_loop
|
loop_time = self.typed_config.max_loop
|
||||||
child_state = []
|
child_state = []
|
||||||
|
idx = 0
|
||||||
while not self.evaluate_conditional() and self.looping and loop_time > 0:
|
while not self.evaluate_conditional() and self.looping and loop_time > 0:
|
||||||
logger.info(f"loop node {self.node_id}: running")
|
logger.info(f"loop node {self.node_id}: running")
|
||||||
result = await self.graph.ainvoke(loopstate)
|
result = await self._run(loopstate, idx)
|
||||||
child_state.append(result)
|
child_state.append(result)
|
||||||
|
|
||||||
self.merge_conv_vars()
|
self.merge_conv_vars(loopstate)
|
||||||
if result["looping"] == 2:
|
if result["looping"] == 2:
|
||||||
self.looping = False
|
self.looping = False
|
||||||
loop_time -= 1
|
loop_time -= 1
|
||||||
|
idx += 1
|
||||||
|
|
||||||
logger.info(f"loop node {self.node_id}: execution completed")
|
logger.info(f"loop node {self.node_id}: execution completed")
|
||||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
||||||
|
|||||||
@@ -172,6 +172,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
return await LoopRuntime(
|
return await LoopRuntime(
|
||||||
start_id=self.start_node_id,
|
start_id=self.start_node_id,
|
||||||
|
stream=False,
|
||||||
graph=self.graph,
|
graph=self.graph,
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -182,6 +183,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
if self.node_type == NodeType.ITERATION:
|
if self.node_type == NodeType.ITERATION:
|
||||||
return await IterationRuntime(
|
return await IterationRuntime(
|
||||||
start_id=self.start_node_id,
|
start_id=self.start_node_id,
|
||||||
|
stream=False,
|
||||||
graph=self.graph,
|
graph=self.graph,
|
||||||
node_id=self.node_id,
|
node_id=self.node_id,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
@@ -190,3 +192,36 @@ class CycleGraphNode(BaseNode):
|
|||||||
child_variable_pool=self.child_variable_pool
|
child_variable_pool=self.child_variable_pool
|
||||||
).run()
|
).run()
|
||||||
raise RuntimeError("Unknown cycle node type")
|
raise RuntimeError("Unknown cycle node type")
|
||||||
|
|
||||||
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
|
if self.node_type == NodeType.LOOP:
|
||||||
|
yield {
|
||||||
|
"__final__": True,
|
||||||
|
"result": await LoopRuntime(
|
||||||
|
start_id=self.start_node_id,
|
||||||
|
stream=True,
|
||||||
|
graph=self.graph,
|
||||||
|
node_id=self.node_id,
|
||||||
|
config=self.config,
|
||||||
|
state=state,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
child_variable_pool=self.child_variable_pool,
|
||||||
|
).run()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
if self.node_type == NodeType.ITERATION:
|
||||||
|
yield {
|
||||||
|
"__final__": True,
|
||||||
|
"result": await IterationRuntime(
|
||||||
|
start_id=self.start_node_id,
|
||||||
|
stream=True,
|
||||||
|
graph=self.graph,
|
||||||
|
node_id=self.node_id,
|
||||||
|
config=self.config,
|
||||||
|
state=state,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
child_variable_pool=self.child_variable_pool
|
||||||
|
).run()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
raise RuntimeError("Unknown cycle node type")
|
||||||
|
|||||||
@@ -439,7 +439,7 @@ class DraftRunRequest(BaseModel):
|
|||||||
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)")
|
||||||
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值")
|
||||||
stream: bool = Field(default=False, description="是否流式返回")
|
stream: bool = Field(default=False, description="是否流式返回")
|
||||||
files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)")
|
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
class DraftRunResponse(BaseModel):
|
class DraftRunResponse(BaseModel):
|
||||||
|
|||||||
@@ -588,7 +588,7 @@ class WorkflowService:
|
|||||||
"message_length": len(payload.get("output", ""))
|
"message_length": len(payload.get("output", ""))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case "node_start" | "node_end" | "node_error":
|
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||||
return None
|
return None
|
||||||
case _:
|
case _:
|
||||||
return event
|
return event
|
||||||
|
|||||||
Reference in New Issue
Block a user