Merge pull request #708 from SuanmoSuanyangTechnology/pref/workflow-engine

pref(workflow): performance optimization
This commit is contained in:
Ke Sun
2026-03-31 11:39:34 +08:00
committed by GitHub
11 changed files with 357 additions and 249 deletions

View File

@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
def merge_activate_state(x, y):
return {
k: x.get(k, False) or y.get(k, False)
for k in set(x) | set(y)
}
merged = dict(x)
for k, v in y.items():
merged[k] = merged.get(k, False) or v
return merged
def merge_looping_state(x, y):

View File

@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
class LazyVariableDict:
def __init__(self, source, literal):
self._source: dict[str, VariableStruct[Any]] = source
self._literal: bool = literal
self._cache = {}
def keys(self):
return self._source.keys()
def _resolve(self, key):
if key in self._cache:
return self._cache[key]
var_struct = self._source.get(key)
if var_struct is None:
raise KeyError(key)
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
self._cache[key] = value
return value
def get(self, key, default=None):
try:
return self._resolve(key)
except KeyError:
return default
def __getitem__(self, key):
return self._resolve(key)
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
return self._resolve(key)
def __contains__(self, key):
return key in self._source
def __iter__(self):
return iter(self._source)
def __len__(self):
return len(self._source)
class VariableSelector:
"""变量选择器
@@ -117,8 +162,7 @@ class VariablePool:
@staticmethod
def transform_selector(selector):
pattern = r"\{\{\s*(.*?)\s*\}\}"
variable_literal = re.sub(pattern, r"\1", selector).strip()
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}")
@@ -303,6 +347,16 @@ class VariablePool:
"""
return self._get_variable_struct(selector) is not None
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
return LazyVariableDict(self.variables.get(namespace, {}), literal)
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
return {
ns: LazyVariableDict(vars_dict, literal)
for ns, vars_dict in self.variables.items()
if ns not in ("sys", "conv")
}
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
var_type=var_type,
mut=False
)

View File

@@ -552,9 +552,9 @@ class BaseNode(ABC):
return render_template(
template=template,
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True),
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
system_vars=variable_pool.lazy_namespace("sys", literal=True),
strict=strict
)
@@ -579,9 +579,9 @@ class BaseNode(ABC):
return evaluate_condition(
expression=expression,
conv_var=variable_pool.get_all_conversation_vars(),
node_outputs=variable_pool.get_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars()
conv_var=variable_pool.lazy_namespace("conv"),
node_outputs=variable_pool.lazy_all_node_outputs(),
system_vars=variable_pool.lazy_namespace("sys")
)
@staticmethod

View File

