Merge #62 into develop from feature/20251219_myh

feat(workflow): add VariableAggregatorNode for aggregating workflow variables

* feature/20251219_myh: (1 commits)
  feat(workflow): add VariableAggregatorNode for aggregating workflow variables

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/62
This commit is contained in:
朱文辉
2025-12-25 18:44:30 +08:00
6 changed files with 162 additions and 0 deletions

View File

@@ -18,6 +18,7 @@ from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfi
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.start.config import StartNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
__all__ = [
# 基础类
@@ -36,4 +37,5 @@ __all__ = [
"AssignerNodeConfig",
"HttpRequestNodeConfig",
"JinjaRenderNodeConfig",
"VariableAggregatorNodeConfig",
]

View File

@@ -25,6 +25,7 @@ class NodeType(StrEnum):
AGENT = "agent"
ASSIGNER = "assigner"
JINJARENDER = "jinja-render"
VAR_AGGREGATOR = "var-aggregator"
class ComparisonOperator(StrEnum):

View File

@@ -19,6 +19,7 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
logger = logging.getLogger(__name__)
@@ -34,6 +35,7 @@ WorkflowNode = Union[
HttpRequestNode,
KnowledgeRetrievalNode,
JinjaRenderNode,
VariableAggregatorNode
]
@@ -55,6 +57,7 @@ class NodeFactory:
NodeType.ASSIGNER: AssignerNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.JINJARENDER: JinjaRenderNode,
NodeType.VAR_AGGREGATOR: VariableAggregatorNode
}
@classmethod

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]

View File

@@ -0,0 +1,74 @@
from pydantic import Field, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
class VariableAggregatorNodeConfig(BaseNodeConfig):
group: bool = Field(
...,
description="输出变量是否需要分组",
)
group_names: list[str] = Field(
default_factory=lambda: ["output"],
description="各个分组的名称"
)
group_variables: list[str] | list[list[str]] = Field(
...,
description="需要被聚合的变量"
)
@field_validator("group_names", mode="before")
@classmethod
def group_names_validator(cls, v, info):
group_status = info.data.get("group")
if not group_status or not v:
return ["output"]
return v
@field_validator("group_variables")
@classmethod
def group_variables_validator(cls, v, info):
group_status = info.data.get("group")
group_names = info.data.get("group_names")
if not isinstance(v, list):
raise ValueError("group_variables must be a list")
if not group_status:
for variable in v:
if not isinstance(variable, str):
raise ValueError("When group=False, group_variables must be a list of strings")
else:
if len(group_names) != len(v):
raise ValueError("group_names and group_variables length mismatch")
for group in v:
if not isinstance(group, list):
raise ValueError("When group=True, each element of group_variables must be a list")
for variable in group:
if not isinstance(variable, str):
raise ValueError("Each element inside group_variables lists must be a string")
return v
class Config:
json_schema_extra = {
"examples": [
{
"group": True,
"group_names": [
"user_message",
"conv_var"
],
"group_variables": [
[
"{{start.test_none}}",
"{{start.test}}"
],
[
"{{conv.test_1}}",
"{{conv.test_2}}"
]
]
}
]
}

View File

@@ -0,0 +1,78 @@
import logging
import re
from typing import Any
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
logger = logging.getLogger(__name__)
class VariableAggregatorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = VariableAggregatorNodeConfig(**self.config)
@staticmethod
def _get_express(variable_string: str) -> Any:
"""
Extract the variable name from a template string '{{ var }}'.
Args:
variable_string: A string containing the variable template.
Returns:
The extracted variable name, or the stripped original string if no template is found.
"""
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", variable_string).strip()
return expression
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the variable aggregation logic.
Returns:
- str: In non-group mode, returns the first non-None variable value.
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
"""
if not self.typed_config.group:
# --------------------------
# Non-group mode
# --------------------------
for variable in self.typed_config.group_variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}': {e}")
continue
if value is not None:
return value
logger.info("No variable found in non-group mode; returning empty string.")
return ""
# --------------------------
# Group mode
# --------------------------
result = {}
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables):
for variable in variables:
var_express = self._get_express(variable)
try:
value = self.get_variable(var_express, state)
except Exception as e:
logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}")
continue
if value is not None:
result[group_name] = value
break
else:
result[group_name] = ""
logger.info(f"No variable found for group '{group_name}'; set empty string.")
return result