fix(workflow): enable nested search in knowledge base retrieval node
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
@@ -24,6 +24,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
@@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
kb_config.kb_id = child.id
|
||||
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||
return
|
||||
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
return
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the knowledge retrieval workflow node.
|
||||
@@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
query = self._render_template(self.typed_config.query, variable_pool)
|
||||
with get_db_read() as db:
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
||||
|
||||
if not existing_ids:
|
||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||
|
||||
rs = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
|
||||
# Deduplicate hy brid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
continue
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
key=lambda d: d.metadata.get("score", 0),
|
||||
reverse=True
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
if not rs:
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
|
||||
Reference in New Issue
Block a user