feat(sandbox): add Node.js code execution support to sandbox
This commit is contained in:
@@ -1,48 +1,66 @@
|
||||
"""Concurrency control middleware"""
|
||||
"""
|
||||
Concurrency control middleware
|
||||
"""
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from app.config import get_config
|
||||
from app.models import error_response
|
||||
from app.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
# Global semaphores
|
||||
_worker_semaphore: None | asyncio.Semaphore = None
|
||||
_request_counter = 0
|
||||
_request_lock = asyncio.Lock()
|
||||
class ConcurrencyController:
|
||||
def __init__(self):
|
||||
self._worker_semaphore: asyncio.Semaphore | None = None
|
||||
self._request_counter = 0
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
config = get_config()
|
||||
self.max_requests = config.max_requests
|
||||
|
||||
def init(self):
|
||||
config = get_config()
|
||||
self._worker_semaphore = asyncio.Semaphore(config.max_workers)
|
||||
|
||||
async def _acquire_worker(self):
|
||||
if self._worker_semaphore is None:
|
||||
self.init()
|
||||
async with self._worker_semaphore:
|
||||
yield
|
||||
|
||||
async def _limit_requests(self):
|
||||
async with self._lock:
|
||||
logger.info(f"Current requests: {self._request_counter}/{self.max_requests}")
|
||||
if self._request_counter >= self.max_requests:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail={
|
||||
"code": 503,
|
||||
"message": "Too many requests",
|
||||
"data": None,
|
||||
}
|
||||
)
|
||||
self._request_counter += 1
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
async with self._lock:
|
||||
self._request_counter -= 1
|
||||
|
||||
def acquire_worker(self):
|
||||
return asynccontextmanager(self._acquire_worker)()
|
||||
|
||||
def limit_requests(self):
|
||||
return asynccontextmanager(self._limit_requests)()
|
||||
|
||||
|
||||
def init_concurrency_control():
|
||||
"""Initialize concurrency control"""
|
||||
global _worker_semaphore
|
||||
config = get_config()
|
||||
_worker_semaphore = asyncio.Semaphore(config.max_workers)
|
||||
concurrency = ConcurrencyController()
|
||||
|
||||
|
||||
async def check_max_requests():
|
||||
"""Check if max requests limit is reached"""
|
||||
global _request_counter
|
||||
config = get_config()
|
||||
|
||||
async with _request_lock:
|
||||
if _request_counter >= config.max_requests:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=error_response(-503, "Too many requests")
|
||||
)
|
||||
_request_counter += 1
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
async with _request_lock:
|
||||
_request_counter -= 1
|
||||
|
||||
|
||||
async def acquire_worker():
|
||||
"""Acquire a worker slot"""
|
||||
if _worker_semaphore is None:
|
||||
init_concurrency_control()
|
||||
|
||||
async with _worker_semaphore:
|
||||
yield
|
||||
async def concurrency_guard():
|
||||
async with concurrency.limit_requests():
|
||||
async with concurrency.acquire_worker():
|
||||
yield
|
||||
|
||||
Reference in New Issue
Block a user