From 7fd00009a21ed3334f6c72b9068fd3a018ffd637 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 27 Mar 2026 12:00:30 +0800 Subject: [PATCH] perf(workflow): introduce LazyDict to reduce variable serialization, optimize regex to reduce compilation - Use LazyDict for deferred serialization, improving performance - Reuse regex patterns to avoid repeated compilation --- api/app/core/workflow/engine/state_manager.py | 8 +- api/app/core/workflow/engine/variable_pool.py | 60 ++++++- api/app/core/workflow/nodes/base_node.py | 12 +- .../core/workflow/nodes/cycle_graph/loop.py | 15 +- .../workflow/utils/expression_evaluator.py | 78 ++++----- .../core/workflow/utils/template_renderer.py | 164 +++++++++--------- 6 files changed, 188 insertions(+), 149 deletions(-) diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 2da0d3a8..eed44278 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType def merge_activate_state(x, y): - return { - k: x.get(k, False) or y.get(k, False) - for k in set(x) | set(y) - } + merged = dict(x) + for k, v in y.items(): + merged[k] = merged.get(k, False) or v + return merged def merge_looping_state(x, y): diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 60f1257e..7faca82d 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta logger = logging.getLogger(__name__) +VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}") + + +class LazyVariableDict: + def __init__(self, source, literal): + self._source: dict[str, VariableStruct[Any]] = source + self._literal: bool = literal + self._cache = {} + + def keys(self): + return self._source.keys() + + def _resolve(self, key): + if key in self._cache: + return self._cache[key] + var_struct = self._source.get(key) + if var_struct is None: + raise KeyError(key) + value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value() + self._cache[key] = value + return value + + def get(self, key, default=None): + try: + return self._resolve(key) + except KeyError: + return default + + def __getitem__(self, key): + return self._resolve(key) + + def __getattr__(self, key): + if key.startswith('_'): + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + return self._resolve(key) + + def __contains__(self, key): + return key in self._source + + def __iter__(self): + return iter(self._source) + + def __len__(self): + return len(self._source) + class VariableSelector: """变量选择器 @@ -117,8 +162,7 @@ class VariablePool: @staticmethod def transform_selector(selector): - pattern = r"\{\{\s*(.*?)\s*\}\}" - variable_literal = re.sub(pattern, r"\1", selector).strip() + variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip() selector = VariableSelector.from_string(variable_literal).path if len(selector) != 2: raise ValueError(f"Selector not valid - {selector}") @@ -303,6 +347,16 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None + def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict: + return LazyVariableDict(self.variables.get(namespace, {}), literal) + + def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]: + return { + ns: LazyVariableDict(vars_dict, literal) + for ns, vars_dict in self.variables.items() + if ns not in ("sys", "conv") + } + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 @@ -479,5 +533,3 @@ class VariablePoolInitializer: var_type=var_type, mut=False ) - - diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 8567ebbe..bedf6165 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -552,9 +552,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(literal=True), - node_outputs=variable_pool.get_all_node_outputs(literal=True), - system_vars=variable_pool.get_all_system_vars(literal=True), + conv_vars=variable_pool.lazy_namespace("conv", literal=True), + node_outputs=variable_pool.lazy_all_node_outputs(literal=True), + system_vars=variable_pool.lazy_namespace("sys", literal=True), strict=strict ) @@ -579,9 +579,9 @@ class BaseNode(ABC): return evaluate_condition( expression=expression, - conv_var=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars() + conv_var=variable_pool.lazy_namespace("conv"), + node_outputs=variable_pool.lazy_all_node_outputs(), + system_vars=variable_pool.lazy_namespace("sys") ) @staticmethod diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 84901bad..e555a228 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance -from app.core.workflow.utils.expression_evaluator import evaluate_expression logger = logging.getLogger(__name__) @@ -85,12 +84,7 @@ class LoopRuntime: for variable in self.typed_config.cycle_vars: if variable.input_type == ValueInputType.VARIABLE: - value = evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + value = self.variable_pool.get_value(variable.value) else: value = TypeTransformer.transform(variable.value, variable.type) await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) @@ -98,12 +92,7 @@ class LoopRuntime: **self.state ) loopstate["node_outputs"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + variable.name: self.variable_pool.get_value(variable.value) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars diff --git a/api/app/core/workflow/utils/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py index 4bc5fc4c..05a3294b 100644 --- a/api/app/core/workflow/utils/expression_evaluator.py +++ b/api/app/core/workflow/utils/expression_evaluator.py @@ -4,32 +4,33 @@ from typing import Any from simpleeval import simple_eval, NameNotDefined, InvalidExpression +from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class ExpressionEvaluator: """Safe expression evaluator for workflow variables and node outputs.""" - + # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} @classmethod def normalize_template(cls, template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @classmethod def evaluate( - cls, - expression: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + cls, + expression: str, + conv_vars: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> Any: """ Safely evaluate an expression using workflow variables. @@ -49,48 +50,47 @@ class ExpressionEvaluator: # Remove Jinja2-style brackets if present expression = expression.strip() expression = cls.normalize_template(expression) - pattern = r"\{\{\s*(.*?)\s*\}\}" - expression = re.sub(pattern, r"\1", expression).strip() + expression = VARIABLE_PATTERN.sub(r"\1", expression).strip() # Build context for evaluation context = { - "conv": conv_vars, # conversation variables - "node": node_outputs, # node outputs - "sys": system_vars or {}, # system variables + "conv": conv_vars, # conversation variables + "node": node_outputs, # node outputs + "sys": system_vars or {}, # system variables } - context.update(conv_vars) - context["nodes"] = node_outputs + # context.update(conv_vars) + # context["nodes"] = node_outputs context.update(node_outputs) - + try: # simpleeval supports safe operations: # arithmetic, comparisons, logical ops, attribute/dict/list access result = simple_eval(expression, names=context) return result - + except NameNotDefined as e: logger.error(f"Undefined variable in expression: {expression}, error: {e}") raise ValueError(f"Undefined variable: {e}") - + except InvalidExpression as e: logger.error(f"Invalid expression syntax: {expression}, error: {e}") raise ValueError(f"Invalid expression syntax: {e}") - + except SyntaxError as e: logger.error(f"Syntax error in expression: {expression}, error: {e}") raise ValueError(f"Syntax error: {e}") - + except Exception as e: logger.error(f"Expression evaluation failed: {expression}, error: {e}") raise ValueError(f"Expression evaluation failed: {e}") - + @staticmethod def evaluate_bool( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + expression: str, + conv_var: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> bool: """ Evaluate a boolean expression (for conditions). @@ -108,7 +108,7 @@ class ExpressionEvaluator: expression, conv_var, node_outputs, system_vars ) return bool(result) - + @staticmethod def validate_variable_names(variables: list[dict]) -> list[str]: """ @@ -121,7 +121,7 @@ class ExpressionEvaluator: list[str]: List of error messages. Empty if all names are valid. """ errors = [] - + for var in variables: var_name = var.get("name", "") @@ -134,16 +134,16 @@ class ExpressionEvaluator: errors.append( f"Variable name '{var_name}' is not a valid Python identifier" ) - + return errors # 便捷函数 def evaluate_expression( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict ) -> Any: """Evaluate an expression (convenience function).""" return ExpressionEvaluator.evaluate( @@ -152,11 +152,11 @@ def evaluate_expression( def evaluate_condition( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None -) -> bool: + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict +) -> Any: """Evaluate a boolean condition expression (convenience function).""" return ExpressionEvaluator.evaluate_bool( expression, conv_var, node_outputs, system_vars diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 6a73efc4..bb1e18bf 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -1,7 +1,8 @@ """ -模板渲染器 +Template Renderer -使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。 +Provides safe template rendering using Jinja2, supporting variable references +and expressions. """ import logging @@ -10,11 +11,15 @@ from typing import Any from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined +from app.core.workflow.engine.variable_pool import LazyVariableDict + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class SafeUndefined(Undefined): - """访问未定义属性不会报错,返回空字符串""" + """Return empty string instead of raising error when accessing undefined variables""" __slots__ = () def _fail_with_undefined_error(self, *args, **kwargs): @@ -26,26 +31,22 @@ class SafeUndefined(Undefined): class TemplateRenderer: - """模板渲染器""" - def __init__(self, strict: bool = True): - """初始化渲染器 - + """Initialize renderer + Args: - strict: 是否使用严格模式(未定义变量会抛出异常) + strict: Whether to enable strict mode (raise error on undefined variables) """ self.strict = strict self.env = Environment( undefined=StrictUndefined if strict else SafeUndefined, - autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML + autoescape=False # Disable auto-escaping since we handle plain text instead of HTML ) @staticmethod def normalize_template(template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + """Normalize template syntax (convert numeric node reference to dict access)""" + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @@ -53,24 +54,24 @@ class TemplateRenderer: def render( self, template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict | None = None ) -> str: - """渲染模板 - + """Render template + Args: - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出结果 - system_vars: 系统变量 - + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Raises: - ValueError: 模板语法错误或变量未定义 - + ValueError: If template syntax is invalid or variables are undefined + Examples: >>> renderer = TemplateRenderer() >>> renderer.render( @@ -80,122 +81,119 @@ class TemplateRenderer: ... {} ... ) 'Hello World!' - + >>> renderer.render( - ... "分析结果: {{node.analyze.output}}", + ... "Analysis result: {{node.analyze.output}}", ... {}, - ... {"analyze": {"output": "正面情绪"}}, + ... {"analyze": {"output": "positive sentiment"}}, ... {} ... ) - '分析结果: 正面情绪' + 'Analysis result: positive sentiment' """ - # 构建命名空间上下文 + # Build namespace context context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": system_vars, # 系统变量:{{sys.execution_id}} + "conv": conv_vars, # Conversation variables: {{conv.user_name}} + "node": node_outputs, # Node outputs: {{node.node_1.output}} + "sys": system_vars, # System variables: {{sys.execution_id}} } - # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} - # 将所有节点输出添加到顶层上下文 + # Allow direct access to node outputs by node ID: {{llm_qa.output}} if node_outputs: context.update(node_outputs) - # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} - if conv_vars: - context.update(conv_vars) - - context["nodes"] = node_outputs or {} # 旧语法兼容 + # # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} + # if conv_vars: + # context.update(conv_vars) + # + # context["nodes"] = node_outputs or {} # 旧语法兼容 template = self.normalize_template(template) try: tmpl = self.env.from_string(template) return tmpl.render(**context) except TemplateSyntaxError as e: - logger.error(f"模板语法错误: {template}, 错误: {e}") - raise ValueError(f"模板语法错误: {e}") - + logger.error(f"Template syntax error: {template}, error: {e}") + raise ValueError(f"Template syntax error: {e}") except UndefinedError as e: - logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") - raise ValueError(f"未定义的变量: {e}") - + logger.error(f"Undefined variable in template: {template}, error: {e}") + raise ValueError(f"Undefined variable: {e}") except Exception as e: - logger.error(f"模板渲染异常: {template}, 错误: {e}") - raise ValueError(f"模板渲染失败: {e}") + logger.error(f"Template rendering error: {template}, error: {e}") + raise ValueError(f"Template rendering failed: {e}") def validate(self, template: str) -> list[str]: - """验证模板语法 - + """Validate template syntax + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表,如果为空则验证通过 - + List of errors (empty if valid) + Examples: >>> renderer = TemplateRenderer() >>> renderer.validate("Hello {{var.name}}!") [] - - >>> renderer.validate("Hello {{var.name") # 缺少结束标记 - ['模板语法错误: ...'] + + >>> renderer.validate("Hello {{var.name") # Missing closing tag + ['Template syntax error: ...'] """ errors = [] try: self.env.from_string(template) except TemplateSyntaxError as e: - errors.append(f"模板语法错误: {e}") + errors.append(f"Template syntax error: {e}") except Exception as e: - errors.append(f"模板验证失败: {e}") + errors.append(f"Template validation failed: {e}") return errors -# 全局渲染器实例(严格模式) +# Global renderer instances (strict / lenient) _strict_renderer = TemplateRenderer(strict=True) _lenient_renderer = TemplateRenderer(strict=False) def render_template( template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any], + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | LazyVariableDict, + system_vars: dict[str, Any] | LazyVariableDict, strict: bool = True ) -> str: - """渲染模板(便捷函数) - + """Render template (convenience function) + Args: - strict: 严格模式 - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出 - system_vars: 系统变量 - + strict: Whether to use strict mode + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Examples: >>> render_template( - ... "请分析: {{var.text}}", - ... {"text": "这是一段文本"}, + ... "Analyze: {{var.text}}", + ... {"text": "This is a text"}, ... {}, ... {} ... ) - '请分析: 这是一段文本' + 'Analyze: This is a text' """ renderer = _strict_renderer if strict else _lenient_renderer return renderer.render(template, conv_vars, node_outputs, system_vars) def validate_template(template: str) -> list[str]: - """验证模板语法(便捷函数) - + """Validate template syntax (convenience function) + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表 + List of errors """ return _strict_renderer.validate(template)