Files
MemoryBear/api/app/db.py

83 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from contextlib import contextmanager
from typing import Generator
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.ext.declarative import declarative_base
from app.core.config import settings
SQLALCHEMY_DATABASE_URL = f"postgresql://{settings.DB_USER}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/{settings.DB_NAME}"
engine = create_engine(
SQLALCHEMY_DATABASE_URL,
pool_size=settings.DB_POOL_SIZE,
max_overflow=settings.DB_MAX_OVERFLOW,
pool_pre_ping=settings.DB_POOL_PRE_PING,
pool_recycle=settings.DB_POOL_RECYCLE,
pool_timeout=settings.DB_POOL_TIMEOUT,
connect_args={
"options": "-c timezone=UTC -c statement_timeout=60000"
},
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()
# Dependency to get a DB session
def get_db():
db = SessionLocal()
try:
yield db
finally:
try:
if db.in_transaction():
db.rollback()
finally:
db.close()
@contextmanager
def get_db_context() -> Generator[Session, None, None]:
"""
线程安全、池友好的 Session 上下文。
不会自动 commit/rollback调用方自己决定事务边界。
用法:
with get_db_context() as db:
db.add(obj)
db.commit() # 或 db.rollback()
"""
db: Session = SessionLocal()
try:
yield db
finally:
# 如果还有未提交的事务,直接 rollback 防止 idle in transaction
if db.in_transaction():
db.rollback()
db.close()
@contextmanager
def get_db_read() -> Generator[Session, None, None]:
"""只读场景专用,出上下文自动 rollback绝不留下 idle in transaction"""
with get_db_context() as db:
try:
yield db
finally:
db.rollback() # 只读任务无需 commit
db.close()
def get_pool_status():
"""获取连接池状态(用于监控)"""
pool = engine.pool
return {
"pool_size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"total": pool.size() + pool.overflow(),
"usage_percent": round(pool.checkedout() / (pool.size() + pool.overflow()) * 100, 2) if (
pool.size() + pool.overflow()) > 0 else 0
}