diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 72e8750f..97ebaa82 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -4,7 +4,7 @@ from typing import Any from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.nodes.base_node import BaseNode, WorkflowState from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig -from app.db import get_db +from app.db import get_db_context from app.models import knowledge_model, knowledgeshare_model from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType @@ -20,9 +20,7 @@ class KnowledgeRetrievalNode(BaseNode): async def execute(self, state: WorkflowState) -> Any: query = self._render_template(self.typed_config.query, state) - db_gen = get_db() - db = next(db_gen) - try: + with get_db_context() as db: filters = [ knowledge_model.Knowledge.id.in_(self.typed_config.kb_ids), knowledge_model.Knowledge.permission_id == knowledge_model.PermissionType.Private, @@ -93,5 +91,3 @@ class KnowledgeRetrievalNode(BaseNode): unique_rs.append(doc) rs = vector_service.rerank(query=query, docs=unique_rs, top_k=self.typed_config.top_k) return [chunk.model_dump() for chunk in rs] - finally: - next(db_gen)