perf(workflow): expose extract_document_text as instance method, optimize knowledge base parallel search

- Change extract_document_text from private to instance method in multimodal service for external access
- Optimize knowledge base search logic to improve parallel retrieval performance
This commit is contained in:
Eternity
2026-03-27 12:02:36 +08:00
parent 7fd00009a2
commit bca43fcc75
4 changed files with 120 additions and 65 deletions

View File

@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched # Reuse cached bytes if already fetched
if f.get_content(): if f.get_content():
file_input.set_content(f.get_content()) file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input) text = await svc.extract_document_text(file_input)
chunks.append(text) chunks.append(text)
except Exception as e: except Exception as e:
logger.error( logger.error(

View File

@@ -1,19 +1,23 @@
import asyncio
import logging import logging
import uuid import uuid
from typing import Any from typing import Any
from langchain_core.documents import Document
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType from app.models import knowledge_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return { return {
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc) unique.append(doc)
return unique return unique
def _get_existing_kb_ids(self, db, kb_ids): def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
""" """
Resolve all accessible and valid knowledge base IDs for retrieval. Reorder the list of document blocks and return the top_k results most relevant to the query
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Args: Args:
db: Database session. query: query string
kb_ids (list[UUID]): Knowledge base IDs from node configuration. docs: List of document chunk to be rearranged
top_k: The number of top-level documents returned
Returns: Returns:
list[UUID]: Final list of valid knowledge base IDs. Rearranged document chunk list (sorted in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
""" """
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private) reranker = self.get_reranker_model()
# parameter validation
existing_ids = knowledge_repository.get_chunked_knowledgeids( if not docs:
db=db, raise ValueError("retrieval chunks be empty")
filters=filters if top_k <= 0:
) raise ValueError("top_k must be a positive integer")
try:
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) # Convert to LangChain Document object
documents = [
share_ids = knowledge_repository.get_chunked_knowledgeids( Document(
db=db, page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
filters=filters metadata=doc.metadata or {} # Deal with possible None metadata
) )
for doc in docs
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
] ]
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db, # Perform reordering (compress_documents will automatically handle relevance scores and indexing)
filters=filters reranked_docs = list(reranker.compress_documents(documents, query))
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
) )
existing_ids.extend(items) # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
return existing_ids result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank: def get_reranker_model(self) -> RedBearRerank:
""" """
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
) )
return reranker return reranker
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
tasks = []
for child in children: for child in children:
if not (child and child.chunk_num > 0 and child.status == 1): if not (child and child.chunk_num > 0 and child.status == 1):
continue continue
kb_config.kb_id = child.id child_kb_config = kb_config.model_copy()
self.knowledge_retrieval(db, query, rs, child, kb_config) child_kb_config.kb_id = child.id
return tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
return rs
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower() indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type: match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE: case RetrieveType.PARTICIPLE:
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, rs.extend(
indices=indices, await asyncio.to_thread(
score_threshold=kb_config.similarity_threshold)) vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
)
case RetrieveType.SEMANTIC: case RetrieveType.SEMANTIC:
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs.extend(
indices=indices, await asyncio.to_thread(
score_threshold=kb_config.vector_similarity_weight)) vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
)
case RetrieveType.HYBRID: case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs1_task = asyncio.to_thread(
indices=indices, vector_service.search_by_vector, **{
score_threshold=kb_config.vector_similarity_weight) "query": query,
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, "top_k": kb_config.top_k,
indices=indices, "indices": indices,
score_threshold=kb_config.similarity_threshold) "score_threshold": kb_config.vector_similarity_weight
}
)
rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results # Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2) unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs: if not unique_rs:
return return []
if self.typed_config.reranker_id: if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model() rs.extend(
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
)
)
else: else:
rs.extend(sorted( rs.extend(sorted(
unique_rs, unique_rs,
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k]) )[:kb_config.top_k])
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
""" """
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases knowledge_bases = self.typed_config.knowledge_bases
rs = [] rs = []
tasks = []
for kb_config in knowledge_bases: for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge: if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.") raise RuntimeError("The knowledge base does not exist or access is denied.")
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
if not rs: if not rs:
return [] return []
if self.typed_config.reranker_id: if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model() final_rs = await asyncio.to_thread(
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) self.rerank,
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
)
else: else:
final_rs = sorted( final_rs = sorted(
rs, rs,

View File

@@ -158,7 +158,7 @@ _lenient_renderer = TemplateRenderer(strict=False)
def render_template( def render_template(
template: str, template: str,
conv_vars: dict[str, Any] | LazyVariableDict, conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any] | LazyVariableDict, node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | LazyVariableDict, system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True strict: bool = True
) -> str: ) -> str:

View File

@@ -438,13 +438,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL: if file.transfer_method == TransferMethod.REMOTE_URL:
return True, { return True, {
"type": "text", "type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>" "text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
} }
else: else:
# 本地文件,提取文本内容 # 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file) text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter( file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id FileMetadata.id == file.upload_file_id
).first() ).first()
@@ -542,7 +542,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}" return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file: FileInput) -> str: async def extract_document_text(self, file: FileInput) -> str:
""" """
提取文档文本内容 提取文档文本内容