From 84e24ede046e13af4cc127b1cf0584f1d27f3582 Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 14 Jan 2026 10:47:38 +0800 Subject: [PATCH] fix(workflow): move node config validation to runtime for proper error handling --- api/app/core/workflow/nodes/assigner/node.py | 3 ++- api/app/core/workflow/nodes/cycle_graph/node.py | 1 - api/app/core/workflow/nodes/http_request/node.py | 3 ++- api/app/core/workflow/nodes/if_else/node.py | 3 ++- api/app/core/workflow/nodes/jinja_render/node.py | 3 ++- api/app/core/workflow/nodes/knowledge/node.py | 3 ++- api/app/core/workflow/nodes/llm/node.py | 3 ++- api/app/core/workflow/nodes/memory/node.py | 3 ++- api/app/core/workflow/nodes/parameter_extractor/node.py | 3 ++- api/app/core/workflow/nodes/question_classifier/node.py | 6 ++++-- api/app/core/workflow/nodes/start/node.py | 3 ++- api/app/core/workflow/nodes/tool/node.py | 3 ++- api/app/core/workflow/nodes/variable_aggregator/node.py | 3 ++- 13 files changed, 26 insertions(+), 14 deletions(-) diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index 7b9d645b..96f68ce8 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -14,7 +14,7 @@ logger = logging.getLogger(__name__) class AssignerNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = AssignerNodeConfig(**self.config) + self.typed_config: AssignerNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: """ @@ -28,6 +28,7 @@ class AssignerNode(BaseNode): None or the result of the assignment operation. """ # Initialize a variable pool for accessing conversation, node, and system variables + self.typed_config = AssignerNodeConfig(**self.config) logger.info(f"节点 {self.node_id} 开始执行") pool = VariablePool(state) for assignment in self.typed_config.assignments: diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index fb062f39..1659395e 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -30,7 +30,6 @@ class CycleGraphNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config: LoopNodeConfig | IterationNodeConfig | None = None self.cycle_nodes = list() # Nodes belonging to this cycle self.cycle_edges = list() # Edges connecting nodes within the cycle diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index 2e5de796..141cba79 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -32,7 +32,7 @@ class HttpRequestNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = HttpRequestNodeConfig(**self.config) + self.typed_config: HttpRequestNodeConfig | None = None def _build_timeout(self) -> Timeout: """ @@ -181,6 +181,7 @@ class HttpRequestNode(BaseNode): - dict: Serialized HttpRequestNodeOutput on success - str: Branch identifier (e.g. "ERROR") when branching is enabled """ + self.typed_config = HttpRequestNodeConfig(**self.config) async with httpx.AsyncClient( verify=self.typed_config.verify_ssl, timeout=self._build_timeout(), diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 8c6d222f..41f1138b 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) class IfElseNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = IfElseNodeConfig(**self.config) + self.typed_config: IfElseNodeConfig | None= None @staticmethod def _evaluate(operator, instance: CompareOperatorInstance) -> Any: @@ -109,6 +109,7 @@ class IfElseNode(BaseNode): Returns: str: The matched branch identifier, e.g., 'CASE1', 'CASE2', ..., used for node transitions. """ + self.typed_config = IfElseNodeConfig(**self.config) expressions = self.evaluate_conditional_edge_expressions(state) # TODO: 变量类型及文本类型解析 for i in range(len(expressions)): diff --git a/api/app/core/workflow/nodes/jinja_render/node.py b/api/app/core/workflow/nodes/jinja_render/node.py index 70993573..822f1918 100644 --- a/api/app/core/workflow/nodes/jinja_render/node.py +++ b/api/app/core/workflow/nodes/jinja_render/node.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) class JinjaRenderNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = JinjaRenderNodeConfig(**self.config) + self.typed_config: JinjaRenderNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: """ @@ -34,6 +34,7 @@ class JinjaRenderNode(BaseNode): RuntimeError: If Jinja2 template rendering fails due to invalid template syntax or missing variables. """ + self.typed_config = JinjaRenderNodeConfig(**self.config) render = TemplateRenderer(strict=False) context = {} diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 061328e1..221ca079 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -21,7 +21,7 @@ logger = logging.getLogger(__name__) class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) + self.typed_config: KnowledgeRetrievalNodeConfig | None = None @staticmethod def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType): @@ -171,6 +171,7 @@ class KnowledgeRetrievalNode(BaseNode): Raises: RuntimeError: If no valid knowledge base is found or access is denied. """ + self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) query = self._render_template(self.typed_config.query, state) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index 5fb86ae2..6395d3b8 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -68,7 +68,7 @@ class LLMNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = LLMNodeConfig(**self.config) + self.typed_config: LLMNodeConfig | None = None def _render_context(self, message, state): context = f"{self._render_template(self.typed_config.context, state)}" @@ -164,6 +164,7 @@ class LLMNode(BaseNode): Returns: LLM 响应消息 """ + self.typed_config = LLMNodeConfig(**self.config) llm, prompt_or_messages = self._prepare_llm(state, True) logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)") diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 0d1b1fb4..f1c99ddb 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -10,9 +10,10 @@ from app.services.memory_agent_service import MemoryAgentService class MemoryReadNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = MemoryReadNodeConfig(**self.config) + self.typed_config: MemoryReadNodeConfig | None = None async def execute(self, state: WorkflowState) -> Any: + self.typed_config = MemoryReadNodeConfig(**self.config) with get_db_read() as db: workspace_id = self.get_variable('sys.workspace_id', state) end_user_id = self.get_variable("sys.user_id", state) diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 84d61aa9..ec58d96c 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) class ParameterExtractorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = ParameterExtractorNodeConfig(**self.config) + self.typed_config: ParameterExtractorNodeConfig | None = None @staticmethod def _get_prompt(): @@ -145,6 +145,7 @@ class ParameterExtractorNode(BaseNode): Raises: BusinessException: If LLM output cannot be parsed as valid JSON. """ + self.typed_config = ParameterExtractorNodeConfig(**self.config) llm = self._get_llm_instance() system_prompt, user_prompt = self._get_prompt() diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index b0f2c28d..aee72eda 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -21,8 +21,8 @@ class QuestionClassifierNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = QuestionClassifierNodeConfig(**self.config) - self.category_to_case_map = self._build_category_case_map() + self.typed_config: QuestionClassifierNodeConfig | None = None + self.category_to_case_map = {} def _get_llm_instance(self) -> RedBearLLM: """获取LLM实例""" @@ -67,6 +67,8 @@ class QuestionClassifierNode(BaseNode): async def execute(self, state: WorkflowState) -> dict: """执行问题分类""" + self.typed_config = QuestionClassifierNodeConfig(**self.config) + self.category_to_case_map = self._build_category_case_map() question = self.typed_config.input_variable supplement_prompt = self.typed_config.user_supplement_prompt or "" categories = self.typed_config.categories or [] diff --git a/api/app/core/workflow/nodes/start/node.py b/api/app/core/workflow/nodes/start/node.py index f9927f0c..69560422 100644 --- a/api/app/core/workflow/nodes/start/node.py +++ b/api/app/core/workflow/nodes/start/node.py @@ -35,7 +35,7 @@ class StartNode(BaseNode): super().__init__(node_config, workflow_config) # 解析并验证配置 - self.typed_config = StartNodeConfig(**self.config) + self.typed_config: StartNodeConfig | None = None async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行 start 节点业务逻辑 @@ -48,6 +48,7 @@ class StartNode(BaseNode): Returns: 包含系统参数、会话变量和自定义变量的字典 """ + self.typed_config = StartNodeConfig(**self.config) logger.info(f"节点 {self.node_id} (Start) 开始执行") # 创建变量池实例(在方法内复用) diff --git a/api/app/core/workflow/nodes/tool/node.py b/api/app/core/workflow/nodes/tool/node.py index e1b5f380..a83aea9f 100644 --- a/api/app/core/workflow/nodes/tool/node.py +++ b/api/app/core/workflow/nodes/tool/node.py @@ -17,10 +17,11 @@ class ToolNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = ToolNodeConfig(**self.config) + self.typed_config: ToolNodeConfig | None = None async def execute(self, state: WorkflowState) -> dict[str, Any]: """执行工具""" + self.typed_config = ToolNodeConfig(**self.config) # 获取租户ID和用户ID tenant_id = self.get_variable("sys.tenant_id", state) user_id = self.get_variable("sys.user_id", state) diff --git a/api/app/core/workflow/nodes/variable_aggregator/node.py b/api/app/core/workflow/nodes/variable_aggregator/node.py index e6cbf75b..5bff8e33 100644 --- a/api/app/core/workflow/nodes/variable_aggregator/node.py +++ b/api/app/core/workflow/nodes/variable_aggregator/node.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) class VariableAggregatorNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) - self.typed_config = VariableAggregatorNodeConfig(**self.config) + self.typed_config: VariableAggregatorNodeConfig | None = None @staticmethod def _get_express(variable_string: str) -> Any: @@ -37,6 +37,7 @@ class VariableAggregatorNode(BaseNode): - str: In non-group mode, returns the first non-None variable value. - dict: In group mode, returns a mapping of group_name -> first non-None variable value. """ + self.typed_config = VariableAggregatorNodeConfig(**self.config) if not self.typed_config.group: # -------------------------- # Non-group mode