From 284951900d04c8aa393fa7c65ca0c127b69c2576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E4=BF=8A=E7=94=B7?= Date: Mon, 29 Dec 2025 17:01:19 +0800 Subject: [PATCH] feat(workflow-node): question_classfier node development --- api/app/core/workflow/nodes/configs.py | 2 + api/app/core/workflow/nodes/node_factory.py | 5 +- .../nodes/question_classifier/__init__.py | 6 ++ .../nodes/question_classifier/config.py | 29 ++++++ .../nodes/question_classifier/node.py | 94 +++++++++++++++++++ 5 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 api/app/core/workflow/nodes/question_classifier/__init__.py create mode 100644 api/app/core/workflow/nodes/question_classifier/config.py create mode 100644 api/app/core/workflow/nodes/question_classifier/node.py diff --git a/api/app/core/workflow/nodes/configs.py b/api/app/core/workflow/nodes/configs.py index a8363421..b1c64227 100644 --- a/api/app/core/workflow/nodes/configs.py +++ b/api/app/core/workflow/nodes/configs.py @@ -20,6 +20,7 @@ from app.core.workflow.nodes.start.config import StartNodeConfig from app.core.workflow.nodes.transform.config import TransformNodeConfig from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig +from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig __all__ = [ # 基础类 @@ -40,4 +41,5 @@ __all__ = [ "JinjaRenderNodeConfig", "VariableAggregatorNodeConfig", "ParameterExtractorNodeConfig", + "QuestionClassifierNodeConfig" ] diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 98c1468f..90c48ac0 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -21,6 +21,7 @@ from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode from app.core.workflow.nodes.start import StartNode from app.core.workflow.nodes.transform import TransformNode from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode +from app.core.workflow.nodes.question_classifier import QuestionClassifierNode logger = logging.getLogger(__name__) @@ -37,7 +38,8 @@ WorkflowNode = Union[ KnowledgeRetrievalNode, JinjaRenderNode, VariableAggregatorNode, - ParameterExtractorNode + ParameterExtractorNode, + QuestionClassifierNode ] @@ -61,6 +63,7 @@ class NodeFactory: NodeType.JINJARENDER: JinjaRenderNode, NodeType.VAR_AGGREGATOR: VariableAggregatorNode, NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, + NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, } @classmethod diff --git a/api/app/core/workflow/nodes/question_classifier/__init__.py b/api/app/core/workflow/nodes/question_classifier/__init__.py new file mode 100644 index 00000000..4f042737 --- /dev/null +++ b/api/app/core/workflow/nodes/question_classifier/__init__.py @@ -0,0 +1,6 @@ +from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig +from app.core.workflow.nodes.question_classifier.node import QuestionClassifierNode + +__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeConfig"] + + diff --git a/api/app/core/workflow/nodes/question_classifier/config.py b/api/app/core/workflow/nodes/question_classifier/config.py new file mode 100644 index 00000000..f3b2cc20 --- /dev/null +++ b/api/app/core/workflow/nodes/question_classifier/config.py @@ -0,0 +1,29 @@ +import uuid +from typing import Optional + +from pydantic import Field, BaseModel + +from app.core.workflow.nodes.base_config import BaseNodeConfig + +class ClassifierConfig(BaseModel): + """分类器节点配置""" + + class_name: str = Field(..., description="分类类别名称") + + +class QuestionClassifierNodeConfig(BaseNodeConfig): + """问题分类器节点配置""" + + model_id: uuid.UUID = Field(..., description="LLM模型ID") + input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)") + user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令") + categories: list[ClassifierConfig] = Field(..., description="分类类别列表") + system_prompt: str = Field( + default="你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。", + description="系统提示词" + ) + user_prompt: str = Field( + default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。", + description="用户提示词模板" + ) + output_variable: str = Field(default="class_name", description="输出分类结果的变量名") diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py new file mode 100644 index 00000000..bd3c8752 --- /dev/null +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -0,0 +1,94 @@ +import logging +from typing import Any + +from app.core.workflow.nodes.base_node import BaseNode, WorkflowState +from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig +from app.core.models import RedBearLLM, RedBearModelConfig +from app.core.exceptions import BusinessException +from app.core.error_codes import BizCode +from app.db import get_db_read +from app.models import ModelType +from app.services.model_service import ModelConfigService + +logger = logging.getLogger(__name__) + + +class QuestionClassifierNode(BaseNode): + """问题分类器节点""" + + def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): + super().__init__(node_config, workflow_config) + self.typed_config = QuestionClassifierNodeConfig(**self.config) + + def _get_llm_instance(self) -> RedBearLLM: + """获取LLM实例""" + with get_db_read() as db: + config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id) + + if not config: + raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND) + + if not config.api_keys or len(config.api_keys) == 0: + raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) + + api_config = config.api_keys[0] + model_name = api_config.model_name + provider = api_config.provider + api_key = api_config.api_key + base_url = api_config.api_base + model_type = config.type + + return RedBearLLM( + RedBearModelConfig( + model_name=model_name, + provider=provider, + api_key=api_key, + base_url=base_url, + ), + type=ModelType(model_type) + ) + + async def execute(self, state: WorkflowState) -> dict[str, Any]: + """执行问题分类""" + question = self.typed_config.input_variable + + supplement_prompt = "" + if self.typed_config.user_supplement_prompt is not None: + supplement_prompt = self.typed_config.user_supplement_prompt + + category_names = [class_item.class_name for class_item in self.typed_config.categories] + + if not question: + logger.warning(f"节点 {self.node_id} 未获取到输入问题") + return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"} + + llm = self._get_llm_instance() + + # 渲染用户提示词模板,支持工作流变量 + user_prompt = self._render_template( + self.typed_config.user_prompt.format( + question=question, + categories=", ".join(category_names), + supplement_prompt=supplement_prompt + ), + state + ) + + messages = [ + ("system", self.typed_config.system_prompt), + ("user", user_prompt), + ] + + response = await llm.ainvoke(messages) + result = response.content.strip() + + if result in category_names: + category = result + else: + logger.warning(f"LLM返回了未知类别: {result}") + category = category_names[0] if category_names else "unknown" + + log_supplement = supplement_prompt if supplement_prompt else "无" + logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") + + return {self.typed_config.output_variable: category} \ No newline at end of file