From 4534b65d6a67392c1a29ce0aac0452ec89e9fd8d Mon Sep 17 00:00:00 2001
From: Eternity <1533512157@qq.com>
Date: Fri, 27 Mar 2026 11:56:22 +0800
Subject: [PATCH 1/3] refactor(workflow): optimize workflow history queries and
migrate ORM to SQLAlchemy 2.0
- Migrate historical workflow queries from legacy ORM Query API to SQLAlchemy 2.0 select() + execute()
- Limit query fields and use pagination to reduce returned data, improving performance
- Preserve original ordering and filtering logic
---
api/app/repositories/workflow_repository.py | 35 ++++++++------
api/app/services/workflow_service.py | 51 +++++++++++----------
2 files changed, 50 insertions(+), 36 deletions(-)
diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py
index 4e24faa0..a783fe3f 100644
--- a/api/app/repositories/workflow_repository.py
+++ b/api/app/repositories/workflow_repository.py
@@ -3,9 +3,9 @@
"""
import uuid
-from typing import Any, Annotated
+from typing import Any, Annotated, Literal
from sqlalchemy.orm import Session
-from sqlalchemy import desc
+from sqlalchemy import desc, select
from fastapi import Depends
from app.models.workflow_model import (
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
Returns:
执行记录列表
"""
- return self.db.query(WorkflowExecution).filter(
+ stmt = select(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id
).order_by(
desc(WorkflowExecution.started_at)
- ).limit(limit).offset(offset).all()
+ ).limit(limit).offset(offset)
+ return list(self.db.execute(stmt).scalars())
def get_by_conversation_id(
self,
- conversation_id: uuid.UUID
+ conversation_id: uuid.UUID,
+ status: Literal["running", "completed", "failed"] = None,
+ limit_count: int = 50
) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表
Args:
+ limit_count:
conversation_id: 会话 ID
+ status: 状态(可选)
Returns:
执行记录列表
"""
- return self.db.query(WorkflowExecution).filter(
+ stmt = select(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id
- ).order_by(
- desc(WorkflowExecution.started_at)
- ).all()
+ )
+ if status:
+ stmt = stmt.filter(WorkflowExecution.status == status)
+ stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
+ return list(self.db.execute(stmt).scalars())
def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
Returns:
节点执行记录列表(按执行顺序排序)
"""
- return self.db.query(WorkflowNodeExecution).filter(
+ stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id
).order_by(
WorkflowNodeExecution.execution_order
- ).all()
+ )
+ return list(self.db.execute(stmt).scalars())
def get_by_node_id(
self,
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
Returns:
节点执行记录列表
"""
- return self.db.query(WorkflowNodeExecution).filter(
+ stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id
).order_by(
WorkflowNodeExecution.retry_count
- ).all()
+ )
+ return list(self.db.execute(stmt).scalars())
# ==================== 依赖注入函数 ====================
diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py
index c7d7f2b1..13267078 100644
--- a/api/app/services/workflow_service.py
+++ b/api/app/services/workflow_service.py
@@ -561,6 +561,24 @@ class WorkflowService:
storage_type = 'neo4j'
return storage_type, user_rag_memory_id
+ def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
+ executions = self.execution_repo.get_by_conversation_id(
+ conversation_id=conversation_id,
+ status="completed",
+ limit_count=1
+ )
+
+ if executions:
+ last_state = executions[0].output_data
+ if isinstance(last_state, dict):
+ variables = last_state.get("variables", {})
+ conv_vars = variables.get("conv", {})
+ # input_data["conv"] = conv_vars
+ # input_data["conv_messages"] = last_state.get("messages") or []
+ conv_messages = last_state.get("messages") or []
+ return conv_vars, conv_messages
+ return None
+
# ==================== 工作流执行 ====================
async def run(
@@ -634,18 +652,11 @@ class WorkflowService:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
- executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
-
- for exec_res in executions:
- if exec_res.status == "completed":
- last_state = exec_res.output_data
- if isinstance(last_state, dict):
- variables = last_state.get("variables", {})
- conv_vars = variables.get("conv", {})
- input_data["conv"] = conv_vars
- input_data["conv_messages"] = last_state.get("messages") or []
- break
-
+ history = self._get_history_info(conversation_id_uuid)
+ if history:
+ conv_vars, conv_messages = history
+ input_data["conv"] = conv_vars
+ input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow(
@@ -807,17 +818,11 @@ class WorkflowService:
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files
self.update_execution_status(execution.execution_id, "running")
- executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
-
- for exec_res in executions:
- if exec_res.status == "completed":
- last_state = exec_res.output_data
- if isinstance(last_state, dict):
- variables = last_state.get("variables", {})
- conv_vars = variables.get("conv", {})
- input_data["conv"] = conv_vars
- input_data["conv_messages"] = last_state.get("messages") or []
- break
+ history = self._get_history_info(conversation_id_uuid)
+ if history:
+ conv_vars, conv_messages = history
+ input_data["conv"] = conv_vars
+ input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4()
async for event in execute_workflow_stream(
From 7fd00009a21ed3334f6c72b9068fd3a018ffd637 Mon Sep 17 00:00:00 2001
From: Eternity <1533512157@qq.com>
Date: Fri, 27 Mar 2026 12:00:30 +0800
Subject: [PATCH 2/3] perf(workflow): introduce LazyDict to reduce variable
serialization, optimize regex to reduce compilation
- Use LazyDict for deferred serialization, improving performance
- Reuse regex patterns to avoid repeated compilation
---
api/app/core/workflow/engine/state_manager.py | 8 +-
api/app/core/workflow/engine/variable_pool.py | 60 ++++++-
api/app/core/workflow/nodes/base_node.py | 12 +-
.../core/workflow/nodes/cycle_graph/loop.py | 15 +-
.../workflow/utils/expression_evaluator.py | 78 ++++-----
.../core/workflow/utils/template_renderer.py | 164 +++++++++---------
6 files changed, 188 insertions(+), 149 deletions(-)
diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py
index 2da0d3a8..eed44278 100644
--- a/api/app/core/workflow/engine/state_manager.py
+++ b/api/app/core/workflow/engine/state_manager.py
@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
def merge_activate_state(x, y):
- return {
- k: x.get(k, False) or y.get(k, False)
- for k in set(x) | set(y)
- }
+ merged = dict(x)
+ for k, v in y.items():
+ merged[k] = merged.get(k, False) or v
+ return merged
def merge_looping_state(x, y):
diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py
index 60f1257e..7faca82d 100644
--- a/api/app/core/workflow/engine/variable_pool.py
+++ b/api/app/core/workflow/engine/variable_pool.py
@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
logger = logging.getLogger(__name__)
+VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
+
+
+class LazyVariableDict:
+ def __init__(self, source, literal):
+ self._source: dict[str, VariableStruct[Any]] = source
+ self._literal: bool = literal
+ self._cache = {}
+
+ def keys(self):
+ return self._source.keys()
+
+ def _resolve(self, key):
+ if key in self._cache:
+ return self._cache[key]
+ var_struct = self._source.get(key)
+ if var_struct is None:
+ raise KeyError(key)
+ value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
+ self._cache[key] = value
+ return value
+
+ def get(self, key, default=None):
+ try:
+ return self._resolve(key)
+ except KeyError:
+ return default
+
+ def __getitem__(self, key):
+ return self._resolve(key)
+
+ def __getattr__(self, key):
+ if key.startswith('_'):
+ raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
+ return self._resolve(key)
+
+ def __contains__(self, key):
+ return key in self._source
+
+ def __iter__(self):
+ return iter(self._source)
+
+ def __len__(self):
+ return len(self._source)
+
class VariableSelector:
"""变量选择器
@@ -117,8 +162,7 @@ class VariablePool:
@staticmethod
def transform_selector(selector):
- pattern = r"\{\{\s*(.*?)\s*\}\}"
- variable_literal = re.sub(pattern, r"\1", selector).strip()
+ variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}")
@@ -303,6 +347,16 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
+ def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
+ return LazyVariableDict(self.variables.get(namespace, {}), literal)
+
+ def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
+ return {
+ ns: LazyVariableDict(vars_dict, literal)
+ for ns, vars_dict in self.variables.items()
+ if ns not in ("sys", "conv")
+ }
+
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
var_type=var_type,
mut=False
)
-
-
diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py
index 8567ebbe..bedf6165 100644
--- a/api/app/core/workflow/nodes/base_node.py
+++ b/api/app/core/workflow/nodes/base_node.py
@@ -552,9 +552,9 @@ class BaseNode(ABC):
return render_template(
template=template,
- conv_vars=variable_pool.get_all_conversation_vars(literal=True),
- node_outputs=variable_pool.get_all_node_outputs(literal=True),
- system_vars=variable_pool.get_all_system_vars(literal=True),
+ conv_vars=variable_pool.lazy_namespace("conv", literal=True),
+ node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
+ system_vars=variable_pool.lazy_namespace("sys", literal=True),
strict=strict
)
@@ -579,9 +579,9 @@ class BaseNode(ABC):
return evaluate_condition(
expression=expression,
- conv_var=variable_pool.get_all_conversation_vars(),
- node_outputs=variable_pool.get_all_node_outputs(),
- system_vars=variable_pool.get_all_system_vars()
+ conv_var=variable_pool.lazy_namespace("conv"),
+ node_outputs=variable_pool.lazy_all_node_outputs(),
+ system_vars=variable_pool.lazy_namespace("sys")
)
@staticmethod
diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py
index 84901bad..e555a228 100644
--- a/api/app/core/workflow/nodes/cycle_graph/loop.py
+++ b/api/app/core/workflow/nodes/cycle_graph/loop.py
@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
-from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__)
@@ -85,12 +84,7 @@ class LoopRuntime:
for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE:
- value = evaluate_expression(
- expression=variable.value,
- conv_var=self.variable_pool.get_all_conversation_vars(),
- node_outputs=self.variable_pool.get_all_node_outputs(),
- system_vars=self.variable_pool.get_all_system_vars(),
- )
+ value = self.variable_pool.get_value(variable.value)
else:
value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
@@ -98,12 +92,7 @@ class LoopRuntime:
**self.state
)
loopstate["node_outputs"][self.node_id] = {
- variable.name: evaluate_expression(
- expression=variable.value,
- conv_var=self.variable_pool.get_all_conversation_vars(),
- node_outputs=self.variable_pool.get_all_node_outputs(),
- system_vars=self.variable_pool.get_all_system_vars(),
- )
+ variable.name: self.variable_pool.get_value(variable.value)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars
diff --git a/api/app/core/workflow/utils/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py
index 4bc5fc4c..05a3294b 100644
--- a/api/app/core/workflow/utils/expression_evaluator.py
+++ b/api/app/core/workflow/utils/expression_evaluator.py
@@ -4,32 +4,33 @@ from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
+from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
+
logger = logging.getLogger(__name__)
+_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
+
class ExpressionEvaluator:
"""Safe expression evaluator for workflow variables and node outputs."""
-
+
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@classmethod
def normalize_template(cls, template: str) -> str:
- pattern = re.compile(
- r"\{\{\s*(\d+)\.(\w+)\s*}}"
- )
- return pattern.sub(
+ return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@classmethod
def evaluate(
- cls,
- expression: str,
- conv_vars: dict[str, Any],
- node_outputs: dict[str, Any],
- system_vars: dict[str, Any] | None = None
+ cls,
+ expression: str,
+ conv_vars: dict[str, Any],
+ node_outputs: dict[str, Any],
+ system_vars: dict[str, Any] | None = None
) -> Any:
"""
Safely evaluate an expression using workflow variables.
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
# Remove Jinja2-style brackets if present
expression = expression.strip()
expression = cls.normalize_template(expression)
- pattern = r"\{\{\s*(.*?)\s*\}\}"
- expression = re.sub(pattern, r"\1", expression).strip()
+ expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
# Build context for evaluation
context = {
- "conv": conv_vars, # conversation variables
- "node": node_outputs, # node outputs
- "sys": system_vars or {}, # system variables
+ "conv": conv_vars, # conversation variables
+ "node": node_outputs, # node outputs
+ "sys": system_vars or {}, # system variables
}
- context.update(conv_vars)
- context["nodes"] = node_outputs
+ # context.update(conv_vars)
+ # context["nodes"] = node_outputs
context.update(node_outputs)
-
+
try:
# simpleeval supports safe operations:
# arithmetic, comparisons, logical ops, attribute/dict/list access
result = simple_eval(expression, names=context)
return result
-
+
except NameNotDefined as e:
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
-
+
except InvalidExpression as e:
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
raise ValueError(f"Invalid expression syntax: {e}")
-
+
except SyntaxError as e:
logger.error(f"Syntax error in expression: {expression}, error: {e}")
raise ValueError(f"Syntax error: {e}")
-
+
except Exception as e:
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
raise ValueError(f"Expression evaluation failed: {e}")
-
+
@staticmethod
def evaluate_bool(
- expression: str,
- conv_var: dict[str, Any],
- node_outputs: dict[str, Any],
- system_vars: dict[str, Any] | None = None
+ expression: str,
+ conv_var: dict[str, Any],
+ node_outputs: dict[str, Any],
+ system_vars: dict[str, Any] | None = None
) -> bool:
"""
Evaluate a boolean expression (for conditions).
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
expression, conv_var, node_outputs, system_vars
)
return bool(result)
-
+
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
list[str]: List of error messages. Empty if all names are valid.
"""
errors = []
-
+
for var in variables:
var_name = var.get("name", "")
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
errors.append(
f"Variable name '{var_name}' is not a valid Python identifier"
)
-
+
return errors
# 便捷函数
def evaluate_expression(
- expression: str,
- conv_var: dict[str, Any],
- node_outputs: dict[str, Any],
- system_vars: dict[str, Any]
+ expression: str,
+ conv_var: dict[str, Any] | LazyVariableDict,
+ node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
+ system_vars: dict[str, Any] | LazyVariableDict
) -> Any:
"""Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate(
@@ -152,11 +152,11 @@ def evaluate_expression(
def evaluate_condition(
- expression: str,
- conv_var: dict[str, Any],
- node_outputs: dict[str, Any],
- system_vars: dict[str, Any] | None = None
-) -> bool:
+ expression: str,
+ conv_var: dict[str, Any] | LazyVariableDict,
+ node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
+ system_vars: dict[str, Any] | LazyVariableDict
+) -> Any:
"""Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool(
expression, conv_var, node_outputs, system_vars
diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py
index 6a73efc4..bb1e18bf 100644
--- a/api/app/core/workflow/utils/template_renderer.py
+++ b/api/app/core/workflow/utils/template_renderer.py
@@ -1,7 +1,8 @@
"""
-模板渲染器
+Template Renderer
-使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
+Provides safe template rendering using Jinja2, supporting variable references
+and expressions.
"""
import logging
@@ -10,11 +11,15 @@ from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
+from app.core.workflow.engine.variable_pool import LazyVariableDict
+
logger = logging.getLogger(__name__)
+_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
+
class SafeUndefined(Undefined):
- """访问未定义属性不会报错,返回空字符串"""
+ """Return empty string instead of raising error when accessing undefined variables"""
__slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs):
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
class TemplateRenderer:
- """模板渲染器"""
-
def __init__(self, strict: bool = True):
- """初始化渲染器
-
+ """Initialize renderer
+
Args:
- strict: 是否使用严格模式(未定义变量会抛出异常)
+ strict: Whether to enable strict mode (raise error on undefined variables)
"""
self.strict = strict
self.env = Environment(
undefined=StrictUndefined if strict else SafeUndefined,
- autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
+ autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
)
@staticmethod
def normalize_template(template: str) -> str:
- pattern = re.compile(
- r"\{\{\s*(\d+)\.(\w+)\s*}}"
- )
- return pattern.sub(
+ """Normalize template syntax (convert numeric node reference to dict access)"""
+ return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@@ -53,24 +54,24 @@ class TemplateRenderer:
def render(
self,
template: str,
- conv_vars: dict[str, Any],
- node_outputs: dict[str, Any],
- system_vars: dict[str, Any] | None = None
+ conv_vars: dict[str, Any] | LazyVariableDict,
+ node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
+ system_vars: dict[str, Any] | LazyVariableDict | None = None
) -> str:
- """渲染模板
-
+ """Render template
+
Args:
- template: 模板字符串
- conv_vars: 会话变量
- node_outputs: 节点输出结果
- system_vars: 系统变量
-
+ template: Template string
+ conv_vars: Conversation variables
+ node_outputs: Node outputs
+ system_vars: System variables
+
Returns:
- 渲染后的字符串
-
+ Rendered string
+
Raises:
- ValueError: 模板语法错误或变量未定义
-
+ ValueError: If template syntax is invalid or variables are undefined
+
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
@@ -80,122 +81,119 @@ class TemplateRenderer:
... {}
... )
'Hello World!'
-
+
>>> renderer.render(
- ... "分析结果: {{node.analyze.output}}",
+ ... "Analysis result: {{node.analyze.output}}",
... {},
- ... {"analyze": {"output": "正面情绪"}},
+ ... {"analyze": {"output": "positive sentiment"}},
... {}
... )
- '分析结果: 正面情绪'
+ 'Analysis result: positive sentiment'
"""
- # 构建命名空间上下文
+ # Build namespace context
context = {
- "conv": conv_vars, # 会话变量:{{conv.user_name}}
- "node": node_outputs, # 节点输出:{{node.node_1.output}}
- "sys": system_vars, # 系统变量:{{sys.execution_id}}
+ "conv": conv_vars, # Conversation variables: {{conv.user_name}}
+ "node": node_outputs, # Node outputs: {{node.node_1.output}}
+ "sys": system_vars, # System variables: {{sys.execution_id}}
}
- # 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
- # 将所有节点输出添加到顶层上下文
+ # Allow direct access to node outputs by node ID: {{llm_qa.output}}
if node_outputs:
context.update(node_outputs)
- # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
- if conv_vars:
- context.update(conv_vars)
-
- context["nodes"] = node_outputs or {} # 旧语法兼容
+ # # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
+ # if conv_vars:
+ # context.update(conv_vars)
+ #
+ # context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template)
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
- logger.error(f"模板语法错误: {template}, 错误: {e}")
- raise ValueError(f"模板语法错误: {e}")
-
+ logger.error(f"Template syntax error: {template}, error: {e}")
+ raise ValueError(f"Template syntax error: {e}")
except UndefinedError as e:
- logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
- raise ValueError(f"未定义的变量: {e}")
-
+ logger.error(f"Undefined variable in template: {template}, error: {e}")
+ raise ValueError(f"Undefined variable: {e}")
except Exception as e:
- logger.error(f"模板渲染异常: {template}, 错误: {e}")
- raise ValueError(f"模板渲染失败: {e}")
+ logger.error(f"Template rendering error: {template}, error: {e}")
+ raise ValueError(f"Template rendering failed: {e}")
def validate(self, template: str) -> list[str]:
- """验证模板语法
-
+ """Validate template syntax
+
Args:
- template: 模板字符串
-
+ template: Template string
+
Returns:
- 错误列表,如果为空则验证通过
-
+ List of errors (empty if valid)
+
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
-
- >>> renderer.validate("Hello {{var.name") # 缺少结束标记
- ['模板语法错误: ...']
+
+ >>> renderer.validate("Hello {{var.name") # Missing closing tag
+ ['Template syntax error: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
- errors.append(f"模板语法错误: {e}")
+ errors.append(f"Template syntax error: {e}")
except Exception as e:
- errors.append(f"模板验证失败: {e}")
+ errors.append(f"Template validation failed: {e}")
return errors
-# 全局渲染器实例(严格模式)
+# Global renderer instances (strict / lenient)
_strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template(
template: str,
- conv_vars: dict[str, Any],
- node_outputs: dict[str, Any],
- system_vars: dict[str, Any],
+ conv_vars: dict[str, Any] | LazyVariableDict,
+ node_outputs: dict[str, Any] | LazyVariableDict,
+ system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True
) -> str:
- """渲染模板(便捷函数)
-
+ """Render template (convenience function)
+
Args:
- strict: 严格模式
- template: 模板字符串
- conv_vars: 会话变量
- node_outputs: 节点输出
- system_vars: 系统变量
-
+ strict: Whether to use strict mode
+ template: Template string
+ conv_vars: Conversation variables
+ node_outputs: Node outputs
+ system_vars: System variables
+
Returns:
- 渲染后的字符串
-
+ Rendered string
+
Examples:
>>> render_template(
- ... "请分析: {{var.text}}",
- ... {"text": "这是一段文本"},
+ ... "Analyze: {{var.text}}",
+ ... {"text": "This is a text"},
... {},
... {}
... )
- '请分析: 这是一段文本'
+ 'Analyze: This is a text'
"""
renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
- """验证模板语法(便捷函数)
-
+ """Validate template syntax (convenience function)
+
Args:
- template: 模板字符串
-
+ template: Template string
+
Returns:
- 错误列表
+ List of errors
"""
return _strict_renderer.validate(template)
From bca43fcc75e41efa68f4fd997ee8acb0587c334d Mon Sep 17 00:00:00 2001
From: Eternity <1533512157@qq.com>
Date: Fri, 27 Mar 2026 12:02:36 +0800
Subject: [PATCH 3/3] perf(workflow): expose extract_document_text as instance
method, optimize knowledge base parallel search
- Change extract_document_text from private to instance method in multimodal service for external access
- Optimize knowledge base search logic to improve parallel retrieval performance
---
.../workflow/nodes/document_extractor/node.py | 2 +-
api/app/core/workflow/nodes/knowledge/node.py | 175 ++++++++++++------
.../core/workflow/utils/template_renderer.py | 2 +-
api/app/services/multimodal_service.py | 6 +-
4 files changed, 120 insertions(+), 65 deletions(-)
diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py
index 40641f3c..bd828760 100644
--- a/api/app/core/workflow/nodes/document_extractor/node.py
+++ b/api/app/core/workflow/nodes/document_extractor/node.py
@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
- text = await svc._extract_document_text(file_input)
+ text = await svc.extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(
diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py
index 92699cb4..d0b6d098 100644
--- a/api/app/core/workflow/nodes/knowledge/node.py
+++ b/api/app/core/workflow/nodes/knowledge/node.py
@@ -1,19 +1,23 @@
+import asyncio
import logging
import uuid
from typing import Any
+from langchain_core.documents import Document
+
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
-from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
+from app.core.rag.models.chunk import DocumentChunk
+from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
-from app.models import knowledge_model, knowledgeshare_model, ModelType
-from app.repositories import knowledge_repository, knowledgeshare_repository
+from app.models import knowledge_model, ModelType
+from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
- self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc)
return unique
- def _get_existing_kb_ids(self, db, kb_ids):
+ def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
"""
- Resolve all accessible and valid knowledge base IDs for retrieval.
-
- This includes:
- - Private knowledge bases owned by the user
- - Shared knowledge bases
- - Source knowledge bases mapped via knowledge sharing relationships
-
+ Reorder the list of document blocks and return the top_k results most relevant to the query
Args:
- db: Database session.
- kb_ids (list[UUID]): Knowledge base IDs from node configuration.
+ query: query string
+ docs: List of document chunk to be rearranged
+ top_k: The number of top-level documents returned
Returns:
- list[UUID]: Final list of valid knowledge base IDs.
+ Rearranged document chunk list (sorted in descending order of relevance)
+
+ Raises:
+ ValueError: If the input document list is empty or top_k is invalid
"""
- filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
-
- existing_ids = knowledge_repository.get_chunked_knowledgeids(
- db=db,
- filters=filters
- )
-
- filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
-
- share_ids = knowledge_repository.get_chunked_knowledgeids(
- db=db,
- filters=filters
- )
-
- if share_ids:
- filters = [
- knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
+ reranker = self.get_reranker_model()
+ # parameter validation
+ if not docs:
+ raise ValueError("retrieval chunks be empty")
+ if top_k <= 0:
+ raise ValueError("top_k must be a positive integer")
+ try:
+ # Convert to LangChain Document object
+ documents = [
+ Document(
+ page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
+ metadata=doc.metadata or {} # Deal with possible None metadata
+ )
+ for doc in docs
]
- items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
- db=db,
- filters=filters
+
+ # Perform reordering (compress_documents will automatically handle relevance scores and indexing)
+ reranked_docs = list(reranker.compress_documents(documents, query))
+
+ # Sort in descending order based on relevance score
+ reranked_docs.sort(
+ key=lambda x: x.metadata.get("relevance_score", 0),
+ reverse=True
)
- existing_ids.extend(items)
- return existing_ids
+ # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
+ result = []
+ for item in reranked_docs[:top_k]:
+ for doc in docs:
+ if doc.page_content == item.page_content:
+ doc.metadata["score"] = item.metadata["relevance_score"]
+ result.append(doc)
+ return result
+ except Exception as e:
+ raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank:
"""
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
- def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
+ async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
+ rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
+ tasks = []
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
- kb_config.kb_id = child.id
- self.knowledge_retrieval(db, query, rs, child, kb_config)
- return
- self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
+ child_kb_config = kb_config.model_copy()
+ child_kb_config.kb_id = child.id
+ tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
+ if tasks:
+ result = await asyncio.gather(*tasks)
+ for _ in result:
+ rs.extend(_)
+ return rs
+ vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
- rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.similarity_threshold))
+ rs.extend(
+ await asyncio.to_thread(
+ vector_service.search_by_full_text, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.similarity_threshold
+ }
+ )
+ )
case RetrieveType.SEMANTIC:
- rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.vector_similarity_weight))
+ rs.extend(
+ await asyncio.to_thread(
+ vector_service.search_by_vector, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.vector_similarity_weight
+ }
+ )
+ )
case RetrieveType.HYBRID:
- rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.vector_similarity_weight)
- rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.similarity_threshold)
+ rs1_task = asyncio.to_thread(
+ vector_service.search_by_vector, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.vector_similarity_weight
+ }
+ )
+ rs2_task = asyncio.to_thread(
+ vector_service.search_by_full_text, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.similarity_threshold
+ }
+ )
+ rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
- return
+ return []
if self.typed_config.reranker_id:
- self.vector_service.reranker = self.get_reranker_model()
- rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
+ rs.extend(
+ await asyncio.to_thread(
+ self.rerank,
+ **{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
+ )
+ )
else:
rs.extend(sorted(
unique_rs,
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k])
case _:
raise RuntimeError("Unknown retrieval type")
+ return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases
rs = []
+ tasks = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
- self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
+ tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
+ if tasks:
+ result = await asyncio.gather(*tasks)
+ for _ in result:
+ rs.extend(_)
if not rs:
return []
if self.typed_config.reranker_id:
- self.vector_service.reranker = self.get_reranker_model()
- final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
+ final_rs = await asyncio.to_thread(
+ self.rerank,
+ **{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
+ )
else:
final_rs = sorted(
rs,
diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py
index bb1e18bf..2c2d0f67 100644
--- a/api/app/core/workflow/utils/template_renderer.py
+++ b/api/app/core/workflow/utils/template_renderer.py
@@ -158,7 +158,7 @@ _lenient_renderer = TemplateRenderer(strict=False)
def render_template(
template: str,
conv_vars: dict[str, Any] | LazyVariableDict,
- node_outputs: dict[str, Any] | LazyVariableDict,
+ node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True
) -> str:
diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py
index 4cf3d89d..120cccb7 100644
--- a/api/app/services/multimodal_service.py
+++ b/api/app/services/multimodal_service.py
@@ -438,13 +438,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return True, {
"type": "text",
- "text": f"\n{await self._extract_document_text(file)}\n"
+ "text": f"\n{await self.extract_document_text(file)}\n"
}
else:
# 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
- text = await self._extract_document_text(file)
+ text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
@@ -542,7 +542,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
- async def _extract_document_text(self, file: FileInput) -> str:
+ async def extract_document_text(self, file: FileInput) -> str:
"""
提取文档文本内容