feat(workflow): optimize streaming output logic for sequential execution of multiple END nodes
This commit is contained in:
@@ -23,6 +23,17 @@ from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Regex to split output into:
|
||||||
|
# - variable placeholders: {{ ... }}
|
||||||
|
# - normal literal text
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# "Hello {{user.name}}!" ->
|
||||||
|
# ["Hello ", "{{user.name}}", "!"]
|
||||||
|
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||||
|
# Strict variable format: {{ node_id.field_name }}
|
||||||
|
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
class GraphBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -41,17 +52,20 @@ class GraphBuilder:
|
|||||||
self.end_node_ids = []
|
self.end_node_ids = []
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map = {node["id"]: node for node in self.nodes}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_branch_node = lru_cache(
|
self._find_upstream_activation_dep = lru_cache(
|
||||||
maxsize=len(self.nodes) * 2
|
maxsize=len(self.nodes) * 2
|
||||||
)(self._find_upstream_branch_node)
|
)(self._find_upstream_activation_dep)
|
||||||
if variable_pool:
|
if variable_pool:
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
else:
|
else:
|
||||||
self.variable_pool = VariablePool()
|
self.variable_pool = VariablePool()
|
||||||
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
self._build_reverse_adj()
|
||||||
self._analyze_end_node_output()
|
self._analyze_end_node_output()
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
@@ -87,60 +101,48 @@ class GraphBuilder:
|
|||||||
result[node[0]].append(node[1])
|
result[node[0]].append(node[1])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
def _build_reverse_adj(self):
|
||||||
"""
|
for edge in self.edges:
|
||||||
Recursively find all upstream branch (control) nodes that influence the execution
|
self._reverse_adj[edge.get("target")].append({
|
||||||
of the given target node.
|
"id": edge["source"], "branch": edge.get("label")
|
||||||
|
})
|
||||||
|
|
||||||
This method walks upstream along the workflow graph starting from `target_node`.
|
def _find_upstream_activation_dep(
|
||||||
It distinguishes between:
|
self,
|
||||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
target_node: str
|
||||||
- non-branch nodes (ordinary processing nodes)
|
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
|
||||||
|
"""Find upstream dependencies that affect the activation of a target node.
|
||||||
|
|
||||||
Traversal rules:
|
Walks upstream along the workflow graph from the target node, collecting
|
||||||
1. For each immediate upstream node:
|
two types of dependencies:
|
||||||
- If it is a branch node, it is recorded as an affecting control node.
|
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
|
||||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
routing outcome determines whether the target node executes.
|
||||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
- Output nodes: upstream END nodes that must complete their output
|
||||||
a branch node, the traversal is considered invalid:
|
before the target node can activate.
|
||||||
- `has_branch` will be False
|
|
||||||
- no branch nodes are returned.
|
|
||||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
|
||||||
branch node will `has_branch` be True.
|
|
||||||
|
|
||||||
Special case:
|
The traversal terminates early and returns empty tuples if any upstream
|
||||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
path reaches START/CYCLE_START without encountering a branch or output
|
||||||
it is considered directly reachable from the workflow entry, and therefore
|
node, indicating the target node is directly reachable and should be
|
||||||
has no controlling branch nodes.
|
activated immediately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_node (str):
|
target_node: The ID of the node whose upstream activation
|
||||||
The identifier of the node whose upstream control branches
|
dependencies are to be resolved.
|
||||||
are to be resolved.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[bool, tuple[tuple[str, str]]]:
|
A tuple of two elements:
|
||||||
- has_branch (bool):
|
- A deduplicated tuple of (branch_node_id, branch_label) pairs
|
||||||
True if every upstream path from `target_node` encounters
|
representing upstream branch control dependencies. Empty if
|
||||||
at least one branch node.
|
any clean path to START exists.
|
||||||
False if any path reaches a start node without a branch.
|
- A deduplicated tuple of upstream output node IDs that must
|
||||||
- branch_nodes (tuple[tuple[str, str]]):
|
complete before this node activates.
|
||||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
|
||||||
representing all branch nodes that can influence `target_node`.
|
|
||||||
Returns an empty tuple if `has_branch` is False.
|
|
||||||
"""
|
"""
|
||||||
source_nodes = [
|
source_nodes = self._reverse_adj[target_node]
|
||||||
{
|
if not source_nodes or self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
"id": edge.get("source"),
|
return tuple(), tuple()
|
||||||
"branch": edge.get("label")
|
|
||||||
}
|
|
||||||
for edge in self.edges
|
|
||||||
if edge.get("target") == target_node
|
|
||||||
]
|
|
||||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
|
||||||
return False, tuple()
|
|
||||||
|
|
||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
|
output_nodes = []
|
||||||
non_branch_nodes = []
|
non_branch_nodes = []
|
||||||
|
|
||||||
for node_info in source_nodes:
|
for node_info in source_nodes:
|
||||||
@@ -149,19 +151,23 @@ class GraphBuilder:
|
|||||||
(node_info["id"], node_info["branch"])
|
(node_info["id"], node_info["branch"])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.get_node_type(node_info["id"]) == NodeType.END:
|
||||||
|
output_nodes.append(node_info["id"])
|
||||||
non_branch_nodes.append(node_info["id"])
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
has_branch = True
|
has_branch = True
|
||||||
for node_id in non_branch_nodes:
|
for node_id in non_branch_nodes:
|
||||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
|
||||||
has_branch = has_branch and node_has_branch
|
if not upstream_control_nodes:
|
||||||
if not has_branch:
|
if not upstream_output_nodes and node_id not in output_nodes:
|
||||||
break
|
return tuple(), tuple()
|
||||||
branch_nodes.extend(nodes)
|
branch_nodes = []
|
||||||
if not has_branch:
|
has_branch = False
|
||||||
branch_nodes = []
|
if has_branch:
|
||||||
|
branch_nodes.extend(upstream_control_nodes)
|
||||||
|
output_nodes.extend(upstream_output_nodes)
|
||||||
|
|
||||||
return has_branch, tuple(set(branch_nodes))
|
return tuple(set(branch_nodes)), tuple(set(output_nodes))
|
||||||
|
|
||||||
def _analyze_end_node_output(self):
|
def _analyze_end_node_output(self):
|
||||||
"""
|
"""
|
||||||
@@ -195,42 +201,33 @@ class GraphBuilder:
|
|||||||
if not output:
|
if not output:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Regex to split output into:
|
|
||||||
# - variable placeholders: {{ ... }}
|
|
||||||
# - normal literal text
|
|
||||||
#
|
|
||||||
# Example:
|
|
||||||
# "Hello {{user.name}}!" ->
|
|
||||||
# ["Hello ", "{{user.name}}", "!"]
|
|
||||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
|
||||||
|
|
||||||
# Strict variable format: {{ node_id.field_name }}
|
|
||||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
|
||||||
variable_pattern = re.compile(variable_pattern_string)
|
|
||||||
|
|
||||||
# Split output into ordered segments
|
# Split output into ordered segments
|
||||||
output_template = list(re.findall(pattern, output))
|
output_template = list(_OUTPUT_PATTERN.findall(output))
|
||||||
|
|
||||||
# Determine whether each segment is literal text
|
# Determine whether each segment is literal text
|
||||||
# True -> literal (can be directly output)
|
# True -> literal (can be directly output)
|
||||||
# False -> variable placeholder (needs runtime value)
|
# False -> variable placeholder (needs runtime value)
|
||||||
output_flag = [
|
output_flag = [
|
||||||
not bool(variable_pattern.match(item))
|
not bool(_VARIABLE_PATTERN.match(item))
|
||||||
for item in output_template
|
for item in output_template
|
||||||
]
|
]
|
||||||
|
|
||||||
# Stream mode: output activation depends on upstream branch nodes
|
# Stream mode: output activation depends on upstream branch nodes
|
||||||
if self.stream:
|
if self.stream:
|
||||||
# Find upstream branch nodes that can control this End node
|
# Find upstream branch nodes that can control this End node
|
||||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
|
||||||
|
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
|
||||||
# Build StreamOutputConfig for this End node
|
# Build StreamOutputConfig for this End node
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
# If there is no upstream branch, output is active immediately
|
# If there is no upstream branch, output is active immediately
|
||||||
activate=not has_branch,
|
activate=activate,
|
||||||
|
|
||||||
# Branch nodes that control activation of this End node
|
# Branch nodes that control activation of this End node
|
||||||
control_nodes=self._merge_control_nodes(control_nodes),
|
control_nodes=self._merge_control_nodes(upstream_control_nodes),
|
||||||
|
upstream_output_nodes=list(upstream_output_nodes),
|
||||||
|
control_resolved=not bool(upstream_control_nodes),
|
||||||
|
output_resolved=not bool(upstream_output_nodes),
|
||||||
|
|
||||||
# Convert output segments into OutputContent objects
|
# Convert output segments into OutputContent objects
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -249,14 +246,16 @@ class GraphBuilder:
|
|||||||
cursor=0
|
cursor=0
|
||||||
)
|
)
|
||||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||||
f"activate: {not has_branch}, "
|
f"activate: {activate}, "
|
||||||
f"control_nodes: {control_nodes},"
|
f"control_nodes: {upstream_control_nodes},"
|
||||||
|
f"ref_outputs: {upstream_output_nodes},"
|
||||||
f"output: {output_template},"
|
f"output: {output_template},"
|
||||||
f"output_activate: {output_flag}")
|
f"output_activate: {output_flag}")
|
||||||
|
|
||||||
# Non-stream mode: all outputs are activated by default
|
# Non-stream mode: all outputs are activated by default
|
||||||
else:
|
else:
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
activate=True,
|
activate=True,
|
||||||
control_nodes={},
|
control_nodes={},
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -269,7 +268,10 @@ class GraphBuilder:
|
|||||||
for output_string, activate in zip(output_template, output_flag)
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
cursor=0
|
cursor=0,
|
||||||
|
upstream_output_nodes=[],
|
||||||
|
control_resolved=True,
|
||||||
|
output_resolved=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_nodes(self):
|
def add_nodes(self):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/9 15:11
|
# @Time : 2026/2/9 15:11
|
||||||
import re
|
import re
|
||||||
|
from queue import Queue
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
|
|||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this output segment is currently active.\n"
|
"Whether this output segment is currently active."
|
||||||
"- True: allowed to be emitted/output\n"
|
"- True: allowed to be emitted/output"
|
||||||
"- False: blocked until activated by branch control"
|
"- False: blocked until activated by branch control"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -46,8 +47,8 @@ class OutputContent(BaseModel):
|
|||||||
is_variable: bool = Field(
|
is_variable: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this segment represents a variable placeholder.\n"
|
"Whether this segment represents a variable placeholder."
|
||||||
"True -> variable (e.g. {{ node.field }})\n"
|
"True -> variable (e.g. {{ node.field }})"
|
||||||
"False -> literal text"
|
"False -> literal text"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel):
|
|||||||
- which upstream branch/control nodes gate the activation
|
- which upstream branch/control nodes gate the activation
|
||||||
- how each parsed output segment is streamed and activated
|
- how each parsed output segment is streamed and activated
|
||||||
"""
|
"""
|
||||||
|
id: str = Field(
|
||||||
|
...,
|
||||||
|
description="ID of the End node this configuration belongs to."
|
||||||
|
)
|
||||||
|
|
||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Global activation flag for the End node output.\n"
|
"Global activation flag for the End node output."
|
||||||
"When False, output segments should not be emitted even if available.\n"
|
"When False, output segments should not be emitted even if available."
|
||||||
"This flag typically becomes True once required control branch conditions "
|
"This flag typically becomes True once required control branch conditions "
|
||||||
"are satisfied."
|
"are satisfied."
|
||||||
)
|
)
|
||||||
@@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel):
|
|||||||
control_nodes: dict[str, list[str]] = Field(
|
control_nodes: dict[str, list[str]] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Control branch conditions for this End node output.\n"
|
"Control branch conditions for this End node output."
|
||||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
"Mapping of `branch_node_id -> expected_branch_label`."
|
||||||
"The End node output becomes globally active when a controlling branch node "
|
"The End node output becomes globally active when a controlling branch node "
|
||||||
"reports a matching completion status."
|
"reports a matching completion status."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upstream_output_nodes: list[str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Upstream output node dependencies (data flow)."
|
||||||
|
"Represents END/output nodes that this output depends on."
|
||||||
|
"These nodes provide data sources required before this output can be activated "
|
||||||
|
"or streamed."
|
||||||
|
"Used to ensure correct ordering and dependency resolution in streaming mode."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
control_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream branch control dependencies have been satisfied."
|
||||||
|
"True if no upstream branch nodes exist or the required branch "
|
||||||
|
"conditions have been met."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream output node dependencies have been completed."
|
||||||
|
"True if no upstream output nodes exist or all upstream output "
|
||||||
|
"nodes have finished their output."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
outputs: list[OutputContent] = Field(
|
outputs: list[OutputContent] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Ordered list of output segments parsed from the output template.\n"
|
"Ordered list of output segments parsed from the output template."
|
||||||
"Each segment represents either a literal text block or a variable placeholder "
|
"Each segment represents either a literal text block or a variable placeholder "
|
||||||
"that may be activated independently."
|
"that may be activated independently."
|
||||||
)
|
)
|
||||||
@@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel):
|
|||||||
cursor: int = Field(
|
cursor: int = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Streaming cursor index.\n"
|
"Streaming cursor index."
|
||||||
"Indicates the next output segment index to be emitted.\n"
|
"Indicates the next output segment index to be emitted."
|
||||||
"Segments with index < cursor are considered already streamed."
|
"Segments with index < cursor are considered already streamed."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
force: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Force flag for output emission."
|
||||||
|
"When True, all output segments are emitted regardless of activation state."
|
||||||
|
"Triggered when this output node has finished execution."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def update_activate(self, scope: str, status=None):
|
def update_activate(self, scope: str, status=None):
|
||||||
"""
|
"""
|
||||||
Update streaming activation state based on an upstream node or special variable.
|
Update streaming activation state based on upstream events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scope (str):
|
scope (str):
|
||||||
Identifier of the completed upstream entity.
|
Identifier of the completed upstream entity.
|
||||||
- If a control branch node, it should match a key in `control_nodes`.
|
- If a control branch node, it should match a key in `control_nodes`.
|
||||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
|
||||||
|
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
|
||||||
|
it may appear in output segments.
|
||||||
|
|
||||||
status (optional):
|
status (optional):
|
||||||
Completion status of the control branch node.
|
Completion status of the control branch node.
|
||||||
Required when `scope` refers to a control node.
|
Required when `scope` refers to a control node.
|
||||||
|
|
||||||
Behavior:
|
Behavior:
|
||||||
1. Control branch nodes:
|
1. Force activation:
|
||||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
- If `self.force` is True, the method returns immediately.
|
||||||
branch label, the End node output becomes globally active (`activate = True`).
|
- If `scope == self.id`, the node marks itself as completed:
|
||||||
|
- `activate = True`
|
||||||
|
- `force = True`
|
||||||
|
This is typically used for final flushing when the node finishes execution.
|
||||||
|
|
||||||
2. Variable output segments:
|
2. Control dependency resolution:
|
||||||
- For each segment that is a variable (`is_variable=True`):
|
- If `scope` matches a key in `control_nodes`:
|
||||||
- If the segment literal references `scope`, mark the segment as active.
|
- `status` must be provided.
|
||||||
- This applies both to regular node variables (e.g., "node_id.field")
|
- If `status` matches expected branch labels, mark control as resolved
|
||||||
and special system variables (e.g., "sys.xxx").
|
(`control_resolved = True`).
|
||||||
|
|
||||||
|
3. Upstream output dependency resolution:
|
||||||
|
- If `scope` is in `upstream_output_nodes`,
|
||||||
|
mark data dependency as resolved (`output_resolved = True`).
|
||||||
|
|
||||||
|
4. Global activation condition:
|
||||||
|
- The node becomes active when BOTH conditions are satisfied:
|
||||||
|
- control_resolved == True
|
||||||
|
- output_resolved == True
|
||||||
|
- Once activated, `activate` remains True.
|
||||||
|
|
||||||
|
5. Variable segment activation:
|
||||||
|
- For each output segment that is a variable (`is_variable=True`):
|
||||||
|
- If the segment depends on the given `scope`,
|
||||||
|
mark the segment as active.
|
||||||
|
- This applies to both node variables (e.g., "node_id.field")
|
||||||
|
and system variables (e.g., "sys.xxx").
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- This method does not emit output or advance the streaming cursor.
|
- This method does NOT emit output or advance the streaming cursor.
|
||||||
- It only updates activation flags based on upstream events or special variables.
|
- It only updates activation and dependency resolution states.
|
||||||
|
- Activation is driven by both control flow (branch nodes) and
|
||||||
|
data flow (upstream output nodes).
|
||||||
"""
|
"""
|
||||||
|
if self.force:
|
||||||
|
return
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
if scope == self.id:
|
||||||
|
self.activate = True
|
||||||
|
self.force = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# resolve control branch dependency
|
||||||
if scope in self.control_nodes:
|
if scope in self.control_nodes:
|
||||||
if status is None:
|
if status is None:
|
||||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
if status in self.control_nodes[scope]:
|
if status in self.control_nodes[scope]:
|
||||||
self.activate = True
|
self.control_resolved = True
|
||||||
|
|
||||||
# Case 2: activate variable segments related to this node
|
if scope in self.upstream_output_nodes:
|
||||||
|
self.upstream_output_nodes.remove(scope)
|
||||||
|
if not self.upstream_output_nodes:
|
||||||
|
self.output_resolved = True
|
||||||
|
|
||||||
|
self.activate = self.activate or (self.control_resolved and self.output_resolved)
|
||||||
|
|
||||||
|
# activate variable segments related to this node
|
||||||
for i in range(len(self.outputs)):
|
for i in range(len(self.outputs)):
|
||||||
if (
|
if (
|
||||||
self.outputs[i].is_variable
|
self.outputs[i].is_variable
|
||||||
@@ -174,6 +256,8 @@ class StreamOutputCoordinator:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||||
self.activate_end: str | None = None
|
self.activate_end: str | None = None
|
||||||
|
self.output_queue: Queue = Queue()
|
||||||
|
self.processed_outputs = []
|
||||||
|
|
||||||
def initialize_end_outputs(
|
def initialize_end_outputs(
|
||||||
self,
|
self,
|
||||||
@@ -211,8 +295,11 @@ class StreamOutputCoordinator:
|
|||||||
"""
|
"""
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs.keys():
|
||||||
self.end_outputs[node].update_activate(scope, status)
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
if self.end_outputs[node].activate and self.activate_end is None:
|
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||||
self.activate_end = node
|
self.output_queue.put(node)
|
||||||
|
self.processed_outputs.append(node)
|
||||||
|
if self.activate_end is None and not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
|
|
||||||
async def emit_activate_chunk(
|
async def emit_activate_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -256,7 +343,7 @@ class StreamOutputCoordinator:
|
|||||||
final_chunk = ''
|
final_chunk = ''
|
||||||
current_segment = end_info.outputs[end_info.cursor]
|
current_segment = end_info.outputs[end_info.cursor]
|
||||||
|
|
||||||
if not current_segment.activate and not force:
|
if not current_segment.activate and not force and not end_info.force:
|
||||||
# Stop processing until this segment becomes active
|
# Stop processing until this segment becomes active
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -285,8 +372,7 @@ class StreamOutputCoordinator:
|
|||||||
end_info.cursor += 1
|
end_info.cursor += 1
|
||||||
|
|
||||||
if end_info.cursor >= len(end_info.outputs):
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
self.end_outputs.pop(self.activate_end)
|
self.pop_current_activate_end()
|
||||||
self.activate_end = None
|
|
||||||
|
|
||||||
async def flush_remaining_chunk(
|
async def flush_remaining_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -325,6 +411,8 @@ class StreamOutputCoordinator:
|
|||||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
|
if not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
# Move to next active End node if current one is done
|
# Move to next active End node if current one is done
|
||||||
if not self.activate_end and self.end_outputs:
|
if not self.activate_end and self.end_outputs:
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
|||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType
|
||||||
|
|
||||||
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
|
|||||||
"""
|
"""
|
||||||
if isinstance(var, str):
|
if isinstance(var, str):
|
||||||
return cls.STRING
|
return cls.STRING
|
||||||
elif isinstance(var, (int, float)):
|
|
||||||
return cls.NUMBER
|
|
||||||
elif isinstance(var, bool):
|
elif isinstance(var, bool):
|
||||||
return cls.BOOLEAN
|
return cls.BOOLEAN
|
||||||
|
elif isinstance(var, (int, float)):
|
||||||
|
return cls.NUMBER
|
||||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||||
return cls.FILE
|
return cls.FILE
|
||||||
elif isinstance(var, dict):
|
elif isinstance(var, dict):
|
||||||
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
|
|||||||
content_cache: dict = Field(default_factory=dict)
|
content_cache: dict = Field(default_factory=dict)
|
||||||
is_file: bool
|
is_file: bool
|
||||||
|
|
||||||
_byte_content: bytes | None = None
|
_byte_content: bytes | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
def get_content(self):
|
def get_content(self):
|
||||||
return self._byte_content
|
return self._byte_content
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
|
|||||||
|
|
||||||
|
|
||||||
class StringVariable(BaseVariable):
|
class StringVariable(BaseVariable):
|
||||||
|
value: str
|
||||||
type = 'str'
|
type = 'str'
|
||||||
|
|
||||||
def valid_value(self, value) -> str:
|
def valid_value(self, value) -> str:
|
||||||
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class NumberVariable(BaseVariable):
|
class NumberVariable(BaseVariable):
|
||||||
|
value: int | float
|
||||||
type = 'number'
|
type = 'number'
|
||||||
|
|
||||||
def valid_value(self, value) -> int | float:
|
def valid_value(self, value) -> int | float:
|
||||||
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class BooleanVariable(BaseVariable):
|
class BooleanVariable(BaseVariable):
|
||||||
|
value: bool
|
||||||
type = 'boolean'
|
type = 'boolean'
|
||||||
|
|
||||||
def valid_value(self, value) -> bool:
|
def valid_value(self, value) -> bool:
|
||||||
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class DictVariable(BaseVariable):
|
class DictVariable(BaseVariable):
|
||||||
|
value: dict
|
||||||
type = 'object'
|
type = 'object'
|
||||||
|
|
||||||
def valid_value(self, value) -> dict:
|
def valid_value(self, value) -> dict:
|
||||||
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class FileVariable(BaseVariable):
|
class FileVariable(BaseVariable):
|
||||||
|
value: FileObject
|
||||||
type = 'file'
|
type = 'file'
|
||||||
|
|
||||||
def valid_value(self, value) -> FileObject:
|
def valid_value(self, value) -> FileObject:
|
||||||
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class ArrayVariable(BaseVariable, Generic[T]):
|
class ArrayVariable(BaseVariable, Generic[T]):
|
||||||
|
value: list[T]
|
||||||
type = 'array'
|
type = 'array'
|
||||||
|
|
||||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||||
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
class NestedArrayVariable(BaseVariable):
|
class NestedArrayVariable(BaseVariable):
|
||||||
|
value: list[ArrayVariable]
|
||||||
type = 'array_nest'
|
type = 'array_nest'
|
||||||
|
|
||||||
def valid_value(self, value: list[T]) -> list[T]:
|
def valid_value(self, value: list[T]) -> list[T]:
|
||||||
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
|
|||||||
category=RuntimeWarning
|
category=RuntimeWarning
|
||||||
)
|
)
|
||||||
class AnyVariable(BaseVariable):
|
class AnyVariable(BaseVariable):
|
||||||
|
value: Any
|
||||||
type = 'any'
|
type = 'any'
|
||||||
|
|
||||||
def valid_value(self, value: Any) -> Any:
|
def valid_value(self, value: Any) -> Any:
|
||||||
|
|||||||
Reference in New Issue
Block a user