feat(workflow): optimize streaming output logic for sequential execution of multiple END nodes

This commit is contained in:
Eternity
2026-03-19 21:26:59 +08:00
parent 9d8c26b999
commit fcc81ac025
4 changed files with 209 additions and 111 deletions

View File

@@ -23,6 +23,17 @@ from app.core.workflow.utils.expression_evaluator import evaluate_condition
logger = logging.getLogger(__name__)
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
# Strict variable format: {{ node_id.field_name }}
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
class GraphBuilder:
def __init__(
@@ -41,17 +52,20 @@ class GraphBuilder:
self.end_node_ids = []
self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_branch_node = lru_cache(
self._find_upstream_activation_dep = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
)(self._find_upstream_activation_dep)
if variable_pool:
self.variable_pool = variable_pool
else:
self.variable_pool = VariablePool()
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj()
self._analyze_end_node_output()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
@@ -87,60 +101,48 @@ class GraphBuilder:
result[node[0]].append(node[1])
return result
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
"""
Recursively find all upstream branch (control) nodes that influence the execution
of the given target node.
def _build_reverse_adj(self):
for edge in self.edges:
self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label")
})
This method walks upstream along the workflow graph starting from `target_node`.
It distinguishes between:
- branch nodes (node types listed in `BRANCH_NODES`)
- non-branch nodes (ordinary processing nodes)
def _find_upstream_activation_dep(
self,
target_node: str
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
"""Find upstream dependencies that affect the activation of a target node.
Traversal rules:
1. For each immediate upstream node:
- If it is a branch node, it is recorded as an affecting control node.
- If it is a non-branch node, the traversal continues recursively upstream.
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
a branch node, the traversal is considered invalid:
- `has_branch` will be False
- no branch nodes are returned.
3. Only when ALL upstream non-branch paths eventually lead to at least one
branch node will `has_branch` be True.
Walks upstream along the workflow graph from the target node, collecting
two types of dependencies:
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
routing outcome determines whether the target node executes.
- Output nodes: upstream END nodes that must complete their output
before the target node can activate.
Special case:
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
it is considered directly reachable from the workflow entry, and therefore
has no controlling branch nodes.
The traversal terminates early and returns empty tuples if any upstream
path reaches START/CYCLE_START without encountering a branch or output
node, indicating the target node is directly reachable and should be
activated immediately.
Args:
target_node (str):
The identifier of the node whose upstream control branches
are to be resolved.
target_node: The ID of the node whose upstream activation
dependencies are to be resolved.
Returns:
tuple[bool, tuple[tuple[str, str]]]:
- has_branch (bool):
True if every upstream path from `target_node` encounters
at least one branch node.
False if any path reaches a start node without a branch.
- branch_nodes (tuple[tuple[str, str]]):
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
representing all branch nodes that can influence `target_node`.
Returns an empty tuple if `has_branch` is False.
A tuple of two elements:
- A deduplicated tuple of (branch_node_id, branch_label) pairs
representing upstream branch control dependencies. Empty if
any clean path to START exists.
- A deduplicated tuple of upstream output node IDs that must
complete before this node activates.
"""
source_nodes = [
{
"id": edge.get("source"),
"branch": edge.get("label")
}
for edge in self.edges
if edge.get("target") == target_node
]
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return False, tuple()
source_nodes = self._reverse_adj[target_node]
if not source_nodes or self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return tuple(), tuple()
branch_nodes = []
output_nodes = []
non_branch_nodes = []
for node_info in source_nodes:
@@ -149,19 +151,23 @@ class GraphBuilder:
(node_info["id"], node_info["branch"])
)
else:
if self.get_node_type(node_info["id"]) == NodeType.END:
output_nodes.append(node_info["id"])
non_branch_nodes.append(node_info["id"])
has_branch = True
for node_id in non_branch_nodes:
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
has_branch = has_branch and node_has_branch
if not has_branch:
break
branch_nodes.extend(nodes)
if not has_branch:
branch_nodes = []
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
if not upstream_control_nodes:
if not upstream_output_nodes and node_id not in output_nodes:
return tuple(), tuple()
branch_nodes = []
has_branch = False
if has_branch:
branch_nodes.extend(upstream_control_nodes)
output_nodes.extend(upstream_output_nodes)
return has_branch, tuple(set(branch_nodes))
return tuple(set(branch_nodes)), tuple(set(output_nodes))
def _analyze_end_node_output(self):
"""
@@ -195,42 +201,33 @@ class GraphBuilder:
if not output:
continue
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
pattern = r'\{\{.*?\}\}|[^{}]+'
# Strict variable format: {{ node_id.field_name }}
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
# Split output into ordered segments
output_template = list(re.findall(pattern, output))
output_template = list(_OUTPUT_PATTERN.findall(output))
# Determine whether each segment is literal text
# True -> literal (can be directly output)
# False -> variable placeholder (needs runtime value)
output_flag = [
not bool(variable_pattern.match(item))
not bool(_VARIABLE_PATTERN.match(item))
for item in output_template
]
# Stream mode: output activation depends on upstream branch nodes
if self.stream:
# Find upstream branch nodes that can control this End node
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
# Build StreamOutputConfig for this End node
self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
# If there is no upstream branch, output is active immediately
activate=not has_branch,
activate=activate,
# Branch nodes that control activation of this End node
control_nodes=self._merge_control_nodes(control_nodes),
control_nodes=self._merge_control_nodes(upstream_control_nodes),
upstream_output_nodes=list(upstream_output_nodes),
control_resolved=not bool(upstream_control_nodes),
output_resolved=not bool(upstream_output_nodes),
# Convert output segments into OutputContent objects
outputs=list(
@@ -249,14 +246,16 @@ class GraphBuilder:
cursor=0
)
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
f"activate: {not has_branch}, "
f"control_nodes: {control_nodes},"
f"activate: {activate}, "
f"control_nodes: {upstream_control_nodes},"
f"ref_outputs: {upstream_output_nodes},"
f"output: {output_template},"
f"output_activate: {output_flag}")
# Non-stream mode: all outputs are activated by default
else:
self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
activate=True,
control_nodes={},
outputs=list(
@@ -269,7 +268,10 @@ class GraphBuilder:
for output_string, activate in zip(output_template, output_flag)
]
),
cursor=0
cursor=0,
upstream_output_nodes=[],
control_resolved=True,
output_resolved=True,
)
def add_nodes(self):

