feat(workflow): add conditional branch (If-Else) node
- Introduce a new conditional branch node for workflows. - Supports multiple case branches with logical operators (AND/OR). - Enables workflow routing based on evaluated conditions.
This commit is contained in:
@@ -13,8 +13,9 @@ from langchain_core.messages import HumanMessage
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.tools.registry import ToolRegistry
|
||||
from app.core.tools.executor import ToolExecutor
|
||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||
@@ -30,11 +31,11 @@ class WorkflowExecutor:
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""初始化执行器
|
||||
|
||||
@@ -95,8 +96,6 @@ class WorkflowExecutor:
|
||||
"error_node": None
|
||||
}
|
||||
|
||||
|
||||
|
||||
def build_graph(self) -> CompiledStateGraph:
|
||||
"""构建 LangGraph
|
||||
|
||||
@@ -117,19 +116,36 @@ class WorkflowExecutor:
|
||||
node_id = node.get("id")
|
||||
|
||||
# 记录 start 和 end 节点 ID
|
||||
if node_type == "start":
|
||||
if node_type == NodeType.START:
|
||||
start_node_id = node_id
|
||||
elif node_type == "end":
|
||||
elif node_type == NodeType.END:
|
||||
end_node_ids.append(node_id)
|
||||
|
||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||
|
||||
if node_type in [NodeType.IF_ELSE]:
|
||||
# Build ordered boolean expression strings for each branch.
|
||||
# These expressions will be attached to outgoing edges as
|
||||
# LangGraph conditional routing rules.
|
||||
expressions = node_instance.build_conditional_edge_expressions()
|
||||
|
||||
# Collect all outgoing edges from the current node.
|
||||
# The order of edges must match the order of generated expressions.
|
||||
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||
|
||||
# Attach each condition expression to the corresponding edge
|
||||
# based on branch priority
|
||||
for idx in range(len(expressions)):
|
||||
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# 包装节点的 run 方法
|
||||
# 使用函数工厂避免闭包问题
|
||||
def make_node_func(inst):
|
||||
async def node_func(state: WorkflowState):
|
||||
return await inst.run(state)
|
||||
|
||||
return node_func
|
||||
|
||||
workflow.add_node(node_id, make_node_func(node_instance))
|
||||
@@ -170,14 +186,14 @@ class WorkflowExecutor:
|
||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||
"""条件路由函数"""
|
||||
if evaluate_condition(
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
cond,
|
||||
state.get("variables", {}),
|
||||
state.get("node_outputs", {}),
|
||||
{
|
||||
"execution_id": state.get("execution_id"),
|
||||
"workspace_id": state.get("workspace_id"),
|
||||
"user_id": state.get("user_id")
|
||||
}
|
||||
):
|
||||
return tgt
|
||||
return END # 条件不满足,结束
|
||||
@@ -201,8 +217,8 @@ class WorkflowExecutor:
|
||||
return graph
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(非流式)
|
||||
|
||||
@@ -276,8 +292,8 @@ class WorkflowExecutor:
|
||||
}
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
self,
|
||||
input_data: dict[str, Any]
|
||||
):
|
||||
"""执行工作流(流式)
|
||||
|
||||
@@ -331,7 +347,6 @@ class WorkflowExecutor:
|
||||
"token_usage": None
|
||||
}
|
||||
|
||||
|
||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||
"""从节点输出中提取最终输出
|
||||
|
||||
@@ -391,11 +406,11 @@ class WorkflowExecutor:
|
||||
|
||||
|
||||
async def execute_workflow(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
) -> dict[str, Any]:
|
||||
"""执行工作流(便捷函数)
|
||||
|
||||
@@ -419,11 +434,11 @@ async def execute_workflow(
|
||||
|
||||
|
||||
async def execute_workflow_stream(
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str
|
||||
):
|
||||
"""执行工作流(流式,便捷函数)
|
||||
|
||||
|
||||
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Condition Node"""
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.node import IfElseNode
|
||||
|
||||
__all__ = ["IfElseNode", "IfElseNodeConfig"]
|
||||
122
api/app/core/workflow/nodes/if_else/config.py
Normal file
122
api/app/core/workflow/nodes/if_else/config.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Condition Configuration"""
|
||||
from pydantic import Field, BaseModel, field_validator
|
||||
from enum import StrEnum
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class LogicOperator(StrEnum):
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
|
||||
|
||||
class ComparisonOpeartor(StrEnum):
|
||||
EMPTY = "empty"
|
||||
NOT_EMPTY = "not_empty"
|
||||
CONTAINS = "contains"
|
||||
NOT_CONTAINS = "not_contains"
|
||||
START_WITH = "startwith"
|
||||
END_WITH = "endwith"
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
LT = "lt"
|
||||
LE = "le"
|
||||
GT = "gt"
|
||||
GE = "ge"
|
||||
|
||||
|
||||
class ConditionDetail(BaseModel):
|
||||
comparison_operator: ComparisonOpeartor = Field(
|
||||
...,
|
||||
description="Comparison operator used to evaluate the condition"
|
||||
)
|
||||
|
||||
left: str = Field(
|
||||
...,
|
||||
description="Value to compare against"
|
||||
)
|
||||
|
||||
right: str = Field(
|
||||
...,
|
||||
description="Value to compare with"
|
||||
)
|
||||
|
||||
|
||||
class ConditionBranchConfig(BaseModel):
|
||||
"""Configuration for a conditional branch"""
|
||||
|
||||
logical_operator: LogicOperator = Field(
|
||||
default=LogicOperator.AND.value,
|
||||
description="Logical operator used to combine multiple condition expressions"
|
||||
)
|
||||
|
||||
conditions: list[ConditionDetail] = Field(
|
||||
...,
|
||||
description="List of condition expressions within this branch"
|
||||
)
|
||||
|
||||
|
||||
class IfElseNodeConfig(BaseNodeConfig):
|
||||
cases: list[ConditionBranchConfig] = Field(
|
||||
...,
|
||||
description="List of branch conditions or expressions"
|
||||
)
|
||||
|
||||
@field_validator("cases")
|
||||
@classmethod
|
||||
def validate_case_number(cls, v, info):
|
||||
if len(v) < 1:
|
||||
raise ValueError("At least one cases are required")
|
||||
return v
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"examples": [
|
||||
{
|
||||
"cases": [
|
||||
# if/CASE1
|
||||
{
|
||||
"logical_operator": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"left": "sys.message",
|
||||
"comparison_operator": "eq",
|
||||
"right": "'test'"
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
||||
},
|
||||
{
|
||||
"case_number": 3,
|
||||
"cases": [
|
||||
# if/CASE1
|
||||
{
|
||||
"logic": "or",
|
||||
"conditions": [
|
||||
{
|
||||
"left": "sys.message",
|
||||
"comparison_operator": "eq",
|
||||
"right": "'test'"
|
||||
}
|
||||
]
|
||||
},
|
||||
# elif/CASE2
|
||||
{
|
||||
"logic": "and",
|
||||
"conditions": [
|
||||
{
|
||||
"left": "sys.message",
|
||||
"comparison_operator": "eq",
|
||||
"right": "'test'"
|
||||
},
|
||||
{
|
||||
"left": "sys.message",
|
||||
"comparison_operator": "contains",
|
||||
"right": "'test'"
|
||||
}
|
||||
]
|
||||
},
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
168
api/app/core/workflow/nodes/if_else/node.py
Normal file
168
api/app/core/workflow/nodes/if_else/node.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from simpleeval import NameNotDefined, InvalidExpression
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOpeartor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConditionExpressionBuilder:
|
||||
"""
|
||||
Build a Python boolean expression string based on a comparison operator.
|
||||
|
||||
This class does not evaluate the expression.
|
||||
It only generates a valid Python expression string
|
||||
that can be evaluated later in a workflow context.
|
||||
"""
|
||||
|
||||
def __init__(self, left: str, operator: ComparisonOpeartor, right: str):
|
||||
self.left = left
|
||||
self.operator = operator
|
||||
self.right = right
|
||||
|
||||
def _empty(self):
|
||||
return f"{self.left} == ''"
|
||||
|
||||
def _not_empty(self):
|
||||
return f"{self.left} != ''"
|
||||
|
||||
def _contains(self):
|
||||
return f"{self.right} in {self.left}"
|
||||
|
||||
def _not_contains(self):
|
||||
return f"{self.right} not in {self.left}"
|
||||
|
||||
def _startwith(self):
|
||||
return f'{self.left}.startswith({self.right})'
|
||||
|
||||
def _endwith(self):
|
||||
return f'{self.left}.endswith({self.right})'
|
||||
|
||||
def _eq(self):
|
||||
return f"{self.left} == {self.right}"
|
||||
|
||||
def _ne(self):
|
||||
return f"{self.left} != {self.right}"
|
||||
|
||||
def _lt(self):
|
||||
return f"{self.left} < {self.right}"
|
||||
|
||||
def _le(self):
|
||||
return f"{self.left} <= {self.right}"
|
||||
|
||||
def _gt(self):
|
||||
return f"{self.left} > {self.right}"
|
||||
|
||||
def _ge(self):
|
||||
return f"{self.left} >= {self.right}"
|
||||
|
||||
def build(self):
|
||||
match self.operator:
|
||||
case ComparisonOpeartor.EMPTY:
|
||||
return self._empty()
|
||||
case ComparisonOpeartor.NOT_EMPTY:
|
||||
return self._not_empty()
|
||||
case ComparisonOpeartor.CONTAINS:
|
||||
return self._contains()
|
||||
case ComparisonOpeartor.NOT_CONTAINS:
|
||||
return self._not_contains()
|
||||
case ComparisonOpeartor.START_WITH:
|
||||
return self._startwith()
|
||||
case ComparisonOpeartor.END_WITH:
|
||||
return self._endwith()
|
||||
case ComparisonOpeartor.EQ:
|
||||
return self._eq()
|
||||
case ComparisonOpeartor.NE:
|
||||
return self._ne()
|
||||
case ComparisonOpeartor.LT:
|
||||
return self._lt()
|
||||
case ComparisonOpeartor.LE:
|
||||
return self._le()
|
||||
case ComparisonOpeartor.GT:
|
||||
return self._gt()
|
||||
case ComparisonOpeartor.GE:
|
||||
return self._ge()
|
||||
case _:
|
||||
raise ValueError(f"Invalid condition: {self.operator}")
|
||||
|
||||
|
||||
class IfElseNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = IfElseNodeConfig(**self.config)
|
||||
|
||||
@staticmethod
|
||||
def _build_condition_expression(
|
||||
condition: ConditionDetail,
|
||||
) -> str:
|
||||
"""
|
||||
Build a single boolean condition expression string.
|
||||
|
||||
This method does NOT evaluate the condition.
|
||||
It only generates a valid Python boolean expression string
|
||||
(e.g. "x > 10", "'a' in name") that can later be used
|
||||
in a conditional edge or evaluated by the workflow engine.
|
||||
|
||||
Args:
|
||||
condition (ConditionDetail): Definition of a single comparison condition.
|
||||
|
||||
Returns:
|
||||
str: A Python boolean expression string.
|
||||
"""
|
||||
return ConditionExpressionBuilder(
|
||||
left=condition.left,
|
||||
operator=condition.comparison_operator,
|
||||
right=condition.right
|
||||
).build()
|
||||
|
||||
def build_conditional_edge_expressions(self) -> list[str]:
|
||||
"""
|
||||
Build conditional edge expressions for the If-Else node.
|
||||
|
||||
This method does NOT evaluate any condition at runtime.
|
||||
Instead, it converts each case branch into a Python boolean
|
||||
expression string, which will later be attached to LangGraph
|
||||
as conditional edges.
|
||||
|
||||
Each returned expression corresponds to one branch and is
|
||||
evaluated in order. A fallback 'True' condition is appended
|
||||
to ensure a default branch when no previous conditions match.
|
||||
|
||||
Returns:
|
||||
list[str]: A list of Python boolean expression strings,
|
||||
ordered by branch priority.
|
||||
"""
|
||||
branch_index = 0
|
||||
conditions = []
|
||||
|
||||
for case_branch in self.typed_config.cases:
|
||||
branch_index += 1
|
||||
|
||||
branch_conditions = [
|
||||
self._build_condition_expression(condition)
|
||||
for condition in case_branch.conditions
|
||||
]
|
||||
if len(branch_conditions) > 1:
|
||||
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
||||
else:
|
||||
combined_condition = branch_conditions[0]
|
||||
conditions.append(combined_condition)
|
||||
|
||||
# Default fallback branch
|
||||
conditions.append("True")
|
||||
|
||||
return conditions
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
"""
|
||||
expressions = self.build_conditional_edge_expressions()
|
||||
for i in range(len(expressions)):
|
||||
logger.info(expressions[i])
|
||||
if self._evaluate_condition(expressions[i], state):
|
||||
return f'CASE{i+1}'
|
||||
return f'CASE{len(expressions)}'
|
||||
Reference in New Issue
Block a user