diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index 67f53801..b0f2c28d 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -65,7 +65,7 @@ class QuestionClassifierNode(BaseNode): category_map[category_name] = case_tag return category_map - async def execute(self, state: WorkflowState) -> str: + async def execute(self, state: WorkflowState) -> dict: """执行问题分类""" question = self.typed_config.input_variable supplement_prompt = self.typed_config.user_supplement_prompt or "" @@ -79,7 +79,15 @@ class QuestionClassifierNode(BaseNode): f"(默认分支:{DEFAULT_EMPTY_QUESTION_CASE},分类总数:{category_count})" ) # 若分类列表为空,返回默认unknown分支,否则返回CASE1 - return DEFAULT_EMPTY_QUESTION_CASE if category_count > 0 else "unknown" + if category_count > 0: + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } try: llm = self._get_llm_instance() @@ -111,7 +119,10 @@ class QuestionClassifierNode(BaseNode): log_supplement = supplement_prompt if supplement_prompt else "无" logger.info(f"节点 {self.node_id} 分类结果: {category}, 用户补充提示词:{log_supplement}") - return f"CASE{category_names.index(category) + 1}" + return { + "class_name": category, + "output": f"CASE{category_names.index(category) + 1}", + } except Exception as e: logger.error( f"节点 {self.node_id} 分类执行异常:{str(e)}", @@ -119,5 +130,11 @@ class QuestionClassifierNode(BaseNode): ) # 异常时返回默认分支,保证工作流容错性 if category_count > 0: - return DEFAULT_EMPTY_QUESTION_CASE - return "unknown" + return { + "class_name": category_names[0], + "output": DEFAULT_EMPTY_QUESTION_CASE + } + return { + "class_name": "unknown", + "output": DEFAULT_EMPTY_QUESTION_CASE + } diff --git a/api/app/core/workflow/nodes/tool/config.py b/api/app/core/workflow/nodes/tool/config.py index 487efae2..d3b1a644 100644 --- a/api/app/core/workflow/nodes/tool/config.py +++ b/api/app/core/workflow/nodes/tool/config.py @@ -1,4 +1,6 @@ from pydantic import Field +from typing import Any + from app.core.workflow.nodes.base_config import BaseNodeConfig @@ -6,4 +8,4 @@ class ToolNodeConfig(BaseNodeConfig): """工具节点配置""" tool_id: str = Field(..., description="工具ID") - tool_parameters: dict[str, str] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") + tool_parameters: dict[str, Any] = Field(default_factory=dict, description="工具参数映射,支持工作流变量") diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index 993a3804..e1b5f380 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -1,5 +1,5 @@ import logging -import uuid +import re from typing import Any from app.core.workflow.nodes.base_node import BaseNode, WorkflowState @@ -9,6 +9,8 @@ from app.db import get_db_read logger = logging.getLogger(__name__) +TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}") + class ToolNode(BaseNode): """工具节点""" @@ -25,25 +27,33 @@ class ToolNode(BaseNode): # 如果没有租户ID,尝试从工作流ID获取 if not tenant_id: - workflow_id = self.get_variable("sys.workflow_id", state) - if workflow_id: + workspace_id = self.get_variable("sys.workspace_id", state) + if workspace_id: from app.repositories.tool_repository import ToolRepository with get_db_read() as db: - tenant_id = ToolRepository.get_tenant_id_by_workflow_id(db, workflow_id) + tenant_id = ToolRepository.get_tenant_id_by_workspace_id(db, workspace_id) if not tenant_id: - tenant_id = uuid.UUID("6c2c91b0-3f49-4489-9157-2208aa56a097") - # logger.error(f"节点 {self.node_id} 缺少租户ID") - # return {"error": "缺少租户ID"} + logger.error(f"节点 {self.node_id} 缺少租户ID") + return { + "success": False, + "data": "缺少租户ID" + } # 渲染工具参数 rendered_parameters = {} for param_name, param_template in self.typed_config.tool_parameters.items(): - rendered_value = self._render_template(param_template, state) + if isinstance(param_template, str) and TEMPLATE_PATTERN.search(param_template): + try: + rendered_value = self._render_template(param_template, state) + except Exception as e: + raise ValueError(f"模板渲染失败:参数 {param_name} 的模板 {param_template} 解析错误") from e + else: + # 非模板参数(数字/布尔/普通字符串)直接保留原值 + rendered_value = param_template rendered_parameters[param_name] = rendered_value logger.info(f"节点 {self.node_id} 执行工具 {self.typed_config.tool_id},参数: {rendered_parameters}") - print(self.typed_config.tool_id) # 执行工具 with get_db_read() as db: @@ -54,7 +64,7 @@ class ToolNode(BaseNode): tenant_id=tenant_id, user_id=user_id ) - print(result) + if result.success: logger.info(f"节点 {self.node_id} 工具执行成功") return { @@ -66,7 +76,7 @@ class ToolNode(BaseNode): logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}") return { "success": False, - "error": result.error, + "data": result.error, "error_code": result.error_code, "execution_time": result.execution_time } \ No newline at end of file diff --git a/api/app/repositories/tool_repository.py b/api/app/repositories/tool_repository.py index 3aa7b16e..257910c3 100644 --- a/api/app/repositories/tool_repository.py +++ b/api/app/repositories/tool_repository.py @@ -38,6 +38,33 @@ class ToolRepository: return result[0] if result else None + @staticmethod + def get_tenant_id_by_workspace_id(db: Session, workspace_id: str) -> Optional[uuid.UUID]: + """ + 根据空间ID获取tenant_id + + Args: + db: 数据库会话 + workspace_id: 空间ID + + Returns: + tenant_id或None + """ + from app.models.workspace_model import Workspace + + tenant_id = db.query(Workspace.tenant_id).filter( + Workspace.id == workspace_id + ).scalar() + + if tenant_id is not None and not isinstance(tenant_id, uuid.UUID): + # 兼容数据库中字段类型不匹配的情况(比如存储为字符串) + try: + tenant_id = uuid.UUID(tenant_id) + except (ValueError, TypeError): + return None + + return tenant_id + @staticmethod def find_by_tenant( db: Session, diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 50cca957..ab5128fd 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -344,14 +344,16 @@ class ToolService: break if operation_param: - # 有多个操作 + # 有多个操作,为每个操作生成具体参数 methods = [] for operation in operation_param.enum: + # 获取该操作的具体参数 + operation_params = self._get_operation_specific_params(tool_instance, operation) methods.append({ "method_id": f"{config.name}_{operation}", "name": operation, "description": f"{config.description} - {operation}", - "parameters": [p for p in tool_instance.parameters if p.name != "operation"] + "parameters": operation_params }) return methods else: @@ -362,6 +364,243 @@ class ToolService: "description": config.description, "parameters": [p for p in tool_instance.parameters if p.name != "operation"] }] + + def _get_operation_specific_params(self, tool_instance: BaseTool, operation: str) -> List[Dict[str, Any]]: + """获取特定操作的参数列表""" + # 对于datetime_tool,根据操作类型返回相关参数 + if hasattr(tool_instance, 'name') and tool_instance.name == 'datetime_tool': + return self._get_datetime_tool_params(operation) + # 对于json_tool,根据操作类型返回相关参数 + elif hasattr(tool_instance, 'name') and tool_instance.name == 'json_tool': + return self._get_json_tool_params(operation) + + # 其他工具的默认处理:返回除operation外的所有参数 + return [{ + "name": param.name, + "type": param.type.value, + "description": param.description, + "required": param.required, + "default": param.default, + "enum": param.enum, + "minimum": param.minimum, + "maximum": param.maximum, + "pattern": param.pattern + } for param in tool_instance.parameters if param.name != "operation"] + + def _get_datetime_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取datetime_tool特定操作的参数""" + if operation == "now": + return [ + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "format": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + } + ] + elif operation == "convert_timezone": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + elif operation == "timestamp_to_datetime": + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": True + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + } + ] + else: + # 默认返回所有参数(除了operation) + return [ + { + "name": "input_value", + "type": "string", + "description": "输入值(时间字符串或时间戳)", + "required": False + }, + { + "name": "input_format", + "type": "string", + "description": "输入时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "output_format", + "type": "string", + "description": "输出时间格式(如:%Y-%m-%d %H:%M:%S)", + "required": False, + "default": "%Y-%m-%d %H:%M:%S" + }, + { + "name": "from_timezone", + "type": "string", + "description": "源时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "to_timezone", + "type": "string", + "description": "目标时区(如:UTC, Asia/Shanghai)", + "required": False, + "default": "Asia/Shanghai" + }, + { + "name": "calculation", + "type": "string", + "description": "时间计算表达式(如:+1d, -2h, +30m)", + "required": False + } + ] + + def _get_json_tool_params(self, operation: str) -> List[Dict[str, Any]]: + """获取json_tool特定操作的参数""" + base_params = [ + { + "name": "input_data", + "type": "string", + "description": "输入数据(JSON字符串、YAML字符串或XML字符串)", + "required": True + } + ] + + if operation == "insert": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "new_value", + "type": "string", + "description": "新值(用于insert操作)", + "required": True + } + ] + elif operation == "replace": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + }, + { + "name": "old_text", + "type": "string", + "description": "要替换的原文本(用于replace操作)", + "required": True + }, + { + "name": "new_text", + "type": "string", + "description": "替换后的新文本(用于replace操作)", + "required": True + } + ] + elif operation == "delete": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + elif operation == "parse": + return base_params + [ + { + "name": "json_path", + "type": "string", + "description": "JSON路径表达式(如:$.user.name或users[0].name)", + "required": True + } + ] + + return base_params async def _get_custom_tool_methods(self, config: ToolConfig) -> List[Dict[str, Any]]: """获取自定义工具的方法"""