perf(workflow): Adjust graph construction timing, adopting a lazy strategy for constructing cyclic subgraphs within nodes

This commit is contained in:
Eternity
2026-03-25 14:11:55 +08:00
parent e86d679ae5
commit 45eef12842
4 changed files with 31 additions and 35 deletions

View File

@@ -61,21 +61,11 @@ class GraphBuilder:
else: else:
self.variable_pool = VariablePool() self.variable_pool = VariablePool()
self.graph = StateGraph(WorkflowState) self.graph: StateGraph | None = None
self.add_nodes() self.reachable_nodes: set[str] | None = None
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) self.end_nodes: list[dict] = []
self.end_nodes = [ self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list)
node self._adj: dict[str, list[str]] | None = defaultdict(list)
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list)
self._build_reverse_adj()
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._analyze_end_node_output()
@property @property
def nodes(self) -> list[dict[str, Any]]: def nodes(self) -> list[dict[str, Any]]:
@@ -109,7 +99,7 @@ class GraphBuilder:
result[node[0]].append(node[1]) result[node[0]].append(node[1])
return result return result
def _build_reverse_adj(self): def _build_adj(self):
for edge in self.edges: for edge in self.edges:
if edge["source"] not in self.reachable_nodes: if edge["source"] not in self.reachable_nodes:
continue continue
@@ -513,6 +503,21 @@ class GraphBuilder:
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list)
self._build_adj()
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._analyze_end_node_output()
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
self.graph = self.graph.compile(checkpointer=checkpointer) return self.graph.compile(checkpointer=checkpointer)
return self.graph

View File

@@ -88,9 +88,10 @@ class WorkflowExecutor:
self.workflow_config, self.workflow_config,
stream=stream, stream=stream,
) )
self.graph = builder.build()
self.start_node_id = builder.start_node_id self.start_node_id = builder.start_node_id
self.variable_pool = builder.variable_pool self.variable_pool = builder.variable_pool
self.graph = builder.build()
self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
self.event_handler = EventStreamHandler( self.event_handler = EventStreamHandler(

View File

@@ -32,15 +32,11 @@ class CycleGraphNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT} outputs = {"__child_state": VariableType.ARRAY_OBJECT}
@@ -137,7 +133,7 @@ class CycleGraphNode(BaseNode):
3. Compile the graph for runtime execution 3. Compile the graph for runtime execution
""" """
from app.core.workflow.engine.graph_builder import GraphBuilder from app.core.workflow.engine.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool() self.child_variable_pool = VariablePool()
builder = GraphBuilder( builder = GraphBuilder(
{ {
@@ -147,8 +143,8 @@ class CycleGraphNode(BaseNode):
subgraph=True, subgraph=True,
variable_pool=self.child_variable_pool variable_pool=self.child_variable_pool
) )
self.start_node_id = builder.start_node_id
self.graph = builder.build() self.graph = builder.build()
self.start_node_id = builder.start_node_id
self.child_variable_pool = builder.variable_pool self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
Raises: Raises:
RuntimeError: If the node type is unsupported. RuntimeError: If the node type is unsupported.
""" """
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
return await LoopRuntime( return await LoopRuntime(
start_id=self.start_node_id, start_id=self.start_node_id,
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
raise RuntimeError("Unknown cycle node type") raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
yield { yield {
"__final__": True, "__final__": True,

View File

@@ -183,7 +183,7 @@ class WorkflowValidator:
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges) has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle: if has_cycle:
errors.append( errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}" f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
) )
# 8. 验证变量名 # 8. 验证变量名
@@ -229,10 +229,6 @@ class WorkflowValidator:
Returns: Returns:
(has_cycle, cycle_path): 是否有循环和循环路径 (has_cycle, cycle_path): 是否有循环和循环路径
""" """
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {} graph: dict[str, list[str]] = {}
for edge in edges: for edge in edges:
source = edge.get("source") source = edge.get("source")
@@ -243,10 +239,6 @@ class WorkflowValidator:
if edge_type == "error": if edge_type == "error":
continue continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target: if source and target:
if source not in graph: if source not in graph:
graph[source] = [] graph[source] = []