fix(workflow): handle non-stream output field changes, add file type support to HTTP node, fix iteration node flattening bug

This commit is contained in:
Eternity
2026-03-02 14:52:51 +08:00
parent 6718553bf4
commit 5cf2b08777
6 changed files with 76 additions and 21 deletions

View File

@@ -671,4 +671,4 @@ class DifyConverter(BaseConverter):
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
detail=f"Please reconfigure the tool node.", detail=f"Please reconfigure the tool node.",
)) ))
return {} return {}

View File

@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
instance: instance:
The concrete variable object. The actual Python type is The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringVariable, represented by the generic parameter ``T`` (e.g. StringVariable,
NumberVariable, ArrayObject[StringVariable]). NumberVariable, ArrayVariable[StringVariable]).
mut: mut:
Whether the variable is mutable. Whether the variable is mutable.
""" """
@@ -152,6 +152,20 @@ class VariablePool:
return None return None
return var_instance return var_instance
def get_instance(
self,
selector: str,
default: Any = None,
strict: bool = True
):
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
return variable_struct.instance
def get_value( def get_value(
self, self,
selector: str, selector: str,

View File

@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
if config.flatten: if config.flatten:
outputs['output'] = config.output_type outputs['output'] = config.output_type
else: else:
outputs['output'] = VariableType.ARRAY_STRING outputs['output'] = VariableType.NESTED_ARRAY
else: else:
outputs['output'] = VariableType(f"array[{config.output_type}]") outputs['output'] = VariableType(f"array[{config.output_type}]")
return outputs return outputs

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
import logging import logging
import uuid
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
import httpx import httpx
@@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
@@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode):
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return params return params
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
""" """
Build HTTP request body arguments for httpx request methods. Build HTTP request body arguments for httpx request methods.
@@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode):
)) ))
case HttpContentType.FROM_DATA: case HttpContentType.FROM_DATA:
data = {} data = {}
content["files"] = {}
for item in self.typed_config.body.data: for item in self.typed_config.body.data:
if item.type == "text": if item.type == "text":
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool) data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
variable_pool)
elif item.type == "file": elif item.type == "file":
# TODO: File support (Feature) content["files"][self._render_template(item.key, variable_pool)] = (
pass uuid.uuid4().hex,
await variable_pool.get_instance(item.value).get_content()
)
content["data"] = data content["data"] = data
case HttpContentType.BINARY: case HttpContentType.BINARY:
# TODO: File support (Feature) content["files"] = []
pass file_instence = variable_pool.get_instance(self.typed_config.body.data)
if isinstance(file_instence, ArrayVariable):
for v in file_instence.value:
if isinstance(v, FileVariable):
content["files"].append(
(
"files", (uuid.uuid4().hex, await v.get_content())
)
)
elif isinstance(file_instence, FileVariable):
content["files"].append(
(
"file", (uuid.uuid4().hex, await file_instence.get_content())
)
)
case HttpContentType.WWW_FORM: case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template( content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), variable_pool json.dumps(self.typed_config.body.data), variable_pool
@@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode):
request_func = self._get_client_method(client) request_func = self._get_client_method(client)
resp = await request_func( resp = await request_func(
url=self._render_template(self.typed_config.url, variable_pool), url=self._render_template(self.typed_config.url, variable_pool),
**self._build_content(variable_pool) **(await self._build_content(variable_pool))
) )
resp.raise_for_status() resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded") logger.info(f"Node {self.node_id}: HTTP request succeeded")

View File

@@ -1,8 +1,10 @@
from typing import Any, TypeVar, Type, Generic from typing import Any, TypeVar, Type, Generic
import httpx
from deprecated import deprecated from deprecated import deprecated
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
from app.core.config import settings
T = TypeVar("T", bound=BaseVariable) T = TypeVar("T", bound=BaseVariable)
@@ -80,8 +82,23 @@ class FileVariable(BaseVariable):
def get_value(self) -> Any: def get_value(self) -> Any:
return self.value.model_dump() return self.value.model_dump()
async def get_content(self):
total_bytes = 0
chunks = []
class ArrayObject(BaseVariable, Generic[T]): async with httpx.AsyncClient() as client:
async with client.stream("GET", self.value.url) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes(8192):
total_bytes += len(chunk)
if total_bytes > settings.MAX_FILE_SIZE:
raise ValueError(f"File too large: {total_bytes} bytes")
chunks.append(chunk)
return b"".join(chunks)
class ArrayVariable(BaseVariable, Generic[T]):
type = 'array' type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]): def __init__(self, child_type: Type[T], value: list[Any]):
@@ -108,7 +125,7 @@ class ArrayObject(BaseVariable, Generic[T]):
return [v.get_value() for v in self.value] return [v.get_value() for v in self.value]
class NestedArrayObject(BaseVariable): class NestedArrayVariable(BaseVariable):
type = 'array_nest' type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]: def valid_value(self, value: list[T]) -> list[T]:
@@ -116,23 +133,23 @@ class NestedArrayObject(BaseVariable):
raise TypeError(f"Value must be a list - {type(value)}:{value}") raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = [] final_value = []
for v in value: for v in value:
if not isinstance(v, ArrayObject): if not isinstance(v, list):
raise TypeError("All elements must be of type list") raise TypeError("All elements must be of type list")
final_value.append(v) final_value.append(make_array(AnyVariable, v))
return final_value return final_value
def to_literal(self) -> str: def to_literal(self) -> str:
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value]) return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value])
def get_value(self) -> Any: def get_value(self) -> Any:
return [[item.get_value() for item in row] for row in self.value] return [[item for item in row.get_value()] for row in self.value]
@deprecated( @deprecated(
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.", reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
category=RuntimeWarning category=RuntimeWarning
) )
class AnyObject(BaseVariable): class AnyVariable(BaseVariable):
type = 'any' type = 'any'
def valid_value(self, value: Any) -> Any: def valid_value(self, value: Any) -> Any:
@@ -142,10 +159,10 @@ class AnyObject(BaseVariable):
return str(self.value) return str(self.value)
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]: def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]:
"""简化 ArrayObject 创建,不需要重复写类型""" """简化 ArrayVariable 创建,不需要重复写类型"""
return ArrayObject(child_type, value) return ArrayVariable(child_type, value)
def create_variable_instance(var_type: VariableType, value: Any) -> T: def create_variable_instance(var_type: VariableType, value: Any) -> T:
@@ -168,7 +185,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return make_array(DictVariable, value) return make_array(DictVariable, value)
case VariableType.ARRAY_FILE: case VariableType.ARRAY_FILE:
return make_array(FileVariable, value) return make_array(FileVariable, value)
case VariableType.NESTED_ARRAY:
return NestedArrayVariable(value)
case VariableType.ANY: case VariableType.ANY:
return AnyObject(value) return AnyVariable(value)
case _: case _:
raise TypeError(f"Invalid type - {var_type}") raise TypeError(f"Invalid type - {var_type}")

View File

@@ -580,6 +580,7 @@ class WorkflowService:
# "variables": result.get("variables"), # "variables": result.get("variables"),
# "messages": result.get("messages"), # "messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串) "output": result.get("output"), # 最终输出(字符串)
"message": result.get("output"), # 最终输出(字符串)
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) # "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID "conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"), "error_message": result.get("error"),