perf(workflow): Adjust graph construction timing, adopting a lazy strategy for constructing cyclic subgraphs within nodes
This commit is contained in:
@@ -61,21 +61,11 @@ class GraphBuilder:
|
|||||||
else:
|
else:
|
||||||
self.variable_pool = VariablePool()
|
self.variable_pool = VariablePool()
|
||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph: StateGraph | None = None
|
||||||
self.add_nodes()
|
self.reachable_nodes: set[str] | None = None
|
||||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
self.end_nodes: list[dict] = []
|
||||||
self.end_nodes = [
|
self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list)
|
||||||
node
|
self._adj: dict[str, list[str]] | None = defaultdict(list)
|
||||||
for node in self.nodes
|
|
||||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
|
||||||
]
|
|
||||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
|
||||||
self._adj: dict[str, list[str]] = defaultdict(list)
|
|
||||||
self._build_reverse_adj()
|
|
||||||
self.add_edges()
|
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
|
||||||
|
|
||||||
self._analyze_end_node_output()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
def nodes(self) -> list[dict[str, Any]]:
|
||||||
@@ -109,7 +99,7 @@ class GraphBuilder:
|
|||||||
result[node[0]].append(node[1])
|
result[node[0]].append(node[1])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _build_reverse_adj(self):
|
def _build_adj(self):
|
||||||
for edge in self.edges:
|
for edge in self.edges:
|
||||||
if edge["source"] not in self.reachable_nodes:
|
if edge["source"] not in self.reachable_nodes:
|
||||||
continue
|
continue
|
||||||
@@ -513,6 +503,21 @@ class GraphBuilder:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def build(self) -> CompiledStateGraph:
|
def build(self) -> CompiledStateGraph:
|
||||||
|
self.graph = StateGraph(WorkflowState)
|
||||||
|
self.add_nodes()
|
||||||
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
|
self.end_nodes = [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
|
]
|
||||||
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||||
|
self._build_adj()
|
||||||
|
self.add_edges()
|
||||||
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
|
self._analyze_end_node_output()
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
self.graph = self.graph.compile(checkpointer=checkpointer)
|
return self.graph.compile(checkpointer=checkpointer)
|
||||||
return self.graph
|
|
||||||
|
|||||||
@@ -88,9 +88,10 @@ class WorkflowExecutor:
|
|||||||
self.workflow_config,
|
self.workflow_config,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.graph = builder.build()
|
||||||
self.start_node_id = builder.start_node_id
|
self.start_node_id = builder.start_node_id
|
||||||
self.variable_pool = builder.variable_pool
|
self.variable_pool = builder.variable_pool
|
||||||
self.graph = builder.build()
|
|
||||||
|
|
||||||
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
|
||||||
self.event_handler = EventStreamHandler(
|
self.event_handler = EventStreamHandler(
|
||||||
|
|||||||
@@ -32,15 +32,11 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||||
self.cycle_nodes = list() # Nodes belonging to this cycle
|
|
||||||
self.cycle_edges = list() # Edges connecting nodes within the cycle
|
|
||||||
self.start_node_id = None # ID of the start node within the cycle
|
self.start_node_id = None # ID of the start node within the cycle
|
||||||
|
|
||||||
self.graph: StateGraph | CompiledStateGraph | None = None
|
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||||
self.child_variable_pool: VariablePool | None = None
|
self.child_variable_pool: VariablePool | None = None
|
||||||
self.build_graph()
|
|
||||||
self.iteration_flag = True
|
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
outputs = {"__child_state": VariableType.ARRAY_OBJECT}
|
||||||
@@ -137,7 +133,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
3. Compile the graph for runtime execution
|
3. Compile the graph for runtime execution
|
||||||
"""
|
"""
|
||||||
from app.core.workflow.engine.graph_builder import GraphBuilder
|
from app.core.workflow.engine.graph_builder import GraphBuilder
|
||||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
|
||||||
self.child_variable_pool = VariablePool()
|
self.child_variable_pool = VariablePool()
|
||||||
builder = GraphBuilder(
|
builder = GraphBuilder(
|
||||||
{
|
{
|
||||||
@@ -147,8 +143,8 @@ class CycleGraphNode(BaseNode):
|
|||||||
subgraph=True,
|
subgraph=True,
|
||||||
variable_pool=self.child_variable_pool
|
variable_pool=self.child_variable_pool
|
||||||
)
|
)
|
||||||
self.start_node_id = builder.start_node_id
|
|
||||||
self.graph = builder.build()
|
self.graph = builder.build()
|
||||||
|
self.start_node_id = builder.start_node_id
|
||||||
self.child_variable_pool = builder.variable_pool
|
self.child_variable_pool = builder.variable_pool
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
@@ -169,6 +165,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If the node type is unsupported.
|
RuntimeError: If the node type is unsupported.
|
||||||
"""
|
"""
|
||||||
|
self.build_graph()
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
return await LoopRuntime(
|
return await LoopRuntime(
|
||||||
start_id=self.start_node_id,
|
start_id=self.start_node_id,
|
||||||
@@ -194,6 +191,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
raise RuntimeError("Unknown cycle node type")
|
raise RuntimeError("Unknown cycle node type")
|
||||||
|
|
||||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||||
|
self.build_graph()
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
yield {
|
yield {
|
||||||
"__final__": True,
|
"__final__": True,
|
||||||
|
|||||||
@@ -183,7 +183,7 @@ class WorkflowValidator:
|
|||||||
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
has_cycle, cycle_path = WorkflowValidator._has_cycle(nodes, edges)
|
||||||
if has_cycle:
|
if has_cycle:
|
||||||
errors.append(
|
errors.append(
|
||||||
f"工作流存在循环依赖(请使用 loop 节点实现循环): {' -> '.join(cycle_path)}"
|
f"工作流存在循环依赖(请使用 loop/iteration 节点实现循环): {' -> '.join(cycle_path)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. 验证变量名
|
# 8. 验证变量名
|
||||||
@@ -229,10 +229,6 @@ class WorkflowValidator:
|
|||||||
Returns:
|
Returns:
|
||||||
(has_cycle, cycle_path): 是否有循环和循环路径
|
(has_cycle, cycle_path): 是否有循环和循环路径
|
||||||
"""
|
"""
|
||||||
# 排除 loop 类型的节点
|
|
||||||
loop_nodes = {n["id"] for n in nodes if n.get("type") == "loop"}
|
|
||||||
|
|
||||||
# 构建邻接表(排除 loop 节点的边和错误边)
|
|
||||||
graph: dict[str, list[str]] = {}
|
graph: dict[str, list[str]] = {}
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
source = edge.get("source")
|
source = edge.get("source")
|
||||||
@@ -243,10 +239,6 @@ class WorkflowValidator:
|
|||||||
if edge_type == "error":
|
if edge_type == "error":
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果涉及 loop 节点,跳过
|
|
||||||
if source in loop_nodes or target in loop_nodes:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if source and target:
|
if source and target:
|
||||||
if source not in graph:
|
if source not in graph:
|
||||||
graph[source] = []
|
graph[source] = []
|
||||||
|
|||||||
Reference in New Issue
Block a user