@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__)
@@ -85,12 +84,7 @@ class LoopRuntime:
for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE:
value = evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
value = self.variable_pool.get_value(variable.value)
else:
value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
@@ -98,12 +92,7 @@ class LoopRuntime:
**self.state
)
loopstate["node_outputs"][self.node_id] = {
variable.name: evaluate_expression(
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
variable.name: self.variable_pool.get_value(variable.value)
if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars

View File

@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input)
text = await svc.extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(

View File

@@ -1,19 +1,23 @@
import asyncio
import logging
import uuid
from typing import Any
from langchain_core.documents import Document
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository
from app.models import knowledge_model, ModelType
from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc)
return unique
def _get_existing_kb_ids(self, db, kb_ids):
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
"""
Resolve all accessible and valid knowledge base IDs for retrieval.
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Reorder the list of document blocks and return the top_k results most relevant to the query
Args:
db: Database session.
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
query: query string
docs: List of document chunk to be rearranged
top_k: The number of top-level documents returned
Returns:
list[UUID]: Final list of valid knowledge base IDs.
Rearranged document chunk list (sorted in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
"""
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
share_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
reranker = self.get_reranker_model()
# parameter validation
if not docs:
raise ValueError("retrieval chunks be empty")
if top_k <= 0:
raise ValueError("top_k must be a positive integer")
try:
# Convert to LangChain Document object
documents = [
Document(
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
metadata=doc.metadata or {} # Deal with possible None metadata
)
for doc in docs
]
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
reranked_docs = list(reranker.compress_documents(documents, query))
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
)
existing_ids.extend(items)
return existing_ids
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank:
"""
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
tasks = []
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
kb_config.kb_id = child.id
self.knowledge_retrieval(db, query, rs, child, kb_config)
return
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
child_kb_config = kb_config.model_copy()
child_kb_config.kb_id = child.id
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
return rs
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
rs.extend(
await asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
)
case RetrieveType.SEMANTIC:
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
rs.extend(
await asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
)
case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
rs1_task = asyncio.to_thread(
vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
return
return []
if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model()
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
rs.extend(
await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
)
)
else:
rs.extend(sorted(
unique_rs,
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k])
case _:
raise RuntimeError("Unknown retrieval type")
return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases
rs = []
tasks = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
if not rs:
return []
if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model()
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
final_rs = await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
)
else:
final_rs = sorted(
rs,

View File

@@ -4,32 +4,33 @@ from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class ExpressionEvaluator:
"""Safe expression evaluator for workflow variables and node outputs."""
# Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@classmethod
def normalize_template(cls, template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@classmethod
def evaluate(
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
cls,
expression: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> Any:
"""
Safely evaluate an expression using workflow variables.
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
# Remove Jinja2-style brackets if present
expression = expression.strip()
expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}"
expression = re.sub(pattern, r"\1", expression).strip()
expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
# Build context for evaluation
context = {
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
"conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables
}
context.update(conv_vars)
context["nodes"] = node_outputs
# context.update(conv_vars)
# context["nodes"] = node_outputs
context.update(node_outputs)
try:
# simpleeval supports safe operations:
# arithmetic, comparisons, logical ops, attribute/dict/list access
result = simple_eval(expression, names=context)
return result
except NameNotDefined as e:
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except InvalidExpression as e:
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
raise ValueError(f"Invalid expression syntax: {e}")
except SyntaxError as e:
logger.error(f"Syntax error in expression: {expression}, error: {e}")
raise ValueError(f"Syntax error: {e}")
except Exception as e:
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
raise ValueError(f"Expression evaluation failed: {e}")
@staticmethod
def evaluate_bool(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
"""
Evaluate a boolean expression (for conditions).
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
expression, conv_var, node_outputs, system_vars
)
return bool(result)
@staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]:
"""
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
list[str]: List of error messages. Empty if all names are valid.
"""
errors = []
for var in variables:
var_name = var.get("name", "")
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
errors.append(
f"Variable name '{var_name}' is not a valid Python identifier"
)
return errors
# 便捷函数
def evaluate_expression(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any]
expression: str,
conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict
) -> Any:
"""Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate(
@@ -152,11 +152,11 @@ def evaluate_expression(
def evaluate_condition(
expression: str,
conv_var: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
) -> bool:
expression: str,
conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict
) -> Any:
"""Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool(
expression, conv_var, node_outputs, system_vars

View File

@@ -1,7 +1,8 @@
"""
模板渲染器
Template Renderer
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
Provides safe template rendering using Jinja2, supporting variable references
and expressions.
"""
import logging
@@ -10,11 +11,15 @@ from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
from app.core.workflow.engine.variable_pool import LazyVariableDict
logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串"""
"""Return empty string instead of raising error when accessing undefined variables"""
__slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs):
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True):
"""初始化渲染器
"""Initialize renderer
Args:
strict: 是否使用严格模式(未定义变量会抛出异常)
strict: Whether to enable strict mode (raise error on undefined variables)
"""
self.strict = strict
self.env = Environment(
undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
)
@staticmethod
def normalize_template(template: str) -> str:
pattern = re.compile(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
"""Normalize template syntax (convert numeric node reference to dict access)"""
return _NORMALIZE_PATTERN.sub(
r'{{ node["\1"].\2 }}',
template
)
@@ -53,24 +54,24 @@ class TemplateRenderer:
def render(
self,
template: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None
conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict | None = None
) -> str:
"""渲染模板
"""Render template
Args:
template: 模板字符串
conv_vars: 会话变量
node_outputs: 节点输出结果
system_vars: 系统变量
template: Template string
conv_vars: Conversation variables
node_outputs: Node outputs
system_vars: System variables
Returns:
渲染后的字符串
Rendered string
Raises:
ValueError: 模板语法错误或变量未定义
ValueError: If template syntax is invalid or variables are undefined
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.render(
@@ -80,122 +81,119 @@ class TemplateRenderer:
... {}
... )
'Hello World!'
>>> renderer.render(
... "分析结果: {{node.analyze.output}}",
... "Analysis result: {{node.analyze.output}}",
... {},
... {"analyze": {"output": "正面情绪"}},
... {"analyze": {"output": "positive sentiment"}},
... {}
... )
'分析结果: 正面情绪'
'Analysis result: positive sentiment'
"""
# 构建命名空间上下文
# Build namespace context
context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}}
"sys": system_vars, # 系统变量:{{sys.execution_id}}
"conv": conv_vars, # Conversation variables: {{conv.user_name}}
"node": node_outputs, # Node outputs: {{node.node_1.output}}
"sys": system_vars, # System variables: {{sys.execution_id}}
}
# 支持直接通过节点ID访问节点输出{{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
# Allow direct access to node outputs by node ID: {{llm_qa.output}}
if node_outputs:
context.update(node_outputs)
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
if conv_vars:
context.update(conv_vars)
context["nodes"] = node_outputs or {} # 旧语法兼容
# # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
# if conv_vars:
# context.update(conv_vars)
#
# context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template)
try:
tmpl = self.env.from_string(template)
return tmpl.render(**context)
except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}")
raise ValueError(f"模板语法错误: {e}")
logger.error(f"Template syntax error: {template}, error: {e}")
raise ValueError(f"Template syntax error: {e}")
except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
raise ValueError(f"未定义的变量: {e}")
logger.error(f"Undefined variable in template: {template}, error: {e}")
raise ValueError(f"Undefined variable: {e}")
except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}")
raise ValueError(f"模板渲染失败: {e}")
logger.error(f"Template rendering error: {template}, error: {e}")
raise ValueError(f"Template rendering failed: {e}")
def validate(self, template: str) -> list[str]:
"""验证模板语法
"""Validate template syntax
Args:
template: 模板字符串
template: Template string
Returns:
错误列表,如果为空则验证通过
List of errors (empty if valid)
Examples:
>>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!")
[]
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
['模板语法错误: ...']
>>> renderer.validate("Hello {{var.name") # Missing closing tag
['Template syntax error: ...']
"""
errors = []
try:
self.env.from_string(template)
except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}")
errors.append(f"Template syntax error: {e}")
except Exception as e:
errors.append(f"模板验证失败: {e}")
errors.append(f"Template validation failed: {e}")
return errors
# 全局渲染器实例(严格模式)
# Global renderer instances (strict / lenient)
_strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False)
def render_template(
template: str,
conv_vars: dict[str, Any],
node_outputs: dict[str, Any],
system_vars: dict[str, Any],
conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True
) -> str:
"""渲染模板(便捷函数)
"""Render template (convenience function)
Args:
strict: 严格模式
template: 模板字符串
conv_vars: 会话变量
node_outputs: 节点输出
system_vars: 系统变量
strict: Whether to use strict mode
template: Template string
conv_vars: Conversation variables
node_outputs: Node outputs
system_vars: System variables
Returns:
渲染后的字符串
Rendered string
Examples:
>>> render_template(
... "请分析: {{var.text}}",
... {"text": "这是一段文本"},
... "Analyze: {{var.text}}",
... {"text": "This is a text"},
... {},
... {}
... )
'请分析: 这是一段文本'
'Analyze: This is a text'
"""
renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数)
"""Validate template syntax (convenience function)
Args:
template: 模板字符串
template: Template string
Returns:
错误列表
List of errors
"""
return _strict_renderer.validate(template)

