Merge pull request #219 from SuanmoSuanyangTechnology/fix/workflow-stream

fix(workflow): fix streaming output issues with multi-output End nodes
This commit is contained in:
Mark
2026-01-28 15:32:48 +08:00
committed by GitHub
4 changed files with 275 additions and 170 deletions

View File

@@ -16,7 +16,6 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.template_renderer import render_template
logger = logging.getLogger(__name__)
@@ -157,12 +156,137 @@ class WorkflowExecutor:
"error": result.get("error"),
}
def _update_end_activate(self, node_id):
def _update_scope_activate(self, scope, status=None):
"""
Update the activation state of all End nodes based on a completed scope (node or variable).
Iterates over all End nodes in `self.end_outputs` and calls
`update_activate` on each, which may:
- Activate variable segments that depend on the completed node/scope.
- Activate the entire End node output if all control conditions are met.
If any End node becomes active and `self.activate_end` is not yet set,
this node will be marked as the currently active End node.
Args:
scope (str): The node ID or scope that has completed execution.
status (str | None): Optional status of the node (used for branch/control nodes).
"""
for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(node_id)
self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and self.activate_end is None:
self.activate_end = node
def _update_stream_output_status(self, activate, data):
"""
Update the stream output state of End nodes based on workflow state updates.
This method checks which nodes/scopes are activated and propagates
activation to End nodes accordingly.
Args:
activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated.
data (dict): Mapping of node_id -> node runtime data, including outputs.
Behavior:
For each node in `data`:
1. If the node is activated (`activate[node_id]` is True),
retrieve its output status from `runtime_vars`.
2. Call `_update_scope_activate` to propagate the activation
to all relevant End nodes and update `self.activate_end`.
"""
for node_id in data.keys():
if activate.get(node_id):
node_output_status = (
data[node_id]
.get('runtime_vars', {})
.get(node_id)
.get("output")
)
self._update_scope_activate(node_id, status=node_output_status)
async def _emit_active_chunks(
self,
node_outputs: dict,
variables: dict,
force=False
):
"""
Process and yield all currently active output segments for the currently active End node.
This method handles stream-mode output for an End node by iterating through its output segments
(`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless
`force=True`, which allows all segments to be processed regardless of their activation state.
Behavior:
1. Iterates from the current `cursor` position to the end of the outputs list.
2. For each segment:
- If the segment is literal text (`is_variable=False`), append it directly.
- If the segment is a variable (`is_variable=True`), evaluate it using
`evaluate_expression` with the given `node_outputs` and `variables`,
then transform the result with `_trans_output_string`.
3. Yield a stream event of type "message" containing the processed chunk.
4. Move the `cursor` forward after processing each segment.
5. When all segments have been processed, remove this End node from `end_outputs`
and reset `activate_end` to None.
Args:
node_outputs (dict): Current runtime node outputs, used for variable evaluation.
variables (dict): Current runtime variables, used for variable evaluation.
force (bool, default=False): If True, process segments even if `activate=False`.
Yields:
dict: A stream event of type "message" containing the processed chunk.
Notes:
- Segments that fail evaluation (ValueError) are skipped with a warning logged.
- This method only processes the currently active End node (`self.activate_end`).
- Use `force=True` for final emission regardless of activation state.
"""
end_info = self.end_outputs[self.activate_end]
while end_info.cursor < len(end_info.outputs):
final_chunk = ''
current_segment = end_info.outputs[end_info.cursor]
if not current_segment.activate and not force:
# Stop processing until this segment becomes active
break
# Literal segment
if not current_segment.is_variable:
final_chunk += current_segment.literal
else:
# Variable segment: evaluate and transform
try:
chunk = evaluate_expression(
current_segment.literal,
variables=variables,
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
final_chunk += chunk
except ValueError:
# Log failed evaluation but continue streaming
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}")
if final_chunk:
yield {
"event": "message",
"data": {
"chunk": final_chunk
}
}
# Advance cursor after processing
end_info.cursor += 1
# Remove End node from active tracking if all segments have been processed
if end_info.cursor >= len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
@staticmethod
def _trans_output_string(content):
if isinstance(content, str):
@@ -218,14 +342,8 @@ class WorkflowExecutor:
result = await graph.ainvoke(initial_state, config=self.checkpoint_config)
full_content = ''
for end_info in self.end_outputs.values():
output_template = "".join([output.literal for output in end_info.outputs])
full_content += render_template(
output_template,
result.get("variables", {}),
result.get("runtime_vars", {}),
strict=False
)
for end_id in self.end_outputs.keys():
full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '')
result["messages"].extend(
[
{
@@ -306,7 +424,7 @@ class WorkflowExecutor:
try:
chunk_count = 0
full_content = ''
self._update_scope_activate("sys")
async for event in graph.astream(
initial_state,
stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode
@@ -333,9 +451,12 @@ class WorkflowExecutor:
if not end_info or end_info.cursor >= len(end_info.outputs):
continue
current_output = end_info.outputs[end_info.cursor]
if current_output.is_variable and current_output.depends_on_node(node_id):
if current_output.is_variable and current_output.depends_on_scope(node_id):
if data.get("done"):
end_info.cursor += 1
if end_info.cursor >= len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
else:
full_content += data.get("chunk")
yield {
@@ -415,91 +536,53 @@ class WorkflowExecutor:
elif mode == "updates":
# Handle state updates - store final state
for node_id in data.keys():
self._update_end_activate(node_id)
wait = False
state = graph.get_state(config=self.checkpoint_config)
node_outputs = state.values.get("runtime_vars", {})
for _ in data.keys():
node_outputs = node_outputs | data.get(_).get("runtime_vars", {})
state = graph.get_state(config=self.checkpoint_config).values
node_outputs = state.get("runtime_vars", {})
variables = state.get("variables", {})
activate = state.get("activate", {})
for _, node_data in data.items():
node_outputs |= node_data.get("runtime_vars", {})
variables |= node_data.get("variables", {})
self._update_stream_output_status(activate, data)
wait = False
while self.activate_end and not wait:
message = ''
logger.info(self.activate_end)
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
while content.activate:
if not content.is_variable:
full_content += content.literal
message += content.literal
else:
try:
chunk = evaluate_expression(
content.literal,
variables={},
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
break
content = end_info.outputs[end_info.cursor]
if end_info.cursor != len(end_info.outputs):
async for msg_event in self._emit_active_chunks(
node_outputs=node_outputs,
variables=variables
):
full_content += msg_event["data"]['chunk']
yield msg_event
if self.activate_end:
wait = True
else:
self.end_outputs.pop(self.activate_end)
self.activate_end = None
for node_id in data.keys():
self._update_end_activate(node_id)
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
self._update_stream_output_status(activate, data)
logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} "
f"- execution_id: {self.execution_id}")
result = graph.get_state(self.checkpoint_config).values
while self.activate_end:
message = ''
end_info = self.end_outputs[self.activate_end]
content = end_info.outputs[end_info.cursor]
if not content.is_variable:
message += content.literal
else:
node_outputs = result.get("runtime_vars", {})
variables = result.get("variables", {})
try:
chunk = evaluate_expression(
content.literal,
node_outputs = result.get("runtime_vars", {})
variables = result.get("variables", {})
self.end_outputs = {
node_id: node_info
for node_id, node_info in self.end_outputs.items()
if node_info.activate
}
if self.end_outputs or self.activate_end:
while self.activate_end:
async for msg_event in self._emit_active_chunks(
node_outputs=node_outputs,
variables=variables,
node_outputs=node_outputs
)
chunk = self._trans_output_string(chunk)
message += chunk
full_content += chunk
except ValueError:
pass
end_info.cursor += 1
if end_info.cursor == len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
if self.end_outputs:
force=True
):
full_content += msg_event["data"]['chunk']
yield msg_event
if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0]
if message:
yield {
"event": "message",
"data": {
"chunk": message
}
}
# 计算耗时
end_time = datetime.datetime.now()

View File

@@ -53,114 +53,110 @@ class OutputContent(BaseModel):
)
)
def depends_on_node(self, node_id: str) -> bool:
def depends_on_scope(self, scope: str) -> bool:
"""
Check if this output segment depends on a specific node's variable.
This method examines the `literal` of the output segment to see if it
contains a variable placeholder referencing the given node in the form:
{{ node_id.field_name }}
It uses a regular expression to match the exact node ID, avoiding
false positives from substring matches (e.g., 'node1' should not match 'node10').
Check if this segment depends on a given scope.
Args:
node_id (str): The ID of the node to check for in this segment's variable placeholders.
scope (str): Node ID or special variable prefix (e.g., "sys").
Returns:
bool:
- True if the segment contains a variable referencing the given node.
- False otherwise.
Example:
literal = "{{node1.name}}"
depends_on_node("node1") -> True
depends_on_node("node2") -> False
Usage:
This method is primarily used in stream mode to determine whether
a particular variable output segment should be activated when a
specific upstream node completes execution.
bool: True if this segment references the given scope.
"""
variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}"
pattern = re.compile(variable_pattern)
match = pattern.search(self.literal)
if match:
return True
return False
pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}"
return bool(re.search(pattern, self.literal))
class StreamOutputConfig(BaseModel):
"""
Streaming output configuration for an End node.
This structure controls:
- whether the End node output is globally active
- which upstream branch nodes are responsible for activation
- how each output segment behaves in streaming mode
This configuration describes how the End node output behaves in streaming mode,
including:
- whether output emission is globally activated
- which upstream branch/control nodes gate the activation
- how each parsed output segment is streamed and activated
"""
activate: bool = Field(
...,
description=(
"Global activation state of the End node output.\n"
"If False, no output should be emitted until all control nodes are resolved."
"Global activation flag for the End node output.\n"
"When False, output segments should not be emitted even if available.\n"
"This flag typically becomes True once required control branch conditions "
"are satisfied."
)
)
control_nodes: list[str] = Field(
control_nodes: dict[str, str] = Field(
...,
description=(
"List of upstream branch node IDs that control this End node.\n"
"Each node must signal completion before output becomes active."
"Control branch conditions for this End node output.\n"
"Mapping of `branch_node_id -> expected_branch_label`.\n"
"The End node output becomes globally active when a controlling branch node "
"reports a matching completion status."
)
)
outputs: list[OutputContent] = Field(
...,
description="Ordered list of output segments parsed from the output template."
description=(
"Ordered list of output segments parsed from the output template.\n"
"Each segment represents either a literal text block or a variable placeholder "
"that may be activated independently."
)
)
cursor: int = Field(
...,
description=(
"Streaming cursor index.\n"
"Indicates how many output segments have already been emitted."
"Indicates the next output segment index to be emitted.\n"
"Segments with index < cursor are considered already streamed."
)
)
def update_activate(self, node_id):
def update_activate(self, scope: str, status=None):
"""
Update activation state based on an upstream node completion.
Update streaming activation state based on an upstream node or special variable.
This method is typically called when a branch/control node finishes execution.
Args:
scope (str):
Identifier of the completed upstream entity.
- If a control branch node, it should match a key in `control_nodes`.
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
status (optional):
Completion status of the control branch node.
Required when `scope` refers to a control node.
Behavior:
1. If the node is a control node:
- Remove it from `control_nodes`
- If all control nodes are resolved, activate the entire output
1. Control branch nodes:
- If `scope` matches a key in `control_nodes` and `status` matches the expected
branch label, the End node output becomes globally active (`activate = True`).
2. Activate variable output segments that depend on this node:
- If an output segment is a variable
- And its literal references the completed node_id
- Mark that segment as active
2. Variable output segments:
- For each segment that is a variable (`is_variable=True`):
- If the segment literal references `scope`, mark the segment as active.
- This applies both to regular node variables (e.g., "node_id.field")
and special system variables (e.g., "sys.xxx").
Notes:
- This method does not emit output or advance the streaming cursor.
- It only updates activation flags based on upstream events or special variables.
"""
# Case 1: resolve control branch dependency
if node_id in self.control_nodes:
self.control_nodes.remove(node_id)
# All branch constraints resolved → enable output
if not self.control_nodes:
if scope in self.control_nodes.keys():
if status is None:
raise RuntimeError("[Stream Output] Control node activation status not provided")
if status == self.control_nodes[scope]:
self.activate = True
# Case 2: activate variable segments related to this node
for i in range(len(self.outputs)):
if (
self.outputs[i].is_variable
and self.outputs[i].depends_on_node(node_id)
and self.outputs[i].depends_on_scope(scope)
):
self.outputs[i].activate = True
@@ -184,11 +180,11 @@ class GraphBuilder:
self._find_upstream_branch_node = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
self._analyze_end_node_output()
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
self._analyze_end_node_output()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
@property
@@ -216,30 +212,53 @@ class GraphBuilder:
except KeyError:
raise RuntimeError(f"Node not found: Id={node_id}")
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]:
"""Find upstream branch nodes for a given target node in the workflow graph.
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
"""
Recursively find all upstream branch (control) nodes that influence the execution
of the given target node.
This method identifies all upstream control (branch) nodes that can affect
the execution of `target_node`. If `target_node` is reachable from a start
node (i.e., a node with no upstream nodes), the method returns an empty tuple.
This method walks upstream along the workflow graph starting from `target_node`.
It distinguishes between:
- branch nodes (node types listed in `BRANCH_NODES`)
- non-branch nodes (ordinary processing nodes)
The function distinguishes between branch nodes (defined in `BRANCH_NODES`)
and non-branch nodes, recursively traversing upstream through non-branch
nodes. If any non-branch upstream path does not lead to a branch node,
the result will indicate that no valid upstream branch node exists.
Traversal rules:
1. For each immediate upstream node:
- If it is a branch node, it is recorded as an affecting control node.
- If it is a non-branch node, the traversal continues recursively upstream.
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
a branch node, the traversal is considered invalid:
- `has_branch` will be False
- no branch nodes are returned.
3. Only when ALL upstream non-branch paths eventually lead to at least one
branch node will `has_branch` be True.
Special case:
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
it is considered directly reachable from the workflow entry, and therefore
has no controlling branch nodes.
Args:
target_node (str): The identifier of the target node.
target_node (str):
The identifier of the node whose upstream control branches
are to be resolved.
Returns:
tuple[bool, tuple[str]]:
- has_branch (bool): True if all upstream non-branch paths lead to at least
one branch node; False if any path reaches a start node without a branch.
- branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs
affecting `target_node`. Returns an empty tuple if `has_branch` is False.
tuple[bool, tuple[tuple[str, str]]]:
- has_branch (bool):
True if every upstream path from `target_node` encounters
at least one branch node.
False if any path reaches a start node without a branch.
- branch_nodes (tuple[tuple[str, str]]):
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
representing all branch nodes that can influence `target_node`.
Returns an empty tuple if `has_branch` is False.
"""
source_nodes = [
edge.get("source")
{
"id": edge.get("source"),
"branch": edge.get("label")
}
for edge in self.edges
if edge.get("target") == target_node
]
@@ -249,11 +268,13 @@ class GraphBuilder:
branch_nodes = []
non_branch_nodes = []
for node_id in source_nodes:
if self.get_node_type(node_id) in BRANCH_NODES:
branch_nodes.append(node_id)
for node_info in source_nodes:
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
branch_nodes.append(
(node_info["id"], node_info["branch"])
)
else:
non_branch_nodes.append(node_id)
non_branch_nodes.append(node_info["id"])
has_branch = True
for node_id in non_branch_nodes:
@@ -334,7 +355,7 @@ class GraphBuilder:
activate=not has_branch,
# Branch nodes that control activation of this End node
control_nodes=list(control_nodes),
control_nodes=dict(control_nodes),
# Convert output segments into OutputContent objects
outputs=list(
@@ -362,7 +383,7 @@ class GraphBuilder:
else:
self.end_node_map[end_node_id] = StreamOutputConfig(
activate=True,
control_nodes=[],
control_nodes={},
outputs=list(
[
OutputContent(

View File

@@ -25,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig):
...
)
config_id: UUID = Field(
config_id: UUID | int = Field(
...
)

View File

@@ -36,9 +36,10 @@ class MemoryReadNode(BaseNode):
class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = MemoryWriteNodeConfig(**self.config)
self.typed_config: MemoryWriteNodeConfig | None = None
async def execute(self, state: WorkflowState) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", state)
if not end_user_id: