Merge #76 into develop from feature/20251219_xjn
feat(workflow-node): question_classfier node development * feature/20251219_xjn: (1 commits) feat(workflow-node): question_classfier node development Signed-off-by: 谢俊男 <accounts_6853d0ea6f8174722fb0c8f1@mail.teambition.com> Reviewed-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> Merged-by: zhuwenhui5566@163.com <zhuwenhui5566@163.com> CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/76
This commit is contained in:
@@ -20,6 +20,7 @@ from app.core.workflow.nodes.start.config import StartNodeConfig
|
|||||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||||
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
from app.core.workflow.nodes.parameter_extractor.config import ParameterExtractorNodeConfig
|
||||||
|
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 基础类
|
# 基础类
|
||||||
@@ -40,4 +41,5 @@ __all__ = [
|
|||||||
"JinjaRenderNodeConfig",
|
"JinjaRenderNodeConfig",
|
||||||
"VariableAggregatorNodeConfig",
|
"VariableAggregatorNodeConfig",
|
||||||
"ParameterExtractorNodeConfig",
|
"ParameterExtractorNodeConfig",
|
||||||
|
"QuestionClassifierNodeConfig"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
|||||||
from app.core.workflow.nodes.start import StartNode
|
from app.core.workflow.nodes.start import StartNode
|
||||||
from app.core.workflow.nodes.transform import TransformNode
|
from app.core.workflow.nodes.transform import TransformNode
|
||||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||||
|
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -37,7 +38,8 @@ WorkflowNode = Union[
|
|||||||
KnowledgeRetrievalNode,
|
KnowledgeRetrievalNode,
|
||||||
JinjaRenderNode,
|
JinjaRenderNode,
|
||||||
VariableAggregatorNode,
|
VariableAggregatorNode,
|
||||||
ParameterExtractorNode
|
ParameterExtractorNode,
|
||||||
|
QuestionClassifierNode
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -61,6 +63,7 @@ class NodeFactory:
|
|||||||
NodeType.JINJARENDER: JinjaRenderNode,
|
NodeType.JINJARENDER: JinjaRenderNode,
|
||||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
|
NodeType.VAR_AGGREGATOR: VariableAggregatorNode,
|
||||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||||
|
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||||
|
from app.core.workflow.nodes.question_classifier.node import QuestionClassifierNode
|
||||||
|
|
||||||
|
__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeConfig"]
|
||||||
|
|
||||||
|
|
||||||
29
api/app/core/workflow/nodes/question_classifier/config.py
Normal file
29
api/app/core/workflow/nodes/question_classifier/config.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field, BaseModel
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
class ClassifierConfig(BaseModel):
|
||||||
|
"""分类器节点配置"""
|
||||||
|
|
||||||
|
class_name: str = Field(..., description="分类类别名称")
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionClassifierNodeConfig(BaseNodeConfig):
|
||||||
|
"""问题分类器节点配置"""
|
||||||
|
|
||||||
|
model_id: uuid.UUID = Field(..., description="LLM模型ID")
|
||||||
|
input_variable: str = Field(default="{{sys.message}}", description="输入变量选择器(用户问题)")
|
||||||
|
user_supplement_prompt: Optional[str] = Field(default=None, description="用户补充提示词,额外分类指令")
|
||||||
|
categories: list[ClassifierConfig] = Field(..., description="分类类别列表")
|
||||||
|
system_prompt: str = Field(
|
||||||
|
default="你是一个问题分类器,请根据用户问题选择最合适的分类。只返回分类名称,不要其他内容。",
|
||||||
|
description="系统提示词"
|
||||||
|
)
|
||||||
|
user_prompt: str = Field(
|
||||||
|
default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||||
|
description="用户提示词模板"
|
||||||
|
)
|
||||||
|
output_variable: str = Field(default="class_name", description="输出分类结果的变量名")
|
||||||
94
api/app/core/workflow/nodes/question_classifier/node.py
Normal file
94
api/app/core/workflow/nodes/question_classifier/node.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
|
from app.core.workflow.nodes.question_classifier.config import QuestionClassifierNodeConfig
|
||||||
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.db import get_db_read
|
||||||
|
from app.models import ModelType
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionClassifierNode(BaseNode):
|
||||||
|
"""问题分类器节点"""
|
||||||
|
|
||||||
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
|
super().__init__(node_config, workflow_config)
|
||||||
|
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||||
|
|
||||||
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
|
"""获取LLM实例"""
|
||||||
|
with get_db_read() as db:
|
||||||
|
config = ModelConfigService.get_model_by_id(db=db, model_id=self.typed_config.model_id)
|
||||||
|
|
||||||
|
if not config:
|
||||||
|
raise BusinessException("配置的模型不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
|
if not config.api_keys or len(config.api_keys) == 0:
|
||||||
|
raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
|
api_config = config.api_keys[0]
|
||||||
|
model_name = api_config.model_name
|
||||||
|
provider = api_config.provider
|
||||||
|
api_key = api_config.api_key
|
||||||
|
base_url = api_config.api_base
|
||||||
|
model_type = config.type
|
||||||
|
|
||||||
|
return RedBearLLM(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=model_name,
|
||||||
|
provider=provider,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
),
|
||||||
|
type=ModelType(model_type)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
||||||
|
"""执行问题分类"""
|
||||||
|
question = self.typed_config.input_variable
|
||||||
|
|
||||||
|
supplement_prompt = ""
|
||||||
|
if self.typed_config.user_supplement_prompt is not None:
|
||||||
|
supplement_prompt = self.typed_config.user_supplement_prompt
|
||||||
|
|
||||||
|
category_names = [class_item.class_name for class_item in self.typed_config.categories]
|
||||||
|
|
||||||
|
if not question:
|
||||||
|
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
|
||||||
|
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
|
||||||
|
|
||||||
|
llm = self._get_llm_instance()
|
||||||
|
|
||||||
|
# 渲染用户提示词模板,支持工作流变量
|
||||||
|
user_prompt = self._render_template(
|
||||||
|
self.typed_config.user_prompt.format(
|
||||||
|
question=question,
|
||||||
|
categories=", ".join(category_names),
|
||||||
|
supplement_prompt=supplement_prompt
|
||||||
|
),
|
||||||
|
state
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
("system", self.typed_config.system_prompt),
|
||||||
|
("user", user_prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
result = response.content.strip()
|
||||||
|
|
||||||
|
if result in category_names:
|
||||||
|
category = result
|
||||||
|
else:
|
||||||
|
logger.warning(f"LLM返回了未知类别: {result}")
|
||||||
|
category = category_names[0] if category_names else "unknown"
|
||||||
|
|
||||||
|
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||||
|
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||||
|
|
||||||
|
return {self.typed_config.output_variable: category}
|
||||||
Reference in New Issue
Block a user