fix(workflow): fix streaming output issues caused by unreachable nodes

This commit is contained in:
Eternity
2026-03-20 13:58:20 +08:00
parent 7c6e48b04e
commit 06de54ebfd
5 changed files with 28 additions and 21 deletions

View File

@@ -20,6 +20,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
from app.core.workflow.utils.expression_evaluator import evaluate_condition
from app.core.workflow.validator import WorkflowValidator
logger = logging.getLogger(__name__)
@@ -48,8 +49,8 @@ class GraphBuilder:
self.stream = stream
self.subgraph = subgraph
self.start_node_id = None
self.end_node_ids = []
self.start_node_id: str | None = None
self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep = lru_cache(
@@ -62,11 +63,18 @@ class GraphBuilder:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj()
self._analyze_end_node_output()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
@property
def nodes(self) -> list[dict[str, Any]]:
@@ -102,6 +110,8 @@ class GraphBuilder:
def _build_reverse_adj(self):
for edge in self.edges:
if edge["source"] not in self.reachable_nodes:
continue
self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label")
})
@@ -137,10 +147,8 @@ class GraphBuilder:
complete before this node activates.
"""
source_nodes = self._reverse_adj[target_node]
if not source_nodes:
if self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return tuple(), tuple()
raise RuntimeError(f"Node {target_node} is not reachable from the Start node")
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return tuple(), tuple()
branch_nodes = []
output_nodes = []
@@ -189,11 +197,10 @@ class GraphBuilder:
"""
# Collect all End nodes in the workflow
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
# Iterate through each End node to analyze its output
for end_node in end_nodes:
for end_node in self.end_nodes:
end_node_id = end_node.get("id")
config = end_node.get("config", {})
output = config.get("output")
@@ -307,8 +314,6 @@ class GraphBuilder:
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
@@ -497,9 +502,11 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
for end_node in self.end_nodes:
end_node_id = end_node.get("id")
if end_node_id:
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
return
def build(self) -> CompiledStateGraph:

View File

@@ -363,7 +363,7 @@ class StreamOutputCoordinator:
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
if final_chunk:
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
yield {
"event": "message",
"data": {

View File

@@ -33,11 +33,11 @@ class IfElseNode(BaseNode):
"right": expression.right
if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator,
"operator": str(expression.operator),
})
result.append({
"expressions": expressions,
"logical_operator": case.logical_operator,
"logical_operator": str(case.logical_operator),
})
return {
"cases": result

View File

@@ -170,7 +170,7 @@ class WorkflowValidator:
# 仅在发布时验证所有节点可达
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
reachable = WorkflowValidator.get_reachable_nodes(
start_nodes[0]["id"],
edges
)
@@ -194,7 +194,7 @@ class WorkflowValidator:
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
Args:

View File

@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository
)
from app.schemas import DraftRunRequest, FileInput, FileType
from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService