perf(workflow): Optimize downstream node activation method to reduce performance overhead
This commit is contained in:
@@ -1099,7 +1099,6 @@ class ExtractionOrchestrator:
|
|||||||
metadata=chunk.metadata,
|
metadata=chunk.metadata,
|
||||||
)
|
)
|
||||||
chunk_nodes.append(chunk_node)
|
chunk_nodes.append(chunk_node)
|
||||||
logger.error(f"chunk file: {chunk.files}")
|
|
||||||
|
|
||||||
for p, file_type in chunk.files:
|
for p, file_type in chunk.files:
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import re
|
|||||||
import uuid
|
import uuid
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, Iterable
|
from typing import Any, Iterable, Callable
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import InMemorySaver
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
from langgraph.graph import START, END
|
from langgraph.graph import START, END
|
||||||
@@ -41,39 +41,31 @@ class GraphBuilder:
|
|||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
subgraph: bool = False,
|
cycle: str = '',
|
||||||
variable_pool: VariablePool | None = None
|
variable_pool: VariablePool | None = None
|
||||||
):
|
):
|
||||||
self.workflow_config = workflow_config
|
self.workflow_config = workflow_config
|
||||||
|
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.subgraph = subgraph
|
self.cycle = cycle
|
||||||
|
|
||||||
self.start_node_id: str | None = None
|
self.start_node_id: str | None = None
|
||||||
|
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map: dict[str, dict] = {}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_activation_dep = lru_cache(
|
self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep
|
||||||
maxsize=len(self.nodes) * 2
|
|
||||||
)(self._find_upstream_activation_dep)
|
|
||||||
if variable_pool:
|
if variable_pool:
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
else:
|
else:
|
||||||
self.variable_pool = VariablePool()
|
self.variable_pool = VariablePool()
|
||||||
|
|
||||||
self.graph: StateGraph | None = None
|
self.graph: StateGraph | None = None
|
||||||
|
self.nodes: list = []
|
||||||
|
self.edges: list = []
|
||||||
self.reachable_nodes: set[str] | None = None
|
self.reachable_nodes: set[str] | None = None
|
||||||
self.end_nodes: list[dict] = []
|
self.end_nodes: list[dict] = []
|
||||||
self._reverse_adj: dict[str, list[dict]] | None = defaultdict(list)
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
self._adj: dict[str, list[str]] | None = defaultdict(list)
|
self._adj: dict[str, list[str]] = defaultdict(list)
|
||||||
|
|
||||||
@property
|
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
|
||||||
return self.workflow_config.get("nodes", [])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def edges(self) -> list[dict[str, Any]]:
|
|
||||||
return self.workflow_config.get("edges", [])
|
|
||||||
|
|
||||||
def get_node_type(self, node_id: str) -> str:
|
def get_node_type(self, node_id: str) -> str:
|
||||||
"""Retrieve the type of node given its ID.
|
"""Retrieve the type of node given its ID.
|
||||||
@@ -294,22 +286,13 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node_type = node.get("type")
|
node_type = node.get("type")
|
||||||
if node_type == NodeType.NOTES:
|
|
||||||
continue
|
|
||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
cycle_node = node.get("cycle")
|
if node_id not in self.reachable_nodes:
|
||||||
if cycle_node:
|
continue
|
||||||
# Nodes within a loop subgraph are constructed by CycleGraphNode
|
|
||||||
if not self.subgraph:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Record start and end node IDs
|
|
||||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
|
||||||
self.start_node_id = node_id
|
|
||||||
|
|
||||||
# Create node instance (start and end nodes are also created)
|
# Create node instance (start and end nodes are also created)
|
||||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id])
|
||||||
|
|
||||||
if node_type in BRANCH_NODES:
|
if node_type in BRANCH_NODES:
|
||||||
|
|
||||||
@@ -503,21 +486,46 @@ class GraphBuilder:
|
|||||||
return
|
return
|
||||||
|
|
||||||
def build(self) -> CompiledStateGraph:
|
def build(self) -> CompiledStateGraph:
|
||||||
self.graph = StateGraph(WorkflowState)
|
nodes = self.workflow_config.get("nodes", [])
|
||||||
self.add_nodes()
|
edges = self.workflow_config.get("edges", [])
|
||||||
|
|
||||||
|
for node in nodes:
|
||||||
|
if (node.get("cycle") or '') == self.cycle:
|
||||||
|
node_type = node.get("type")
|
||||||
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
|
self.start_node_id = node.get("id")
|
||||||
|
elif node_type == NodeType.NOTES:
|
||||||
|
continue
|
||||||
|
self.nodes.append(node)
|
||||||
|
self.node_map[node.get("id")] = node
|
||||||
|
|
||||||
|
for edge in edges:
|
||||||
|
source_in = edge.get("source") in self.node_map
|
||||||
|
target_in = edge.get("target") in self.node_map
|
||||||
|
if source_in ^ target_in:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cycle node is connected to external node, "
|
||||||
|
f"source: {edge.get('source')}, target: {edge.get('target')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if source_in and target_in:
|
||||||
|
self.edges.append(edge)
|
||||||
|
|
||||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
self.end_nodes = [
|
self.end_nodes = [
|
||||||
node
|
node
|
||||||
for node in self.nodes
|
for node in self.nodes
|
||||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
]
|
]
|
||||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
|
||||||
self._adj: dict[str, list[str]] = defaultdict(list)
|
|
||||||
self._build_adj()
|
self._build_adj()
|
||||||
|
self._find_upstream_activation_dep: Callable = lru_cache(
|
||||||
|
maxsize=len(self.nodes)*2
|
||||||
|
)(self._find_upstream_activation_dep)
|
||||||
|
|
||||||
|
self.graph = StateGraph(WorkflowState)
|
||||||
|
self.add_nodes()
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
|
||||||
|
|
||||||
self._analyze_end_node_output()
|
self._analyze_end_node_output()
|
||||||
checkpointer = InMemorySaver()
|
checkpointer = InMemorySaver()
|
||||||
return self.graph.compile(checkpointer=checkpointer)
|
return self.graph.compile(checkpointer=checkpointer)
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class AssignerNode(BaseNode):
|
class AssignerNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.variable_updater = True
|
self.variable_updater = True
|
||||||
self.typed_config: AssignerNodeConfig | None = None
|
self.typed_config: AssignerNodeConfig | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class BaseNode(ABC):
|
|||||||
All node types should inherit from this class and implement the `execute` method.
|
All node types should inherit from this class and implement the `execute` method.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
"""Initialize the node.
|
"""Initialize the node.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -41,6 +41,7 @@ class BaseNode(ABC):
|
|||||||
self.node_type = node_config["type"]
|
self.node_type = node_config["type"]
|
||||||
self.cycle = node_config.get("cycle")
|
self.cycle = node_config.get("cycle")
|
||||||
self.node_name = node_config.get("name", self.node_id)
|
self.node_name = node_config.get("name", self.node_id)
|
||||||
|
self.down_stream_nodes = down_stream_nodes
|
||||||
# 使用 or 运算符处理 None 值
|
# 使用 or 运算符处理 None 值
|
||||||
self.config = node_config.get("config") or {}
|
self.config = node_config.get("config") or {}
|
||||||
self.error_handling = node_config.get("error_handling") or {}
|
self.error_handling = node_config.get("error_handling") or {}
|
||||||
@@ -93,18 +94,16 @@ class BaseNode(ABC):
|
|||||||
dict: A dict with a single key 'activate', mapping node IDs to
|
dict: A dict with a single key 'activate', mapping node IDs to
|
||||||
their activation status (True/False).
|
their activation status (True/False).
|
||||||
"""
|
"""
|
||||||
edges = self.workflow_config.get("edges")
|
activate_flag = self.check_activate(state)
|
||||||
under_stream_nodes = [
|
|
||||||
edge.get("target")
|
if self.node_type not in BRANCH_NODES:
|
||||||
for edge in edges
|
activate = {node_id: activate_flag for node_id in self.down_stream_nodes}
|
||||||
if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
|
else:
|
||||||
]
|
activate = {}
|
||||||
return {
|
|
||||||
"activate": {
|
activate[self.node_id] = activate_flag
|
||||||
node_id: self.check_activate(state)
|
|
||||||
for node_id in under_stream_nodes
|
return {"activate": activate}
|
||||||
} | {self.node_id: self.check_activate(state)}
|
|
||||||
}
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
@@ -428,8 +427,8 @@ class BaseNode(ABC):
|
|||||||
when an error edge exists. If no error edge exists, this method
|
when an error edge exists. If no error edge exists, this method
|
||||||
raises an exception to stop the workflow.
|
raises an exception to stop the workflow.
|
||||||
"""
|
"""
|
||||||
# Check if the node has an error edge defined
|
# # Check if the node has an error edge defined
|
||||||
error_edge = self._find_error_edge()
|
# error_edge = self._find_error_edge()
|
||||||
|
|
||||||
# Extract input data (for logging or audit purposes)
|
# Extract input data (for logging or audit purposes)
|
||||||
input_data = self._extract_input(state, variable_pool)
|
input_data = self._extract_input(state, variable_pool)
|
||||||
@@ -447,27 +446,26 @@ class BaseNode(ABC):
|
|||||||
"error": error_message
|
"error": error_message
|
||||||
}
|
}
|
||||||
|
|
||||||
if error_edge:
|
# if error_edge:
|
||||||
# If an error edge exists, log a warning and continue to error node
|
# # If an error edge exists, log a warning and continue to error node
|
||||||
logger.warning(
|
# logger.warning(
|
||||||
f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
# f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
|
||||||
)
|
# )
|
||||||
return {
|
# return {
|
||||||
"node_outputs": {
|
# "node_outputs": {
|
||||||
self.node_id: node_output
|
# self.node_id: node_output
|
||||||
},
|
# },
|
||||||
"error": error_message,
|
# "error": error_message,
|
||||||
"error_node": self.node_id
|
# "error_node": self.node_id
|
||||||
}
|
# }
|
||||||
else:
|
# else:
|
||||||
# If no error edge, send the error via stream writer and stop the workflow
|
writer = get_stream_writer()
|
||||||
writer = get_stream_writer()
|
writer({
|
||||||
writer({
|
"type": "node_error",
|
||||||
"type": "node_error",
|
**node_output
|
||||||
**node_output
|
})
|
||||||
})
|
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
||||||
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
|
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
||||||
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
|
|
||||||
|
|
||||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""Extracts the input data for this node (used for logging or audit).
|
"""Extracts the input data for this node (used for logging or audit).
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ console.log(result)
|
|||||||
|
|
||||||
|
|
||||||
class CodeNode(BaseNode):
|
class CodeNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: CodeNodeConfig | None = None
|
self.typed_config: CodeNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -30,8 +30,8 @@ class CycleGraphNode(BaseNode):
|
|||||||
It acts as a container and execution controller for a subgraph.
|
It acts as a container and execution controller for a subgraph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
|
||||||
self.start_node_id = None # ID of the start node within the cycle
|
self.start_node_id = None # ID of the start node within the cycle
|
||||||
|
|
||||||
@@ -115,11 +115,11 @@ class CycleGraphNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
remain_edges.append(edge)
|
remain_edges.append(edge)
|
||||||
|
|
||||||
# Update workflow_config by removing cycle nodes and internal edges
|
# # Update workflow_config by removing cycle nodes and internal edges
|
||||||
self.workflow_config["nodes"] = [
|
# self.workflow_config["nodes"] = [
|
||||||
node for node in nodes if node.get("cycle") != self.node_id
|
# node for node in nodes if node.get("cycle") != self.node_id
|
||||||
]
|
# ]
|
||||||
self.workflow_config["edges"] = remain_edges
|
# self.workflow_config["edges"] = remain_edges
|
||||||
|
|
||||||
return cycle_nodes, cycle_edges
|
return cycle_nodes, cycle_edges
|
||||||
|
|
||||||
@@ -140,8 +140,8 @@ class CycleGraphNode(BaseNode):
|
|||||||
"nodes": self.cycle_nodes,
|
"nodes": self.cycle_nodes,
|
||||||
"edges": self.cycle_edges,
|
"edges": self.cycle_edges,
|
||||||
},
|
},
|
||||||
subgraph=True,
|
variable_pool=self.child_variable_pool,
|
||||||
variable_pool=self.child_variable_pool
|
cycle=self.node_id
|
||||||
)
|
)
|
||||||
self.graph = builder.build()
|
self.graph = builder.build()
|
||||||
self.start_node_id = builder.start_node_id
|
self.start_node_id = builder.start_node_id
|
||||||
|
|||||||
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
|
|||||||
or a branch identifier string when error branching is enabled.
|
or a branch identifier string when error branching is enabled.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: HttpRequestNodeConfig | None = None
|
self.typed_config: HttpRequestNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class IfElseNode(BaseNode):
|
class IfElseNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: IfElseNodeConfig | None = None
|
self.typed_config: IfElseNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class JinjaRenderNode(BaseNode):
|
class JinjaRenderNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: JinjaRenderNodeConfig | None = None
|
self.typed_config: JinjaRenderNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class KnowledgeRetrievalNode(BaseNode):
|
class KnowledgeRetrievalNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
self.vector_service: ElasticSearchVector | None = None
|
self.vector_service: ElasticSearchVector | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
|
|||||||
- ai/assistant: AI 消息(AIMessage)
|
- ai/assistant: AI 消息(AIMessage)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: LLMNodeConfig | None = None
|
self.typed_config: LLMNodeConfig | None = None
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
|
|||||||
@@ -14,8 +14,8 @@ from app.tasks import write_message_task
|
|||||||
|
|
||||||
|
|
||||||
class MemoryReadNode(BaseNode):
|
class MemoryReadNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: MemoryReadNodeConfig | None = None
|
self.typed_config: MemoryReadNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
@@ -45,8 +45,8 @@ class MemoryReadNode(BaseNode):
|
|||||||
|
|
||||||
|
|
||||||
class MemoryWriteNode(BaseNode):
|
class MemoryWriteNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: MemoryWriteNodeConfig | None = None
|
self.typed_config: MemoryWriteNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -104,13 +104,15 @@ class NodeFactory:
|
|||||||
def create_node(
|
def create_node(
|
||||||
cls,
|
cls,
|
||||||
node_config: dict[str, Any],
|
node_config: dict[str, Any],
|
||||||
workflow_config: dict[str, Any]
|
workflow_config: dict[str, Any],
|
||||||
|
down_stream_nodes: list[str]
|
||||||
) -> WorkflowNode | None:
|
) -> WorkflowNode | None:
|
||||||
"""创建节点实例
|
"""创建节点实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_config: 节点配置
|
node_config: 节点配置
|
||||||
workflow_config: 工作流配置
|
workflow_config: 工作流配置
|
||||||
|
down_stream_nodes: 下游节点
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点实例或 None(对于不支持的节点类型)
|
节点实例或 None(对于不支持的节点类型)
|
||||||
@@ -127,7 +129,7 @@ class NodeFactory:
|
|||||||
|
|
||||||
# 创建节点实例
|
# 创建节点实例
|
||||||
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
|
||||||
return node_class(node_config, workflow_config)
|
return node_class(node_config, workflow_config, down_stream_nodes)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_types(cls) -> list[str]:
|
def get_supported_types(cls) -> list[str]:
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ParameterExtractorNode(BaseNode):
|
class ParameterExtractorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: ParameterExtractorNodeConfig | None = None
|
self.typed_config: ParameterExtractorNodeConfig | None = None
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
|||||||
class QuestionClassifierNode(BaseNode):
|
class QuestionClassifierNode(BaseNode):
|
||||||
"""问题分类器节点"""
|
"""问题分类器节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: QuestionClassifierNodeConfig | None = None
|
self.typed_config: QuestionClassifierNodeConfig | None = None
|
||||||
self.category_to_case_map = {}
|
self.category_to_case_map = {}
|
||||||
self.response_metadata = {}
|
self.response_metadata = {}
|
||||||
|
|||||||
@@ -27,14 +27,8 @@ class StartNode(BaseNode):
|
|||||||
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
注意:变量的验证和默认值处理由 Executor 在初始化时完成。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
"""初始化 Start 节点
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
|
|
||||||
Args:
|
|
||||||
node_config: 节点配置
|
|
||||||
workflow_config: 工作流配置
|
|
||||||
"""
|
|
||||||
super().__init__(node_config, workflow_config)
|
|
||||||
|
|
||||||
# 解析并验证配置
|
# 解析并验证配置
|
||||||
self.typed_config: StartNodeConfig | None = None
|
self.typed_config: StartNodeConfig | None = None
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ TEMPLATE_PATTERN = re.compile(r"\{\{.*?}}")
|
|||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
"""工具节点"""
|
"""工具节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: ToolNodeConfig | None = None
|
self.typed_config: ToolNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VariableAggregatorNode(BaseNode):
|
class VariableAggregatorNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: VariableAggregatorNodeConfig | None = None
|
self.typed_config: VariableAggregatorNodeConfig | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
|
|||||||
@@ -153,7 +153,8 @@ class TemplateRenderer:
|
|||||||
|
|
||||||
|
|
||||||
# 全局渲染器实例(严格模式)
|
# 全局渲染器实例(严格模式)
|
||||||
_default_renderer = TemplateRenderer(strict=True)
|
_strict_renderer = TemplateRenderer(strict=True)
|
||||||
|
_lenient_renderer = TemplateRenderer(strict=False)
|
||||||
|
|
||||||
|
|
||||||
def render_template(
|
def render_template(
|
||||||
@@ -184,7 +185,7 @@ def render_template(
|
|||||||
... )
|
... )
|
||||||
'请分析: 这是一段文本'
|
'请分析: 这是一段文本'
|
||||||
"""
|
"""
|
||||||
renderer = TemplateRenderer(strict=strict)
|
renderer = _strict_renderer if strict else _lenient_renderer
|
||||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||||
|
|
||||||
|
|
||||||
@@ -197,4 +198,4 @@ def validate_template(template: str) -> list[str]:
|
|||||||
Returns:
|
Returns:
|
||||||
错误列表
|
错误列表
|
||||||
"""
|
"""
|
||||||
return _default_renderer.validate(template)
|
return _strict_renderer.validate(template)
|
||||||
|
|||||||
Reference in New Issue
Block a user