chore(api): organize imports and refactor database context management
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,6 +29,7 @@ search_results.json
|
|||||||
api/migrations/versions
|
api/migrations/versions
|
||||||
tmp
|
tmp
|
||||||
files
|
files
|
||||||
|
powers/
|
||||||
|
|
||||||
# Exclude dep files
|
# Exclude dep files
|
||||||
huggingface.co/
|
huggingface.co/
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import shutil
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from uuid import UUID
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from math import ceil
|
from math import ceil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import shutil
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
import requests
|
import requests
|
||||||
@@ -38,7 +38,7 @@ from app.db import get_db, get_db_context
|
|||||||
from app.models.document_model import Document
|
from app.models.document_model import Document
|
||||||
from app.models.file_model import File
|
from app.models.file_model import File
|
||||||
from app.models.knowledge_model import Knowledge
|
from app.models.knowledge_model import Knowledge
|
||||||
from app.schemas import file_schema, document_schema
|
from app.schemas import document_schema, file_schema
|
||||||
from app.services.memory_agent_service import MemoryAgentService
|
from app.services.memory_agent_service import MemoryAgentService
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
@@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
Document parsing, vectorization, and storage
|
Document parsing, vectorization, and storage
|
||||||
"""
|
"""
|
||||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||||
import trio
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
import trio
|
||||||
importlib.reload(trio)
|
importlib.reload(trio)
|
||||||
db = next(get_db()) # Manually call the generator
|
db = next(get_db()) # Manually call the generator
|
||||||
db_document = None
|
db_document = None
|
||||||
@@ -297,8 +298,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID):
|
|||||||
build knowledge graph
|
build knowledge graph
|
||||||
"""
|
"""
|
||||||
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
# Force re-importing Trio in child processes (to avoid inheriting the state of the parent process)
|
||||||
import trio
|
|
||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
|
import trio
|
||||||
importlib.reload(trio)
|
importlib.reload(trio)
|
||||||
db = next(get_db()) # Manually call the generator
|
db = next(get_db()) # Manually call the generator
|
||||||
db_documents = None
|
db_documents = None
|
||||||
@@ -932,24 +934,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s
|
|||||||
if actual_config_id is None:
|
if actual_config_id is None:
|
||||||
try:
|
try:
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
actual_config_id = connected_config.get("memory_config_id")
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Log but continue - will fail later with proper error
|
# Log but continue - will fail later with proper error
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _run() -> str:
|
async def _run() -> str:
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
try:
|
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db,
|
||||||
storage_type, user_rag_memory_id)
|
storage_type, user_rag_memory_id)
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 nest_asyncio 来避免事件循环冲突
|
# 使用 nest_asyncio 来避免事件循环冲突
|
||||||
@@ -1049,19 +1045,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
if actual_config_id is None:
|
if actual_config_id is None:
|
||||||
try:
|
try:
|
||||||
from app.services.memory_agent_service import get_end_user_connected_config
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
actual_config_id = connected_config.get("memory_config_id")
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Log but continue - will fail later with proper error
|
# Log but continue - will fail later with proper error
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _run() -> str:
|
async def _run() -> str:
|
||||||
db = next(get_db())
|
with get_db_context() as db:
|
||||||
try:
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}")
|
||||||
service = MemoryAgentService()
|
service = MemoryAgentService()
|
||||||
@@ -1069,11 +1061,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
user_rag_memory_id, language)
|
user_rag_memory_id, language)
|
||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用 nest_asyncio 来避免事件循环冲突
|
# 使用 nest_asyncio 来避免事件循环冲突
|
||||||
@@ -1328,9 +1315,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
|||||||
|
|
||||||
async def _run() -> Dict[str, Any]:
|
async def _run() -> Dict[str, Any]:
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
from app.repositories.memory_increment_repository import write_memory_increment
|
from app.repositories.memory_increment_repository import write_memory_increment
|
||||||
from app.services.memory_storage_service import search_all
|
from app.services.memory_storage_service import search_all
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user