Merge #17 into develop from feature/20251219_myh

feat(workflow): add conditional branch (If-Else) node

* feature/20251219_myh: (10 commits)
  fix(workflow): fix run_workflow streaming issues
  fix(prompt-optimizer): switch to built-in system prompt
  feat(workflow): add conditional branch (If-Else) node
  perf(types): add Union type declaration for workflow nodes
  fix(expression-eval): fix variable extraction issue in Jinja2 templates
  docs(samples): add config example for If-Else node
  style(workflow): update condition edge comments for conditional nodes
  style(enums): correct enum class name spelling
  refactor(workflow): unify all enum classes in one file and restructure workflow...
  feat(workflow): add import for if-else node configuration

Signed-off-by: Eternity <1533512157@qq.com>
Commented-by: Eternity <1533512157@qq.com>
Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>
Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/17
This commit is contained in:
朱文辉
2025-12-19 18:18:50 +08:00
16 changed files with 478 additions and 335 deletions

View File

@@ -4,16 +4,17 @@
基于 LangGraph 的工作流执行引擎。
"""
import logging
import datetime
import logging
from typing import Any
from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
logger = logging.getLogger(__name__)
@@ -25,11 +26,11 @@ class WorkflowExecutor:
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
@@ -90,8 +91,6 @@ class WorkflowExecutor:
"error_node": None
}
def build_graph(self) -> CompiledStateGraph:
"""构建 LangGraph
@@ -112,19 +111,38 @@ class WorkflowExecutor:
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
if node_type == NodeType.START:
start_node_id = node_id
elif node_type == "end":
elif node_type == NodeType.END:
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE]:
expressions = node_instance.build_conditional_edge_expressions()
# Number of branches, usually matches the number of conditional expressions
branch_number = len(expressions)
# Find all edges whose source is the current node
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Iterate over each branch
for idx in range(branch_number):
# Generate a condition expression for each edge
# Used later to determine which branch to take based on the node's output
# Assumes node output `node.<node_id>.output` matches the edge's label
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
@@ -165,14 +183,14 @@ class WorkflowExecutor:
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
@@ -196,8 +214,8 @@ class WorkflowExecutor:
return graph
async def execute(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
@@ -271,8 +289,8 @@ class WorkflowExecutor:
}
async def execute_stream(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
@@ -305,7 +323,7 @@ class WorkflowExecutor:
try:
async for chunk in graph.astream(
initial_state,
# subgraphs=True,
# subgraphs=True,
stream_mode="updates",
):
# print(chunk)
@@ -326,7 +344,6 @@ class WorkflowExecutor:
"token_usage": None
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
@@ -386,11 +403,11 @@ class WorkflowExecutor:
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
@@ -414,11 +431,11 @@ async def execute_workflow(
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)

View File

@@ -5,6 +5,7 @@
"""
import logging
import re
from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
@@ -59,9 +60,10 @@ class ExpressionEvaluator:
"""
# 移除 Jinja2 模板语法的花括号(如果存在)
expression = expression.strip()
if expression.startswith("{{") and expression.endswith("}}"):
expression = expression[2:-2].strip()
# "{{system.message}} == {{ user.messge }}" -> "system.message == user.message"
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
# 构建命名空间上下文
context = {
"var": variables, # 用户变量

View File

@@ -4,13 +4,14 @@
提供各种类型的节点实现,用于工作流执行。
"""
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.node_factory import NodeFactory
from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
__all__ = [
"BaseNode",
@@ -18,7 +19,9 @@ __all__ = [
"LLMNode",
"AgentNode",
"TransformNode",
"IfElseNode",
"StartNode",
"EndNode",
"NodeFactory",
"WorkflowNode"
]

View File

@@ -13,6 +13,7 @@ from app.core.workflow.nodes.end.config import EndNodeConfig
from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
from app.core.workflow.nodes.agent.config import AgentNodeConfig
from app.core.workflow.nodes.transform.config import TransformNodeConfig
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
__all__ = [
# 基础类
@@ -26,4 +27,5 @@ __all__ = [
"MessageConfig",
"AgentNodeConfig",
"TransformNodeConfig",
"IfElseNodeConfig",
]

View File

@@ -1,5 +1,6 @@
from enum import StrEnum
class NodeType(StrEnum):
START = "start"
END = "end"
@@ -13,3 +14,23 @@ class NodeType(StrEnum):
HTTP_REQUEST = "http-request"
TOOL = "tool"
AGENT = "agent"
class ComparisonOperator(StrEnum):
EMPTY = "empty"
NOT_EMPTY = "not_empty"
CONTAINS = "contains"
NOT_CONTAINS = "not_contains"
START_WITH = "startwith"
END_WITH = "endwith"
EQ = "eq"
NE = "ne"
LT = "lt"
LE = "le"
GT = "gt"
GE = "ge"
class LogicOperator(StrEnum):
AND = "and"
OR = "or"

View File

@@ -0,0 +1,5 @@
"""Condition Node"""
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.if_else.node import IfElseNode
__all__ = ["IfElseNode", "IfElseNodeConfig"]

View File

@@ -0,0 +1,97 @@
"""Condition Configuration"""
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
...,
description="Comparison operator used to evaluate the condition"
)
left: str = Field(
...,
description="Value to compare against"
)
right: str = Field(
...,
description="Value to compare with"
)
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
description="Logical operator used to combine multiple condition expressions"
)
conditions: list[ConditionDetail] = Field(
...,
description="List of condition expressions within this branch"
)
class IfElseNodeConfig(BaseNodeConfig):
cases: list[ConditionBranchConfig] = Field(
...,
description="List of branch conditions or expressions"
)
@field_validator("cases")
@classmethod
def validate_case_number(cls, v, info):
if len(v) < 1:
raise ValueError("At least one cases are required")
return v
class Config:
json_schema_extra = {
"examples": [
{
"cases": [
# CASE1 / IF Branch
{
"logical_operator": "and",
"conditions": [
[
{
"left": "node.userinput.message",
"comparison_operator": "eq",
"right": "'123'"
},
{
"left": "node.userinput.test",
"comparison_operator": "eq",
"right": "True"
}
]
]
},
# CASE1 / ELIF Branch
{
"logical_operator": "or",
"conditions": [
[
{
"left": "node.userinput.test",
"comparison_operator": "eq",
"right": "False"
},
{
"left": "node.userinput.message",
"comparison_operator": "contains",
"right": "'123'"
}
]
]
}
# CASE3 / ELSE Branch
]
}
]
}

View File

@@ -0,0 +1,167 @@
import logging
from typing import Any
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail
logger = logging.getLogger(__name__)
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startwith(self):
return f'{self.left}.startswith({self.right})'
def _endwith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startwith()
case ComparisonOperator.END_WITH:
return self._endwith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod
def _build_condition_expression(
condition: ConditionDetail,
) -> str:
"""
Build a single boolean condition expression string.
This method does NOT evaluate the condition.
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
"""
Build conditional edge expressions for the If-Else node.
This method does NOT evaluate any condition at runtime.
Instead, it converts each case branch into a Python boolean
expression string, which will later be attached to LangGraph
as conditional edges.
Each returned expression corresponds to one branch and is
evaluated in order. A fallback 'True' condition is appended
to ensure a default branch when no previous conditions match.
Returns:
list[str]: A list of Python boolean expression strings,
ordered by branch priority.
"""
branch_index = 0
conditions = []
for case_branch in self.typed_config.cases:
branch_index += 1
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.conditions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
conditions.append(combined_condition)
# Default fallback branch
conditions.append("True")
return conditions
async def execute(self, state: WorkflowState) -> Any:
"""
"""
expressions = self.build_conditional_edge_expressions()
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
return f'CASE{i+1}'
return f'CASE{len(expressions)}'

View File

@@ -5,18 +5,29 @@
"""
import logging
from typing import Any
from typing import Any, Union
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.agent import AgentNode
from app.core.workflow.nodes.transform import TransformNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.end import EndNode
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.nodes.if_else import IfElseNode
from app.core.workflow.nodes.llm import LLMNode
from app.core.workflow.nodes.start import StartNode
from app.core.workflow.nodes.transform import TransformNode
logger = logging.getLogger(__name__)
WorkflowNode = Union[
BaseNode,
StartNode,
EndNode,
LLMNode,
IfElseNode,
AgentNode,
TransformNode,
]
class NodeFactory:
"""节点工厂
@@ -25,16 +36,17 @@ class NodeFactory:
"""
# 节点类型注册表
_node_types: dict[str, type[BaseNode]] = {
_node_types: dict[str, type[WorkflowNode]] = {
NodeType.START: StartNode,
NodeType.END: EndNode,
NodeType.LLM: LLMNode,
NodeType.AGENT: AgentNode,
NodeType.TRANSFORM: TransformNode,
NodeType.IF_ELSE: IfElseNode
}
@classmethod
def register_node_type(cls, node_type: str, node_class: type[BaseNode]):
def register_node_type(cls, node_type: str, node_class: type[WorkflowNode]):
"""注册新的节点类型
Args:
@@ -52,10 +64,10 @@ class NodeFactory:
@classmethod
def create_node(
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> BaseNode | None:
cls,
node_config: dict[str, Any],
workflow_config: dict[str, Any]
) -> WorkflowNode | None:
"""创建节点实例
Args: