refactor(workflow): refactor graph construction to support subgraph building

This commit is contained in:
mengyonghao
2026-01-05 11:06:21 +08:00
parent 5957eb9c1a
commit 4685fd14ad
5 changed files with 321 additions and 281 deletions

View File

@@ -10,11 +10,10 @@ import logging
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.graph_builder import GraphBuilder
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.enums import NodeType
# from app.core.tools.registry import ToolRegistry
@@ -191,159 +190,10 @@ class WorkflowExecutor:
编译后的状态图
"""
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
# 分析 End 节点的前缀配置和相邻且被引用的节点
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
# 1. 创建状态图
workflow = StateGraph(WorkflowState)
# 2. 添加所有节点(包括 start 和 end
start_node_id = None
end_node_ids = []
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
continue
# 记录 start 和 end 节点 ID
if node_type == NodeType.START:
start_node_id = node_id
elif node_type == NodeType.END:
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
workflow.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
# 3. 添加边
# 从 START 连接到 start 节点
if start_node_id:
workflow.add_edge(START, start_node_id)
logger.debug(f"添加边: START -> {start_node_id}")
for edge in self.workflow_config.get("edges", []):
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == start_node_id:
# 但要连接 start 到下一个节点
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{tgt}"
return router_fn
router_fn = make_router(condition, target)
workflow.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
workflow.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in end_node_ids:
workflow.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
# 4. 编译图
graph = workflow.compile()
graph = GraphBuilder(
self.workflow_config,
stream=stream,
).build()
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
return graph

View File

@@ -0,0 +1,253 @@
import logging
import uuid
from typing import Any
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.graph import START, END
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
# TODO: 子图拆解支持
class GraphBuilder:
def __init__(
self,
workflow_config: dict[str, Any],
stream: bool = False,
subgraph: bool = False,
):
self.workflow_config = workflow_config
self.stream = stream
self.subgraph = subgraph
self.start_node_id = None
self.end_node_ids = []
self.graph: StateGraph | CompiledStateGraph | None = None
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
"""分析 End 节点的前缀配置
检查每个 End 节点的模板,找到直接上游节点的引用,
提取该引用之前的前缀部分。
Returns:
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
"""
import re
prefixes = {}
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
# 找到所有 End 节点
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
for end_node in end_nodes:
end_node_id = end_node.get("id")
output_template = end_node.get("config", {}).get("output")
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
if not output_template:
continue
# 查找模板中引用了哪些节点
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
matches = list(re.finditer(pattern, output_template))
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
# 找到所有直接连接到 End 节点的上游节点
direct_upstream_nodes = []
for edge in self.edges:
if edge.get("target") == end_node_id:
source_node_id = edge.get("source")
direct_upstream_nodes.append(source_node_id)
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
# 找到第一个直接上游节点的引用
for match in matches:
referenced_node_id = match.group(1)
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
if referenced_node_id in direct_upstream_nodes:
# 这是直接上游节点的引用,提取前缀
prefix = output_template[:match.start()]
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
# 标记这个节点为"相邻且被引用"
adjacent_and_referenced.add(referenced_node_id)
if prefix:
prefixes[referenced_node_id] = prefix
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
# 只处理第一个直接上游节点的引用
break
logger.info(f"[前缀分析] 最终配置: {prefixes}")
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
return prefixes, adjacent_and_referenced
def add_nodes(self):
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
for node in self.nodes:
node_type = node.get("type")
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
if not self.subgraph:
continue
# 记录 start 和 end 节点 ID
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(len(related_edge)):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 如果是流式模式,且节点有 End 前缀配置,注入配置
if self.stream and node_id in end_prefixes:
# 将 End 前缀配置注入到节点实例
node_instance._end_node_prefix = end_prefixes[node_id]
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
# 如果是流式模式,标记节点是否与 End 相邻且被引用
if self.stream:
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
if node_id in adjacent_and_referenced:
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
if self.stream:
# 流式模式:创建 async generator 函数
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
def make_stream_func(inst):
async def node_func(state: WorkflowState):
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
async for item in inst.run_stream(state):
yield item
return node_func
self.graph.add_node(node_id, make_stream_func(node_instance))
else:
# 非流式模式:创建 async function
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
def add_edges(self):
if self.start_node_id:
self.graph.add_edge(START, self.start_node_id)
logger.debug(f"添加边: START -> {self.start_node_id}")
for edge in self.edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
continue
# # 处理到 end 节点的边
# if target in end_node_ids:
# # 连接到 end 节点
# workflow.add_edge(source, target)
# logger.debug(f"添加边: {source} -> {target}")
# continue
# 跳过错误边(在节点内部处理)
if edge_type == "error":
continue
if condition:
# 条件边
def make_router(cond, tgt):
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
def router_fn(state: WorkflowState):
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("runtime_vars", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END
# 动态修改函数名,避免重复
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
return router_fn
router_fn = make_router(condition, target)
self.graph.add_conditional_edges(source, router_fn)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
return
def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges() # 添加边必须在添加节点之后
return self.graph.compile()

View File

@@ -1,7 +1,9 @@
from typing import Any
from pydantic import Field, BaseModel
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class CycleVariable(BaseNodeConfig):
@@ -9,18 +11,25 @@ class CycleVariable(BaseNodeConfig):
...,
description="Name of the loop variable"
)
type: VariableType = Field(
...,
description="Data type of the loop variable"
)
value: str = Field(
input_type: ValueInputType = Field(
...,
description="Input type of the loop variable"
)
value: Any = Field(
...,
description="Initial or current value of the loop variable"
)
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
operator: ComparisonOperator = Field(
...,
description="Operator used to compare the left and right operands"
)
@@ -30,11 +39,16 @@ class ConditionDetail(BaseModel):
description="Left-hand operand of the comparison expression"
)
right: str = Field(
right: Any = Field(
...,
description="Right-hand operand of the comparison expression"
)
input_type: ValueInputType = Field(
...,
description="Input type of the loop variable"
)
class ConditionsConfig(BaseModel):
"""Configuration for loop condition evaluation"""

View File

@@ -1,10 +1,9 @@
import logging
from typing import Any
from langgraph.graph import StateGraph, START, END
from langgraph.graph import StateGraph
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
@@ -17,12 +16,18 @@ logger = logging.getLogger(__name__)
class CycleGraphNode(BaseNode):
"""
Node representing a cycle (loop) subgraph within the workflow.
Node representing a cyclic (loop or iteration) subgraph within the workflow.
This node manages internal loop/iteration nodes, builds a subgraph
for execution, handles conditional routing, and executes loop
or iteration logic based on node type.
A CycleGraphNode is a structural node that:
- Extracts a group of nodes marked as belonging to the same cycle
- Builds an isolated internal StateGraph (subgraph)
- Delegates runtime execution to LoopRuntime or IterationRuntime
depending on the node type
This node itself does NOT execute business logic directly.
It acts as a container and execution controller for a subgraph.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
@@ -38,16 +43,23 @@ class CycleGraphNode(BaseNode):
def pure_cycle_graph(self) -> tuple[list, list]:
"""
Extract cycle nodes and internal edges from the workflow configuration,
removing them from the global workflow.
Extract cycle-scoped nodes and internal edges from the workflow configuration.
Raises:
ValueError: If cycle nodes are connected to external nodes improperly.
This method:
- Identifies all nodes marked with `cycle == self.node_id`
- Collects edges that fully connect cycle nodes
- Removes extracted nodes and edges from the global workflow configuration
Safety check:
- Raises an error if a cycle node is connected to an external node
Returns:
Tuple containing:
- cycle_nodes: List of removed nodes
- cycle_edges: List of removed edges
tuple[list, list]:
- cycle_nodes: Nodes belonging to this cycle
- cycle_edges: Edges connecting nodes within the cycle
Raises:
ValueError: If a cycle node is improperly connected to an external node.
"""
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
@@ -83,131 +95,41 @@ class CycleGraphNode(BaseNode):
return cycle_nodes, cycle_edges
def create_node(self):
"""
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
Special handling is applied for conditional nodes to generate
edge conditions based on node outputs.
"""
from app.core.workflow.nodes import NodeFactory
for node in self.cycle_nodes:
node_type = node.get("type")
node_id = node.get("id")
if node_type == NodeType.CYCLE_START:
self.start_node_id = node_id
continue
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
def make_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
self.graph.add_node(node_id, make_func(node_instance))
def create_edge(self):
"""
Connect nodes within the cycle subgraph by adding edges to the internal graph.
Conditional edges are routed based on evaluated expressions.
Start and end nodes are connected to global START and END nodes.
"""
for edge in self.cycle_edges:
source = edge.get("source")
target = edge.get("target")
edge_type = edge.get("type")
condition = edge.get("condition")
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start
if source == self.start_node_id:
# 但要连接 start 到下一个节点
self.graph.add_edge(START, target)
logger.debug(f"添加边: {source} -> {target}")
continue
if condition:
# 条件边
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
self.graph.add_conditional_edges(source, router)
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
else:
# 普通边
self.graph.add_edge(source, target)
logger.debug(f"添加边: {source} -> {target}")
# 从 end 节点连接到 END
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"添加边: {end_node_id} -> END")
def build_graph(self):
"""
Build the internal subgraph for the cycle node.
Build and compile the internal subgraph for this cycle node.
Steps:
1. Extract cycle nodes and edges.
2. Create node instances and add them to the graph.
3. Connect edges and conditional routes.
4. Compile the graph for execution.
1. Extract cycle nodes and internal edges from the workflow
2. Construct a StateGraph using GraphBuilder in subgraph mode
3. Compile the graph for runtime execution
"""
self.graph = StateGraph(WorkflowState)
from app.core.workflow.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.create_node()
self.create_edge()
self.graph = self.graph.compile()
self.graph = GraphBuilder(
{
"nodes": self.cycle_nodes,
"edges": self.cycle_edges,
},
subgraph=True
).build()
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the cycle node at runtime.
Depending on the node type, runs either a loop (LoopRuntime)
or an iteration (IterationRuntime) over the internal subgraph.
Based on the node type:
- LOOP: Executes LoopRuntime, repeatedly invoking the subgraph
- ITERATION: Executes IterationRuntime, iterating over a collection
Args:
state: Current workflow state.
state: The current workflow state when entering the cycle node.
Returns:
Runtime result of the cycle, typically the final loop/iteration variables.
Any: The runtime result produced by the loop or iteration executor.
Raises:
RuntimeError: If node type is unrecognized.
RuntimeError: If the node type is unsupported.
"""
if self.node_type == NodeType.LOOP:
return await LoopRuntime(

View File

@@ -72,6 +72,7 @@ class NodeFactory:
NodeType.LOOP: CycleGraphNode,
NodeType.ITERATION: CycleGraphNode,
NodeType.BREAK: BreakNode,
NodeType.CYCLE_START: StartNode,
}
@classmethod