From 4685fd14adcac42d8cf990cfc779feb6353b66b0 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 11:06:21 +0800 Subject: [PATCH] refactor(workflow): refactor graph construction to support subgraph building --- api/app/core/workflow/executor.py | 162 +---------- api/app/core/workflow/graph_builder.py | 253 ++++++++++++++++++ .../core/workflow/nodes/cycle_graph/config.py | 22 +- .../core/workflow/nodes/cycle_graph/node.py | 164 +++--------- api/app/core/workflow/nodes/node_factory.py | 1 + 5 files changed, 321 insertions(+), 281 deletions(-) create mode 100644 api/app/core/workflow/graph_builder.py diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 0d0879d7..7274764a 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -10,11 +10,10 @@ import logging from typing import Any from langchain_core.messages import HumanMessage -from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_condition -from app.core.workflow.nodes import WorkflowState, NodeFactory +from app.core.workflow.graph_builder import GraphBuilder +from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.enums import NodeType # from app.core.tools.registry import ToolRegistry @@ -191,159 +190,10 @@ class WorkflowExecutor: 编译后的状态图 """ logger.info(f"开始构建工作流图: execution_id={self.execution_id}") - - # 分析 End 节点的前缀配置和相邻且被引用的节点 - end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set()) - - # 1. 创建状态图 - workflow = StateGraph(WorkflowState) - - # 2. 添加所有节点(包括 start 和 end) - start_node_id = None - end_node_ids = [] - - for node in self.nodes: - node_type = node.get("type") - node_id = node.get("id") - cycle_node = node.get("cycle") - if cycle_node: - # 处于循环子图中的节点由 CycleGraphNode 进行构建处理 - continue - - # 记录 start 和 end 节点 ID - if node_type == NodeType.START: - start_node_id = node_id - elif node_type == NodeType.END: - end_node_ids.append(node_id) - - # 创建节点实例(现在 start 和 end 也会被创建) - node_instance = NodeFactory.create_node(node, self.workflow_config) - - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: - expressions = node_instance.build_conditional_edge_expressions() - - # Number of branches, usually matches the number of conditional expressions - branch_number = len(expressions) - - # Find all edges whose source is the current node - related_edge = [edge for edge in self.edges if edge.get("source") == node_id] - - # Iterate over each branch - for idx in range(branch_number): - # Generate a condition expression for each edge - # Used later to determine which branch to take based on the node's output - # Assumes node output `node..output` matches the edge's label - # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' - related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" - - if node_instance: - # 如果是流式模式,且节点有 End 前缀配置,注入配置 - if stream and node_id in end_prefixes: - # 将 End 前缀配置注入到节点实例 - node_instance._end_node_prefix = end_prefixes[node_id] - logger.info(f"为节点 {node_id} 注入 End 前缀配置") - - # 如果是流式模式,标记节点是否与 End 相邻且被引用 - if stream: - node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced - if node_id in adjacent_and_referenced: - logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用") - - # 包装节点的 run 方法 - # 使用函数工厂避免闭包问题 - if stream: - # 流式模式:创建 async generator 函数 - # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state - def make_stream_func(inst): - async def node_func(state: WorkflowState): - # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") - async for item in inst.run_stream(state): - yield item - - return node_func - - workflow.add_node(node_id, make_stream_func(node_instance)) - else: - # 非流式模式:创建 async function - def make_func(inst): - async def node_func(state: WorkflowState): - return await inst.run(state) - - return node_func - - workflow.add_node(node_id, make_func(node_instance)) - - logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})") - - # 3. 添加边 - # 从 START 连接到 start 节点 - if start_node_id: - workflow.add_edge(START, start_node_id) - logger.debug(f"添加边: START -> {start_node_id}") - - for edge in self.workflow_config.get("edges", []): - source = edge.get("source") - target = edge.get("target") - edge_type = edge.get("type") - condition = edge.get("condition") - - # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) - if source == start_node_id: - # 但要连接 start 到下一个节点 - workflow.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - continue - - # # 处理到 end 节点的边 - # if target in end_node_ids: - # # 连接到 end 节点 - # workflow.add_edge(source, target) - # logger.debug(f"添加边: {source} -> {target}") - # continue - - # 跳过错误边(在节点内部处理) - if edge_type == "error": - continue - - if condition: - # 条件边 - def make_router(cond, tgt): - """Dynamically generate a conditional router function to ensure each branch has a unique name.""" - - - def router_fn(state: WorkflowState): - if evaluate_condition( - cond, - state.get("variables", {}), - state.get("node_outputs", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } - ): - return tgt - return END - - # 动态修改函数名,避免重复 - router_fn.__name__ = f"router_{tgt}" - return router_fn - - router_fn = make_router(condition, target) - workflow.add_conditional_edges(source, router_fn) - logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") - else: - # 普通边 - workflow.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - - # 从 end 节点连接到 END - for end_node_id in end_node_ids: - workflow.add_edge(end_node_id, END) - logger.debug(f"添加边: {end_node_id} -> END") - - # 4. 编译图 - graph = workflow.compile() + graph = GraphBuilder( + self.workflow_config, + stream=stream, + ).build() logger.info(f"工作流图构建完成: execution_id={self.execution_id}") return graph diff --git a/api/app/core/workflow/graph_builder.py b/api/app/core/workflow/graph_builder.py new file mode 100644 index 00000000..9e80db33 --- /dev/null +++ b/api/app/core/workflow/graph_builder.py @@ -0,0 +1,253 @@ +import logging +import uuid +from typing import Any + +from langgraph.graph.state import CompiledStateGraph, StateGraph +from langgraph.graph import START, END + +from app.core.workflow.expression_evaluator import evaluate_condition +from app.core.workflow.nodes import WorkflowState, NodeFactory +from app.core.workflow.nodes.enums import NodeType + +logger = logging.getLogger(__name__) + + +# TODO: 子图拆解支持 +class GraphBuilder: + def __init__( + self, + workflow_config: dict[str, Any], + stream: bool = False, + subgraph: bool = False, + ): + self.workflow_config = workflow_config + + self.stream = stream + self.subgraph = subgraph + + self.start_node_id = None + self.end_node_ids = [] + + self.graph: StateGraph | CompiledStateGraph | None = None + + @property + def nodes(self) -> list[dict[str, Any]]: + return self.workflow_config.get("nodes", []) + + @property + def edges(self) -> list[dict[str, Any]]: + return self.workflow_config.get("edges", []) + + def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]: + """分析 End 节点的前缀配置 + + 检查每个 End 节点的模板,找到直接上游节点的引用, + 提取该引用之前的前缀部分。 + + Returns: + 元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合}) + """ + import re + + prefixes = {} + adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点 + + # 找到所有 End 节点 + end_nodes = [node for node in self.nodes if node.get("type") == "end"] + logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点") + + for end_node in end_nodes: + end_node_id = end_node.get("id") + output_template = end_node.get("config", {}).get("output") + + logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}") + + if not output_template: + continue + + # 查找模板中引用了哪些节点 + # 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格) + pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}' + matches = list(re.finditer(pattern, output_template)) + + logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用") + + # 找到所有直接连接到 End 节点的上游节点 + direct_upstream_nodes = [] + for edge in self.edges: + if edge.get("target") == end_node_id: + source_node_id = edge.get("source") + direct_upstream_nodes.append(source_node_id) + + logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}") + + # 找到第一个直接上游节点的引用 + for match in matches: + referenced_node_id = match.group(1) + logger.info(f"[前缀分析] 检查引用: {referenced_node_id}") + + if referenced_node_id in direct_upstream_nodes: + # 这是直接上游节点的引用,提取前缀 + prefix = output_template[:match.start()] + + logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'") + + # 标记这个节点为"相邻且被引用" + adjacent_and_referenced.add(referenced_node_id) + + if prefix: + prefixes[referenced_node_id] = prefix + logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'") + + # 只处理第一个直接上游节点的引用 + break + + logger.info(f"[前缀分析] 最终配置: {prefixes}") + logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}") + return prefixes, adjacent_and_referenced + + def add_nodes(self): + end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set()) + + for node in self.nodes: + node_type = node.get("type") + node_id = node.get("id") + cycle_node = node.get("cycle") + if cycle_node: + # 处于循环子图中的节点由 CycleGraphNode 进行构建处理 + if not self.subgraph: + continue + + # 记录 start 和 end 节点 ID + if node_type in [NodeType.START, NodeType.CYCLE_START]: + self.start_node_id = node_id + elif node_type == NodeType.END: + self.end_node_ids.append(node_id) + + # 创建节点实例(现在 start 和 end 也会被创建) + # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph + node_instance = NodeFactory.create_node(node, self.workflow_config) + + if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: + + # Find all edges whose source is the current node + related_edge = [edge for edge in self.edges if edge.get("source") == node_id] + + # Iterate over each branch + for idx in range(len(related_edge)): + # Generate a condition expression for each edge + # Used later to determine which branch to take based on the node's output + # Assumes node output `node..output` matches the edge's label + # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' + related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + + if node_instance: + # 如果是流式模式,且节点有 End 前缀配置,注入配置 + if self.stream and node_id in end_prefixes: + # 将 End 前缀配置注入到节点实例 + node_instance._end_node_prefix = end_prefixes[node_id] + logger.info(f"为节点 {node_id} 注入 End 前缀配置") + + # 如果是流式模式,标记节点是否与 End 相邻且被引用 + if self.stream: + node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced + if node_id in adjacent_and_referenced: + logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用") + + # 包装节点的 run 方法 + # 使用函数工厂避免闭包问题 + if self.stream: + # 流式模式:创建 async generator 函数 + # LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state + def make_stream_func(inst): + async def node_func(state: WorkflowState): + # logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}") + async for item in inst.run_stream(state): + yield item + + return node_func + + self.graph.add_node(node_id, make_stream_func(node_instance)) + else: + # 非流式模式:创建 async function + def make_func(inst): + async def node_func(state: WorkflowState): + return await inst.run(state) + + return node_func + + self.graph.add_node(node_id, make_func(node_instance)) + + logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})") + + def add_edges(self): + if self.start_node_id: + self.graph.add_edge(START, self.start_node_id) + logger.debug(f"添加边: START -> {self.start_node_id}") + + for edge in self.edges: + source = edge.get("source") + target = edge.get("target") + edge_type = edge.get("type") + condition = edge.get("condition") + + # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) + if source == self.start_node_id: + # 但要连接 start 到下一个节点 + self.graph.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + continue + + # # 处理到 end 节点的边 + # if target in end_node_ids: + # # 连接到 end 节点 + # workflow.add_edge(source, target) + # logger.debug(f"添加边: {source} -> {target}") + # continue + + # 跳过错误边(在节点内部处理) + if edge_type == "error": + continue + + if condition: + # 条件边 + def make_router(cond, tgt): + """Dynamically generate a conditional router function to ensure each branch has a unique name.""" + + def router_fn(state: WorkflowState): + if evaluate_condition( + cond, + state.get("variables", {}), + state.get("runtime_vars", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } + ): + return tgt + return END + + # 动态修改函数名,避免重复 + router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}" + return router_fn + + router_fn = make_router(condition, target) + self.graph.add_conditional_edges(source, router_fn) + logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") + else: + # 普通边 + self.graph.add_edge(source, target) + logger.debug(f"添加边: {source} -> {target}") + + # 从 end 节点连接到 END + for end_node_id in self.end_node_ids: + self.graph.add_edge(end_node_id, END) + logger.debug(f"添加边: {end_node_id} -> END") + return + + def build(self) -> CompiledStateGraph: + self.graph = StateGraph(WorkflowState) + self.add_nodes() + self.add_edges() # 添加边必须在添加节点之后 + return self.graph.compile() diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py index b1b613a4..fcf65717 100644 --- a/api/app/core/workflow/nodes/cycle_graph/config.py +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -1,7 +1,9 @@ +from typing import Any + from pydantic import Field, BaseModel from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType -from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType class CycleVariable(BaseNodeConfig): @@ -9,18 +11,25 @@ class CycleVariable(BaseNodeConfig): ..., description="Name of the loop variable" ) + type: VariableType = Field( ..., description="Data type of the loop variable" ) - value: str = Field( + + input_type: ValueInputType = Field( + ..., + description="Input type of the loop variable" + ) + + value: Any = Field( ..., description="Initial or current value of the loop variable" ) class ConditionDetail(BaseModel): - comparison_operator: ComparisonOperator = Field( + operator: ComparisonOperator = Field( ..., description="Operator used to compare the left and right operands" ) @@ -30,11 +39,16 @@ class ConditionDetail(BaseModel): description="Left-hand operand of the comparison expression" ) - right: str = Field( + right: Any = Field( ..., description="Right-hand operand of the comparison expression" ) + input_type: ValueInputType = Field( + ..., + description="Input type of the loop variable" + ) + class ConditionsConfig(BaseModel): """Configuration for loop condition evaluation""" diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index 2428ef46..fb062f39 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -1,10 +1,9 @@ import logging from typing import Any -from langgraph.graph import StateGraph, START, END +from langgraph.graph import StateGraph from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_condition from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig @@ -17,12 +16,18 @@ logger = logging.getLogger(__name__) class CycleGraphNode(BaseNode): """ - Node representing a cycle (loop) subgraph within the workflow. + Node representing a cyclic (loop or iteration) subgraph within the workflow. - This node manages internal loop/iteration nodes, builds a subgraph - for execution, handles conditional routing, and executes loop - or iteration logic based on node type. + A CycleGraphNode is a structural node that: + - Extracts a group of nodes marked as belonging to the same cycle + - Builds an isolated internal StateGraph (subgraph) + - Delegates runtime execution to LoopRuntime or IterationRuntime + depending on the node type + + This node itself does NOT execute business logic directly. + It acts as a container and execution controller for a subgraph. """ + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None @@ -38,16 +43,23 @@ class CycleGraphNode(BaseNode): def pure_cycle_graph(self) -> tuple[list, list]: """ - Extract cycle nodes and internal edges from the workflow configuration, - removing them from the global workflow. + Extract cycle-scoped nodes and internal edges from the workflow configuration. - Raises: - ValueError: If cycle nodes are connected to external nodes improperly. + This method: + - Identifies all nodes marked with `cycle == self.node_id` + - Collects edges that fully connect cycle nodes + - Removes extracted nodes and edges from the global workflow configuration + + Safety check: + - Raises an error if a cycle node is connected to an external node Returns: - Tuple containing: - - cycle_nodes: List of removed nodes - - cycle_edges: List of removed edges + tuple[list, list]: + - cycle_nodes: Nodes belonging to this cycle + - cycle_edges: Edges connecting nodes within the cycle + + Raises: + ValueError: If a cycle node is improperly connected to an external node. """ nodes = self.workflow_config.get("nodes", []) edges = self.workflow_config.get("edges", []) @@ -83,131 +95,41 @@ class CycleGraphNode(BaseNode): return cycle_nodes, cycle_edges - def create_node(self): - """ - Instantiate node objects for each node in the cycle subgraph and add them to the graph. - - Special handling is applied for conditional nodes to generate - edge conditions based on node outputs. - """ - from app.core.workflow.nodes import NodeFactory - for node in self.cycle_nodes: - node_type = node.get("type") - node_id = node.get("id") - - if node_type == NodeType.CYCLE_START: - self.start_node_id = node_id - continue - elif node_type == NodeType.END: - self.end_node_ids.append(node_id) - - node_instance = NodeFactory.create_node(node, self.workflow_config) - - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: - expressions = node_instance.build_conditional_edge_expressions() - - # Number of branches, usually matches the number of conditional expressions - branch_number = len(expressions) - - # Find all edges whose source is the current node - related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id] - - # Iterate over each branch - for idx in range(branch_number): - # Generate a condition expression for each edge - # Used later to determine which branch to take based on the node's output - # Assumes node output `node..output` matches the edge's label - # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' - related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" - - def make_func(inst): - async def node_func(state: WorkflowState): - return await inst.run(state) - - return node_func - - self.graph.add_node(node_id, make_func(node_instance)) - - def create_edge(self): - """ - Connect nodes within the cycle subgraph by adding edges to the internal graph. - - Conditional edges are routed based on evaluated expressions. - Start and end nodes are connected to global START and END nodes. - """ - for edge in self.cycle_edges: - source = edge.get("source") - target = edge.get("target") - edge_type = edge.get("type") - condition = edge.get("condition") - - # 跳过从 start 节点出发的边(因为已经从 START 连接到 start) - if source == self.start_node_id: - # 但要连接 start 到下一个节点 - self.graph.add_edge(START, target) - logger.debug(f"添加边: {source} -> {target}") - continue - - if condition: - # 条件边 - def router(state: WorkflowState, cond=condition, tgt=target): - """条件路由函数""" - if evaluate_condition( - cond, - state.get("variables", {}), - state.get("node_outputs", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } - ): - return tgt - return END # 条件不满足,结束 - - self.graph.add_conditional_edges(source, router) - logger.debug(f"添加条件边: {source} -> {target} (condition={condition})") - else: - # 普通边 - self.graph.add_edge(source, target) - logger.debug(f"添加边: {source} -> {target}") - - # 从 end 节点连接到 END - for end_node_id in self.end_node_ids: - self.graph.add_edge(end_node_id, END) - logger.debug(f"添加边: {end_node_id} -> END") - def build_graph(self): """ - Build the internal subgraph for the cycle node. + Build and compile the internal subgraph for this cycle node. Steps: - 1. Extract cycle nodes and edges. - 2. Create node instances and add them to the graph. - 3. Connect edges and conditional routes. - 4. Compile the graph for execution. + 1. Extract cycle nodes and internal edges from the workflow + 2. Construct a StateGraph using GraphBuilder in subgraph mode + 3. Compile the graph for runtime execution """ - self.graph = StateGraph(WorkflowState) + from app.core.workflow.graph_builder import GraphBuilder self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() - self.create_node() - self.create_edge() - self.graph = self.graph.compile() + self.graph = GraphBuilder( + { + "nodes": self.cycle_nodes, + "edges": self.cycle_edges, + }, + subgraph=True + ).build() async def execute(self, state: WorkflowState) -> Any: """ Execute the cycle node at runtime. - Depending on the node type, runs either a loop (LoopRuntime) - or an iteration (IterationRuntime) over the internal subgraph. + Based on the node type: + - LOOP: Executes LoopRuntime, repeatedly invoking the subgraph + - ITERATION: Executes IterationRuntime, iterating over a collection Args: - state: Current workflow state. + state: The current workflow state when entering the cycle node. Returns: - Runtime result of the cycle, typically the final loop/iteration variables. + Any: The runtime result produced by the loop or iteration executor. Raises: - RuntimeError: If node type is unrecognized. + RuntimeError: If the node type is unsupported. """ if self.node_type == NodeType.LOOP: return await LoopRuntime( diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index ed26533d..f86a2b9b 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -72,6 +72,7 @@ class NodeFactory: NodeType.LOOP: CycleGraphNode, NodeType.ITERATION: CycleGraphNode, NodeType.BREAK: BreakNode, + NodeType.CYCLE_START: StartNode, } @classmethod