Merge pull request #427 from SuanmoSuanyangTechnology/fix/workflow-variable

fix(workflow): handle non-stream output field changes, add file type support to HTTP node, fix iteration node flattening bug
This commit is contained in:
Mark
2026-03-02 17:55:54 +08:00
committed by GitHub
10 changed files with 239 additions and 107 deletions

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModel
ExceptionType ExceptionType
from app.core.workflow.nodes.assigner import AssignerNodeConfig from app.core.workflow.nodes.assigner import AssignerNodeConfig
from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code import CodeNodeConfig from app.core.workflow.nodes.code import CodeNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
@@ -69,9 +69,27 @@ class DifyConverter(BaseConverter):
} }
def get_node_convert(self, node_type): def get_node_convert(self, node_type):
func = self.CONFIG_CONVERT_MAP.get(node_type, None) func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
return func return func
def config_validate(
self,
node_id: str,
node_name: str,
config: type[BaseNodeConfig],
value: dict
):
try:
return config.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
@staticmethod @staticmethod
def is_variable(expression) -> bool: def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression)) return bool(re.match(r"\{\{#(.*?)#}}", expression))
@@ -80,14 +98,16 @@ class DifyConverter(BaseConverter):
if not var_selector: if not var_selector:
return "" return ""
selector = var_selector.split('.') selector = var_selector.split('.')
if len(selector) != 2: if len(selector) not in [2, 3]:
raise Exception(f"invalid variable selector: {var_selector}") raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3:
selector = selector[1:]
if selector[0] == "conversation": if selector[0] == "conversation":
selector[0] = "conv" selector[0] = "conv"
var_selector = ".".join(selector) var_selector = ".".join(selector)
mapping = { mapping = {
"sys.query": "sys.message" "sys.query": "sys.message"
} | self.node_output_map } | self.node_output_map
var_selector = mapping.get(var_selector, var_selector) var_selector = mapping.get(var_selector, var_selector)
return var_selector return var_selector
@@ -237,7 +257,7 @@ class DifyConverter(BaseConverter):
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
name=var["variable"], name=var["variable"],
detail=f"Unsupport Variable type for start node: {var_type}" detail=f"Unsupported Variable type for start node: {var_type}"
) )
) )
continue continue
@@ -253,9 +273,11 @@ class DifyConverter(BaseConverter):
max_length=var.get("max_length"), max_length=var.get("max_length"),
) )
start_vars.append(var_def) start_vars.append(var_def)
return StartNodeConfig( result = StartNodeConfig.model_construct(
variables=start_vars variables=start_vars
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result)
return result
def convert_question_classifier_node_config(self, node: dict) -> dict: def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
@@ -270,16 +292,18 @@ class DifyConverter(BaseConverter):
for category in node_data["classes"]: for category in node_data["classes"]:
self.branch_node_cache[node["id"]].append(category["id"]) self.branch_node_cache[node["id"]].append(category["id"])
categories.append( categories.append(
ClassifierConfig( ClassifierConfig.model_construct(
class_name=category["name"], class_name=category["name"],
) )
) )
return QuestionClassifierNodeConfig.model_construct( result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]), input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories, categories=categories,
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result)
return result
def convert_llm_node_config(self, node: dict) -> dict: def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
@@ -315,7 +339,7 @@ class DifyConverter(BaseConverter):
vision_input = self._process_list_variable_litearl( vision_input = self._process_list_variable_litearl(
node_data["vision"]["configs"]["variable_selector"] node_data["vision"]["configs"]["variable_selector"]
) if vision else None ) if vision else None
return LLMNodeConfig.model_construct( result = LLMNodeConfig.model_construct(
model_id=None, model_id=None,
context=context, context=context,
memory=memory, memory=memory,
@@ -323,12 +347,16 @@ class DifyConverter(BaseConverter):
vision_input=vision_input, vision_input=vision_input,
messages=messages messages=messages
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result)
return result
def convert_end_node_config(self, node: dict) -> dict: def convert_end_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
return EndNodeConfig( result = EndNodeConfig.model_construct(
output=self.trans_variable_format(node_data["answer"]), output=self.trans_variable_format(node_data.get("answer", "")),
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
return result
def convert_if_else_node_config(self, node: dict) -> dict: def convert_if_else_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
@@ -359,9 +387,11 @@ class DifyConverter(BaseConverter):
) )
) )
self.branch_node_cache[node["id"]].append(case_id) self.branch_node_cache[node["id"]].append(case_id)
return IfElseNodeConfig( result = IfElseNodeConfig.model_construct(
cases=cases cases=cases
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result)
return result
def convert_loop_node_config(self, node: dict) -> dict: def convert_loop_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
@@ -370,7 +400,7 @@ class DifyConverter(BaseConverter):
for condition in node_data["break_conditions"]: for condition in node_data["break_conditions"]:
right_value = condition["value"] right_value = condition["value"]
conditions.append( conditions.append(
LoopConditionDetail( LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]), operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]), left=self._process_list_variable_litearl(condition["variable_selector"]),
right=self.trans_variable_format( right=self.trans_variable_format(
@@ -383,7 +413,7 @@ class DifyConverter(BaseConverter):
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT, if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
) )
) )
condition_config = ConditionsConfig( condition_config = ConditionsConfig.model_construct(
logical_operator=logical_operator, logical_operator=logical_operator,
expressions=conditions expressions=conditions
) )
@@ -392,9 +422,9 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"] right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"]) right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE: if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable["value"]) right_value = self._process_list_variable_litearl(variable.get("value", ""))
else: else:
right_value = self.convert_variable_type(right_value_type, variable["value"]) right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append( loop_variables.append(
CycleVariable( CycleVariable(
name=variable["label"], name=variable["label"],
@@ -403,23 +433,28 @@ class DifyConverter(BaseConverter):
input_type=right_input_type input_type=right_input_type
) )
) )
return LoopNodeConfig( result = LoopNodeConfig.model_construct(
condition=condition_config, condition=condition_config,
cycle_vars=loop_variables, cycle_vars=loop_variables,
max_loop=node_data["loop_count"] max_loop=node_data.get("loop_count", 10)
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result)
return result
def convert_iteration_node_config(self, node: dict) -> dict: def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
return IterationNodeConfig( result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]), input=self._process_list_variable_litearl(node_data["iterator_selector"]),
parallel=node_data["is_parallel"], parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"], parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]), output=self._process_list_variable_litearl(node_data["output_selector"]),
output_type=self.variable_type_map(node_data["output_type"]), output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"], flatten=node_data["flatten_output"],
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
return result
def convert_assigner_node_config(self, node: dict) -> dict: def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
assignments = [] assignments = []
@@ -435,16 +470,18 @@ class DifyConverter(BaseConverter):
operation=self.convert_assignment_operator(assignment["operation"]) operation=self.convert_assignment_operator(assignment["operation"])
) )
) )
return AssignerNodeConfig( result = AssignerNodeConfig.model_construct(
assignments=assignments assignments=assignments
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result)
return result
def convert_code_node_config(self, node: dict) -> dict: def convert_code_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
input_variables = [] input_variables = []
for input_variable in node_data["variables"]: for input_variable in node_data["variables"]:
input_variables.append( input_variables.append(
InputVariable( InputVariable.model_construct(
name=input_variable["variable"], name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]), variable=self._process_list_variable_litearl(input_variable["value_selector"]),
) )
@@ -453,7 +490,7 @@ class DifyConverter(BaseConverter):
output_variables = [] output_variables = []
for output_variable in node_data["outputs"]: for output_variable in node_data["outputs"]:
output_variables.append( output_variables.append(
OutputVariable( OutputVariable.model_construct(
name=output_variable, name=output_variable,
type=node_data["outputs"][output_variable]["type"], type=node_data["outputs"][output_variable]["type"],
) )
@@ -461,18 +498,20 @@ class DifyConverter(BaseConverter):
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8") code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
return CodeNodeConfig( result = CodeNodeConfig.model_construct(
input_variables=input_variables, input_variables=input_variables,
language=node_data["code_language"], language=node_data["code_language"],
output_variables=output_variables, output_variables=output_variables,
code=code code=code
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result)
return result
def convert_http_node_config(self, node: dict) -> dict: def convert_http_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
if node_data["authorization"] != 'no-auth': if node_data["authorization"]["type"] != 'no-auth':
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"]) auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
auth_config = HttpAuthConfig( auth_config = HttpAuthConfig.model_construct(
auth_type=auth_type, auth_type=auth_type,
header=node_data["authorization"]["config"].get("header"), header=node_data["authorization"]["config"].get("header"),
api_key=node_data["authorization"]["config"].get("api_key"), api_key=node_data["authorization"]["config"].get("api_key"),
@@ -504,7 +543,7 @@ class DifyConverter(BaseConverter):
body_content = "" body_content = ""
headers = {} headers = {}
for header in node_data["headers"].split("\n"): for header in node_data.get("headers", "").split("\n"):
if not header: if not header:
continue continue
@@ -522,7 +561,7 @@ class DifyConverter(BaseConverter):
)) ))
params = {} params = {}
for param in node_data["params"].split("\n"): for param in node_data.get("params", "").split("\n"):
if not param: if not param:
continue continue
@@ -547,7 +586,7 @@ class DifyConverter(BaseConverter):
default_body = "" default_body = ""
default_header = {} default_header = {}
default_status_code = 0 default_status_code = 0
for var in node_data["default_value"]: for var in node_data.get("default_value") or []:
if var["key"] == "body": if var["key"] == "body":
default_body = var["value"] default_body = var["value"]
elif var["key"] == "header": elif var["key"] == "header":
@@ -561,45 +600,50 @@ class DifyConverter(BaseConverter):
) )
self.error_branch_node_cache.append(node['id']) self.error_branch_node_cache.append(node['id'])
return HttpRequestNodeConfig( result = HttpRequestNodeConfig.model_construct(
method=node_data["method"].upper(), method=node_data["method"].upper(),
url=node_data["url"], url=node_data["url"],
auth=auth_config, auth=auth_config,
body=HttpContentTypeConfig( body=HttpContentTypeConfig.model_construct(
content_type=self.convert_http_content_type(node_data["body"]["type"]), content_type=self.convert_http_content_type(node_data["body"]["type"]),
data=body_content, data=body_content,
), ),
headers=headers, headers=headers,
params=params, params=params,
verify_ssl=node_data["ssl_verify"], verify_ssl=node_data["ssl_verify"],
timeouts=HttpTimeOutConfig( timeouts=HttpTimeOutConfig.model_construct(
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5, connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
read_timeout=node_data["timeout"]["max_read_timeout"] or 5, read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
write_timeout=node_data["timeout"]["max_write_timeout"] or 5, write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
), ),
retry=HttpRetryConfig( retry=HttpRetryConfig.model_construct(
enable=node_data["retry_config"]["retry_enabled"], enable=node_data["retry_config"]["retry_enabled"],
max_attempts=node_data["retry_config"]["max_retries"], max_attempts=node_data["retry_config"]["max_retries"],
retry_interval=node_data["retry_config"]["retry_interval"], retry_interval=node_data["retry_config"]["retry_interval"],
), ),
error_handle=HttpErrorHandleConfig( error_handle=HttpErrorHandleConfig.model_construct(
method=error_handle_type, method=error_handle_type,
default=default_value, default=default_value,
) )
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result)
return result
def convert_jinja_render_node_config(self, node: dict) -> dict: def convert_jinja_render_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
mapping = [] mapping = []
for variable in node_data["variables"]: for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig( mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"], name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"]) value=self._process_list_variable_litearl(variable["value_selector"])
)) ))
return JinjaRenderNodeConfig( result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"], template=node_data["template"],
mapping=mapping, mapping=mapping,
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result)
return result
def convert_knowledge_node_config(self, node: dict) -> dict: def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
@@ -609,10 +653,13 @@ class DifyConverter(BaseConverter):
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.", detail=f"Please reconfigure the Knowledge Retrieval node.",
)) ))
return KnowledgeRetrievalNodeConfig.model_construct( result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]), query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
return result
def convert_parameter_extractor_node_config(self, node: dict) -> dict: def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
@@ -623,46 +670,53 @@ class DifyConverter(BaseConverter):
) )
) )
params = [] params = []
for param in node_data["parameters"]: for param in node_data.get("parameters", []):
params.append( params.append(
ParamsConfig( ParamsConfig.model_construct(
name=param["name"], name=param["name"],
desc=param["description"], desc=param["description"],
required=param["required"], required=param["required"],
type=param["type"], type=param["type"],
) )
) )
return ParameterExtractorNodeConfig.model_construct( result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]), text=self._process_list_variable_litearl(node_data["query"]),
params=params, params=params,
prompt=node_data["instruction"] prompt=node_data.get("instruction")
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result)
return result
def convert_variable_aggregator_node_config(self, node: dict) -> dict: def convert_variable_aggregator_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
group_enable = node_data["advanced_settings"]["group_enabled"] advanced_settings = node_data.get("advanced_settings", {})
group_variables = {} group_variables = {}
group_type = {} group_type = {}
if not group_enable: if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables["output"] = [ group_variables["output"] = [
self._process_list_variable_litearl(variable) self._process_list_variable_litearl(variable)
for variable in node_data["variables"] for variable in node_data["variables"]
] ]
group_type["output"] = node_data["output_type"] group_type["output"] = node_data["output_type"]
else: else:
for group in node_data["advanced_settings"]["groups"]: for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [ group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable) self._process_list_variable_litearl(variable)
for variable in group["variables"] for variable in group["variables"]
] ]
group_type[group["group_name"]] = group["output_type"] group_type[group["group_name"]] = group["output_type"]
return VariableAggregatorNodeConfig( result = VariableAggregatorNodeConfig.model_construct(
group=group_enable, group=advanced_settings.get("group_enabled", False),
group_variables=group_variables, group_variables=group_variables,
group_type=group_type, group_type=group_type,
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result)
return result
def convert_tool_node_config(self, node: dict) -> dict: def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefineition( self.warnings.append(ExceptionDefineition(
@@ -671,4 +725,4 @@ class DifyConverter(BaseConverter):
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
detail=f"Please reconfigure the tool node.", detail=f"Please reconfigure the tool node.",
)) ))
return {} return {}

