fix(workflow): fix exceptions when importing configs from Dify

This commit is contained in:
Eternity
2026-02-28 17:27:07 +08:00
parent 5e512df3d4
commit 54700e6fbe
8 changed files with 68 additions and 20 deletions

View File

@@ -68,6 +68,8 @@ class BasePlatformAdapter(ABC):
self.branch_node_cache = defaultdict(list) self.branch_node_cache = defaultdict(list)
self.error_branch_node_cache = [] self.error_branch_node_cache = []
self.node_output_map = {}
@abstractmethod @abstractmethod
def get_metadata(self) -> PlatformMetadata: def get_metadata(self) -> PlatformMetadata:
"""get platform metadata""" """get platform metadata"""

View File

@@ -44,6 +44,7 @@ class DifyConverter(BaseConverter):
warnings: list warnings: list
branch_node_cache: dict branch_node_cache: dict
error_branch_node_cache: list error_branch_node_cache: list
node_output_map: dict
def __init__(self): def __init__(self):
self.CONFIG_CONVERT_MAP = { self.CONFIG_CONVERT_MAP = {
@@ -60,7 +61,8 @@ class DifyConverter(BaseConverter):
"knowledge-retrieval": self.convert_knowledge_node_config, "knowledge-retrieval": self.convert_knowledge_node_config,
"parameter-extractor": self.convert_parameter_extractor_node_config, "parameter-extractor": self.convert_parameter_extractor_node_config,
"question-classifier": self.convert_question_classifier_node_config, "question-classifier": self.convert_question_classifier_node_config,
"variable-aggregator": self.convert_variable_aggregator, "variable-aggregator": self.convert_variable_aggregator_node_config,
"tool": self.convert_tool_node_config,
"loop-start": lambda x: {}, "loop-start": lambda x: {},
"iteration-start": lambda x: {}, "iteration-start": lambda x: {},
"loop-end": lambda x: {}, "loop-end": lambda x: {},
@@ -74,8 +76,7 @@ class DifyConverter(BaseConverter):
def is_variable(expression) -> bool: def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression)) return bool(re.match(r"\{\{#(.*?)#}}", expression))
@staticmethod def process_var_selector(self, var_selector):
def process_var_selector(var_selector):
if not var_selector: if not var_selector:
return "" return ""
selector = var_selector.split('.') selector = var_selector.split('.')
@@ -86,7 +87,7 @@ class DifyConverter(BaseConverter):
var_selector = ".".join(selector) var_selector = ".".join(selector)
mapping = { mapping = {
"sys.query": "sys.message" "sys.query": "sys.message"
} } | self.node_output_map
var_selector = mapping.get(var_selector, var_selector) var_selector = mapping.get(var_selector, var_selector)
return var_selector return var_selector
@@ -124,6 +125,8 @@ class DifyConverter(BaseConverter):
"checkbox": VariableType.BOOLEAN, "checkbox": VariableType.BOOLEAN,
"file-list": VariableType.ARRAY_FILE, "file-list": VariableType.ARRAY_FILE,
"select": VariableType.STRING, "select": VariableType.STRING,
"integer": VariableType.NUMBER,
"float": VariableType.NUMBER,
} }
var_type = type_map.get(source_type, source_type) var_type = type_map.get(source_type, source_type)
return var_type return var_type
@@ -160,6 +163,8 @@ class DifyConverter(BaseConverter):
"": ComparisonOperator.GE, "": ComparisonOperator.GE,
"": ComparisonOperator.LE, "": ComparisonOperator.LE,
"not empty": ComparisonOperator.NOT_EMPTY, "not empty": ComparisonOperator.NOT_EMPTY,
"start with": ComparisonOperator.START_WITH,
"end with": ComparisonOperator.END_WITH,
} }
return operator_map.get(operator, operator) return operator_map.get(operator, operator)
@@ -633,7 +638,7 @@ class DifyConverter(BaseConverter):
prompt=node_data["instruction"] prompt=node_data["instruction"]
).model_dump() ).model_dump()
def convert_variable_aggregator(self, node: dict) -> dict: def convert_variable_aggregator_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
group_enable = node_data["advanced_settings"]["group_enabled"] group_enable = node_data["advanced_settings"]["group_enabled"]
group_variables = {} group_variables = {}
@@ -657,3 +662,13 @@ class DifyConverter(BaseConverter):
group_variables=group_variables, group_variables=group_variables,
group_type=group_type, group_type=group_type,
).model_dump() ).model_dump()
def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the tool node.",
))
return {}

View File

