feat(workflow): add placeholder node for unknown types

This commit is contained in:
Eternity
2026-03-02 16:05:25 +08:00
parent 5cf2b08777
commit 574ab4506b
6 changed files with 163 additions and 86 deletions

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModel
ExceptionType
from app.core.workflow.nodes.assigner import AssignerNodeConfig
from app.core.workflow.nodes.assigner.config import AssignmentItem
from app.core.workflow.nodes.base_config import VariableDefinition
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
from app.core.workflow.nodes.code import CodeNodeConfig
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
@@ -69,9 +69,27 @@ class DifyConverter(BaseConverter):
}
def get_node_convert(self, node_type):
func = self.CONFIG_CONVERT_MAP.get(node_type, None)
func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {})
return func
def config_validate(
self,
node_id: str,
node_name: str,
config: type[BaseNodeConfig],
value: dict
):
try:
return config.model_validate(value)
except Exception as e:
self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG,
node_id=node_id,
node_name=node_name,
detail=str(e)
))
return None
@staticmethod
def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression))
@@ -80,14 +98,16 @@ class DifyConverter(BaseConverter):
if not var_selector:
return ""
selector = var_selector.split('.')
if len(selector) != 2:
if len(selector) not in [2, 3]:
raise Exception(f"invalid variable selector: {var_selector}")
if len(selector) == 3:
selector = selector[1:]
if selector[0] == "conversation":
selector[0] = "conv"
var_selector = ".".join(selector)
mapping = {
"sys.query": "sys.message"
} | self.node_output_map
"sys.query": "sys.message"
} | self.node_output_map
var_selector = mapping.get(var_selector, var_selector)
return var_selector
@@ -237,7 +257,7 @@ class DifyConverter(BaseConverter):
node_id=node["id"],
node_name=node_data["title"],
name=var["variable"],
detail=f"Unsupport Variable type for start node: {var_type}"
detail=f"Unsupported Variable type for start node: {var_type}"
)
)
continue
@@ -253,9 +273,11 @@ class DifyConverter(BaseConverter):
max_length=var.get("max_length"),
)
start_vars.append(var_def)
return StartNodeConfig(
result = StartNodeConfig.model_construct(
variables=start_vars
).model_dump()
self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result)
return result
def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -270,16 +292,18 @@ class DifyConverter(BaseConverter):
for category in node_data["classes"]:
self.branch_node_cache[node["id"]].append(category["id"])
categories.append(
ClassifierConfig(
ClassifierConfig.model_construct(
class_name=category["name"],
)
)
return QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data["query_variable_selector"]),
user_supplement_prompt=self.trans_variable_format(node_data["instructions"]),
result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result)
return result
def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -315,7 +339,7 @@ class DifyConverter(BaseConverter):
vision_input = self._process_list_variable_litearl(
node_data["vision"]["configs"]["variable_selector"]
) if vision else None
return LLMNodeConfig.model_construct(
result = LLMNodeConfig.model_construct(
model_id=None,
context=context,
memory=memory,
@@ -323,12 +347,16 @@ class DifyConverter(BaseConverter):
vision_input=vision_input,
messages=messages
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result)
return result
def convert_end_node_config(self, node: dict) -> dict:
node_data = node["data"]
return EndNodeConfig(
output=self.trans_variable_format(node_data["answer"]),
result = EndNodeConfig.model_construct(
output=self.trans_variable_format(node_data.get("answer", "")),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result)
return result
def convert_if_else_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -359,9 +387,11 @@ class DifyConverter(BaseConverter):
)
)
self.branch_node_cache[node["id"]].append(case_id)
return IfElseNodeConfig(
result = IfElseNodeConfig.model_construct(
cases=cases
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result)
return result
def convert_loop_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -370,7 +400,7 @@ class DifyConverter(BaseConverter):
for condition in node_data["break_conditions"]:
right_value = condition["value"]
conditions.append(
LoopConditionDetail(
LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_litearl(condition["variable_selector"]),
right=self.trans_variable_format(
@@ -383,7 +413,7 @@ class DifyConverter(BaseConverter):
if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT,
)
)
condition_config = ConditionsConfig(
condition_config = ConditionsConfig.model_construct(
logical_operator=logical_operator,
expressions=conditions
)
@@ -392,9 +422,9 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_litearl(variable["value"])
right_value = self._process_list_variable_litearl(variable.get("value", ""))
else:
right_value = self.convert_variable_type(right_value_type, variable["value"])
right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append(
CycleVariable(
name=variable["label"],
@@ -403,23 +433,28 @@ class DifyConverter(BaseConverter):
input_type=right_input_type
)
)
return LoopNodeConfig(
result = LoopNodeConfig.model_construct(
condition=condition_config,
cycle_vars=loop_variables,
max_loop=node_data["loop_count"]
max_loop=node_data.get("loop_count", 10)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result)
return result
def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"]
return IterationNodeConfig(
result = IterationNodeConfig.model_construct(
input=self._process_list_variable_litearl(node_data["iterator_selector"]),
parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_litearl(node_data["output_selector"]),
output_type=self.variable_type_map(node_data["output_type"]),
output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"],
).model_dump()
self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result)
return result
def convert_assigner_node_config(self, node: dict) -> dict:
node_data = node["data"]
assignments = []
@@ -435,16 +470,18 @@ class DifyConverter(BaseConverter):
operation=self.convert_assignment_operator(assignment["operation"])
)
)
return AssignerNodeConfig(
result = AssignerNodeConfig.model_construct(
assignments=assignments
).model_dump()
self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result)
return result
def convert_code_node_config(self, node: dict) -> dict:
node_data = node["data"]
input_variables = []
for input_variable in node_data["variables"]:
input_variables.append(
InputVariable(
InputVariable.model_construct(
name=input_variable["variable"],
variable=self._process_list_variable_litearl(input_variable["value_selector"]),
)
@@ -453,7 +490,7 @@ class DifyConverter(BaseConverter):
output_variables = []
for output_variable in node_data["outputs"]:
output_variables.append(
OutputVariable(
OutputVariable.model_construct(
name=output_variable,
type=node_data["outputs"][output_variable]["type"],
)
@@ -461,18 +498,20 @@ class DifyConverter(BaseConverter):
code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8")
return CodeNodeConfig(
result = CodeNodeConfig.model_construct(
input_variables=input_variables,
language=node_data["code_language"],
output_variables=output_variables,
code=code
).model_dump()
self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result)
return result
def convert_http_node_config(self, node: dict) -> dict:
node_data = node["data"]
if node_data["authorization"] != 'no-auth':
if node_data["authorization"]["type"] != 'no-auth':
auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"])
auth_config = HttpAuthConfig(
auth_config = HttpAuthConfig.model_construct(
auth_type=auth_type,
header=node_data["authorization"]["config"].get("header"),
api_key=node_data["authorization"]["config"].get("api_key"),
@@ -504,7 +543,7 @@ class DifyConverter(BaseConverter):
body_content = ""
headers = {}
for header in node_data["headers"].split("\n"):
for header in node_data.get("headers", "").split("\n"):
if not header:
continue
@@ -522,7 +561,7 @@ class DifyConverter(BaseConverter):
))
params = {}
for param in node_data["params"].split("\n"):
for param in node_data.get("params", "").split("\n"):
if not param:
continue
@@ -547,7 +586,7 @@ class DifyConverter(BaseConverter):
default_body = ""
default_header = {}
default_status_code = 0
for var in node_data["default_value"]:
for var in node_data.get("default_value") or []:
if var["key"] == "body":
default_body = var["value"]
elif var["key"] == "header":
@@ -561,45 +600,50 @@ class DifyConverter(BaseConverter):
)
self.error_branch_node_cache.append(node['id'])
return HttpRequestNodeConfig(
result = HttpRequestNodeConfig.model_construct(
method=node_data["method"].upper(),
url=node_data["url"],
auth=auth_config,
body=HttpContentTypeConfig(
body=HttpContentTypeConfig.model_construct(
content_type=self.convert_http_content_type(node_data["body"]["type"]),
data=body_content,
),
headers=headers,
params=params,
verify_ssl=node_data["ssl_verify"],
timeouts=HttpTimeOutConfig(
timeouts=HttpTimeOutConfig.model_construct(
connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5,
read_timeout=node_data["timeout"]["max_read_timeout"] or 5,
write_timeout=node_data["timeout"]["max_write_timeout"] or 5,
),
retry=HttpRetryConfig(
retry=HttpRetryConfig.model_construct(
enable=node_data["retry_config"]["retry_enabled"],
max_attempts=node_data["retry_config"]["max_retries"],
retry_interval=node_data["retry_config"]["retry_interval"],
),
error_handle=HttpErrorHandleConfig(
error_handle=HttpErrorHandleConfig.model_construct(
method=error_handle_type,
default=default_value,
)
).model_dump()
self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result)
return result
def convert_jinja_render_node_config(self, node: dict) -> dict:
node_data = node["data"]
mapping = []
for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig(
mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"],
value=self._process_list_variable_litearl(variable["value_selector"])
))
return JinjaRenderNodeConfig(
result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"],
mapping=mapping,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result)
return result
def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"]
@@ -609,10 +653,13 @@ class DifyConverter(BaseConverter):
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.",
))
return KnowledgeRetrievalNodeConfig.model_construct(
result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
return result
def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(
@@ -623,46 +670,53 @@ class DifyConverter(BaseConverter):
)
)
params = []
for param in node_data["parameters"]:
for param in node_data.get("parameters", []):
params.append(
ParamsConfig(
ParamsConfig.model_construct(
name=param["name"],
desc=param["description"],
required=param["required"],
type=param["type"],
)
)
return ParameterExtractorNodeConfig.model_construct(
result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_litearl(node_data["query"]),
params=params,
prompt=node_data["instruction"]
prompt=node_data.get("instruction")
).model_dump()
self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result)
return result
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
node_data = node["data"]
group_enable = node_data["advanced_settings"]["group_enabled"]
advanced_settings = node_data.get("advanced_settings", {})
group_variables = {}
group_type = {}
if not group_enable:
if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables["output"] = [
self._process_list_variable_litearl(variable)
for variable in node_data["variables"]
]
group_type["output"] = node_data["output_type"]
else:
for group in node_data["advanced_settings"]["groups"]:
for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [
self._process_list_variable_litearl(variable)
for variable in group["variables"]
]
group_type[group["group_name"]] = group["output_type"]
return VariableAggregatorNodeConfig(
group=group_enable,
result = VariableAggregatorNodeConfig.model_construct(
group=advanced_settings.get("group_enabled", False),
group_variables=group_variables,
group_type=group_type,
).model_dump()
self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result)
return result
def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(

View File

@@ -59,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
)
def map_node_type(self, platform_node_type) -> str:
return self.NODE_TYPE_MAPPING.get(platform_node_type)
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
@property
def origin_nodes(self):
@@ -179,8 +179,13 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_type = node_data["type"]
try:
converter = self.get_node_convert(node_type)
if converter is None:
raise Exception(f"node type not supported - {node_type}")
if node_type not in self.CONFIG_CONVERT_MAP:
self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE,
node_id=node["id"],
node_name=node["data"]["title"],
detail=f"node type {node_type} is unsupported",
))
return converter(node)
except Exception as e:
self.errors.append(ExceptionDefineition(

View File

@@ -158,6 +158,22 @@ class VariablePool:
default: Any = None,
strict: bool = True
):
"""Retrieve a variable instance from the variable pool.
Args:
selector:
Variable selector as a string variable literal (e.g. "{{ sys.message }}").
default:
The value to return if the variable does not exist.
strict:
If True, raises KeyError when the variable does not exist.
Returns:
The variable instance object if it exists; otherwise returns `default`.
Raises:
KeyError: If strict is True and the variable does not exist.
"""
variable_struct = self._get_variable_struct(selector)
if variable_struct is None:
if strict:

View File

@@ -132,24 +132,24 @@ class WorkflowExecutor:
start_time = datetime.datetime.now()
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
# Execute the workflow
try:
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes
@@ -231,23 +231,23 @@ class WorkflowExecutor:
}
}
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
try:
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
# Initialize the variable pool and system variables
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
full_content = ''
self.stream_coordinator.update_scope_activation("sys")

View File

@@ -24,6 +24,8 @@ class NodeType(StrEnum):
MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write"
UNKNOWN = "unknown"
BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]

View File

@@ -123,10 +123,10 @@ class NodeFactory:
# 获取节点类
node_class = cls._node_types.get(node_type)
if not node_class:
raise ValueError(f"不支持的节点类型: {node_type}")
raise ValueError(f"Unsupported node type: {node_type}")
# 创建节点实例
logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})")
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config)
@classmethod