feat(workflow): update workflow conditional logic

This commit is contained in:
mengyonghao
2026-01-05 10:57:44 +08:00
parent eaf2437633
commit b56994b999
9 changed files with 436 additions and 186 deletions

View File

@@ -3,10 +3,11 @@ from typing import Any
from langgraph.graph.state import CompiledStateGraph
from app.core.workflow.expression_evaluator import evaluate_condition, evaluate_expression
from app.core.workflow.expression_evaluator import evaluate_expression
from app.core.workflow.nodes import WorkflowState
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -14,11 +15,13 @@ logger = logging.getLogger(__name__)
class LoopRuntime:
"""
Runtime executor for loop nodes in a workflow.
Runtime executor for a loop node in a workflow graph.
Handles iterative execution of a loop node according to defined loop variables
and conditional expressions. Supports maximum loop count and loop control
through the workflow state.
This class is responsible for executing a loop node at runtime:
- Initializing loop-scoped variables
- Evaluating loop continuation conditions
- Repeatedly invoking a compiled sub-graph
- Enforcing maximum loop count and external stop signals
"""
def __init__(
@@ -29,13 +32,13 @@ class LoopRuntime:
state: WorkflowState,
):
"""
Initialize the loop runtime.
Initialize the loop runtime executor.
Args:
graph: Compiled workflow graph capable of async invocation.
node_id: Unique identifier of the loop node.
config: Dictionary containing loop node configuration.
state: Current workflow state at the point of loop execution.
graph: A compiled LangGraph state graph representing the loop body.
node_id: The unique identifier of the loop node in the workflow.
config: Raw configuration dictionary for the loop node.
state: The current workflow state before entering the loop.
"""
self.graph = graph
self.state = state
@@ -46,12 +49,15 @@ class LoopRuntime:
"""
Initialize workflow state for loop execution.
- Evaluates initial values of loop variables.
- Stores loop variables in runtime_vars and node_outputs.
- Marks the loop as active by setting 'looping' to True.
This method:
- Evaluates initial values of loop variables
- Stores loop variables into both `runtime_vars` and `node_outputs`
under the current loop node's scope
- Creates a shallow copy of the workflow state
- Marks the loop as active by setting `looping = True`
Returns:
A copy of the workflow state prepared for the loop execution.
WorkflowState: A prepared workflow state used for loop execution.
"""
pool = VariablePool(self.state)
# 循环变量
@@ -61,7 +67,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
self.state["node_outputs"][self.node_id] = {
@@ -70,7 +76,7 @@ class LoopRuntime:
variables=pool.get_all_conversation_vars(),
node_outputs=pool.get_all_node_outputs(),
system_vars=pool.get_all_system_vars(),
)
) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
}
loopstate = WorkflowState(
@@ -79,49 +85,93 @@ class LoopRuntime:
loopstate["looping"] = True
return loopstate
def _get_loop_expression(self):
@staticmethod
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
"""
Build the Python boolean expression for evaluating the loop condition.
Dispatch and execute a comparison operator against a resolved
CompareOperatorInstance.
- Converts each condition in the loop configuration into a Python expression string.
- Combines multiple conditions with the configured logical operator (AND/OR).
Args:
operator: A ComparisonOperator enum value.
instance: A CompareOperatorInstance bound to concrete operands.
Returns:
A string representing the combined loop condition expression.
Any: The evaluation result, typically a boolean.
"""
branch_conditions = [
ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
for condition in self.typed_config.condition.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {self.typed_config.condition.logical_operator} '.join(branch_conditions)
else:
combined_condition = branch_conditions[0]
match operator:
case ComparisonOperator.EMPTY:
return instance.empty()
case ComparisonOperator.NOT_EMPTY:
return instance.not_empty()
case ComparisonOperator.CONTAINS:
return instance.contains()
case ComparisonOperator.NOT_CONTAINS:
return instance.not_contains()
case ComparisonOperator.START_WITH:
return instance.startswith()
case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
return combined_condition
def evaluate_conditional(self, state) -> bool:
"""
Evaluate the loop continuation condition at runtime.
This method:
- Resolves all condition expressions against the current workflow state
- Evaluates each comparison expression immediately
- Combines results using the configured logical operator (AND / OR)
Args:
state: The current workflow state during loop execution.
Returns:
bool: True if the loop should continue, False otherwise.
"""
conditions = []
for expression in self.typed_config.condition.expressions:
left_value = VariablePool(state).get(expression.left)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
VariablePool(state),
expression.left,
expression.right,
expression.input_type
)
conditions.append(self._evaluate(expression.operator, evaluator))
if self.typed_config.condition.logical_operator == LogicOperator.AND:
return all(conditions)
else:
return any(conditions)
async def run(self):
"""
Execute the loop node until the condition is no longer met, the loop is
manually stopped, or the maximum loop count is reached.
Execute the loop node until termination conditions are met.
The loop terminates when any of the following occurs:
- The loop condition evaluates to False
- The `looping` flag in the workflow state is set to False
- The maximum loop count is reached
Returns:
The final runtime variables of this loop node after completion.
dict[str, Any]: The final runtime variables of this loop node.
"""
loopstate = self._init_loop_state()
expression = self._get_loop_expression()
loop_variable_pool = VariablePool(loopstate)
loop_time = self.typed_config.max_loop
while evaluate_condition(
expression=expression,
variables=loop_variable_pool.get_all_conversation_vars(),
node_outputs=loop_variable_pool.get_all_node_outputs(),
system_vars=loop_variable_pool.get_all_system_vars(),
) and loopstate["looping"] and loop_time > 0:
while self.evaluate_conditional(loopstate) and loopstate["looping"] and loop_time > 0:
logger.info(f"loop node {self.node_id}: running")
await self.graph.ainvoke(loopstate)
loop_time -= 1

View File

@@ -61,7 +61,7 @@ class EndNode(BaseNode):
引用的节点 ID 列表
"""
# 匹配 {{node_id.xxx}} 格式
pattern = r'\{\{([a-zA-Z0-9_-]+)\.[a-zA-Z0-9_]+\}\}'
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
matches = re.findall(pattern, template)
return list(set(matches)) # 去重

View File

@@ -45,7 +45,8 @@ class LogicOperator(StrEnum):
class AssignmentOperator(StrEnum):
ASSIGN = "assign"
COVER = "cover" # 覆盖
ASSIGN = "assign" # 设置
CLEAR = "clear"
ADD = "add" # +=
@@ -87,3 +88,8 @@ class HttpErrorHandle(StrEnum):
NONE = "none"
DEFAULT = "default"
BRANCH = "branch"
class ValueInputType(StrEnum):
VARIABLE = "Variable"
CONSTANT = "Constant"

View File

@@ -1,12 +1,13 @@
"""Condition Configuration"""
from typing import Any
from pydantic import Field, BaseModel, field_validator
from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
class ConditionDetail(BaseModel):
comparison_operator: ComparisonOperator = Field(
operator: ComparisonOperator = Field(
...,
description="Comparison operator used to evaluate the condition"
)
@@ -16,17 +17,22 @@ class ConditionDetail(BaseModel):
description="Value to compare against"
)
right: str = Field(
right: Any = Field(
...,
description="Value to compare with"
)
input_type: ValueInputType = Field(
...,
description="Value input type for comparison"
)
class ConditionBranchConfig(BaseModel):
"""Configuration for a conditional branch"""
logical_operator: LogicOperator = Field(
default=LogicOperator.AND.value,
default=LogicOperator.AND,
description="Logical operator used to combine multiple condition expressions"
)

View File

@@ -1,10 +1,11 @@
import logging
import re
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
from app.core.workflow.nodes.if_else import IfElseNodeConfig
from app.core.workflow.nodes.if_else.config import ConditionDetail
from app.core.workflow.nodes.operators import ConditionExpressionBuilder
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
logger = logging.getLogger(__name__)
@@ -15,30 +16,36 @@ class IfElseNode(BaseNode):
self.typed_config = IfElseNodeConfig(**self.config)
@staticmethod
def _build_condition_expression(
condition: ConditionDetail,
) -> str:
"""
Build a single boolean condition expression string.
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
match operator:
case ComparisonOperator.EMPTY:
return instance.empty()
case ComparisonOperator.NOT_EMPTY:
return instance.not_empty()
case ComparisonOperator.CONTAINS:
return instance.contains()
case ComparisonOperator.NOT_CONTAINS:
return instance.not_contains()
case ComparisonOperator.START_WITH:
return instance.startswith()
case ComparisonOperator.END_WITH:
return instance.endswith()
case ComparisonOperator.EQ:
return instance.eq()
case ComparisonOperator.NE:
return instance.ne()
case ComparisonOperator.LT:
return instance.lt()
case ComparisonOperator.LE:
return instance.le()
case ComparisonOperator.GT:
return instance.gt()
case ComparisonOperator.GE:
return instance.ge()
case _:
raise ValueError(f"Invalid condition: {operator}")
This method does NOT evaluate the condition.
It only generates a valid Python boolean expression string
(e.g. "x > 10", "'a' in name") that can later be used
in a conditional edge or evaluated by the workflow engine.
Args:
condition (ConditionDetail): Definition of a single comparison condition.
Returns:
str: A Python boolean expression string.
"""
return ConditionExpressionBuilder(
left=condition.left,
operator=condition.comparison_operator,
right=condition.right
).build()
def build_conditional_edge_expressions(self) -> list[str]:
def evaluate_conditional_edge_expressions(self, state) -> list[bool]:
"""
Build conditional edge expressions for the If-Else node.
@@ -60,19 +67,28 @@ class IfElseNode(BaseNode):
for case_branch in self.typed_config.cases:
branch_index += 1
branch_conditions = [
self._build_condition_expression(condition)
for condition in case_branch.expressions
]
if len(branch_conditions) > 1:
combined_condition = f' {case_branch.logical_operator} '.join(branch_conditions)
branch_result = []
for expression in case_branch.expressions:
pattern = r"\{\{\s*(.*?)\s*\}\}"
left_string = re.sub(pattern, r"\1", expression.left).strip()
left_value = self.get_variable(left_string, state)
evaluator = ConditionExpressionResolver.resolve_by_value(left_value)(
self.get_variable_pool(state),
expression.left,
expression.right,
expression.input_type
)
branch_result.append(self._evaluate(expression.operator, evaluator))
if case_branch.logical_operator == LogicOperator.AND:
conditions.append(all(branch_result))
else:
combined_condition = branch_conditions[0]
conditions.append(combined_condition)
condition_res = any(branch_result)
conditions.append(condition_res)
if condition_res:
return conditions
# Default fallback branch
conditions.append("True")
conditions.append(True)
return conditions
@@ -90,10 +106,10 @@ class IfElseNode(BaseNode):
Returns:
str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions.
"""
expressions = self.build_conditional_edge_expressions()
expressions = self.evaluate_conditional_edge_expressions(state)
# TODO: 变量类型及文本类型解析
for i in range(len(expressions)):
logger.info(expressions[i])
if self._evaluate_condition(expressions[i], state):
if expressions[i]:
logger.info(f"Node {self.node_id}: switched to branch CASE {i + 1}")
return f'CASE{i + 1}'
return f'CASE{len(expressions)}'

View File

@@ -8,6 +8,7 @@ from app.core.workflow.template_renderer import TemplateRenderer
logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)

View File

@@ -1,10 +1,73 @@
import json
import re
from abc import ABC
from typing import Union, Type
from typing import Union, Type, NoReturn
from app.core.workflow.nodes.enums import ComparisonOperator
from app.core.workflow.nodes.base_config import VariableType
from app.core.workflow.nodes.enums import ValueInputType
from app.core.workflow.variable_pool import VariablePool
class TypeTransformer:
@classmethod
def _fail(cls, value, target) -> NoReturn:
raise TypeError(f"Cannot convert {value!r} to {target} type")
@classmethod
def _json_load(cls, value, target):
try:
return json.loads(value)
except Exception:
cls._fail(value, target)
@classmethod
def transform(cls, variable_literal: str | bool, target_type: VariableType):
match target_type:
case VariableType.STRING:
return str(variable_literal)
case VariableType.NUMBER:
for caster in (int, float):
try:
return caster(variable_literal)
except Exception:
pass
cls._fail(variable_literal, target_type)
case VariableType.BOOLEAN:
if isinstance(variable_literal, bool):
return variable_literal
cls._fail(variable_literal, target_type)
case VariableType.OBJECT:
obj = cls._json_load(variable_literal, target_type)
if isinstance(obj, dict):
return obj
cls._fail(variable_literal, target_type)
case VariableType.ARRAY_BOOLEAN:
return cls._parse_list(variable_literal, bool, target_type)
case VariableType.ARRAY_NUMBER:
return cls._parse_list(variable_literal, (int, float), target_type)
case VariableType.ARRAY_STRING:
return cls._parse_list(variable_literal, str, target_type)
case VariableType.ARRAY_OBJECT:
return cls._parse_list(variable_literal, dict, target_type)
case _:
raise TypeError("Invalid type")
@classmethod
def _parse_list(cls, value, item_type, target):
arr = cls._json_load(value, target)
if isinstance(arr, list) and all(isinstance(i, item_type) for i in arr):
return arr
cls._fail(value, target)
class OperatorBase(ABC):
def __init__(self, pool: VariablePool, left_selector, right):
self.pool = pool
@@ -19,7 +82,9 @@ class OperatorBase(ABC):
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
if not no_right and not isinstance(self.right, self.type_limit):
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
raise TypeError(
f"The value assigned must be of {self.type_limit} type"
)
class StringOperator(OperatorBase):
@@ -126,7 +191,7 @@ class ArrayOperator(OperatorBase):
class ObjectOperator(OperatorBase):
def __init__(self, pool: VariablePool, left_selector, right):
super().__init__(pool, left_selector, right)
self.type_limit = object
self.type_limit = dict
def assign(self) -> None:
self.check()
@@ -138,20 +203,21 @@ class ObjectOperator(OperatorBase):
class AssignmentOperatorResolver:
OPERATOR_MAP = {
str: StringOperator,
bool: BooleanOperator,
int: NumberOperator,
float: NumberOperator,
list: ArrayOperator,
dict: ObjectOperator,
}
@classmethod
def resolve_by_value(cls, value):
if isinstance(value, str):
return StringOperator
elif isinstance(value, bool):
return BooleanOperator
elif isinstance(value, (int, float)):
return NumberOperator
elif isinstance(value, list):
return ArrayOperator
elif isinstance(value, dict):
return ObjectOperator
else:
raise TypeError(f"Unsupported variable type: {type(value)}")
for t, op in cls.OPERATOR_MAP.items():
if isinstance(value, t):
return op
raise TypeError(f"Unsupported variable type: {type(value)}")
AssignmentOperatorInstance = Union[
@@ -164,81 +230,186 @@ AssignmentOperatorInstance = Union[
AssignmentOperatorType = Type[AssignmentOperatorInstance]
class ConditionExpressionBuilder:
"""
Build a Python boolean expression string based on a comparison operator.
class ConditionBase(ABC):
type_limit: type[str, int, dict, list] = None
This class does not evaluate the expression.
It only generates a valid Python expression string
that can be evaluated later in a workflow context.
"""
def __init__(
self,
pool: VariablePool,
left_selector,
right_selector: str,
input_type: ValueInputType
):
self.pool = pool
self.left_selector = left_selector
self.right_selector = right_selector
self.input_type = input_type
def __init__(self, left: str, operator: ComparisonOperator, right: str):
self.left = left
self.operator = operator
self.right = right
self.left_value = self.pool.get(self.left_selector)
self.right_value = self.resolve_right_literal_value()
def _empty(self):
return f"{self.left} == ''"
self.type_limit = getattr(self, "type_limit", None)
def _not_empty(self):
return f"{self.left} != ''"
def resolve_right_literal_value(self):
if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
return self.pool.get(right_expression)
elif self.input_type == ValueInputType.CONSTANT:
return self.right_selector
raise RuntimeError("Unsupported variable type")
def _contains(self):
return f"{self.right} in {self.left}"
def check(self, no_right=False):
left = self.pool.get(self.left_selector.variable_selector)
if not isinstance(left, self.type_limit):
raise TypeError(f"The variable to be compared on must be of {self.type_limit} type")
if not no_right:
right = self.resolve_right_literal_value()
if not isinstance(right, self.type_limit):
raise TypeError(
f"The compared variable must be of {self.type_limit} type"
)
def _not_contains(self):
return f"{self.right} not in {self.left}"
def _startswith(self):
return f'{self.left}.startswith({self.right})'
class StringComparisonOperator(ConditionBase):
type_limit = str
def _endswith(self):
return f'{self.left}.endswith({self.right})'
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def _eq(self):
return f"{self.left} == {self.right}"
def empty(self):
self.check(no_right=True)
return self.left_value == ""
def _ne(self):
return f"{self.left} != {self.right}"
def not_empty(self):
return not self.empty()
def _lt(self):
return f"{self.left} < {self.right}"
def contains(self):
self.check()
return self.right_value in self.left_value
def _le(self):
return f"{self.left} <= {self.right}"
def not_contains(self):
return self.right_value not in self.left_value
def _gt(self):
return f"{self.left} > {self.right}"
def startswith(self):
self.check()
return self.left_value.startswith(self.right_value)
def _ge(self):
return f"{self.left} >= {self.right}"
def endswith(self):
return self.left_value.endswith(self.right_value)
def build(self):
match self.operator:
case ComparisonOperator.EMPTY:
return self._empty()
case ComparisonOperator.NOT_EMPTY:
return self._not_empty()
case ComparisonOperator.CONTAINS:
return self._contains()
case ComparisonOperator.NOT_CONTAINS:
return self._not_contains()
case ComparisonOperator.START_WITH:
return self._startswith()
case ComparisonOperator.END_WITH:
return self._endswith()
case ComparisonOperator.EQ:
return self._eq()
case ComparisonOperator.NE:
return self._ne()
case ComparisonOperator.LT:
return self._lt()
case ComparisonOperator.LE:
return self._le()
case ComparisonOperator.GT:
return self._gt()
case ComparisonOperator.GE:
return self._ge()
case _:
raise ValueError(f"Invalid condition: {self.operator}")
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
class NumberComparisonOperator(ConditionBase):
type_limit = (int, float)
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def empty(self):
return self.left_value == 0
def not_empty(self):
return self.left_value != 0
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
def lt(self):
return self.left_value < self.right_value
def le(self):
return self.left_value <= self.right_value
def gt(self):
return self.left_value > self.right_value
def ge(self):
return self.left_value >= self.right_value
class BooleanComparisonOperator(ConditionBase):
type_limit = bool
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
class ObjectComparisonOperator(ConditionBase):
type_limit = dict
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def eq(self):
return self.left_value == self.right_value
def ne(self):
return self.left_value != self.right_value
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
class ArrayComparisonOperator(ConditionBase):
type_limit = list
def __init__(self, pool: VariablePool, left_selector, right_selector, input_type):
super().__init__(pool, left_selector, right_selector, input_type)
def empty(self):
return not self.left_value
def not_empty(self):
return bool(self.left_value)
def contains(self):
return self.right_value in self.left_value
def not_contains(self):
return self.right_value not in self.left_value
CompareOperatorInstance = Union[
StringComparisonOperator,
NumberComparisonOperator,
BooleanComparisonOperator,
ArrayComparisonOperator,
ObjectComparisonOperator
]
CompareOperatorType = Type[CompareOperatorInstance]
class ConditionExpressionResolver:
CONDITION_OPERATOR_MAP = {
str: StringComparisonOperator,
bool: BooleanComparisonOperator,
int: NumberComparisonOperator,
float: NumberComparisonOperator,
list: ArrayComparisonOperator,
dict: ObjectComparisonOperator,
}
@classmethod
def resolve_by_value(cls, value) -> CompareOperatorType:
for t, op in cls.CONDITION_OPERATOR_MAP.items():
if isinstance(value, t):
return op
raise TypeError(f"Unsupported variable type: {type(value)}")

View File

@@ -15,29 +15,29 @@ logger = logging.getLogger(__name__)
class QuestionClassifierNode(BaseNode):
"""问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config = QuestionClassifierNodeConfig(**self.config)
def _get_llm_instance(self) -> RedBearLLM:
"""获取LLM实例"""
with get_db_read() as db:
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
if not config:
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
if not config.api_keys or len(config.api_keys) == 0:
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
api_config = config.api_keys[0]
model_name = api_config.model_name
provider = api_config.provider
api_key = api_config.api_key
base_url = api_config.api_base
model_type = config.type
return RedBearLLM(
RedBearModelConfig(
model_name=model_name,
@@ -47,7 +47,7 @@ class QuestionClassifierNode(BaseNode):
),
type=ModelType(model_type)
)
async def execute(self, state: WorkflowState) -> dict[str, Any]:
"""执行问题分类"""
question = self.typed_config.input_variable
@@ -55,15 +55,15 @@ class QuestionClassifierNode(BaseNode):
supplement_prompt = ""
if self.typed_config.user_supplement_prompt is not None:
supplement_prompt = self.typed_config.user_supplement_prompt
category_names = [class_item.class_name for class_item in self.typed_config.categories]
if not question:
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
llm = self._get_llm_instance()
# 渲染用户提示词模板,支持工作流变量
user_prompt = self._render_template(
self.typed_config.user_prompt.format(
@@ -73,15 +73,15 @@ class QuestionClassifierNode(BaseNode):
),
state
)
messages = [
("system", self.typed_config.system_prompt),
("user", user_prompt),
]
response = await llm.ainvoke(messages)
result = response.content.strip()
if result in category_names:
category = result
else:
@@ -90,5 +90,5 @@ class QuestionClassifierNode(BaseNode):
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return {self.typed_config.output_variable: category}
return {self.typed_config.output_variable: category}

View File

@@ -1,4 +1,4 @@
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
from app.core.workflow.nodes.variable_aggregator.node import VariableAggregatorNode
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]
__all__ = ["VariableAggregatorNode", "VariableAggregatorNodeConfig"]