feat(workflow): add conditional branch (If-Else) node
- Introduce a new conditional branch node for workflows. - Supports multiple case branches with logical operators (AND/OR). - Enables workflow routing based on evaluated conditions.
This commit is contained in:
@@ -13,8 +13,9 @@ from langchain_core.messages import HumanMessage
|
|||||||
from langgraph.graph import StateGraph, START, END
|
from langgraph.graph import StateGraph, START, END
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition
|
from app.core.workflow.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.nodes import WorkflowState, NodeFactory
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.tools.registry import ToolRegistry
|
from app.core.tools.registry import ToolRegistry
|
||||||
from app.core.tools.executor import ToolExecutor
|
from app.core.tools.executor import ToolExecutor
|
||||||
from app.core.tools.langchain_adapter import LangchainAdapter
|
from app.core.tools.langchain_adapter import LangchainAdapter
|
||||||
@@ -30,11 +31,11 @@ class WorkflowExecutor:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str
|
||||||
):
|
):
|
||||||
"""初始化执行器
|
"""初始化执行器
|
||||||
|
|
||||||
@@ -95,8 +96,6 @@ class WorkflowExecutor:
|
|||||||
"error_node": None
|
"error_node": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def build_graph(self) -> CompiledStateGraph:
|
def build_graph(self) -> CompiledStateGraph:
|
||||||
"""构建 LangGraph
|
"""构建 LangGraph
|
||||||
|
|
||||||
@@ -117,19 +116,36 @@ class WorkflowExecutor:
|
|||||||
node_id = node.get("id")
|
node_id = node.get("id")
|
||||||
|
|
||||||
# 记录 start 和 end 节点 ID
|
# 记录 start 和 end 节点 ID
|
||||||
if node_type == "start":
|
if node_type == NodeType.START:
|
||||||
start_node_id = node_id
|
start_node_id = node_id
|
||||||
elif node_type == "end":
|
elif node_type == NodeType.END:
|
||||||
end_node_ids.append(node_id)
|
end_node_ids.append(node_id)
|
||||||
|
|
||||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||||
|
|
||||||
|
if node_type in [NodeType.IF_ELSE]:
|
||||||
|
# Build ordered boolean expression strings for each branch.
|
||||||
|
# These expressions will be attached to outgoing edges as
|
||||||
|
# LangGraph conditional routing rules.
|
||||||
|
expressions = node_instance.build_conditional_edge_expressions()
|
||||||
|
|
||||||
|
# Collect all outgoing edges from the current node.
|
||||||
|
# The order of edges must match the order of generated expressions.
|
||||||
|
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||||
|
|
||||||
|
# Attach each condition expression to the corresponding edge
|
||||||
|
# based on branch priority
|
||||||
|
for idx in range(len(expressions)):
|
||||||
|
related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'"
|
||||||
|
|
||||||
if node_instance:
|
if node_instance:
|
||||||
# 包装节点的 run 方法
|
# 包装节点的 run 方法
|
||||||
# 使用函数工厂避免闭包问题
|
# 使用函数工厂避免闭包问题
|
||||||
def make_node_func(inst):
|
def make_node_func(inst):
|
||||||
async def node_func(state: WorkflowState):
|
async def node_func(state: WorkflowState):
|
||||||
return await inst.run(state)
|
return await inst.run(state)
|
||||||
|
|
||||||
return node_func
|
return node_func
|
||||||
|
|
||||||
workflow.add_node(node_id, make_node_func(node_instance))
|
workflow.add_node(node_id, make_node_func(node_instance))
|
||||||
@@ -170,14 +186,14 @@ class WorkflowExecutor:
|
|||||||
def router(state: WorkflowState, cond=condition, tgt=target):
|
def router(state: WorkflowState, cond=condition, tgt=target):
|
||||||
"""条件路由函数"""
|
"""条件路由函数"""
|
||||||
if evaluate_condition(
|
if evaluate_condition(
|
||||||
cond,
|
cond,
|
||||||
state.get("variables", {}),
|
state.get("variables", {}),
|
||||||
state.get("node_outputs", {}),
|
state.get("node_outputs", {}),
|
||||||
{
|
{
|
||||||
"execution_id": state.get("execution_id"),
|
"execution_id": state.get("execution_id"),
|
||||||
"workspace_id": state.get("workspace_id"),
|
"workspace_id": state.get("workspace_id"),
|
||||||
"user_id": state.get("user_id")
|
"user_id": state.get("user_id")
|
||||||
}
|
}
|
||||||
):
|
):
|
||||||
return tgt
|
return tgt
|
||||||
return END # 条件不满足,结束
|
return END # 条件不满足,结束
|
||||||
@@ -201,8 +217,8 @@ class WorkflowExecutor:
|
|||||||
return graph
|
return graph
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
input_data: dict[str, Any]
|
input_data: dict[str, Any]
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""执行工作流(非流式)
|
"""执行工作流(非流式)
|
||||||
|
|
||||||
@@ -276,8 +292,8 @@ class WorkflowExecutor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def execute_stream(
|
async def execute_stream(
|
||||||
self,
|
self,
|
||||||
input_data: dict[str, Any]
|
input_data: dict[str, Any]
|
||||||
):
|
):
|
||||||
"""执行工作流(流式)
|
"""执行工作流(流式)
|
||||||
|
|
||||||
@@ -331,7 +347,6 @@ class WorkflowExecutor:
|
|||||||
"token_usage": None
|
"token_usage": None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None:
|
||||||
"""从节点输出中提取最终输出
|
"""从节点输出中提取最终输出
|
||||||
|
|
||||||
@@ -391,11 +406,11 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
|
|
||||||
async def execute_workflow(
|
async def execute_workflow(
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""执行工作流(便捷函数)
|
"""执行工作流(便捷函数)
|
||||||
|
|
||||||
@@ -419,11 +434,11 @@ async def execute_workflow(
|
|||||||
|
|
||||||
|
|
||||||
async def execute_workflow_stream(
|
async def execute_workflow_stream(
|
||||||
workflow_config: dict[str, Any],
|
workflow_config: dict[str, Any],
|
||||||
input_data: dict[str, Any],
|
input_data: dict[str, Any],
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
workspace_id: str,
|
workspace_id: str,
|
||||||
user_id: str
|
user_id: str
|
||||||
):
|
):
|
||||||
"""执行工作流(流式,便捷函数)
|
"""执行工作流(流式,便捷函数)
|
||||||
|
|
||||||
|
|||||||
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
5
api/app/core/workflow/nodes/if_else/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Condition Node"""
|
||||||
|
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||||
|
from app.core.workflow.nodes.if_else.node import IfElseNode
|
||||||
|
|
||||||
|
__all__ = ["IfElseNode", "IfElseNodeConfig"]
|
||||||
122
api/app/core/workflow/nodes/if_else/config.py
Normal file
122
api/app/core/workflow/nodes/if_else/config.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Condition Configuration"""
|
||||||
|
from pydantic import Field, BaseModel, field_validator
|
||||||
|
from enum import StrEnum
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LogicOperator(StrEnum):
|
||||||
|
AND = "and"
|
||||||
|
OR = "or"
|
||||||
|
|
||||||
|
|
||||||
|
class ComparisonOpeartor(StrEnum):
|
||||||
|
EMPTY = "empty"
|
||||||
|
NOT_EMPTY = "not_empty"
|
||||||
|
CONTAINS = "contains"
|
||||||
|
NOT_CONTAINS = "not_contains"
|
||||||
|
START_WITH = "startwith"
|
||||||
|
END_WITH = "endwith"
|
||||||
|
EQ = "eq"
|
||||||
|
NE = "ne"
|
||||||
|
LT = "lt"
|
||||||
|
LE = "le"
|
||||||
|
GT = "gt"
|
||||||
|
GE = "ge"
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionDetail(BaseModel):
|
||||||
|
comparison_operator: ComparisonOpeartor = Field(
|
||||||
|
...,
|
||||||
|
description="Comparison operator used to evaluate the condition"
|
||||||
|
)
|
||||||
|
|
||||||
|
left: str = Field(
|
||||||
|
...,
|
||||||
|
description="Value to compare against"
|
||||||
|
)
|
||||||
|
|
||||||
|
right: str = Field(
|
||||||
|
...,
|
||||||
|
description="Value to compare with"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionBranchConfig(BaseModel):
|
||||||
|
"""Configuration for a conditional branch"""
|
||||||
|
|
||||||
|
logical_operator: LogicOperator = Field(
|
||||||
|
default=LogicOperator.AND.value,
|
||||||
|
description="Logical operator used to combine multiple condition expressions"
|
||||||
|
)
|
||||||
|
|
||||||
|
conditions: list[ConditionDetail] = Field(
|
||||||
|
...,
|
||||||
|
description="List of condition expressions within this branch"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IfElseNodeConfig(BaseNodeConfig):
|
||||||
|
cases: list[ConditionBranchConfig] = Field(
|
||||||
|
...,
|
||||||
|
description="List of branch conditions or expressions"
|
||||||
|
)
|
||||||
|
|
||||||
|
@field_validator("cases")
|
||||||
|
@classmethod
|
||||||
|
def validate_case_number(cls, v, info):
|
||||||
|
if len(v) < 1:
|
||||||
|
raise ValueError("At least one cases are required")
|
||||||
|
return v
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
json_schema_extra = {
|
||||||
|
"examples": [
|
||||||
|
{
|
||||||
|
"cases": [
|
||||||
|
# if/CASE1
|
||||||
|
{
|
||||||
|
"logical_operator": "and",
|
||||||
|
"conditions": [
|
||||||
|
{
|
||||||
|
"left": "sys.message",
|
||||||
|
"comparison_operator": "eq",
|
||||||
|
"right": "'test'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"case_number": 3,
|
||||||
|
"cases": [
|
||||||
|
# if/CASE1
|
||||||
|
{
|
||||||
|
"logic": "or",
|
||||||
|
"conditions": [
|
||||||
|
{
|
||||||
|
"left": "sys.message",
|
||||||
|
"comparison_operator": "eq",
|
||||||
|
"right": "'test'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
# elif/CASE2
|
||||||
|
{
|
||||||
|
"logic": "and",
|
||||||
|
"conditions": [
|
||||||
|
{
|
||||||
|
"left": "sys.message",
|
||||||
|
"comparison_operator": "eq",
|
||||||
|
"right": "'test'"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"left": "sys.message",
|
||||||
|
"comparison_operator": "contains",
|
||||||
|
"right": "'test'"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
168
api/app/core/workflow/nodes/if_else/node.py
Normal file
168
api/app/core/workflow/nodes/if_else/node.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from simpleeval import NameNotDefined, InvalidExpression
|
||||||
|
|
||||||
|
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||||
|
from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOpeartor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionExpressionBuilder:
|
||||||
|
"""
|
||||||
|
Build a Python boolean expression string based on a comparison operator.
|
||||||
|
|
||||||
|
This class does not evaluate the expression.
|
||||||
|
It only generates a valid Python expression string
|
||||||
|
that can be evaluated later in a workflow context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, left: str, operator: ComparisonOpeartor, right: str):
|
||||||
|
self.left = left
|
||||||
|
self.operator = operator
|
||||||
|
self.right = right
|
||||||
|
|
||||||
|
def _empty(self):
|
||||||
|
return f"{self.left} == ''"
|
||||||
|
|
||||||
|
def _not_empty(self):
|
||||||
|
return f"{self.left} != ''"
|
||||||
|
|
||||||
|
def _contains(self):
|
||||||
|
return f"{self.right} in {self.left}"
|
||||||
|
|
||||||
|
def _not_contains(self):
|
||||||
|
return f"{self.right} not in {self.left}"
|
||||||
|
|
||||||
|
def _startwith(self):
|
||||||
|
return f'{self.left}.startswith({self.right})'
|
||||||
|
|
||||||
|
def _endwith(self):
|
||||||
|
return f'{self.left}.endswith({self.right})'
|
||||||
|
|
||||||
|
def _eq(self):
|
||||||
|
return f"{self.left} == {self.right}"
|
||||||
|
|
||||||
|
def _ne(self):
|
||||||
|
return f"{self.left} != {self.right}"
|
||||||
|
|
||||||
|
def _lt(self):
|
||||||
|
return f"{self.left} < {self.right}"
|
||||||
|
|
||||||
|
def _le(self):
|
||||||
|
return f"{self.left} <= {self.right}"
|
||||||
|
|
||||||
|
def _gt(self):
|
||||||
|
return f"{self.left} > {self.right}"
|
||||||
|
|
||||||
|
def _ge(self):
|
||||||
|
return f"{self.left} >= {self.right}"
|
||||||
|
|
||||||
|
def build(self):
|
||||||
|
match self.operator:
|
||||||
|
case ComparisonOpeartor.EMPTY:
|
||||||
|
return self._empty()
|
||||||
|
case ComparisonOpeartor.NOT_EMPTY:
|
||||||
|
return self._not_empty()
|
||||||
|
case ComparisonOpeartor.CONTAINS:
|
||||||
|
return self._contains()
|
||||||
|
case ComparisonOpeartor.NOT_CONTAINS:
|
||||||
|
return self._not_contains()
|
||||||
|
case ComparisonOpeartor.START_WITH:
|
||||||
|
return self._startwith()
|
||||||
|
case ComparisonOpeartor.END_WITH:
|
||||||
|
return self._endwith()
|
||||||
|
case ComparisonOpeartor.EQ:
|
||||||
|
return self._eq()
|
||||||
|
case ComparisonOpeartor.NE:
|
||||||
|
return self._ne()
|
||||||
|
case ComparisonOpeartor.LT:
|
||||||
|
return self._lt()
|
||||||
|
case ComparisonOpeartor.LE:
|
||||||
|
return self._le()
|
||||||
|
case ComparisonOpeartor.GT:
|
||||||
|
return self._gt()
|
||||||
|
case ComparisonOpeartor.GE:
|
||||||
|
return self._ge()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {self.operator}")
|
||||||
|
|
||||||
|
|
||||||
|
class IfElseNode(BaseNode):
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = IfElseNodeConfig(**self.config)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_condition_expression(
|
||||||
|
condition: ConditionDetail,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a single boolean condition expression string.
|
||||||
|
|
||||||
|
This method does NOT evaluate the condition.
|
||||||
|
It only generates a valid Python boolean expression string
|
||||||
|
(e.g. "x > 10", "'a' in name") that can later be used
|
||||||
|
in a conditional edge or evaluated by the workflow engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition (ConditionDetail): Definition of a single comparison condition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: A Python boolean expression string.
|
||||||
|
"""
|
||||||
|
return ConditionExpressionBuilder(
|
||||||
|
left=condition.left,
|
||||||
|
operator=condition.comparison_operator,
|
||||||
|
right=condition.right
|
||||||
|
).build()
|
||||||
|
|
||||||
|
def build_conditional_edge_expressions(self) -> list[str]:
|
||||||
|
"""
|
||||||
|
Build conditional edge expressions for the If-Else node.
|
||||||
|
|
||||||
|
This method does NOT evaluate any condition at runtime.
|
||||||
|
Instead, it converts each case branch into a Python boolean
|
||||||
|
expression string, which will later be attached to LangGraph
|
||||||
|
as conditional edges.
|
||||||
|
|
||||||
|
Each returned expression corresponds to one branch and is
|
||||||
|
evaluated in order. A fallback 'True' condition is appended
|
||||||
|
to ensure a default branch when no previous conditions match.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: A list of Python boolean expression strings,
|
||||||
|
ordered by branch priority.
|
||||||
|
"""
|
||||||
|
branch_index = 0
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
for case_branch in self.typed_config.cases:
|
||||||
|
branch_index += 1
|
||||||
|
|
||||||
|
branch_conditions = [
|
||||||
|
self._build_condition_expression(condition)
|
||||||
|
for condition in case_branch.conditions
|
||||||
|
]
|
||||||
|
if len(branch_conditions) > 1:
|
||||||
|
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
||||||
|
else:
|
||||||
|
combined_condition = branch_conditions[0]
|
||||||
|
conditions.append(combined_condition)
|
||||||
|
|
||||||
|
# Default fallback branch
|
||||||
|
conditions.append("True")
|
||||||
|
|
||||||
|
return conditions
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
expressions = self.build_conditional_edge_expressions()
|
||||||
|
for i in range(len(expressions)):
|
||||||
|
logger.info(expressions[i])
|
||||||
|
if self._evaluate_condition(expressions[i], state):
|
||||||
|
return f'CASE{i+1}'
|
||||||
|
return f'CASE{len(expressions)}'
|
||||||
Reference in New Issue
Block a user