From 7181d41e51de947da92a1b88bf325367ab98502b Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Thu, 25 Dec 2025 18:38:53 +0800 Subject: [PATCH] feat(workflow): add VariableAggregatorNode for aggregating workflow variables --- api/app/core/workflow/nodes/configs.py | 2 + api/app/core/workflow/nodes/enums.py | 1 + api/app/core/workflow/nodes/node_factory.py | 3 + .../nodes/variable_aggregator/__init__.py | 4 + .../nodes/variable_aggregator/config.py | 74 ++++++++++++++++++ .../nodes/variable_aggregator/node.py | 78 +++++++++++++++++++ 6 files changed, 162 insertions(+) create mode 100644 api/app/core/workflow/nodes/variable_aggregator/__init__.py create mode 100644 api/app/core/workflow/nodes/variable_aggregator/config.py create mode 100644 api/app/core/workflow/nodes/variable_aggregator/node.py diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index 3a87c589..12bb18cb 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -18,6 +18,7 @@ from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfi from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig +from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig __all__ = [ # 基础类 @@ -36,4 +37,5 @@ __all__ = [ "AssignerNodeConfig", "HttpRequestNodeConfig", "JinjaRenderNodeConfig", + "VariableAggregatorNodeConfig", ] diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 40b9a7ef..3ca59695 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -25,6 +25,7 @@ class NodeType(StrEnum): AGENT = "agent" ASSIGNER = "assigner" JINJARENDER = "jinja-render" + VAR_AGGREGATOR = "var-aggregator" class ComparisonOperator(StrEnum): diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index a6d735d0..9a5b5093 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -19,6 +19,7 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode +from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode logger = logging.getLogger(__name__) @@ -34,6 +35,7 @@ WorkflowNode = Union[ HttpRequestNode, KnowledgeRetrievalNode, JinjaRenderNode, + VariableAggregatorNode ] @@ -55,6 +57,7 @@ class NodeFactory: NodeType.ASSIGNER: AssignerNode, NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.JINJARENDER: JinjaRenderNode, + NodeType.VAR_AGGREGATOR: VariableAggregatorNode } @classmethod diff --git a/api/app/core/workflow/nodes/variable_aggregator/__init__.py b/api/app/core/workflow/nodes/variable_aggregator/__init__.py new file mode 100644 index 00000000..7bc9afa7 --- /dev/null +++ b/api/app/core/workflow/nodes/variable_aggregator/__init__.py @@ -0,0 +1,4 @@ +from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig +from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode + +__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"] \ No newline at end of file diff --git a/api/app/core/workflow/nodes/variable_aggregator/config.py b/api/app/core/workflow/nodes/variable_aggregator/config.py new file mode 100644 index 00000000..84f82487 --- /dev/null +++ b/api/app/core/workflow/nodes/variable_aggregator/config.py @@ -0,0 +1,74 @@ +from pydantic import Field, field_validator + +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class VariableAggregatorNodeConfig(BaseNodeConfig): + group: bool = Field( + ..., + description="输出变量是否需要分组", + ) + + group_names: list[str] = Field( + default_factory=lambda: ["output"], + description="各个分组的名称" + ) + + group_variables: list[str] | list[list[str]] = Field( + ..., + description="需要被聚合的变量" + ) + + @field_validator("group_names", mode="before") + @classmethod + def group_names_validator(cls, v, info): + group_status = info.data.get("group") + if not group_status or not v: + return ["output"] + return v + + @field_validator("group_variables") + @classmethod + def group_variables_validator(cls, v, info): + group_status = info.data.get("group") + group_names = info.data.get("group_names") + if not isinstance(v, list): + raise ValueError("group_variables must be a list") + + if not group_status: + for variable in v: + if not isinstance(variable, str): + raise ValueError("When group=False, group_variables must be a list of strings") + else: + if len(group_names) != len(v): + raise ValueError("group_names and group_variables length mismatch") + for group in v: + if not isinstance(group, list): + raise ValueError("When group=True, each element of group_variables must be a list") + for variable in group: + if not isinstance(variable, str): + raise ValueError("Each element inside group_variables lists must be a string") + return v + + class Config: + json_schema_extra = { + "examples": [ + { + "group": True, + "group_names": [ + "user_message", + "conv_var" + ], + "group_variables": [ + [ + "{{start.test_none}}", + "{{start.test}}" + ], + [ + "{{conv.test_1}}", + "{{conv.test_2}}" + ] + ] + } + ] + } diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py new file mode 100644 index 00000000..f53f9269 --- /dev/null +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -0,0 +1,78 @@ +import logging +import re +from typing import Any + +from app.core.workflow.nodes import WorkflowState +from app.core.workflow.nodes.base_node import BaseNode +from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig + +logger = logging.getLogger(__name__) + + +class VariableAggregatorNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = VariableAggregatorNodeConfig(**self.config) + + @staticmethod + def _get_express(variable_string: str) -> Any: + """ + Extract the variable name from a template string '{{ var }}'. + + Args: + variable_string: A string containing the variable template. + + Returns: + The extracted variable name, or the stripped original string if no template is found. + """ + pattern = r"\{\{\s*(.*?)\s*\}\}" + expression = re.sub(pattern, r"\1", variable_string).strip() + return expression + + async def execute(self, state: WorkflowState) -> Any: + """ + Execute the variable aggregation logic. + + Returns: + - str: In non-group mode, returns the first non-None variable value. + - dict: In group mode, returns a mapping of group_name -> first non-None variable value. + """ + if not self.typed_config.group: + # -------------------------- + # Non-group mode + # -------------------------- + for variable in self.typed_config.group_variables: + var_express = self._get_express(variable) + try: + value = self.get_variable(var_express, state) + except Exception as e: + logger.warning(f"Failed to get variable '{var_express}': {e}") + continue + + if value is not None: + return value + + logger.info("No variable found in non-group mode; returning empty string.") + return "" + + # -------------------------- + # Group mode + # -------------------------- + result = {} + for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables): + for variable in variables: + var_express = self._get_express(variable) + try: + value = self.get_variable(var_express, state) + except Exception as e: + logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}") + continue + + if value is not None: + result[group_name] = value + break + else: + result[group_name] = "" + logger.info(f"No variable found for group '{group_name}'; set empty string.") + + return result