fix(db): fix database connection handling

This commit is contained in:
mengyonghao
2025-12-24 12:21:12 +08:00
parent b99671e04a
commit 38220006a6

View File

@@ -20,78 +20,74 @@ class KnowledgeRetrievalNode(BaseNode):
async def execute(self, state: WorkflowState) -> Any: async def execute(self, state: WorkflowState) -> Any:
query = self._render_template(self.typed_config.query, state) query = self._render_template(self.typed_config.query, state)
db_gen = get_db() db = next(get_db())
db = next(db_gen) filters = [
try: knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
existing_ids = knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [ filters = [
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
] ]
existing_ids = knowledge_repository.get_chunked_knowledgeids( items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db, db=db,
filters=filters filters=filters
) )
filters = [ existing_ids.extend(items)
knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids),
knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share,
knowledge_model.Knowledge.chunk_num > 0,
knowledge_model.Knowledge.status == 1
]
share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids(
db=db,
filters=filters
)
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids)
]
items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db,
filters=filters
)
existing_ids.extend(items)
if not existing_ids: if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
kb_id = existing_ids[0] kb_id = existing_ids[0]
uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids]
indices = ",".join(uuid_strs) indices = ",".join(uuid_strs)
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id)
if not db_knowledge: if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.") raise RuntimeError("The knowledge base does not exist or access is denied.")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
match self.typed_config.retrieve_type: match self.typed_config.retrieve_type:
case RetrieveType.PARTICIPLE: case RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices, indices=indices,
score_threshold=self.typed_config.similarity_threshold) score_threshold=self.typed_config.similarity_threshold)
return [chunk.model_dump() for chunk in rs] return [chunk.model_dump() for chunk in rs]
case RetrieveType.SEMANTIC: case RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
return [chunk.model_dump() for chunk in rs]
case _:
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k,
indices=indices,
score_threshold=self.typed_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k,
indices=indices, indices=indices,
score_threshold=self.typed_config.vector_similarity_weight) score_threshold=self.typed_config.similarity_threshold)
return [chunk.model_dump() for chunk in rs] # Efficient deduplication
case _: seen_ids = set()
rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, unique_rs = []
indices=indices, for doc in rs1 + rs2:
score_threshold=self.typed_config.vector_similarity_weight) if doc.metadata["doc_id"] not in seen_ids:
rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, seen_ids.add(doc.metadata["doc_id"])
indices=indices, unique_rs.append(doc)
score_threshold=self.typed_config.similarity_threshold) rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
# Efficient deduplication return [chunk.model_dump() for chunk in rs]
seen_ids = set()
unique_rs = []
for doc in rs1 + rs2:
if doc.metadata["doc_id"] not in seen_ids:
seen_ids.add(doc.metadata["doc_id"])
unique_rs.append(doc)
rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k)
return [chunk.model_dump() for chunk in rs]
finally:
next(db_gen)