From 06de54ebfd32ba328cb922544e8eb866be671ab5 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 13:58:20 +0800 Subject: [PATCH] fix(workflow): fix streaming output issues caused by unreachable nodes --- api/app/core/workflow/engine/graph_builder.py | 37 +++++++++++-------- .../engine/stream_output_coordinator.py | 2 +- api/app/core/workflow/nodes/if_else/node.py | 4 +- api/app/core/workflow/validator.py | 4 +- api/app/services/workflow_service.py | 2 +- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index c20fd0bb..813a543f 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -20,6 +20,7 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes import NodeFactory from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES from app.core.workflow.utils.expression_evaluator import evaluate_condition +from app.core.workflow.validator import WorkflowValidator logger = logging.getLogger(__name__) @@ -48,8 +49,8 @@ class GraphBuilder: self.stream = stream self.subgraph = subgraph - self.start_node_id = None - self.end_node_ids = [] + self.start_node_id: str | None = None + self.node_map = {node["id"]: node for node in self.nodes} self.end_node_map: dict[str, StreamOutputConfig] = {} self._find_upstream_activation_dep = lru_cache( @@ -62,11 +63,18 @@ class GraphBuilder: self.graph = StateGraph(WorkflowState) self.add_nodes() + self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) + self.end_nodes = [ + node + for node in self.nodes + if node.get("type") == "end" and node.get("id") in self.reachable_nodes + ] self.add_edges() + # EDGES MUST BE ADDED AFTER NODES ARE ADDED. + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) self._build_reverse_adj() self._analyze_end_node_output() - # EDGES MUST BE ADDED AFTER NODES ARE ADDED. @property def nodes(self) -> list[dict[str, Any]]: @@ -102,6 +110,8 @@ class GraphBuilder: def _build_reverse_adj(self): for edge in self.edges: + if edge["source"] not in self.reachable_nodes: + continue self._reverse_adj[edge.get("target")].append({ "id": edge["source"], "branch": edge.get("label") }) @@ -137,10 +147,8 @@ class GraphBuilder: complete before this node activates. """ source_nodes = self._reverse_adj[target_node] - if not source_nodes: - if self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: - return tuple(), tuple() - raise RuntimeError(f"Node {target_node} is not reachable from the Start node") + if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: + return tuple(), tuple() branch_nodes = [] output_nodes = [] @@ -189,11 +197,10 @@ class GraphBuilder: """ # Collect all End nodes in the workflow - end_nodes = [node for node in self.nodes if node.get("type") == "end"] - logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes") + logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes") # Iterate through each End node to analyze its output - for end_node in end_nodes: + for end_node in self.end_nodes: end_node_id = end_node.get("id") config = end_node.get("config", {}) output = config.get("output") @@ -307,8 +314,6 @@ class GraphBuilder: # Record start and end node IDs if node_type in [NodeType.START, NodeType.CYCLE_START]: self.start_node_id = node_id - elif node_type == NodeType.END: - self.end_node_ids.append(node_id) # Create node instance (start and end nodes are also created) # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph @@ -497,9 +502,11 @@ class GraphBuilder: logger.debug(f"Added waiting edge: {sources} -> {target}") # Connect End nodes to the global END node - for end_node_id in self.end_node_ids: - self.graph.add_edge(end_node_id, END) - logger.debug(f"Added edge: {end_node_id} -> END") + for end_node in self.end_nodes: + end_node_id = end_node.get("id") + if end_node_id: + self.graph.add_edge(end_node_id, END) + logger.debug(f"Added edge: {end_node_id} -> END") return def build(self) -> CompiledStateGraph: diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index ceffc7dc..6685a49e 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -363,7 +363,7 @@ class StreamOutputCoordinator: logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}") if final_chunk: - logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}") + logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}") yield { "event": "message", "data": { diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 7e98efab..16782488 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -33,11 +33,11 @@ class IfElseNode(BaseNode): "right": expression.right if expression.input_type == ValueInputType.CONSTANT else self.get_variable(expression.right, variable_pool, strict=False), - "operator": expression.operator, + "operator": str(expression.operator), }) result.append({ "expressions": expressions, - "logical_operator": case.logical_operator, + "logical_operator": str(case.logical_operator), }) return { "cases": result diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 3b6e9036..fe4aea19 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -170,7 +170,7 @@ class WorkflowValidator: # 仅在发布时验证所有节点可达 # 6. 验证所有节点可达(从 start 节点出发) if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 - reachable = WorkflowValidator._get_reachable_nodes( + reachable = WorkflowValidator.get_reachable_nodes( start_nodes[0]["id"], edges ) @@ -194,7 +194,7 @@ class WorkflowValidator: return len(errors) == 0, errors @staticmethod - def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: + def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: """获取从 start 节点可达的所有节点 Args: diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index 04a778a1..56f34496 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -25,7 +25,7 @@ from app.repositories.workflow_repository import ( WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) -from app.schemas import DraftRunRequest, FileInput, FileType +from app.schemas import DraftRunRequest, FileInput from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService