perf(workflow): Optimize downstream node activation method to reduce performance overhead

This commit is contained in:
Eternity
2026-03-25 17:03:12 +08:00
parent 45eef12842
commit 85daf576e9
19 changed files with 122 additions and 120 deletions

View File

@@ -1099,7 +1099,6 @@ class ExtractionOrchestrator:
metadata=chunk.metadata,
)
chunk_nodes.append(chunk_node)
logger.error(f"chunk file: {chunk.files}")
for p, file_type in chunk.files:

View File

@@ -7,7 +7,7 @@ import re
import uuid
from collections import defaultdict
from functools import lru_cache
from typing import Any, Iterable
from typing import Any, Iterable, Callable
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END
@@ -41,39 +41,31 @@ class GraphBuilder:
self,
workflow_config: dict[str, Any],
stream: bool = False,
subgraph: bool = False,
cycle: str = '',
variable_pool: VariablePool | None = None
):
self.workflow_config = workflow_config
self.stream = stream
self.subgraph = subgraph
self.cycle = cycle
self.start_node_id: str | None = None
self.node_map = {node["id"]: node for node in self.nodes}
self.node_map: dict[str, dict] = {}
self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_activation_dep)
self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
if variable_pool:
self.variable_pool = variable_pool
else:
self.variable_pool = VariablePool()
self.graph: StateGraph | None = None
self.nodes: list = []
self.edges: list = []
self.reachable_nodes: set[str] | None = None
self.end_nodes: list[dict] = []
self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list)
self._adj: dict[str, list[str]] | None = defaultdict(list)
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list)
def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID.
@@ -294,22 +286,13 @@ class GraphBuilder:
"""
for node in self.nodes:
node_type = node.get("type")
if node_type == NodeType.NOTES:
continue
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
if node_id not in self.reachable_nodes:
continue
# Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config)
node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
if node_type in BRANCH_NODES:
@@ -503,21 +486,46 @@ class GraphBuilder:
return
def build(self) -> CompiledStateGraph:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
for node in nodes:
if (node.get("cycle") or '') == self.cycle:
node_type = node.get("type")
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node.get("id")
elif node_type == NodeType.NOTES:
continue
self.nodes.append(node)
self.node_map[node.get("id")] = node
for edge in edges:
source_in = edge.get("source") in self.node_map
target_in = edge.get("target") in self.node_map
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
self.edges.append(edge)
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list)
self._build_adj()
self._find_upstream_activation_dep: Callable = lru_cache(
maxsize=len(self.nodes)*2
)(self._find_upstream_activation_dep)
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._analyze_end_node_output()
checkpointer = InMemorySaver()
return self.graph.compile(checkpointer=checkpointer)

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None

View File

@@ -28,7 +28,7 @@ class BaseNode(ABC):
All node types should inherit from this class and implement the `execute` method.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
"""Initialize the node.
Args:
@@ -41,6 +41,7 @@ class BaseNode(ABC):
self.node_type = node_config["type"]
self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id)
self.down_stream_nodes = down_stream_nodes
# 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {}
@@ -93,18 +94,16 @@ class BaseNode(ABC):
dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False).
"""
edges = self.workflow_config.get("edges")
under_stream_nodes = [
edge.get("target")
for edge in edges
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
]
return {
"activate": {
node_id: self.check_activate(state)
for node_id in under_stream_nodes
} | {self.node_id: self.check_activate(state)}
}
activate_flag = self.check_activate(state)
if self.node_type not in BRANCH_NODES:
activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
else:
activate = {}
activate[self.node_id] = activate_flag
return {"activate": activate}
@abstractmethod
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -428,8 +427,8 @@ class BaseNode(ABC):
when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow.
"""
# Check if the node has an error edge defined
error_edge = self._find_error_edge()
# # Check if the node has an error edge defined
# error_edge = self._find_error_edge()
# Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool)
@@ -447,27 +446,26 @@ class BaseNode(ABC):
"error": error_message
}
if error_edge:
# If an error edge exists, log a warning and continue to error node
logger.warning(
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
)
return {
"node_outputs": {
self.node_id: node_output
},
"error": error_message,
"error_node": self.node_id
}
else:
# If no error edge, send the error via stream writer and stop the workflow
writer = get_stream_writer()
writer({
"type": "node_error",
**node_output
})
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
# if error_edge:
# # If an error edge exists, log a warning and continue to error node
# logger.warning(
# f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
# )
# return {
# "node_outputs": {
# self.node_id: node_output
# },
# "error": error_message,
# "error_node": self.node_id
# }
# else:
writer = get_stream_writer()
writer({
"type": "node_error",
**node_output
})
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit).

View File

@@ -51,8 +51,8 @@ console.log(result)
class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:

View File

@@ -30,8 +30,8 @@ class CycleGraphNode(BaseNode):
It acts as a container and execution controller for a subgraph.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.start_node_id = None # ID of the start node within the cycle
@@ -115,11 +115,11 @@ class CycleGraphNode(BaseNode):
else:
remain_edges.append(edge)
# Update workflow_config by removing cycle nodes and internal edges
self.workflow_config["nodes"] = [
node for node in nodes if node.get("cycle") != self.node_id
]
self.workflow_config["edges"] = remain_edges
# # Update workflow_config by removing cycle nodes and internal edges
# self.workflow_config["nodes"] = [
# node for node in nodes if node.get("cycle") != self.node_id
# ]
# self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges
@@ -140,8 +140,8 @@ class CycleGraphNode(BaseNode):
"nodes": self.cycle_nodes,
"edges": self.cycle_edges,
},
subgraph=True,
variable_pool=self.child_variable_pool
variable_pool=self.child_variable_pool,
cycle=self.node_id
)
self.graph = builder.build()
self.start_node_id = builder.start_node_id

View File

@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
or a branch identifier string when error branching is enabled.
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: JinjaRenderNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None

View File

@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: LLMNodeConfig | None = None
self.messages = []

View File

@@ -14,8 +14,8 @@ from app.tasks import write_message_task
class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryReadNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:
@@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode):
class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: MemoryWriteNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:

View File

@@ -104,13 +104,15 @@ class NodeFactory:
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
workflow_config: dict[str, Any],
down_stream_nodes: list[str]
) -> WorkflowNode | None:
"""创建节点实例
Args:
node_config: 节点配置
workflow_config: 工作流配置
down_stream_nodes: 下游节点
Returns:
节点实例或 None对于不支持的节点类型
@@ -127,7 +129,7 @@ class NodeFactory:
# 创建节点实例
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
return node_class(node_config, workflow_config, down_stream_nodes)
@classmethod
def get_supported_types(cls) -> list[str]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {}

View File

@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode):
"""问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {}
self.response_metadata = {}

View File

@@ -27,14 +27,8 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""初始化 Start 节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
# 解析并验证配置
self.typed_config: StartNodeConfig | None = None

View File

@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
class ToolNode(BaseNode):
"""工具节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: ToolNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: VariableAggregatorNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]:

View File

@@ -153,7 +153,8 @@ class TemplateRenderer:
# 全局渲染器实例(严格模式)
_default_renderer = TemplateRenderer(strict=True)
_strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template(
@@ -184,7 +185,7 @@ def render_template(
... )
'请分析: 这是一段文本'
"""
renderer = TemplateRenderer(strict=strict)
renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars)
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
Returns:
错误列表
"""
return _default_renderer.validate(template)
return _strict_renderer.validate(template)