chore(api): organize imports and refactor database context management
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime, timezone
|
||||
from math import ceil
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import redis
|
||||
import requests
|
||||
@@ -38,7 +38,7 @@ from app.db import get_db, get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.file_model import File
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.schemas import file_schema, document_schema
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
@@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
Document parsing, vectorization, and storage
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
|
||||
import trio
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_document = None
|
||||
@@ -297,8 +298,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
||||
build knowledge graph
|
||||
"""
|
||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||
import trio
|
||||
import importlib
|
||||
|
||||
import trio
|
||||
importlib.reload(trio)
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_documents = None
|
||||
@@ -932,24 +934,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
service = MemoryAgentService()
|
||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
||||
storage_type, user_rag_memory_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
@@ -1049,19 +1045,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
if actual_config_id is None:
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
async def _run() -> str:
|
||||
db = next(get_db())
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
logger.info(
|
||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||
service = MemoryAgentService()
|
||||
@@ -1069,11 +1061,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
||||
user_rag_memory_id, language)
|
||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
@@ -1328,9 +1315,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.workspace_model import Workspace
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
|
||||
Reference in New Issue
Block a user