chore(api): organize imports and refactor database context management

This commit is contained in:
Ke Sun
2026-03-03 12:30:09 +08:00
parent bdc22c892d
commit 304ccef101
2 changed files with 14 additions and 26 deletions

1
.gitignore vendored
View File

@@ -29,6 +29,7 @@ search_results.json
api/migrations/versions
tmp
files
powers/
# Exclude dep files
huggingface.co/

View File

@@ -1,16 +1,16 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
import json
import os
import re
import shutil
import time
import uuid
from uuid import UUID
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timezone
from math import ceil
from pathlib import Path
import shutil
from typing import Any, Dict, List, Optional
from uuid import UUID
import redis
import requests
@@ -38,7 +38,7 @@ from app.db import get_db, get_db_context
from app.models.document_model import Document
from app.models.file_model import File
from app.models.knowledge_model import Knowledge
from app.schemas import file_schema, document_schema
from app.schemas import document_schema, file_schema
from app.services.memory_agent_service import MemoryAgentService
from app.utils.config_utils import resolve_config_id
@@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID):
Document parsing, vectorization, and storage
"""
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
import trio
import importlib
import trio
importlib.reload(trio)
db = next(get_db()) # Manually call the generator
db_document = None
@@ -297,8 +298,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
build knowledge graph
"""
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
import trio
import importlib
import trio
importlib.reload(trio)
db = next(get_db()) # Manually call the generator
db_documents = None
@@ -932,24 +934,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
with get_db_context() as db:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
db = next(get_db())
try:
with get_db_context() as db:
service = MemoryAgentService()
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
storage_type, user_rag_memory_id)
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
@@ -1049,19 +1045,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
if actual_config_id is None:
try:
from app.services.memory_agent_service import get_end_user_connected_config
db = next(get_db())
try:
with get_db_context() as db:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
finally:
db.close()
except Exception:
# Log but continue - will fail later with proper error
pass
async def _run() -> str:
db = next(get_db())
try:
with get_db_context() as db:
logger.info(
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
service = MemoryAgentService()
@@ -1069,11 +1061,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
user_rag_memory_id, language)
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result
except Exception as e:
logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True)
raise
finally:
db.close()
try:
# 使用 nest_asyncio 来避免事件循环冲突
@@ -1328,9 +1315,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.models.workspace_model import Workspace
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all