perf(workflow): Adjust graph construction timing, adopting a lazy strategy for constructing cyclic subgraphs within nodes
This commit is contained in:
@@ -61,21 +61,11 @@ class GraphBuilder:
|
||||
else:
|
||||
self.variable_pool = VariablePool()
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||
self._build_reverse_adj()
|
||||
self.add_edges()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
self._analyze_end_node_output()
|
||||
self.graph: StateGraph | None = None
|
||||
self.reachable_nodes: set[str] | None = None
|
||||
self.end_nodes: list[dict] = []
|
||||
self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list)
|
||||
self._adj: dict[str, list[str]] | None = defaultdict(list)
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[dict[str, Any]]:
|
||||
@@ -109,7 +99,7 @@ class GraphBuilder:
|
||||
result[node[0]].append(node[1])
|
||||
return result
|
||||
|
||||
def _build_reverse_adj(self):
|
||||
def _build_adj(self):
|
||||
for edge in self.edges:
|
||||
if edge["source"] not in self.reachable_nodes:
|
||||
continue
|
||||
@@ -513,6 +503,21 @@ class GraphBuilder:
|
||||
return
|
||||
|
||||
def build(self) -> CompiledStateGraph:
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||
self._build_adj()
|
||||
self.add_edges()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
self._analyze_end_node_output()
|
||||
checkpointer = InMemorySaver()
|
||||
self.graph = self.graph.compile(checkpointer=checkpointer)
|
||||
return self.graph
|
||||
return self.graph.compile(checkpointer=checkpointer)
|
||||
|
||||
|
||||
@@ -88,9 +88,10 @@ class WorkflowExecutor:
|
||||
self.workflow_config,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
self.graph = builder.build()
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.variable_pool = builder.variable_pool
|
||||
self.graph = builder.build()
|
||||
|
||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||
self.event_handler = EventStreamHandler(
|
||||
|
||||
@@ -32,15 +32,11 @@ class CycleGraphNode(BaseNode):
|
||||
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
|
||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
self.start_node_id = None # ID of the start node within the cycle
|
||||
|
||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||
self.child_variable_pool: VariablePool | None = None
|
||||
self.build_graph()
|
||||
self.iteration_flag = True
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||
@@ -137,7 +133,7 @@ class CycleGraphNode(BaseNode):
|
||||
3. Compile the graph for runtime execution
|
||||
"""
|
||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||
|
||||
self.child_variable_pool = VariablePool()
|
||||
builder = GraphBuilder(
|
||||
{
|
||||
@@ -147,8 +143,8 @@ class CycleGraphNode(BaseNode):
|
||||
subgraph=True,
|
||||
variable_pool=self.child_variable_pool
|
||||
)
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.graph = builder.build()
|
||||
self.start_node_id = builder.start_node_id
|
||||
self.child_variable_pool = builder.variable_pool
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
|
||||
Raises:
|
||||
RuntimeError: If the node type is unsupported.
|
||||
"""
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
return await LoopRuntime(
|
||||
start_id=self.start_node_id,
|
||||
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
|
||||
raise RuntimeError("Unknown cycle node type")
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
self.build_graph()
|
||||
if self.node_type == NodeType.LOOP:
|
||||
yield {
|
||||
"__final__": True,
|
||||
|
||||
@@ -183,7 +183,7 @@ class WorkflowValidator:
|
||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||
if has_cycle:
|
||||
errors.append(
|
||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
|
||||
)
|
||||
|
||||
# 8. 验证变量名
|
||||
@@ -229,10 +229,6 @@ class WorkflowValidator:
|
||||
Returns:
|
||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||
"""
|
||||
# 排除 loop 类型的节点
|
||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
||||
|
||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
||||
graph: dict[str, list[str]] = {}
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
@@ -243,10 +239,6 @@ class WorkflowValidator:
|
||||
if edge_type == "error":
|
||||
continue
|
||||
|
||||
# 如果涉及 loop 节点,跳过
|
||||
if source in loop_nodes or target in loop_nodes:
|
||||
continue
|
||||
|
||||
if source and target:
|
||||
if source not in graph:
|
||||
graph[source] = []
|
||||
|
||||
Reference in New Issue
Block a user