fix(workflow): fix streaming output issues caused by unreachable nodes
This commit is contained in:
@@ -20,6 +20,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes import NodeFactory
|
from app.core.workflow.nodes import NodeFactory
|
||||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.validator import WorkflowValidator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -48,8 +49,8 @@ class GraphBuilder:
|
|||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.subgraph = subgraph
|
self.subgraph = subgraph
|
||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id: str | None = None
|
||||||
self.end_node_ids = []
|
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map = {node["id"]: node for node in self.nodes}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_activation_dep = lru_cache(
|
self._find_upstream_activation_dep = lru_cache(
|
||||||
@@ -62,11 +63,18 @@ class GraphBuilder:
|
|||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
|
self.end_nodes = [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
|
]
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
self._build_reverse_adj()
|
self._build_reverse_adj()
|
||||||
self._analyze_end_node_output()
|
self._analyze_end_node_output()
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
def nodes(self) -> list[dict[str, Any]]:
|
||||||
@@ -102,6 +110,8 @@ class GraphBuilder:
|
|||||||
|
|
||||||
def _build_reverse_adj(self):
|
def _build_reverse_adj(self):
|
||||||
for edge in self.edges:
|
for edge in self.edges:
|
||||||
|
if edge["source"] not in self.reachable_nodes:
|
||||||
|
continue
|
||||||
self._reverse_adj[edge.get("target")].append({
|
self._reverse_adj[edge.get("target")].append({
|
||||||
"id": edge["source"], "branch": edge.get("label")
|
"id": edge["source"], "branch": edge.get("label")
|
||||||
})
|
})
|
||||||
@@ -137,10 +147,8 @@ class GraphBuilder:
|
|||||||
complete before this node activates.
|
complete before this node activates.
|
||||||
"""
|
"""
|
||||||
source_nodes = self._reverse_adj[target_node]
|
source_nodes = self._reverse_adj[target_node]
|
||||||
if not source_nodes:
|
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
if self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
return tuple(), tuple()
|
||||||
return tuple(), tuple()
|
|
||||||
raise RuntimeError(f"Node {target_node} is not reachable from the Start node")
|
|
||||||
|
|
||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
output_nodes = []
|
output_nodes = []
|
||||||
@@ -189,11 +197,10 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Collect all End nodes in the workflow
|
# Collect all End nodes in the workflow
|
||||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
|
||||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
|
||||||
|
|
||||||
# Iterate through each End node to analyze its output
|
# Iterate through each End node to analyze its output
|
||||||
for end_node in end_nodes:
|
for end_node in self.end_nodes:
|
||||||
end_node_id = end_node.get("id")
|
end_node_id = end_node.get("id")
|
||||||
config = end_node.get("config", {})
|
config = end_node.get("config", {})
|
||||||
output = config.get("output")
|
output = config.get("output")
|
||||||
@@ -307,8 +314,6 @@ class GraphBuilder:
|
|||||||
# Record start and end node IDs
|
# Record start and end node IDs
|
||||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
self.start_node_id = node_id
|
self.start_node_id = node_id
|
||||||
elif node_type == NodeType.END:
|
|
||||||
self.end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
# Create node instance (start and end nodes are also created)
|
# Create node instance (start and end nodes are also created)
|
||||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
@@ -497,9 +502,11 @@ class GraphBuilder:
|
|||||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||||
|
|
||||||
# Connect End nodes to the global END node
|
# Connect End nodes to the global END node
|
||||||
for end_node_id in self.end_node_ids:
|
for end_node in self.end_nodes:
|
||||||
self.graph.add_edge(end_node_id, END)
|
end_node_id = end_node.get("id")
|
||||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
if end_node_id:
|
||||||
|
self.graph.add_edge(end_node_id, END)
|
||||||
|
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||||
return
|
return
|
||||||
|
|
||||||
def build(self) -> CompiledStateGraph:
|
def build(self) -> CompiledStateGraph:
|
||||||
|
|||||||
@@ -363,7 +363,7 @@ class StreamOutputCoordinator:
|
|||||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||||
|
|
||||||
if final_chunk:
|
if final_chunk:
|
||||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
|
||||||
yield {
|
yield {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": {
|
"data": {
|
||||||
|
|||||||
@@ -33,11 +33,11 @@ class IfElseNode(BaseNode):
|
|||||||
"right": expression.right
|
"right": expression.right
|
||||||
if expression.input_type == ValueInputType.CONSTANT
|
if expression.input_type == ValueInputType.CONSTANT
|
||||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||||
"operator": expression.operator,
|
"operator": str(expression.operator),
|
||||||
})
|
})
|
||||||
result.append({
|
result.append({
|
||||||
"expressions": expressions,
|
"expressions": expressions,
|
||||||
"logical_operator": case.logical_operator,
|
"logical_operator": str(case.logical_operator),
|
||||||
})
|
})
|
||||||
return {
|
return {
|
||||||
"cases": result
|
"cases": result
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class WorkflowValidator:
|
|||||||
# 仅在发布时验证所有节点可达
|
# 仅在发布时验证所有节点可达
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
reachable = WorkflowValidator.get_reachable_nodes(
|
||||||
start_nodes[0]["id"],
|
start_nodes[0]["id"],
|
||||||
edges
|
edges
|
||||||
)
|
)
|
||||||
@@ -194,7 +194,7 @@ class WorkflowValidator:
|
|||||||
return len(errors) == 0, errors
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||||
"""获取从 start 节点可达的所有节点
|
"""获取从 start 节点可达的所有节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
|
|||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest, FileInput, FileType
|
from app.schemas import DraftRunRequest, FileInput
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|||||||
Reference in New Issue
Block a user