fix(workflow node): Workflow nodes and question classifier nodes - bug fixes
This commit is contained in:
@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
|
||||
category_map[category_name] = case_tag
|
||||
return category_map
|
||||
|
||||
async def execute(self, state: WorkflowState) -> str:
|
||||
async def execute(self, state: WorkflowState) -> dict:
|
||||
"""执行问题分类"""
|
||||
question = self.typed_config.input_variable
|
||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
|
||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||
)
|
||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
||||
if category_count > 0:
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
try:
|
||||
llm = self._get_llm_instance()
|
||||
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
|
||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||
|
||||
return f"CASE{category_names.index(category) + 1}"
|
||||
return {
|
||||
"class_name": category,
|
||||
"output": f"CASE{category_names.index(category) + 1}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
|
||||
)
|
||||
# 异常时返回默认分支,保证工作流容错性
|
||||
if category_count > 0:
|
||||
return DEFAULT_EMPTY_QUESTION_CASE
|
||||
return "unknown"
|
||||
return {
|
||||
"class_name": category_names[0],
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
return {
|
||||
"class_name": "unknown",
|
||||
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pydantic import Field
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
|
||||
"""工具节点配置"""
|
||||
|
||||
tool_id: str = Field(..., description="工具ID")
|
||||
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
import uuid
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
@@ -9,6 +9,8 @@ from app.db import get_db_read
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
"""工具节点"""
|
||||
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
|
||||
|
||||
# 如果没有租户ID,尝试从工作流ID获取
|
||||
if not tenant_id:
|
||||
workflow_id = self.get_variable("sys.workflow_id", state)
|
||||
if workflow_id:
|
||||
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||
if workspace_id:
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
with get_db_read() as db:
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
||||
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||
|
||||
if not tenant_id:
|
||||
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
||||
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
# return {"error": "缺少租户ID"}
|
||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
return {
|
||||
"success": False,
|
||||
"data": "缺少租户ID"
|
||||
}
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||
try:
|
||||
rendered_value = self._render_template(param_template, state)
|
||||
except Exception as e:
|
||||
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||
else:
|
||||
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||
rendered_value = param_template
|
||||
rendered_parameters[param_name] = rendered_value
|
||||
|
||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||
print(self.typed_config.tool_id)
|
||||
|
||||
# 执行工具
|
||||
with get_db_read() as db:
|
||||
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id
|
||||
)
|
||||
print(result)
|
||||
|
||||
if result.success:
|
||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||
return {
|
||||
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": result.error,
|
||||
"data": result.error,
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
@@ -38,6 +38,33 @@ class ToolRepository:
|
||||
|
||||
return result[0] if result else None
|
||||
|
||||
@staticmethod
|
||||
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
|
||||
"""
|
||||
根据空间ID获取tenant_id
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 空间ID
|
||||
|
||||
Returns:
|
||||
tenant_id或None
|
||||
"""
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
tenant_id = db.query(Workspace.tenant_id).filter(
|
||||
Workspace.id == workspace_id
|
||||
).scalar()
|
||||
|
||||
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
|
||||
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
|
||||
try:
|
||||
tenant_id = uuid.UUID(tenant_id)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
return tenant_id
|
||||
|
||||
@staticmethod
|
||||
def find_by_tenant(
|
||||
db: Session,
|
||||
|
||||
@@ -344,14 +344,16 @@ class ToolService:
|
||||
break
|
||||
|
||||
if operation_param:
|
||||
# 有多个操作
|
||||
# 有多个操作,为每个操作生成具体参数
|
||||
methods = []
|
||||
for operation in operation_param.enum:
|
||||
# 获取该操作的具体参数
|
||||
operation_params = self._get_operation_specific_params(tool_instance, operation)
|
||||
methods.append({
|
||||
"method_id": f"{config.name}_{operation}",
|
||||
"name": operation,
|
||||
"description": f"{config.description} - {operation}",
|
||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||
"parameters": operation_params
|
||||
})
|
||||
return methods
|
||||
else:
|
||||
@@ -362,6 +364,243 @@ class ToolService:
|
||||
"description": config.description,
|
||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||
}]
|
||||
|
||||
def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]:
|
||||
"""获取特定操作的参数列表"""
|
||||
# 对于datetime_tool,根据操作类型返回相关参数
|
||||
if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool':
|
||||
return self._get_datetime_tool_params(operation)
|
||||
# 对于json_tool,根据操作类型返回相关参数
|
||||
elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool':
|
||||
return self._get_json_tool_params(operation)
|
||||
|
||||
# 其他工具的默认处理:返回除operation外的所有参数
|
||||
return [{
|
||||
"name": param.name,
|
||||
"type": param.type.value,
|
||||
"description": param.description,
|
||||
"required": param.required,
|
||||
"default": param.default,
|
||||
"enum": param.enum,
|
||||
"minimum": param.minimum,
|
||||
"maximum": param.maximum,
|
||||
"pattern": param.pattern
|
||||
} for param in tool_instance.parameters if param.name != "operation"]
|
||||
|
||||
def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||
"""获取datetime_tool特定操作的参数"""
|
||||
if operation == "now":
|
||||
return [
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
}
|
||||
]
|
||||
elif operation == "format":
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "input_format",
|
||||
"type": "string",
|
||||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
}
|
||||
]
|
||||
elif operation == "convert_timezone":
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "input_format",
|
||||
"type": "string",
|
||||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "from_timezone",
|
||||
"type": "string",
|
||||
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
}
|
||||
]
|
||||
elif operation == "timestamp_to_datetime":
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
}
|
||||
]
|
||||
else:
|
||||
# 默认返回所有参数(除了operation)
|
||||
return [
|
||||
{
|
||||
"name": "input_value",
|
||||
"type": "string",
|
||||
"description": "输入值(时间字符串或时间戳)",
|
||||
"required": False
|
||||
},
|
||||
{
|
||||
"name": "input_format",
|
||||
"type": "string",
|
||||
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "output_format",
|
||||
"type": "string",
|
||||
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||
"required": False,
|
||||
"default": "%Y-%m-%d %H:%M:%S"
|
||||
},
|
||||
{
|
||||
"name": "from_timezone",
|
||||
"type": "string",
|
||||
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "to_timezone",
|
||||
"type": "string",
|
||||
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||
"required": False,
|
||||
"default": "Asia/Shanghai"
|
||||
},
|
||||
{
|
||||
"name": "calculation",
|
||||
"type": "string",
|
||||
"description": "时间计算表达式(如:+1d, -2h, +30m)",
|
||||
"required": False
|
||||
}
|
||||
]
|
||||
|
||||
def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||
"""获取json_tool特定操作的参数"""
|
||||
base_params = [
|
||||
{
|
||||
"name": "input_data",
|
||||
"type": "string",
|
||||
"description": "输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
|
||||
if operation == "insert":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "new_value",
|
||||
"type": "string",
|
||||
"description": "新值(用于insert操作)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
elif operation == "replace":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "old_text",
|
||||
"type": "string",
|
||||
"description": "要替换的原文本(用于replace操作)",
|
||||
"required": True
|
||||
},
|
||||
{
|
||||
"name": "new_text",
|
||||
"type": "string",
|
||||
"description": "替换后的新文本(用于replace操作)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
elif operation == "delete":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
elif operation == "parse":
|
||||
return base_params + [
|
||||
{
|
||||
"name": "json_path",
|
||||
"type": "string",
|
||||
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||
"required": True
|
||||
}
|
||||
]
|
||||
|
||||
return base_params
|
||||
|
||||
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||
"""获取自定义工具的方法"""
|
||||
|
||||
Reference in New Issue
Block a user