diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index f0411ae3..b7abf659 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -16,7 +16,6 @@ from app.core.workflow.graph_builder import GraphBuilder, StreamOutputConfig from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_config import VariableType from app.core.workflow.nodes.enums import NodeType -from app.core.workflow.template_renderer import render_template logger = logging.getLogger(__name__) @@ -157,12 +156,137 @@ class WorkflowExecutor: "error": result.get("error"), } - def _update_end_activate(self, node_id): + def _update_scope_activate(self, scope, status=None): + """ + Update the activation state of all End nodes based on a completed scope (node or variable). + + Iterates over all End nodes in `self.end_outputs` and calls + `update_activate` on each, which may: + - Activate variable segments that depend on the completed node/scope. + - Activate the entire End node output if all control conditions are met. + + If any End node becomes active and `self.activate_end` is not yet set, + this node will be marked as the currently active End node. + + Args: + scope (str): The node ID or scope that has completed execution. + status (str | None): Optional status of the node (used for branch/control nodes). + """ for node in self.end_outputs.keys(): - self.end_outputs[node].update_activate(node_id) + self.end_outputs[node].update_activate(scope, status) if self.end_outputs[node].activate and self.activate_end is None: self.activate_end = node + def _update_stream_output_status(self, activate, data): + """ + Update the stream output state of End nodes based on workflow state updates. + + This method checks which nodes/scopes are activated and propagates + activation to End nodes accordingly. + + Args: + activate (dict): Mapping of node_id -> bool indicating which nodes/scopes are activated. + data (dict): Mapping of node_id -> node runtime data, including outputs. + + Behavior: + For each node in `data`: + 1. If the node is activated (`activate[node_id]` is True), + retrieve its output status from `runtime_vars`. + 2. Call `_update_scope_activate` to propagate the activation + to all relevant End nodes and update `self.activate_end`. + """ + for node_id in data.keys(): + if activate.get(node_id): + node_output_status = ( + data[node_id] + .get('runtime_vars', {}) + .get(node_id) + .get("output") + ) + self._update_scope_activate(node_id, status=node_output_status) + + async def _emit_active_chunks( + self, + node_outputs: dict, + variables: dict, + force=False + ): + """ + Process and yield all currently active output segments for the currently active End node. + + This method handles stream-mode output for an End node by iterating through its output segments + (`OutputContent`). Only segments marked as active (`activate=True`) are processed, unless + `force=True`, which allows all segments to be processed regardless of their activation state. + + Behavior: + 1. Iterates from the current `cursor` position to the end of the outputs list. + 2. For each segment: + - If the segment is literal text (`is_variable=False`), append it directly. + - If the segment is a variable (`is_variable=True`), evaluate it using + `evaluate_expression` with the given `node_outputs` and `variables`, + then transform the result with `_trans_output_string`. + 3. Yield a stream event of type "message" containing the processed chunk. + 4. Move the `cursor` forward after processing each segment. + 5. When all segments have been processed, remove this End node from `end_outputs` + and reset `activate_end` to None. + + Args: + node_outputs (dict): Current runtime node outputs, used for variable evaluation. + variables (dict): Current runtime variables, used for variable evaluation. + force (bool, default=False): If True, process segments even if `activate=False`. + + Yields: + dict: A stream event of type "message" containing the processed chunk. + + Notes: + - Segments that fail evaluation (ValueError) are skipped with a warning logged. + - This method only processes the currently active End node (`self.activate_end`). + - Use `force=True` for final emission regardless of activation state. + """ + + end_info = self.end_outputs[self.activate_end] + + while end_info.cursor < len(end_info.outputs): + final_chunk = '' + current_segment = end_info.outputs[end_info.cursor] + + if not current_segment.activate and not force: + # Stop processing until this segment becomes active + break + + # Literal segment + if not current_segment.is_variable: + final_chunk += current_segment.literal + else: + # Variable segment: evaluate and transform + try: + chunk = evaluate_expression( + current_segment.literal, + variables=variables, + node_outputs=node_outputs + ) + chunk = self._trans_output_string(chunk) + final_chunk += chunk + except ValueError: + # Log failed evaluation but continue streaming + logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}") + + if final_chunk: + yield { + "event": "message", + "data": { + "chunk": final_chunk + } + } + + # Advance cursor after processing + end_info.cursor += 1 + + # Remove End node from active tracking if all segments have been processed + if end_info.cursor >= len(end_info.outputs): + self.end_outputs.pop(self.activate_end) + self.activate_end = None + @staticmethod def _trans_output_string(content): if isinstance(content, str): @@ -218,14 +342,8 @@ class WorkflowExecutor: result = await graph.ainvoke(initial_state, config=self.checkpoint_config) full_content = '' - for end_info in self.end_outputs.values(): - output_template = "".join([output.literal for output in end_info.outputs]) - full_content += render_template( - output_template, - result.get("variables", {}), - result.get("runtime_vars", {}), - strict=False - ) + for end_id in self.end_outputs.keys(): + full_content += result.get('runtime_vars', {}).get(end_id, {}).get('output', '') result["messages"].extend( [ { @@ -306,7 +424,7 @@ class WorkflowExecutor: try: chunk_count = 0 full_content = '' - + self._update_scope_activate("sys") async for event in graph.astream( initial_state, stream_mode=["updates", "debug", "custom"], # Use updates + debug + custom mode @@ -333,9 +451,12 @@ class WorkflowExecutor: if not end_info or end_info.cursor >= len(end_info.outputs): continue current_output = end_info.outputs[end_info.cursor] - if current_output.is_variable and current_output.depends_on_node(node_id): + if current_output.is_variable and current_output.depends_on_scope(node_id): if data.get("done"): end_info.cursor += 1 + if end_info.cursor >= len(end_info.outputs): + self.end_outputs.pop(self.activate_end) + self.activate_end = None else: full_content += data.get("chunk") yield { @@ -415,91 +536,53 @@ class WorkflowExecutor: elif mode == "updates": # Handle state updates - store final state - for node_id in data.keys(): - self._update_end_activate(node_id) - wait = False - state = graph.get_state(config=self.checkpoint_config) - node_outputs = state.values.get("runtime_vars", {}) - for _ in data.keys(): - node_outputs = node_outputs | data.get(_).get("runtime_vars", {}) + state = graph.get_state(config=self.checkpoint_config).values + node_outputs = state.get("runtime_vars", {}) + variables = state.get("variables", {}) + activate = state.get("activate", {}) + for _, node_data in data.items(): + node_outputs |= node_data.get("runtime_vars", {}) + variables |= node_data.get("variables", {}) + self._update_stream_output_status(activate, data) + wait = False while self.activate_end and not wait: - message = '' - logger.info(self.activate_end) - end_info = self.end_outputs[self.activate_end] - content = end_info.outputs[end_info.cursor] - while content.activate: - if not content.is_variable: - full_content += content.literal - message += content.literal - else: - try: - chunk = evaluate_expression( - content.literal, - variables={}, - node_outputs=node_outputs - ) - chunk = self._trans_output_string(chunk) - message += chunk - full_content += chunk - except ValueError: - pass - end_info.cursor += 1 - if end_info.cursor == len(end_info.outputs): - break - content = end_info.outputs[end_info.cursor] - if end_info.cursor != len(end_info.outputs): + async for msg_event in self._emit_active_chunks( + node_outputs=node_outputs, + variables=variables + ): + full_content += msg_event["data"]['chunk'] + yield msg_event + + if self.activate_end: wait = True else: - self.end_outputs.pop(self.activate_end) - self.activate_end = None - for node_id in data.keys(): - self._update_end_activate(node_id) - if message: - yield { - "event": "message", - "data": { - "chunk": message - } - } + self._update_stream_output_status(activate, data) logger.debug(f"[UPDATES] 收到 state 更新 from {list(data.keys())} " f"- execution_id: {self.execution_id}") result = graph.get_state(self.checkpoint_config).values - while self.activate_end: - message = '' - end_info = self.end_outputs[self.activate_end] - content = end_info.outputs[end_info.cursor] - if not content.is_variable: - message += content.literal - else: - node_outputs = result.get("runtime_vars", {}) - variables = result.get("variables", {}) - try: - chunk = evaluate_expression( - content.literal, + node_outputs = result.get("runtime_vars", {}) + variables = result.get("variables", {}) + self.end_outputs = { + node_id: node_info + for node_id, node_info in self.end_outputs.items() + if node_info.activate + } + + if self.end_outputs or self.activate_end: + while self.activate_end: + async for msg_event in self._emit_active_chunks( + node_outputs=node_outputs, variables=variables, - node_outputs=node_outputs - ) - chunk = self._trans_output_string(chunk) - message += chunk - full_content += chunk - except ValueError: - pass - end_info.cursor += 1 - if end_info.cursor == len(end_info.outputs): - self.end_outputs.pop(self.activate_end) - self.activate_end = None - if self.end_outputs: + force=True + ): + full_content += msg_event["data"]['chunk'] + yield msg_event + + if not self.activate_end and self.end_outputs: self.activate_end = list(self.end_outputs.keys())[0] - if message: - yield { - "event": "message", - "data": { - "chunk": message - } - } # 计算耗时 end_time = datetime.datetime.now() diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py index 9fa89fd2..b1d43e08 100644 --- a/api/app/core/workflow/graph_builder.py +++ b/api/app/core/workflow/graph_builder.py @@ -53,114 +53,110 @@ class OutputContent(BaseModel): ) ) - def depends_on_node(self, node_id: str) -> bool: + def depends_on_scope(self, scope: str) -> bool: """ - Check if this output segment depends on a specific node's variable. - - This method examines the `literal` of the output segment to see if it - contains a variable placeholder referencing the given node in the form: - - {{ node_id.field_name }} - - It uses a regular expression to match the exact node ID, avoiding - false positives from substring matches (e.g., 'node1' should not match 'node10'). + Check if this segment depends on a given scope. Args: - node_id (str): The ID of the node to check for in this segment's variable placeholders. + scope (str): Node ID or special variable prefix (e.g., "sys"). Returns: - bool: - - True if the segment contains a variable referencing the given node. - - False otherwise. - - Example: - literal = "{{node1.name}}" - - depends_on_node("node1") -> True - depends_on_node("node2") -> False - - Usage: - This method is primarily used in stream mode to determine whether - a particular variable output segment should be activated when a - specific upstream node completes execution. + bool: True if this segment references the given scope. """ - variable_pattern = rf"\{{\{{\s*{re.escape(node_id)}\.[a-zA-Z0-9_]+\s*\}}\}}" - pattern = re.compile(variable_pattern) - match = pattern.search(self.literal) - if match: - return True - return False + pattern = rf"\{{\{{\s*{re.escape(scope)}\.[a-zA-Z0-9_]+\s*\}}\}}" + return bool(re.search(pattern, self.literal)) class StreamOutputConfig(BaseModel): """ Streaming output configuration for an End node. - This structure controls: - - whether the End node output is globally active - - which upstream branch nodes are responsible for activation - - how each output segment behaves in streaming mode + This configuration describes how the End node output behaves in streaming mode, + including: + - whether output emission is globally activated + - which upstream branch/control nodes gate the activation + - how each parsed output segment is streamed and activated """ activate: bool = Field( ..., description=( - "Global activation state of the End node output.\n" - "If False, no output should be emitted until all control nodes are resolved." + "Global activation flag for the End node output.\n" + "When False, output segments should not be emitted even if available.\n" + "This flag typically becomes True once required control branch conditions " + "are satisfied." ) ) - control_nodes: list[str] = Field( + control_nodes: dict[str, str] = Field( ..., description=( - "List of upstream branch node IDs that control this End node.\n" - "Each node must signal completion before output becomes active." + "Control branch conditions for this End node output.\n" + "Mapping of `branch_node_id -> expected_branch_label`.\n" + "The End node output becomes globally active when a controlling branch node " + "reports a matching completion status." ) ) outputs: list[OutputContent] = Field( ..., - description="Ordered list of output segments parsed from the output template." + description=( + "Ordered list of output segments parsed from the output template.\n" + "Each segment represents either a literal text block or a variable placeholder " + "that may be activated independently." + ) ) cursor: int = Field( ..., description=( "Streaming cursor index.\n" - "Indicates how many output segments have already been emitted." + "Indicates the next output segment index to be emitted.\n" + "Segments with index < cursor are considered already streamed." ) ) - def update_activate(self, node_id): + def update_activate(self, scope: str, status=None): """ - Update activation state based on an upstream node completion. + Update streaming activation state based on an upstream node or special variable. - This method is typically called when a branch/control node finishes execution. + Args: + scope (str): + Identifier of the completed upstream entity. + - If a control branch node, it should match a key in `control_nodes`. + - If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments. + status (optional): + Completion status of the control branch node. + Required when `scope` refers to a control node. Behavior: - 1. If the node is a control node: - - Remove it from `control_nodes` - - If all control nodes are resolved, activate the entire output + 1. Control branch nodes: + - If `scope` matches a key in `control_nodes` and `status` matches the expected + branch label, the End node output becomes globally active (`activate = True`). - 2. Activate variable output segments that depend on this node: - - If an output segment is a variable - - And its literal references the completed node_id - - Mark that segment as active + 2. Variable output segments: + - For each segment that is a variable (`is_variable=True`): + - If the segment literal references `scope`, mark the segment as active. + - This applies both to regular node variables (e.g., "node_id.field") + and special system variables (e.g., "sys.xxx"). + + Notes: + - This method does not emit output or advance the streaming cursor. + - It only updates activation flags based on upstream events or special variables. """ # Case 1: resolve control branch dependency - if node_id in self.control_nodes: - self.control_nodes.remove(node_id) - - # All branch constraints resolved → enable output - if not self.control_nodes: + if scope in self.control_nodes.keys(): + if status is None: + raise RuntimeError("[Stream Output] Control node activation status not provided") + if status == self.control_nodes[scope]: self.activate = True # Case 2: activate variable segments related to this node for i in range(len(self.outputs)): if ( self.outputs[i].is_variable - and self.outputs[i].depends_on_node(node_id) + and self.outputs[i].depends_on_scope(scope) ): self.outputs[i].activate = True @@ -184,11 +180,11 @@ class GraphBuilder: self._find_upstream_branch_node = lru_cache( maxsize=len(self.nodes) * 2 )(self._find_upstream_branch_node) - self._analyze_end_node_output() self.graph = StateGraph(WorkflowState) self.add_nodes() self.add_edges() + self._analyze_end_node_output() # EDGES MUST BE ADDED AFTER NODES ARE ADDED. @property @@ -216,30 +212,53 @@ class GraphBuilder: except KeyError: raise RuntimeError(f"Node not found: Id={node_id}") - def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[str]]: - """Find upstream branch nodes for a given target node in the workflow graph. + def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]: + """ + Recursively find all upstream branch (control) nodes that influence the execution + of the given target node. - This method identifies all upstream control (branch) nodes that can affect - the execution of `target_node`. If `target_node` is reachable from a start - node (i.e., a node with no upstream nodes), the method returns an empty tuple. + This method walks upstream along the workflow graph starting from `target_node`. + It distinguishes between: + - branch nodes (node types listed in `BRANCH_NODES`) + - non-branch nodes (ordinary processing nodes) - The function distinguishes between branch nodes (defined in `BRANCH_NODES`) - and non-branch nodes, recursively traversing upstream through non-branch - nodes. If any non-branch upstream path does not lead to a branch node, - the result will indicate that no valid upstream branch node exists. + Traversal rules: + 1. For each immediate upstream node: + - If it is a branch node, it is recorded as an affecting control node. + - If it is a non-branch node, the traversal continues recursively upstream. + 2. If ANY upstream path reaches a START / CYCLE_START node without encountering + a branch node, the traversal is considered invalid: + - `has_branch` will be False + - no branch nodes are returned. + 3. Only when ALL upstream non-branch paths eventually lead to at least one + branch node will `has_branch` be True. + + Special case: + - If `target_node` has no upstream nodes AND its type is START or CYCLE_START, + it is considered directly reachable from the workflow entry, and therefore + has no controlling branch nodes. Args: - target_node (str): The identifier of the target node. + target_node (str): + The identifier of the node whose upstream control branches + are to be resolved. Returns: - tuple[bool, tuple[str]]: - - has_branch (bool): True if all upstream non-branch paths lead to at least - one branch node; False if any path reaches a start node without a branch. - - branch_nodes (tuple[str]): A deduplicated tuple of upstream branch node IDs - affecting `target_node`. Returns an empty tuple if `has_branch` is False. + tuple[bool, tuple[tuple[str, str]]]: + - has_branch (bool): + True if every upstream path from `target_node` encounters + at least one branch node. + False if any path reaches a start node without a branch. + - branch_nodes (tuple[tuple[str, str]]): + A deduplicated tuple of `(branch_node_id, branch_label)` pairs + representing all branch nodes that can influence `target_node`. + Returns an empty tuple if `has_branch` is False. """ source_nodes = [ - edge.get("source") + { + "id": edge.get("source"), + "branch": edge.get("label") + } for edge in self.edges if edge.get("target") == target_node ] @@ -249,11 +268,13 @@ class GraphBuilder: branch_nodes = [] non_branch_nodes = [] - for node_id in source_nodes: - if self.get_node_type(node_id) in BRANCH_NODES: - branch_nodes.append(node_id) + for node_info in source_nodes: + if self.get_node_type(node_info["id"]) in BRANCH_NODES: + branch_nodes.append( + (node_info["id"], node_info["branch"]) + ) else: - non_branch_nodes.append(node_id) + non_branch_nodes.append(node_info["id"]) has_branch = True for node_id in non_branch_nodes: @@ -334,7 +355,7 @@ class GraphBuilder: activate=not has_branch, # Branch nodes that control activation of this End node - control_nodes=list(control_nodes), + control_nodes=dict(control_nodes), # Convert output segments into OutputContent objects outputs=list( @@ -362,7 +383,7 @@ class GraphBuilder: else: self.end_node_map[end_node_id] = StreamOutputConfig( activate=True, - control_nodes=[], + control_nodes={}, outputs=list( [ OutputContent( diff --git a/api/app/core/workflow/nodes/memory/config.py b/api/app/core/workflow/nodes/memory/config.py index 57ee6dc2..31881e24 100644 --- a/api/app/core/workflow/nodes/memory/config.py +++ b/api/app/core/workflow/nodes/memory/config.py @@ -25,6 +25,6 @@ class MemoryWriteNodeConfig(BaseNodeConfig): ... ) - config_id: UUID = Field( + config_id: UUID | int = Field( ... ) diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index f71c70ee..13860bec 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -36,9 +36,10 @@ class MemoryReadNode(BaseNode): class MemoryWriteNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = MemoryWriteNodeConfig(**self.config) + self.typed_config: MemoryWriteNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: + self.typed_config = MemoryWriteNodeConfig(**self.config) end_user_id = self.get_variable("sys.user_id", state) if not end_user_id: