perf(workflow): expose extract_document_text as instance method, optimize knowledge base parallel search
- Change extract_document_text from private to instance method in multimodal service for external access - Optimize knowledge base search logic to improve parallel retrieval performance
This commit is contained in:
@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
|
||||
# Reuse cached bytes if already fetched
|
||||
if f.get_content():
|
||||
file_input.set_content(f.get_content())
|
||||
text = await svc._extract_document_text(file_input)
|
||||
text = await svc.extract_document_text(file_input)
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.models import knowledge_model, ModelType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
unique.append(doc)
|
||||
return unique
|
||||
|
||||
def _get_existing_kb_ids(self, db, kb_ids):
|
||||
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
Resolve all accessible and valid knowledge base IDs for retrieval.
|
||||
|
||||
This includes:
|
||||
- Private knowledge bases owned by the user
|
||||
- Shared knowledge bases
|
||||
- Source knowledge bases mapped via knowledge sharing relationships
|
||||
|
||||
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||
Args:
|
||||
db: Database session.
|
||||
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
|
||||
query: query string
|
||||
docs: List of document chunk to be rearranged
|
||||
top_k: The number of top-level documents returned
|
||||
|
||||
Returns:
|
||||
list[UUID]: Final list of valid knowledge base IDs.
|
||||
Rearranged document chunk list (sorted in descending order of relevance)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input document list is empty or top_k is invalid
|
||||
"""
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
|
||||
|
||||
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
|
||||
|
||||
share_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
if share_ids:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
|
||||
reranker = self.get_reranker_model()
|
||||
# parameter validation
|
||||
if not docs:
|
||||
raise ValueError("retrieval chunks be empty")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
try:
|
||||
# Convert to LangChain Document object
|
||||
documents = [
|
||||
Document(
|
||||
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters
|
||||
|
||||
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||
|
||||
# Sort in descending order based on relevance score
|
||||
reranked_docs.sort(
|
||||
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||
reverse=True
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
return existing_ids
|
||||
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||
result = []
|
||||
for item in reranked_docs[:top_k]:
|
||||
for doc in docs:
|
||||
if doc.page_content == item.page_content:
|
||||
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||
result.append(doc)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||
|
||||
def get_reranker_model(self) -> RedBearRerank:
|
||||
"""
|
||||
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
|
||||
rs = []
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
tasks = []
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
kb_config.kb_id = child.id
|
||||
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||
return
|
||||
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
child_kb_config = kb_config.model_copy()
|
||||
child_kb_config.kb_id = child.id
|
||||
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
return rs
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
rs1_task = asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
rs2_task = asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
return
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
|
||||
)
|
||||
)
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
return rs
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
rs = []
|
||||
tasks = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
|
||||
if not rs:
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
final_rs = await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
|
||||
)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
|
||||
@@ -158,7 +158,7 @@ _lenient_renderer = TemplateRenderer(strict=False)
|
||||
def render_template(
|
||||
template: str,
|
||||
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict,
|
||||
strict: bool = True
|
||||
) -> str:
|
||||
|
||||
@@ -438,13 +438,13 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return True, {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
text = await self.extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -542,7 +542,7 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
async def extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
|
||||
Reference in New Issue
Block a user