feat(workflow): add conditional branch (If-Else) node

- Introduce a new conditional branch node for workflows.
- Supports multiple case branches with logical operators (AND/OR).
- Enables workflow routing based on evaluated conditions.
This commit is contained in:
mengyonghao
2025-12-19 14:19:18 +08:00
committed by 谢俊男
parent 5cd46e441e
commit 01ac36195a
4 changed files with 343 additions and 33 deletions

View File

@@ -13,8 +13,9 @@ from langchain_core.messages import HumanMessage
from langgraph.graph import StateGraph, START, END
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.expression_evaluator import evaluate_condition
from app.core.workflow.nodes import WorkflowState, NodeFactory
from app.core.workflow.nodes.enums import NodeType
from app.core.tools.registry import ToolRegistry
from app.core.tools.executor import ToolExecutor
from app.core.tools.langchain_adapter import LangchainAdapter
@@ -30,11 +31,11 @@ class WorkflowExecutor:
"""
def __init__(
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
self,
workflow_config: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""初始化执行器
@@ -95,8 +96,6 @@ class WorkflowExecutor:
"error_node": None
}
def build_graph(self) -> CompiledStateGraph:
"""构建 LangGraph
@@ -117,19 +116,36 @@ class WorkflowExecutor:
node_id = node.get("id")
# 记录 start 和 end 节点 ID
if node_type == "start":
if node_type == NodeType.START:
start_node_id = node_id
elif node_type == "end":
elif node_type == NodeType.END:
end_node_ids.append(node_id)
# 创建节点实例(现在 start 和 end 也会被创建)
node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in [NodeType.IF_ELSE]:
# Build ordered boolean expression strings for each branch.
# These expressions will be attached to outgoing edges as
# LangGraph conditional routing rules.
expressions = node_instance.build_conditional_edge_expressions()
# Collect all outgoing edges from the current node.
# The order of edges must match the order of generated expressions.
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
# Attach each condition expression to the corresponding edge
# based on branch priority
for idx in range(len(expressions)):
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
if node_instance:
# 包装节点的 run 方法
# 使用函数工厂避免闭包问题
def make_node_func(inst):
async def node_func(state: WorkflowState):
return await inst.run(state)
return node_func
workflow.add_node(node_id, make_node_func(node_instance))
@@ -170,14 +186,14 @@ class WorkflowExecutor:
def router(state: WorkflowState, cond=condition, tgt=target):
"""条件路由函数"""
if evaluate_condition(
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
cond,
state.get("variables", {}),
state.get("node_outputs", {}),
{
"execution_id": state.get("execution_id"),
"workspace_id": state.get("workspace_id"),
"user_id": state.get("user_id")
}
):
return tgt
return END # 条件不满足,结束
@@ -201,8 +217,8 @@ class WorkflowExecutor:
return graph
async def execute(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
) -> dict[str, Any]:
"""执行工作流(非流式)
@@ -276,8 +292,8 @@ class WorkflowExecutor:
}
async def execute_stream(
self,
input_data: dict[str, Any]
self,
input_data: dict[str, Any]
):
"""执行工作流(流式)
@@ -331,7 +347,6 @@ class WorkflowExecutor:
"token_usage": None
}
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
"""从节点输出中提取最终输出
@@ -391,11 +406,11 @@ class WorkflowExecutor:
async def execute_workflow(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
) -> dict[str, Any]:
"""执行工作流(便捷函数)
@@ -419,11 +434,11 @@ async def execute_workflow(
async def execute_workflow_stream(
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
workflow_config: dict[str, Any],
input_data: dict[str, Any],
execution_id: str,
workspace_id: str,
user_id: str
):
"""执行工作流(流式,便捷函数)

View File

@@ -0,0 +1,5 @@
"""Condition Node"""
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
from app.core.workflow.nodes.if_else.node import IfElseNode
__all__ = ["IfElseNode", "IfElseNodeConfig"]

View File

@@ -0,0 +1,122 @@
"""Condition Configuration"""
from pydantic import Field, BaseModel, field_validator
from enum import StrEnum
from app.core.workflow.nodes.base_config import BaseNodeConfig
class LogicOperator(StrEnum):
AND = "and"
OR = "or"
class ComparisonOpeartor(StrEnum):
EMPTY = "empty"
NOT_EMPTY = "not_empty"
CONTAINS = "contains"
NOT_CONTAINS = "not_contains"
START_WITH = "startwith"
END_WITH = "endwith"
EQ = "eq"
NE = "ne"
LT = "lt"
LE = "le"
GT = "gt"
GE = "ge"
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOpeartor = Field(
...,
description="Comparison operator used to evaluate the condition"
)
left: str = Field(
...,
description="Value to compare against"
)
right: str = Field(
...,
description="Value to compare with"
)
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
description="Logical operator used to combine multiple condition expressions"
)
conditions: list[ConditionDetail] = Field(
...,
description="List of condition expressions within this branch"
)
class IfElseNodeConfig(BaseNodeConfig):
cases: list[ConditionBranchConfig] = Field(
...,
description="List of branch conditions or expressions"
)
@field_validator("cases")
@classmethod
def validate_case_number(cls, v, info):
if len(v) < 1:
raise ValueError("At least one cases are required")
return v
class Config:
json_schema_extra = {
"examples": [
{
"cases": [
# if/CASE1
{
"logical_operator": "and",
"conditions": [
{
"left": "sys.message",
"comparison_operator": "eq",
"right": "'test'"
}
]
},
]
},
{
"case_number": 3,
"cases": [
# if/CASE1
{
"logic": "or",
"conditions": [
{
"left": "sys.message",
"comparison_operator": "eq",
"right": "'test'"
}
]
},
# elif/CASE2
{
"logic": "and",
"conditions": [
{
"left": "sys.message",
"comparison_operator": "eq",
"right": "'test'"
},
{
"left": "sys.message",
"comparison_operator": "contains",
"right": "'test'"
}
]
},
]
}
]
}

View File

@@ -0,0 +1,168 @@
import logging
from typing import Any
from simpleeval import NameNotDefined, InvalidExpression
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOpeartor
logger = logging.getLogger(__name__)
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(self, left: str, operator: ComparisonOpeartor, right: str):
self.left = left
self.operator = operator
self.right = right
def _empty(self):
return f"{self.left} == ''"
def _not_empty(self):
return f"{self.left} != ''"
def _contains(self):
return f"{self.right} in {self.left}"
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startwith(self):
return f'{self.left}.startswith({self.right})'
def _endwith(self):
return f'{self.left}.endswith({self.right})'
def _eq(self):
return f"{self.left} == {self.right}"
def _ne(self):
return f"{self.left} != {self.right}"
def _lt(self):
return f"{self.left} < {self.right}"
def _le(self):
return f"{self.left} <= {self.right}"
def _gt(self):
return f"{self.left} > {self.right}"
def _ge(self):
return f"{self.left} >= {self.right}"
def build(self):
match self.operator:
case ComparisonOpeartor.EMPTY:
return self._empty()
case ComparisonOpeartor.NOT_EMPTY:
return self._not_empty()
case ComparisonOpeartor.CONTAINS:
return self._contains()
case ComparisonOpeartor.NOT_CONTAINS:
return self._not_contains()
case ComparisonOpeartor.START_WITH:
return self._startwith()
case ComparisonOpeartor.END_WITH:
return self._endwith()
case ComparisonOpeartor.EQ:
return self._eq()
case ComparisonOpeartor.NE:
return self._ne()
case ComparisonOpeartor.LT:
return self._lt()
case ComparisonOpeartor.LE:
return self._le()
case ComparisonOpeartor.GT:
return self._gt()
case ComparisonOpeartor.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod
def _build_condition_expression(
condition: ConditionDetail,
) -> str:
"""
Build a single boolean condition expression string.
This method does NOT evaluate the condition.
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
"""
Build conditional edge expressions for the If-Else node.
This method does NOT evaluate any condition at runtime.
Instead, it converts each case branch into a Python boolean
expression string, which will later be attached to LangGraph
as conditional edges.
Each returned expression corresponds to one branch and is
evaluated in order. A fallback 'True' condition is appended
to ensure a default branch when no previous conditions match.
Returns:
list[str]: A list of Python boolean expression strings,
ordered by branch priority.
"""
branch_index = 0
conditions = []
for case_branch in self.typed_config.cases:
branch_index += 1
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.conditions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
conditions.append(combined_condition)
# Default fallback branch
conditions.append("True")
return conditions
async def execute(self, state: WorkflowState) -> Any:
"""
"""
expressions = self.build_conditional_edge_expressions()
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
return f'CASE{i+1}'
return f'CASE{len(expressions)}'