fix(workflow): handle non-stream output field changes, add file type support to HTTP node, fix iteration node flattening bug
This commit is contained in:
@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
|
|||||||
instance:
|
instance:
|
||||||
The concrete variable object. The actual Python type is
|
The concrete variable object. The actual Python type is
|
||||||
represented by the generic parameter ``T`` (e.g. StringVariable,
|
represented by the generic parameter ``T`` (e.g. StringVariable,
|
||||||
NumberVariable, ArrayObject[StringVariable]).
|
NumberVariable, ArrayVariable[StringVariable]).
|
||||||
mut:
|
mut:
|
||||||
Whether the variable is mutable.
|
Whether the variable is mutable.
|
||||||
"""
|
"""
|
||||||
@@ -152,6 +152,20 @@ class VariablePool:
|
|||||||
return None
|
return None
|
||||||
return var_instance
|
return var_instance
|
||||||
|
|
||||||
|
def get_instance(
|
||||||
|
self,
|
||||||
|
selector: str,
|
||||||
|
default: Any = None,
|
||||||
|
strict: bool = True
|
||||||
|
):
|
||||||
|
variable_struct = self._get_variable_struct(selector)
|
||||||
|
if variable_struct is None:
|
||||||
|
if strict:
|
||||||
|
raise KeyError(f"{selector} not exist")
|
||||||
|
return default
|
||||||
|
|
||||||
|
return variable_struct.instance
|
||||||
|
|
||||||
def get_value(
|
def get_value(
|
||||||
self,
|
self,
|
||||||
selector: str,
|
selector: str,
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
|
|||||||
if config.flatten:
|
if config.flatten:
|
||||||
outputs['output'] = config.output_type
|
outputs['output'] = config.output_type
|
||||||
else:
|
else:
|
||||||
outputs['output'] = VariableType.ARRAY_STRING
|
outputs['output'] = VariableType.NESTED_ARRAY
|
||||||
else:
|
else:
|
||||||
outputs['output'] = VariableType(f"array[{config.output_type}]")
|
outputs['output'] = VariableType(f"array[{config.output_type}]")
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode
|
|||||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
@@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Build HTTP request body arguments for httpx request methods.
|
Build HTTP request body arguments for httpx request methods.
|
||||||
|
|
||||||
@@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode):
|
|||||||
))
|
))
|
||||||
case HttpContentType.FROM_DATA:
|
case HttpContentType.FROM_DATA:
|
||||||
data = {}
|
data = {}
|
||||||
|
content["files"] = {}
|
||||||
for item in self.typed_config.body.data:
|
for item in self.typed_config.body.data:
|
||||||
if item.type == "text":
|
if item.type == "text":
|
||||||
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool)
|
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
|
||||||
|
variable_pool)
|
||||||
elif item.type == "file":
|
elif item.type == "file":
|
||||||
# TODO: File support (Feature)
|
content["files"][self._render_template(item.key, variable_pool)] = (
|
||||||
pass
|
uuid.uuid4().hex,
|
||||||
|
await variable_pool.get_instance(item.value).get_content()
|
||||||
|
)
|
||||||
content["data"] = data
|
content["data"] = data
|
||||||
case HttpContentType.BINARY:
|
case HttpContentType.BINARY:
|
||||||
# TODO: File support (Feature)
|
content["files"] = []
|
||||||
pass
|
file_instence = variable_pool.get_instance(self.typed_config.body.data)
|
||||||
|
if isinstance(file_instence, ArrayVariable):
|
||||||
|
for v in file_instence.value:
|
||||||
|
if isinstance(v, FileVariable):
|
||||||
|
content["files"].append(
|
||||||
|
(
|
||||||
|
"files", (uuid.uuid4().hex, await v.get_content())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(file_instence, FileVariable):
|
||||||
|
content["files"].append(
|
||||||
|
(
|
||||||
|
"file", (uuid.uuid4().hex, await file_instence.get_content())
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
case HttpContentType.WWW_FORM:
|
case HttpContentType.WWW_FORM:
|
||||||
content["data"] = json.loads(self._render_template(
|
content["data"] = json.loads(self._render_template(
|
||||||
json.dumps(self.typed_config.body.data), variable_pool
|
json.dumps(self.typed_config.body.data), variable_pool
|
||||||
@@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
request_func = self._get_client_method(client)
|
request_func = self._get_client_method(client)
|
||||||
resp = await request_func(
|
resp = await request_func(
|
||||||
url=self._render_template(self.typed_config.url, variable_pool),
|
url=self._render_template(self.typed_config.url, variable_pool),
|
||||||
**self._build_content(variable_pool)
|
**(await self._build_content(variable_pool))
|
||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
from typing import Any, TypeVar, Type, Generic
|
from typing import Any, TypeVar, Type, Generic
|
||||||
|
|
||||||
|
import httpx
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
|
|
||||||
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
|
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseVariable)
|
T = TypeVar("T", bound=BaseVariable)
|
||||||
|
|
||||||
@@ -80,8 +82,23 @@ class FileVariable(BaseVariable):
|
|||||||
def get_value(self) -> Any:
|
def get_value(self) -> Any:
|
||||||
return self.value.model_dump()
|
return self.value.model_dump()
|
||||||
|
|
||||||
|
async def get_content(self):
|
||||||
|
total_bytes = 0
|
||||||
|
chunks = []
|
||||||
|
|
||||||
class ArrayObject(BaseVariable, Generic[T]):
|
async with httpx.AsyncClient() as client:
|
||||||
|
async with client.stream("GET", self.value.url) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for chunk in resp.aiter_bytes(8192):
|
||||||
|
total_bytes += len(chunk)
|
||||||
|
if total_bytes > settings.MAX_FILE_SIZE:
|
||||||
|
raise ValueError(f"File too large: {total_bytes} bytes")
|
||||||
|
chunks.append(chunk)
|
||||||
|
|
||||||
|
return b"".join(chunks)
|
||||||
|
|
||||||
|
|
||||||
|
class ArrayVariable(BaseVariable, Generic[T]):
|
||||||
type = 'array'
|
type = 'array'
|
||||||
|
|
||||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||||
@@ -108,7 +125,7 @@ class ArrayObject(BaseVariable, Generic[T]):
|
|||||||
return [v.get_value() for v in self.value]
|
return [v.get_value() for v in self.value]
|
||||||
|
|
||||||
|
|
||||||
class NestedArrayObject(BaseVariable):
|
class NestedArrayVariable(BaseVariable):
|
||||||
type = 'array_nest'
|
type = 'array_nest'
|
||||||
|
|
||||||
def valid_value(self, value: list[T]) -> list[T]:
|
def valid_value(self, value: list[T]) -> list[T]:
|
||||||
@@ -116,23 +133,23 @@ class NestedArrayObject(BaseVariable):
|
|||||||
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
raise TypeError(f"Value must be a list - {type(value)}:{value}")
|
||||||
final_value = []
|
final_value = []
|
||||||
for v in value:
|
for v in value:
|
||||||
if not isinstance(v, ArrayObject):
|
if not isinstance(v, list):
|
||||||
raise TypeError("All elements must be of type list")
|
raise TypeError("All elements must be of type list")
|
||||||
final_value.append(v)
|
final_value.append(make_array(AnyVariable, v))
|
||||||
return final_value
|
return final_value
|
||||||
|
|
||||||
def to_literal(self) -> str:
|
def to_literal(self) -> str:
|
||||||
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value])
|
return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value])
|
||||||
|
|
||||||
def get_value(self) -> Any:
|
def get_value(self) -> Any:
|
||||||
return [[item.get_value() for item in row] for row in self.value]
|
return [[item for item in row.get_value()] for row in self.value]
|
||||||
|
|
||||||
|
|
||||||
@deprecated(
|
@deprecated(
|
||||||
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
|
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
|
||||||
category=RuntimeWarning
|
category=RuntimeWarning
|
||||||
)
|
)
|
||||||
class AnyObject(BaseVariable):
|
class AnyVariable(BaseVariable):
|
||||||
type = 'any'
|
type = 'any'
|
||||||
|
|
||||||
def valid_value(self, value: Any) -> Any:
|
def valid_value(self, value: Any) -> Any:
|
||||||
@@ -142,10 +159,10 @@ class AnyObject(BaseVariable):
|
|||||||
return str(self.value)
|
return str(self.value)
|
||||||
|
|
||||||
|
|
||||||
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]:
|
def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]:
|
||||||
"""简化 ArrayObject 创建,不需要重复写类型"""
|
"""简化 ArrayVariable 创建,不需要重复写类型"""
|
||||||
|
|
||||||
return ArrayObject(child_type, value)
|
return ArrayVariable(child_type, value)
|
||||||
|
|
||||||
|
|
||||||
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
||||||
@@ -168,7 +185,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
|
|||||||
return make_array(DictVariable, value)
|
return make_array(DictVariable, value)
|
||||||
case VariableType.ARRAY_FILE:
|
case VariableType.ARRAY_FILE:
|
||||||
return make_array(FileVariable, value)
|
return make_array(FileVariable, value)
|
||||||
|
case VariableType.NESTED_ARRAY:
|
||||||
|
return NestedArrayVariable(value)
|
||||||
case VariableType.ANY:
|
case VariableType.ANY:
|
||||||
return AnyObject(value)
|
return AnyVariable(value)
|
||||||
case _:
|
case _:
|
||||||
raise TypeError(f"Invalid type - {var_type}")
|
raise TypeError(f"Invalid type - {var_type}")
|
||||||
|
|||||||
@@ -580,6 +580,7 @@ class WorkflowService:
|
|||||||
# "variables": result.get("variables"),
|
# "variables": result.get("variables"),
|
||||||
# "messages": result.get("messages"),
|
# "messages": result.get("messages"),
|
||||||
"output": result.get("output"), # 最终输出(字符串)
|
"output": result.get("output"), # 最终输出(字符串)
|
||||||
|
"message": result.get("output"), # 最终输出(字符串)
|
||||||
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
|
||||||
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
"conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID
|
||||||
"error_message": result.get("error"),
|
"error_message": result.get("error"),
|
||||||
|
|||||||
Reference in New Issue
Block a user