Merge #62 into develop from feature/20251219_myh

feat(workflow): add VariableAggregatorNode for aggregating workflow variables

* feature/20251219_myh: (1 commits)
  feat(workflow): add VariableAggregatorNode for aggregating workflow variables

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/62
This commit is contained in:
朱文辉
2025-12-25 18:44:30 +08:00
6 changed files with 162 additions and 0 deletions

View File

@@ -18,6 +18,7 @@ from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfi
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
__all__ = [ __all__ = [
# 基础类 # 基础类
@@ -36,4 +37,5 @@ __all__ = [
"AssignerNodeConfig", "AssignerNodeConfig",
"HttpRequestNodeConfig", "HttpRequestNodeConfig",
"JinjaRenderNodeConfig", "JinjaRenderNodeConfig",
"VariableAggregatorNodeConfig",
] ]

View File

@@ -25,6 +25,7 @@ class NodeType(StrEnum):
AGENT = "agent" AGENT = "agent"
ASSIGNER = "assigner" ASSIGNER = "assigner"
JINJARENDER = "jinja-render" JINJARENDER = "jinja-render"
VAR_AGGREGATOR = "var-aggregator"
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -19,6 +19,7 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -34,6 +35,7 @@ WorkflowNode = Union[
HttpRequestNode, HttpRequestNode,
KnowledgeRetrievalNode, KnowledgeRetrievalNode,
JinjaRenderNode, JinjaRenderNode,
VariableAggregatorNode
] ]
@@ -55,6 +57,7 @@ class NodeFactory:
NodeType.ASSIGNER: AssignerNode, NodeType.ASSIGNER: AssignerNode,
NodeType.HTTP_REQUEST: HttpRequestNode, NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.JINJARENDER: JinjaRenderNode, NodeType.JINJARENDER: JinjaRenderNode,
NodeType.VAR_AGGREGATOR: VariableAggregatorNode
} }
@classmethod @classmethod

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]

View File

@@ -0,0 +1,74 @@
from pydantic import Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
class VariableAggregatorNodeConfig(BaseNodeConfig):
group: bool = Field(
...,
description="输出变量是否需要分组",
)
group_names: list[str] = Field(
default_factory=lambda: ["output"],
description="各个分组的名称"
)
group_variables: list[str] | list[list[str]] = Field(
...,
description="需要被聚合的变量"
)
@field_validator("group_names", mode="before")
@classmethod
def group_names_validator(cls, v, info):
group_status = info.data.get("group")
if not group_status or not v:
return ["output"]
return v
@field_validator("group_variables")
@classmethod
def group_variables_validator(cls, v, info):
group_status = info.data.get("group")
group_names = info.data.get("group_names")
if not isinstance(v, list):
raise ValueError("group_variables must be a list")
if not group_status:
for variable in v:
if not isinstance(variable, str):
raise ValueError("When group=False, group_variables must be a list of strings")
else:
if len(group_names) != len(v):
raise ValueError("group_names and group_variables length mismatch")
for group in v:
if not isinstance(group, list):
raise ValueError("When group=True, each element of group_variables must be a list")
for variable in group:
if not isinstance(variable, str):
raise ValueError("Each element inside group_variables lists must be a string")
return v
class Config:
json_schema_extra = {
"examples": [
{
"group": True,
"group_names": [
"user_message",
"conv_var"
],
"group_variables": [
[
"{{start.test_none}}",
"{{start.test}}"
],
[
"{{conv.test_1}}",
"{{conv.test_2}}"
]
]
}
]
}

View File

@@ -0,0 +1,78 @@
import logging
import re
from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = VariableAggregatorNodeConfig(**self.config)
@staticmethod
def _get_express(variable_string: str) -> Any:
"""
Extract the variable name from a template string '{{ var }}'.
Args:
variable_string: A string containing the variable template.
Returns:
The extracted variable name, or the stripped original string if no template is found.
"""
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", variable_string).strip()
return expression
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the variable aggregation logic.
Returns:
- str: In non-group mode, returns the first non-None variable value.
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
"""
if not self.typed_config.group:
# --------------------------
# Non-group mode
# --------------------------
for variable in self.typed_config.group_variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}': {e}")
continue
if value is not None:
return value
logger.info("No variable found in non-group mode; returning empty string.")
return ""
# --------------------------
# Group mode
# --------------------------
result = {}
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables):
for variable in variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}")
continue
if value is not None:
result[group_name] = value
break
else:
result[group_name] = ""
logger.info(f"No variable found for group '{group_name}'; set empty string.")
return result