Merge pull request #417 from SuanmoSuanyangTechnology/fix/workflow-adapter

fix(workflow): enhance Dify import types, templates and tool nodes
This commit is contained in:
Mark
2026-02-28 18:46:56 +08:00
committed by GitHub
8 changed files with 68 additions and 20 deletions

View File

@@ -68,6 +68,8 @@ class BasePlatformAdapter(ABC):
self.branch_node_cache = defaultdict(list)
self.error_branch_node_cache = []
self.node_output_map = {}
@abstractmethod
def get_metadata(self) -> PlatformMetadata:
"""get platform metadata"""

View File

@@ -44,6 +44,7 @@ class DifyConverter(BaseConverter):
warnings: list
branch_node_cache: dict
error_branch_node_cache: list
node_output_map: dict
def __init__(self):
self.CONFIG_CONVERT_MAP = {
@@ -60,7 +61,8 @@ class DifyConverter(BaseConverter):
"knowledge-retrieval": self.convert_knowledge_node_config,
"parameter-extractor": self.convert_parameter_extractor_node_config,
"question-classifier": self.convert_question_classifier_node_config,
"variable-aggregator": self.convert_variable_aggregator,
"variable-aggregator": self.convert_variable_aggregator_node_config,
"tool": self.convert_tool_node_config,
"loop-start": lambda x: {},
"iteration-start": lambda x: {},
"loop-end": lambda x: {},
@@ -74,8 +76,7 @@ class DifyConverter(BaseConverter):
def is_variable(expression) -> bool:
return bool(re.match(r"\{\{#(.*?)#}}", expression))
@staticmethod
def process_var_selector(var_selector):
def process_var_selector(self, var_selector):
if not var_selector:
return ""
selector = var_selector.split('.')
@@ -86,7 +87,7 @@ class DifyConverter(BaseConverter):
var_selector = ".".join(selector)
mapping = {
"sys.query": "sys.message"
}
} | self.node_output_map
var_selector = mapping.get(var_selector, var_selector)
return var_selector
@@ -124,6 +125,8 @@ class DifyConverter(BaseConverter):
"checkbox": VariableType.BOOLEAN,
"file-list": VariableType.ARRAY_FILE,
"select": VariableType.STRING,
"integer": VariableType.NUMBER,
"float": VariableType.NUMBER,
}
var_type = type_map.get(source_type, source_type)
return var_type
@@ -160,6 +163,8 @@ class DifyConverter(BaseConverter):
"": ComparisonOperator.GE,
"": ComparisonOperator.LE,
"not empty": ComparisonOperator.NOT_EMPTY,
"start with": ComparisonOperator.START_WITH,
"end with": ComparisonOperator.END_WITH,
}
return operator_map.get(operator, operator)
@@ -633,7 +638,7 @@ class DifyConverter(BaseConverter):
prompt=node_data["instruction"]
).model_dump()
def convert_variable_aggregator(self, node: dict) -> dict:
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
node_data = node["data"]
group_enable = node_data["advanced_settings"]["group_enabled"]
group_variables = {}
@@ -657,3 +662,13 @@ class DifyConverter(BaseConverter):
group_variables=group_variables,
group_type=group_type,
).model_dump()
def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"]
self.warnings.append(ExceptionDefineition(
node_id=node["id"],
node_name=node_data["title"],
type=ExceptionType.CONFIG,
detail=f"Please reconfigure the tool node.",
))
return {}

View File

@@ -43,7 +43,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
"question-classifier": NodeType.QUESTION_CLASSIFIER,
"variable-aggregator": NodeType.VAR_AGGREGATOR
"variable-aggregator": NodeType.VAR_AGGREGATOR,
"tool": NodeType.TOOL
}
def __init__(self, config: dict[str, Any]):
@@ -89,6 +90,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return True
def parse_workflow(self) -> WorkflowParserResult:
self._init_node_output_map()
for node in self.origin_nodes:
node = self._convert_node(node)
if node:
@@ -128,6 +130,11 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
errors=self.errors
)
def _init_node_output_map(self):
for node in self.origin_nodes:
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
def _convert_cycle_node_position(self, node_id: str, position: dict):
for node in self.origin_nodes:
if node["id"] == node_id:
@@ -214,6 +221,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
type=ExceptionType.EDGE,
detail=f"convert edge error - {e}",
))
logger.debug(f"convert edge error - {e}", exc_info=True)
return None
def _convert_variable(self, variable) -> VariableDefinition | None:
@@ -221,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
return VariableDefinition(
name=variable["name"],
default=variable["value"],
type=variable["value_type"],
type=self.variable_type_map(variable["value_type"]),
)
except Exception as e:
self.errors.append(ExceptionDefineition(

View File

@@ -175,7 +175,7 @@ class WorkflowExecutor:
elapsed_time = (end_time - start_time).total_seconds()
logger.info(
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
@@ -322,7 +322,7 @@ class WorkflowExecutor:
)
logger.info(
f"Workflow execution completed (streaming), "
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
)
yield {

View File

@@ -196,7 +196,7 @@ class BaseNode(ABC):
timeout=timeout
)
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
# Extract processed outputs using subclass-defined logic.
extracted_output = self._extract_output(business_result)
@@ -219,7 +219,7 @@ class BaseNode(ABC):
} | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(
f"Node {self.node_id} execution timed out ({timeout} seconds)."
)
@@ -230,7 +230,7 @@ class BaseNode(ABC):
variable_pool,
)
except Exception as e:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(
f"Node {self.node_id} execution failed: {e}",
exc_info=True,
@@ -307,10 +307,10 @@ class BaseNode(ABC):
"done": done
})
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.info(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result)
@@ -337,7 +337,7 @@ class BaseNode(ABC):
yield state_update | self.trans_activate(state)
except TimeoutError:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
error_output = self._wrap_error(
f"Node execution timed out ({timeout}s)",
@@ -347,7 +347,7 @@ class BaseNode(ABC):
)
yield error_output
except Exception as e:
elapsed_time = time.time() - start_time
elapsed_time = (time.time() - start_time) * 1000
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
yield error_output

View File

@@ -12,9 +12,20 @@ class ExpressionEvaluator:
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@staticmethod
@classmethod
def normalize_template(cls, template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}',
template
)
@classmethod
def evaluate(
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
@@ -37,6 +48,7 @@ class ExpressionEvaluator:
"""
# Remove Jinja2-style brackets if present
expression = expression.strip()
expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()

View File

@@ -5,6 +5,7 @@
"""
import logging
import re
from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
@@ -39,6 +40,16 @@ class TemplateRenderer:
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
)
@staticmethod
def normalize_template(template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}',
template
)
def render(
self,
template: str,
@@ -95,7 +106,7 @@ class TemplateRenderer:
context.update(conv_vars)
context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template)
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)

View File

@@ -68,7 +68,7 @@ class WorkflowImportSave(BaseModel):
"""工作流导入请求"""
temp_id: str
name: str
description: str
description: str | None = Field(default=None)
# ==================== 工作流配置 ====================