Merge #34 into develop from feature/20251219_myh

feat(workflow): add assigner node and fix circular imports with minor code style cleanup

* feature/20251219_myh: (7 commits)
  style(service): workflow
  style(workflow): remove unnecessary indentation
  revert(workflow): read conversation variables from database instead of API input
  feat(workflow): add assigner node and fix circular imports with minor code style cleanup
  fix(workflow): fix incorrect list append/pop logic in assigner node
  fix(workflow): fix incorrect list extend logic in assigner node
  fix(workflow): fix incorrect list append logic in assigner node

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/34
This commit is contained in:
朱文辉
2025-12-23 17:06:43 +08:00
16 changed files with 466 additions and 181 deletions

View File

@@ -5,9 +5,11 @@
"""
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
@@ -23,5 +25,7 @@ __all__ = [
"StartNode",
"EndNode",
"NodeFactory",
"WorkflowNode"
"WorkflowNode",
# "KnowledgeRetrievalNode",
"AssignerNode",
]

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.assigner.node import AssignerNode
__all__ = ["AssignerNode", "AssignerNodeConfig"]

View File

@@ -0,0 +1,21 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import AssignmentOperator
class AssignerNodeConfig(BaseNodeConfig):
variable_selector: str | list[str] = Field(
...,
description="Variables to be assigned",
)
operation: AssignmentOperator = Field(
...,
description="Operator to assign",
)
value: str | list[str] = Field(
...,
description="Values to assign",
)

View File

@@ -0,0 +1,80 @@
import logging
from typing import Any
from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = AssignerNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the assignment operation defined by this node.
Args:
state: The current workflow state, including conversation variables,
node outputs, and system variables.
Returns:
None or the result of the assignment operation.
"""
# Initialize a variable pool for accessing conversation, node, and system variables
pool = VariablePool(state)
# Get the target variable selector (e.g., "conv.test")
variable_selector = self.typed_config.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
variable_selector = variable_selector.split('.')
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
raise ValueError("Only conversation variables can be assigned.")
# Get the value or expression to assign
value = self.typed_config.value
if isinstance(value, list):
value = '.'.join(value)
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
pool, variable_selector, value
)
# Execute the configured assignment operation
match self.typed_config.operation:
case AssignmentOperator.ASSIGN:
operator.assign()
case AssignmentOperator.CLEAR:
operator.clear()
case AssignmentOperator.ADD:
operator.add()
case AssignmentOperator.SUBTRACT:
operator.subtract()
case AssignmentOperator.MULTIPLY:
operator.multiply()
case AssignmentOperator.DIVIDE:
operator.divide()
case AssignmentOperator.APPEND:
operator.append()
case AssignmentOperator.REMOVE_FIRST:
operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")

View File

@@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [
# 基础类
@@ -28,4 +30,6 @@ __all__ = [
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
# "KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",
]

View File

@@ -1,5 +1,14 @@
from enum import StrEnum
from app.core.workflow.nodes.operators import (
StringOperator,
NumberOperator,
AssignmentOperatorType,
BooleanOperator,
ArrayOperator,
ObjectOperator
)
class NodeType(StrEnum):
START = "start"
@@ -14,6 +23,7 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"
ASSIGNER = "assigner"
class ComparisonOperator(StrEnum):
@@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum):
class LogicOperator(StrEnum):
AND = "and"
OR = "or"
class AssignmentOperator(StrEnum):
ASSIGN = "assign"
CLEAR = "clear"
ADD = "add" # +=
SUBTRACT = "subtract" # -=
MULTIPLY = "multiply" # *=
DIVIDE = "divide" # /=
APPEND = "append"
REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first"
@classmethod
def get_operator(cls, obj) -> AssignmentOperatorType:
if isinstance(obj, str):
return StringOperator
elif isinstance(obj, bool):
return BooleanOperator
elif isinstance(obj, (int, float)):
return NumberOperator
elif isinstance(obj, list):
return ArrayOperator
elif isinstance(obj, dict):
return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})")

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail

View File

