perf(workflow): Adjust graph construction timing, adopting a lazy strategy for constructing cyclic subgraphs within nodes

This commit is contained in:
Eternity
2026-03-25 14:11:55 +08:00
parent e86d679ae5
commit 45eef12842
4 changed files with 31 additions and 35 deletions

View File

@@ -61,21 +61,11 @@ class GraphBuilder:
else:
self.variable_pool = VariablePool()
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list)
self._build_reverse_adj()
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._analyze_end_node_output()
self.graph: StateGraph | None = None
self.reachable_nodes: set[str] | None = None
self.end_nodes: list[dict] = []
self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list)
self._adj: dict[str, list[str]] | None = defaultdict(list)
@property
def nodes(self) -> list[dict[str, Any]]:
@@ -109,7 +99,7 @@ class GraphBuilder:
result[node[0]].append(node[1])
return result
def _build_reverse_adj(self):
def _build_adj(self):
for edge in self.edges:
if edge["source"] not in self.reachable_nodes:
continue
@@ -513,6 +503,21 @@ class GraphBuilder:
return
def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list)
self._build_adj()
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._analyze_end_node_output()
checkpointer = InMemorySaver()
self.graph = self.graph.compile(checkpointer=checkpointer)
return self.graph
return self.graph.compile(checkpointer=checkpointer)

View File

@@ -88,9 +88,10 @@ class WorkflowExecutor:
self.workflow_config,
stream=stream,
)
self.graph = builder.build()
self.start_node_id = builder.start_node_id
self.variable_pool = builder.variable_pool
self.graph = builder.build()
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
self.event_handler = EventStreamHandler(

View File

@@ -32,15 +32,11 @@ class CycleGraphNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
@@ -137,7 +133,7 @@ class CycleGraphNode(BaseNode):
3. Compile the graph for runtime execution
"""
from app.core.workflow.engine.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool()
builder = GraphBuilder(
{
@@ -147,8 +143,8 @@ class CycleGraphNode(BaseNode):
subgraph=True,
variable_pool=self.child_variable_pool
)
self.start_node_id = builder.start_node_id
self.graph = builder.build()
self.start_node_id = builder.start_node_id
self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
Raises:
RuntimeError: If the node type is unsupported.
"""
self.build_graph()
if self.node_type == NodeType.LOOP:
return await LoopRuntime(
start_id=self.start_node_id,
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP:
yield {
"__final__": True,

View File

@@ -183,7 +183,7 @@ class WorkflowValidator:
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
if has_cycle:
errors.append(
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
)
# 8. 验证变量名
@@ -229,10 +229,6 @@ class WorkflowValidator:
Returns:
(has_cycle, cycle_path): 是否有循环和循环路径
"""
# 排除 loop 类型的节点
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
# 构建邻接表(排除 loop 节点的边和错误边)
graph: dict[str, list[str]] = {}
for edge in edges:
source = edge.get("source")
@@ -243,10 +239,6 @@ class WorkflowValidator:
if edge_type == "error":
continue
# 如果涉及 loop 节点,跳过
if source in loop_nodes or target in loop_nodes:
continue
if source and target:
if source not in graph:
graph[source] = []