feat(workflow): update workflow conditional logic

This commit is contained in:
mengyonghao
2026-01-05 10:57:44 +08:00
parent eaf2437633
commit b56994b999
9 changed files with 436 additions and 186 deletions

View File

@@ -3,10 +3,11 @@ from typing import Any
from langgraph.graph.state import CompiledStateGraph from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression from app.core.workflow.expression_evaluator import evaluate_expression
from app.core.workflow.nodes import WorkflowState from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -14,11 +15,13 @@ logger = logging.getLogger(__name__)
class LoopRuntime: class LoopRuntime:
""" """
Runtime executor for loop nodes in a workflow. Runtime executor for a loop node in a workflow graph.
Handles iterative execution of a loop node according to defined loop variables This class is responsible for executing a loop node at runtime:
and conditional expressions. Supports maximum loop count and loop control - Initializing loop-scoped variables
through the workflow state. - Evaluating loop continuation conditions
- Repeatedly invoking a compiled sub-graph
- Enforcing maximum loop count and external stop signals
""" """
def __init__( def __init__(
@@ -29,13 +32,13 @@ class LoopRuntime:
state: WorkflowState, state: WorkflowState,
): ):
""" """
Initialize the loop runtime. Initialize the loop runtime executor.
Args: Args:
graph: Compiled workflow graph capable of async invocation. graph: A compiled LangGraph state graph representing the loop body.
node_id: Unique identifier of the loop node. node_id: The unique identifier of the loop node in the workflow.
config: Dictionary containing loop node configuration. config: Raw configuration dictionary for the loop node.
state: Current workflow state at the point of loop execution. state: The current workflow state before entering the loop.
""" """
self.graph = graph self.graph = graph
self.state = state self.state = state
@@ -46,12 +49,15 @@ class LoopRuntime:
""" """
Initialize workflow state for loop execution. Initialize workflow state for loop execution.
- Evaluates initial values of loop variables. This method:
- Stores loop variables in runtime_vars and node_outputs. - Evaluates initial values of loop variables
- Marks the loop as active by setting 'looping' to True. - Stores loop variables into both `runtime_vars` and `node_outputs`
under the current loop node's scope
- Creates a shallow copy of the workflow state
- Marks the loop as active by setting `looping = True`
Returns: Returns:
A copy of the workflow state prepared for the loop execution. WorkflowState: A prepared workflow state used for loop execution.
""" """
pool = VariablePool(self.state) pool = VariablePool(self.state)
# 循环变量 # 循环变量
@@ -61,7 +67,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(), variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(), node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(), system_vars=pool.get_all_system_vars(),
) ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars for variable in self.typed_config.cycle_vars
} }
self.state["node_outputs"][self.node_id] = { self.state["node_outputs"][self.node_id] = {
@@ -70,7 +76,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(), variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(), node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(), system_vars=pool.get_all_system_vars(),
) ) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars for variable in self.typed_config.cycle_vars
} }
loopstate = WorkflowState( loopstate = WorkflowState(
@@ -79,49 +85,93 @@ class LoopRuntime:
loopstate["looping"] = True loopstate["looping"] = True
return loopstate return loopstate
def _get_loop_expression(self): @staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
""" """
Build the Python boolean expression for evaluating the loop condition. Dispatch and execute a comparison operator against a resolved
CompareOperatorInstance.
- Converts each condition in the loop configuration into a Python expression string. Args:
- Combines multiple conditions with the configured logical operator (AND/OR). operator: A ComparisonOperator enum value.
instance: A CompareOperatorInstance bound to concrete operands.
Returns: Returns:
A string representing the combined loop condition expression. Any: The evaluation result, typically a boolean.
""" """
branch_conditions = [ match operator:
ConditionExpressionBuilder( case ComparisonOperator.EMPTY:
left=condition.left, return instance.empty()
operator=condition.comparison_operator, case ComparisonOperator.NOT_EMPTY:
right=condition.right return instance.not_empty()
).build() case ComparisonOperator.CONTAINS:
for condition in self.typed_config.condition.expressions return instance.contains()
] case ComparisonOperator.NOT_CONTAINS:
if len(branch_conditions) > 1: return instance.not_contains()
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions) case ComparisonOperator.START_WITH:
else: return instance.startswith()
combined_condition = branch_conditions[0] case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
return combined_condition def evaluate_conditional(self, state) -> bool:
"""
Evaluate the loop continuation condition at runtime.
This method:
- Resolves all condition expressions against the current workflow state
- Evaluates each comparison expression immediately
- Combines results using the configured logical operator (AND / OR)
Args:
state: The current workflow state during loop execution.
Returns:
bool: True if the loop should continue, False otherwise.
"""
conditions = []
for expression in self.typed_config.condition.expressions:
left_value = VariablePool(state).get(expression.left)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
VariablePool(state),
expression.left,
expression.right,
expression.input_type
)
conditions.append(self._evaluate(expression.operator, evaluator))
if self.typed_config.condition.logical_operator == LogicOperator.AND:
return all(conditions)
else:
return any(conditions)
async def run(self): async def run(self):
""" """
Execute the loop node until the condition is no longer met, the loop is Execute the loop node until termination conditions are met.
manually stopped, or the maximum loop count is reached.
The loop terminates when any of the following occurs:
- The loop condition evaluates to False
- The `looping` flag in the workflow state is set to False
- The maximum loop count is reached
Returns: Returns:
The final runtime variables of this loop node after completion. dict[str, Any]: The final runtime variables of this loop node.
""" """
loopstate = self._init_loop_state() loopstate = self._init_loop_state()
expression = self._get_loop_expression()
loop_variable_pool = VariablePool(loopstate)
loop_time = self.typed_config.max_loop loop_time = self.typed_config.max_loop
while evaluate_condition( while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
expression=expression,
variables=loop_variable_pool.get_all_conversation_vars(),
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running") logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate) await self.graph.ainvoke(loopstate)
loop_time -= 1 loop_time -= 1

View File

@@ -61,7 +61,7 @@ class EndNode(BaseNode):
引用的节点 ID 列表 引用的节点 ID 列表
""" """
# 匹配 {{node_id.xxx}} 格式 # 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}' pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template) matches = re.findall(pattern, template)
return list(set(matches)) # 去重 return list(set(matches)) # 去重

View File

@@ -45,7 +45,8 @@ class LogicOperator(StrEnum):
class AssignmentOperator(StrEnum): class AssignmentOperator(StrEnum):
ASSIGN = "assign" COVER = "cover" # 覆盖
ASSIGN = "assign" # 设置
CLEAR = "clear" CLEAR = "clear"
ADD = "add" # += ADD = "add" # +=
@@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum):
NONE = "none" NONE = "none"
DEFAULT = "default" DEFAULT = "default"
BRANCH = "branch" BRANCH = "branch"
class ValueInputType(StrEnum):
VARIABLE = "Variable"
CONSTANT = "Constant"

