From b56994b9994dab22c3fe2d4e2b9cf6d873d11dbc Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Mon, 5 Jan 2026 10:57:44 +0800 Subject: [PATCH] feat(workflow): update workflow conditional logic --- .../core/workflow/nodes/cycle_graph/loop.py | 142 +++++--- api/app/core/workflow/nodes/end/node.py | 2 +- api/app/core/workflow/nodes/enums.py | 8 +- api/app/core/workflow/nodes/if_else/config.py | 14 +- api/app/core/workflow/nodes/if_else/node.py | 92 +++-- .../core/workflow/nodes/jinja_render/node.py | 1 + api/app/core/workflow/nodes/operators.py | 329 +++++++++++++----- .../nodes/question_classifier/node.py | 32 +- .../nodes/variable_aggregator/__init__.py | 2 +- 9 files changed, 436 insertions(+), 186 deletions(-) diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index af75d372..2e2ab4fb 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -3,10 +3,11 @@ from typing import Any from langgraph.graph.state import CompiledStateGraph -from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression +from app.core.workflow.expression_evaluator import evaluate_expression from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes.cycle_graph import LoopNodeConfig -from app.core.workflow.nodes.operators import ConditionExpressionBuilder +from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator +from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.variable_pool import VariablePool logger = logging.getLogger(__name__) @@ -14,11 +15,13 @@ logger = logging.getLogger(__name__) class LoopRuntime: """ - Runtime executor for loop nodes in a workflow. + Runtime executor for a loop node in a workflow graph. - Handles iterative execution of a loop node according to defined loop variables - and conditional expressions. Supports maximum loop count and loop control - through the workflow state. + This class is responsible for executing a loop node at runtime: + - Initializing loop-scoped variables + - Evaluating loop continuation conditions + - Repeatedly invoking a compiled sub-graph + - Enforcing maximum loop count and external stop signals """ def __init__( @@ -29,13 +32,13 @@ class LoopRuntime: state: WorkflowState, ): """ - Initialize the loop runtime. + Initialize the loop runtime executor. Args: - graph: Compiled workflow graph capable of async invocation. - node_id: Unique identifier of the loop node. - config: Dictionary containing loop node configuration. - state: Current workflow state at the point of loop execution. + graph: A compiled LangGraph state graph representing the loop body. + node_id: The unique identifier of the loop node in the workflow. + config: Raw configuration dictionary for the loop node. + state: The current workflow state before entering the loop. """ self.graph = graph self.state = state @@ -46,12 +49,15 @@ class LoopRuntime: """ Initialize workflow state for loop execution. - - Evaluates initial values of loop variables. - - Stores loop variables in runtime_vars and node_outputs. - - Marks the loop as active by setting 'looping' to True. + This method: + - Evaluates initial values of loop variables + - Stores loop variables into both `runtime_vars` and `node_outputs` + under the current loop node's scope + - Creates a shallow copy of the workflow state + - Marks the loop as active by setting `looping = True` Returns: - A copy of the workflow state prepared for the loop execution. + WorkflowState: A prepared workflow state used for loop execution. """ pool = VariablePool(self.state) # 循环变量 @@ -61,7 +67,7 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) + ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } self.state["node_outputs"][self.node_id] = { @@ -70,7 +76,7 @@ class LoopRuntime: variables=pool.get_all_conversation_vars(), node_outputs=pool.get_all_node_outputs(), system_vars=pool.get_all_system_vars(), - ) + ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars } loopstate = WorkflowState( @@ -79,49 +85,93 @@ class LoopRuntime: loopstate["looping"] = True return loopstate - def _get_loop_expression(self): + @staticmethod + def _evaluate(operator, instance: CompareOperatorInstance) -> Any: """ - Build the Python boolean expression for evaluating the loop condition. + Dispatch and execute a comparison operator against a resolved + CompareOperatorInstance. - - Converts each condition in the loop configuration into a Python expression string. - - Combines multiple conditions with the configured logical operator (AND/OR). + Args: + operator: A ComparisonOperator enum value. + instance: A CompareOperatorInstance bound to concrete operands. Returns: - A string representing the combined loop condition expression. + Any: The evaluation result, typically a boolean. """ - branch_conditions = [ - ConditionExpressionBuilder( - left=condition.left, - operator=condition.comparison_operator, - right=condition.right - ).build() - for condition in self.typed_config.condition.expressions - ] - if len(branch_conditions) > 1: - combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions) - else: - combined_condition = branch_conditions[0] + match operator: + case ComparisonOperator.EMPTY: + return instance.empty() + case ComparisonOperator.NOT_EMPTY: + return instance.not_empty() + case ComparisonOperator.CONTAINS: + return instance.contains() + case ComparisonOperator.NOT_CONTAINS: + return instance.not_contains() + case ComparisonOperator.START_WITH: + return instance.startswith() + case ComparisonOperator.END_WITH: + return instance.endswith() + case ComparisonOperator.EQ: + return instance.eq() + case ComparisonOperator.NE: + return instance.ne() + case ComparisonOperator.LT: + return instance.lt() + case ComparisonOperator.LE: + return instance.le() + case ComparisonOperator.GT: + return instance.gt() + case ComparisonOperator.GE: + return instance.ge() + case _: + raise ValueError(f"Invalid condition: {operator}") - return combined_condition + def evaluate_conditional(self, state) -> bool: + """ + Evaluate the loop continuation condition at runtime. + + This method: + - Resolves all condition expressions against the current workflow state + - Evaluates each comparison expression immediately + - Combines results using the configured logical operator (AND / OR) + + Args: + state: The current workflow state during loop execution. + + Returns: + bool: True if the loop should continue, False otherwise. + """ + conditions = [] + + for expression in self.typed_config.condition.expressions: + left_value = VariablePool(state).get(expression.left) + evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( + VariablePool(state), + expression.left, + expression.right, + expression.input_type + ) + conditions.append(self._evaluate(expression.operator, evaluator)) + if self.typed_config.condition.logical_operator == LogicOperator.AND: + return all(conditions) + else: + return any(conditions) async def run(self): """ - Execute the loop node until the condition is no longer met, the loop is - manually stopped, or the maximum loop count is reached. + Execute the loop node until termination conditions are met. + + The loop terminates when any of the following occurs: + - The loop condition evaluates to False + - The `looping` flag in the workflow state is set to False + - The maximum loop count is reached Returns: - The final runtime variables of this loop node after completion. + dict[str, Any]: The final runtime variables of this loop node. """ loopstate = self._init_loop_state() - expression = self._get_loop_expression() - loop_variable_pool = VariablePool(loopstate) loop_time = self.typed_config.max_loop - while evaluate_condition( - expression=expression, - variables=loop_variable_pool.get_all_conversation_vars(), - node_outputs=loop_variable_pool.get_all_node_outputs(), - system_vars=loop_variable_pool.get_all_system_vars(), - ) and loopstate["looping"] and loop_time > 0: + while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0: logger.info(f"loop node {self.node_id}: running") await self.graph.ainvoke(loopstate) loop_time -= 1 diff --git a/api/app/core/workflow/nodes/end/node.py b/api/app/core/workflow/nodes/end/node.py index 65bb6cb5..efc62dc5 100644 --- a/api/app/core/workflow/nodes/end/node.py +++ b/api/app/core/workflow/nodes/end/node.py @@ -61,7 +61,7 @@ class EndNode(BaseNode): 引用的节点 ID 列表 """ # 匹配 {{node_id.xxx}} 格式 - pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}' + pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}' matches = re.findall(pattern, template) return list(set(matches)) # 去重 diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 0492a7bf..b1c9d687 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -45,7 +45,8 @@ class LogicOperator(StrEnum): class AssignmentOperator(StrEnum): - ASSIGN = "assign" + COVER = "cover" # 覆盖 + ASSIGN = "assign" # 设置 CLEAR = "clear" ADD = "add" # += @@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum): NONE = "none" DEFAULT = "default" BRANCH = "branch" + + +class ValueInputType(StrEnum): + VARIABLE = "Variable" + CONSTANT = "Constant" diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 9eddb473..4dcb00d1 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -1,12 +1,13 @@ """Condition Configuration""" +from typing import Any from pydantic import Field, BaseModel, field_validator from app.core.workflow.nodes.base_config import BaseNodeConfig -from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType class ConditionDetail(BaseModel): - comparison_operator: ComparisonOperator = Field( + operator: ComparisonOperator = Field( ..., description="Comparison operator used to evaluate the condition" ) @@ -16,17 +17,22 @@ class ConditionDetail(BaseModel): description="Value to compare against" ) - right: str = Field( + right: Any = Field( ..., description="Value to compare with" ) + input_type: ValueInputType = Field( + ..., + description="Value input type for comparison" + ) + class ConditionBranchConfig(BaseModel): """Configuration for a conditional branch""" logical_operator: LogicOperator = Field( - default=LogicOperator.AND.value, + default=LogicOperator.AND, description="Logical operator used to combine multiple condition expressions" ) diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 1450a28f..fd5864a8 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -1,10 +1,11 @@ import logging +import re from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.if_else import IfElseNodeConfig -from app.core.workflow.nodes.if_else.config import ConditionDetail -from app.core.workflow.nodes.operators import ConditionExpressionBuilder +from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance logger = logging.getLogger(__name__) @@ -15,30 +16,36 @@ class IfElseNode(BaseNode): self.typed_config = IfElseNodeConfig(**self.config) @staticmethod - def _build_condition_expression( - condition: ConditionDetail, - ) -> str: - """ - Build a single boolean condition expression string. + def _evaluate(operator, instance: CompareOperatorInstance) -> Any: + match operator: + case ComparisonOperator.EMPTY: + return instance.empty() + case ComparisonOperator.NOT_EMPTY: + return instance.not_empty() + case ComparisonOperator.CONTAINS: + return instance.contains() + case ComparisonOperator.NOT_CONTAINS: + return instance.not_contains() + case ComparisonOperator.START_WITH: + return instance.startswith() + case ComparisonOperator.END_WITH: + return instance.endswith() + case ComparisonOperator.EQ: + return instance.eq() + case ComparisonOperator.NE: + return instance.ne() + case ComparisonOperator.LT: + return instance.lt() + case ComparisonOperator.LE: + return instance.le() + case ComparisonOperator.GT: + return instance.gt() + case ComparisonOperator.GE: + return instance.ge() + case _: + raise ValueError(f"Invalid condition: {operator}") - This method does NOT evaluate the condition. - It only generates a valid Python boolean expression string - (e.g. "x > 10", "'a' in name") that can later be used - in a conditional edge or evaluated by the workflow engine. - - Args: - condition (ConditionDetail): Definition of a single comparison condition. - - Returns: - str: A Python boolean expression string. - """ - return ConditionExpressionBuilder( - left=condition.left, - operator=condition.comparison_operator, - right=condition.right - ).build() - - def build_conditional_edge_expressions(self) -> list[str]: + def evaluate_conditional_edge_expressions(self, state) -> list[bool]: """ Build conditional edge expressions for the If-Else node. @@ -60,19 +67,28 @@ class IfElseNode(BaseNode): for case_branch in self.typed_config.cases: branch_index += 1 - - branch_conditions = [ - self._build_condition_expression(condition) - for condition in case_branch.expressions - ] - if len(branch_conditions) > 1: - combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) + branch_result = [] + for expression in case_branch.expressions: + pattern = r"\{\{\s*(.*?)\s*\}\}" + left_string = re.sub(pattern, r"\1", expression.left).strip() + left_value = self.get_variable(left_string, state) + evaluator = ConditionExpressionResolver.resolve_by_value(left_value)( + self.get_variable_pool(state), + expression.left, + expression.right, + expression.input_type + ) + branch_result.append(self._evaluate(expression.operator, evaluator)) + if case_branch.logical_operator == LogicOperator.AND: + conditions.append(all(branch_result)) else: - combined_condition = branch_conditions[0] - conditions.append(combined_condition) + condition_res = any(branch_result) + conditions.append(condition_res) + if condition_res: + return conditions # Default fallback branch - conditions.append("True") + conditions.append(True) return conditions @@ -90,10 +106,10 @@ class IfElseNode(BaseNode): Returns: str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. """ - expressions = self.build_conditional_edge_expressions() + expressions = self.evaluate_conditional_edge_expressions(state) + # TODO: 变量类型及文本类型解析 for i in range(len(expressions)): - logger.info(expressions[i]) - if self._evaluate_condition(expressions[i], state): + if expressions[i]: logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}") return f'CASE{i + 1}' return f'CASE{len(expressions)}' diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index 6130c30a..e18a2001 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer logger = logging.getLogger(__name__) + class JinjaRenderNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index 70668b6a..fc856aee 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -1,10 +1,73 @@ +import json +import re from abc import ABC -from typing import Union, Type +from typing import Union, Type, NoReturn -from app.core.workflow.nodes.enums import ComparisonOperator +from app.core.workflow.nodes.base_config import VariableType +from app.core.workflow.nodes.enums import ValueInputType from app.core.workflow.variable_pool import VariablePool +class TypeTransformer: + @classmethod + def _fail(cls, value, target) -> NoReturn: + raise TypeError(f"Cannot convert {value!r} to {target} type") + + @classmethod + def _json_load(cls, value, target): + try: + return json.loads(value) + except Exception: + cls._fail(value, target) + + @classmethod + def transform(cls, variable_literal: str | bool, target_type: VariableType): + match target_type: + case VariableType.STRING: + return str(variable_literal) + + case VariableType.NUMBER: + for caster in (int, float): + try: + return caster(variable_literal) + except Exception: + pass + cls._fail(variable_literal, target_type) + + case VariableType.BOOLEAN: + if isinstance(variable_literal, bool): + return variable_literal + cls._fail(variable_literal, target_type) + + case VariableType.OBJECT: + obj = cls._json_load(variable_literal, target_type) + if isinstance(obj, dict): + return obj + cls._fail(variable_literal, target_type) + + case VariableType.ARRAY_BOOLEAN: + return cls._parse_list(variable_literal, bool, target_type) + + case VariableType.ARRAY_NUMBER: + return cls._parse_list(variable_literal, (int, float), target_type) + + case VariableType.ARRAY_STRING: + return cls._parse_list(variable_literal, str, target_type) + + case VariableType.ARRAY_OBJECT: + return cls._parse_list(variable_literal, dict, target_type) + + case _: + raise TypeError("Invalid type") + + @classmethod + def _parse_list(cls, value, item_type, target): + arr = cls._json_load(value, target) + if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr): + return arr + cls._fail(value, target) + + class OperatorBase(ABC): def __init__(self, pool: VariablePool, left_selector, right): self.pool = pool @@ -19,7 +82,9 @@ class OperatorBase(ABC): raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") if not no_right and not isinstance(self.right, self.type_limit): - raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type") + raise TypeError( + f"The value assigned must be of {self.type_limit} type" + ) class StringOperator(OperatorBase): @@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase): class ObjectOperator(OperatorBase): def __init__(self, pool: VariablePool, left_selector, right): super().__init__(pool, left_selector, right) - self.type_limit = object + self.type_limit = dict def assign(self) -> None: self.check() @@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase): class AssignmentOperatorResolver: + OPERATOR_MAP = { + str: StringOperator, + bool: BooleanOperator, + int: NumberOperator, + float: NumberOperator, + list: ArrayOperator, + dict: ObjectOperator, + } + @classmethod def resolve_by_value(cls, value): - if isinstance(value, str): - return StringOperator - elif isinstance(value, bool): - return BooleanOperator - elif isinstance(value, (int, float)): - return NumberOperator - elif isinstance(value, list): - return ArrayOperator - elif isinstance(value, dict): - return ObjectOperator - else: - raise TypeError(f"Unsupported variable type: {type(value)}") + for t, op in cls.OPERATOR_MAP.items(): + if isinstance(value, t): + return op + raise TypeError(f"Unsupported variable type: {type(value)}") AssignmentOperatorInstance = Union[ @@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[ AssignmentOperatorType = Type[AssignmentOperatorInstance] -class ConditionExpressionBuilder: - """ - Build a Python boolean expression string based on a comparison operator. +class ConditionBase(ABC): + type_limit: type[str, int, dict, list] = None - This class does not evaluate the expression. - It only generates a valid Python expression string - that can be evaluated later in a workflow context. - """ + def __init__( + self, + pool: VariablePool, + left_selector, + right_selector: str, + input_type: ValueInputType + ): + self.pool = pool + self.left_selector = left_selector + self.right_selector = right_selector + self.input_type = input_type - def __init__(self, left: str, operator: ComparisonOperator, right: str): - self.left = left - self.operator = operator - self.right = right + self.left_value = self.pool.get(self.left_selector) + self.right_value = self.resolve_right_literal_value() - def _empty(self): - return f"{self.left} == ''" + self.type_limit = getattr(self, "type_limit", None) - def _not_empty(self): - return f"{self.left} != ''" + def resolve_right_literal_value(self): + if self.input_type == ValueInputType.VARIABLE: + pattern = r"\{\{\s*(.*?)\s*\}\}" + right_expression = re.sub(pattern, r"\1", self.right_selector).strip() + return self.pool.get(right_expression) + elif self.input_type == ValueInputType.CONSTANT: + return self.right_selector + raise RuntimeError("Unsupported variable type") - def _contains(self): - return f"{self.right} in {self.left}" + def check(self, no_right=False): + left = self.pool.get(self.left_selector.variable_selector) + if not isinstance(left, self.type_limit): + raise TypeError(f"The variable to be compared on must be of {self.type_limit} type") + if not no_right: + right = self.resolve_right_literal_value() + if not isinstance(right, self.type_limit): + raise TypeError( + f"The compared variable must be of {self.type_limit} type" + ) - def _not_contains(self): - return f"{self.right} not in {self.left}" - def _startswith(self): - return f'{self.left}.startswith({self.right})' +class StringComparisonOperator(ConditionBase): + type_limit = str - def _endswith(self): - return f'{self.left}.endswith({self.right})' + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) - def _eq(self): - return f"{self.left} == {self.right}" + def empty(self): + self.check(no_right=True) + return self.left_value == "" - def _ne(self): - return f"{self.left} != {self.right}" + def not_empty(self): + return not self.empty() - def _lt(self): - return f"{self.left} < {self.right}" + def contains(self): + self.check() + return self.right_value in self.left_value - def _le(self): - return f"{self.left} <= {self.right}" + def not_contains(self): + return self.right_value not in self.left_value - def _gt(self): - return f"{self.left} > {self.right}" + def startswith(self): + self.check() + return self.left_value.startswith(self.right_value) - def _ge(self): - return f"{self.left} >= {self.right}" + def endswith(self): + return self.left_value.endswith(self.right_value) - def build(self): - match self.operator: - case ComparisonOperator.EMPTY: - return self._empty() - case ComparisonOperator.NOT_EMPTY: - return self._not_empty() - case ComparisonOperator.CONTAINS: - return self._contains() - case ComparisonOperator.NOT_CONTAINS: - return self._not_contains() - case ComparisonOperator.START_WITH: - return self._startswith() - case ComparisonOperator.END_WITH: - return self._endswith() - case ComparisonOperator.EQ: - return self._eq() - case ComparisonOperator.NE: - return self._ne() - case ComparisonOperator.LT: - return self._lt() - case ComparisonOperator.LE: - return self._le() - case ComparisonOperator.GT: - return self._gt() - case ComparisonOperator.GE: - return self._ge() - case _: - raise ValueError(f"Invalid condition: {self.operator}") + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + +class NumberComparisonOperator(ConditionBase): + type_limit = (int, float) + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def empty(self): + return self.left_value == 0 + + def not_empty(self): + return self.left_value != 0 + + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + def lt(self): + return self.left_value < self.right_value + + def le(self): + return self.left_value <= self.right_value + + def gt(self): + return self.left_value > self.right_value + + def ge(self): + return self.left_value >= self.right_value + + +class BooleanComparisonOperator(ConditionBase): + type_limit = bool + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + +class ObjectComparisonOperator(ConditionBase): + type_limit = dict + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def eq(self): + return self.left_value == self.right_value + + def ne(self): + return self.left_value != self.right_value + + def empty(self): + return not self.left_value + + def not_empty(self): + return bool(self.left_value) + + +class ArrayComparisonOperator(ConditionBase): + type_limit = list + + def __init__(self, pool: VariablePool, left_selector, right_selector, input_type): + super().__init__(pool, left_selector, right_selector, input_type) + + def empty(self): + return not self.left_value + + def not_empty(self): + return bool(self.left_value) + + def contains(self): + return self.right_value in self.left_value + + def not_contains(self): + return self.right_value not in self.left_value + + +CompareOperatorInstance = Union[ + StringComparisonOperator, + NumberComparisonOperator, + BooleanComparisonOperator, + ArrayComparisonOperator, + ObjectComparisonOperator +] +CompareOperatorType = Type[CompareOperatorInstance] + + +class ConditionExpressionResolver: + CONDITION_OPERATOR_MAP = { + str: StringComparisonOperator, + bool: BooleanComparisonOperator, + int: NumberComparisonOperator, + float: NumberComparisonOperator, + list: ArrayComparisonOperator, + dict: ObjectComparisonOperator, + } + + @classmethod + def resolve_by_value(cls, value) -> CompareOperatorType: + for t, op in cls.CONDITION_OPERATOR_MAP.items(): + if isinstance(value, t): + return op + raise TypeError(f"Unsupported variable type: {type(value)}") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index bd3c8752..7e6a40b2 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -15,29 +15,29 @@ logger = logging.getLogger(__name__) class QuestionClassifierNode(BaseNode): """问题分类器节点""" - + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config = QuestionClassifierNodeConfig(**self.config) - + def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" with get_db_read() as db: config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id) - + if not config: raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) - + if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - + api_config = config.api_keys[0] model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key base_url = api_config.api_base model_type = config.type - + return RedBearLLM( RedBearModelConfig( model_name=model_name, @@ -47,7 +47,7 @@ class QuestionClassifierNode(BaseNode): ), type=ModelType(model_type) ) - + async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行问题分类""" question = self.typed_config.input_variable @@ -55,15 +55,15 @@ class QuestionClassifierNode(BaseNode): supplement_prompt = "" if self.typed_config.user_supplement_prompt is not None: supplement_prompt = self.typed_config.user_supplement_prompt - + category_names = [class_item.class_name for class_item in self.typed_config.categories] - + if not question: logger.warning(f"节点 {self.node_id} 未获取到输入问题") return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"} - + llm = self._get_llm_instance() - + # 渲染用户提示词模板,支持工作流变量 user_prompt = self._render_template( self.typed_config.user_prompt.format( @@ -73,15 +73,15 @@ class QuestionClassifierNode(BaseNode): ), state ) - + messages = [ ("system", self.typed_config.system_prompt), ("user", user_prompt), ] - + response = await llm.ainvoke(messages) result = response.content.strip() - + if result in category_names: category = result else: @@ -90,5 +90,5 @@ class QuestionClassifierNode(BaseNode): log_supplement = supplement_prompt if supplement_prompt else "无" logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - - return {self.typed_config.output_variable: category} \ No newline at end of file + + return {self.typed_config.output_variable: category} diff --git a/api/app/core/workflow/nodes/variable_aggregator/__init__.py b/api/app/core/workflow/nodes/variable_aggregator/__init__.py index 7bc9afa7..d7eda8f5 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/__init__.py +++ b/api/app/core/workflow/nodes/variable_aggregator/__init__.py @@ -1,4 +1,4 @@ from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode -__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"] \ No newline at end of file +__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]