Merge #62 into develop from feature/20251219_myh
feat(workflow): add VariableAggregatorNode for aggregating workflow variables * feature/20251219_myh: (1 commits) feat(workflow): add VariableAggregatorNode for aggregating workflow variables Signed-off-by: Eternity <1533512157@qq.com> Commented-by: Eternity <1533512157@qq.com> Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/62
This commit is contained in:
@@ -18,6 +18,7 @@ from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfi
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -36,4 +37,5 @@ __all__ = [
|
||||
"AssignerNodeConfig",
|
||||
"HttpRequestNodeConfig",
|
||||
"JinjaRenderNodeConfig",
|
||||
"VariableAggregatorNodeConfig",
|
||||
]
|
||||
|
||||
@@ -25,6 +25,7 @@ class NodeType(StrEnum):
|
||||
AGENT = "agent"
|
||||
ASSIGNER = "assigner"
|
||||
JINJARENDER = "jinja-render"
|
||||
VAR_AGGREGATOR = "var-aggregator"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
|
||||
@@ -19,6 +19,7 @@ from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,6 +35,7 @@ WorkflowNode = Union[
|
||||
HttpRequestNode,
|
||||
KnowledgeRetrievalNode,
|
||||
JinjaRenderNode,
|
||||
VariableAggregatorNode
|
||||
]
|
||||
|
||||
|
||||
@@ -55,6 +57,7 @@ class NodeFactory:
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.JINJARENDER: JinjaRenderNode,
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNode
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
|
||||
|
||||
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]
|
||||
74
api/app/core/workflow/nodes/variable_aggregator/config.py
Normal file
74
api/app/core/workflow/nodes/variable_aggregator/config.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class VariableAggregatorNodeConfig(BaseNodeConfig):
|
||||
group: bool = Field(
|
||||
...,
|
||||
description="输出变量是否需要分组",
|
||||
)
|
||||
|
||||
group_names: list[str] = Field(
|
||||
default_factory=lambda: ["output"],
|
||||
description="各个分组的名称"
|
||||
)
|
||||
|
||||
group_variables: list[str] | list[list[str]] = Field(
|
||||
...,
|
||||
description="需要被聚合的变量"
|
||||
)
|
||||
|
||||
@field_validator("group_names", mode="before")
|
||||
@classmethod
|
||||
def group_names_validator(cls, v, info):
|
||||
group_status = info.data.get("group")
|
||||
if not group_status or not v:
|
||||
return ["output"]
|
||||
return v
|
||||
|
||||
@field_validator("group_variables")
|
||||
@classmethod
|
||||
def group_variables_validator(cls, v, info):
|
||||
group_status = info.data.get("group")
|
||||
group_names = info.data.get("group_names")
|
||||
if not isinstance(v, list):
|
||||
raise ValueError("group_variables must be a list")
|
||||
|
||||
if not group_status:
|
||||
for variable in v:
|
||||
if not isinstance(variable, str):
|
||||
raise ValueError("When group=False, group_variables must be a list of strings")
|
||||
else:
|
||||
if len(group_names) != len(v):
|
||||
raise ValueError("group_names and group_variables length mismatch")
|
||||
for group in v:
|
||||
if not isinstance(group, list):
|
||||
raise ValueError("When group=True, each element of group_variables must be a list")
|
||||
for variable in group:
|
||||
if not isinstance(variable, str):
|
||||
raise ValueError("Each element inside group_variables lists must be a string")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"group": True,
|
||||
"group_names": [
|
||||
"user_message",
|
||||
"conv_var"
|
||||
],
|
||||
"group_variables": [
|
||||
[
|
||||
"{{start.test_none}}",
|
||||
"{{start.test}}"
|
||||
],
|
||||
[
|
||||
"{{conv.test_1}}",
|
||||
"{{conv.test_2}}"
|
||||
]
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
78
api/app/core/workflow/nodes/variable_aggregator/node.py
Normal file
78
api/app/core/workflow/nodes/variable_aggregator/node.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = VariableAggregatorNodeConfig(**self.config)
|
||||
|
||||
@staticmethod
|
||||
def _get_express(variable_string: str) -> Any:
|
||||
"""
|
||||
Extract the variable name from a template string '{{ var }}'.
|
||||
|
||||
Args:
|
||||
variable_string: A string containing the variable template.
|
||||
|
||||
Returns:
|
||||
The extracted variable name, or the stripped original string if no template is found.
|
||||
"""
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", variable_string).strip()
|
||||
return expression
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the variable aggregation logic.
|
||||
|
||||
Returns:
|
||||
- str: In non-group mode, returns the first non-None variable value.
|
||||
- dict: In group mode, returns a mapping of group_name -> first non-None variable value.
|
||||
"""
|
||||
if not self.typed_config.group:
|
||||
# --------------------------
|
||||
# Non-group mode
|
||||
# --------------------------
|
||||
for variable in self.typed_config.group_variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
value = self.get_variable(var_express, state)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get variable '{var_express}': {e}")
|
||||
continue
|
||||
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
logger.info("No variable found in non-group mode; returning empty string.")
|
||||
return ""
|
||||
|
||||
# --------------------------
|
||||
# Group mode
|
||||
# --------------------------
|
||||
result = {}
|
||||
for group_name, variables in zip(self.typed_config.group_names, self.typed_config.group_variables):
|
||||
for variable in variables:
|
||||
var_express = self._get_express(variable)
|
||||
try:
|
||||
value = self.get_variable(var_express, state)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get variable '{var_express}' in group '{group_name}': {e}")
|
||||
continue
|
||||
|
||||
if value is not None:
|
||||
result[group_name] = value
|
||||
break
|
||||
else:
|
||||
result[group_name] = ""
|
||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||
|
||||
return result
|
||||
Reference in New Issue
Block a user