67 lines
1.8 KiB
Python
67 lines
1.8 KiB
Python
"""
|
|
Concurrency control middleware
|
|
"""
|
|
import asyncio
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import HTTPException, status
|
|
|
|
from app.config import get_config
|
|
from app.logger import get_logger
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class ConcurrencyController:
|
|
def __init__(self):
|
|
self._worker_semaphore: asyncio.Semaphore | None = None
|
|
self._request_counter = 0
|
|
self._lock = asyncio.Lock()
|
|
|
|
config = get_config()
|
|
self.max_requests = config.max_requests
|
|
|
|
def init(self):
|
|
config = get_config()
|
|
self._worker_semaphore = asyncio.Semaphore(config.max_workers)
|
|
|
|
async def _acquire_worker(self):
|
|
if self._worker_semaphore is None:
|
|
self.init()
|
|
async with self._worker_semaphore:
|
|
yield
|
|
|
|
async def _limit_requests(self):
|
|
async with self._lock:
|
|
logger.info(f"Current requests: {self._request_counter}/{self.max_requests}")
|
|
if self._request_counter >= self.max_requests:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail={
|
|
"code": 503,
|
|
"message": "Too many requests",
|
|
"data": None,
|
|
}
|
|
)
|
|
self._request_counter += 1
|
|
try:
|
|
yield
|
|
finally:
|
|
async with self._lock:
|
|
self._request_counter -= 1
|
|
|
|
def acquire_worker(self):
|
|
return asynccontextmanager(self._acquire_worker)()
|
|
|
|
def limit_requests(self):
|
|
return asynccontextmanager(self._limit_requests)()
|
|
|
|
|
|
concurrency = ConcurrencyController()
|
|
|
|
|
|
async def concurrency_guard():
|
|
async with concurrency.limit_requests():
|
|
async with concurrency.acquire_worker():
|
|
yield
|