fix(workflow node): Workflow nodes and question classifier nodes - bug fixes
This commit is contained in:
@@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
category_map[category_name] = case_tag
|
category_map[category_name] = case_tag
|
||||||
return category_map
|
return category_map
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> str:
|
async def execute(self, state: WorkflowState) -> dict:
|
||||||
"""执行问题分类"""
|
"""执行问题分类"""
|
||||||
question = self.typed_config.input_variable
|
question = self.typed_config.input_variable
|
||||||
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
supplement_prompt = self.typed_config.user_supplement_prompt or ""
|
||||||
@@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})"
|
||||||
)
|
)
|
||||||
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
# 若分类列表为空,返回默认unknown分支,否则返回CASE1
|
||||||
return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown"
|
if category_count > 0:
|
||||||
|
return {
|
||||||
|
"class_name": category_names[0],
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"class_name": "unknown",
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm = self._get_llm_instance()
|
llm = self._get_llm_instance()
|
||||||
@@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
log_supplement = supplement_prompt if supplement_prompt else "无"
|
log_supplement = supplement_prompt if supplement_prompt else "无"
|
||||||
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}")
|
||||||
|
|
||||||
return f"CASE{category_names.index(category) + 1}"
|
return {
|
||||||
|
"class_name": category,
|
||||||
|
"output": f"CASE{category_names.index(category) + 1}",
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
f"节点 {self.node_id} 分类执行异常:{str(e)}",
|
||||||
@@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode):
|
|||||||
)
|
)
|
||||||
# 异常时返回默认分支,保证工作流容错性
|
# 异常时返回默认分支,保证工作流容错性
|
||||||
if category_count > 0:
|
if category_count > 0:
|
||||||
return DEFAULT_EMPTY_QUESTION_CASE
|
return {
|
||||||
return "unknown"
|
"class_name": category_names[0],
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"class_name": "unknown",
|
||||||
|
"output": DEFAULT_EMPTY_QUESTION_CASE
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig):
|
|||||||
"""工具节点配置"""
|
"""工具节点配置"""
|
||||||
|
|
||||||
tool_id: str = Field(..., description="工具ID")
|
tool_id: str = Field(..., description="工具ID")
|
||||||
tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量")
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||||
@@ -9,6 +9,8 @@ from app.db import get_db_read
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(BaseNode):
|
class ToolNode(BaseNode):
|
||||||
"""工具节点"""
|
"""工具节点"""
|
||||||
@@ -25,25 +27,33 @@ class ToolNode(BaseNode):
|
|||||||
|
|
||||||
# 如果没有租户ID,尝试从工作流ID获取
|
# 如果没有租户ID,尝试从工作流ID获取
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
workflow_id = self.get_variable("sys.workflow_id", state)
|
workspace_id = self.get_variable("sys.workspace_id", state)
|
||||||
if workflow_id:
|
if workspace_id:
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id)
|
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id)
|
||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097")
|
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||||
# logger.error(f"节点 {self.node_id} 缺少租户ID")
|
return {
|
||||||
# return {"error": "缺少租户ID"}
|
"success": False,
|
||||||
|
"data": "缺少租户ID"
|
||||||
|
}
|
||||||
|
|
||||||
# 渲染工具参数
|
# 渲染工具参数
|
||||||
rendered_parameters = {}
|
rendered_parameters = {}
|
||||||
for param_name, param_template in self.typed_config.tool_parameters.items():
|
for param_name, param_template in self.typed_config.tool_parameters.items():
|
||||||
rendered_value = self._render_template(param_template, state)
|
if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template):
|
||||||
|
try:
|
||||||
|
rendered_value = self._render_template(param_template, state)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e
|
||||||
|
else:
|
||||||
|
# 非模板参数(数字/布尔/普通字符串)直接保留原值
|
||||||
|
rendered_value = param_template
|
||||||
rendered_parameters[param_name] = rendered_value
|
rendered_parameters[param_name] = rendered_value
|
||||||
|
|
||||||
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}")
|
||||||
print(self.typed_config.tool_id)
|
|
||||||
|
|
||||||
# 执行工具
|
# 执行工具
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
@@ -54,7 +64,7 @@ class ToolNode(BaseNode):
|
|||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=user_id
|
user_id=user_id
|
||||||
)
|
)
|
||||||
print(result)
|
|
||||||
if result.success:
|
if result.success:
|
||||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||||
return {
|
return {
|
||||||
@@ -66,7 +76,7 @@ class ToolNode(BaseNode):
|
|||||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": result.error,
|
"data": result.error,
|
||||||
"error_code": result.error_code,
|
"error_code": result.error_code,
|
||||||
"execution_time": result.execution_time
|
"execution_time": result.execution_time
|
||||||
}
|
}
|
||||||
@@ -38,6 +38,33 @@ class ToolRepository:
|
|||||||
|
|
||||||
return result[0] if result else None
|
return result[0] if result else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]:
|
||||||
|
"""
|
||||||
|
根据空间ID获取tenant_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
workspace_id: 空间ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tenant_id或None
|
||||||
|
"""
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
|
||||||
|
tenant_id = db.query(Workspace.tenant_id).filter(
|
||||||
|
Workspace.id == workspace_id
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
if tenant_id is not None and not isinstance(tenant_id, uuid.UUID):
|
||||||
|
# 兼容数据库中字段类型不匹配的情况(比如存储为字符串)
|
||||||
|
try:
|
||||||
|
tenant_id = uuid.UUID(tenant_id)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return tenant_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_by_tenant(
|
def find_by_tenant(
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
@@ -344,14 +344,16 @@ class ToolService:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if operation_param:
|
if operation_param:
|
||||||
# 有多个操作
|
# 有多个操作,为每个操作生成具体参数
|
||||||
methods = []
|
methods = []
|
||||||
for operation in operation_param.enum:
|
for operation in operation_param.enum:
|
||||||
|
# 获取该操作的具体参数
|
||||||
|
operation_params = self._get_operation_specific_params(tool_instance, operation)
|
||||||
methods.append({
|
methods.append({
|
||||||
"method_id": f"{config.name}_{operation}",
|
"method_id": f"{config.name}_{operation}",
|
||||||
"name": operation,
|
"name": operation,
|
||||||
"description": f"{config.description} - {operation}",
|
"description": f"{config.description} - {operation}",
|
||||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
"parameters": operation_params
|
||||||
})
|
})
|
||||||
return methods
|
return methods
|
||||||
else:
|
else:
|
||||||
@@ -362,6 +364,243 @@ class ToolService:
|
|||||||
"description": config.description,
|
"description": config.description,
|
||||||
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
"parameters": [p for p in tool_instance.parameters if p.name != "operation"]
|
||||||
}]
|
}]
|
||||||
|
|
||||||
|
def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取特定操作的参数列表"""
|
||||||
|
# 对于datetime_tool,根据操作类型返回相关参数
|
||||||
|
if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool':
|
||||||
|
return self._get_datetime_tool_params(operation)
|
||||||
|
# 对于json_tool,根据操作类型返回相关参数
|
||||||
|
elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool':
|
||||||
|
return self._get_json_tool_params(operation)
|
||||||
|
|
||||||
|
# 其他工具的默认处理:返回除operation外的所有参数
|
||||||
|
return [{
|
||||||
|
"name": param.name,
|
||||||
|
"type": param.type.value,
|
||||||
|
"description": param.description,
|
||||||
|
"required": param.required,
|
||||||
|
"default": param.default,
|
||||||
|
"enum": param.enum,
|
||||||
|
"minimum": param.minimum,
|
||||||
|
"maximum": param.maximum,
|
||||||
|
"pattern": param.pattern
|
||||||
|
} for param in tool_instance.parameters if param.name != "operation"]
|
||||||
|
|
||||||
|
def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取datetime_tool特定操作的参数"""
|
||||||
|
if operation == "now":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "format":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "input_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "convert_timezone":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "input_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "from_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "timestamp_to_datetime":
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# 默认返回所有参数(除了operation)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "input_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入值(时间字符串或时间戳)",
|
||||||
|
"required": False
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "input_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "output_format",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)",
|
||||||
|
"required": False,
|
||||||
|
"default": "%Y-%m-%d %H:%M:%S"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "from_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "源时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "to_timezone",
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标时区(如:UTC, Asia/Shanghai)",
|
||||||
|
"required": False,
|
||||||
|
"default": "Asia/Shanghai"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "calculation",
|
||||||
|
"type": "string",
|
||||||
|
"description": "时间计算表达式(如:+1d, -2h, +30m)",
|
||||||
|
"required": False
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取json_tool特定操作的参数"""
|
||||||
|
base_params = [
|
||||||
|
{
|
||||||
|
"name": "input_data",
|
||||||
|
"type": "string",
|
||||||
|
"description": "输入数据(JSON字符串、YAML字符串或XML字符串)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
if operation == "insert":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "new_value",
|
||||||
|
"type": "string",
|
||||||
|
"description": "新值(用于insert操作)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "replace":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "old_text",
|
||||||
|
"type": "string",
|
||||||
|
"description": "要替换的原文本(用于replace操作)",
|
||||||
|
"required": True
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "new_text",
|
||||||
|
"type": "string",
|
||||||
|
"description": "替换后的新文本(用于replace操作)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "delete":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
elif operation == "parse":
|
||||||
|
return base_params + [
|
||||||
|
{
|
||||||
|
"name": "json_path",
|
||||||
|
"type": "string",
|
||||||
|
"description": "JSON路径表达式(如:$.user.name或users[0].name)",
|
||||||
|
"required": True
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return base_params
|
||||||
|
|
||||||
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]:
|
||||||
"""获取自定义工具的方法"""
|
"""获取自定义工具的方法"""
|
||||||
|
|||||||
Reference in New Issue
Block a user