feat(workflow): add assigner node and fix circular imports with minor code style cleanup
This commit is contained in:
@@ -5,9 +5,11 @@
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
@@ -23,5 +25,7 @@ __all__ = [
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
"NodeFactory",
|
||||
"WorkflowNode"
|
||||
"WorkflowNode",
|
||||
# "KnowledgeRetrievalNode",
|
||||
"AssignerNode",
|
||||
]
|
||||
|
||||
4
api/app/core/workflow/nodes/assigner/__init__.py
Normal file
4
api/app/core/workflow/nodes/assigner/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.assigner.node import AssignerNode
|
||||
|
||||
__all__ = ["AssignerNode", "AssignerNodeConfig"]
|
||||
21
api/app/core/workflow/nodes/assigner/config.py
Normal file
21
api/app/core/workflow/nodes/assigner/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
|
||||
|
||||
class AssignerNodeConfig(BaseNodeConfig):
|
||||
variable_selector: str | list[str] = Field(
|
||||
...,
|
||||
description="Variables to be assigned",
|
||||
)
|
||||
|
||||
operation: AssignmentOperator = Field(
|
||||
...,
|
||||
description="Operator to assign",
|
||||
)
|
||||
|
||||
value: str | list[str] = Field(
|
||||
...,
|
||||
description="Values to assign",
|
||||
)
|
||||
80
api/app/core/workflow/nodes/assigner/node.py
Normal file
80
api/app/core/workflow/nodes/assigner/node.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the assignment operation defined by this node.
|
||||
|
||||
Args:
|
||||
state: The current workflow state, including conversation variables,
|
||||
node outputs, and system variables.
|
||||
|
||||
Returns:
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
pool = VariablePool(state)
|
||||
|
||||
# Get the target variable selector (e.g., "conv.test")
|
||||
variable_selector = self.typed_config.variable_selector
|
||||
if isinstance(variable_selector, str):
|
||||
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
||||
variable_selector = variable_selector.split('.')
|
||||
|
||||
# Only conversation variables ('conv') are allowed
|
||||
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
||||
raise ValueError("Only conversation variables can be assigned.")
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = self.typed_config.value
|
||||
if isinstance(value, list):
|
||||
value = '.'.join(value)
|
||||
value = ExpressionEvaluator.evaluate(
|
||||
expression=value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
|
||||
# Select the appropriate assignment operator instance based on the target variable type
|
||||
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
||||
pool, variable_selector, value
|
||||
)
|
||||
|
||||
# Execute the configured assignment operation
|
||||
match self.typed_config.operation:
|
||||
case AssignmentOperator.ASSIGN:
|
||||
operator.assign()
|
||||
case AssignmentOperator.CLEAR:
|
||||
operator.clear()
|
||||
case AssignmentOperator.ADD:
|
||||
operator.add()
|
||||
case AssignmentOperator.SUBTRACT:
|
||||
operator.subtract()
|
||||
case AssignmentOperator.MULTIPLY:
|
||||
operator.multiply()
|
||||
case AssignmentOperator.DIVIDE:
|
||||
operator.divide()
|
||||
case AssignmentOperator.APPEND:
|
||||
operator.append()
|
||||
case AssignmentOperator.REMOVE_FIRST:
|
||||
operator.remove_first()
|
||||
case AssignmentOperator.REMOVE_LAST:
|
||||
operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")
|
||||
@@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -28,4 +30,6 @@ __all__ = [
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
"IfElseNodeConfig",
|
||||
# "KnowledgeRetrievalNodeConfig",
|
||||
"AssignerNodeConfig",
|
||||
]
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.operators import (
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
AssignmentOperatorType,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
)
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
@@ -14,6 +23,7 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
AGENT = "agent"
|
||||
ASSIGNER = "assigner"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
@@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum):
|
||||
class LogicOperator(StrEnum):
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
|
||||
|
||||
class AssignmentOperator(StrEnum):
|
||||
ASSIGN = "assign"
|
||||
CLEAR = "clear"
|
||||
|
||||
ADD = "add" # +=
|
||||
SUBTRACT = "subtract" # -=
|
||||
MULTIPLY = "multiply" # *=
|
||||
DIVIDE = "divide" # /=
|
||||
|
||||
APPEND = "append"
|
||||
REMOVE_LAST = "remove_last"
|
||||
REMOVE_FIRST = "remove_first"
|
||||
|
||||
@classmethod
|
||||
def get_operator(cls, obj) -> AssignmentOperatorType:
|
||||
if isinstance(obj, str):
|
||||
return StringOperator
|
||||
elif isinstance(obj, bool):
|
||||
return BooleanOperator
|
||||
elif isinstance(obj, (int, float)):
|
||||
return NumberOperator
|
||||
elif isinstance(obj, list):
|
||||
return ArrayOperator
|
||||
elif isinstance(obj, dict):
|
||||
return ObjectOperator
|
||||
|
||||
raise TypeError(f"Unsupported variable type ({type(obj)})")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
@@ -29,7 +29,7 @@ WorkflowNode = Union[
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
AssignerNode,
|
||||
KnowledgeRetrievalNode,
|
||||
# KnowledgeRetrievalNode,
|
||||
]
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ class NodeFactory:
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
}
|
||||
|
||||
|
||||
142
api/app/core/workflow/nodes/operators.py
Normal file
142
api/app/core/workflow/nodes/operators.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from abc import ABC
|
||||
from typing import Union, Type
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
|
||||
class OperatorBase(ABC):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
self.pool = pool
|
||||
self.left_selector = left_selector
|
||||
self.right = right
|
||||
|
||||
self.type_limit: type[str, int, dict, list] = None
|
||||
|
||||
def check(self, no_right=False):
|
||||
left = self.pool.get(self.left_selector)
|
||||
if not isinstance(left, self.type_limit):
|
||||
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
||||
|
||||
if not no_right and not isinstance(self.right, self.type_limit):
|
||||
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
|
||||
|
||||
|
||||
class StringOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = str
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, '')
|
||||
|
||||
|
||||
class NumberOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = (float, int)
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, 0)
|
||||
|
||||
def add(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin + self.right)
|
||||
|
||||
def subtract(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin - self.right)
|
||||
|
||||
def multiply(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin * self.right)
|
||||
|
||||
def divide(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin / self.right)
|
||||
|
||||
|
||||
class BooleanOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = bool
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, False)
|
||||
|
||||
|
||||
class ArrayOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = list
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, list())
|
||||
|
||||
def append(self) -> None:
|
||||
self.check()
|
||||
# TODO:require type limit in list
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin.append(self.right))
|
||||
|
||||
def extend(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin.extend(self.right))
|
||||
|
||||
def remove_last(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin.pop())
|
||||
|
||||
def remove_first(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin.pop(0))
|
||||
|
||||
|
||||
class ObjectOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = object
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, dict())
|
||||
|
||||
|
||||
AssignmentOperatorInstance = Union[
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
]
|
||||
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
||||
Reference in New Issue
Block a user