fix(workflow): allow right-hand operand to be optional when not required by comparison

This commit is contained in:
Eternity
2026-03-20 19:09:41 +08:00
parent fabc8936ab
commit cf26c9f39c
9 changed files with 13 additions and 10 deletions

View File

@@ -456,7 +456,7 @@ class GraphBuilder:
branch_activate = [] branch_activate = []
new_state = state.copy() new_state = state.copy()
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False) node_output = variable_pool.get_node_output(src, default=dict(), strict=False)
for label, branch in unique_branch.items(): for label, branch in unique_branch.items():
if node_output and evaluate_condition( if node_output and evaluate_condition(
branch["condition"], branch["condition"],

View File

@@ -351,12 +351,12 @@ class VariablePool:
} }
return runtime_vars return runtime_vars
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量) """获取指定节点的输出(运行时变量)
Args: Args:
node_id: 节点 ID node_id: 节点 ID
defalut: 默认值 default: 默认值
strict: 是否严格模式 strict: 是否严格模式
Returns: Returns:
@@ -368,7 +368,7 @@ class VariablePool:
if strict: if strict:
raise KeyError(f"node {node_id} output not exist") raise KeyError(f"node {node_id} output not exist")
else: else:
return defalut return default
def copy(self, pool: 'VariablePool'): def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables) self.variables = deepcopy(pool.variables)

View File

@@ -51,7 +51,7 @@ class ConditionDetail(BaseModel):
) )
right: Any = Field( right: Any = Field(
..., default=None,
description="Right-hand operand of the comparison expression" description="Right-hand operand of the comparison expression"
) )

View File

@@ -158,7 +158,7 @@ class LoopRuntime:
self.variable_pool.variables["conv"].update( self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"] self.child_variable_pool.variables["conv"]
) )
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars loopstate["node_outputs"][self.node_id] = loop_vars
def evaluate_conditional(self) -> bool: def evaluate_conditional(self) -> bool:
@@ -261,4 +261,4 @@ class LoopRuntime:
idx += 1 idx += 1
logger.info(f"loop node {self.node_id}: execution completed") logger.info(f"loop node {self.node_id}: execution completed")
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state} return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state}

View File

@@ -18,7 +18,7 @@ class ConditionDetail(BaseModel):
) )
right: Any = Field( right: Any = Field(
..., default=None,
description="Value to compare with" description="Value to compare with"
) )

View File

@@ -31,7 +31,7 @@ class IfElseNode(BaseNode):
expressions.append({ expressions.append({
"left": self.get_variable(expression.left, variable_pool, strict=False), "left": self.get_variable(expression.left, variable_pool, strict=False),
"right": expression.right "right": expression.right
if expression.input_type == ValueInputType.CONSTANT if expression.input_type == ValueInputType.CONSTANT or expression.right is None
else self.get_variable(expression.right, variable_pool, strict=False), else self.get_variable(expression.right, variable_pool, strict=False),
"operator": str(expression.operator), "operator": str(expression.operator),
}) })

View File

@@ -250,6 +250,8 @@ class ConditionBase(ABC):
self.type_limit = getattr(self, "type_limit", None) self.type_limit = getattr(self, "type_limit", None)
def resolve_right_literal_value(self): def resolve_right_literal_value(self):
if self.right_selector is None:
return None
if self.input_type == ValueInputType.VARIABLE: if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}" pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip() right_expression = re.sub(pattern, r"\1", self.right_selector).strip()

View File

@@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]:
yield db yield db
finally: finally:
db.rollback() # 只读任务无需 commit db.rollback() # 只读任务无需 commit
db.close()
def get_pool_status(): def get_pool_status():

View File

@@ -303,7 +303,7 @@ async def test_get_node_output_not_exist_with_default():
"""测试获取不存在的节点输出(使用默认值)""" """测试获取不存在的节点输出(使用默认值)"""
pool = VariablePool() pool = VariablePool()
result = pool.get_node_output("nonexistent_node", defalut=None, strict=False) result = pool.get_node_output("nonexistent_node", default=None, strict=False)
assert result is None assert result is None