@@ -43,7 +43,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL, "knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR, "parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"question-classifier": NodeType.QUESTION_CLASSIFIER, "question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR "variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL
} }
def __init__(self, config: dict[str, Any]): def __init__(self, config: dict[str, Any]):
@@ -89,6 +90,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return True return True
def parse_workflow(self) -> WorkflowParserResult: def parse_workflow(self) -> WorkflowParserResult:
self._init_node_output_map()
for node in self.origin_nodes: for node in self.origin_nodes:
node = self._convert_node(node) node = self._convert_node(node)
if node: if node:
@@ -128,6 +130,11 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
errors=self.errors errors=self.errors
) )
def _init_node_output_map(self):
for node in self.origin_nodes:
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
def _convert_cycle_node_position(self, node_id: str, position: dict): def _convert_cycle_node_position(self, node_id: str, position: dict):
for node in self.origin_nodes: for node in self.origin_nodes:
if node["id"] == node_id: if node["id"] == node_id:
@@ -214,6 +221,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}", detail=f"convert edge error - {e}",
)) ))
logger.debug(f"convert edge error - {e}", exc_info=True)
return None return None
def _convert_variable(self, variable) -> VariableDefinition | None: def _convert_variable(self, variable) -> VariableDefinition | None:
@@ -221,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return VariableDefinition( return VariableDefinition(
name=variable["name"], name=variable["name"],
default=variable["value"], default=variable["value"],
type=variable["value_type"], type=self.variable_type_map(variable["value_type"]),
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefineition( self.errors.append(ExceptionDefineition(

View File

@@ -175,7 +175,7 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds() elapsed_time = (end_time - start_time).total_seconds()
logger.info( logger.info(
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s") f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
@@ -322,7 +322,7 @@ class WorkflowExecutor:
) )
logger.info( logger.info(
f"Workflow execution completed (streaming), " f"Workflow execution completed (streaming), "
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}" f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
) )
yield { yield {

View File

@@ -196,7 +196,7 @@ class BaseNode(ABC):
timeout=timeout timeout=timeout
) )
elapsed_time = time.time() - start_time elapsed_time = (time.time() - start_time) * 1000
# Extract processed outputs using subclass-defined logic. # Extract processed outputs using subclass-defined logic.
extracted_output = self._extract_output(business_result) extracted_output = self._extract_output(business_result)
@@ -219,7 +219,7 @@ class BaseNode(ABC):
} | self.trans_activate(state) } | self.trans_activate(state)
except TimeoutError: except TimeoutError:
elapsed_time = time.time() - start_time elapsed_time = (time.time() - start_time) * 1000
logger.error( logger.error(
f"Node {self.node_id} execution timed out ({timeout} seconds)." f"Node {self.node_id} execution timed out ({timeout} seconds)."
) )
@@ -230,7 +230,7 @@ class BaseNode(ABC):
variable_pool, variable_pool,
) )
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = (time.time() - start_time) * 1000
logger.error( logger.error(
f"Node {self.node_id} execution failed: {e}", f"Node {self.node_id} execution failed: {e}",
exc_info=True, exc_info=True,
@@ -307,10 +307,10 @@ class BaseNode(ABC):
"done": done "done": done
}) })
elapsed_time = time.time() - start_time elapsed_time = (time.time() - start_time) * 1000
logger.info(f"Node {self.node_id} streaming execution finished, " logger.info(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}") f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output) # Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result) extracted_output = self._extract_output(final_result)
@@ -337,7 +337,7 @@ class BaseNode(ABC):
yield state_update | self.trans_activate(state) yield state_update | self.trans_activate(state)
except TimeoutError: except TimeoutError:
elapsed_time = time.time() - start_time elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)") logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
error_output = self._wrap_error( error_output = self._wrap_error(
f"Node execution timed out ({timeout}s)", f"Node execution timed out ({timeout}s)",
@@ -347,7 +347,7 @@ class BaseNode(ABC):
) )
yield error_output yield error_output
except Exception as e: except Exception as e:
elapsed_time = time.time() - start_time elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True) logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool) error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
yield error_output yield error_output

View File

@@ -12,9 +12,20 @@ class ExpressionEvaluator:
# Reserved namespaces # Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod @classmethod
def normalize_template(cls, template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}',
template
)
@classmethod
def evaluate( def evaluate(
cls,
expression: str, expression: str,
conv_vars: dict[str, Any], conv_vars: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
@@ -37,6 +48,7 @@ class ExpressionEvaluator:
""" """
# Remove Jinja2-style brackets if present # Remove Jinja2-style brackets if present
expression = expression.strip() expression = expression.strip()
expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}" pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip() expression = re.sub(pattern, r"\1", expression).strip()

View File

@@ -5,6 +5,7 @@
""" """
import logging import logging
import re
from typing import Any from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
@@ -39,6 +40,16 @@ class TemplateRenderer:
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
) )
@staticmethod
def normalize_template(template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}',
template
)
def render( def render(
self, self,
template: str, template: str,
@@ -95,7 +106,7 @@ class TemplateRenderer:
context.update(conv_vars) context.update(conv_vars)
context["nodes"] = node_outputs or {} # 旧语法兼容 context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template)
try: try:
tmpl = self.env.from_string(template) tmpl = self.env.from_string(template)
return tmpl.render(**context) return tmpl.render(**context)

View File

@@ -68,7 +68,7 @@ class WorkflowImportSave(BaseModel):
"""工作流导入请求""" """工作流导入请求"""
temp_id: str temp_id: str
name: str name: str
description: str description: str | None = Field(default=None)
# ==================== 工作流配置 ==================== # ==================== 工作流配置 ====================