View File

@@ -1,12 +1,13 @@
"""Condition Configuration""" """Condition Configuration"""
from typing import Any
from pydantic import Field, BaseModel, field_validator from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class ConditionDetail(BaseModel): class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field( operator: ComparisonOperator = Field(
..., ...,
description="Comparison operator used to evaluate the condition" description="Comparison operator used to evaluate the condition"
) )
@@ -16,17 +17,22 @@ class ConditionDetail(BaseModel):
description="Value to compare against" description="Value to compare against"
) )
right: str = Field( right: Any = Field(
..., ...,
description="Value to compare with" description="Value to compare with"
) )
input_type: ValueInputType = Field(
...,
description="Value input type for comparison"
)
class ConditionBranchConfig(BaseModel): class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch""" """Configuration for a conditional branch"""
logical_operator: LogicOperator = Field( logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value, default=LogicOperator.AND,
description="Logical operator used to combine multiple condition expressions" description="Logical operator used to combine multiple condition expressions"
) )

View File

@@ -1,10 +1,11 @@
import logging import logging
import re
from typing import Any from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,30 +16,36 @@ class IfElseNode(BaseNode):
self.typed_config = IfElseNodeConfig(**self.config) self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod @staticmethod
def _build_condition_expression( def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
condition: ConditionDetail, match operator:
) -> str: case ComparisonOperator.EMPTY:
""" return instance.empty()
Build a single boolean condition expression string. case ComparisonOperator.NOT_EMPTY:
return instance.not_empty()
case ComparisonOperator.CONTAINS:
return instance.contains()
case ComparisonOperator.NOT_CONTAINS:
return instance.not_contains()
case ComparisonOperator.START_WITH:
return instance.startswith()
case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
This method does NOT evaluate the condition. def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
""" """
Build conditional edge expressions for the If-Else node. Build conditional edge expressions for the If-Else node.
@@ -60,19 +67,28 @@ class IfElseNode(BaseNode):
for case_branch in self.typed_config.cases: for case_branch in self.typed_config.cases:
branch_index += 1 branch_index += 1
branch_result = []
branch_conditions = [ for expression in case_branch.expressions:
self._build_condition_expression(condition) pattern = r"\{\{\s*(.*?)\s*\}\}"
for condition in case_branch.expressions left_string = re.sub(pattern, r"\1", expression.left).strip()
] left_value = self.get_variable(left_string, state)
if len(branch_conditions) > 1: evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions) self.get_variable_pool(state),
expression.left,
expression.right,
expression.input_type
)
branch_result.append(self._evaluate(expression.operator, evaluator))
if case_branch.logical_operator == LogicOperator.AND:
conditions.append(all(branch_result))
else: else:
combined_condition = branch_conditions[0] condition_res = any(branch_result)
conditions.append(combined_condition) conditions.append(condition_res)
if condition_res:
return conditions
# Default fallback branch # Default fallback branch
conditions.append("True") conditions.append(True)
return conditions return conditions
@@ -90,10 +106,10 @@ class IfElseNode(BaseNode):
Returns: Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
""" """
expressions = self.build_conditional_edge_expressions() expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
for i in range(len(expressions)): for i in range(len(expressions)):
logger.info(expressions[i]) if expressions[i]:
if self._evaluate_condition(expressions[i], state):
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}") logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}' return f'CASE{i + 1}'
return f'CASE{len(expressions)}' return f'CASE{len(expressions)}'

