From 01ac36195aefc1c7b3ccef2c233bb5878b68640f Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Fri, 19 Dec 2025 14:19:18 +0800 Subject: [PATCH] feat(workflow): add conditional branch (If-Else) node - Introduce a new conditional branch node for workflows. - Supports multiple case branches with logical operators (AND/OR). - Enables workflow routing based on evaluated conditions. --- api/app/core/workflow/executor.py | 81 +++++---- .../core/workflow/nodes/if_else/__init__.py | 5 + api/app/core/workflow/nodes/if_else/config.py | 122 +++++++++++++ api/app/core/workflow/nodes/if_else/node.py | 168 ++++++++++++++++++ 4 files changed, 343 insertions(+), 33 deletions(-) create mode 100644 api/app/core/workflow/nodes/if_else/__init__.py create mode 100644 api/app/core/workflow/nodes/if_else/config.py create mode 100644 api/app/core/workflow/nodes/if_else/node.py diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 9cf711db..3710e4ed 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -13,8 +13,9 @@ from langchain_core.messages import HumanMessage from langgraph.graph import StateGraph, START, END from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.nodes import WorkflowState, NodeFactory from app.core.workflow.expression_evaluator import evaluate_condition +from app.core.workflow.nodes import WorkflowState, NodeFactory +from app.core.workflow.nodes.enums import NodeType from app.core.tools.registry import ToolRegistry from app.core.tools.executor import ToolExecutor from app.core.tools.langchain_adapter import LangchainAdapter @@ -30,11 +31,11 @@ class WorkflowExecutor: """ def __init__( - self, - workflow_config: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str + self, + workflow_config: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str ): """初始化执行器 @@ -95,8 +96,6 @@ class WorkflowExecutor: "error_node": None } - - def build_graph(self) -> CompiledStateGraph: """构建 LangGraph @@ -117,19 +116,36 @@ class WorkflowExecutor: node_id = node.get("id") # 记录 start 和 end 节点 ID - if node_type == "start": + if node_type == NodeType.START: start_node_id = node_id - elif node_type == "end": + elif node_type == NodeType.END: end_node_ids.append(node_id) # 创建节点实例(现在 start 和 end 也会被创建) node_instance = NodeFactory.create_node(node, self.workflow_config) + + if node_type in [NodeType.IF_ELSE]: + # Build ordered boolean expression strings for each branch. + # These expressions will be attached to outgoing edges as + # LangGraph conditional routing rules. + expressions = node_instance.build_conditional_edge_expressions() + + # Collect all outgoing edges from the current node. + # The order of edges must match the order of generated expressions. + related_edge = [edge for edge in self.edges if edge.get("source") == node_id] + + # Attach each condition expression to the corresponding edge + # based on branch priority + for idx in range(len(expressions)): + related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + if node_instance: # 包装节点的 run 方法 # 使用函数工厂避免闭包问题 def make_node_func(inst): async def node_func(state: WorkflowState): return await inst.run(state) + return node_func workflow.add_node(node_id, make_node_func(node_instance)) @@ -170,14 +186,14 @@ class WorkflowExecutor: def router(state: WorkflowState, cond=condition, tgt=target): """条件路由函数""" if evaluate_condition( - cond, - state.get("variables", {}), - state.get("node_outputs", {}), - { - "execution_id": state.get("execution_id"), - "workspace_id": state.get("workspace_id"), - "user_id": state.get("user_id") - } + cond, + state.get("variables", {}), + state.get("node_outputs", {}), + { + "execution_id": state.get("execution_id"), + "workspace_id": state.get("workspace_id"), + "user_id": state.get("user_id") + } ): return tgt return END # 条件不满足,结束 @@ -201,8 +217,8 @@ class WorkflowExecutor: return graph async def execute( - self, - input_data: dict[str, Any] + self, + input_data: dict[str, Any] ) -> dict[str, Any]: """执行工作流(非流式) @@ -276,8 +292,8 @@ class WorkflowExecutor: } async def execute_stream( - self, - input_data: dict[str, Any] + self, + input_data: dict[str, Any] ): """执行工作流(流式) @@ -331,7 +347,6 @@ class WorkflowExecutor: "token_usage": None } - def _extract_final_output(self, node_outputs: dict[str, Any]) -> str | None: """从节点输出中提取最终输出 @@ -391,11 +406,11 @@ class WorkflowExecutor: async def execute_workflow( - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str ) -> dict[str, Any]: """执行工作流(便捷函数) @@ -419,11 +434,11 @@ async def execute_workflow( async def execute_workflow_stream( - workflow_config: dict[str, Any], - input_data: dict[str, Any], - execution_id: str, - workspace_id: str, - user_id: str + workflow_config: dict[str, Any], + input_data: dict[str, Any], + execution_id: str, + workspace_id: str, + user_id: str ): """执行工作流(流式,便捷函数) diff --git a/api/app/core/workflow/nodes/if_else/__init__.py b/api/app/core/workflow/nodes/if_else/__init__.py new file mode 100644 index 00000000..ffdf3b5b --- /dev/null +++ b/api/app/core/workflow/nodes/if_else/__init__.py @@ -0,0 +1,5 @@ +"""Condition Node""" +from app.core.workflow.nodes.if_else.config import IfElseNodeConfig +from app.core.workflow.nodes.if_else.node import IfElseNode + +__all__ = ["IfElseNode", "IfElseNodeConfig"] diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py new file mode 100644 index 00000000..1a9adbbb --- /dev/null +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -0,0 +1,122 @@ +"""Condition Configuration""" +from pydantic import Field, BaseModel, field_validator +from enum import StrEnum +from app.core.workflow.nodes.base_config import BaseNodeConfig + + +class LogicOperator(StrEnum): + AND = "and" + OR = "or" + + +class ComparisonOpeartor(StrEnum): + EMPTY = "empty" + NOT_EMPTY = "not_empty" + CONTAINS = "contains" + NOT_CONTAINS = "not_contains" + START_WITH = "startwith" + END_WITH = "endwith" + EQ = "eq" + NE = "ne" + LT = "lt" + LE = "le" + GT = "gt" + GE = "ge" + + +class ConditionDetail(BaseModel): + comparison_operator: ComparisonOpeartor = Field( + ..., + description="Comparison operator used to evaluate the condition" + ) + + left: str = Field( + ..., + description="Value to compare against" + ) + + right: str = Field( + ..., + description="Value to compare with" + ) + + +class ConditionBranchConfig(BaseModel): + """Configuration for a conditional branch""" + + logical_operator: LogicOperator = Field( + default=LogicOperator.AND.value, + description="Logical operator used to combine multiple condition expressions" + ) + + conditions: list[ConditionDetail] = Field( + ..., + description="List of condition expressions within this branch" + ) + + +class IfElseNodeConfig(BaseNodeConfig): + cases: list[ConditionBranchConfig] = Field( + ..., + description="List of branch conditions or expressions" + ) + + @field_validator("cases") + @classmethod + def validate_case_number(cls, v, info): + if len(v) < 1: + raise ValueError("At least one cases are required") + return v + + class Config: + json_schema_extra = { + "examples": [ + { + "cases": [ + # if/CASE1 + { + "logical_operator": "and", + "conditions": [ + { + "left": "sys.message", + "comparison_operator": "eq", + "right": "'test'" + } + ] + }, + ] + }, + { + "case_number": 3, + "cases": [ + # if/CASE1 + { + "logic": "or", + "conditions": [ + { + "left": "sys.message", + "comparison_operator": "eq", + "right": "'test'" + } + ] + }, + # elif/CASE2 + { + "logic": "and", + "conditions": [ + { + "left": "sys.message", + "comparison_operator": "eq", + "right": "'test'" + }, + { + "left": "sys.message", + "comparison_operator": "contains", + "right": "'test'" + } + ] + }, + ] + } + ] + } diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py new file mode 100644 index 00000000..3219edae --- /dev/null +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -0,0 +1,168 @@ +import logging +from typing import Any + +from simpleeval import NameNotDefined, InvalidExpression + +from app.core.workflow.nodes import BaseNode, WorkflowState +from app.core.workflow.nodes.if_else import IfElseNodeConfig +from app.core.workflow.nodes.if_else.config import LogicOperator, ConditionDetail, ComparisonOpeartor + +logger = logging.getLogger(__name__) + + +class ConditionExpressionBuilder: + """ + Build a Python boolean expression string based on a comparison operator. + + This class does not evaluate the expression. + It only generates a valid Python expression string + that can be evaluated later in a workflow context. + """ + + def __init__(self, left: str, operator: ComparisonOpeartor, right: str): + self.left = left + self.operator = operator + self.right = right + + def _empty(self): + return f"{self.left} == ''" + + def _not_empty(self): + return f"{self.left} != ''" + + def _contains(self): + return f"{self.right} in {self.left}" + + def _not_contains(self): + return f"{self.right} not in {self.left}" + + def _startwith(self): + return f'{self.left}.startswith({self.right})' + + def _endwith(self): + return f'{self.left}.endswith({self.right})' + + def _eq(self): + return f"{self.left} == {self.right}" + + def _ne(self): + return f"{self.left} != {self.right}" + + def _lt(self): + return f"{self.left} < {self.right}" + + def _le(self): + return f"{self.left} <= {self.right}" + + def _gt(self): + return f"{self.left} > {self.right}" + + def _ge(self): + return f"{self.left} >= {self.right}" + + def build(self): + match self.operator: + case ComparisonOpeartor.EMPTY: + return self._empty() + case ComparisonOpeartor.NOT_EMPTY: + return self._not_empty() + case ComparisonOpeartor.CONTAINS: + return self._contains() + case ComparisonOpeartor.NOT_CONTAINS: + return self._not_contains() + case ComparisonOpeartor.START_WITH: + return self._startwith() + case ComparisonOpeartor.END_WITH: + return self._endwith() + case ComparisonOpeartor.EQ: + return self._eq() + case ComparisonOpeartor.NE: + return self._ne() + case ComparisonOpeartor.LT: + return self._lt() + case ComparisonOpeartor.LE: + return self._le() + case ComparisonOpeartor.GT: + return self._gt() + case ComparisonOpeartor.GE: + return self._ge() + case _: + raise ValueError(f"Invalid condition: {self.operator}") + + +class IfElseNode(BaseNode): + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = IfElseNodeConfig(**self.config) + + @staticmethod + def _build_condition_expression( + condition: ConditionDetail, + ) -> str: + """ + Build a single boolean condition expression string. + + This method does NOT evaluate the condition. + It only generates a valid Python boolean expression string + (e.g. "x > 10", "'a' in name") that can later be used + in a conditional edge or evaluated by the workflow engine. + + Args: + condition (ConditionDetail): Definition of a single comparison condition. + + Returns: + str: A Python boolean expression string. + """ + return ConditionExpressionBuilder( + left=condition.left, + operator=condition.comparison_operator, + right=condition.right + ).build() + + def build_conditional_edge_expressions(self) -> list[str]: + """ + Build conditional edge expressions for the If-Else node. + + This method does NOT evaluate any condition at runtime. + Instead, it converts each case branch into a Python boolean + expression string, which will later be attached to LangGraph + as conditional edges. + + Each returned expression corresponds to one branch and is + evaluated in order. A fallback 'True' condition is appended + to ensure a default branch when no previous conditions match. + + Returns: + list[str]: A list of Python boolean expression strings, + ordered by branch priority. + """ + branch_index = 0 + conditions = [] + + for case_branch in self.typed_config.cases: + branch_index += 1 + + branch_conditions = [ + self._build_condition_expression(condition) + for condition in case_branch.conditions + ] + if len(branch_conditions) > 1: + combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) + else: + combined_condition = branch_conditions[0] + conditions.append(combined_condition) + + # Default fallback branch + conditions.append("True") + + return conditions + + async def execute(self, state: WorkflowState) -> Any: + """ + """ + expressions = self.build_conditional_edge_expressions() + for i in range(len(expressions)): + logger.info(expressions[i]) + if self._evaluate_condition(expressions[i], state): + return f'CASE{i+1}' + return f'CASE{len(expressions)}'