Merge pull request #417 from SuanmoSuanyangTechnology/fix/workflow-adapter
fix(workflow): enhance Dify import types, templates and tool nodes
This commit is contained in:
@@ -68,6 +68,8 @@ class BasePlatformAdapter(ABC):
|
||||
self.branch_node_cache = defaultdict(list)
|
||||
self.error_branch_node_cache = []
|
||||
|
||||
self.node_output_map = {}
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
"""get platform metadata"""
|
||||
|
||||
@@ -44,6 +44,7 @@ class DifyConverter(BaseConverter):
|
||||
warnings: list
|
||||
branch_node_cache: dict
|
||||
error_branch_node_cache: list
|
||||
node_output_map: dict
|
||||
|
||||
def __init__(self):
|
||||
self.CONFIG_CONVERT_MAP = {
|
||||
@@ -60,7 +61,8 @@ class DifyConverter(BaseConverter):
|
||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
||||
"question-classifier": self.convert_question_classifier_node_config,
|
||||
"variable-aggregator": self.convert_variable_aggregator,
|
||||
"variable-aggregator": self.convert_variable_aggregator_node_config,
|
||||
"tool": self.convert_tool_node_config,
|
||||
"loop-start": lambda x: {},
|
||||
"iteration-start": lambda x: {},
|
||||
"loop-end": lambda x: {},
|
||||
@@ -74,8 +76,7 @@ class DifyConverter(BaseConverter):
|
||||
def is_variable(expression) -> bool:
|
||||
return bool(re.match(r"\{\{#(.*?)#}}", expression))
|
||||
|
||||
@staticmethod
|
||||
def process_var_selector(var_selector):
|
||||
def process_var_selector(self, var_selector):
|
||||
if not var_selector:
|
||||
return ""
|
||||
selector = var_selector.split('.')
|
||||
@@ -86,7 +87,7 @@ class DifyConverter(BaseConverter):
|
||||
var_selector = ".".join(selector)
|
||||
mapping = {
|
||||
"sys.query": "sys.message"
|
||||
}
|
||||
} | self.node_output_map
|
||||
|
||||
var_selector = mapping.get(var_selector, var_selector)
|
||||
return var_selector
|
||||
@@ -124,6 +125,8 @@ class DifyConverter(BaseConverter):
|
||||
"checkbox": VariableType.BOOLEAN,
|
||||
"file-list": VariableType.ARRAY_FILE,
|
||||
"select": VariableType.STRING,
|
||||
"integer": VariableType.NUMBER,
|
||||
"float": VariableType.NUMBER,
|
||||
}
|
||||
var_type = type_map.get(source_type, source_type)
|
||||
return var_type
|
||||
@@ -160,6 +163,8 @@ class DifyConverter(BaseConverter):
|
||||
"≥": ComparisonOperator.GE,
|
||||
"≤": ComparisonOperator.LE,
|
||||
"not empty": ComparisonOperator.NOT_EMPTY,
|
||||
"start with": ComparisonOperator.START_WITH,
|
||||
"end with": ComparisonOperator.END_WITH,
|
||||
}
|
||||
return operator_map.get(operator, operator)
|
||||
|
||||
@@ -633,7 +638,7 @@ class DifyConverter(BaseConverter):
|
||||
prompt=node_data["instruction"]
|
||||
).model_dump()
|
||||
|
||||
def convert_variable_aggregator(self, node: dict) -> dict:
|
||||
def convert_variable_aggregator_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
group_enable = node_data["advanced_settings"]["group_enabled"]
|
||||
group_variables = {}
|
||||
@@ -657,3 +662,13 @@ class DifyConverter(BaseConverter):
|
||||
group_variables=group_variables,
|
||||
group_type=group_type,
|
||||
).model_dump()
|
||||
|
||||
def convert_tool_node_config(self, node: dict) -> dict:
|
||||
node_data = node["data"]
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
node_id=node["id"],
|
||||
node_name=node_data["title"],
|
||||
type=ExceptionType.CONFIG,
|
||||
detail=f"Please reconfigure the tool node.",
|
||||
))
|
||||
return {}
|
||||
@@ -43,7 +43,8 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
"knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL,
|
||||
"parameter-extractor": NodeType.PARAMETER_EXTRACTOR,
|
||||
"question-classifier": NodeType.QUESTION_CLASSIFIER,
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR
|
||||
"variable-aggregator": NodeType.VAR_AGGREGATOR,
|
||||
"tool": NodeType.TOOL
|
||||
}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
@@ -89,6 +90,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
return True
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self._init_node_output_map()
|
||||
for node in self.origin_nodes:
|
||||
node = self._convert_node(node)
|
||||
if node:
|
||||
@@ -128,6 +130,11 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
errors=self.errors
|
||||
)
|
||||
|
||||
def _init_node_output_map(self):
|
||||
for node in self.origin_nodes:
|
||||
if self.map_node_type(node["data"]["type"]) == NodeType.LLM:
|
||||
self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output"
|
||||
|
||||
def _convert_cycle_node_position(self, node_id: str, position: dict):
|
||||
for node in self.origin_nodes:
|
||||
if node["id"] == node_id:
|
||||
@@ -214,6 +221,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}",
|
||||
))
|
||||
logger.debug(f"convert edge error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_variable(self, variable) -> VariableDefinition | None:
|
||||
@@ -221,7 +229,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
return VariableDefinition(
|
||||
name=variable["name"],
|
||||
default=variable["value"],
|
||||
type=variable["value_type"],
|
||||
type=self.variable_type_map(variable["value_type"]),
|
||||
)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
|
||||
@@ -175,7 +175,7 @@ class WorkflowExecutor:
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s")
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
|
||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
|
||||
@@ -322,7 +322,7 @@ class WorkflowExecutor:
|
||||
)
|
||||
logger.info(
|
||||
f"Workflow execution completed (streaming), "
|
||||
f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}"
|
||||
f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}"
|
||||
)
|
||||
|
||||
yield {
|
||||
|
||||
@@ -196,7 +196,7 @@ class BaseNode(ABC):
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Extract processed outputs using subclass-defined logic.
|
||||
extracted_output = self._extract_output(business_result)
|
||||
@@ -219,7 +219,7 @@ class BaseNode(ABC):
|
||||
} | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution timed out ({timeout} seconds)."
|
||||
)
|
||||
@@ -230,7 +230,7 @@ class BaseNode(ABC):
|
||||
variable_pool,
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(
|
||||
f"Node {self.node_id} execution failed: {e}",
|
||||
exc_info=True,
|
||||
@@ -307,10 +307,10 @@ class BaseNode(ABC):
|
||||
"done": done
|
||||
})
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
|
||||
logger.info(f"Node {self.node_id} streaming execution finished, "
|
||||
f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}")
|
||||
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
|
||||
|
||||
# Extract processed output (call subclass's _extract_output)
|
||||
extracted_output = self._extract_output(final_result)
|
||||
@@ -337,7 +337,7 @@ class BaseNode(ABC):
|
||||
yield state_update | self.trans_activate(state)
|
||||
|
||||
except TimeoutError:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(f"Node {self.node_id} execution timed out ({timeout}s)")
|
||||
error_output = self._wrap_error(
|
||||
f"Node execution timed out ({timeout}s)",
|
||||
@@ -347,7 +347,7 @@ class BaseNode(ABC):
|
||||
)
|
||||
yield error_output
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
elapsed_time = (time.time() - start_time) * 1000
|
||||
logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True)
|
||||
error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool)
|
||||
yield error_output
|
||||
|
||||
@@ -12,9 +12,20 @@ class ExpressionEvaluator:
|
||||
|
||||
# Reserved namespaces
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@staticmethod
|
||||
|
||||
@classmethod
|
||||
def normalize_template(cls, template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def evaluate(
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
@@ -37,6 +48,7 @@ class ExpressionEvaluator:
|
||||
"""
|
||||
# Remove Jinja2-style brackets if present
|
||||
expression = expression.strip()
|
||||
expression = cls.normalize_template(expression)
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
@@ -39,6 +40,16 @@ class TemplateRenderer:
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_template(template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
@@ -95,7 +106,7 @@ class TemplateRenderer:
|
||||
context.update(conv_vars)
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
|
||||
template = self.normalize_template(template)
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
@@ -68,7 +68,7 @@ class WorkflowImportSave(BaseModel):
|
||||
"""工作流导入请求"""
|
||||
temp_id: str
|
||||
name: str
|
||||
description: str
|
||||
description: str | None = Field(default=None)
|
||||
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
|
||||
Reference in New Issue
Block a user