fix(workflow): fix streaming output issues caused by unreachable nodes

This commit is contained in:
Eternity
2026-03-20 13:58:20 +08:00
parent 7c6e48b04e
commit 06de54ebfd
5 changed files with 28 additions and 21 deletions

View File

@@ -20,6 +20,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import NodeFactory from app.core.workflow.nodes import NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
from app.core.workflow.utils.expression_evaluator import evaluate_condition from app.core.workflow.utils.expression_evaluator import evaluate_condition
from app.core.workflow.validator import WorkflowValidator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -48,8 +49,8 @@ class GraphBuilder:
self.stream = stream self.stream = stream
self.subgraph = subgraph self.subgraph = subgraph
self.start_node_id = None self.start_node_id: str | None = None
self.end_node_ids = []
self.node_map = {node["id"]: node for node in self.nodes} self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {} self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep = lru_cache( self._find_upstream_activation_dep = lru_cache(
@@ -62,11 +63,18 @@ class GraphBuilder:
self.graph = StateGraph(WorkflowState) self.graph = StateGraph(WorkflowState)
self.add_nodes() self.add_nodes()
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self.add_edges() self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list) self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj() self._build_reverse_adj()
self._analyze_end_node_output() self._analyze_end_node_output()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
@property @property
def nodes(self) -> list[dict[str, Any]]: def nodes(self) -> list[dict[str, Any]]:
@@ -102,6 +110,8 @@ class GraphBuilder:
def _build_reverse_adj(self): def _build_reverse_adj(self):
for edge in self.edges: for edge in self.edges:
if edge["source"] not in self.reachable_nodes:
continue
self._reverse_adj[edge.get("target")].append({ self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label") "id": edge["source"], "branch": edge.get("label")
}) })
@@ -137,10 +147,8 @@ class GraphBuilder:
complete before this node activates. complete before this node activates.
""" """
source_nodes = self._reverse_adj[target_node] source_nodes = self._reverse_adj[target_node]
if not source_nodes: if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
if self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: return tuple(), tuple()
return tuple(), tuple()
raise RuntimeError(f"Node {target_node} is not reachable from the Start node")
branch_nodes = [] branch_nodes = []
output_nodes = [] output_nodes = []
@@ -189,11 +197,10 @@ class GraphBuilder:
""" """
# Collect all End nodes in the workflow # Collect all End nodes in the workflow
end_nodes = [node for node in self.nodes if node.get("type") == "end"] logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
# Iterate through each End node to analyze its output # Iterate through each End node to analyze its output
for end_node in end_nodes: for end_node in self.end_nodes:
end_node_id = end_node.get("id") end_node_id = end_node.get("id")
config = end_node.get("config", {}) config = end_node.get("config", {})
output = config.get("output") output = config.get("output")
@@ -307,8 +314,6 @@ class GraphBuilder:
# Record start and end node IDs # Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]: if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# Create node instance (start and end nodes are also created) # Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
@@ -497,9 +502,11 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}") logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node # Connect End nodes to the global END node
for end_node_id in self.end_node_ids: for end_node in self.end_nodes:
self.graph.add_edge(end_node_id, END) end_node_id = end_node.get("id")
logger.debug(f"Added edge: {end_node_id} -> END") if end_node_id:
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:

View File

@@ -363,7 +363,7 @@ class StreamOutputCoordinator:
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}") logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
if final_chunk: if final_chunk:
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}") logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
yield { yield {
"event": "message", "event": "message",
"data": { "data": {

View File

@@ -33,11 +33,11 @@ class IfElseNode(BaseNode):
"right": expression.right "right": expression.right
if expression.input_type == ValueInputType.CONSTANT if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False), else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator, "operator": str(expression.operator),
}) })
result.append({ result.append({
"expressions": expressions, "expressions": expressions,
"logical_operator": case.logical_operator, "logical_operator": str(case.logical_operator),
}) })
return { return {
"cases": result "cases": result

View File

@@ -170,7 +170,7 @@ class WorkflowValidator:
# 仅在发布时验证所有节点可达 # 仅在发布时验证所有节点可达
# 6. 验证所有节点可达(从 start 节点出发) # 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性 if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes( reachable = WorkflowValidator.get_reachable_nodes(
start_nodes[0]["id"], start_nodes[0]["id"],
edges edges
) )
@@ -194,7 +194,7 @@ class WorkflowValidator:
return len(errors) == 0, errors return len(errors) == 0, errors
@staticmethod @staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]: def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点 """获取从 start 节点可达的所有节点
Args: Args:

View File

@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
WorkflowExecutionRepository, WorkflowExecutionRepository,
WorkflowNodeExecutionRepository WorkflowNodeExecutionRepository
) )
from app.schemas import DraftRunRequest, FileInput, FileType from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService