diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 90668ad9..61896574 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -23,6 +23,17 @@ from app.core.workflow.utils.expression_evaluator import evaluate_condition logger = logging.getLogger(__name__) +# Regex to split output into: +# - variable placeholders: {{ ... }} +# - normal literal text +# +# Example: +# "Hello {{user.name}}!" -> +# ["Hello ", "{{user.name}}", "!"] +_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+') +# Strict variable format: {{ node_id.field_name }} +_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}') + class GraphBuilder: def __init__( @@ -41,17 +52,20 @@ class GraphBuilder: self.end_node_ids = [] self.node_map = {node["id"]: node for node in self.nodes} self.end_node_map: dict[str, StreamOutputConfig] = {} - self._find_upstream_branch_node = lru_cache( + self._find_upstream_activation_dep = lru_cache( maxsize=len(self.nodes) * 2 - )(self._find_upstream_branch_node) + )(self._find_upstream_activation_dep) if variable_pool: self.variable_pool = variable_pool else: self.variable_pool = VariablePool() + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) self.graph = StateGraph(WorkflowState) self.add_nodes() self.add_edges() + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) + self._build_reverse_adj() self._analyze_end_node_output() # EDGES MUST BE ADDED AFTER NODES ARE ADDED. @@ -87,60 +101,48 @@ class GraphBuilder: result[node[0]].append(node[1]) return result - def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]: - """ - Recursively find all upstream branch (control) nodes that influence the execution - of the given target node. + def _build_reverse_adj(self): + for edge in self.edges: + self._reverse_adj[edge.get("target")].append({ + "id": edge["source"], "branch": edge.get("label") + }) - This method walks upstream along the workflow graph starting from `target_node`. - It distinguishes between: - - branch nodes (node types listed in `BRANCH_NODES`) - - non-branch nodes (ordinary processing nodes) + def _find_upstream_activation_dep( + self, + target_node: str + ) -> tuple[tuple[tuple[str, str]], tuple[str]]: + """Find upstream dependencies that affect the activation of a target node. - Traversal rules: - 1. For each immediate upstream node: - - If it is a branch node, it is recorded as an affecting control node. - - If it is a non-branch node, the traversal continues recursively upstream. - 2. If ANY upstream path reaches a START / CYCLE_START node without encountering - a branch node, the traversal is considered invalid: - - `has_branch` will be False - - no branch nodes are returned. - 3. Only when ALL upstream non-branch paths eventually lead to at least one - branch node will `has_branch` be True. + Walks upstream along the workflow graph from the target node, collecting + two types of dependencies: + - Branch control nodes: upstream branch nodes (e.g. if-else) whose + routing outcome determines whether the target node executes. + - Output nodes: upstream END nodes that must complete their output + before the target node can activate. - Special case: - - If `target_node` has no upstream nodes AND its type is START or CYCLE_START, - it is considered directly reachable from the workflow entry, and therefore - has no controlling branch nodes. + The traversal terminates early and returns empty tuples if any upstream + path reaches START/CYCLE_START without encountering a branch or output + node, indicating the target node is directly reachable and should be + activated immediately. Args: - target_node (str): - The identifier of the node whose upstream control branches - are to be resolved. + target_node: The ID of the node whose upstream activation + dependencies are to be resolved. Returns: - tuple[bool, tuple[tuple[str, str]]]: - - has_branch (bool): - True if every upstream path from `target_node` encounters - at least one branch node. - False if any path reaches a start node without a branch. - - branch_nodes (tuple[tuple[str, str]]): - A deduplicated tuple of `(branch_node_id, branch_label)` pairs - representing all branch nodes that can influence `target_node`. - Returns an empty tuple if `has_branch` is False. + A tuple of two elements: + - A deduplicated tuple of (branch_node_id, branch_label) pairs + representing upstream branch control dependencies. Empty if + any clean path to START exists. + - A deduplicated tuple of upstream output node IDs that must + complete before this node activates. """ - source_nodes = [ - { - "id": edge.get("source"), - "branch": edge.get("label") - } - for edge in self.edges - if edge.get("target") == target_node - ] - if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: - return False, tuple() + source_nodes = self._reverse_adj[target_node] + if not source_nodes or self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: + return tuple(), tuple() branch_nodes = [] + output_nodes = [] non_branch_nodes = [] for node_info in source_nodes: @@ -149,19 +151,23 @@ class GraphBuilder: (node_info["id"], node_info["branch"]) ) else: + if self.get_node_type(node_info["id"]) == NodeType.END: + output_nodes.append(node_info["id"]) non_branch_nodes.append(node_info["id"]) has_branch = True for node_id in non_branch_nodes: - node_has_branch, nodes = self._find_upstream_branch_node(node_id) - has_branch = has_branch and node_has_branch - if not has_branch: - break - branch_nodes.extend(nodes) - if not has_branch: - branch_nodes = [] + upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id) + if not upstream_control_nodes: + if not upstream_output_nodes and node_id not in output_nodes: + return tuple(), tuple() + branch_nodes = [] + has_branch = False + if has_branch: + branch_nodes.extend(upstream_control_nodes) + output_nodes.extend(upstream_output_nodes) - return has_branch, tuple(set(branch_nodes)) + return tuple(set(branch_nodes)), tuple(set(output_nodes)) def _analyze_end_node_output(self): """ @@ -195,42 +201,33 @@ class GraphBuilder: if not output: continue - # Regex to split output into: - # - variable placeholders: {{ ... }} - # - normal literal text - # - # Example: - # "Hello {{user.name}}!" -> - # ["Hello ", "{{user.name}}", "!"] - pattern = r'\{\{.*?\}\}|[^{}]+' - - # Strict variable format: {{ node_id.field_name }} - variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}' - variable_pattern = re.compile(variable_pattern_string) - # Split output into ordered segments - output_template = list(re.findall(pattern, output)) + output_template = list(_OUTPUT_PATTERN.findall(output)) # Determine whether each segment is literal text # True -> literal (can be directly output) # False -> variable placeholder (needs runtime value) output_flag = [ - not bool(variable_pattern.match(item)) + not bool(_VARIABLE_PATTERN.match(item)) for item in output_template ] # Stream mode: output activation depends on upstream branch nodes if self.stream: # Find upstream branch nodes that can control this End node - has_branch, control_nodes = self._find_upstream_branch_node(end_node_id) - + upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id) + activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes) # Build StreamOutputConfig for this End node self.end_node_map[end_node_id] = StreamOutputConfig( + id=end_node_id, # If there is no upstream branch, output is active immediately - activate=not has_branch, + activate=activate, # Branch nodes that control activation of this End node - control_nodes=self._merge_control_nodes(control_nodes), + control_nodes=self._merge_control_nodes(upstream_control_nodes), + upstream_output_nodes=list(upstream_output_nodes), + control_resolved=not bool(upstream_control_nodes), + output_resolved=not bool(upstream_output_nodes), # Convert output segments into OutputContent objects outputs=list( @@ -249,14 +246,16 @@ class GraphBuilder: cursor=0 ) logger.info(f"[Stream Analysis] end_id: {end_node_id}, " - f"activate: {not has_branch}, " - f"control_nodes: {control_nodes}," + f"activate: {activate}, " + f"control_nodes: {upstream_control_nodes}," + f"ref_outputs: {upstream_output_nodes}," f"output: {output_template}," f"output_activate: {output_flag}") # Non-stream mode: all outputs are activated by default else: self.end_node_map[end_node_id] = StreamOutputConfig( + id=end_node_id, activate=True, control_nodes={}, outputs=list( @@ -269,7 +268,10 @@ class GraphBuilder: for output_string, activate in zip(output_template, output_flag) ] ), - cursor=0 + cursor=0, + upstream_output_nodes=[], + control_resolved=True, + output_resolved=True, ) def add_nodes(self): diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index ddee9adc..8184545c 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -3,6 +3,7 @@ # @Email: 1533512157@qq.com # @Time : 2026/2/9 15:11 import re +from queue import Queue from typing import AsyncGenerator from pydantic import BaseModel, Field, PrivateAttr @@ -37,8 +38,8 @@ class OutputContent(BaseModel): activate: bool = Field( ..., description=( - "Whether this output segment is currently active.\n" - "- True: allowed to be emitted/output\n" + "Whether this output segment is currently active." + "- True: allowed to be emitted/output" "- False: blocked until activated by branch control" ) ) @@ -46,8 +47,8 @@ class OutputContent(BaseModel): is_variable: bool = Field( ..., description=( - "Whether this segment represents a variable placeholder.\n" - "True -> variable (e.g. {{ node.field }})\n" + "Whether this segment represents a variable placeholder." + "True -> variable (e.g. {{ node.field }})" "False -> literal text" ) ) @@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel): - which upstream branch/control nodes gate the activation - how each parsed output segment is streamed and activated """ + id: str = Field( + ..., + description="ID of the End node this configuration belongs to." + ) activate: bool = Field( ..., description=( - "Global activation flag for the End node output.\n" - "When False, output segments should not be emitted even if available.\n" + "Global activation flag for the End node output." + "When False, output segments should not be emitted even if available." "This flag typically becomes True once required control branch conditions " "are satisfied." ) @@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel): control_nodes: dict[str, list[str]] = Field( ..., description=( - "Control branch conditions for this End node output.\n" - "Mapping of `branch_node_id -> expected_branch_label`.\n" + "Control branch conditions for this End node output." + "Mapping of `branch_node_id -> expected_branch_label`." "The End node output becomes globally active when a controlling branch node " "reports a matching completion status." ) ) + upstream_output_nodes: list[str] = Field( + ..., + description=( + "Upstream output node dependencies (data flow)." + "Represents END/output nodes that this output depends on." + "These nodes provide data sources required before this output can be activated " + "or streamed." + "Used to ensure correct ordering and dependency resolution in streaming mode." + ) + ) + + control_resolved: bool = Field( + ..., + description=( + "Whether all upstream branch control dependencies have been satisfied." + "True if no upstream branch nodes exist or the required branch " + "conditions have been met." + ) + ) + + output_resolved: bool = Field( + ..., + description=( + "Whether all upstream output node dependencies have been completed." + "True if no upstream output nodes exist or all upstream output " + "nodes have finished their output." + ) + ) + outputs: list[OutputContent] = Field( ..., description=( - "Ordered list of output segments parsed from the output template.\n" + "Ordered list of output segments parsed from the output template." "Each segment represents either a literal text block or a variable placeholder " "that may be activated independently." ) @@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel): cursor: int = Field( ..., description=( - "Streaming cursor index.\n" - "Indicates the next output segment index to be emitted.\n" + "Streaming cursor index." + "Indicates the next output segment index to be emitted." "Segments with index < cursor are considered already streamed." ) ) + force: bool = Field( + default=False, + description=( + "Force flag for output emission." + "When True, all output segments are emitted regardless of activation state." + "Triggered when this output node has finished execution." + ) + ) + def update_activate(self, scope: str, status=None): """ - Update streaming activation state based on an upstream node or special variable. + Update streaming activation state based on upstream events. Args: scope (str): Identifier of the completed upstream entity. - If a control branch node, it should match a key in `control_nodes`. - - If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments. + - If an upstream output node, it should match an entry in `upstream_output_nodes`. + - If a variable placeholder (e.g., "sys.xxx" or "node_id.field"), + it may appear in output segments. + status (optional): Completion status of the control branch node. Required when `scope` refers to a control node. Behavior: - 1. Control branch nodes: - - If `scope` matches a key in `control_nodes` and `status` matches the expected - branch label, the End node output becomes globally active (`activate = True`). + 1. Force activation: + - If `self.force` is True, the method returns immediately. + - If `scope == self.id`, the node marks itself as completed: + - `activate = True` + - `force = True` + This is typically used for final flushing when the node finishes execution. - 2. Variable output segments: - - For each segment that is a variable (`is_variable=True`): - - If the segment literal references `scope`, mark the segment as active. - - This applies both to regular node variables (e.g., "node_id.field") - and special system variables (e.g., "sys.xxx"). + 2. Control dependency resolution: + - If `scope` matches a key in `control_nodes`: + - `status` must be provided. + - If `status` matches expected branch labels, mark control as resolved + (`control_resolved = True`). + + 3. Upstream output dependency resolution: + - If `scope` is in `upstream_output_nodes`, + mark data dependency as resolved (`output_resolved = True`). + + 4. Global activation condition: + - The node becomes active when BOTH conditions are satisfied: + - control_resolved == True + - output_resolved == True + - Once activated, `activate` remains True. + + 5. Variable segment activation: + - For each output segment that is a variable (`is_variable=True`): + - If the segment depends on the given `scope`, + mark the segment as active. + - This applies to both node variables (e.g., "node_id.field") + and system variables (e.g., "sys.xxx"). Notes: - - This method does not emit output or advance the streaming cursor. - - It only updates activation flags based on upstream events or special variables. + - This method does NOT emit output or advance the streaming cursor. + - It only updates activation and dependency resolution states. + - Activation is driven by both control flow (branch nodes) and + data flow (upstream output nodes). """ + if self.force: + return - # Case 1: resolve control branch dependency + if scope == self.id: + self.activate = True + self.force = True + return + + # resolve control branch dependency if scope in self.control_nodes: if status is None: raise RuntimeError("[Stream Output] Control node activation status not provided") if status in self.control_nodes[scope]: - self.activate = True + self.control_resolved = True - # Case 2: activate variable segments related to this node + if scope in self.upstream_output_nodes: + self.upstream_output_nodes.remove(scope) + if not self.upstream_output_nodes: + self.output_resolved = True + + self.activate = self.activate or (self.control_resolved and self.output_resolved) + + # activate variable segments related to this node for i in range(len(self.outputs)): if ( self.outputs[i].is_variable @@ -174,6 +256,8 @@ class StreamOutputCoordinator: def __init__(self): self.end_outputs: dict[str, StreamOutputConfig] = {} self.activate_end: str | None = None + self.output_queue: Queue = Queue() + self.processed_outputs = [] def initialize_end_outputs( self, @@ -211,8 +295,11 @@ class StreamOutputCoordinator: """ for node in self.end_outputs.keys(): self.end_outputs[node].update_activate(scope, status) - if self.end_outputs[node].activate and self.activate_end is None: - self.activate_end = node + if self.end_outputs[node].activate and node not in self.processed_outputs: + self.output_queue.put(node) + self.processed_outputs.append(node) + if self.activate_end is None and not self.output_queue.empty(): + self.activate_end = self.output_queue.get_nowait() async def emit_activate_chunk( self, @@ -256,7 +343,7 @@ class StreamOutputCoordinator: final_chunk = '' current_segment = end_info.outputs[end_info.cursor] - if not current_segment.activate and not force: + if not current_segment.activate and not force and not end_info.force: # Stop processing until this segment becomes active break @@ -285,8 +372,7 @@ class StreamOutputCoordinator: end_info.cursor += 1 if end_info.cursor >= len(end_info.outputs): - self.end_outputs.pop(self.activate_end) - self.activate_end = None + self.pop_current_activate_end() async def flush_remaining_chunk( self, @@ -325,6 +411,8 @@ class StreamOutputCoordinator: async for msg_event in self.emit_activate_chunk(variable_pool, force=True): yield msg_event + if not self.output_queue.empty(): + self.activate_end = self.output_queue.get_nowait() # Move to next active End node if current one is done if not self.activate_end and self.end_outputs: self.activate_end = list(self.end_outputs.keys())[0] diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index aea40cf6..f5d8ff8f 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -2,7 +2,7 @@ from enum import StrEnum from abc import abstractmethod, ABC from typing import Any -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, PrivateAttr from app.schemas import FileType @@ -41,10 +41,10 @@ class VariableType(StrEnum): """ if isinstance(var, str): return cls.STRING - elif isinstance(var, (int, float)): - return cls.NUMBER elif isinstance(var, bool): return cls.BOOLEAN + elif isinstance(var, (int, float)): + return cls.NUMBER elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')): return cls.FILE elif isinstance(var, dict): @@ -116,7 +116,7 @@ class FileObject(BaseModel): content_cache: dict = Field(default_factory=dict) is_file: bool - _byte_content: bytes | None = None + _byte_content: bytes | None = PrivateAttr(default=None) def get_content(self): return self._byte_content diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 63437fd9..5e8e3f1e 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable) class StringVariable(BaseVariable): + value: str type = 'str' def valid_value(self, value) -> str: @@ -22,6 +23,7 @@ class StringVariable(BaseVariable): class NumberVariable(BaseVariable): + value: int | float type = 'number' def valid_value(self, value) -> int | float: @@ -34,6 +36,7 @@ class NumberVariable(BaseVariable): class BooleanVariable(BaseVariable): + value: bool type = 'boolean' def valid_value(self, value) -> bool: @@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable): class DictVariable(BaseVariable): + value: dict type = 'object' def valid_value(self, value) -> dict: @@ -58,6 +62,7 @@ class DictVariable(BaseVariable): class FileVariable(BaseVariable): + value: FileObject type = 'file' def valid_value(self, value) -> FileObject: @@ -102,6 +107,7 @@ class FileVariable(BaseVariable): class ArrayVariable(BaseVariable, Generic[T]): + value: list[T] type = 'array' def __init__(self, child_type: Type[T], value: list[Any]): @@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]): class NestedArrayVariable(BaseVariable): + value: list[ArrayVariable] type = 'array_nest' def valid_value(self, value: list[T]) -> list[T]: @@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable): category=RuntimeWarning ) class AnyVariable(BaseVariable): + value: Any type = 'any' def valid_value(self, value: Any) -> Any: