feat(workflow_node): question classifier node optimization
This commit is contained in:
@@ -219,17 +219,13 @@ class WorkflowExecutor:
|
|||||||
# 创建节点实例(现在 start 和 end 也会被创建)
|
# 创建节点实例(现在 start 和 end 也会被创建)
|
||||||
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
node_instance = NodeFactory.create_node(node, self.workflow_config)
|
||||||
|
|
||||||
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST]:
|
if node_type in [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]:
|
||||||
expressions = node_instance.build_conditional_edge_expressions()
|
|
||||||
|
|
||||||
# Number of branches, usually matches the number of conditional expressions
|
|
||||||
branch_number = len(expressions)
|
|
||||||
|
|
||||||
# Find all edges whose source is the current node
|
# Find all edges whose source is the current node
|
||||||
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
related_edge = [edge for edge in self.edges if edge.get("source") == node_id]
|
||||||
|
|
||||||
# Iterate over each branch
|
# Iterate over each branch
|
||||||
for idx in range(branch_number):
|
for idx in range(len(related_edge)):
|
||||||
# Generate a condition expression for each edge
|
# Generate a condition expression for each edge
|
||||||
# Used later to determine which branch to take based on the node's output
|
# Used later to determine which branch to take based on the node's output
|
||||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||||
|
|||||||
@@ -26,4 +26,3 @@ class QuestionClassifierNodeConfig(BaseNodeConfig):
|
|||||||
default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
default="问题:{question}\n\n可选分类:{categories}\n\n补充指令:{supplement_prompt}\n\n请选择最合适的分类。",
|
||||||
description="用户提示词模板"
|
description="用户提示词模板"
|
||||||
)
|
)
|
||||||
output_variable: str = Field(default="class_name", description="输出分类结果的变量名")
|
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ from app.services.model_service import ModelConfigService
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_CASE_PREFIX = "CASE"
|
||||||
|
DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
|
||||||
|
|
||||||
|
|
||||||
class QuestionClassifierNode(BaseNode):
|
class QuestionClassifierNode(BaseNode):
|
||||||
"""问题分类器节点"""
|
"""问题分类器节点"""
|
||||||
@@ -19,6 +22,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
self.typed_config = QuestionClassifierNodeConfig(**self.config)
|
||||||
|
self.category_to_case_map = self._build_category_case_map()
|
||||||
|
|
||||||
def _get_llm_instance(self) -> RedBearLLM:
|
def _get_llm_instance(self) -> RedBearLLM:
|
||||||
"""获取LLM实例"""
|
"""获取LLM实例"""
|
||||||
@@ -47,48 +51,73 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=ModelType(model_type)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _build_category_case_map(self) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
预构建 分类名称 -> CASE标识 的映射字典
|
||||||
|
示例:{"产品咨询": "CASE1", "售后问题": "CASE2"}
|
||||||
|
"""
|
||||||
|
category_map = {}
|
||||||
|
categories = self.typed_config.categories or []
|
||||||
|
for idx, class_item in enumerate(categories, start=1):
|
||||||
|
category_name = class_item.class_name.strip()
|
||||||
|
case_tag = f"{DEFAULT_CASE_PREFIX}{idx}"
|
||||||
|
category_map[category_name] = case_tag
|
||||||
|
return category_map
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> dict[str, Any]:
|
async def execute(self, state: WorkflowState) -> str:
|
||||||
"""执行问题分类"""
|
"""执行问题分类"""
|
||||||
question = self.typed_config.input_variable
|
question = self.typed_config.input_variable
|
||||||
|
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||||
supplement_prompt = ""
|
categories = self.typed_config.categories or []
|
||||||
if self.typed_config.user_supplement_prompt is not None:
|
category_names = [class_item.class_name.strip() for class_item in categories]
|
||||||
supplement_prompt = self.typed_config.user_supplement_prompt
|
category_count = len(category_names)
|
||||||
|
|
||||||
category_names = [class_item.class_name for class_item in self.typed_config.categories]
|
|
||||||
|
|
||||||
if not question:
|
if not question:
|
||||||
logger.warning(f"节点 {self.node_id} 未获取到输入问题")
|
logger.warning(
|
||||||
return {self.typed_config.output_variable: category_names[0] if category_names else "unknown"}
|
f"节点 {self.node_id} 未获取到输入问题,使用默认分支"
|
||||||
|
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||||
llm = self._get_llm_instance()
|
)
|
||||||
|
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||||
# 渲染用户提示词模板,支持工作流变量
|
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
||||||
user_prompt = self._render_template(
|
|
||||||
self.typed_config.user_prompt.format(
|
|
||||||
question=question,
|
|
||||||
categories=", ".join(category_names),
|
|
||||||
supplement_prompt=supplement_prompt
|
|
||||||
),
|
|
||||||
state
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
("system", self.typed_config.system_prompt),
|
|
||||||
("user", user_prompt),
|
|
||||||
]
|
|
||||||
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
result = response.content.strip()
|
|
||||||
|
|
||||||
if result in category_names:
|
|
||||||
category = result
|
|
||||||
else:
|
|
||||||
logger.warning(f"LLM返回了未知类别: {result}")
|
|
||||||
category = category_names[0] if category_names else "unknown"
|
|
||||||
|
|
||||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
try:
|
||||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
llm = self._get_llm_instance()
|
||||||
|
|
||||||
return {self.typed_config.output_variable: category}
|
# 渲染用户提示词模板,支持工作流变量
|
||||||
|
user_prompt = self._render_template(
|
||||||
|
self.typed_config.user_prompt.format(
|
||||||
|
question=question,
|
||||||
|
categories=", ".join(category_names),
|
||||||
|
supplement_prompt=supplement_prompt
|
||||||
|
),
|
||||||
|
state
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
("system", self.typed_config.system_prompt),
|
||||||
|
("user", user_prompt),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
result = response.content.strip()
|
||||||
|
|
||||||
|
if result in category_names:
|
||||||
|
category = result
|
||||||
|
else:
|
||||||
|
logger.warning(f"LLM返回了未知类别: {result}")
|
||||||
|
category = category_names[0] if category_names else "unknown"
|
||||||
|
|
||||||
|
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||||
|
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||||
|
|
||||||
|
return f"CASE{category_names.index(category) + 1}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||||
|
exc_info=True # 打印堆栈信息,便于调试
|
||||||
|
)
|
||||||
|
# 异常时返回默认分支,保证工作流容错性
|
||||||
|
if category_count > 0:
|
||||||
|
return DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
return "unknown"
|
||||||
|
|||||||
Reference in New Issue
Block a user