From 45eef128427cca589830f9a0290699b7f5ac7c5f Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 25 Mar 2026 14:11:55 +0800 Subject: [PATCH] perf(workflow): Adjust graph construction timing, adopting a lazy strategy for constructing cyclic subgraphs within nodes --- api/app/core/workflow/engine/graph_builder.py | 41 +++++++++++-------- api/app/core/workflow/executor.py | 3 +- .../core/workflow/nodes/cycle_graph/node.py | 12 +++--- api/app/core/workflow/validator.py | 10 +---- 4 files changed, 31 insertions(+), 35 deletions(-) diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 29f46765..d092db5b 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -61,21 +61,11 @@ class GraphBuilder: else: self.variable_pool = VariablePool() - self.graph = StateGraph(WorkflowState) - self.add_nodes() - self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) - self.end_nodes = [ - node - for node in self.nodes - if node.get("type") == "end" and node.get("id") in self.reachable_nodes - ] - self._reverse_adj: dict[str, list[dict]] = defaultdict(list) - self._adj: dict[str, list[str]] = defaultdict(list) - self._build_reverse_adj() - self.add_edges() - # EDGES MUST BE ADDED AFTER NODES ARE ADDED. - - self._analyze_end_node_output() + self.graph: StateGraph | None = None + self.reachable_nodes: set[str] | None = None + self.end_nodes: list[dict] = [] + self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list) + self._adj: dict[str, list[str]] | None = defaultdict(list) @property def nodes(self) -> list[dict[str, Any]]: @@ -109,7 +99,7 @@ class GraphBuilder: result[node[0]].append(node[1]) return result - def _build_reverse_adj(self): + def _build_adj(self): for edge in self.edges: if edge["source"] not in self.reachable_nodes: continue @@ -513,6 +503,21 @@ class GraphBuilder: return def build(self) -> CompiledStateGraph: + self.graph = StateGraph(WorkflowState) + self.add_nodes() + self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) + self.end_nodes = [ + node + for node in self.nodes + if node.get("type") == "end" and node.get("id") in self.reachable_nodes + ] + self._reverse_adj: dict[str, list[dict]] = defaultdict(list) + self._adj: dict[str, list[str]] = defaultdict(list) + self._build_adj() + self.add_edges() + # EDGES MUST BE ADDED AFTER NODES ARE ADDED. + + self._analyze_end_node_output() checkpointer = InMemorySaver() - self.graph = self.graph.compile(checkpointer=checkpointer) - return self.graph + return self.graph.compile(checkpointer=checkpointer) + diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 1170d66c..0a820826 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -88,9 +88,10 @@ class WorkflowExecutor: self.workflow_config, stream=stream, ) + + self.graph = builder.build() self.start_node_id = builder.start_node_id self.variable_pool = builder.variable_pool - self.graph = builder.build() self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.event_handler = EventStreamHandler( diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 71e0dbdb..16939bac 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -32,15 +32,11 @@ class CycleGraphNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - - self.cycle_nodes = list() # Nodes belonging to this cycle - self.cycle_edges = list() # Edges connecting nodes within the cycle + self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() self.start_node_id = None # ID of the start node within the cycle self.graph: StateGraph | CompiledStateGraph | None = None self.child_variable_pool: VariablePool | None = None - self.build_graph() - self.iteration_flag = True def _output_types(self) -> dict[str, VariableType]: outputs = {"__child_state": VariableType.ARRAY_OBJECT} @@ -137,7 +133,7 @@ class CycleGraphNode(BaseNode): 3. Compile the graph for runtime execution """ from app.core.workflow.engine.graph_builder import GraphBuilder - self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() + self.child_variable_pool = VariablePool() builder = GraphBuilder( { @@ -147,8 +143,8 @@ class CycleGraphNode(BaseNode): subgraph=True, variable_pool=self.child_variable_pool ) - self.start_node_id = builder.start_node_id self.graph = builder.build() + self.start_node_id = builder.start_node_id self.child_variable_pool = builder.variable_pool async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: @@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode): Raises: RuntimeError: If the node type is unsupported. """ + self.build_graph() if self.node_type == NodeType.LOOP: return await LoopRuntime( start_id=self.start_node_id, @@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode): raise RuntimeError("Unknown cycle node type") async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): + self.build_graph() if self.node_type == NodeType.LOOP: yield { "__final__": True, diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 683ccb98..0ad74865 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -183,7 +183,7 @@ class WorkflowValidator: has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) if has_cycle: errors.append( - f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" + f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}" ) # 8. 验证变量名 @@ -229,10 +229,6 @@ class WorkflowValidator: Returns: (has_cycle, cycle_path): 是否有循环和循环路径 """ - # 排除 loop 类型的节点 - loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"} - - # 构建邻接表(排除 loop 节点的边和错误边) graph: dict[str, list[str]] = {} for edge in edges: source = edge.get("source") @@ -243,10 +239,6 @@ class WorkflowValidator: if edge_type == "error": continue - # 如果涉及 loop 节点,跳过 - if source in loop_nodes or target in loop_nodes: - continue - if source and target: if source not in graph: graph[source] = []