From cf26c9f39c721ae80beb56f6b8516aa5c767cb24 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 20 Mar 2026 19:09:41 +0800 Subject: [PATCH] fix(workflow): allow right-hand operand to be optional when not required by comparison --- api/app/core/workflow/engine/graph_builder.py | 2 +- api/app/core/workflow/engine/variable_pool.py | 6 +++--- api/app/core/workflow/nodes/cycle_graph/config.py | 2 +- api/app/core/workflow/nodes/cycle_graph/loop.py | 4 ++-- api/app/core/workflow/nodes/if_else/config.py | 2 +- api/app/core/workflow/nodes/if_else/node.py | 2 +- api/app/core/workflow/nodes/operators.py | 2 ++ api/app/db.py | 1 + api/tests/workflow/executor/test_vairable_pool.py | 2 +- 9 files changed, 13 insertions(+), 10 deletions(-) diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 813a543f..674c45d0 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -456,7 +456,7 @@ class GraphBuilder: branch_activate = [] new_state = state.copy() new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate - node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False) + node_output = variable_pool.get_node_output(src, default=dict(), strict=False) for label, branch in unique_branch.items(): if node_output and evaluate_condition( branch["condition"], diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index bc88df19..cf6f4a7b 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -351,12 +351,12 @@ class VariablePool: } return runtime_vars - def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: + def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None: """获取指定节点的输出(运行时变量) Args: node_id: 节点 ID - defalut: 默认值 + default: 默认值 strict: 是否严格模式 Returns: @@ -368,7 +368,7 @@ class VariablePool: if strict: raise KeyError(f"node {node_id} output not exist") else: - return defalut + return default def copy(self, pool: 'VariablePool'): self.variables = deepcopy(pool.variables) diff --git a/api/app/core/workflow/nodes/cycle_graph/config.py b/api/app/core/workflow/nodes/cycle_graph/config.py index 52aca1d9..75358c47 100644 --- a/api/app/core/workflow/nodes/cycle_graph/config.py +++ b/api/app/core/workflow/nodes/cycle_graph/config.py @@ -51,7 +51,7 @@ class ConditionDetail(BaseModel): ) right: Any = Field( - ..., + default=None, description="Right-hand operand of the comparison expression" ) diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index d3ada1ec..84901bad 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -158,7 +158,7 @@ class LoopRuntime: self.variable_pool.variables["conv"].update( self.child_variable_pool.variables["conv"] ) - loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) + loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) loopstate["node_outputs"][self.node_id] = loop_vars def evaluate_conditional(self) -> bool: @@ -261,4 +261,4 @@ class LoopRuntime: idx += 1 logger.info(f"loop node {self.node_id}: execution completed") - return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state} + return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state} diff --git a/api/app/core/workflow/nodes/if_else/config.py b/api/app/core/workflow/nodes/if_else/config.py index 894898f0..638e4b2d 100644 --- a/api/app/core/workflow/nodes/if_else/config.py +++ b/api/app/core/workflow/nodes/if_else/config.py @@ -18,7 +18,7 @@ class ConditionDetail(BaseModel): ) right: Any = Field( - ..., + default=None, description="Value to compare with" ) diff --git a/api/app/core/workflow/nodes/if_else/node.py b/api/app/core/workflow/nodes/if_else/node.py index 16782488..5d2bdf9a 100644 --- a/api/app/core/workflow/nodes/if_else/node.py +++ b/api/app/core/workflow/nodes/if_else/node.py @@ -31,7 +31,7 @@ class IfElseNode(BaseNode): expressions.append({ "left": self.get_variable(expression.left, variable_pool, strict=False), "right": expression.right - if expression.input_type == ValueInputType.CONSTANT + if expression.input_type == ValueInputType.CONSTANT or expression.right is None else self.get_variable(expression.right, variable_pool, strict=False), "operator": str(expression.operator), }) diff --git a/api/app/core/workflow/nodes/operators.py b/api/app/core/workflow/nodes/operators.py index be33d35a..14fc9d9f 100644 --- a/api/app/core/workflow/nodes/operators.py +++ b/api/app/core/workflow/nodes/operators.py @@ -250,6 +250,8 @@ class ConditionBase(ABC): self.type_limit = getattr(self, "type_limit", None) def resolve_right_literal_value(self): + if self.right_selector is None: + return None if self.input_type == ValueInputType.VARIABLE: pattern = r"\{\{\s*(.*?)\s*\}\}" right_expression = re.sub(pattern, r"\1", self.right_selector).strip() diff --git a/api/app/db.py b/api/app/db.py index 80ab2756..32261c46 100644 --- a/api/app/db.py +++ b/api/app/db.py @@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]: yield db finally: db.rollback() # 只读任务无需 commit + db.close() def get_pool_status(): diff --git a/api/tests/workflow/executor/test_vairable_pool.py b/api/tests/workflow/executor/test_vairable_pool.py index 3404eb79..0ba4d259 100644 --- a/api/tests/workflow/executor/test_vairable_pool.py +++ b/api/tests/workflow/executor/test_vairable_pool.py @@ -303,7 +303,7 @@ async def test_get_node_output_not_exist_with_default(): """测试获取不存在的节点输出(使用默认值)""" pool = VariablePool() - result = pool.get_node_output("nonexistent_node", defalut=None, strict=False) + result = pool.get_node_output("nonexistent_node", default=None, strict=False) assert result is None