From 304ccef1016a7e42ba1199fd32fe7133bb7d6a9d Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Tue, 3 Mar 2026 12:30:09 +0800 Subject: [PATCH] chore(api): organize imports and refactor database context management --- .gitignore | 1 + api/app/tasks.py | 39 +++++++++++++-------------------------- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/.gitignore b/.gitignore index 2fb41537..66d1beb2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ search_results.json api/migrations/versions tmp files +powers/ # Exclude dep files huggingface.co/ diff --git a/api/app/tasks.py b/api/app/tasks.py index 8e3aea85..299d188b 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,16 +1,16 @@ import asyncio -from concurrent.futures import ThreadPoolExecutor import json import os import re +import shutil import time import uuid -from uuid import UUID +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from math import ceil from pathlib import Path -import shutil from typing import Any, Dict, List, Optional +from uuid import UUID import redis import requests @@ -38,7 +38,7 @@ from app.db import get_db, get_db_context from app.models.document_model import Document from app.models.file_model import File from app.models.knowledge_model import Knowledge -from app.schemas import file_schema, document_schema +from app.schemas import document_schema, file_schema from app.services.memory_agent_service import MemoryAgentService from app.utils.config_utils import resolve_config_id @@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID): Document parsing, vectorization, and storage """ # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import trio import importlib + + import trio importlib.reload(trio) db = next(get_db()) # Manually call the generator db_document = None @@ -297,8 +298,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): build knowledge graph """ # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import trio import importlib + + import trio importlib.reload(trio) db = next(get_db()) # Manually call the generator db_documents = None @@ -932,24 +934,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config - db = next(get_db()) - try: + with get_db_context() as db: connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") - finally: - db.close() except Exception: # Log but continue - will fail later with proper error pass async def _run() -> str: - db = next(get_db()) - try: + with get_db_context() as db: service = MemoryAgentService() return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) - finally: - db.close() try: # 使用 nest_asyncio 来避免事件循环冲突 @@ -1049,19 +1045,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config - db = next(get_db()) - try: + with get_db_context() as db: connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") - finally: - db.close() except Exception: # Log but continue - will fail later with proper error pass async def _run() -> str: - db = next(get_db()) - try: + with get_db_context() as db: logger.info( f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() @@ -1069,11 +1061,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result - except Exception as e: - logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True) - raise - finally: - db.close() try: # 使用 nest_asyncio 来避免事件循环冲突 @@ -1328,9 +1315,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: async def _run() -> Dict[str, Any]: from app.core.logging_config import get_api_logger - from app.models.workspace_model import Workspace from app.models.app_model import App from app.models.end_user_model import EndUser + from app.models.workspace_model import Workspace from app.repositories.memory_increment_repository import write_memory_increment from app.services.memory_storage_service import search_all