feat(workflow): update workflow conditional logic
This commit is contained in:
@@ -3,10 +3,11 @@ from typing import Any
|
|||||||
|
|
||||||
from langgraph.graph.state import CompiledStateGraph
|
from langgraph.graph.state import CompiledStateGraph
|
||||||
|
|
||||||
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
|
from app.core.workflow.expression_evaluator import evaluate_expression
|
||||||
from app.core.workflow.nodes import WorkflowState
|
from app.core.workflow.nodes import WorkflowState
|
||||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
|
||||||
|
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -14,11 +15,13 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class LoopRuntime:
|
class LoopRuntime:
|
||||||
"""
|
"""
|
||||||
Runtime executor for loop nodes in a workflow.
|
Runtime executor for a loop node in a workflow graph.
|
||||||
|
|
||||||
Handles iterative execution of a loop node according to defined loop variables
|
This class is responsible for executing a loop node at runtime:
|
||||||
and conditional expressions. Supports maximum loop count and loop control
|
- Initializing loop-scoped variables
|
||||||
through the workflow state.
|
- Evaluating loop continuation conditions
|
||||||
|
- Repeatedly invoking a compiled sub-graph
|
||||||
|
- Enforcing maximum loop count and external stop signals
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -29,13 +32,13 @@ class LoopRuntime:
|
|||||||
state: WorkflowState,
|
state: WorkflowState,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the loop runtime.
|
Initialize the loop runtime executor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Compiled workflow graph capable of async invocation.
|
graph: A compiled LangGraph state graph representing the loop body.
|
||||||
node_id: Unique identifier of the loop node.
|
node_id: The unique identifier of the loop node in the workflow.
|
||||||
config: Dictionary containing loop node configuration.
|
config: Raw configuration dictionary for the loop node.
|
||||||
state: Current workflow state at the point of loop execution.
|
state: The current workflow state before entering the loop.
|
||||||
"""
|
"""
|
||||||
self.graph = graph
|
self.graph = graph
|
||||||
self.state = state
|
self.state = state
|
||||||
@@ -46,12 +49,15 @@ class LoopRuntime:
|
|||||||
"""
|
"""
|
||||||
Initialize workflow state for loop execution.
|
Initialize workflow state for loop execution.
|
||||||
|
|
||||||
- Evaluates initial values of loop variables.
|
This method:
|
||||||
- Stores loop variables in runtime_vars and node_outputs.
|
- Evaluates initial values of loop variables
|
||||||
- Marks the loop as active by setting 'looping' to True.
|
- Stores loop variables into both `runtime_vars` and `node_outputs`
|
||||||
|
under the current loop node's scope
|
||||||
|
- Creates a shallow copy of the workflow state
|
||||||
|
- Marks the loop as active by setting `looping = True`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A copy of the workflow state prepared for the loop execution.
|
WorkflowState: A prepared workflow state used for loop execution.
|
||||||
"""
|
"""
|
||||||
pool = VariablePool(self.state)
|
pool = VariablePool(self.state)
|
||||||
# 循环变量
|
# 循环变量
|
||||||
@@ -61,7 +67,7 @@ class LoopRuntime:
|
|||||||
variables=pool.get_all_conversation_vars(),
|
variables=pool.get_all_conversation_vars(),
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
system_vars=pool.get_all_system_vars(),
|
system_vars=pool.get_all_system_vars(),
|
||||||
)
|
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
|
||||||
for variable in self.typed_config.cycle_vars
|
for variable in self.typed_config.cycle_vars
|
||||||
}
|
}
|
||||||
self.state["node_outputs"][self.node_id] = {
|
self.state["node_outputs"][self.node_id] = {
|
||||||
@@ -70,7 +76,7 @@ class LoopRuntime:
|
|||||||
variables=pool.get_all_conversation_vars(),
|
variables=pool.get_all_conversation_vars(),
|
||||||
node_outputs=pool.get_all_node_outputs(),
|
node_outputs=pool.get_all_node_outputs(),
|
||||||
system_vars=pool.get_all_system_vars(),
|
system_vars=pool.get_all_system_vars(),
|
||||||
)
|
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
|
||||||
for variable in self.typed_config.cycle_vars
|
for variable in self.typed_config.cycle_vars
|
||||||
}
|
}
|
||||||
loopstate = WorkflowState(
|
loopstate = WorkflowState(
|
||||||
@@ -79,49 +85,93 @@ class LoopRuntime:
|
|||||||
loopstate["looping"] = True
|
loopstate["looping"] = True
|
||||||
return loopstate
|
return loopstate
|
||||||
|
|
||||||
def _get_loop_expression(self):
|
@staticmethod
|
||||||
|
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||||
"""
|
"""
|
||||||
Build the Python boolean expression for evaluating the loop condition.
|
Dispatch and execute a comparison operator against a resolved
|
||||||
|
CompareOperatorInstance.
|
||||||
|
|
||||||
- Converts each condition in the loop configuration into a Python expression string.
|
Args:
|
||||||
- Combines multiple conditions with the configured logical operator (AND/OR).
|
operator: A ComparisonOperator enum value.
|
||||||
|
instance: A CompareOperatorInstance bound to concrete operands.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A string representing the combined loop condition expression.
|
Any: The evaluation result, typically a boolean.
|
||||||
"""
|
"""
|
||||||
branch_conditions = [
|
match operator:
|
||||||
ConditionExpressionBuilder(
|
case ComparisonOperator.EMPTY:
|
||||||
left=condition.left,
|
return instance.empty()
|
||||||
operator=condition.comparison_operator,
|
case ComparisonOperator.NOT_EMPTY:
|
||||||
right=condition.right
|
return instance.not_empty()
|
||||||
).build()
|
case ComparisonOperator.CONTAINS:
|
||||||
for condition in self.typed_config.condition.expressions
|
return instance.contains()
|
||||||
]
|
case ComparisonOperator.NOT_CONTAINS:
|
||||||
if len(branch_conditions) > 1:
|
return instance.not_contains()
|
||||||
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
|
case ComparisonOperator.START_WITH:
|
||||||
else:
|
return instance.startswith()
|
||||||
combined_condition = branch_conditions[0]
|
case ComparisonOperator.END_WITH:
|
||||||
|
return instance.endswith()
|
||||||
|
case ComparisonOperator.EQ:
|
||||||
|
return instance.eq()
|
||||||
|
case ComparisonOperator.NE:
|
||||||
|
return instance.ne()
|
||||||
|
case ComparisonOperator.LT:
|
||||||
|
return instance.lt()
|
||||||
|
case ComparisonOperator.LE:
|
||||||
|
return instance.le()
|
||||||
|
case ComparisonOperator.GT:
|
||||||
|
return instance.gt()
|
||||||
|
case ComparisonOperator.GE:
|
||||||
|
return instance.ge()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {operator}")
|
||||||
|
|
||||||
return combined_condition
|
def evaluate_conditional(self, state) -> bool:
|
||||||
|
"""
|
||||||
|
Evaluate the loop continuation condition at runtime.
|
||||||
|
|
||||||
|
This method:
|
||||||
|
- Resolves all condition expressions against the current workflow state
|
||||||
|
- Evaluates each comparison expression immediately
|
||||||
|
- Combines results using the configured logical operator (AND / OR)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The current workflow state during loop execution.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the loop should continue, False otherwise.
|
||||||
|
"""
|
||||||
|
conditions = []
|
||||||
|
|
||||||
|
for expression in self.typed_config.condition.expressions:
|
||||||
|
left_value = VariablePool(state).get(expression.left)
|
||||||
|
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||||
|
VariablePool(state),
|
||||||
|
expression.left,
|
||||||
|
expression.right,
|
||||||
|
expression.input_type
|
||||||
|
)
|
||||||
|
conditions.append(self._evaluate(expression.operator, evaluator))
|
||||||
|
if self.typed_config.condition.logical_operator == LogicOperator.AND:
|
||||||
|
return all(conditions)
|
||||||
|
else:
|
||||||
|
return any(conditions)
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""
|
"""
|
||||||
Execute the loop node until the condition is no longer met, the loop is
|
Execute the loop node until termination conditions are met.
|
||||||
manually stopped, or the maximum loop count is reached.
|
|
||||||
|
The loop terminates when any of the following occurs:
|
||||||
|
- The loop condition evaluates to False
|
||||||
|
- The `looping` flag in the workflow state is set to False
|
||||||
|
- The maximum loop count is reached
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The final runtime variables of this loop node after completion.
|
dict[str, Any]: The final runtime variables of this loop node.
|
||||||
"""
|
"""
|
||||||
loopstate = self._init_loop_state()
|
loopstate = self._init_loop_state()
|
||||||
expression = self._get_loop_expression()
|
|
||||||
loop_variable_pool = VariablePool(loopstate)
|
|
||||||
loop_time = self.typed_config.max_loop
|
loop_time = self.typed_config.max_loop
|
||||||
while evaluate_condition(
|
while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
|
||||||
expression=expression,
|
|
||||||
variables=loop_variable_pool.get_all_conversation_vars(),
|
|
||||||
node_outputs=loop_variable_pool.get_all_node_outputs(),
|
|
||||||
system_vars=loop_variable_pool.get_all_system_vars(),
|
|
||||||
) and loopstate["looping"] and loop_time > 0:
|
|
||||||
logger.info(f"loop node {self.node_id}: running")
|
logger.info(f"loop node {self.node_id}: running")
|
||||||
await self.graph.ainvoke(loopstate)
|
await self.graph.ainvoke(loopstate)
|
||||||
loop_time -= 1
|
loop_time -= 1
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class EndNode(BaseNode):
|
|||||||
引用的节点 ID 列表
|
引用的节点 ID 列表
|
||||||
"""
|
"""
|
||||||
# 匹配 {{node_id.xxx}} 格式
|
# 匹配 {{node_id.xxx}} 格式
|
||||||
pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}'
|
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||||
matches = re.findall(pattern, template)
|
matches = re.findall(pattern, template)
|
||||||
return list(set(matches)) # 去重
|
return list(set(matches)) # 去重
|
||||||
|
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ class LogicOperator(StrEnum):
|
|||||||
|
|
||||||
|
|
||||||
class AssignmentOperator(StrEnum):
|
class AssignmentOperator(StrEnum):
|
||||||
ASSIGN = "assign"
|
COVER = "cover" # 覆盖
|
||||||
|
ASSIGN = "assign" # 设置
|
||||||
CLEAR = "clear"
|
CLEAR = "clear"
|
||||||
|
|
||||||
ADD = "add" # +=
|
ADD = "add" # +=
|
||||||
@@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum):
|
|||||||
NONE = "none"
|
NONE = "none"
|
||||||
DEFAULT = "default"
|
DEFAULT = "default"
|
||||||
BRANCH = "branch"
|
BRANCH = "branch"
|
||||||
|
|
||||||
|
|
||||||
|
class ValueInputType(StrEnum):
|
||||||
|
VARIABLE = "Variable"
|
||||||
|
CONSTANT = "Constant"
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
"""Condition Configuration"""
|
"""Condition Configuration"""
|
||||||
|
from typing import Any
|
||||||
from pydantic import Field, BaseModel, field_validator
|
from pydantic import Field, BaseModel, field_validator
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||||
|
|
||||||
|
|
||||||
class ConditionDetail(BaseModel):
|
class ConditionDetail(BaseModel):
|
||||||
comparison_operator: ComparisonOperator = Field(
|
operator: ComparisonOperator = Field(
|
||||||
...,
|
...,
|
||||||
description="Comparison operator used to evaluate the condition"
|
description="Comparison operator used to evaluate the condition"
|
||||||
)
|
)
|
||||||
@@ -16,17 +17,22 @@ class ConditionDetail(BaseModel):
|
|||||||
description="Value to compare against"
|
description="Value to compare against"
|
||||||
)
|
)
|
||||||
|
|
||||||
right: str = Field(
|
right: Any = Field(
|
||||||
...,
|
...,
|
||||||
description="Value to compare with"
|
description="Value to compare with"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
input_type: ValueInputType = Field(
|
||||||
|
...,
|
||||||
|
description="Value input type for comparison"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ConditionBranchConfig(BaseModel):
|
class ConditionBranchConfig(BaseModel):
|
||||||
"""Configuration for a conditional branch"""
|
"""Configuration for a conditional branch"""
|
||||||
|
|
||||||
logical_operator: LogicOperator = Field(
|
logical_operator: LogicOperator = Field(
|
||||||
default=LogicOperator.AND.value,
|
default=LogicOperator.AND,
|
||||||
description="Logical operator used to combine multiple condition expressions"
|
description="Logical operator used to combine multiple condition expressions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||||
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -15,30 +16,36 @@ class IfElseNode(BaseNode):
|
|||||||
self.typed_config = IfElseNodeConfig(**self.config)
|
self.typed_config = IfElseNodeConfig(**self.config)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_condition_expression(
|
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||||
condition: ConditionDetail,
|
match operator:
|
||||||
) -> str:
|
case ComparisonOperator.EMPTY:
|
||||||
"""
|
return instance.empty()
|
||||||
Build a single boolean condition expression string.
|
case ComparisonOperator.NOT_EMPTY:
|
||||||
|
return instance.not_empty()
|
||||||
|
case ComparisonOperator.CONTAINS:
|
||||||
|
return instance.contains()
|
||||||
|
case ComparisonOperator.NOT_CONTAINS:
|
||||||
|
return instance.not_contains()
|
||||||
|
case ComparisonOperator.START_WITH:
|
||||||
|
return instance.startswith()
|
||||||
|
case ComparisonOperator.END_WITH:
|
||||||
|
return instance.endswith()
|
||||||
|
case ComparisonOperator.EQ:
|
||||||
|
return instance.eq()
|
||||||
|
case ComparisonOperator.NE:
|
||||||
|
return instance.ne()
|
||||||
|
case ComparisonOperator.LT:
|
||||||
|
return instance.lt()
|
||||||
|
case ComparisonOperator.LE:
|
||||||
|
return instance.le()
|
||||||
|
case ComparisonOperator.GT:
|
||||||
|
return instance.gt()
|
||||||
|
case ComparisonOperator.GE:
|
||||||
|
return instance.ge()
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid condition: {operator}")
|
||||||
|
|
||||||
This method does NOT evaluate the condition.
|
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
|
||||||
It only generates a valid Python boolean expression string
|
|
||||||
(e.g. "x > 10", "'a' in name") that can later be used
|
|
||||||
in a conditional edge or evaluated by the workflow engine.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
condition (ConditionDetail): Definition of a single comparison condition.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A Python boolean expression string.
|
|
||||||
"""
|
|
||||||
return ConditionExpressionBuilder(
|
|
||||||
left=condition.left,
|
|
||||||
operator=condition.comparison_operator,
|
|
||||||
right=condition.right
|
|
||||||
).build()
|
|
||||||
|
|
||||||
def build_conditional_edge_expressions(self) -> list[str]:
|
|
||||||
"""
|
"""
|
||||||
Build conditional edge expressions for the If-Else node.
|
Build conditional edge expressions for the If-Else node.
|
||||||
|
|
||||||
@@ -60,19 +67,28 @@ class IfElseNode(BaseNode):
|
|||||||
|
|
||||||
for case_branch in self.typed_config.cases:
|
for case_branch in self.typed_config.cases:
|
||||||
branch_index += 1
|
branch_index += 1
|
||||||
|
branch_result = []
|
||||||
branch_conditions = [
|
for expression in case_branch.expressions:
|
||||||
self._build_condition_expression(condition)
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
for condition in case_branch.expressions
|
left_string = re.sub(pattern, r"\1", expression.left).strip()
|
||||||
]
|
left_value = self.get_variable(left_string, state)
|
||||||
if len(branch_conditions) > 1:
|
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
|
||||||
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
|
self.get_variable_pool(state),
|
||||||
|
expression.left,
|
||||||
|
expression.right,
|
||||||
|
expression.input_type
|
||||||
|
)
|
||||||
|
branch_result.append(self._evaluate(expression.operator, evaluator))
|
||||||
|
if case_branch.logical_operator == LogicOperator.AND:
|
||||||
|
conditions.append(all(branch_result))
|
||||||
else:
|
else:
|
||||||
combined_condition = branch_conditions[0]
|
condition_res = any(branch_result)
|
||||||
conditions.append(combined_condition)
|
conditions.append(condition_res)
|
||||||
|
if condition_res:
|
||||||
|
return conditions
|
||||||
|
|
||||||
# Default fallback branch
|
# Default fallback branch
|
||||||
conditions.append("True")
|
conditions.append(True)
|
||||||
|
|
||||||
return conditions
|
return conditions
|
||||||
|
|
||||||
@@ -90,10 +106,10 @@ class IfElseNode(BaseNode):
|
|||||||
Returns:
|
Returns:
|
||||||
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
|
||||||
"""
|
"""
|
||||||
expressions = self.build_conditional_edge_expressions()
|
expressions = self.evaluate_conditional_edge_expressions(state)
|
||||||
|
# TODO: 变量类型及文本类型解析
|
||||||
for i in range(len(expressions)):
|
for i in range(len(expressions)):
|
||||||
logger.info(expressions[i])
|
if expressions[i]:
|
||||||
if self._evaluate_condition(expressions[i], state):
|
|
||||||
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
|
||||||
return f'CASE{i + 1}'
|
return f'CASE{i + 1}'
|
||||||
return f'CASE{len(expressions)}'
|
return f'CASE{len(expressions)}'
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class JinjaRenderNode(BaseNode):
|
class JinjaRenderNode(BaseNode):
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
|
|||||||
@@ -1,10 +1,73 @@
|
|||||||
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Union, Type
|
from typing import Union, Type, NoReturn
|
||||||
|
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
from app.core.workflow.nodes.base_config import VariableType
|
||||||
|
from app.core.workflow.nodes.enums import ValueInputType
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
|
class TypeTransformer:
|
||||||
|
@classmethod
|
||||||
|
def _fail(cls, value, target) -> NoReturn:
|
||||||
|
raise TypeError(f"Cannot convert {value!r} to {target} type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _json_load(cls, value, target):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except Exception:
|
||||||
|
cls._fail(value, target)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def transform(cls, variable_literal: str | bool, target_type: VariableType):
|
||||||
|
match target_type:
|
||||||
|
case VariableType.STRING:
|
||||||
|
return str(variable_literal)
|
||||||
|
|
||||||
|
case VariableType.NUMBER:
|
||||||
|
for caster in (int, float):
|
||||||
|
try:
|
||||||
|
return caster(variable_literal)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
cls._fail(variable_literal, target_type)
|
||||||
|
|
||||||
|
case VariableType.BOOLEAN:
|
||||||
|
if isinstance(variable_literal, bool):
|
||||||
|
return variable_literal
|
||||||
|
cls._fail(variable_literal, target_type)
|
||||||
|
|
||||||
|
case VariableType.OBJECT:
|
||||||
|
obj = cls._json_load(variable_literal, target_type)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return obj
|
||||||
|
cls._fail(variable_literal, target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_BOOLEAN:
|
||||||
|
return cls._parse_list(variable_literal, bool, target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_NUMBER:
|
||||||
|
return cls._parse_list(variable_literal, (int, float), target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_STRING:
|
||||||
|
return cls._parse_list(variable_literal, str, target_type)
|
||||||
|
|
||||||
|
case VariableType.ARRAY_OBJECT:
|
||||||
|
return cls._parse_list(variable_literal, dict, target_type)
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise TypeError("Invalid type")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_list(cls, value, item_type, target):
|
||||||
|
arr = cls._json_load(value, target)
|
||||||
|
if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr):
|
||||||
|
return arr
|
||||||
|
cls._fail(value, target)
|
||||||
|
|
||||||
|
|
||||||
class OperatorBase(ABC):
|
class OperatorBase(ABC):
|
||||||
def __init__(self, pool: VariablePool, left_selector, right):
|
def __init__(self, pool: VariablePool, left_selector, right):
|
||||||
self.pool = pool
|
self.pool = pool
|
||||||
@@ -19,7 +82,9 @@ class OperatorBase(ABC):
|
|||||||
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
||||||
|
|
||||||
if not no_right and not isinstance(self.right, self.type_limit):
|
if not no_right and not isinstance(self.right, self.type_limit):
|
||||||
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
|
raise TypeError(
|
||||||
|
f"The value assigned must be of {self.type_limit} type"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StringOperator(OperatorBase):
|
class StringOperator(OperatorBase):
|
||||||
@@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase):
|
|||||||
class ObjectOperator(OperatorBase):
|
class ObjectOperator(OperatorBase):
|
||||||
def __init__(self, pool: VariablePool, left_selector, right):
|
def __init__(self, pool: VariablePool, left_selector, right):
|
||||||
super().__init__(pool, left_selector, right)
|
super().__init__(pool, left_selector, right)
|
||||||
self.type_limit = object
|
self.type_limit = dict
|
||||||
|
|
||||||
def assign(self) -> None:
|
def assign(self) -> None:
|
||||||
self.check()
|
self.check()
|
||||||
@@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase):
|
|||||||
|
|
||||||
|
|
||||||
class AssignmentOperatorResolver:
|
class AssignmentOperatorResolver:
|
||||||
|
OPERATOR_MAP = {
|
||||||
|
str: StringOperator,
|
||||||
|
bool: BooleanOperator,
|
||||||
|
int: NumberOperator,
|
||||||
|
float: NumberOperator,
|
||||||
|
list: ArrayOperator,
|
||||||
|
dict: ObjectOperator,
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def resolve_by_value(cls, value):
|
def resolve_by_value(cls, value):
|
||||||
if isinstance(value, str):
|
for t, op in cls.OPERATOR_MAP.items():
|
||||||
return StringOperator
|
if isinstance(value, t):
|
||||||
elif isinstance(value, bool):
|
return op
|
||||||
return BooleanOperator
|
raise TypeError(f"Unsupported variable type: {type(value)}")
|
||||||
elif isinstance(value, (int, float)):
|
|
||||||
return NumberOperator
|
|
||||||
elif isinstance(value, list):
|
|
||||||
return ArrayOperator
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
return ObjectOperator
|
|
||||||
else:
|
|
||||||
raise TypeError(f"Unsupported variable type: {type(value)}")
|
|
||||||
|
|
||||||
|
|
||||||
AssignmentOperatorInstance = Union[
|
AssignmentOperatorInstance = Union[
|
||||||
@@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[
|
|||||||
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
||||||
|
|
||||||
|
|
||||||
class ConditionExpressionBuilder:
|
class ConditionBase(ABC):
|
||||||
"""
|
type_limit: type[str, int, dict, list] = None
|
||||||
Build a Python boolean expression string based on a comparison operator.
|
|
||||||
|
|
||||||
This class does not evaluate the expression.
|
def __init__(
|
||||||
It only generates a valid Python expression string
|
self,
|
||||||
that can be evaluated later in a workflow context.
|
pool: VariablePool,
|
||||||
"""
|
left_selector,
|
||||||
|
right_selector: str,
|
||||||
|
input_type: ValueInputType
|
||||||
|
):
|
||||||
|
self.pool = pool
|
||||||
|
self.left_selector = left_selector
|
||||||
|
self.right_selector = right_selector
|
||||||
|
self.input_type = input_type
|
||||||
|
|
||||||
def __init__(self, left: str, operator: ComparisonOperator, right: str):
|
self.left_value = self.pool.get(self.left_selector)
|
||||||
self.left = left
|
self.right_value = self.resolve_right_literal_value()
|
||||||
self.operator = operator
|
|
||||||
self.right = right
|
|
||||||
|
|
||||||
def _empty(self):
|
self.type_limit = getattr(self, "type_limit", None)
|
||||||
return f"{self.left} == ''"
|
|
||||||
|
|
||||||
def _not_empty(self):
|
def resolve_right_literal_value(self):
|
||||||
return f"{self.left} != ''"
|
if self.input_type == ValueInputType.VARIABLE:
|
||||||
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
|
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||||
|
return self.pool.get(right_expression)
|
||||||
|
elif self.input_type == ValueInputType.CONSTANT:
|
||||||
|
return self.right_selector
|
||||||
|
raise RuntimeError("Unsupported variable type")
|
||||||
|
|
||||||
def _contains(self):
|
def check(self, no_right=False):
|
||||||
return f"{self.right} in {self.left}"
|
left = self.pool.get(self.left_selector.variable_selector)
|
||||||
|
if not isinstance(left, self.type_limit):
|
||||||
|
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
|
||||||
|
if not no_right:
|
||||||
|
right = self.resolve_right_literal_value()
|
||||||
|
if not isinstance(right, self.type_limit):
|
||||||
|
raise TypeError(
|
||||||
|
f"The compared variable must be of {self.type_limit} type"
|
||||||
|
)
|
||||||
|
|
||||||
def _not_contains(self):
|
|
||||||
return f"{self.right} not in {self.left}"
|
|
||||||
|
|
||||||
def _startswith(self):
|
class StringComparisonOperator(ConditionBase):
|
||||||
return f'{self.left}.startswith({self.right})'
|
type_limit = str
|
||||||
|
|
||||||
def _endswith(self):
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
return f'{self.left}.endswith({self.right})'
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
def _eq(self):
|
def empty(self):
|
||||||
return f"{self.left} == {self.right}"
|
self.check(no_right=True)
|
||||||
|
return self.left_value == ""
|
||||||
|
|
||||||
def _ne(self):
|
def not_empty(self):
|
||||||
return f"{self.left} != {self.right}"
|
return not self.empty()
|
||||||
|
|
||||||
def _lt(self):
|
def contains(self):
|
||||||
return f"{self.left} < {self.right}"
|
self.check()
|
||||||
|
return self.right_value in self.left_value
|
||||||
|
|
||||||
def _le(self):
|
def not_contains(self):
|
||||||
return f"{self.left} <= {self.right}"
|
return self.right_value not in self.left_value
|
||||||
|
|
||||||
def _gt(self):
|
def startswith(self):
|
||||||
return f"{self.left} > {self.right}"
|
self.check()
|
||||||
|
return self.left_value.startswith(self.right_value)
|
||||||
|
|
||||||
def _ge(self):
|
def endswith(self):
|
||||||
return f"{self.left} >= {self.right}"
|
return self.left_value.endswith(self.right_value)
|
||||||
|
|
||||||
def build(self):
|
def eq(self):
|
||||||
match self.operator:
|
return self.left_value == self.right_value
|
||||||
case ComparisonOperator.EMPTY:
|
|
||||||
return self._empty()
|
def ne(self):
|
||||||
case ComparisonOperator.NOT_EMPTY:
|
return self.left_value != self.right_value
|
||||||
return self._not_empty()
|
|
||||||
case ComparisonOperator.CONTAINS:
|
|
||||||
return self._contains()
|
class NumberComparisonOperator(ConditionBase):
|
||||||
case ComparisonOperator.NOT_CONTAINS:
|
type_limit = (int, float)
|
||||||
return self._not_contains()
|
|
||||||
case ComparisonOperator.START_WITH:
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
return self._startswith()
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
case ComparisonOperator.END_WITH:
|
|
||||||
return self._endswith()
|
def empty(self):
|
||||||
case ComparisonOperator.EQ:
|
return self.left_value == 0
|
||||||
return self._eq()
|
|
||||||
case ComparisonOperator.NE:
|
def not_empty(self):
|
||||||
return self._ne()
|
return self.left_value != 0
|
||||||
case ComparisonOperator.LT:
|
|
||||||
return self._lt()
|
def eq(self):
|
||||||
case ComparisonOperator.LE:
|
return self.left_value == self.right_value
|
||||||
return self._le()
|
|
||||||
case ComparisonOperator.GT:
|
def ne(self):
|
||||||
return self._gt()
|
return self.left_value != self.right_value
|
||||||
case ComparisonOperator.GE:
|
|
||||||
return self._ge()
|
def lt(self):
|
||||||
case _:
|
return self.left_value < self.right_value
|
||||||
raise ValueError(f"Invalid condition: {self.operator}")
|
|
||||||
|
def le(self):
|
||||||
|
return self.left_value <= self.right_value
|
||||||
|
|
||||||
|
def gt(self):
|
||||||
|
return self.left_value > self.right_value
|
||||||
|
|
||||||
|
def ge(self):
|
||||||
|
return self.left_value >= self.right_value
|
||||||
|
|
||||||
|
|
||||||
|
class BooleanComparisonOperator(ConditionBase):
|
||||||
|
type_limit = bool
|
||||||
|
|
||||||
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
|
def eq(self):
|
||||||
|
return self.left_value == self.right_value
|
||||||
|
|
||||||
|
def ne(self):
|
||||||
|
return self.left_value != self.right_value
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectComparisonOperator(ConditionBase):
|
||||||
|
type_limit = dict
|
||||||
|
|
||||||
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
|
def eq(self):
|
||||||
|
return self.left_value == self.right_value
|
||||||
|
|
||||||
|
def ne(self):
|
||||||
|
return self.left_value != self.right_value
|
||||||
|
|
||||||
|
def empty(self):
|
||||||
|
return not self.left_value
|
||||||
|
|
||||||
|
def not_empty(self):
|
||||||
|
return bool(self.left_value)
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayComparisonOperator(ConditionBase):
|
||||||
|
type_limit = list
|
||||||
|
|
||||||
|
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
|
||||||
|
super().__init__(pool, left_selector, right_selector, input_type)
|
||||||
|
|
||||||
|
def empty(self):
|
||||||
|
return not self.left_value
|
||||||
|
|
||||||
|
def not_empty(self):
|
||||||
|
return bool(self.left_value)
|
||||||
|
|
||||||
|
def contains(self):
|
||||||
|
return self.right_value in self.left_value
|
||||||
|
|
||||||
|
def not_contains(self):
|
||||||
|
return self.right_value not in self.left_value
|
||||||
|
|
||||||
|
|
||||||
|
CompareOperatorInstance = Union[
|
||||||
|
StringComparisonOperator,
|
||||||
|
NumberComparisonOperator,
|
||||||
|
BooleanComparisonOperator,
|
||||||
|
ArrayComparisonOperator,
|
||||||
|
ObjectComparisonOperator
|
||||||
|
]
|
||||||
|
CompareOperatorType = Type[CompareOperatorInstance]
|
||||||
|
|
||||||
|
|
||||||
|
class ConditionExpressionResolver:
|
||||||
|
CONDITION_OPERATOR_MAP = {
|
||||||
|
str: StringComparisonOperator,
|
||||||
|
bool: BooleanComparisonOperator,
|
||||||
|
int: NumberComparisonOperator,
|
||||||
|
float: NumberComparisonOperator,
|
||||||
|
list: ArrayComparisonOperator,
|
||||||
|
dict: ObjectComparisonOperator,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve_by_value(cls, value) -> CompareOperatorType:
|
||||||
|
for t, op in cls.CONDITION_OPERATOR_MAP.items():
|
||||||
|
if isinstance(value, t):
|
||||||
|
return op
|
||||||
|
raise TypeError(f"Unsupported variable type: {type(value)}")
|
||||||
|
|||||||
@@ -15,29 +15,29 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class QuestionClassifierNode(BaseNode):
|
class QuestionClassifierNode(BaseNode):
|
||||||
"""问题分类器节点"""
|
"""问题分类器节点"""
|
||||||
|
|
||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||||
|
|
||||||
def _get_llm_instance(self) -> RedBearLLM:
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
"""获取LLM实例"""
|
"""获取LLM实例"""
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
||||||
|
|
||||||
if not config:
|
if not config:
|
||||||
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
if not config.api_keys or len(config.api_keys) == 0:
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
api_config = config.api_keys[0]
|
api_config = config.api_keys[0]
|
||||||
model_name = api_config.model_name
|
model_name = api_config.model_name
|
||||||
provider = api_config.provider
|
provider = api_config.provider
|
||||||
api_key = api_config.api_key
|
api_key = api_config.api_key
|
||||||
base_url = api_config.api_base
|
base_url = api_config.api_base
|
||||||
model_type = config.type
|
model_type = config.type
|
||||||
|
|
||||||
return RedBearLLM(
|
return RedBearLLM(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@@ -47,7 +47,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
"""执行问题分类"""
|
"""执行问题分类"""
|
||||||
question = self.typed_config.input_variable
|
question = self.typed_config.input_variable
|
||||||
@@ -55,15 +55,15 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
supplement_prompt = ""
|
supplement_prompt = ""
|
||||||
if self.typed_config.user_supplement_prompt is not None:
|
if self.typed_config.user_supplement_prompt is not None:
|
||||||
supplement_prompt = self.typed_config.user_supplement_prompt
|
supplement_prompt = self.typed_config.user_supplement_prompt
|
||||||
|
|
||||||
category_names = [class_item.class_name for class_item in self.typed_config.categories]
|
category_names = [class_item.class_name for class_item in self.typed_config.categories]
|
||||||
|
|
||||||
if not question:
|
if not question:
|
||||||
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
|
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
|
||||||
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
|
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
|
||||||
|
|
||||||
llm = self._get_llm_instance()
|
llm = self._get_llm_instance()
|
||||||
|
|
||||||
# 渲染用户提示词模板,支持工作流变量
|
# 渲染用户提示词模板,支持工作流变量
|
||||||
user_prompt = self._render_template(
|
user_prompt = self._render_template(
|
||||||
self.typed_config.user_prompt.format(
|
self.typed_config.user_prompt.format(
|
||||||
@@ -73,15 +73,15 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
),
|
),
|
||||||
state
|
state
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
("system", self.typed_config.system_prompt),
|
("system", self.typed_config.system_prompt),
|
||||||
("user", user_prompt),
|
("user", user_prompt),
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await llm.ainvoke(messages)
|
response = await llm.ainvoke(messages)
|
||||||
result = response.content.strip()
|
result = response.content.strip()
|
||||||
|
|
||||||
if result in category_names:
|
if result in category_names:
|
||||||
category = result
|
category = result
|
||||||
else:
|
else:
|
||||||
@@ -90,5 +90,5 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
|
|
||||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||||
|
|
||||||
return {self.typed_config.output_variable: category}
|
return {self.typed_config.output_variable: category}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||||
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
|
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
|
||||||
|
|
||||||
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]
|
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]
|
||||||
|
|||||||
Reference in New Issue
Block a user