From 38220006a6ec5b39a4d00f67f1d0ae976f6468da Mon Sep 17 00:00:00 2001 From: mengyonghao <1533512157@qq.com> Date: Wed, 24 Dec 2025 12:21:12 +0800 Subject: [PATCH] fix(db): fix database connection handling --- api/app/core/workflow/nodes/knowledge/node.py | 124 +++++++++--------- 1 file changed, 60 insertions(+), 64 deletions(-) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 72e8750f..a9a76743 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -20,78 +20,74 @@ class KnowledgeRetrievalNode(BaseNode): async def execute(self, state: WorkflowState) -> Any: query = self._render_template(self.typed_config.query, state) - db_gen = get_db() - db = next(db_gen) - try: + db = next(get_db()) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + existing_ids = knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + filters = [ + knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), + knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, + knowledge_model.Knowledge.chunk_num > 0, + knowledge_model.Knowledge.status == 1 + ] + share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( + db=db, + filters=filters + ) + if share_ids: filters = [ - knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 + knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) ] - existing_ids = knowledge_repository.get_chunked_knowledgeids( + items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( db=db, filters=filters ) - filters = [ - knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), - knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Share, - knowledge_model.Knowledge.chunk_num > 0, - knowledge_model.Knowledge.status == 1 - ] - share_ids = knowledge_service.knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - if share_ids: - filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(self.typed_config.kb_ids) - ] - items = knowledgeshare_service.knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters - ) - existing_ids.extend(items) + existing_ids.extend(items) - if not existing_ids: - raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") + if not existing_ids: + raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.") - kb_id = existing_ids[0] - uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] - indices = ",".join(uuid_strs) + kb_id = existing_ids[0] + uuid_strs = [f"Vector_index_{kb_id}_Node".lower() for kb_id in existing_ids] + indices = ",".join(uuid_strs) - db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) - if not db_knowledge: - raise RuntimeError("The knowledge base does not exist or access is denied.") + db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_id) + if not db_knowledge: + raise RuntimeError("The knowledge base does not exist or access is denied.") - vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) - match self.typed_config.retrieve_type: - case RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - return [chunk.model_dump() for chunk in rs] - case RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + match self.typed_config.retrieve_type: + case RetrieveType.PARTICIPLE: + rs = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.similarity_threshold) + return [chunk.model_dump() for chunk in rs] + case RetrieveType.SEMANTIC: + rs = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + return [chunk.model_dump() for chunk in rs] + case _: + rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, + indices=indices, + score_threshold=self.typed_config.vector_similarity_weight) + rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - return [chunk.model_dump() for chunk in rs] - case _: - rs1 = vector_service.search_by_vector(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=query, top_k=self.typed_config.top_k, - indices=indices, - score_threshold=self.typed_config.similarity_threshold) - # Efficient deduplication - seen_ids = set() - unique_rs = [] - for doc in rs1 + rs2: - if doc.metadata["doc_id"] not in seen_ids: - seen_ids.add(doc.metadata["doc_id"]) - unique_rs.append(doc) - rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) - return [chunk.model_dump() for chunk in rs] - finally: - next(db_gen) + score_threshold=self.typed_config.similarity_threshold) + # Efficient deduplication + seen_ids = set() + unique_rs = [] + for doc in rs1 + rs2: + if doc.metadata["doc_id"] not in seen_ids: + seen_ids.add(doc.metadata["doc_id"]) + unique_rs.append(doc) + rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) + return [chunk.model_dump() for chunk in rs]