fix(workflow node): Workflow nodes and question classifier nodes - bug fixes
This commit is contained in:
@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
question = self.typed_config.input_variable
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||
)
|
||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
||||
if category_count > 0:
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
try:
|
||||
llm = self._get_llm_instance()
|
||||
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
|
||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||
|
||||
return f"CASE{category_names.index(category) + 1}"
|
||||
return {
|
||||
"class_name": category,
|
||||
"output": f"CASE{category_names.index(category) + 1}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
|
||||
)
|
||||
# 异常时返回默认分支,保证工作流容错性
|
||||
if category_count > 0:
|
||||
return DEFAULT_EMPTY_QUESTION_CASE
|
||||
return "unknown"
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pydantic import Field
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
|
||||
"""工具节点配置"""
|
||||
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
@@ -9,6 +9,8 @@ from app.db import get_db_read
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
workflow_id = self.get_variable("sys.workflow_id", state)
|
||||
if workflow_id:
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
if workspace_id:
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
with get_db_read() as db:
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||
|
||||
if not tenant_id:
|
||||
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
||||
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
# return {"error": "缺少租户ID"}
|
||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
return {
|
||||
"success": False,
|
||||
"data": "缺少租户ID"
|
||||
}
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||
print(self.typed_config.tool_id)
|
||||
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
print(result)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||
return {
|
||||
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.error,
|
||||
"data": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
Reference in New Issue
Block a user