diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py index 40641f3c..bd828760 100644 --- a/api/app/core/workflow/nodes/document_extractor/node.py +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode): # Reuse cached bytes if already fetched if f.get_content(): file_input.set_content(f.get_content()) - text = await svc._extract_document_text(file_input) + text = await svc.extract_document_text(file_input) chunks.append(text) except Exception as e: logger.error( diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 92699cb4..d0b6d098 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -1,19 +1,23 @@ +import asyncio import logging import uuid from typing import Any +from langchain_core.documents import Document + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector +from app.core.rag.models.chunk import DocumentChunk +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_read -from app.models import knowledge_model, knowledgeshare_model, ModelType -from app.repositories import knowledge_repository, knowledgeshare_repository +from app.models import knowledge_model, ModelType +from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType from app.services.model_service import ModelConfigService @@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: KnowledgeRetrievalNodeConfig | None = None - self.vector_service: ElasticSearchVector | None = None def _output_types(self) -> dict[str, VariableType]: return { @@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode): unique.append(doc) return unique - def _get_existing_kb_ids(self, db, kb_ids): + def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: """ - Resolve all accessible and valid knowledge base IDs for retrieval. - - This includes: - - Private knowledge bases owned by the user - - Shared knowledge bases - - Source knowledge bases mapped via knowledge sharing relationships - + Reorder the list of document blocks and return the top_k results most relevant to the query Args: - db: Database session. - kb_ids (list[UUID]): Knowledge base IDs from node configuration. + query: query string + docs: List of document chunk to be rearranged + top_k: The number of top-level documents returned Returns: - list[UUID]: Final list of valid knowledge base IDs. + Rearranged document chunk list (sorted in descending order of relevance) + + Raises: + ValueError: If the input document list is empty or top_k is invalid """ - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private) - - existing_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) - - share_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - if share_ids: - filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids) + reranker = self.get_reranker_model() + # parameter validation + if not docs: + raise ValueError("retrieval chunks be empty") + if top_k <= 0: + raise ValueError("top_k must be a positive integer") + try: + # Convert to LangChain Document object + documents = [ + Document( + page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute + metadata=doc.metadata or {} # Deal with possible None metadata + ) + for doc in docs ] - items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters + + # Perform reordering (compress_documents will automatically handle relevance scores and indexing) + reranked_docs = list(reranker.compress_documents(documents, query)) + + # Sort in descending order based on relevance score + reranked_docs.sort( + key=lambda x: x.metadata.get("relevance_score", 0), + reverse=True ) - existing_ids.extend(items) - return existing_ids + # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"] + result = [] + for item in reranked_docs[:top_k]: + for doc in docs: + if doc.page_content == item.page_content: + doc.metadata["score"] = item.metadata["relevance_score"] + result.append(doc) + return result + except Exception as e: + raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e def get_reranker_model(self) -> RedBearRerank: """ @@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker - def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): + async def knowledge_retrieval(self, db, query, db_knowledge, kb_config): + rs = [] if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) + tasks = [] for child in children: if not (child and child.chunk_num > 0 and child.status == 1): continue - kb_config.kb_id = child.id - self.knowledge_retrieval(db, query, rs, child, kb_config) - return - self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + child_kb_config = kb_config.model_copy() + child_kb_config.kb_id = child.id + tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) + return rs + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) indices = f"Vector_index_{kb_config.kb_id}_Node".lower() match kb_config.retrieve_type: case RetrieveType.PARTICIPLE: - rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + ) case RetrieveType.SEMANTIC: - rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + ) case RetrieveType.HYBRID: - rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight) - rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold) + rs1_task = asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + rs2_task = asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + rs1, rs2 = await asyncio.gather(rs1_task, rs2_task) # Deduplicate hybrid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) if not unique_rs: - return + return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + rs.extend( + await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": unique_rs, "top_k": kb_config.top_k} + ) + ) else: rs.extend(sorted( unique_rs, @@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode): )[:kb_config.top_k]) case _: raise RuntimeError("Unknown retrieval type") + return rs async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ @@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode): knowledge_bases = self.typed_config.knowledge_bases rs = [] + tasks = [] for kb_config in knowledge_bases: db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) if not db_knowledge: raise RuntimeError("The knowledge base does not exist or access is denied.") - self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) + tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) if not rs: return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + final_rs = await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k} + ) else: final_rs = sorted( rs, diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index bb1e18bf..2c2d0f67 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -158,7 +158,7 @@ _lenient_renderer = TemplateRenderer(strict=False) def render_template( template: str, conv_vars: dict[str, Any] | LazyVariableDict, - node_outputs: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], system_vars: dict[str, Any] | LazyVariableDict, strict: bool = True ) -> str: diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 4cf3d89d..120cccb7 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -438,13 +438,13 @@ class MultimodalService: if file.transfer_method == TransferMethod.REMOTE_URL: return True, { "type": "text", - "text": f"\n{await self._extract_document_text(file)}\n" + "text": f"\n{await self.extract_document_text(file)}\n" } else: # 本地文件,提取文本内容 server_url = settings.FILE_LOCAL_SERVER_URL file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" - text = await self._extract_document_text(file) + text = await self.extract_document_text(file) file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file.upload_file_id ).first() @@ -542,7 +542,7 @@ class MultimodalService: server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" - async def _extract_document_text(self, file: FileInput) -> str: + async def extract_document_text(self, file: FileInput) -> str: """ 提取文档文本内容