From c0b29dd9384437dd905c98141e4afe8dc8b69d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Sun, 4 Jan 2026 19:06:51 +0800 Subject: [PATCH] feat(workflow_node): question classifier node optimization --- api/app/core/workflow/executor.py | 8 +- .../nodes/question_classifier/config.py | 1 - .../nodes/question_classifier/node.py | 107 +++++++++++------- 3 files changed, 70 insertions(+), 46 deletions(-) diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 0d0879d7..fe75eace 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -219,17 +219,13 @@ class WorkflowExecutor: # 创建节点实例(现在 start 和 end 也会被创建) node_instance = NodeFactory.create_node(node, self.workflow_config) - if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]: - expressions = node_instance.build_conditional_edge_expressions() - - # Number of branches, usually matches the number of conditional expressions - branch_number = len(expressions) + if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]: # Find all edges whose source is the current node related_edge = [edge for edge in self.edges if edge.get("source") == node_id] # Iterate over each branch - for idx in range(branch_number): + for idx in range(len(related_edge)): # Generate a condition expression for each edge # Used later to determine which branch to take based on the node's output # Assumes node output `node..output` matches the edge's label diff --git a/api/app/core/workflow/nodes/question_classifier/config.py b/api/app/core/workflow/nodes/question_classifier/config.py index f3b2cc20..998e2fb4 100644 --- a/api/app/core/workflow/nodes/question_classifier/config.py +++ b/api/app/core/workflow/nodes/question_classifier/config.py @@ -26,4 +26,3 @@ class QuestionClassifierNodeConfig(BaseNodeConfig): default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", description="用户提示词模板" ) - output_variable: str = Field(default="class_name", description="输出分类结果的变量名") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index bd3c8752..67f53801 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -12,6 +12,9 @@ from app.services.model_service import ModelConfigService logger = logging.getLogger(__name__) +DEFAULT_CASE_PREFIX = "CASE" +DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1" + class QuestionClassifierNode(BaseNode): """问题分类器节点""" @@ -19,6 +22,7 @@ class QuestionClassifierNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config = QuestionClassifierNodeConfig(**self.config) + self.category_to_case_map = self._build_category_case_map() def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" @@ -47,48 +51,73 @@ class QuestionClassifierNode(BaseNode): ), type=ModelType(model_type) ) + + def _build_category_case_map(self) -> dict[str, str]: + """ + 预构建 分类名称 -> CASE标识 的映射字典 + 示例:{"产品咨询": "CASE1", "售后问题": "CASE2"} + """ + category_map = {} + categories = self.typed_config.categories or [] + for idx, class_item in enumerate(categories, start=1): + category_name = class_item.class_name.strip() + case_tag = f"{DEFAULT_CASE_PREFIX}{idx}" + category_map[category_name] = case_tag + return category_map - async def execute(self, state: WorkflowState) -> dict[str, Any]: + async def execute(self, state: WorkflowState) -> str: """执行问题分类""" question = self.typed_config.input_variable - - supplement_prompt = "" - if self.typed_config.user_supplement_prompt is not None: - supplement_prompt = self.typed_config.user_supplement_prompt - - category_names = [class_item.class_name for class_item in self.typed_config.categories] + supplement_prompt = self.typed_config.user_supplement_prompt or "" + categories = self.typed_config.categories or [] + category_names = [class_item.class_name.strip() for class_item in categories] + category_count = len(category_names) if not question: - logger.warning(f"节点 {self.node_id} 未获取到输入问题") - return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"} - - llm = self._get_llm_instance() - - # 渲染用户提示词模板,支持工作流变量 - user_prompt = self._render_template( - self.typed_config.user_prompt.format( - question=question, - categories=", ".join(category_names), - supplement_prompt=supplement_prompt - ), - state - ) - - messages = [ - ("system", self.typed_config.system_prompt), - ("user", user_prompt), - ] - - response = await llm.ainvoke(messages) - result = response.content.strip() - - if result in category_names: - category = result - else: - logger.warning(f"LLM返回了未知类别: {result}") - category = category_names[0] if category_names else "unknown" + logger.warning( + f"节点 {self.node_id} 未获取到输入问题,使用默认分支" + f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})" + ) + # 若分类列表为空,返回默认unknown分支,否则返回CASE1 + return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown" - log_supplement = supplement_prompt if supplement_prompt else "无" - logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - - return {self.typed_config.output_variable: category} \ No newline at end of file + try: + llm = self._get_llm_instance() + + # 渲染用户提示词模板,支持工作流变量 + user_prompt = self._render_template( + self.typed_config.user_prompt.format( + question=question, + categories=", ".join(category_names), + supplement_prompt=supplement_prompt + ), + state + ) + + messages = [ + ("system", self.typed_config.system_prompt), + ("user", user_prompt), + ] + + response = await llm.ainvoke(messages) + result = response.content.strip() + + if result in category_names: + category = result + else: + logger.warning(f"LLM返回了未知类别: {result}") + category = category_names[0] if category_names else "unknown" + + log_supplement = supplement_prompt if supplement_prompt else "无" + logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") + + return f"CASE{category_names.index(category) + 1}" + except Exception as e: + logger.error( + f"节点 {self.node_id} 分类执行异常:{str(e)}", + exc_info=True # 打印堆栈信息,便于调试 + ) + # 异常时返回默认分支,保证工作流容错性 + if category_count > 0: + return DEFAULT_EMPTY_QUESTION_CASE + return "unknown"