fix(db): fix database connection handling
This commit is contained in:
@@ -20,78 +20,74 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
|
|
||||||
async def execute(self, state: WorkflowState) -> Any:
|
async def execute(self, state: WorkflowState) -> Any:
|
||||||
query = self._render_template(self.typed_config.query, state)
|
query = self._render_template(self.typed_config.query, state)
|
||||||
db_gen = get_db()
|
db = next(get_db())
|
||||||
db = next(db_gen)
|
filters = [
|
||||||
try:
|
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
filters = [
|
||||||
|
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
||||||
|
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
||||||
|
knowledge_model.Knowledge.chunk_num > 0,
|
||||||
|
knowledge_model.Knowledge.status == 1
|
||||||
|
]
|
||||||
|
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
|
||||||
|
db=db,
|
||||||
|
filters=filters
|
||||||
|
)
|
||||||
|
if share_ids:
|
||||||
filters = [
|
filters = [
|
||||||
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
|
||||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
|
|
||||||
knowledge_model.Knowledge.chunk_num > 0,
|
|
||||||
knowledge_model.Knowledge.status == 1
|
|
||||||
]
|
]
|
||||||
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||||
db=db,
|
db=db,
|
||||||
filters=filters
|
filters=filters
|
||||||
)
|
)
|
||||||
filters = [
|
existing_ids.extend(items)
|
||||||
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
|
|
||||||
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
|
|
||||||
knowledge_model.Knowledge.chunk_num > 0,
|
|
||||||
knowledge_model.Knowledge.status == 1
|
|
||||||
]
|
|
||||||
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
|
|
||||||
db=db,
|
|
||||||
filters=filters
|
|
||||||
)
|
|
||||||
if share_ids:
|
|
||||||
filters = [
|
|
||||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
|
|
||||||
]
|
|
||||||
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
|
||||||
db=db,
|
|
||||||
filters=filters
|
|
||||||
)
|
|
||||||
existing_ids.extend(items)
|
|
||||||
|
|
||||||
if not existing_ids:
|
if not existing_ids:
|
||||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
||||||
|
|
||||||
kb_id = existing_ids[0]
|
kb_id = existing_ids[0]
|
||||||
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
|
||||||
indices = ",".join(uuid_strs)
|
indices = ",".join(uuid_strs)
|
||||||
|
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
match self.typed_config.retrieve_type:
|
match self.typed_config.retrieve_type:
|
||||||
case RetrieveType.PARTICIPLE:
|
case RetrieveType.PARTICIPLE:
|
||||||
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=self.typed_config.similarity_threshold)
|
score_threshold=self.typed_config.similarity_threshold)
|
||||||
return [chunk.model_dump() for chunk in rs]
|
return [chunk.model_dump() for chunk in rs]
|
||||||
case RetrieveType.SEMANTIC:
|
case RetrieveType.SEMANTIC:
|
||||||
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.vector_similarity_weight)
|
||||||
|
return [chunk.model_dump() for chunk in rs]
|
||||||
|
case _:
|
||||||
|
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=self.typed_config.vector_similarity_weight)
|
||||||
|
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
||||||
indices=indices,
|
indices=indices,
|
||||||
score_threshold=self.typed_config.vector_similarity_weight)
|
score_threshold=self.typed_config.similarity_threshold)
|
||||||
return [chunk.model_dump() for chunk in rs]
|
# Efficient deduplication
|
||||||
case _:
|
seen_ids = set()
|
||||||
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
|
unique_rs = []
|
||||||
indices=indices,
|
for doc in rs1 + rs2:
|
||||||
score_threshold=self.typed_config.vector_similarity_weight)
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
indices=indices,
|
unique_rs.append(doc)
|
||||||
score_threshold=self.typed_config.similarity_threshold)
|
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
|
||||||
# Efficient deduplication
|
return [chunk.model_dump() for chunk in rs]
|
||||||
seen_ids = set()
|
|
||||||
unique_rs = []
|
|
||||||
for doc in rs1 + rs2:
|
|
||||||
if doc.metadata["doc_id"] not in seen_ids:
|
|
||||||
seen_ids.add(doc.metadata["doc_id"])
|
|
||||||
unique_rs.append(doc)
|
|
||||||
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
|
|
||||||
return [chunk.model_dump() for chunk in rs]
|
|
||||||
finally:
|
|
||||||
next(db_gen)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user