refactor(workflow): refactor graph construction to support subgraph building
This commit is contained in:
@@ -10,11 +10,10 @@ import logging
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.graph_builder import GraphBuilder
|
||||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
# from app.core.tools.registry import ToolRegistry
|
# from app.core.tools.registry import ToolRegistry
|
||||||
@@ -191,159 +190,10 @@ class WorkflowExecutor:
|
|||||||
编译后的状态图
|
编译后的状态图
|
||||||
"""
|
"""
|
||||||
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
logger.info(f"开始构建工作流图: execution_id={self.execution_id}")
|
||||||
|
graph = GraphBuilder(
|
||||||
# 分析 End 节点的前缀配置和相邻且被引用的节点
|
self.workflow_config,
|
||||||
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if stream else ({}, set())
|
stream=stream,
|
||||||
|
).build()
|
||||||
# 1. 创建状态图
|
|
||||||
workflow = StateGraph(WorkflowState)
|
|
||||||
|
|
||||||
# 2. 添加所有节点(包括 start 和 end)
|
|
||||||
start_node_id = None
|
|
||||||
end_node_ids = []
|
|
||||||
|
|
||||||
for node in self.nodes:
|
|
||||||
node_type = node.get("type")
|
|
||||||
node_id = node.get("id")
|
|
||||||
cycle_node = node.get("cycle")
|
|
||||||
if cycle_node:
|
|
||||||
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 记录 start 和 end 节点 ID
|
|
||||||
if node_type == NodeType.START:
|
|
||||||
start_node_id = node_id
|
|
||||||
elif node_type == NodeType.END:
|
|
||||||
end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
|
||||||
|
|
||||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
|
|
||||||
expressions = node_instance.build_conditional_edge_expressions()
|
|
||||||
|
|
||||||
# Number of branches, usually matches the number of conditional expressions
|
|
||||||
branch_number = len(expressions)
|
|
||||||
|
|
||||||
# Find all edges whose source is the current node
|
|
||||||
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
|
||||||
|
|
||||||
# Iterate over each branch
|
|
||||||
for idx in range(branch_number):
|
|
||||||
# Generate a condition expression for each edge
|
|
||||||
# Used later to determine which branch to take based on the node's output
|
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
|
||||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
|
||||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
|
||||||
|
|
||||||
if node_instance:
|
|
||||||
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
|
||||||
if stream and node_id in end_prefixes:
|
|
||||||
# 将 End 前缀配置注入到节点实例
|
|
||||||
node_instance._end_node_prefix = end_prefixes[node_id]
|
|
||||||
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
|
||||||
|
|
||||||
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
|
||||||
if stream:
|
|
||||||
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
|
||||||
if node_id in adjacent_and_referenced:
|
|
||||||
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
|
||||||
|
|
||||||
# 包装节点的 run 方法
|
|
||||||
# 使用函数工厂避免闭包问题
|
|
||||||
if stream:
|
|
||||||
# 流式模式:创建 async generator 函数
|
|
||||||
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
|
||||||
def make_stream_func(inst):
|
|
||||||
async def node_func(state: WorkflowState):
|
|
||||||
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
|
||||||
async for item in inst.run_stream(state):
|
|
||||||
yield item
|
|
||||||
|
|
||||||
return node_func
|
|
||||||
|
|
||||||
workflow.add_node(node_id, make_stream_func(node_instance))
|
|
||||||
else:
|
|
||||||
# 非流式模式:创建 async function
|
|
||||||
def make_func(inst):
|
|
||||||
async def node_func(state: WorkflowState):
|
|
||||||
return await inst.run(state)
|
|
||||||
|
|
||||||
return node_func
|
|
||||||
|
|
||||||
workflow.add_node(node_id, make_func(node_instance))
|
|
||||||
|
|
||||||
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={stream})")
|
|
||||||
|
|
||||||
# 3. 添加边
|
|
||||||
# 从 START 连接到 start 节点
|
|
||||||
if start_node_id:
|
|
||||||
workflow.add_edge(START, start_node_id)
|
|
||||||
logger.debug(f"添加边: START -> {start_node_id}")
|
|
||||||
|
|
||||||
for edge in self.workflow_config.get("edges", []):
|
|
||||||
source = edge.get("source")
|
|
||||||
target = edge.get("target")
|
|
||||||
edge_type = edge.get("type")
|
|
||||||
condition = edge.get("condition")
|
|
||||||
|
|
||||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
|
||||||
if source == start_node_id:
|
|
||||||
# 但要连接 start 到下一个节点
|
|
||||||
workflow.add_edge(source, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# # 处理到 end 节点的边
|
|
||||||
# if target in end_node_ids:
|
|
||||||
# # 连接到 end 节点
|
|
||||||
# workflow.add_edge(source, target)
|
|
||||||
# logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# 跳过错误边(在节点内部处理)
|
|
||||||
if edge_type == "error":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if condition:
|
|
||||||
# 条件边
|
|
||||||
def make_router(cond, tgt):
|
|
||||||
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
|
|
||||||
|
|
||||||
|
|
||||||
def router_fn(state: WorkflowState):
|
|
||||||
if evaluate_condition(
|
|
||||||
cond,
|
|
||||||
state.get("variables", {}),
|
|
||||||
state.get("node_outputs", {}),
|
|
||||||
{
|
|
||||||
"execution_id": state.get("execution_id"),
|
|
||||||
"workspace_id": state.get("workspace_id"),
|
|
||||||
"user_id": state.get("user_id")
|
|
||||||
}
|
|
||||||
):
|
|
||||||
return tgt
|
|
||||||
return END
|
|
||||||
|
|
||||||
# 动态修改函数名,避免重复
|
|
||||||
router_fn.__name__ = f"router_{tgt}"
|
|
||||||
return router_fn
|
|
||||||
|
|
||||||
router_fn = make_router(condition, target)
|
|
||||||
workflow.add_conditional_edges(source, router_fn)
|
|
||||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
|
||||||
else:
|
|
||||||
# 普通边
|
|
||||||
workflow.add_edge(source, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
|
|
||||||
# 从 end 节点连接到 END
|
|
||||||
for end_node_id in end_node_ids:
|
|
||||||
workflow.add_edge(end_node_id, END)
|
|
||||||
logger.debug(f"添加边: {end_node_id} -> END")
|
|
||||||
|
|
||||||
# 4. 编译图
|
|
||||||
graph = workflow.compile()
|
|
||||||
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
logger.info(f"工作流图构建完成: execution_id={self.execution_id}")
|
||||||
|
|
||||||
return graph
|
return graph
|
||||||
|
|||||||
253
api/app/core/workflow/graph_builder.py
Normal file
253
api/app/core/workflow/graph_builder.py
Normal file
@@ -0,0 +1,253 @@
|
|||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langgraph.graph.state import CompiledStateGraph, StateGraph
|
||||||
|
from langgraph.graph import START, END
|
||||||
|
|
||||||
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: 子图拆解支持
|
||||||
|
class GraphBuilder:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
workflow_config: dict[str, Any],
|
||||||
|
stream: bool = False,
|
||||||
|
subgraph: bool = False,
|
||||||
|
):
|
||||||
|
self.workflow_config = workflow_config
|
||||||
|
|
||||||
|
self.stream = stream
|
||||||
|
self.subgraph = subgraph
|
||||||
|
|
||||||
|
self.start_node_id = None
|
||||||
|
self.end_node_ids = []
|
||||||
|
|
||||||
|
self.graph: StateGraph | CompiledStateGraph | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes(self) -> list[dict[str, Any]]:
|
||||||
|
return self.workflow_config.get("nodes", [])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def edges(self) -> list[dict[str, Any]]:
|
||||||
|
return self.workflow_config.get("edges", [])
|
||||||
|
|
||||||
|
def _analyze_end_node_prefixes(self) -> tuple[dict[str, str], set[str]]:
|
||||||
|
"""分析 End 节点的前缀配置
|
||||||
|
|
||||||
|
检查每个 End 节点的模板,找到直接上游节点的引用,
|
||||||
|
提取该引用之前的前缀部分。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
元组:({上游节点ID: End节点前缀}, {与End相邻且被引用的节点ID集合})
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
prefixes = {}
|
||||||
|
adjacent_and_referenced = set() # 记录与 End 节点相邻且被引用的节点
|
||||||
|
|
||||||
|
# 找到所有 End 节点
|
||||||
|
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||||
|
logger.info(f"[前缀分析] 找到 {len(end_nodes)} 个 End 节点")
|
||||||
|
|
||||||
|
for end_node in end_nodes:
|
||||||
|
end_node_id = end_node.get("id")
|
||||||
|
output_template = end_node.get("config", {}).get("output")
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] End 节点 {end_node_id} 模板: {output_template}")
|
||||||
|
|
||||||
|
if not output_template:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 查找模板中引用了哪些节点
|
||||||
|
# 匹配 {{node_id.xxx}} 或 {{ node_id.xxx }} 格式(支持空格)
|
||||||
|
pattern = r'\{\{\s*([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\s*\}\}'
|
||||||
|
matches = list(re.finditer(pattern, output_template))
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] 模板中找到 {len(matches)} 个节点引用")
|
||||||
|
|
||||||
|
# 找到所有直接连接到 End 节点的上游节点
|
||||||
|
direct_upstream_nodes = []
|
||||||
|
for edge in self.edges:
|
||||||
|
if edge.get("target") == end_node_id:
|
||||||
|
source_node_id = edge.get("source")
|
||||||
|
direct_upstream_nodes.append(source_node_id)
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] End 节点的直接上游节点: {direct_upstream_nodes}")
|
||||||
|
|
||||||
|
# 找到第一个直接上游节点的引用
|
||||||
|
for match in matches:
|
||||||
|
referenced_node_id = match.group(1)
|
||||||
|
logger.info(f"[前缀分析] 检查引用: {referenced_node_id}")
|
||||||
|
|
||||||
|
if referenced_node_id in direct_upstream_nodes:
|
||||||
|
# 这是直接上游节点的引用,提取前缀
|
||||||
|
prefix = output_template[:match.start()]
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] ✅ 找到直接上游节点 {referenced_node_id} 的引用,前缀: '{prefix}'")
|
||||||
|
|
||||||
|
# 标记这个节点为"相邻且被引用"
|
||||||
|
adjacent_and_referenced.add(referenced_node_id)
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
prefixes[referenced_node_id] = prefix
|
||||||
|
logger.info(f"✅ [前缀分析] 为节点 {referenced_node_id} 配置前缀: '{prefix[:50]}...'")
|
||||||
|
|
||||||
|
# 只处理第一个直接上游节点的引用
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"[前缀分析] 最终配置: {prefixes}")
|
||||||
|
logger.info(f"[前缀分析] 与 End 相邻且被引用的节点: {adjacent_and_referenced}")
|
||||||
|
return prefixes, adjacent_and_referenced
|
||||||
|
|
||||||
|
def add_nodes(self):
|
||||||
|
end_prefixes, adjacent_and_referenced = self._analyze_end_node_prefixes() if self.stream else ({}, set())
|
||||||
|
|
||||||
|
for node in self.nodes:
|
||||||
|
node_type = node.get("type")
|
||||||
|
node_id = node.get("id")
|
||||||
|
cycle_node = node.get("cycle")
|
||||||
|
if cycle_node:
|
||||||
|
# 处于循环子图中的节点由 CycleGraphNode 进行构建处理
|
||||||
|
if not self.subgraph:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 记录 start 和 end 节点 ID
|
||||||
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
|
self.start_node_id = node_id
|
||||||
|
elif node_type == NodeType.END:
|
||||||
|
self.end_node_ids.append(node_id)
|
||||||
|
|
||||||
|
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||||
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
|
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||||
|
|
||||||
|
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
|
||||||
|
|
||||||
|
# Find all edges whose source is the current node
|
||||||
|
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||||
|
|
||||||
|
# Iterate over each branch
|
||||||
|
for idx in range(len(related_edge)):
|
||||||
|
# Generate a condition expression for each edge
|
||||||
|
# Used later to determine which branch to take based on the node's output
|
||||||
|
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||||
|
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||||
|
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||||
|
|
||||||
|
if node_instance:
|
||||||
|
# 如果是流式模式,且节点有 End 前缀配置,注入配置
|
||||||
|
if self.stream and node_id in end_prefixes:
|
||||||
|
# 将 End 前缀配置注入到节点实例
|
||||||
|
node_instance._end_node_prefix = end_prefixes[node_id]
|
||||||
|
logger.info(f"为节点 {node_id} 注入 End 前缀配置")
|
||||||
|
|
||||||
|
# 如果是流式模式,标记节点是否与 End 相邻且被引用
|
||||||
|
if self.stream:
|
||||||
|
node_instance._is_adjacent_to_end = node_id in adjacent_and_referenced
|
||||||
|
if node_id in adjacent_and_referenced:
|
||||||
|
logger.info(f"节点 {node_id} 标记为与 End 相邻且被引用")
|
||||||
|
|
||||||
|
# 包装节点的 run 方法
|
||||||
|
# 使用函数工厂避免闭包问题
|
||||||
|
if self.stream:
|
||||||
|
# 流式模式:创建 async generator 函数
|
||||||
|
# LangGraph 会收集所有 yield 的值,最后一个 yield 的字典会被合并到 state
|
||||||
|
def make_stream_func(inst):
|
||||||
|
async def node_func(state: WorkflowState):
|
||||||
|
# logger.debug(f"流式执行节点: {inst.node_id}, 支持流式: {inst.supports_streaming()}")
|
||||||
|
async for item in inst.run_stream(state):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
return node_func
|
||||||
|
|
||||||
|
self.graph.add_node(node_id, make_stream_func(node_instance))
|
||||||
|
else:
|
||||||
|
# 非流式模式:创建 async function
|
||||||
|
def make_func(inst):
|
||||||
|
async def node_func(state: WorkflowState):
|
||||||
|
return await inst.run(state)
|
||||||
|
|
||||||
|
return node_func
|
||||||
|
|
||||||
|
self.graph.add_node(node_id, make_func(node_instance))
|
||||||
|
|
||||||
|
logger.debug(f"添加节点: {node_id} (type={node_type}, stream={self.stream})")
|
||||||
|
|
||||||
|
def add_edges(self):
|
||||||
|
if self.start_node_id:
|
||||||
|
self.graph.add_edge(START, self.start_node_id)
|
||||||
|
logger.debug(f"添加边: START -> {self.start_node_id}")
|
||||||
|
|
||||||
|
for edge in self.edges:
|
||||||
|
source = edge.get("source")
|
||||||
|
target = edge.get("target")
|
||||||
|
edge_type = edge.get("type")
|
||||||
|
condition = edge.get("condition")
|
||||||
|
|
||||||
|
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
||||||
|
if source == self.start_node_id:
|
||||||
|
# 但要连接 start 到下一个节点
|
||||||
|
self.graph.add_edge(source, target)
|
||||||
|
logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# # 处理到 end 节点的边
|
||||||
|
# if target in end_node_ids:
|
||||||
|
# # 连接到 end 节点
|
||||||
|
# workflow.add_edge(source, target)
|
||||||
|
# logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
# continue
|
||||||
|
|
||||||
|
# 跳过错误边(在节点内部处理)
|
||||||
|
if edge_type == "error":
|
||||||
|
continue
|
||||||
|
|
||||||
|
if condition:
|
||||||
|
# 条件边
|
||||||
|
def make_router(cond, tgt):
|
||||||
|
"""Dynamically generate a conditional router function to ensure each branch has a unique name."""
|
||||||
|
|
||||||
|
def router_fn(state: WorkflowState):
|
||||||
|
if evaluate_condition(
|
||||||
|
cond,
|
||||||
|
state.get("variables", {}),
|
||||||
|
state.get("runtime_vars", {}),
|
||||||
|
{
|
||||||
|
"execution_id": state.get("execution_id"),
|
||||||
|
"workspace_id": state.get("workspace_id"),
|
||||||
|
"user_id": state.get("user_id")
|
||||||
|
}
|
||||||
|
):
|
||||||
|
return tgt
|
||||||
|
return END
|
||||||
|
|
||||||
|
# 动态修改函数名,避免重复
|
||||||
|
router_fn.__name__ = f"router_{uuid.uuid4().hex[:8]}_{tgt}"
|
||||||
|
return router_fn
|
||||||
|
|
||||||
|
router_fn = make_router(condition, target)
|
||||||
|
self.graph.add_conditional_edges(source, router_fn)
|
||||||
|
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
||||||
|
else:
|
||||||
|
# 普通边
|
||||||
|
self.graph.add_edge(source, target)
|
||||||
|
logger.debug(f"添加边: {source} -> {target}")
|
||||||
|
|
||||||
|
# 从 end 节点连接到 END
|
||||||
|
for end_node_id in self.end_node_ids:
|
||||||
|
self.graph.add_edge(end_node_id, END)
|
||||||
|
logger.debug(f"添加边: {end_node_id} -> END")
|
||||||
|
return
|
||||||
|
|
||||||
|
def build(self) -> CompiledStateGraph:
|
||||||
|
self.graph = StateGraph(WorkflowState)
|
||||||
|
self.add_nodes()
|
||||||
|
self.add_edges() # 添加边必须在添加节点之后
|
||||||
|
return self.graph.compile()
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, BaseModel
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableType
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||||
|
|
||||||
|
|
||||||
class CycleVariable(BaseNodeConfig):
|
class CycleVariable(BaseNodeConfig):
|
||||||
@@ -9,18 +11,25 @@ class CycleVariable(BaseNodeConfig):
|
|||||||
...,
|
...,
|
||||||
description="Name of the loop variable"
|
description="Name of the loop variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
type: VariableType = Field(
|
type: VariableType = Field(
|
||||||
...,
|
...,
|
||||||
description="Data type of the loop variable"
|
description="Data type of the loop variable"
|
||||||
)
|
)
|
||||||
value: str = Field(
|
|
||||||
|
input_type: ValueInputType = Field(
|
||||||
|
...,
|
||||||
|
description="Input type of the loop variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
value: Any = Field(
|
||||||
...,
|
...,
|
||||||
description="Initial or current value of the loop variable"
|
description="Initial or current value of the loop variable"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConditionDetail(BaseModel):
|
class ConditionDetail(BaseModel):
|
||||||
comparison_operator: ComparisonOperator = Field(
|
operator: ComparisonOperator = Field(
|
||||||
...,
|
...,
|
||||||
description="Operator used to compare the left and right operands"
|
description="Operator used to compare the left and right operands"
|
||||||
)
|
)
|
||||||
@@ -30,11 +39,16 @@ class ConditionDetail(BaseModel):
|
|||||||
description="Left-hand operand of the comparison expression"
|
description="Left-hand operand of the comparison expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
right: str = Field(
|
right: Any = Field(
|
||||||
...,
|
...,
|
||||||
description="Right-hand operand of the comparison expression"
|
description="Right-hand operand of the comparison expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_type: ValueInputType = Field(
|
||||||
|
...,
|
||||||
|
description="Input type of the loop variable"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConditionsConfig(BaseModel):
|
class ConditionsConfig(BaseModel):
|
||||||
"""Configuration for loop condition evaluation"""
|
"""Configuration for loop condition evaluation"""
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
from app.core.workflow.nodes.cycle_graph.config import LoopNodeConfig, IterationNodeConfig
|
||||||
@@ -17,12 +16,18 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class CycleGraphNode(BaseNode):
|
class CycleGraphNode(BaseNode):
|
||||||
"""
|
"""
|
||||||
Node representing a cycle (loop) subgraph within the workflow.
|
Node representing a cyclic (loop or iteration) subgraph within the workflow.
|
||||||
|
|
||||||
This node manages internal loop/iteration nodes, builds a subgraph
|
A CycleGraphNode is a structural node that:
|
||||||
for execution, handles conditional routing, and executes loop
|
- Extracts a group of nodes marked as belonging to the same cycle
|
||||||
or iteration logic based on node type.
|
- Builds an isolated internal StateGraph (subgraph)
|
||||||
|
- Delegates runtime execution to LoopRuntime or IterationRuntime
|
||||||
|
depending on the node type
|
||||||
|
|
||||||
|
This node itself does NOT execute business logic directly.
|
||||||
|
It acts as a container and execution controller for a subgraph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None
|
||||||
@@ -38,16 +43,23 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
def pure_cycle_graph(self) -> tuple[list, list]:
|
def pure_cycle_graph(self) -> tuple[list, list]:
|
||||||
"""
|
"""
|
||||||
Extract cycle nodes and internal edges from the workflow configuration,
|
Extract cycle-scoped nodes and internal edges from the workflow configuration.
|
||||||
removing them from the global workflow.
|
|
||||||
|
|
||||||
Raises:
|
This method:
|
||||||
ValueError: If cycle nodes are connected to external nodes improperly.
|
- Identifies all nodes marked with `cycle == self.node_id`
|
||||||
|
- Collects edges that fully connect cycle nodes
|
||||||
|
- Removes extracted nodes and edges from the global workflow configuration
|
||||||
|
|
||||||
|
Safety check:
|
||||||
|
- Raises an error if a cycle node is connected to an external node
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple containing:
|
tuple[list, list]:
|
||||||
- cycle_nodes: List of removed nodes
|
- cycle_nodes: Nodes belonging to this cycle
|
||||||
- cycle_edges: List of removed edges
|
- cycle_edges: Edges connecting nodes within the cycle
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a cycle node is improperly connected to an external node.
|
||||||
"""
|
"""
|
||||||
nodes = self.workflow_config.get("nodes", [])
|
nodes = self.workflow_config.get("nodes", [])
|
||||||
edges = self.workflow_config.get("edges", [])
|
edges = self.workflow_config.get("edges", [])
|
||||||
@@ -83,131 +95,41 @@ class CycleGraphNode(BaseNode):
|
|||||||
|
|
||||||
return cycle_nodes, cycle_edges
|
return cycle_nodes, cycle_edges
|
||||||
|
|
||||||
def create_node(self):
|
|
||||||
"""
|
|
||||||
Instantiate node objects for each node in the cycle subgraph and add them to the graph.
|
|
||||||
|
|
||||||
Special handling is applied for conditional nodes to generate
|
|
||||||
edge conditions based on node outputs.
|
|
||||||
"""
|
|
||||||
from app.core.workflow.nodes import NodeFactory
|
|
||||||
for node in self.cycle_nodes:
|
|
||||||
node_type = node.get("type")
|
|
||||||
node_id = node.get("id")
|
|
||||||
|
|
||||||
if node_type == NodeType.CYCLE_START:
|
|
||||||
self.start_node_id = node_id
|
|
||||||
continue
|
|
||||||
elif node_type == NodeType.END:
|
|
||||||
self.end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
|
||||||
|
|
||||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
|
|
||||||
expressions = node_instance.build_conditional_edge_expressions()
|
|
||||||
|
|
||||||
# Number of branches, usually matches the number of conditional expressions
|
|
||||||
branch_number = len(expressions)
|
|
||||||
|
|
||||||
# Find all edges whose source is the current node
|
|
||||||
related_edge = [edge for edge in self.cycle_edges if edge.get("source") == node_id]
|
|
||||||
|
|
||||||
# Iterate over each branch
|
|
||||||
for idx in range(branch_number):
|
|
||||||
# Generate a condition expression for each edge
|
|
||||||
# Used later to determine which branch to take based on the node's output
|
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
|
||||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
|
||||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
|
||||||
|
|
||||||
def make_func(inst):
|
|
||||||
async def node_func(state: WorkflowState):
|
|
||||||
return await inst.run(state)
|
|
||||||
|
|
||||||
return node_func
|
|
||||||
|
|
||||||
self.graph.add_node(node_id, make_func(node_instance))
|
|
||||||
|
|
||||||
def create_edge(self):
|
|
||||||
"""
|
|
||||||
Connect nodes within the cycle subgraph by adding edges to the internal graph.
|
|
||||||
|
|
||||||
Conditional edges are routed based on evaluated expressions.
|
|
||||||
Start and end nodes are connected to global START and END nodes.
|
|
||||||
"""
|
|
||||||
for edge in self.cycle_edges:
|
|
||||||
source = edge.get("source")
|
|
||||||
target = edge.get("target")
|
|
||||||
edge_type = edge.get("type")
|
|
||||||
condition = edge.get("condition")
|
|
||||||
|
|
||||||
# 跳过从 start 节点出发的边(因为已经从 START 连接到 start)
|
|
||||||
if source == self.start_node_id:
|
|
||||||
# 但要连接 start 到下一个节点
|
|
||||||
self.graph.add_edge(START, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if condition:
|
|
||||||
# 条件边
|
|
||||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
|
||||||
"""条件路由函数"""
|
|
||||||
if evaluate_condition(
|
|
||||||
cond,
|
|
||||||
state.get("variables", {}),
|
|
||||||
state.get("node_outputs", {}),
|
|
||||||
{
|
|
||||||
"execution_id": state.get("execution_id"),
|
|
||||||
"workspace_id": state.get("workspace_id"),
|
|
||||||
"user_id": state.get("user_id")
|
|
||||||
}
|
|
||||||
):
|
|
||||||
return tgt
|
|
||||||
return END # 条件不满足,结束
|
|
||||||
|
|
||||||
self.graph.add_conditional_edges(source, router)
|
|
||||||
logger.debug(f"添加条件边: {source} -> {target} (condition={condition})")
|
|
||||||
else:
|
|
||||||
# 普通边
|
|
||||||
self.graph.add_edge(source, target)
|
|
||||||
logger.debug(f"添加边: {source} -> {target}")
|
|
||||||
|
|
||||||
# 从 end 节点连接到 END
|
|
||||||
for end_node_id in self.end_node_ids:
|
|
||||||
self.graph.add_edge(end_node_id, END)
|
|
||||||
logger.debug(f"添加边: {end_node_id} -> END")
|
|
||||||
|
|
||||||
def build_graph(self):
|
def build_graph(self):
|
||||||
"""
|
"""
|
||||||
Build the internal subgraph for the cycle node.
|
Build and compile the internal subgraph for this cycle node.
|
||||||
|
|
||||||
Steps:
|
Steps:
|
||||||
1. Extract cycle nodes and edges.
|
1. Extract cycle nodes and internal edges from the workflow
|
||||||
2. Create node instances and add them to the graph.
|
2. Construct a StateGraph using GraphBuilder in subgraph mode
|
||||||
3. Connect edges and conditional routes.
|
3. Compile the graph for runtime execution
|
||||||
4. Compile the graph for execution.
|
|
||||||
"""
|
"""
|
||||||
self.graph = StateGraph(WorkflowState)
|
from app.core.workflow.graph_builder import GraphBuilder
|
||||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||||
self.create_node()
|
self.graph = GraphBuilder(
|
||||||
self.create_edge()
|
{
|
||||||
self.graph = self.graph.compile()
|
"nodes": self.cycle_nodes,
|
||||||
|
"edges": self.cycle_edges,
|
||||||
|
},
|
||||||
|
subgraph=True
|
||||||
|
).build()
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the cycle node at runtime.
|
Execute the cycle node at runtime.
|
||||||
|
|
||||||
Depending on the node type, runs either a loop (LoopRuntime)
|
Based on the node type:
|
||||||
or an iteration (IterationRuntime) over the internal subgraph.
|
- LOOP: Executes LoopRuntime, repeatedly invoking the subgraph
|
||||||
|
- ITERATION: Executes IterationRuntime, iterating over a collection
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
state: Current workflow state.
|
state: The current workflow state when entering the cycle node.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Runtime result of the cycle, typically the final loop/iteration variables.
|
Any: The runtime result produced by the loop or iteration executor.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
RuntimeError: If node type is unrecognized.
|
RuntimeError: If the node type is unsupported.
|
||||||
"""
|
"""
|
||||||
if self.node_type == NodeType.LOOP:
|
if self.node_type == NodeType.LOOP:
|
||||||
return await LoopRuntime(
|
return await LoopRuntime(
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ class NodeFactory:
|
|||||||
NodeType.LOOP: CycleGraphNode,
|
NodeType.LOOP: CycleGraphNode,
|
||||||
NodeType.ITERATION: CycleGraphNode,
|
NodeType.ITERATION: CycleGraphNode,
|
||||||
NodeType.BREAK: BreakNode,
|
NodeType.BREAK: BreakNode,
|
||||||
|
NodeType.CYCLE_START: StartNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user