View File

@@ -3,9 +3,9 @@
"""
import uuid
from typing import Any, Annotated
from typing import Any, Annotated, Literal
from sqlalchemy.orm import Session
from sqlalchemy import desc
from sqlalchemy import desc, select
from fastapi import Depends
from app.models.workflow_model import (
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
stmt = select(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id
).order_by(
desc(WorkflowExecution.started_at)
).limit(limit).offset(offset).all()
).limit(limit).offset(offset)
return list(self.db.execute(stmt).scalars())
def get_by_conversation_id(
self,
conversation_id: uuid.UUID
conversation_id: uuid.UUID,
status: Literal["running", "completed", "failed"] = None,
limit_count: int = 50
) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表
Args:
limit_count:
conversation_id: 会话 ID
status: 状态(可选)
Returns:
执行记录列表
"""
return self.db.query(WorkflowExecution).filter(
stmt = select(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id
).order_by(
desc(WorkflowExecution.started_at)
).all()
)
if status:
stmt = stmt.filter(WorkflowExecution.status == status)
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
return list(self.db.execute(stmt).scalars())
def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
Returns:
节点执行记录列表(按执行顺序排序)
"""
return self.db.query(WorkflowNodeExecution).filter(
stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id
).order_by(
WorkflowNodeExecution.execution_order
).all()
)
return list(self.db.execute(stmt).scalars())
def get_by_node_id(
self,
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
Returns:
节点执行记录列表
"""
return self.db.query(WorkflowNodeExecution).filter(
stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id
).order_by(
WorkflowNodeExecution.retry_count
).all()
)
return list(self.db.execute(stmt).scalars())
# ==================== 依赖注入函数 ====================

View File

@@ -438,13 +438,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return True, {
"type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
}
else:
# 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file)
text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
@@ -542,7 +542,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file: FileInput) -> str:
async def extract_document_text(self, file: FileInput) -> str:
"""
提取文档文本内容

View File

@@ -561,6 +561,24 @@ class WorkflowService:
storage_type = 'neo4j'
return storage_type, user_rag_memory_id
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
executions = self.execution_repo.get_by_conversation_id(
conversation_id=conversation_id,
status="completed",
limit_count=1
)
if executions:
last_state = executions[0].output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
# input_data["conv"] = conv_vars
# input_data["conv_messages"] = last_state.get("messages") or []
conv_messages = last_state.get("messages") or []
return conv_vars, conv_messages
return None
# ==================== 工作流执行 ====================
async def run(
@@ -634,18 +652,11 @@ class WorkflowService:
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
for exec_res in executions:
if exec_res.status == "completed":
last_state = exec_res.output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
history = self._get_history_info(conversation_id_uuid)
if history:
conv_vars, conv_messages = history
input_data["conv"] = conv_vars
input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow(
@@ -807,17 +818,11 @@ class WorkflowService:
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
for exec_res in executions:
if exec_res.status == "completed":
last_state = exec_res.output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
history = self._get_history_info(conversation_id_uuid)
if history:
conv_vars, conv_messages = history
input_data["conv"] = conv_vars
input_data["conv_messages"] = conv_messages
init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4()
async for event in execute_workflow_stream(