fix(workflow): add backward compatibility for any-value variable type
This commit is contained in:
@@ -5,12 +5,12 @@ import re
|
|||||||
from string import Template
|
from string import Template
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
import urllib.parse
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
|
||||||
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
from app.core.workflow.nodes.code.config import CodeNodeConfig
|
||||||
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.core.workflow.variable_pool import VariablePool
|
from app.core.workflow.variable_pool import VariablePool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -108,6 +108,7 @@ class CodeNode(BaseNode):
|
|||||||
code = base64.b64decode(
|
code = base64.b64decode(
|
||||||
self.typed_config.code
|
self.typed_config.code
|
||||||
).decode("utf-8")
|
).decode("utf-8")
|
||||||
|
code = urllib.parse.unquote(code, encoding='utf-8')
|
||||||
|
|
||||||
input_variable_dict = base64.b64encode(
|
input_variable_dict = base64.b64encode(
|
||||||
json.dumps(input_variable_dict).encode("utf-8")
|
json.dumps(input_variable_dict).encode("utf-8")
|
||||||
|
|||||||
@@ -129,7 +129,7 @@ class IterationNodeConfig(BaseNodeConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
output_type: VariableType = Field(
|
output_type: VariableType = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Data type of the loop iteration output"
|
description="Data type of the loop iteration output"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,9 @@ class CycleGraphNode(BaseNode):
|
|||||||
elif self.node_type == NodeType.ITERATION:
|
elif self.node_type == NodeType.ITERATION:
|
||||||
# Iteration node outputs the processed collection
|
# Iteration node outputs the processed collection
|
||||||
config = IterationNodeConfig(**self.config)
|
config = IterationNodeConfig(**self.config)
|
||||||
|
if not config.output_type:
|
||||||
|
outputs['output'] = VariableType.ANY
|
||||||
|
return outputs
|
||||||
if config.output_type in [
|
if config.output_type in [
|
||||||
VariableType.ARRAY_FILE,
|
VariableType.ARRAY_FILE,
|
||||||
VariableType.ARRAY_STRING,
|
VariableType.ARRAY_STRING,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
group_type: dict[str, VariableType] = Field(
|
group_type: dict[str, VariableType] = Field(
|
||||||
...,
|
default=None,
|
||||||
description="每个分组的变量类型"
|
description="每个分组的变量类型"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,10 @@ class VariableAggregatorNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
config = VariableAggregatorNodeConfig(**self.config)
|
config = VariableAggregatorNodeConfig(**self.config)
|
||||||
output = {}
|
output = {}
|
||||||
|
if not config.group_type:
|
||||||
|
for group_name in config.group_variables.keys():
|
||||||
|
output[group_name] = VariableType.ANY
|
||||||
|
return output
|
||||||
for var_type in config.group_type:
|
for var_type in config.group_type:
|
||||||
output[var_type] = config.group_type[var_type]
|
output[var_type] = config.group_type[var_type]
|
||||||
return output
|
return output
|
||||||
@@ -64,6 +68,8 @@ class VariableAggregatorNode(BaseNode):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
logger.info("No variable found in non-group mode; returning empty string.")
|
logger.info("No variable found in non-group mode; returning empty string.")
|
||||||
|
if not self.typed_config.group_type:
|
||||||
|
return ""
|
||||||
return DEFAULT_VALUE(self.typed_config.group_type["output"])
|
return DEFAULT_VALUE(self.typed_config.group_type["output"])
|
||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
@@ -83,7 +89,10 @@ class VariableAggregatorNode(BaseNode):
|
|||||||
result[group_name] = value
|
result[group_name] = value
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
|
if not self.typed_config.group_type:
|
||||||
|
result[group_name] = ""
|
||||||
|
else:
|
||||||
|
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
|
||||||
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
logger.info(f"No variable found for group '{group_name}'; set empty string.")
|
||||||
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -20,6 +20,8 @@ class VariableType(StrEnum):
|
|||||||
|
|
||||||
NESTED_ARRAY = "array_nest"
|
NESTED_ARRAY = "array_nest"
|
||||||
|
|
||||||
|
ANY = 'any'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def type_map(cls, var: Any) -> "VariableType":
|
def type_map(cls, var: Any) -> "VariableType":
|
||||||
"""Maps a Python value to a corresponding VariableType.
|
"""Maps a Python value to a corresponding VariableType.
|
||||||
|
|||||||
@@ -107,6 +107,16 @@ class NestedArrayObject(BaseVariable):
|
|||||||
return [[item.get_value() for item in row] for row in self.value]
|
return [[item.get_value() for item in row] for row in self.value]
|
||||||
|
|
||||||
|
|
||||||
|
class AnyObject(BaseVariable):
|
||||||
|
type = 'any'
|
||||||
|
|
||||||
|
def valid_value(self, value: Any) -> Any:
|
||||||
|
return value
|
||||||
|
|
||||||
|
def to_literal(self) -> str:
|
||||||
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
||||||
"""简化 ArrayObject 创建,不需要重复写类型"""
|
"""简化 ArrayObject 创建,不需要重复写类型"""
|
||||||
|
|
||||||
@@ -133,5 +143,7 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
|||||||
return make_array(DictObject, value)
|
return make_array(DictObject, value)
|
||||||
case VariableType.ARRAY_FILE:
|
case VariableType.ARRAY_FILE:
|
||||||
return make_array(FileObject, value)
|
return make_array(FileObject, value)
|
||||||
|
case VariableType.ANY:
|
||||||
|
return AnyObject(value)
|
||||||
case _:
|
case _:
|
||||||
raise TypeError(f"Invalid type - {var_type}")
|
raise TypeError(f"Invalid type - {var_type}")
|
||||||
|
|||||||
Reference in New Issue
Block a user