perf(workflow): Optimize downstream node activation method to reduce performance overhead

This commit is contained in:
Eternity
2026-03-25 17:03:12 +08:00
parent 45eef12842
commit 85daf576e9
19 changed files with 122 additions and 120 deletions

View File

@@ -1099,7 +1099,6 @@ class ExtractionOrchestrator:
metadata=chunk.metadata, metadata=chunk.metadata,
) )
chunk_nodes.append(chunk_node) chunk_nodes.append(chunk_node)
logger.error(f"chunk file: {chunk.files}")
for p, file_type in chunk.files: for p, file_type in chunk.files:

View File

@@ -7,7 +7,7 @@ import re
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Any, Iterable from typing import Any, Iterable, Callable
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END from langgraph.graph import START, END
@@ -41,39 +41,31 @@ class GraphBuilder:
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
stream: bool = False, stream: bool = False,
subgraph: bool = False, cycle: str = '',
variable_pool: VariablePool | None = None variable_pool: VariablePool | None = None
): ):
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.stream = stream self.stream = stream
self.subgraph = subgraph self.cycle = cycle
self.start_node_id: str | None = None self.start_node_id: str | None = None
self.node_map = {node["id"]: node for node in self.nodes} self.node_map: dict[str, dict] = {}
self.end_node_map: dict[str, StreamOutputConfig] = {} self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep = lru_cache( self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
maxsize=len(self.nodes) * 2
)(self._find_upstream_activation_dep)
if variable_pool: if variable_pool:
self.variable_pool = variable_pool self.variable_pool = variable_pool
else: else:
self.variable_pool = VariablePool() self.variable_pool = VariablePool()
self.graph: StateGraph | None = None self.graph: StateGraph | None = None
self.nodes: list = []
self.edges: list = []
self.reachable_nodes: set[str] | None = None self.reachable_nodes: set[str] | None = None
self.end_nodes: list[dict] = [] self.end_nodes: list[dict] = []
self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list) self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] | None = defaultdict(list) self._adj: dict[str, list[str]] = defaultdict(list)
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def get_node_type(self, node_id: str) -> str: def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID. """Retrieve the type of node given its ID.
@@ -294,22 +286,13 @@ class GraphBuilder:
""" """
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id") node_id = node.get("id")
cycle_node = node.get("cycle") if node_id not in self.reachable_nodes:
if cycle_node: continue
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
# Create node instance (start and end nodes are also created) # Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config) node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
if node_type in BRANCH_NODES: if node_type in BRANCH_NODES:
@@ -503,21 +486,46 @@ class GraphBuilder:
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState) nodes = self.workflow_config.get("nodes", [])
self.add_nodes() edges = self.workflow_config.get("edges", [])
for node in nodes:
if (node.get("cycle") or '') == self.cycle:
node_type = node.get("type")
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node.get("id")
elif node_type == NodeType.NOTES:
continue
self.nodes.append(node)
self.node_map[node.get("id")] = node
for edge in edges:
source_in = edge.get("source") in self.node_map
target_in = edge.get("target") in self.node_map
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
self.edges.append(edge)
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges) self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [ self.end_nodes = [
node node
for node in self.nodes for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes if node.get("type") == "end" and node.get("id") in self.reachable_nodes
] ]
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list)
self._build_adj() self._build_adj()
self._find_upstream_activation_dep: Callable = lru_cache(
maxsize=len(self.nodes)*2
)(self._find_upstream_activation_dep)
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges() self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._analyze_end_node_output() self._analyze_end_node_output()
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
return self.graph.compile(checkpointer=checkpointer) return self.graph.compile(checkpointer=checkpointer)

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode): class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None

View File

@@ -28,7 +28,7 @@ class BaseNode(ABC):
All node types should inherit from this class and implement the `execute` method. All node types should inherit from this class and implement the `execute` method.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""Initialize the node. """Initialize the node.
Args: Args:
@@ -41,6 +41,7 @@ class BaseNode(ABC):
self.node_type = node_config["type"] self.node_type = node_config["type"]
self.cycle = node_config.get("cycle") self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id) self.node_name = node_config.get("name", self.node_id)
self.down_stream_nodes = down_stream_nodes
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {} self.error_handling = node_config.get("error_handling") or {}
@@ -93,18 +94,16 @@ class BaseNode(ABC):
dict: A dict with a single key 'activate', mapping node IDs to dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False). their activation status (True/False).
""" """
edges = self.workflow_config.get("edges") activate_flag = self.check_activate(state)
under_stream_nodes = [
edge.get("target") if self.node_type not in BRANCH_NODES:
for edge in edges activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES else:
] activate = {}
return {
"activate": { activate[self.node_id] = activate_flag
node_id: self.check_activate(state)
for node_id in under_stream_nodes return {"activate": activate}
} | {self.node_id: self.check_activate(state)}
}
@abstractmethod @abstractmethod
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -428,8 +427,8 @@ class BaseNode(ABC):
when an error edge exists. If no error edge exists, this method when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow. raises an exception to stop the workflow.
""" """
# Check if the node has an error edge defined # # Check if the node has an error edge defined
error_edge = self._find_error_edge() # error_edge = self._find_error_edge()
# Extract input data (for logging or audit purposes) # Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool) input_data = self._extract_input(state, variable_pool)
@@ -447,27 +446,26 @@ class BaseNode(ABC):
"error": error_message "error": error_message
} }
if error_edge: # if error_edge:
# If an error edge exists, log a warning and continue to error node # # If an error edge exists, log a warning and continue to error node
logger.warning( # logger.warning(
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" # f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
) # )
return { # return {
"node_outputs": { # "node_outputs": {
self.node_id: node_output # self.node_id: node_output
}, # },
"error": error_message, # "error": error_message,
"error_node": self.node_id # "error_node": self.node_id
} # }
else: # else:
# If no error edge, send the error via stream writer and stop the workflow writer = get_stream_writer()
writer = get_stream_writer() writer({
writer({ "type": "node_error",
"type": "node_error", **node_output
**node_output })
}) logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") raise Exception(f"Node {self.node_id} execution failed: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit). """Extracts the input data for this node (used for logging or audit).

View File

@@ -51,8 +51,8 @@ console.log(result)
class CodeNode(BaseNode): class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: CodeNodeConfig | None = None self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -30,8 +30,8 @@ class CycleGraphNode(BaseNode):
It acts as a container and execution controller for a subgraph. It acts as a container and execution controller for a subgraph.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph() self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.start_node_id = None # ID of the start node within the cycle self.start_node_id = None # ID of the start node within the cycle
@@ -115,11 +115,11 @@ class CycleGraphNode(BaseNode):
else: else:
remain_edges.append(edge) remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges # # Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [ # self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id # node for node in nodes if node.get("cycle") != self.node_id
] # ]
self.workflow_config["edges"] = remain_edges # self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges return cycle_nodes, cycle_edges
@@ -140,8 +140,8 @@ class CycleGraphNode(BaseNode):
"nodes": self.cycle_nodes, "nodes": self.cycle_nodes,
"edges": self.cycle_edges, "edges": self.cycle_edges,
}, },
subgraph=True, variable_pool=self.child_variable_pool,
variable_pool=self.child_variable_pool cycle=self.node_id
) )
self.graph = builder.build() self.graph = builder.build()
self.start_node_id = builder.start_node_id self.start_node_id = builder.start_node_id

View File

@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
or a branch identifier string when error branching is enabled. or a branch identifier string when error branching is enabled.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: IfElseNodeConfig | None = None self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: JinjaRenderNodeConfig | None = None self.typed_config: JinjaRenderNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode): class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None self.vector_service: ElasticSearchVector | None = None

View File

@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: LLMNodeConfig | None = None self.typed_config: LLMNodeConfig | None = None
self.messages = [] self.messages = []

View File

@@ -14,8 +14,8 @@ from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryReadNodeConfig | None = None self.typed_config: MemoryReadNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
@@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode):
class MemoryWriteNode(BaseNode): class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryWriteNodeConfig | None = None self.typed_config: MemoryWriteNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -104,13 +104,15 @@ class NodeFactory:
def create_node( def create_node(
cls, cls,
node_config: dict[str, Any], node_config: dict[str, Any],
workflow_config: dict[str, Any] workflow_config: dict[str, Any],
down_stream_nodes: list[str]
) -> WorkflowNode | None: ) -> WorkflowNode | None:
"""创建节点实例 """创建节点实例
Args: Args:
node_config: 节点配置 node_config: 节点配置
workflow_config: 工作流配置 workflow_config: 工作流配置
down_stream_nodes: 下游节点
Returns: Returns:
节点实例或 None对于不支持的节点类型 节点实例或 None对于不支持的节点类型
@@ -127,7 +129,7 @@ class NodeFactory:
# 创建节点实例 # 创建节点实例
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config) return node_class(node_config, workflow_config, down_stream_nodes)
@classmethod @classmethod
def get_supported_types(cls) -> list[str]: def get_supported_types(cls) -> list[str]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode): class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ParameterExtractorNodeConfig | None = None self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {} self.response_metadata = {}

View File

@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: QuestionClassifierNodeConfig | None = None self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {} self.category_to_case_map = {}
self.response_metadata = {} self.response_metadata = {}

View File

@@ -27,14 +27,8 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。 注意:变量的验证和默认值处理由 Executor 在初始化时完成。
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""初始化 Start 节点 super().__init__(node_config, workflow_config, down_stream_nodes)
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置 # 解析并验证配置
self.typed_config: StartNodeConfig | None = None self.typed_config: StartNodeConfig | None = None

View File

@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
class ToolNode(BaseNode): class ToolNode(BaseNode):
"""工具节点""" """工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ToolNodeConfig | None = None self.typed_config: ToolNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode): class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: VariableAggregatorNodeConfig | None = None self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -153,7 +153,8 @@ class TemplateRenderer:
# 全局渲染器实例(严格模式) # 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True) _strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template( def render_template(
@@ -184,7 +185,7 @@ def render_template(
... ) ... )
'请分析: 这是一段文本' '请分析: 这是一段文本'
""" """
renderer = TemplateRenderer(strict=strict) renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars) return renderer.render(template, conv_vars, node_outputs, system_vars)
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
Returns: Returns:
错误列表 错误列表
""" """
return _default_renderer.validate(template) return _strict_renderer.validate(template)