View File

@@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)

View File

@@ -1,10 +1,73 @@
import json
import re
from abc import ABC from abc import ABC
from typing import Union, Type from typing import Union, Type, NoReturn
from app.core.workflow.nodes.enums import ComparisonOperator from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import ValueInputType
from app.core.workflow.variable_pool import VariablePool from app.core.workflow.variable_pool import VariablePool
class TypeTransformer:
@classmethod
def _fail(cls, value, target) -> NoReturn:
raise TypeError(f"Cannot convert {value!r} to {target} type")
@classmethod
def _json_load(cls, value, target):
try:
return json.loads(value)
except Exception:
cls._fail(value, target)
@classmethod
def transform(cls, variable_literal: str | bool, target_type: VariableType):
match target_type:
case VariableType.STRING:
return str(variable_literal)
case VariableType.NUMBER:
for caster in (int, float):
try:
return caster(variable_literal)
except Exception:
pass
cls._fail(variable_literal, target_type)
case VariableType.BOOLEAN:
if isinstance(variable_literal, bool):
return variable_literal
cls._fail(variable_literal, target_type)
case VariableType.OBJECT:
obj = cls._json_load(variable_literal, target_type)
if isinstance(obj, dict):
return obj
cls._fail(variable_literal, target_type)
case VariableType.ARRAY_BOOLEAN:
return cls._parse_list(variable_literal, bool, target_type)
case VariableType.ARRAY_NUMBER:
return cls._parse_list(variable_literal, (int, float), target_type)
case VariableType.ARRAY_STRING:
return cls._parse_list(variable_literal, str, target_type)
case VariableType.ARRAY_OBJECT:
return cls._parse_list(variable_literal, dict, target_type)
case _:
raise TypeError("Invalid type")
@classmethod
def _parse_list(cls, value, item_type, target):
arr = cls._json_load(value, target)
if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr):
return arr
cls._fail(value, target)
class OperatorBase(ABC): class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right): def __init__(self, pool: VariablePool, left_selector, right):
self.pool = pool self.pool = pool
@@ -19,7 +82,9 @@ class OperatorBase(ABC):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type") raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
if not no_right and not isinstance(self.right, self.type_limit): if not no_right and not isinstance(self.right, self.type_limit):
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type") raise TypeError(
f"The value assigned must be of {self.type_limit} type"
)
class StringOperator(OperatorBase): class StringOperator(OperatorBase):
@@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase):
class ObjectOperator(OperatorBase): class ObjectOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right): def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right) super().__init__(pool, left_selector, right)
self.type_limit = object self.type_limit = dict
def assign(self) -> None: def assign(self) -> None:
self.check() self.check()
@@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase):
class AssignmentOperatorResolver: class AssignmentOperatorResolver:
OPERATOR_MAP = {
str: StringOperator,
bool: BooleanOperator,
int: NumberOperator,
float: NumberOperator,
list: ArrayOperator,
dict: ObjectOperator,
}
@classmethod @classmethod
def resolve_by_value(cls, value): def resolve_by_value(cls, value):
if isinstance(value, str): for t, op in cls.OPERATOR_MAP.items():
return StringOperator if isinstance(value, t):
elif isinstance(value, bool): return op
return BooleanOperator raise TypeError(f"Unsupported variable type: {type(value)}")
elif isinstance(value, (int, float)):
return NumberOperator
elif isinstance(value, list):
return ArrayOperator
elif isinstance(value, dict):
return ObjectOperator
else:
raise TypeError(f"Unsupported variable type: {type(value)}")
AssignmentOperatorInstance = Union[ AssignmentOperatorInstance = Union[
@@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[
AssignmentOperatorType = Type[AssignmentOperatorInstance] AssignmentOperatorType = Type[AssignmentOperatorInstance]
class ConditionExpressionBuilder: class ConditionBase(ABC):
""" type_limit: type[str, int, dict, list] = None
Build a Python boolean expression string based on a comparison operator.
This class does not evaluate the expression. def __init__(
It only generates a valid Python expression string self,
that can be evaluated later in a workflow context. pool: VariablePool,
""" left_selector,
right_selector: str,
input_type: ValueInputType
):
self.pool = pool
self.left_selector = left_selector
self.right_selector = right_selector
self.input_type = input_type
def __init__(self, left: str, operator: ComparisonOperator, right: str): self.left_value = self.pool.get(self.left_selector)
self.left = left self.right_value = self.resolve_right_literal_value()
self.operator = operator
self.right = right
def _empty(self): self.type_limit = getattr(self, "type_limit", None)
return f"{self.left} == ''"
def _not_empty(self): def resolve_right_literal_value(self):
return f"{self.left} != ''" if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
return self.pool.get(right_expression)
elif self.input_type == ValueInputType.CONSTANT:
return self.right_selector
raise RuntimeError("Unsupported variable type")
def _contains(self): def check(self, no_right=False):
return f"{self.right} in {self.left}" left = self.pool.get(self.left_selector.variable_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
if not no_right:
right = self.resolve_right_literal_value()
if not isinstance(right, self.type_limit):
raise TypeError(
f"The compared variable must be of {self.type_limit} type"
)
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startswith(self): class StringComparisonOperator(ConditionBase):
return f'{self.left}.startswith({self.right})' type_limit = str
def _endswith(self): def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
return f'{self.left}.endswith({self.right})' super().__init__(pool, left_selector, right_selector, input_type)
def _eq(self): def empty(self):
return f"{self.left} == {self.right}" self.check(no_right=True)
return self.left_value == ""
def _ne(self): def not_empty(self):
return f"{self.left} != {self.right}" return not self.empty()
def _lt(self): def contains(self):
return f"{self.left} < {self.right}" self.check()
return self.right_value in self.left_value
def _le(self): def not_contains(self):
return f"{self.left} <= {self.right}" return self.right_value not in self.left_value
def _gt(self): def startswith(self):
return f"{self.left} > {self.right}" self.check()
return self.left_value.startswith(self.right_value)
def _ge(self): def endswith(self):
return f"{self.left} >= {self.right}" return self.left_value.endswith(self.right_value)
def build(self): def eq(self):
match self.operator: return self.left_value == self.right_value
case ComparisonOperator.EMPTY:
return self._empty() def ne(self):
case ComparisonOperator.NOT_EMPTY: return self.left_value != self.right_value
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains() class NumberComparisonOperator(ConditionBase):
case ComparisonOperator.NOT_CONTAINS: type_limit = (int, float)
return self._not_contains()
case ComparisonOperator.START_WITH: def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
return self._startswith() super().__init__(pool, left_selector, right_selector, input_type)
case ComparisonOperator.END_WITH:
return self._endswith() def empty(self):
case ComparisonOperator.EQ: return self.left_value == 0
return self._eq()
case ComparisonOperator.NE: def not_empty(self):
return self._ne() return self.left_value != 0
case ComparisonOperator.LT:
return self._lt() def eq(self):
case ComparisonOperator.LE: return self.left_value == self.right_value
return self._le()
case ComparisonOperator.GT: def ne(self):
return self._gt() return self.left_value != self.right_value
case ComparisonOperator.GE:
return self._ge() def lt(self):
case _: return self.left_value < self.right_value
raise ValueError(f"Invalid condition: {self.operator}")
def le(self):
return self.left_value <= self.right_value
def gt(self):
return self.left_value > self.right_value
def ge(self):
return self.left_value >= self.right_value
class BooleanComparisonOperator(ConditionBase):
type_limit = bool
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
class ObjectComparisonOperator(ConditionBase):
type_limit = dict
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
class ArrayComparisonOperator(ConditionBase):
type_limit = list
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
def contains(self):
return self.right_value in self.left_value
def not_contains(self):
return self.right_value not in self.left_value
CompareOperatorInstance = Union[
StringComparisonOperator,
NumberComparisonOperator,
BooleanComparisonOperator,
ArrayComparisonOperator,
ObjectComparisonOperator
]
CompareOperatorType = Type[CompareOperatorInstance]
class ConditionExpressionResolver:
CONDITION_OPERATOR_MAP = {
str: StringComparisonOperator,
bool: BooleanComparisonOperator,
int: NumberComparisonOperator,
float: NumberComparisonOperator,
list: ArrayComparisonOperator,
dict: ObjectComparisonOperator,
}
@classmethod
def resolve_by_value(cls, value) -> CompareOperatorType:
for t, op in cls.CONDITION_OPERATOR_MAP.items():
if isinstance(value, t):
return op
raise TypeError(f"Unsupported variable type: {type(value)}")

View File

@@ -15,29 +15,29 @@ logger = logging.getLogger(__name__)
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config) super().__init__(node_config, workflow_config)
self.typed_config = QuestionClassifierNodeConfig(**self.config) self.typed_config = QuestionClassifierNodeConfig(**self.config)
def _get_llm_instance(self) -> RedBearLLM: def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例""" """获取LLM实例"""
with get_db_read() as db: with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id) config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
if not config: if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0: if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0] api_config = config.api_keys[0]
model_name = api_config.model_name model_name = api_config.model_name
provider = api_config.provider provider = api_config.provider
api_key = api_config.api_key api_key = api_config.api_key
base_url = api_config.api_base base_url = api_config.api_base
model_type = config.type model_type = config.type
return RedBearLLM( return RedBearLLM(
RedBearModelConfig( RedBearModelConfig(
model_name=model_name, model_name=model_name,
@@ -47,7 +47,7 @@ class QuestionClassifierNode(BaseNode):
), ),
type=ModelType(model_type) type=ModelType(model_type)
) )
async def execute(self, state: WorkflowState) -> dict[str, Any]: async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行问题分类""" """执行问题分类"""
question = self.typed_config.input_variable question = self.typed_config.input_variable
@@ -55,15 +55,15 @@ class QuestionClassifierNode(BaseNode):
supplement_prompt = "" supplement_prompt = ""
if self.typed_config.user_supplement_prompt is not None: if self.typed_config.user_supplement_prompt is not None:
supplement_prompt = self.typed_config.user_supplement_prompt supplement_prompt = self.typed_config.user_supplement_prompt
category_names = [class_item.class_name for class_item in self.typed_config.categories] category_names = [class_item.class_name for class_item in self.typed_config.categories]
if not question: if not question:
logger.warning(f"节点 {self.node_id} 未获取到输入问题") logger.warning(f"节点 {self.node_id} 未获取到输入问题")
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"} return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
llm = self._get_llm_instance() llm = self._get_llm_instance()
# 渲染用户提示词模板,支持工作流变量 # 渲染用户提示词模板,支持工作流变量
user_prompt = self._render_template( user_prompt = self._render_template(
self.typed_config.user_prompt.format( self.typed_config.user_prompt.format(
@@ -73,15 +73,15 @@ class QuestionClassifierNode(BaseNode):
), ),
state state
) )
messages = [ messages = [
("system", self.typed_config.system_prompt), ("system", self.typed_config.system_prompt),
("user", user_prompt), ("user", user_prompt),
] ]
response = await llm.ainvoke(messages) response = await llm.ainvoke(messages)
result = response.content.strip() result = response.content.strip()
if result in category_names: if result in category_names:
category = result category = result
else: else:
@@ -90,5 +90,5 @@ class QuestionClassifierNode(BaseNode):
log_supplement = supplement_prompt if supplement_prompt else "" log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return {self.typed_config.output_variable: category} return {self.typed_config.output_variable: category}

View File

@@ -1,4 +1,4 @@
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"] __all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]