From e849fed5c174d9b1d16a5335cc08288a7e1db821 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Thu, 19 Mar 2026 19:46:20 +0800 Subject: [PATCH] fix(workflow): enable nested search in knowledge base retrieval node --- api/app/core/workflow/nodes/knowledge/node.py | 90 ++++++++++--------- 1 file changed, 49 insertions(+), 41 deletions(-) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 14f789a9..d3e9efd9 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -5,7 +5,7 @@ from typing import Any from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode @@ -24,6 +24,7 @@ class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]): super().__init__(node_config, workflow_config) self.typed_config: KnowledgeRetrievalNodeConfig | None = None + self.vector_service: ElasticSearchVector | None = None def _output_types(self) -> dict[str, VariableType]: return { @@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker + def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): + if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: + children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) + for child in children: + if not (child and child.chunk_num > 0 and child.status == 1): + continue + kb_config.kb_id = child.id + self.knowledge_retrieval(db, query, rs, child, kb_config) + return + self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + indices = f"Vector_index_{kb_config.kb_id}_Node".lower() + match kb_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold)) + case RetrieveType.SEMANTIC: + rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.vector_similarity_weight)) + case RetrieveType.HYBRID: + rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.vector_similarity_weight) + rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, + indices=indices, + score_threshold=kb_config.similarity_threshold) + + # Deduplicate hybrid retrieval results + unique_rs = self._deduplicate_docs(rs1, rs2) + if not unique_rs: + return + if self.typed_config.reranker_id: + self.vector_service.reranker = self.get_reranker_model() + rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + else: + rs.extend(sorted( + unique_rs, + key=lambda d: d.metadata.get("score", 0), + reverse=True + )[:kb_config.top_k]) + case _: + raise RuntimeError("Unknown retrieval type") + async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ Execute the knowledge retrieval workflow node. @@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode): query = self._render_template(self.typed_config.query, variable_pool) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases - existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases]) - - if not existing_ids: - raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") rs = [] for kb_config in knowledge_bases: db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) if not db_knowledge: raise RuntimeError("The knowledge base does not exist or access is denied.") + self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - indices = f"Vector_index_{kb_config.kb_id}_Node".lower() - match kb_config.retrieve_type: - case RetrieveType.PARTICIPLE: - rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold)) - case RetrieveType.SEMANTIC: - rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight)) - case RetrieveType.HYBRID: - rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold) - - # Deduplicate hy brid retrieval results - unique_rs = self._deduplicate_docs(rs1, rs2) - if not unique_rs: - continue - if self.typed_config.reranker_id: - vector_service.reranker = self.get_reranker_model() - rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) - else: - rs.extend(sorted( - unique_rs, - key=lambda d: d.metadata.get("score", 0), - reverse=True - )[:kb_config.top_k]) - case _: - raise RuntimeError("Unknown retrieval type") if not rs: return [] if self.typed_config.reranker_id: - vector_service.reranker = self.get_reranker_model() - final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + self.vector_service.reranker = self.get_reranker_model() + final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) else: final_rs = sorted( rs,