@@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.models import RedBearLLM, RedBearModelConfig
from app.db import get_db_context
from app.models import ModelType
from app.services.model_service import ModelConfigService
from app.core.exceptions import BusinessException
@@ -136,7 +137,7 @@ class LLMNode(BaseNode):
base_url=api_base,
extra_params=extra_params
),
type=model_type
type=ModelType(model_type)
)
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")

View File

@@ -7,6 +7,7 @@
import logging
from typing import Any, Union
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
@@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.assigner import AssignerNode
logger = logging.getLogger(__name__)
@@ -26,6 +28,8 @@ WorkflowNode = Union[
IfElseNode,
AgentNode,
TransformNode,
AssignerNode,
# KnowledgeRetrievalNode,
]
@@ -42,7 +46,9 @@ class NodeFactory:
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode
NodeType.IF_ELSE: IfElseNode,
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
}
@classmethod
@@ -82,10 +88,6 @@ class NodeFactory:
"""
node_type = node_config.get("type")
# 跳过条件节点(由 LangGraph 处理)
if node_type == "condition":
return None
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:

View File

@@ -0,0 +1,146 @@
from abc import ABC
from typing import Union, Type
from app.core.workflow.variable_pool import VariablePool
class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right):
self.pool = pool
self.left_selector = left_selector
self.right = right
self.type_limit: type[str, int, dict, list] = None
def check(self, no_right=False):
left = self.pool.get(self.left_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
if not no_right and not isinstance(self.right, self.type_limit):
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
class StringOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = str
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, '')
class NumberOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = (float, int)
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, 0)
def add(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin + self.right)
def subtract(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin - self.right)
def multiply(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin * self.right)
def divide(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin / self.right)
class BooleanOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = bool
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, False)
class ArrayOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = list
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, list())
def append(self) -> None:
self.check(no_right=True)
# TODOrequire type limit in list
origin = self.pool.get(self.left_selector)
origin.append(self.right)
self.pool.set(self.left_selector, origin)
def extend(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin.extend(self.right)
self.pool.set(self.left_selector, origin)
def remove_last(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin.pop()
self.pool.set(self.left_selector, origin)
def remove_first(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
origin.pop(0)
self.pool.set(self.left_selector, origin)
class ObjectOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = object
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, dict())
AssignmentOperatorInstance = Union[
StringOperator,
NumberOperator,
BooleanOperator,
ArrayOperator,
ObjectOperator
]
AssignmentOperatorType = Type[AssignmentOperatorInstance]

View File

@@ -10,7 +10,10 @@
"""
import logging
from typing import Any
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from app.core.workflow.nodes import WorkflowState
logger = logging.getLogger(__name__)
@@ -82,7 +85,7 @@ class VariablePool:
>>> pool.set(["conv", "user_name"], "张三")
"""
def __init__(self, state: dict[str, Any]):
def __init__(self, state: "WorkflowState"):
"""初始化变量池
Args:

View File

@@ -15,25 +15,6 @@ class ModelType(StrEnum):
EMBEDDING = "embedding"
RERANK = "rerank"
@classmethod
def from_str(cls, value: str) -> "ModelType":
"""
Get a ModelType enum instance from a string value.
Args:
value (str): The string representation of the model type.
Returns:
ModelType: The corresponding ModelType enum object.
Raises:
ValueError: If the given value does not match any ModelType.
"""
try:
return cls(value)
except ValueError:
raise ValueError(f"Invalid ModelType: {value}")
class ModelProvider(StrEnum):
"""模型提供商枚举"""

View File

@@ -1,6 +1,7 @@
import uuid
import datetime
from typing import Optional, Any, List, Dict, TYPE_CHECKING
import uuid
from typing import Optional, Any, List, Dict
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
@@ -33,7 +34,6 @@ class KnowledgeRetrievalConfig(BaseModel):
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
class ToolConfig(BaseModel):
"""工具配置"""
enabled: bool = Field(default=False, description="是否启用该工具")

View File

@@ -169,7 +169,7 @@ class PromptOptimizerService:
provider=api_config.provider,
api_key=api_config.api_key,
base_url=api_config.api_base
), type=ModelType.from_str(model_config.type))
), type=ModelType(model_config.type))
# build message
messages = [