fix(workflow node): Workflow nodes and question classifier nodes - bug fixes

This commit is contained in:
谢俊男
2026-01-06 12:09:55 +08:00
parent 2fadf88a93
commit 962b74a68a
5 changed files with 314 additions and 19 deletions

View File

@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
category_map[category_name] = case_tag
return category_map
async def execute(self, state: WorkflowState) -> str:
async def execute(self, state: WorkflowState) -> dict:
"""执行问题分类"""
question = self.typed_config.input_variable
supplement_prompt = self.typed_config.user_supplement_prompt or ""
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count}"
)
# 若分类列表为空返回默认unknown分支否则返回CASE1
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
if category_count > 0:
return {
"class_name": category_names[0],
"output": DEFAULT_EMPTY_QUESTION_CASE
}
return {
"class_name": "unknown",
"output": DEFAULT_EMPTY_QUESTION_CASE
}
try:
llm = self._get_llm_instance()
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
log_supplement = supplement_prompt if supplement_prompt else ""
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
return f"CASE{category_names.index(category) + 1}"
return {
"class_name": category,
"output": f"CASE{category_names.index(category) + 1}",
}
except Exception as e:
logger.error(
f"节点 {self.node_id} 分类执行异常:{str(e)}",
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
)
# 异常时返回默认分支,保证工作流容错性
if category_count > 0:
return DEFAULT_EMPTY_QUESTION_CASE
return "unknown"
return {
"class_name": category_names[0],
"output": DEFAULT_EMPTY_QUESTION_CASE
}
return {
"class_name": "unknown",
"output": DEFAULT_EMPTY_QUESTION_CASE
}

View File

@@ -1,4 +1,6 @@
from pydantic import Field
from typing import Any
from app.core.workflow.nodes.base_config import BaseNodeConfig
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
"""工具节点配置"""
tool_id: str = Field(..., description="工具ID")
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")

View File

@@ -1,5 +1,5 @@
import logging
import uuid
import re
from typing import Any
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
@@ -9,6 +9,8 @@ from app.db import get_db_read
logger = logging.getLogger(__name__)
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
class ToolNode(BaseNode):
"""工具节点"""
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
# 如果没有租户ID尝试从工作流ID获取
if not tenant_id:
workflow_id = self.get_variable("sys.workflow_id", state)
if workflow_id:
workspace_id = self.get_variable("sys.workspace_id", state)
if workspace_id:
from app.repositories.tool_repository import ToolRepository
with get_db_read() as db:
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
if not tenant_id:
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
# logger.error(f"节点 {self.node_id} 缺少租户ID")
# return {"error": "缺少租户ID"}
logger.error(f"节点 {self.node_id} 缺少租户ID")
return {
"success": False,
"data": "缺少租户ID"
}
# 渲染工具参数
rendered_parameters = {}
for param_name, param_template in self.typed_config.tool_parameters.items():
rendered_value = self._render_template(param_template, state)
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
try:
rendered_value = self._render_template(param_template, state)
except Exception as e:
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
else:
# 非模板参数(数字/布尔/普通字符串)直接保留原值
rendered_value = param_template
rendered_parameters[param_name] = rendered_value
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
print(self.typed_config.tool_id)
# 执行工具
with get_db_read() as db:
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
tenant_id=tenant_id,
user_id=user_id
)
print(result)
if result.success:
logger.info(f"节点 {self.node_id} 工具执行成功")
return {
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
return {
"success": False,
"error": result.error,
"data": result.error,
"error_code": result.error_code,
"execution_time": result.execution_time
}