From bca43fcc75e41efa68f4fd997ee8acb0587c334d Mon Sep 17 00:00:00 2001
From: Eternity <1533512157@qq.com>
Date: Fri, 27 Mar 2026 12:02:36 +0800
Subject: [PATCH] perf(workflow): expose extract_document_text as instance
method, optimize knowledge base parallel search
- Change extract_document_text from private to instance method in multimodal service for external access
- Optimize knowledge base search logic to improve parallel retrieval performance
---
.../workflow/nodes/document_extractor/node.py | 2 +-
api/app/core/workflow/nodes/knowledge/node.py | 175 ++++++++++++------
.../core/workflow/utils/template_renderer.py | 2 +-
api/app/services/multimodal_service.py | 6 +-
4 files changed, 120 insertions(+), 65 deletions(-)
diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py
index 40641f3c..bd828760 100644
--- a/api/app/core/workflow/nodes/document_extractor/node.py
+++ b/api/app/core/workflow/nodes/document_extractor/node.py
@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
- text = await svc._extract_document_text(file_input)
+ text = await svc.extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(
diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py
index 92699cb4..d0b6d098 100644
--- a/api/app/core/workflow/nodes/knowledge/node.py
+++ b/api/app/core/workflow/nodes/knowledge/node.py
@@ -1,19 +1,23 @@
+import asyncio
import logging
import uuid
from typing import Any
+from langchain_core.documents import Document
+
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
-from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
+from app.core.rag.models.chunk import DocumentChunk
+from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read
-from app.models import knowledge_model, knowledgeshare_model, ModelType
-from app.repositories import knowledge_repository, knowledgeshare_repository
+from app.models import knowledge_model, ModelType
+from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
- self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc)
return unique
- def _get_existing_kb_ids(self, db, kb_ids):
+ def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
"""
- Resolve all accessible and valid knowledge base IDs for retrieval.
-
- This includes:
- - Private knowledge bases owned by the user
- - Shared knowledge bases
- - Source knowledge bases mapped via knowledge sharing relationships
-
+ Reorder the list of document blocks and return the top_k results most relevant to the query
Args:
- db: Database session.
- kb_ids (list[UUID]): Knowledge base IDs from node configuration.
+ query: query string
+ docs: List of document chunk to be rearranged
+ top_k: The number of top-level documents returned
Returns:
- list[UUID]: Final list of valid knowledge base IDs.
+ Rearranged document chunk list (sorted in descending order of relevance)
+
+ Raises:
+ ValueError: If the input document list is empty or top_k is invalid
"""
- filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
-
- existing_ids = knowledge_repository.get_chunked_knowledgeids(
- db=db,
- filters=filters
- )
-
- filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
-
- share_ids = knowledge_repository.get_chunked_knowledgeids(
- db=db,
- filters=filters
- )
-
- if share_ids:
- filters = [
- knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
+ reranker = self.get_reranker_model()
+ # parameter validation
+ if not docs:
+ raise ValueError("retrieval chunks be empty")
+ if top_k <= 0:
+ raise ValueError("top_k must be a positive integer")
+ try:
+ # Convert to LangChain Document object
+ documents = [
+ Document(
+ page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
+ metadata=doc.metadata or {} # Deal with possible None metadata
+ )
+ for doc in docs
]
- items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
- db=db,
- filters=filters
+
+ # Perform reordering (compress_documents will automatically handle relevance scores and indexing)
+ reranked_docs = list(reranker.compress_documents(documents, query))
+
+ # Sort in descending order based on relevance score
+ reranked_docs.sort(
+ key=lambda x: x.metadata.get("relevance_score", 0),
+ reverse=True
)
- existing_ids.extend(items)
- return existing_ids
+ # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
+ result = []
+ for item in reranked_docs[:top_k]:
+ for doc in docs:
+ if doc.page_content == item.page_content:
+ doc.metadata["score"] = item.metadata["relevance_score"]
+ result.append(doc)
+ return result
+ except Exception as e:
+ raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank:
"""
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
- def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
+ async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
+ rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
+ tasks = []
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
- kb_config.kb_id = child.id
- self.knowledge_retrieval(db, query, rs, child, kb_config)
- return
- self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
+ child_kb_config = kb_config.model_copy()
+ child_kb_config.kb_id = child.id
+ tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
+ if tasks:
+ result = await asyncio.gather(*tasks)
+ for _ in result:
+ rs.extend(_)
+ return rs
+ vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
- rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.similarity_threshold))
+ rs.extend(
+ await asyncio.to_thread(
+ vector_service.search_by_full_text, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.similarity_threshold
+ }
+ )
+ )
case RetrieveType.SEMANTIC:
- rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.vector_similarity_weight))
+ rs.extend(
+ await asyncio.to_thread(
+ vector_service.search_by_vector, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.vector_similarity_weight
+ }
+ )
+ )
case RetrieveType.HYBRID:
- rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.vector_similarity_weight)
- rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
- indices=indices,
- score_threshold=kb_config.similarity_threshold)
+ rs1_task = asyncio.to_thread(
+ vector_service.search_by_vector, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.vector_similarity_weight
+ }
+ )
+ rs2_task = asyncio.to_thread(
+ vector_service.search_by_full_text, **{
+ "query": query,
+ "top_k": kb_config.top_k,
+ "indices": indices,
+ "score_threshold": kb_config.similarity_threshold
+ }
+ )
+ rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
- return
+ return []
if self.typed_config.reranker_id:
- self.vector_service.reranker = self.get_reranker_model()
- rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
+ rs.extend(
+ await asyncio.to_thread(
+ self.rerank,
+ **{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
+ )
+ )
else:
rs.extend(sorted(
unique_rs,
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k])
case _:
raise RuntimeError("Unknown retrieval type")
+ return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases
rs = []
+ tasks = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
- self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
+ tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
+ if tasks:
+ result = await asyncio.gather(*tasks)
+ for _ in result:
+ rs.extend(_)
if not rs:
return []
if self.typed_config.reranker_id:
- self.vector_service.reranker = self.get_reranker_model()
- final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
+ final_rs = await asyncio.to_thread(
+ self.rerank,
+ **{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
+ )
else:
final_rs = sorted(
rs,
diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py
index bb1e18bf..2c2d0f67 100644
--- a/api/app/core/workflow/utils/template_renderer.py
+++ b/api/app/core/workflow/utils/template_renderer.py
@@ -158,7 +158,7 @@ _lenient_renderer = TemplateRenderer(strict=False)
def render_template(
template: str,
conv_vars: dict[str, Any] | LazyVariableDict,
- node_outputs: dict[str, Any] | LazyVariableDict,
+ node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True
) -> str:
diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py
index 4cf3d89d..120cccb7 100644
--- a/api/app/services/multimodal_service.py
+++ b/api/app/services/multimodal_service.py
@@ -438,13 +438,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL:
return True, {
"type": "text",
- "text": f"\n{await self._extract_document_text(file)}\n"
+ "text": f"\n{await self.extract_document_text(file)}\n"
}
else:
# 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
- text = await self._extract_document_text(file)
+ text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id
).first()
@@ -542,7 +542,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}"
- async def _extract_document_text(self, file: FileInput) -> str:
+ async def extract_document_text(self, file: FileInput) -> str:
"""
提取文档文本内容