View File

@@ -3,6 +3,7 @@
# @Email: 1533512157@qq.com
# @Time : 2026/2/9 15:11
import re
from queue import Queue
from typing import AsyncGenerator
from pydantic import BaseModel, Field, PrivateAttr
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
activate: bool = Field(
...,
description=(
"Whether this output segment is currently active.\n"
"- True: allowed to be emitted/output\n"
"Whether this output segment is currently active."
"- True: allowed to be emitted/output"
"- False: blocked until activated by branch control"
)
)
@@ -46,8 +47,8 @@ class OutputContent(BaseModel):
is_variable: bool = Field(
...,
description=(
"Whether this segment represents a variable placeholder.\n"
"True -> variable (e.g. {{ node.field }})\n"
"Whether this segment represents a variable placeholder."
"True -> variable (e.g. {{ node.field }})"
"False -> literal text"
)
)
@@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel):
- which upstream branch/control nodes gate the activation
- how each parsed output segment is streamed and activated
"""
id: str = Field(
...,
description="ID of the End node this configuration belongs to."
)
activate: bool = Field(
...,
description=(
"Global activation flag for the End node output.\n"
"When False, output segments should not be emitted even if available.\n"
"Global activation flag for the End node output."
"When False, output segments should not be emitted even if available."
"This flag typically becomes True once required control branch conditions "
"are satisfied."
)
@@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel):
control_nodes: dict[str, list[str]] = Field(
...,
description=(
"Control branch conditions for this End node output.\n"
"Mapping of `branch_node_id -> expected_branch_label`.\n"
"Control branch conditions for this End node output."
"Mapping of `branch_node_id -> expected_branch_label`."
"The End node output becomes globally active when a controlling branch node "
"reports a matching completion status."
)
)
upstream_output_nodes: list[str] = Field(
...,
description=(
"Upstream output node dependencies (data flow)."
"Represents END/output nodes that this output depends on."
"These nodes provide data sources required before this output can be activated "
"or streamed."
"Used to ensure correct ordering and dependency resolution in streaming mode."
)
)
control_resolved: bool = Field(
...,
description=(
"Whether all upstream branch control dependencies have been satisfied."
"True if no upstream branch nodes exist or the required branch "
"conditions have been met."
)
)
output_resolved: bool = Field(
...,
description=(
"Whether all upstream output node dependencies have been completed."
"True if no upstream output nodes exist or all upstream output "
"nodes have finished their output."
)
)
outputs: list[OutputContent] = Field(
...,
description=(
"Ordered list of output segments parsed from the output template.\n"
"Ordered list of output segments parsed from the output template."
"Each segment represents either a literal text block or a variable placeholder "
"that may be activated independently."
)
@@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel):
cursor: int = Field(
...,
description=(
"Streaming cursor index.\n"
"Indicates the next output segment index to be emitted.\n"
"Streaming cursor index."
"Indicates the next output segment index to be emitted."
"Segments with index < cursor are considered already streamed."
)
)
force: bool = Field(
default=False,
description=(
"Force flag for output emission."
"When True, all output segments are emitted regardless of activation state."
"Triggered when this output node has finished execution."
)
)
def update_activate(self, scope: str, status=None):
"""
Update streaming activation state based on an upstream node or special variable.
Update streaming activation state based on upstream events.
Args:
scope (str):
Identifier of the completed upstream entity.
- If a control branch node, it should match a key in `control_nodes`.
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
it may appear in output segments.
status (optional):
Completion status of the control branch node.
Required when `scope` refers to a control node.
Behavior:
1. Control branch nodes:
- If `scope` matches a key in `control_nodes` and `status` matches the expected
branch label, the End node output becomes globally active (`activate = True`).
1. Force activation:
- If `self.force` is True, the method returns immediately.
- If `scope == self.id`, the node marks itself as completed:
- `activate = True`
- `force = True`
This is typically used for final flushing when the node finishes execution.
2. Variable output segments:
- For each segment that is a variable (`is_variable=True`):
- If the segment literal references `scope`, mark the segment as active.
- This applies both to regular node variables (e.g., "node_id.field")
and special system variables (e.g., "sys.xxx").
2. Control dependency resolution:
- If `scope` matches a key in `control_nodes`:
- `status` must be provided.
- If `status` matches expected branch labels, mark control as resolved
(`control_resolved = True`).
3. Upstream output dependency resolution:
- If `scope` is in `upstream_output_nodes`,
mark data dependency as resolved (`output_resolved = True`).
4. Global activation condition:
- The node becomes active when BOTH conditions are satisfied:
- control_resolved == True
- output_resolved == True
- Once activated, `activate` remains True.
5. Variable segment activation:
- For each output segment that is a variable (`is_variable=True`):
- If the segment depends on the given `scope`,
mark the segment as active.
- This applies to both node variables (e.g., "node_id.field")
and system variables (e.g., "sys.xxx").
Notes:
- This method does not emit output or advance the streaming cursor.
- It only updates activation flags based on upstream events or special variables.
- This method does NOT emit output or advance the streaming cursor.
- It only updates activation and dependency resolution states.
- Activation is driven by both control flow (branch nodes) and
data flow (upstream output nodes).
"""
if self.force:
return
# Case 1: resolve control branch dependency
if scope == self.id:
self.activate = True
self.force = True
return
# resolve control branch dependency
if scope in self.control_nodes:
if status is None:
raise RuntimeError("[Stream Output] Control node activation status not provided")
if status in self.control_nodes[scope]:
self.activate = True
self.control_resolved = True
# Case 2: activate variable segments related to this node
if scope in self.upstream_output_nodes:
self.upstream_output_nodes.remove(scope)
if not self.upstream_output_nodes:
self.output_resolved = True
self.activate = self.activate or (self.control_resolved and self.output_resolved)
# activate variable segments related to this node
for i in range(len(self.outputs)):
if (
self.outputs[i].is_variable
@@ -174,6 +256,8 @@ class StreamOutputCoordinator:
def __init__(self):
self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None
self.output_queue: Queue = Queue()
self.processed_outputs = []
def initialize_end_outputs(
self,
@@ -211,8 +295,11 @@ class StreamOutputCoordinator:
"""
for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and self.activate_end is None:
self.activate_end = node
if self.end_outputs[node].activate and node not in self.processed_outputs:
self.output_queue.put(node)
self.processed_outputs.append(node)
if self.activate_end is None and not self.output_queue.empty():
self.activate_end = self.output_queue.get_nowait()
async def emit_activate_chunk(
self,
@@ -256,7 +343,7 @@ class StreamOutputCoordinator:
final_chunk = ''
current_segment = end_info.outputs[end_info.cursor]
if not current_segment.activate and not force:
if not current_segment.activate and not force and not end_info.force:
# Stop processing until this segment becomes active
break
@@ -285,8 +372,7 @@ class StreamOutputCoordinator:
end_info.cursor += 1
if end_info.cursor >= len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
self.pop_current_activate_end()
async def flush_remaining_chunk(
self,
@@ -325,6 +411,8 @@ class StreamOutputCoordinator:
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
yield msg_event
if not self.output_queue.empty():
self.activate_end = self.output_queue.get_nowait()
# Move to next active End node if current one is done
if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0]

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from app.schemas import FileType
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
"""
if isinstance(var, str):
return cls.STRING
elif isinstance(var, (int, float)):
return cls.NUMBER
elif isinstance(var, bool):
return cls.BOOLEAN
elif isinstance(var, (int, float)):
return cls.NUMBER
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE
elif isinstance(var, dict):
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
content_cache: dict = Field(default_factory=dict)
is_file: bool
_byte_content: bytes | None = None
_byte_content: bytes | None = PrivateAttr(default=None)
def get_content(self):
return self._byte_content

View File

@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
class StringVariable(BaseVariable):
value: str
type = 'str'
def valid_value(self, value) -> str:
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
class NumberVariable(BaseVariable):
value: int | float
type = 'number'
def valid_value(self, value) -> int | float:
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
class BooleanVariable(BaseVariable):
value: bool
type = 'boolean'
def valid_value(self, value) -> bool:
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
class DictVariable(BaseVariable):
value: dict
type = 'object'
def valid_value(self, value) -> dict:
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
class FileVariable(BaseVariable):
value: FileObject
type = 'file'
def valid_value(self, value) -> FileObject:
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
class ArrayVariable(BaseVariable, Generic[T]):
value: list[T]
type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]):
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
class NestedArrayVariable(BaseVariable):
value: list[ArrayVariable]
type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]:
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
category=RuntimeWarning
)
class AnyVariable(BaseVariable):
value: Any
type = 'any'
def valid_value(self, value: Any) -> Any: