fix(workflow): add backward compatibility for any-value variable type

This commit is contained in:
Eternity
2026-02-04 12:11:22 +08:00
parent bd8a451879
commit c6ea31c296
7 changed files with 32 additions and 5 deletions

View File

@@ -5,12 +5,12 @@ import re
from string import Template
from textwrap import dedent
from typing import Any
import urllib.parse
import httpx
from app.core.workflow.nodes import BaseNode, WorkflowState
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.nodes.code.config import CodeNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable_pool import VariablePool
logger = logging.getLogger(__name__)
@@ -108,6 +108,7 @@ class CodeNode(BaseNode):
code = base64.b64decode(
self.typed_config.code
).decode("utf-8")
code = urllib.parse.unquote(code, encoding='utf-8')
input_variable_dict = base64.b64encode(
json.dumps(input_variable_dict).encode("utf-8")

View File

@@ -129,7 +129,7 @@ class IterationNodeConfig(BaseNodeConfig):
)
output_type: VariableType = Field(
...,
default=None,
description="Data type of the loop iteration output"
)

View File

@@ -53,6 +53,9 @@ class CycleGraphNode(BaseNode):
elif self.node_type == NodeType.ITERATION:
# Iteration node outputs the processed collection
config = IterationNodeConfig(**self.config)
if not config.output_type:
outputs['output'] = VariableType.ANY
return outputs
if config.output_type in [
VariableType.ARRAY_FILE,
VariableType.ARRAY_STRING,

View File

@@ -16,7 +16,7 @@ class VariableAggregatorNodeConfig(BaseNodeConfig):
)
group_type: dict[str, VariableType] = Field(
...,
default=None,
description="每个分组的变量类型"
)

View File

@@ -19,6 +19,10 @@ class VariableAggregatorNode(BaseNode):
def _output_types(self) -> dict[str, VariableType]:
config = VariableAggregatorNodeConfig(**self.config)
output = {}
if not config.group_type:
for group_name in config.group_variables.keys():
output[group_name] = VariableType.ANY
return output
for var_type in config.group_type:
output[var_type] = config.group_type[var_type]
return output
@@ -64,6 +68,8 @@ class VariableAggregatorNode(BaseNode):
return value
logger.info("No variable found in non-group mode; returning empty string.")
if not self.typed_config.group_type:
return ""
return DEFAULT_VALUE(self.typed_config.group_type["output"])
# --------------------------
@@ -83,7 +89,10 @@ class VariableAggregatorNode(BaseNode):
result[group_name] = value
break
else:
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
if not self.typed_config.group_type:
result[group_name] = ""
else:
result[group_name] = DEFAULT_VALUE(self.typed_config.group_type[group_name])
logger.info(f"No variable found for group '{group_name}'; set empty string.")
logger.info(f"Node: {self.node_id} variable aggregation result: {result}")
return result

View File

@@ -20,6 +20,8 @@ class VariableType(StrEnum):
NESTED_ARRAY = "array_nest"
ANY = 'any'
@classmethod
def type_map(cls, var: Any) -> "VariableType":
"""Maps a Python value to a corresponding VariableType.

View File

@@ -107,6 +107,16 @@ class NestedArrayObject(BaseVariable):
return [[item.get_value() for item in row] for row in self.value]
class AnyObject(BaseVariable):
type = 'any'
def valid_value(self, value: Any) -> Any:
return value
def to_literal(self) -> str:
return str(self.value)
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
"""简化 ArrayObject 创建,不需要重复写类型"""
@@ -133,5 +143,7 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return make_array(DictObject, value)
case VariableType.ARRAY_FILE:
return make_array(FileObject, value)
case VariableType.ANY:
return AnyObject(value)
case _:
raise TypeError(f"Invalid type - {var_type}")