Merge pull request #708 from SuanmoSuanyangTechnology/pref/workflow-engine

pref(workflow): performance optimization
This commit is contained in:
Ke Sun
2026-03-31 11:39:34 +08:00
committed by GitHub
11 changed files with 357 additions and 249 deletions

View File

@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
def merge_activate_state(x, y): def merge_activate_state(x, y):
return { merged = dict(x)
k: x.get(k, False) or y.get(k, False) for k, v in y.items():
for k in set(x) | set(y) merged[k] = merged.get(k, False) or v
} return merged
def merge_looping_state(x, y): def merge_looping_state(x, y):

View File

@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
class LazyVariableDict:
def __init__(self, source, literal):
self._source: dict[str, VariableStruct[Any]] = source
self._literal: bool = literal
self._cache = {}
def keys(self):
return self._source.keys()
def _resolve(self, key):
if key in self._cache:
return self._cache[key]
var_struct = self._source.get(key)
if var_struct is None:
raise KeyError(key)
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
self._cache[key] = value
return value
def get(self, key, default=None):
try:
return self._resolve(key)
except KeyError:
return default
def __getitem__(self, key):
return self._resolve(key)
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
return self._resolve(key)
def __contains__(self, key):
return key in self._source
def __iter__(self):
return iter(self._source)
def __len__(self):
return len(self._source)
class VariableSelector: class VariableSelector:
"""变量选择器 """变量选择器
@@ -117,8 +162,7 @@ class VariablePool:
@staticmethod @staticmethod
def transform_selector(selector): def transform_selector(selector):
pattern = r"\{\{\s*(.*?)\s*\}\}" variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2: if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}") raise ValueError(f"Selector not valid - {selector}")
@@ -303,6 +347,16 @@ class VariablePool:
""" """
return self._get_variable_struct(selector) is not None return self._get_variable_struct(selector) is not None
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
return LazyVariableDict(self.variables.get(namespace, {}), literal)
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
return {
ns: LazyVariableDict(vars_dict, literal)
for ns, vars_dict in self.variables.items()
if ns not in ("sys", "conv")
}
def get_all_system_vars(self, literal=False) -> dict[str, Any]: def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量 """获取所有系统变量
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
var_type=var_type, var_type=var_type,
mut=False mut=False
) )

View File

@@ -552,9 +552,9 @@ class BaseNode(ABC):
return render_template( return render_template(
template=template, template=template,
conv_vars=variable_pool.get_all_conversation_vars(literal=True), conv_vars=variable_pool.lazy_namespace("conv", literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True), node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True), system_vars=variable_pool.lazy_namespace("sys", literal=True),
strict=strict strict=strict
) )
@@ -579,9 +579,9 @@ class BaseNode(ABC):
return evaluate_condition( return evaluate_condition(
expression=expression, expression=expression,
conv_var=variable_pool.get_all_conversation_vars(), conv_var=variable_pool.lazy_namespace("conv"),
node_outputs=variable_pool.get_all_node_outputs(), node_outputs=variable_pool.lazy_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars() system_vars=variable_pool.lazy_namespace("sys")
) )
@staticmethod @staticmethod

View File

@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -85,12 +84,7 @@ class LoopRuntime:
for variable in self.typed_config.cycle_vars: for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE: if variable.input_type == ValueInputType.VARIABLE:
value = evaluate_expression( value = self.variable_pool.get_value(variable.value)
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
else: else:
value = TypeTransformer.transform(variable.value, variable.type) value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
@@ -98,12 +92,7 @@ class LoopRuntime:
**self.state **self.state
) )
loopstate["node_outputs"][self.node_id] = { loopstate["node_outputs"][self.node_id] = {
variable.name: evaluate_expression( variable.name: self.variable_pool.get_value(variable.value)
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type) else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars for variable in self.typed_config.cycle_vars

View File

@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched # Reuse cached bytes if already fetched
if f.get_content(): if f.get_content():
file_input.set_content(f.get_content()) file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input) text = await svc.extract_document_text(file_input)
chunks.append(text) chunks.append(text)
except Exception as e: except Exception as e:
logger.error( logger.error(

View File

@@ -1,19 +1,23 @@
import asyncio
import logging import logging
import uuid import uuid
from typing import Any from typing import Any
from langchain_core.documents import Document
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType from app.models import knowledge_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return { return {
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc) unique.append(doc)
return unique return unique
def _get_existing_kb_ids(self, db, kb_ids): def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
""" """
Resolve all accessible and valid knowledge base IDs for retrieval. Reorder the list of document blocks and return the top_k results most relevant to the query
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Args: Args:
db: Database session. query: query string
kb_ids (list[UUID]): Knowledge base IDs from node configuration. docs: List of document chunk to be rearranged
top_k: The number of top-level documents returned
Returns: Returns:
list[UUID]: Final list of valid knowledge base IDs. Rearranged document chunk list (sorted in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
""" """
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private) reranker = self.get_reranker_model()
# parameter validation
existing_ids = knowledge_repository.get_chunked_knowledgeids( if not docs:
db=db, raise ValueError("retrieval chunks be empty")
filters=filters if top_k <= 0:
) raise ValueError("top_k must be a positive integer")
try:
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) # Convert to LangChain Document object
documents = [
share_ids = knowledge_repository.get_chunked_knowledgeids( Document(
db=db, page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
filters=filters metadata=doc.metadata or {} # Deal with possible None metadata
) )
for doc in docs
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
] ]
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db, # Perform reordering (compress_documents will automatically handle relevance scores and indexing)
filters=filters reranked_docs = list(reranker.compress_documents(documents, query))
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
) )
existing_ids.extend(items) # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
return existing_ids result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank: def get_reranker_model(self) -> RedBearRerank:
""" """
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
) )
return reranker return reranker
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
tasks = []
for child in children: for child in children:
if not (child and child.chunk_num > 0 and child.status == 1): if not (child and child.chunk_num > 0 and child.status == 1):
continue continue
kb_config.kb_id = child.id child_kb_config = kb_config.model_copy()
self.knowledge_retrieval(db, query, rs, child, kb_config) child_kb_config.kb_id = child.id
return tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
return rs
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower() indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type: match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE: case RetrieveType.PARTICIPLE:
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, rs.extend(
indices=indices, await asyncio.to_thread(
score_threshold=kb_config.similarity_threshold)) vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
)
case RetrieveType.SEMANTIC: case RetrieveType.SEMANTIC:
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs.extend(
indices=indices, await asyncio.to_thread(
score_threshold=kb_config.vector_similarity_weight)) vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
)
case RetrieveType.HYBRID: case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs1_task = asyncio.to_thread(
indices=indices, vector_service.search_by_vector, **{
score_threshold=kb_config.vector_similarity_weight) "query": query,
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, "top_k": kb_config.top_k,
indices=indices, "indices": indices,
score_threshold=kb_config.similarity_threshold) "score_threshold": kb_config.vector_similarity_weight
}
)
rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results # Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2) unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs: if not unique_rs:
return return []
if self.typed_config.reranker_id: if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model() rs.extend(
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
)
)
else: else:
rs.extend(sorted( rs.extend(sorted(
unique_rs, unique_rs,
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k]) )[:kb_config.top_k])
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
""" """
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases knowledge_bases = self.typed_config.knowledge_bases
rs = [] rs = []
tasks = []
for kb_config in knowledge_bases: for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge: if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.") raise RuntimeError("The knowledge base does not exist or access is denied.")
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
if not rs: if not rs:
return [] return []
if self.typed_config.reranker_id: if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model() final_rs = await asyncio.to_thread(
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) self.rerank,
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
)
else: else:
final_rs = sorted( final_rs = sorted(
rs, rs,

View File

@@ -4,8 +4,12 @@ from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression from simpleeval import simple_eval, NameNotDefined, InvalidExpression
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class ExpressionEvaluator: class ExpressionEvaluator:
"""Safe expression evaluator for workflow variables and node outputs.""" """Safe expression evaluator for workflow variables and node outputs."""
@@ -15,21 +19,18 @@ class ExpressionEvaluator:
@classmethod @classmethod
def normalize_template(cls, template: str) -> str: def normalize_template(cls, template: str) -> str:
pattern = re.compile( return _NORMALIZE_PATTERN.sub(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}', r'{{ node["\1"].\2 }}',
template template
) )
@classmethod @classmethod
def evaluate( def evaluate(
cls, cls,
expression: str, expression: str,
conv_vars: dict[str, Any], conv_vars: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | None = None
) -> Any: ) -> Any:
""" """
Safely evaluate an expression using workflow variables. Safely evaluate an expression using workflow variables.
@@ -49,18 +50,17 @@ class ExpressionEvaluator:
# Remove Jinja2-style brackets if present # Remove Jinja2-style brackets if present
expression = expression.strip() expression = expression.strip()
expression = cls.normalize_template(expression) expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}" expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
expression = re.sub(pattern, r"\1", expression).strip()
# Build context for evaluation # Build context for evaluation
context = { context = {
"conv": conv_vars, # conversation variables "conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs "node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables "sys": system_vars or {}, # system variables
} }
context.update(conv_vars) # context.update(conv_vars)
context["nodes"] = node_outputs # context["nodes"] = node_outputs
context.update(node_outputs) context.update(node_outputs)
try: try:
@@ -87,10 +87,10 @@ class ExpressionEvaluator:
@staticmethod @staticmethod
def evaluate_bool( def evaluate_bool(
expression: str, expression: str,
conv_var: dict[str, Any], conv_var: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | None = None
) -> bool: ) -> bool:
""" """
Evaluate a boolean expression (for conditions). Evaluate a boolean expression (for conditions).
@@ -140,10 +140,10 @@ class ExpressionEvaluator:
# 便捷函数 # 便捷函数
def evaluate_expression( def evaluate_expression(
expression: str, expression: str,
conv_var: dict[str, Any], conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] system_vars: dict[str, Any] | LazyVariableDict
) -> Any: ) -> Any:
"""Evaluate an expression (convenience function).""" """Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate( return ExpressionEvaluator.evaluate(
@@ -152,11 +152,11 @@ def evaluate_expression(
def evaluate_condition( def evaluate_condition(
expression: str, expression: str,
conv_var: dict[str, Any], conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | LazyVariableDict
) -> bool: ) -> Any:
"""Evaluate a boolean condition expression (convenience function).""" """Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool( return ExpressionEvaluator.evaluate_bool(
expression, conv_var, node_outputs, system_vars expression, conv_var, node_outputs, system_vars

View File

@@ -1,7 +1,8 @@
""" """
模板渲染器 Template Renderer
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。 Provides safe template rendering using Jinja2, supporting variable references
and expressions.
""" """
import logging import logging
@@ -10,11 +11,15 @@ from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
from app.core.workflow.engine.variable_pool import LazyVariableDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class SafeUndefined(Undefined): class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串""" """Return empty string instead of raising error when accessing undefined variables"""
__slots__ = () __slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs): def _fail_with_undefined_error(self, *args, **kwargs):
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
class TemplateRenderer: class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True): def __init__(self, strict: bool = True):
"""初始化渲染器 """Initialize renderer
Args: Args:
strict: 是否使用严格模式(未定义变量会抛出异常) strict: Whether to enable strict mode (raise error on undefined variables)
""" """
self.strict = strict self.strict = strict
self.env = Environment( self.env = Environment(
undefined=StrictUndefined if strict else SafeUndefined, undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
) )
@staticmethod @staticmethod
def normalize_template(template: str) -> str: def normalize_template(template: str) -> str:
pattern = re.compile( """Normalize template syntax (convert numeric node reference to dict access)"""
r"\{\{\s*(\d+)\.(\w+)\s*}}" return _NORMALIZE_PATTERN.sub(
)
return pattern.sub(
r'{{ node["\1"].\2 }}', r'{{ node["\1"].\2 }}',
template template
) )
@@ -53,23 +54,23 @@ class TemplateRenderer:
def render( def render(
self, self,
template: str, template: str,
conv_vars: dict[str, Any], conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | LazyVariableDict | None = None
) -> str: ) -> str:
"""渲染模板 """Render template
Args: Args:
template: 模板字符串 template: Template string
conv_vars: 会话变量 conv_vars: Conversation variables
node_outputs: 节点输出结果 node_outputs: Node outputs
system_vars: 系统变量 system_vars: System variables
Returns: Returns:
渲染后的字符串 Rendered string
Raises: Raises:
ValueError: 模板语法错误或变量未定义 ValueError: If template syntax is invalid or variables are undefined
Examples: Examples:
>>> renderer = TemplateRenderer() >>> renderer = TemplateRenderer()
@@ -82,120 +83,117 @@ class TemplateRenderer:
'Hello World!' 'Hello World!'
>>> renderer.render( >>> renderer.render(
... "分析结果: {{node.analyze.output}}", ... "Analysis result: {{node.analyze.output}}",
... {}, ... {},
... {"analyze": {"output": "正面情绪"}}, ... {"analyze": {"output": "positive sentiment"}},
... {} ... {}
... ) ... )
'分析结果: 正面情绪' 'Analysis result: positive sentiment'
""" """
# 构建命名空间上下文 # Build namespace context
context = { context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}} "conv": conv_vars, # Conversation variables: {{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}} "node": node_outputs, # Node outputs: {{node.node_1.output}}
"sys": system_vars, # 系统变量:{{sys.execution_id}} "sys": system_vars, # System variables: {{sys.execution_id}}
} }
# 支持直接通过节点ID访问节点输出{{llm_qa.output}} # Allow direct access to node outputs by node ID: {{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
if node_outputs: if node_outputs:
context.update(node_outputs) context.update(node_outputs)
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} # # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
if conv_vars: # if conv_vars:
context.update(conv_vars) # context.update(conv_vars)
#
context["nodes"] = node_outputs or {} # 旧语法兼容 # context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template) template = self.normalize_template(template)
try: try:
tmpl = self.env.from_string(template) tmpl = self.env.from_string(template)
return tmpl.render(**context) return tmpl.render(**context)
except TemplateSyntaxError as e: except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}") logger.error(f"Template syntax error: {template}, error: {e}")
raise ValueError(f"模板语法错误: {e}") raise ValueError(f"Template syntax error: {e}")
except UndefinedError as e: except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") logger.error(f"Undefined variable in template: {template}, error: {e}")
raise ValueError(f"未定义的变量: {e}") raise ValueError(f"Undefined variable: {e}")
except Exception as e: except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}") logger.error(f"Template rendering error: {template}, error: {e}")
raise ValueError(f"模板渲染失败: {e}") raise ValueError(f"Template rendering failed: {e}")
def validate(self, template: str) -> list[str]: def validate(self, template: str) -> list[str]:
"""验证模板语法 """Validate template syntax
Args: Args:
template: 模板字符串 template: Template string
Returns: Returns:
错误列表,如果为空则验证通过 List of errors (empty if valid)
Examples: Examples:
>>> renderer = TemplateRenderer() >>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!") >>> renderer.validate("Hello {{var.name}}!")
[] []
>>> renderer.validate("Hello {{var.name") # 缺少结束标记 >>> renderer.validate("Hello {{var.name") # Missing closing tag
['模板语法错误: ...'] ['Template syntax error: ...']
""" """
errors = [] errors = []
try: try:
self.env.from_string(template) self.env.from_string(template)
except TemplateSyntaxError as e: except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}") errors.append(f"Template syntax error: {e}")
except Exception as e: except Exception as e:
errors.append(f"模板验证失败: {e}") errors.append(f"Template validation failed: {e}")
return errors return errors
# 全局渲染器实例(严格模式) # Global renderer instances (strict / lenient)
_strict_renderer = TemplateRenderer(strict=True) _strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False) _lenient_renderer = TemplateRenderer(strict=False)
def render_template( def render_template(
template: str, template: str,
conv_vars: dict[str, Any], conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any], system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True strict: bool = True
) -> str: ) -> str:
"""渲染模板(便捷函数) """Render template (convenience function)
Args: Args:
strict: 严格模式 strict: Whether to use strict mode
template: 模板字符串 template: Template string
conv_vars: 会话变量 conv_vars: Conversation variables
node_outputs: 节点输出 node_outputs: Node outputs
system_vars: 系统变量 system_vars: System variables
Returns: Returns:
渲染后的字符串 Rendered string
Examples: Examples:
>>> render_template( >>> render_template(
... "请分析: {{var.text}}", ... "Analyze: {{var.text}}",
... {"text": "这是一段文本"}, ... {"text": "This is a text"},
... {}, ... {},
... {} ... {}
... ) ... )
'请分析: 这是一段文本' 'Analyze: This is a text'
""" """
renderer = _strict_renderer if strict else _lenient_renderer renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars) return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]: def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数) """Validate template syntax (convenience function)
Args: Args:
template: 模板字符串 template: Template string
Returns: Returns:
错误列表 List of errors
""" """
return _strict_renderer.validate(template) return _strict_renderer.validate(template)

View File

@@ -3,9 +3,9 @@
""" """
import uuid import uuid
from typing import Any, Annotated from typing import Any, Annotated, Literal
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import desc from sqlalchemy import desc, select
from fastapi import Depends from fastapi import Depends
from app.models.workflow_model import ( from app.models.workflow_model import (
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
Returns: Returns:
执行记录列表 执行记录列表
""" """
return self.db.query(WorkflowExecution).filter( stmt = select(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id WorkflowExecution.app_id == app_id
).order_by( ).order_by(
desc(WorkflowExecution.started_at) desc(WorkflowExecution.started_at)
).limit(limit).offset(offset).all() ).limit(limit).offset(offset)
return list(self.db.execute(stmt).scalars())
def get_by_conversation_id( def get_by_conversation_id(
self, self,
conversation_id: uuid.UUID conversation_id: uuid.UUID,
status: Literal["running", "completed", "failed"] = None,
limit_count: int = 50
) -> list[WorkflowExecution]: ) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表 """根据会话 ID 获取执行记录列表
Args: Args:
limit_count:
conversation_id: 会话 ID conversation_id: 会话 ID
status: 状态(可选)
Returns: Returns:
执行记录列表 执行记录列表
""" """
return self.db.query(WorkflowExecution).filter( stmt = select(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id WorkflowExecution.conversation_id == conversation_id
).order_by( )
desc(WorkflowExecution.started_at) if status:
).all() stmt = stmt.filter(WorkflowExecution.status == status)
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
return list(self.db.execute(stmt).scalars())
def count_by_app_id(self, app_id: uuid.UUID) -> int: def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数 """统计应用的执行次数
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
Returns: Returns:
节点执行记录列表(按执行顺序排序) 节点执行记录列表(按执行顺序排序)
""" """
return self.db.query(WorkflowNodeExecution).filter( stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id WorkflowNodeExecution.execution_id == execution_id
).order_by( ).order_by(
WorkflowNodeExecution.execution_order WorkflowNodeExecution.execution_order
).all() )
return list(self.db.execute(stmt).scalars())
def get_by_node_id( def get_by_node_id(
self, self,
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
Returns: Returns:
节点执行记录列表 节点执行记录列表
""" """
return self.db.query(WorkflowNodeExecution).filter( stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id, WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id WorkflowNodeExecution.node_id == node_id
).order_by( ).order_by(
WorkflowNodeExecution.retry_count WorkflowNodeExecution.retry_count
).all() )
return list(self.db.execute(stmt).scalars())
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================

View File

@@ -438,13 +438,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL: if file.transfer_method == TransferMethod.REMOTE_URL:
return True, { return True, {
"type": "text", "type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>" "text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
} }
else: else:
# 本地文件,提取文本内容 # 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file) text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter( file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id FileMetadata.id == file.upload_file_id
).first() ).first()
@@ -542,7 +542,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}" return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file: FileInput) -> str: async def extract_document_text(self, file: FileInput) -> str:
""" """
提取文档文本内容 提取文档文本内容

View File

@@ -561,6 +561,24 @@ class WorkflowService:
storage_type = 'neo4j' storage_type = 'neo4j'
return storage_type, user_rag_memory_id return storage_type, user_rag_memory_id
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
executions = self.execution_repo.get_by_conversation_id(
conversation_id=conversation_id,
status="completed",
limit_count=1
)
if executions:
last_state = executions[0].output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
# input_data["conv"] = conv_vars
# input_data["conv_messages"] = last_state.get("messages") or []
conv_messages = last_state.get("messages") or []
return conv_vars, conv_messages
return None
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
async def run( async def run(
@@ -634,18 +652,11 @@ class WorkflowService:
# 更新状态为运行中 # 更新状态为运行中
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) history = self._get_history_info(conversation_id_uuid)
if history:
for exec_res in executions: conv_vars, conv_messages = history
if exec_res.status == "completed": input_data["conv"] = conv_vars
last_state = exec_res.output_data input_data["conv_messages"] = conv_messages
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow( result = await execute_workflow(
@@ -807,17 +818,11 @@ class WorkflowService:
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files input_data["files"] = files
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) history = self._get_history_info(conversation_id_uuid)
if history:
for exec_res in executions: conv_vars, conv_messages = history
if exec_res.status == "completed": input_data["conv"] = conv_vars
last_state = exec_res.output_data input_data["conv_messages"] = conv_messages
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4() message_id = uuid.uuid4()
async for event in execute_workflow_stream( async for event in execute_workflow_stream(