fix(workflow): allow right-hand operand to be optional when not required by comparison
This commit is contained in:
@@ -456,7 +456,7 @@ class GraphBuilder:
|
|||||||
branch_activate = []
|
branch_activate = []
|
||||||
new_state = state.copy()
|
new_state = state.copy()
|
||||||
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
||||||
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
|
node_output = variable_pool.get_node_output(src, default=dict(), strict=False)
|
||||||
for label, branch in unique_branch.items():
|
for label, branch in unique_branch.items():
|
||||||
if node_output and evaluate_condition(
|
if node_output and evaluate_condition(
|
||||||
branch["condition"],
|
branch["condition"],
|
||||||
|
|||||||
@@ -351,12 +351,12 @@ class VariablePool:
|
|||||||
}
|
}
|
||||||
return runtime_vars
|
return runtime_vars
|
||||||
|
|
||||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||||
"""获取指定节点的输出(运行时变量)
|
"""获取指定节点的输出(运行时变量)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点 ID
|
node_id: 节点 ID
|
||||||
defalut: 默认值
|
default: 默认值
|
||||||
strict: 是否严格模式
|
strict: 是否严格模式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -368,7 +368,7 @@ class VariablePool:
|
|||||||
if strict:
|
if strict:
|
||||||
raise KeyError(f"node {node_id} output not exist")
|
raise KeyError(f"node {node_id} output not exist")
|
||||||
else:
|
else:
|
||||||
return defalut
|
return default
|
||||||
|
|
||||||
def copy(self, pool: 'VariablePool'):
|
def copy(self, pool: 'VariablePool'):
|
||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class ConditionDetail(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
right: Any = Field(
|
right: Any = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Right-hand operand of the comparison expression"
|
description="Right-hand operand of the comparison expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class LoopRuntime:
|
|||||||
self.variable_pool.variables["conv"].update(
|
self.variable_pool.variables["conv"].update(
|
||||||
self.child_variable_pool.variables["conv"]
|
self.child_variable_pool.variables["conv"]
|
||||||
)
|
)
|
||||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False)
|
||||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||||
|
|
||||||
def evaluate_conditional(self) -> bool:
|
def evaluate_conditional(self) -> bool:
|
||||||
@@ -261,4 +261,4 @@ class LoopRuntime:
|
|||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
logger.info(f"loop node {self.node_id}: execution completed")
|
logger.info(f"loop node {self.node_id}: execution completed")
|
||||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class ConditionDetail(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
right: Any = Field(
|
right: Any = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Value to compare with"
|
description="Value to compare with"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ class IfElseNode(BaseNode):
|
|||||||
expressions.append({
|
expressions.append({
|
||||||
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||||
"right": expression.right
|
"right": expression.right
|
||||||
if expression.input_type == ValueInputType.CONSTANT
|
if expression.input_type == ValueInputType.CONSTANT or expression.right is None
|
||||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||||
"operator": str(expression.operator),
|
"operator": str(expression.operator),
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -250,6 +250,8 @@ class ConditionBase(ABC):
|
|||||||
self.type_limit = getattr(self, "type_limit", None)
|
self.type_limit = getattr(self, "type_limit", None)
|
||||||
|
|
||||||
def resolve_right_literal_value(self):
|
def resolve_right_literal_value(self):
|
||||||
|
if self.right_selector is None:
|
||||||
|
return None
|
||||||
if self.input_type == ValueInputType.VARIABLE:
|
if self.input_type == ValueInputType.VARIABLE:
|
||||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]:
|
|||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.rollback() # 只读任务无需 commit
|
db.rollback() # 只读任务无需 commit
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
def get_pool_status():
|
def get_pool_status():
|
||||||
|
|||||||
@@ -303,7 +303,7 @@ async def test_get_node_output_not_exist_with_default():
|
|||||||
"""测试获取不存在的节点输出(使用默认值)"""
|
"""测试获取不存在的节点输出(使用默认值)"""
|
||||||
pool = VariablePool()
|
pool = VariablePool()
|
||||||
|
|
||||||
result = pool.get_node_output("nonexistent_node", defalut=None, strict=False)
|
result = pool.get_node_output("nonexistent_node", default=None, strict=False)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user