Merge pull request #708 from SuanmoSuanyangTechnology/pref/workflow-engine
pref(workflow): performance optimization
This commit is contained in:
@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
merged = dict(x)
|
||||
for k, v in y.items():
|
||||
merged[k] = merged.get(k, False) or v
|
||||
return merged
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
|
||||
@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
|
||||
|
||||
|
||||
class LazyVariableDict:
|
||||
def __init__(self, source, literal):
|
||||
self._source: dict[str, VariableStruct[Any]] = source
|
||||
self._literal: bool = literal
|
||||
self._cache = {}
|
||||
|
||||
def keys(self):
|
||||
return self._source.keys()
|
||||
|
||||
def _resolve(self, key):
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
var_struct = self._source.get(key)
|
||||
if var_struct is None:
|
||||
raise KeyError(key)
|
||||
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
|
||||
self._cache[key] = value
|
||||
return value
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self._resolve(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._resolve(key)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
|
||||
return self._resolve(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._source
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._source)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._source)
|
||||
|
||||
|
||||
class VariableSelector:
|
||||
"""变量选择器
|
||||
@@ -117,8 +162,7 @@ class VariablePool:
|
||||
|
||||
@staticmethod
|
||||
def transform_selector(selector):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
if len(selector) != 2:
|
||||
raise ValueError(f"Selector not valid - {selector}")
|
||||
@@ -303,6 +347,16 @@ class VariablePool:
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
|
||||
return LazyVariableDict(self.variables.get(namespace, {}), literal)
|
||||
|
||||
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
|
||||
return {
|
||||
ns: LazyVariableDict(vars_dict, literal)
|
||||
for ns, vars_dict in self.variables.items()
|
||||
if ns not in ("sys", "conv")
|
||||
}
|
||||
|
||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -552,9 +552,9 @@ class BaseNode(ABC):
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
||||
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.get_all_system_vars(literal=True),
|
||||
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
|
||||
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.lazy_namespace("sys", literal=True),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
@@ -579,9 +579,9 @@ class BaseNode(ABC):
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
conv_var=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars()
|
||||
conv_var=variable_pool.lazy_namespace("conv"),
|
||||
node_outputs=variable_pool.lazy_all_node_outputs(),
|
||||
system_vars=variable_pool.lazy_namespace("sys")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -85,12 +84,7 @@ class LoopRuntime:
|
||||
|
||||
for variable in self.typed_config.cycle_vars:
|
||||
if variable.input_type == ValueInputType.VARIABLE:
|
||||
value = evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
value = self.variable_pool.get_value(variable.value)
|
||||
else:
|
||||
value = TypeTransformer.transform(variable.value, variable.type)
|
||||
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
|
||||
@@ -98,12 +92,7 @@ class LoopRuntime:
|
||||
**self.state
|
||||
)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
variable.name: self.variable_pool.get_value(variable.value)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
|
||||
@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
|
||||
# Reuse cached bytes if already fetched
|
||||
if f.get_content():
|
||||
file_input.set_content(f.get_content())
|
||||
text = await svc._extract_document_text(file_input)
|
||||
text = await svc.extract_document_text(file_input)
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.models import knowledge_model, ModelType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
unique.append(doc)
|
||||
return unique
|
||||
|
||||
def _get_existing_kb_ids(self, db, kb_ids):
|
||||
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
Resolve all accessible and valid knowledge base IDs for retrieval.
|
||||
|
||||
This includes:
|
||||
- Private knowledge bases owned by the user
|
||||
- Shared knowledge bases
|
||||
- Source knowledge bases mapped via knowledge sharing relationships
|
||||
|
||||
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||
Args:
|
||||
db: Database session.
|
||||
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
|
||||
query: query string
|
||||
docs: List of document chunk to be rearranged
|
||||
top_k: The number of top-level documents returned
|
||||
|
||||
Returns:
|
||||
list[UUID]: Final list of valid knowledge base IDs.
|
||||
Rearranged document chunk list (sorted in descending order of relevance)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input document list is empty or top_k is invalid
|
||||
"""
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
|
||||
|
||||
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
|
||||
|
||||
share_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
if share_ids:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
|
||||
reranker = self.get_reranker_model()
|
||||
# parameter validation
|
||||
if not docs:
|
||||
raise ValueError("retrieval chunks be empty")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
try:
|
||||
# Convert to LangChain Document object
|
||||
documents = [
|
||||
Document(
|
||||
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters
|
||||
|
||||
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||
|
||||
# Sort in descending order based on relevance score
|
||||
reranked_docs.sort(
|
||||
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||
reverse=True
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
return existing_ids
|
||||
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||
result = []
|
||||
for item in reranked_docs[:top_k]:
|
||||
for doc in docs:
|
||||
if doc.page_content == item.page_content:
|
||||
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||
result.append(doc)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||
|
||||
def get_reranker_model(self) -> RedBearRerank:
|
||||
"""
|
||||
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
|
||||
rs = []
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
tasks = []
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
kb_config.kb_id = child.id
|
||||
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||
return
|
||||
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
child_kb_config = kb_config.model_copy()
|
||||
child_kb_config.kb_id = child.id
|
||||
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
return rs
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
rs1_task = asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
rs2_task = asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
return
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
|
||||
)
|
||||
)
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
return rs
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
rs = []
|
||||
tasks = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
|
||||
if not rs:
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
final_rs = await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
|
||||
)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
|
||||
@@ -4,32 +4,33 @@ from typing import Any
|
||||
|
||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||
|
||||
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||
|
||||
|
||||
class ExpressionEvaluator:
|
||||
"""Safe expression evaluator for workflow variables and node outputs."""
|
||||
|
||||
|
||||
# Reserved namespaces
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@classmethod
|
||||
def normalize_template(cls, template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
return _NORMALIZE_PATTERN.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def evaluate(
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Safely evaluate an expression using workflow variables.
|
||||
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
|
||||
# Remove Jinja2-style brackets if present
|
||||
expression = expression.strip()
|
||||
expression = cls.normalize_template(expression)
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
|
||||
|
||||
# Build context for evaluation
|
||||
context = {
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
}
|
||||
|
||||
context.update(conv_vars)
|
||||
context["nodes"] = node_outputs
|
||||
# context.update(conv_vars)
|
||||
# context["nodes"] = node_outputs
|
||||
context.update(node_outputs)
|
||||
|
||||
|
||||
try:
|
||||
# simpleeval supports safe operations:
|
||||
# arithmetic, comparisons, logical ops, attribute/dict/list access
|
||||
result = simple_eval(expression, names=context)
|
||||
return result
|
||||
|
||||
|
||||
except NameNotDefined as e:
|
||||
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
|
||||
|
||||
except InvalidExpression as e:
|
||||
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
|
||||
raise ValueError(f"Invalid expression syntax: {e}")
|
||||
|
||||
|
||||
except SyntaxError as e:
|
||||
logger.error(f"Syntax error in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Syntax error: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
|
||||
raise ValueError(f"Expression evaluation failed: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def evaluate_bool(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate a boolean expression (for conditions).
|
||||
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||
"""
|
||||
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
|
||||
list[str]: List of error messages. Empty if all names are valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
|
||||
for var in variables:
|
||||
var_name = var.get("name", "")
|
||||
|
||||
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
|
||||
errors.append(
|
||||
f"Variable name '{var_name}' is not a valid Python identifier"
|
||||
)
|
||||
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def evaluate_expression(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any]
|
||||
expression: str,
|
||||
conv_var: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict
|
||||
) -> Any:
|
||||
"""Evaluate an expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate(
|
||||
@@ -152,11 +152,11 @@ def evaluate_expression(
|
||||
|
||||
|
||||
def evaluate_condition(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
expression: str,
|
||||
conv_var: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict
|
||||
) -> Any:
|
||||
"""Evaluate a boolean condition expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate_bool(
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
模板渲染器
|
||||
Template Renderer
|
||||
|
||||
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
|
||||
Provides safe template rendering using Jinja2, supporting variable references
|
||||
and expressions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -10,11 +11,15 @@ from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
|
||||
from app.core.workflow.engine.variable_pool import LazyVariableDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||
|
||||
|
||||
class SafeUndefined(Undefined):
|
||||
"""访问未定义属性不会报错,返回空字符串"""
|
||||
"""Return empty string instead of raising error when accessing undefined variables"""
|
||||
__slots__ = ()
|
||||
|
||||
def _fail_with_undefined_error(self, *args, **kwargs):
|
||||
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
def __init__(self, strict: bool = True):
|
||||
"""初始化渲染器
|
||||
|
||||
"""Initialize renderer
|
||||
|
||||
Args:
|
||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||
strict: Whether to enable strict mode (raise error on undefined variables)
|
||||
"""
|
||||
self.strict = strict
|
||||
self.env = Environment(
|
||||
undefined=StrictUndefined if strict else SafeUndefined,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_template(template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
"""Normalize template syntax (convert numeric node reference to dict access)"""
|
||||
return _NORMALIZE_PATTERN.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
@@ -53,24 +54,24 @@ class TemplateRenderer:
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict | None = None
|
||||
) -> str:
|
||||
"""渲染模板
|
||||
|
||||
"""Render template
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
template: Template string
|
||||
conv_vars: Conversation variables
|
||||
node_outputs: Node outputs
|
||||
system_vars: System variables
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Rendered string
|
||||
|
||||
Raises:
|
||||
ValueError: 模板语法错误或变量未定义
|
||||
|
||||
ValueError: If template syntax is invalid or variables are undefined
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.render(
|
||||
@@ -80,122 +81,119 @@ class TemplateRenderer:
|
||||
... {}
|
||||
... )
|
||||
'Hello World!'
|
||||
|
||||
|
||||
>>> renderer.render(
|
||||
... "分析结果: {{node.analyze.output}}",
|
||||
... "Analysis result: {{node.analyze.output}}",
|
||||
... {},
|
||||
... {"analyze": {"output": "正面情绪"}},
|
||||
... {"analyze": {"output": "positive sentiment"}},
|
||||
... {}
|
||||
... )
|
||||
'分析结果: 正面情绪'
|
||||
'Analysis result: positive sentiment'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
# Build namespace context
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars, # 系统变量:{{sys.execution_id}}
|
||||
"conv": conv_vars, # Conversation variables: {{conv.user_name}}
|
||||
"node": node_outputs, # Node outputs: {{node.node_1.output}}
|
||||
"sys": system_vars, # System variables: {{sys.execution_id}}
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
# Allow direct access to node outputs by node ID: {{llm_qa.output}}
|
||||
if node_outputs:
|
||||
context.update(node_outputs)
|
||||
|
||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
if conv_vars:
|
||||
context.update(conv_vars)
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
# # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
# if conv_vars:
|
||||
# context.update(conv_vars)
|
||||
#
|
||||
# context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
template = self.normalize_template(template)
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板语法错误: {e}")
|
||||
|
||||
logger.error(f"Template syntax error: {template}, error: {e}")
|
||||
raise ValueError(f"Template syntax error: {e}")
|
||||
except UndefinedError as e:
|
||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
logger.error(f"Undefined variable in template: {template}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板渲染失败: {e}")
|
||||
logger.error(f"Template rendering error: {template}, error: {e}")
|
||||
raise ValueError(f"Template rendering failed: {e}")
|
||||
|
||||
def validate(self, template: str) -> list[str]:
|
||||
"""验证模板语法
|
||||
|
||||
"""Validate template syntax
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
template: Template string
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
List of errors (empty if valid)
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.validate("Hello {{var.name}}!")
|
||||
[]
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
|
||||
['模板语法错误: ...']
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # Missing closing tag
|
||||
['Template syntax error: ...']
|
||||
"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
self.env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
errors.append(f"模板语法错误: {e}")
|
||||
errors.append(f"Template syntax error: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"模板验证失败: {e}")
|
||||
errors.append(f"Template validation failed: {e}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 全局渲染器实例(严格模式)
|
||||
# Global renderer instances (strict / lenient)
|
||||
_strict_renderer = TemplateRenderer(strict=True)
|
||||
_lenient_renderer = TemplateRenderer(strict=False)
|
||||
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any],
|
||||
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict,
|
||||
strict: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
"""Render template (convenience function)
|
||||
|
||||
Args:
|
||||
strict: 严格模式
|
||||
template: 模板字符串
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
strict: Whether to use strict mode
|
||||
template: Template string
|
||||
conv_vars: Conversation variables
|
||||
node_outputs: Node outputs
|
||||
system_vars: System variables
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Rendered string
|
||||
|
||||
Examples:
|
||||
>>> render_template(
|
||||
... "请分析: {{var.text}}",
|
||||
... {"text": "这是一段文本"},
|
||||
... "Analyze: {{var.text}}",
|
||||
... {"text": "This is a text"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
'Analyze: This is a text'
|
||||
"""
|
||||
renderer = _strict_renderer if strict else _lenient_renderer
|
||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
"""验证模板语法(便捷函数)
|
||||
|
||||
"""Validate template syntax (convenience function)
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
template: Template string
|
||||
|
||||
Returns:
|
||||
错误列表
|
||||
List of errors
|
||||
"""
|
||||
return _strict_renderer.validate(template)
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Annotated
|
||||
from typing import Any, Annotated, Literal
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import desc, select
|
||||
from fastapi import Depends
|
||||
|
||||
from app.models.workflow_model import (
|
||||
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
stmt = select(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).limit(limit).offset(offset).all()
|
||||
).limit(limit).offset(offset)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def get_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
conversation_id: uuid.UUID,
|
||||
status: Literal["running", "completed", "failed"] = None,
|
||||
limit_count: int = 50
|
||||
) -> list[WorkflowExecution]:
|
||||
"""根据会话 ID 获取执行记录列表
|
||||
|
||||
Args:
|
||||
limit_count:
|
||||
conversation_id: 会话 ID
|
||||
status: 状态(可选)
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
stmt = select(WorkflowExecution).filter(
|
||||
WorkflowExecution.conversation_id == conversation_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).all()
|
||||
)
|
||||
if status:
|
||||
stmt = stmt.filter(WorkflowExecution.status == status)
|
||||
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||
"""统计应用的执行次数
|
||||
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
|
||||
Returns:
|
||||
节点执行记录列表(按执行顺序排序)
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
stmt = select(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.execution_order
|
||||
).all()
|
||||
)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def get_by_node_id(
|
||||
self,
|
||||
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
|
||||
Returns:
|
||||
节点执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
stmt = select(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id,
|
||||
WorkflowNodeExecution.node_id == node_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.retry_count
|
||||
).all()
|
||||
)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
@@ -438,13 +438,13 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return True, {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
text = await self.extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -542,7 +542,7 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
async def extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
|
||||
@@ -561,6 +561,24 @@ class WorkflowService:
|
||||
storage_type = 'neo4j'
|
||||
return storage_type, user_rag_memory_id
|
||||
|
||||
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
|
||||
executions = self.execution_repo.get_by_conversation_id(
|
||||
conversation_id=conversation_id,
|
||||
status="completed",
|
||||
limit_count=1
|
||||
)
|
||||
|
||||
if executions:
|
||||
last_state = executions[0].output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
# input_data["conv"] = conv_vars
|
||||
# input_data["conv_messages"] = last_state.get("messages") or []
|
||||
conv_messages = last_state.get("messages") or []
|
||||
return conv_vars, conv_messages
|
||||
return None
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -634,18 +652,11 @@ class WorkflowService:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
|
||||
history = self._get_history_info(conversation_id_uuid)
|
||||
if history:
|
||||
conv_vars, conv_messages = history
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
result = await execute_workflow(
|
||||
@@ -807,17 +818,11 @@ class WorkflowService:
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
history = self._get_history_info(conversation_id_uuid)
|
||||
if history:
|
||||
conv_vars, conv_messages = history
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
message_id = uuid.uuid4()
|
||||
async for event in execute_workflow_stream(
|
||||
|
||||
Reference in New Issue
Block a user