View File

@@ -59,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
) )
def map_node_type(self, platform_node_type) -> str: def map_node_type(self, platform_node_type) -> str:
return self.NODE_TYPE_MAPPING.get(platform_node_type) return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@property @property
def origin_nodes(self): def origin_nodes(self):
@@ -179,8 +179,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_type = node_data["type"] node_type = node_data["type"]
try: try:
converter = self.get_node_convert(node_type) converter = self.get_node_convert(node_type)
if converter is None: if node_type not in self.CONFIG_CONVERT_MAP:
raise Exception(f"node type not supported - {node_type}") self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
detail=f"node type {node_type} is unsupported",
))
return converter(node) return converter(node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefineition(

View File

@@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]):
instance: instance:
The concrete variable object. The actual Python type is The concrete variable object. The actual Python type is
represented by the generic parameter ``T`` (e.g. StringVariable, represented by the generic parameter ``T`` (e.g. StringVariable,
NumberVariable, ArrayObject[StringVariable]). NumberVariable, ArrayVariable[StringVariable]).
mut: mut:
Whether the variable is mutable. Whether the variable is mutable.
""" """
@@ -152,6 +152,36 @@ class VariablePool:
return None return None
return var_instance return var_instance
def get_instance(
self,
selector: str,
default: Any = None,
strict: bool = True
):
"""Retrieve a variable instance from the variable pool.
Args:
selector:
Variable selector as a string variable literal (e.g. "{{ sys.message }}").
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
Returns:
The variable instance object if it exists; otherwise returns `default`.
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:
raise KeyError(f"{selector} not exist")
return default
return variable_struct.instance
def get_value( def get_value(
self, self,
selector: str, selector: str,

View File

@@ -132,24 +132,24 @@ class WorkflowExecutor:
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
# Execute the workflow # Execute the workflow
try: try:
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes # Aggregate output from all End nodes
@@ -231,23 +231,23 @@ class WorkflowExecutor:
} }
} }
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
try: try:
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
full_content = '' full_content = ''
self.stream_coordinator.update_scope_activation("sys") self.stream_coordinator.update_scope_activation("sys")

View File

@@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode):
if config.flatten: if config.flatten:
outputs['output'] = config.output_type outputs['output'] = config.output_type
else: else:
outputs['output'] = VariableType.ARRAY_STRING outputs['output'] = VariableType.NESTED_ARRAY
else: else:
outputs['output'] = VariableType(f"array[{config.output_type}]") outputs['output'] = VariableType(f"array[{config.output_type}]")
return outputs return outputs

View File

@@ -24,6 +24,8 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read" MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write" MEMORY_WRITE = "memory-write"
UNKNOWN = "unknown"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import json import json
import logging import logging
import uuid
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
import httpx import httpx
@@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
logger = logging.getLogger(__file__) logger = logging.getLogger(__file__)
@@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode):
params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool)
return params return params
def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]:
""" """
Build HTTP request body arguments for httpx request methods. Build HTTP request body arguments for httpx request methods.
@@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode):
)) ))
case HttpContentType.FROM_DATA: case HttpContentType.FROM_DATA:
data = {} data = {}
content["files"] = {}
for item in self.typed_config.body.data: for item in self.typed_config.body.data:
if item.type == "text": if item.type == "text":
data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool) data[self._render_template(item.key, variable_pool)] = self._render_template(item.value,
variable_pool)
elif item.type == "file": elif item.type == "file":
# TODO: File support (Feature) content["files"][self._render_template(item.key, variable_pool)] = (
pass uuid.uuid4().hex,
await variable_pool.get_instance(item.value).get_content()
)
content["data"] = data content["data"] = data
case HttpContentType.BINARY: case HttpContentType.BINARY:
# TODO: File support (Feature) content["files"] = []
pass file_instence = variable_pool.get_instance(self.typed_config.body.data)
if isinstance(file_instence, ArrayVariable):
for v in file_instence.value:
if isinstance(v, FileVariable):
content["files"].append(
(
"files", (uuid.uuid4().hex, await v.get_content())
)
)
elif isinstance(file_instence, FileVariable):
content["files"].append(
(
"file", (uuid.uuid4().hex, await file_instence.get_content())
)
)
case HttpContentType.WWW_FORM: case HttpContentType.WWW_FORM:
content["data"] = json.loads(self._render_template( content["data"] = json.loads(self._render_template(
json.dumps(self.typed_config.body.data), variable_pool json.dumps(self.typed_config.body.data), variable_pool
@@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode):
request_func = self._get_client_method(client) request_func = self._get_client_method(client)
resp = await request_func( resp = await request_func(
url=self._render_template(self.typed_config.url, variable_pool), url=self._render_template(self.typed_config.url, variable_pool),
**self._build_content(variable_pool) **(await self._build_content(variable_pool))
) )
resp.raise_for_status() resp.raise_for_status()
logger.info(f"Node {self.node_id}: HTTP request succeeded") logger.info(f"Node {self.node_id}: HTTP request succeeded")

View File

@@ -123,10 +123,10 @@ class NodeFactory:
# 获取节点类 # 获取节点类
node_class = cls._node_types.get(node_type) node_class = cls._node_types.get(node_type)
if not node_class: if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}") raise ValueError(f"Unsupported node type: {node_type}")
# 创建节点实例 # 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})") logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config) return node_class(node_config, workflow_config)
@classmethod @classmethod

View File

@@ -1,8 +1,10 @@
from typing import Any, TypeVar, Type, Generic from typing import Any, TypeVar, Type, Generic
import httpx
from deprecated import deprecated from deprecated import deprecated
from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType
from app.core.config import settings
T = TypeVar("T", bound=BaseVariable) T = TypeVar("T", bound=BaseVariable)
@@ -80,8 +82,23 @@ class FileVariable(BaseVariable):
def get_value(self) -> Any: def get_value(self) -> Any:
return self.value.model_dump() return self.value.model_dump()
async def get_content(self):
total_bytes = 0
chunks = []
class ArrayObject(BaseVariable, Generic[T]): async with httpx.AsyncClient() as client:
async with client.stream("GET", self.value.url) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes(8192):
total_bytes += len(chunk)
if total_bytes > settings.MAX_FILE_SIZE:
raise ValueError(f"File too large: {total_bytes} bytes")
chunks.append(chunk)
return b"".join(chunks)
class ArrayVariable(BaseVariable, Generic[T]):
type = 'array' type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]): def __init__(self, child_type: Type[T], value: list[Any]):
@@ -108,7 +125,7 @@ class ArrayObject(BaseVariable, Generic[T]):
return [v.get_value() for v in self.value] return [v.get_value() for v in self.value]
class NestedArrayObject(BaseVariable): class NestedArrayVariable(BaseVariable):
type = 'array_nest' type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]: def valid_value(self, value: list[T]) -> list[T]:
@@ -116,23 +133,23 @@ class NestedArrayObject(BaseVariable):
raise TypeError(f"Value must be a list - {type(value)}:{value}") raise TypeError(f"Value must be a list - {type(value)}:{value}")
final_value = [] final_value = []
for v in value: for v in value:
if not isinstance(v, ArrayObject): if not isinstance(v, list):
raise TypeError("All elements must be of type list") raise TypeError("All elements must be of type list")
final_value.append(v) final_value.append(make_array(AnyVariable, v))
return final_value return final_value
def to_literal(self) -> str: def to_literal(self) -> str:
return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value]) return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value])
def get_value(self) -> Any: def get_value(self) -> Any:
return [[item.get_value() for item in row] for row in self.value] return [[item for item in row.get_value()] for row in self.value]
@deprecated( @deprecated(
reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.", reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.",
category=RuntimeWarning category=RuntimeWarning
) )
class AnyObject(BaseVariable): class AnyVariable(BaseVariable):
type = 'any' type = 'any'
def valid_value(self, value: Any) -> Any: def valid_value(self, value: Any) -> Any:
@@ -142,10 +159,10 @@ class AnyObject(BaseVariable):
return str(self.value) return str(self.value)
def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]: def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]:
"""简化 ArrayObject 创建,不需要重复写类型""" """简化 ArrayVariable 创建,不需要重复写类型"""
return ArrayObject(child_type, value) return ArrayVariable(child_type, value)
def create_variable_instance(var_type: VariableType, value: Any) -> T: def create_variable_instance(var_type: VariableType, value: Any) -> T:
@@ -168,7 +185,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T:
return make_array(DictVariable, value) return make_array(DictVariable, value)
case VariableType.ARRAY_FILE: case VariableType.ARRAY_FILE:
return make_array(FileVariable, value) return make_array(FileVariable, value)
case VariableType.NESTED_ARRAY:
return NestedArrayVariable(value)
case VariableType.ANY: case VariableType.ANY:
return AnyObject(value) return AnyVariable(value)
case _: case _:
raise TypeError(f"Invalid type - {var_type}") raise TypeError(f"Invalid type - {var_type}")

View File

@@ -580,6 +580,7 @@ class WorkflowService:
# "variables": result.get("variables"), # "variables": result.get("variables"),
# "messages": result.get("messages"), # "messages": result.get("messages"),
"output": result.get("output"), # 最终输出(字符串) "output": result.get("output"), # 最终输出(字符串)
"message": result.get("output"), # 最终输出(字符串)
# "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) # "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据)
"conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID "conversation_id": result.get("conversation_id"), # 所有节点输出详细数据payload., # 会话 ID
"error_message": result.get("error"), "error_message": result.get("error"),