Merge pull request #633 from SuanmoSuanyangTechnology/fix/knowledge-retrieval
fix(workflow): enable nested search in knowledge base retrieval node
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
@@ -24,6 +24,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
|
self.vector_service: ElasticSearchVector | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
)
|
)
|
||||||
return reranker
|
return reranker
|
||||||
|
|
||||||
|
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||||
|
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||||
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
for child in children:
|
||||||
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
|
continue
|
||||||
|
kb_config.kb_id = child.id
|
||||||
|
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||||
|
return
|
||||||
|
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||||
|
match kb_config.retrieve_type:
|
||||||
|
case RetrieveType.PARTICIPLE:
|
||||||
|
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold))
|
||||||
|
case RetrieveType.SEMANTIC:
|
||||||
|
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight))
|
||||||
|
case RetrieveType.HYBRID:
|
||||||
|
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight)
|
||||||
|
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold)
|
||||||
|
|
||||||
|
# Deduplicate hybrid retrieval results
|
||||||
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
|
if not unique_rs:
|
||||||
|
return
|
||||||
|
if self.typed_config.reranker_id:
|
||||||
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
|
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
|
else:
|
||||||
|
rs.extend(sorted(
|
||||||
|
unique_rs,
|
||||||
|
key=lambda d: d.metadata.get("score", 0),
|
||||||
|
reverse=True
|
||||||
|
)[:kb_config.top_k])
|
||||||
|
case _:
|
||||||
|
raise RuntimeError("Unknown retrieval type")
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the knowledge retrieval workflow node.
|
Execute the knowledge retrieval workflow node.
|
||||||
@@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
query = self._render_template(self.typed_config.query, variable_pool)
|
query = self._render_template(self.typed_config.query, variable_pool)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
knowledge_bases = self.typed_config.knowledge_bases
|
knowledge_bases = self.typed_config.knowledge_bases
|
||||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
|
||||||
|
|
||||||
if not existing_ids:
|
|
||||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
|
||||||
|
|
||||||
rs = []
|
rs = []
|
||||||
for kb_config in knowledge_bases:
|
for kb_config in knowledge_bases:
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
|
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
|
||||||
match kb_config.retrieve_type:
|
|
||||||
case RetrieveType.PARTICIPLE:
|
|
||||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold))
|
|
||||||
case RetrieveType.SEMANTIC:
|
|
||||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight))
|
|
||||||
case RetrieveType.HYBRID:
|
|
||||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight)
|
|
||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold)
|
|
||||||
|
|
||||||
# Deduplicate hy brid retrieval results
|
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
|
||||||
if not unique_rs:
|
|
||||||
continue
|
|
||||||
if self.typed_config.reranker_id:
|
|
||||||
vector_service.reranker = self.get_reranker_model()
|
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
|
||||||
else:
|
|
||||||
rs.extend(sorted(
|
|
||||||
unique_rs,
|
|
||||||
key=lambda d: d.metadata.get("score", 0),
|
|
||||||
reverse=True
|
|
||||||
)[:kb_config.top_k])
|
|
||||||
case _:
|
|
||||||
raise RuntimeError("Unknown retrieval type")
|
|
||||||
if not rs:
|
if not rs:
|
||||||
return []
|
return []
|
||||||
if self.typed_config.reranker_id:
|
if self.typed_config.reranker_id:
|
||||||
vector_service.reranker = self.get_reranker_model()
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
else:
|
else:
|
||||||
final_rs = sorted(
|
final_rs = sorted(
|
||||||
rs,
|
rs,
|
||||||
|
|||||||
Reference in New Issue
Block a user