feat(workflow): add assigner node and fix circular imports with minor code style cleanup

This commit is contained in:
mengyonghao
2025-12-22 20:04:18 +08:00
parent 92c62bb46f
commit 054a5976f5
9 changed files with 299 additions and 5 deletions

View File

@@ -5,9 +5,11 @@
"""
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.assigner import AssignerNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.if_else import IfElseNode
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
@@ -23,5 +25,7 @@ __all__ = [
"StartNode",
"EndNode",
"NodeFactory",
"WorkflowNode"
"WorkflowNode",
# "KnowledgeRetrievalNode",
"AssignerNode",
]

View File

@@ -0,0 +1,4 @@
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.assigner.node import AssignerNode
__all__ = ["AssignerNode", "AssignerNodeConfig"]

View File

@@ -0,0 +1,21 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import AssignmentOperator
class AssignerNodeConfig(BaseNodeConfig):
variable_selector: str | list[str] = Field(
...,
description="Variables to be assigned",
)
operation: AssignmentOperator = Field(
...,
description="Operator to assign",
)
value: str | list[str] = Field(
...,
description="Values to assign",
)

View File

@@ -0,0 +1,80 @@
import logging
from typing import Any
from app.core.workflow.expression_evaluator import ExpressionEvaluator
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import AssignmentOperator
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = AssignerNodeConfig(**self.config)
async def execute(self, state: WorkflowState) -> Any:
"""
Execute the assignment operation defined by this node.
Args:
state: The current workflow state, including conversation variables,
node outputs, and system variables.
Returns:
None or the result of the assignment operation.
"""
# Initialize a variable pool for accessing conversation, node, and system variables
pool = VariablePool(state)
# Get the target variable selector (e.g., "conv.test")
variable_selector = self.typed_config.variable_selector
if isinstance(variable_selector, str):
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
variable_selector = variable_selector.split('.')
# Only conversation variables ('conv') are allowed
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
raise ValueError("Only conversation variables can be assigned.")
# Get the value or expression to assign
value = self.typed_config.value
if isinstance(value, list):
value = '.'.join(value)
value = ExpressionEvaluator.evaluate(
expression=value,
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
# Select the appropriate assignment operator instance based on the target variable type
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
pool, variable_selector, value
)
# Execute the configured assignment operation
match self.typed_config.operation:
case AssignmentOperator.ASSIGN:
operator.assign()
case AssignmentOperator.CLEAR:
operator.clear()
case AssignmentOperator.ADD:
operator.add()
case AssignmentOperator.SUBTRACT:
operator.subtract()
case AssignmentOperator.MULTIPLY:
operator.multiply()
case AssignmentOperator.DIVIDE:
operator.divide()
case AssignmentOperator.APPEND:
operator.append()
case AssignmentOperator.REMOVE_FIRST:
operator.remove_first()
case AssignmentOperator.REMOVE_LAST:
operator.remove_last()
case _:
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")

View File

@@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
__all__ = [
# 基础类
@@ -28,4 +30,6 @@ __all__ = [
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
# "KnowledgeRetrievalNodeConfig",
"AssignerNodeConfig",
]

View File

@@ -1,5 +1,14 @@
from enum import StrEnum
from app.core.workflow.nodes.operators import (
StringOperator,
NumberOperator,
AssignmentOperatorType,
BooleanOperator,
ArrayOperator,
ObjectOperator
)
class NodeType(StrEnum):
START = "start"
@@ -14,6 +23,7 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"
ASSIGNER = "assigner"
class ComparisonOperator(StrEnum):
@@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum):
class LogicOperator(StrEnum):
AND = "and"
OR = "or"
class AssignmentOperator(StrEnum):
ASSIGN = "assign"
CLEAR = "clear"
ADD = "add" # +=
SUBTRACT = "subtract" # -=
MULTIPLY = "multiply" # *=
DIVIDE = "divide" # /=
APPEND = "append"
REMOVE_LAST = "remove_last"
REMOVE_FIRST = "remove_first"
@classmethod
def get_operator(cls, obj) -> AssignmentOperatorType:
if isinstance(obj, str):
return StringOperator
elif isinstance(obj, bool):
return BooleanOperator
elif isinstance(obj, (int, float)):
return NumberOperator
elif isinstance(obj, list):
return ArrayOperator
elif isinstance(obj, dict):
return ObjectOperator
raise TypeError(f"Unsupported variable type ({type(obj)})")

View File

@@ -1,7 +1,7 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail

View File

@@ -7,7 +7,7 @@
import logging
from typing import Any, Union
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
@@ -29,7 +29,7 @@ WorkflowNode = Union[
AgentNode,
TransformNode,
AssignerNode,
KnowledgeRetrievalNode,
# KnowledgeRetrievalNode,
]
@@ -47,7 +47,7 @@ class NodeFactory:
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode,
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
NodeType.ASSIGNER: AssignerNode,
}

View File

@@ -0,0 +1,142 @@
from abc import ABC
from typing import Union, Type
from app.core.workflow.variable_pool import VariablePool
class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right):
self.pool = pool
self.left_selector = left_selector
self.right = right
self.type_limit: type[str, int, dict, list] = None
def check(self, no_right=False):
left = self.pool.get(self.left_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
if not no_right and not isinstance(self.right, self.type_limit):
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
class StringOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = str
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, '')
class NumberOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = (float, int)
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, 0)
def add(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin + self.right)
def subtract(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin - self.right)
def multiply(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin * self.right)
def divide(self) -> None:
self.check()
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin / self.right)
class BooleanOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = bool
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, False)
class ArrayOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = list
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, list())
def append(self) -> None:
self.check()
# TODOrequire type limit in list
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin.append(self.right))
def extend(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin.extend(self.right))
def remove_last(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin.pop())
def remove_first(self) -> None:
self.check(no_right=True)
origin = self.pool.get(self.left_selector)
self.pool.set(self.left_selector, origin.pop(0))
class ObjectOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = object
def assign(self) -> None:
self.check()
self.pool.set(self.left_selector, self.right)
def clear(self) -> None:
self.check(no_right=True)
self.pool.set(self.left_selector, dict())
AssignmentOperatorInstance = Union[
StringOperator,
NumberOperator,
BooleanOperator,
ArrayOperator,
ObjectOperator
]
AssignmentOperatorType = Type[AssignmentOperatorInstance]