Merge pull request #939 from wanxunyang/feature/add-quota-check-decorator
feat(quota): refactor quota management and rate limiting services
This commit is contained in:
@@ -96,40 +96,8 @@ def require_api_key(
|
|||||||
resource_id=api_key_obj.resource_id,
|
resource_id=api_key_obj.resource_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit)──────────
|
|
||||||
try:
|
|
||||||
from app.models.workspace_model import Workspace
|
|
||||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
|
||||||
|
|
||||||
workspace = db.query(Workspace).filter(
|
|
||||||
Workspace.id == api_key_obj.workspace_id
|
|
||||||
).first()
|
|
||||||
if workspace:
|
|
||||||
quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id)
|
|
||||||
tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None
|
|
||||||
if tenant_qps_limit:
|
|
||||||
rate_limiter = RateLimiterService()
|
|
||||||
tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit(
|
|
||||||
workspace.tenant_id, tenant_qps_limit
|
|
||||||
)
|
|
||||||
if not tenant_ok:
|
|
||||||
raise RateLimitException(
|
|
||||||
"租户 API 调用速率超限",
|
|
||||||
BizCode.API_KEY_QPS_LIMIT_EXCEEDED,
|
|
||||||
rate_headers={
|
|
||||||
"X-RateLimit-Tenant-Limit": str(tenant_info["limit"]),
|
|
||||||
"X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]),
|
|
||||||
"X-RateLimit-Tenant-Reset": str(tenant_info["reset"]),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except RateLimitException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Tenant 限速检查异常,跳过: {e}")
|
|
||||||
# ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
rate_limiter = RateLimiterService()
|
rate_limiter = RateLimiterService()
|
||||||
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj)
|
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
|
||||||
if not is_allowed:
|
if not is_allowed:
|
||||||
logger.warning("API Key 限流触发", extra={
|
logger.warning("API Key 限流触发", extra={
|
||||||
"api_key_id": str(api_key_obj.id),
|
"api_key_id": str(api_key_obj.id),
|
||||||
@@ -138,10 +106,12 @@ def require_api_key(
|
|||||||
"error_msg": error_msg
|
"error_msg": error_msg
|
||||||
})
|
})
|
||||||
# 根据错误消息判断限流类型
|
# 根据错误消息判断限流类型
|
||||||
if "QPS" in error_msg:
|
if "Daily" in error_msg:
|
||||||
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
|
||||||
elif "Daily" in error_msg:
|
|
||||||
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
|
||||||
|
elif "Tenant" in error_msg:
|
||||||
|
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
|
||||||
|
elif "QPS" in error_msg:
|
||||||
|
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
|
||||||
else:
|
else:
|
||||||
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
code = BizCode.API_KEY_QUOTA_EXCEEDED
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class BizCode(IntEnum):
|
|||||||
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
API_KEY_QPS_LIMIT_EXCEEDED = 3014
|
||||||
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
API_KEY_DAILY_LIMIT_EXCEEDED = 3015
|
||||||
API_KEY_QUOTA_EXCEEDED = 3016
|
API_KEY_QUOTA_EXCEEDED = 3016
|
||||||
|
API_KEY_RATE_LIMIT_EXCEEDED = 3017
|
||||||
# 资源(4xxx)
|
# 资源(4xxx)
|
||||||
NOT_FOUND = 4000
|
NOT_FOUND = 4000
|
||||||
USER_NOT_FOUND = 4001
|
USER_NOT_FOUND = 4001
|
||||||
|
|||||||
@@ -6,7 +6,6 @@
|
|||||||
2. 降级到 default_free_plan.py 配置文件(社区版兜底)
|
2. 降级到 default_free_plan.py 配置文件(社区版兜底)
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, Callable, Dict, Any
|
from typing import Optional, Callable, Dict, Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@@ -15,10 +14,13 @@ from sqlalchemy import func
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_auth_logger
|
from app.core.logging_config import get_auth_logger
|
||||||
from app.i18n.exceptions import QuotaExceededError
|
from app.i18n.exceptions import QuotaExceededError, InternalServerError
|
||||||
|
|
||||||
logger = get_auth_logger()
|
logger = get_auth_logger()
|
||||||
|
|
||||||
|
# Redis key 格式常量,与 RateLimiterService.check_qps 保持一致(per api_key 独立计数)
|
||||||
|
API_KEY_QPS_REDIS_KEY = "rate_limit:qps:{api_key_id}"
|
||||||
|
|
||||||
|
|
||||||
def _get_user_from_kwargs(kwargs: dict):
|
def _get_user_from_kwargs(kwargs: dict):
|
||||||
"""从 kwargs 中获取 user 对象"""
|
"""从 kwargs 中获取 user 对象"""
|
||||||
@@ -65,7 +67,9 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
|
|||||||
if share_record:
|
if share_record:
|
||||||
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
|
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
|
||||||
if app:
|
if app:
|
||||||
return app.workspace.tenant_id
|
workspace = db.query(Workspace).filter(Workspace.id == app.workspace_id).first()
|
||||||
|
if workspace:
|
||||||
|
return workspace.tenant_id
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -73,31 +77,52 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
|
|||||||
def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
|
def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取租户的配额配置
|
获取租户的配额配置
|
||||||
|
|
||||||
优先级:
|
优先级:
|
||||||
1. premium 模块的 tenant_subscriptions(SaaS 版)
|
1. premium 模块的 tenant_subscriptions(SaaS 版)
|
||||||
2. default_free_plan.py 配置文件(社区版兜底)
|
2. default_free_plan.py 配置文件(社区版兜底)
|
||||||
"""
|
"""
|
||||||
# 尝试从 premium 模块获取
|
# 尝试从 premium 模块获取(SaaS 版)
|
||||||
try:
|
try:
|
||||||
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
from premium.platform_admin.package_plan_service import TenantSubscriptionService
|
||||||
|
# premium 模块存在,运行时错误不应被静默降级,直接抛出
|
||||||
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
|
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
|
||||||
if quota_config:
|
if quota_config:
|
||||||
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
|
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
|
||||||
return quota_config
|
return quota_config
|
||||||
except (ModuleNotFoundError, ImportError, Exception) as e:
|
# premium 存在但该租户无订阅记录,降级到免费套餐
|
||||||
logger.debug(f"无法从 premium 模块获取配额配置: {e}")
|
logger.debug(f"租户 {tenant_id} 无 premium 订阅,降级到免费套餐")
|
||||||
|
except (ModuleNotFoundError, ImportError):
|
||||||
|
# 社区版:premium 包不存在,正常降级
|
||||||
|
logger.debug("premium 模块不存在,使用社区版免费套餐配额")
|
||||||
|
|
||||||
# 降级到配置文件
|
# 降级到社区版配置文件
|
||||||
try:
|
try:
|
||||||
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
from app.config.default_free_plan import DEFAULT_FREE_PLAN
|
||||||
logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}")
|
logger.debug(f"使用社区版免费套餐配额: tenant={tenant_id}")
|
||||||
return DEFAULT_FREE_PLAN.get("quotas")
|
return DEFAULT_FREE_PLAN.get("quotas")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法从配置文件获取配额: {e}")
|
logger.error(f"无法从配置文件获取配额: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_api_ops_rate_limit(db: Session, tenant_id: UUID) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
获取租户套餐的 API 操作速率限制(QPS 上限)
|
||||||
|
|
||||||
|
该函数兼容社区版和 SaaS 版:
|
||||||
|
- SaaS 版:从 premium 模块的套餐配额读取
|
||||||
|
- 社区版:从 default_free_plan.py 配置文件读取
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: api_ops_rate_limit 值,如果未配置则返回 None
|
||||||
|
"""
|
||||||
|
quota_config = _get_quota_config(db, tenant_id)
|
||||||
|
if quota_config:
|
||||||
|
return quota_config.get("api_ops_rate_limit")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class QuotaUsageRepository:
|
class QuotaUsageRepository:
|
||||||
"""配额使用量数据访问层"""
|
"""配额使用量数据访问层"""
|
||||||
|
|
||||||
@@ -247,41 +272,74 @@ def _check_quota(
|
|||||||
|
|
||||||
def check_workspace_quota(func: Callable) -> Callable:
|
def check_workspace_quota(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
def check_skill_quota(func: Callable) -> Callable:
|
def check_skill_quota(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
_check_quota(db, user.tenant_id, "skill_quota", "skill")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
def check_app_quota(func: Callable) -> Callable:
|
def check_app_quota(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
_check_quota(db, user.tenant_id, "app_quota", "app")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, "app_quota", "app")
|
_check_quota(db, user.tenant_id, "app_quota", "app")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
||||||
@@ -289,12 +347,12 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
|||||||
async def async_wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
if not db:
|
if not db:
|
||||||
logger.warning("配额检查失败:缺少 db 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||||
return await func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
logger.warning("配额检查失败:无法获取 tenant_id")
|
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||||
return await func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
@@ -303,8 +361,8 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
|||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@@ -313,15 +371,26 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
|
|||||||
|
|
||||||
def check_memory_engine_quota(func: Callable) -> Callable:
|
def check_memory_engine_quota(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
|
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
def check_end_user_quota(func: Callable) -> Callable:
|
def check_end_user_quota(func: Callable) -> Callable:
|
||||||
@@ -329,12 +398,12 @@ def check_end_user_quota(func: Callable) -> Callable:
|
|||||||
async def async_wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
if not db:
|
if not db:
|
||||||
logger.warning("配额检查失败:缺少 db 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||||
return await func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
logger.warning("配额检查失败:无法获取 tenant_id")
|
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||||
return await func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
||||||
return await func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
@@ -342,12 +411,12 @@ def check_end_user_quota(func: Callable) -> Callable:
|
|||||||
def sync_wrapper(*args, **kwargs):
|
def sync_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
if not db:
|
if not db:
|
||||||
logger.warning("配额检查失败:缺少 db 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
logger.warning("配额检查失败:无法获取 tenant_id")
|
logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
_check_quota(db, tenant_id, "end_user_quota", "end_user")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@@ -356,88 +425,155 @@ def check_end_user_quota(func: Callable) -> Callable:
|
|||||||
|
|
||||||
def check_ontology_project_quota(func: Callable) -> Callable:
|
def check_ontology_project_quota(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
|
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
def check_model_quota(func: Callable) -> Callable:
|
def check_model_quota(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
def check_model_activation_quota(func: Callable) -> Callable:
|
def check_model_activation_quota(func: Callable) -> Callable:
|
||||||
"""模型激活时的配额检查装饰器"""
|
"""模型激活时的配额检查装饰器"""
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
|
||||||
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||||
model_data = kwargs.get("model_data")
|
model_data = kwargs.get("model_data")
|
||||||
|
|
||||||
if not model_id or not model_data:
|
if not model_id or not model_data:
|
||||||
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||||
return func(*args, **kwargs)
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
if model_data.is_active is True:
|
if model_data.is_active:
|
||||||
try:
|
try:
|
||||||
from app.models.models_model import ModelConfig
|
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
existing_model = ModelConfigService.get_model_by_id(
|
existing_model = ModelConfigService.get_model_by_id(
|
||||||
db=db,
|
db=db,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
tenant_id=user.tenant_id
|
tenant_id=user.tenant_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not existing_model.is_active:
|
if not existing_model.is_active:
|
||||||
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||||
_check_quota(db, user.tenant_id, "model_quota", "model")
|
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
|
|
||||||
|
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
|
||||||
|
model_data = kwargs.get("model_data")
|
||||||
|
|
||||||
|
if not model_id or not model_data:
|
||||||
|
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
if model_data.is_active:
|
||||||
|
try:
|
||||||
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
|
existing_model = ModelConfigService.get_model_by_id(
|
||||||
|
db=db,
|
||||||
|
model_id=model_id,
|
||||||
|
tenant_id=user.tenant_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not existing_model.is_active:
|
||||||
|
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
|
||||||
|
_check_quota(db, user.tenant_id, "model_quota", "model")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
|
|
||||||
|
|
||||||
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
|
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
|
||||||
"""通用配额检查装饰器,支持自定义使用量获取函数"""
|
"""通用配额检查装饰器,支持自定义使用量获取函数"""
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
async def async_wrapper(*args, **kwargs):
|
||||||
db: Session = kwargs.get("db")
|
db: Session = kwargs.get("db")
|
||||||
user = _get_user_from_kwargs(kwargs)
|
user = _get_user_from_kwargs(kwargs)
|
||||||
if not db or not user:
|
if not db or not user:
|
||||||
logger.warning("配额检查失败:缺少 db 或 user 参数")
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
return func(*args, **kwargs)
|
raise InternalServerError()
|
||||||
|
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args, **kwargs):
|
||||||
|
db: Session = kwargs.get("db")
|
||||||
|
user = _get_user_from_kwargs(kwargs)
|
||||||
|
if not db or not user:
|
||||||
|
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
|
||||||
|
raise InternalServerError()
|
||||||
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
return wrapper
|
|
||||||
|
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
# ─── 配额使用统计 ────────────────────────────────────────────────────────────
|
# ─── 配额使用统计 ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
|
async def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
|
||||||
"""获取租户所有配额的使用情况"""
|
"""获取租户所有配额的使用情况"""
|
||||||
quota_config = _get_quota_config(db, tenant_id)
|
quota_config = _get_quota_config(db, tenant_id)
|
||||||
if not quota_config:
|
if not quota_config:
|
||||||
@@ -459,18 +595,25 @@ def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
|
|||||||
|
|
||||||
api_ops_current = 0
|
api_ops_current = 0
|
||||||
try:
|
try:
|
||||||
from app.core.config import settings
|
from app.aioRedis import aio_redis as _aio_redis
|
||||||
import redis
|
from app.models.api_key_model import ApiKey
|
||||||
_now = time.time()
|
from app.models.workspace_model import Workspace
|
||||||
_rk = f"rate_limit:tenant_qps:{tenant_id}"
|
# api_ops_rate_limit 限的是每个 api_key 每秒最高限额
|
||||||
_r = redis.StrictRedis(
|
# 展示当前最接近触发限流的 key 的 QPS(取最大值)
|
||||||
host=settings.REDIS_HOST, port=settings.REDIS_PORT,
|
api_key_ids = db.query(ApiKey.id).join(
|
||||||
db=settings.REDIS_DB, password=settings.REDIS_PASSWORD,
|
Workspace, ApiKey.workspace_id == Workspace.id
|
||||||
decode_responses=True
|
).filter(
|
||||||
)
|
Workspace.tenant_id == tenant_id,
|
||||||
api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf"))
|
ApiKey.is_active.is_(True)
|
||||||
except Exception:
|
).all()
|
||||||
pass
|
for (key_id,) in api_key_ids:
|
||||||
|
_rk = API_KEY_QPS_REDIS_KEY.format(api_key_id=key_id)
|
||||||
|
val = await _aio_redis.get(_rk)
|
||||||
|
count = int(val) if val else 0
|
||||||
|
if count > api_ops_current:
|
||||||
|
api_ops_current = count
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取 api_ops_current 失败,返回 0: {type(e).__name__}: {e}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},
|
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from app.core.quota_manager import (
|
|||||||
get_quota_usage,
|
get_quota_usage,
|
||||||
_check_quota,
|
_check_quota,
|
||||||
QuotaUsageRepository,
|
QuotaUsageRepository,
|
||||||
|
API_KEY_QPS_REDIS_KEY,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@@ -33,4 +34,5 @@ __all__ = [
|
|||||||
"get_quota_usage",
|
"get_quota_usage",
|
||||||
"_check_quota",
|
"_check_quota",
|
||||||
"QuotaUsageRepository",
|
"QuotaUsageRepository",
|
||||||
|
"API_KEY_QPS_REDIS_KEY",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -482,14 +482,39 @@ class RateLimitExceededError(I18nException):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class QuotaExceededError(ForbiddenError):
|
class QuotaExceededError(I18nException):
|
||||||
"""Quota exceeded error."""
|
"""Quota exceeded error (402)."""
|
||||||
|
|
||||||
|
# resource key -> i18n display key
|
||||||
|
_RESOURCE_KEY_MAP = {
|
||||||
|
"workspace": "errors.quota_resources.workspace",
|
||||||
|
"app": "errors.quota_resources.app",
|
||||||
|
"skill": "errors.quota_resources.skill",
|
||||||
|
"knowledge_capacity": "errors.quota_resources.knowledge_capacity",
|
||||||
|
"memory_engine": "errors.quota_resources.memory_engine",
|
||||||
|
"end_user": "errors.quota_resources.end_user",
|
||||||
|
"model": "errors.quota_resources.model",
|
||||||
|
"ontology_project": "errors.quota_resources.ontology_project",
|
||||||
|
"api_ops_rate_limit": "errors.quota_resources.api_ops_rate_limit",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(self, resource: Optional[str] = None, **params):
|
def __init__(self, resource: Optional[str] = None, **params):
|
||||||
|
# Translate resource key to a localized display name before calling super()
|
||||||
if resource:
|
if resource:
|
||||||
params["resource"] = resource
|
resource_i18n_key = self._RESOURCE_KEY_MAP.get(resource)
|
||||||
|
if resource_i18n_key:
|
||||||
|
try:
|
||||||
|
from app.i18n.service import get_translation_service
|
||||||
|
from app.core.config import settings
|
||||||
|
_locale = _current_locale.get() or settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
params["resource"] = get_translation_service().translate(resource_i18n_key, _locale)
|
||||||
|
except Exception:
|
||||||
|
params["resource"] = resource
|
||||||
|
else:
|
||||||
|
params["resource"] = resource
|
||||||
super().__init__(
|
super().__init__(
|
||||||
error_key="errors.api.quota_exceeded",
|
error_key="errors.api.quota_exceeded",
|
||||||
|
status_code=402,
|
||||||
error_code="QUOTA_EXCEEDED",
|
error_code="QUOTA_EXCEEDED",
|
||||||
**params
|
**params
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -106,7 +106,7 @@
|
|||||||
},
|
},
|
||||||
"api": {
|
"api": {
|
||||||
"rate_limit_exceeded": "API rate limit exceeded",
|
"rate_limit_exceeded": "API rate limit exceeded",
|
||||||
"quota_exceeded": "API quota exceeded",
|
"quota_exceeded": "{resource} quota exceeded",
|
||||||
"invalid_api_key": "Invalid API key",
|
"invalid_api_key": "Invalid API key",
|
||||||
"api_key_expired": "API key has expired",
|
"api_key_expired": "API key has expired",
|
||||||
"api_key_revoked": "API key has been revoked",
|
"api_key_revoked": "API key has been revoked",
|
||||||
@@ -114,7 +114,8 @@
|
|||||||
"method_not_allowed": "Method not allowed",
|
"method_not_allowed": "Method not allowed",
|
||||||
"invalid_request": "Invalid request",
|
"invalid_request": "Invalid request",
|
||||||
"missing_parameter": "Missing required parameter: {param}",
|
"missing_parameter": "Missing required parameter: {param}",
|
||||||
"invalid_parameter": "Invalid parameter: {param}"
|
"invalid_parameter": "Invalid parameter: {param}",
|
||||||
|
"api_key_rate_limit_exceeded": "API Key rate limit ({rate_limit}) exceeds tenant plan limit ({limit})"
|
||||||
},
|
},
|
||||||
"database": {
|
"database": {
|
||||||
"connection_failed": "Database connection failed",
|
"connection_failed": "Database connection failed",
|
||||||
@@ -134,5 +135,16 @@
|
|||||||
"invalid_format": "Invalid format: {field}",
|
"invalid_format": "Invalid format: {field}",
|
||||||
"invalid_value": "Invalid value: {field}",
|
"invalid_value": "Invalid value: {field}",
|
||||||
"out_of_range": "Value out of range: {field}"
|
"out_of_range": "Value out of range: {field}"
|
||||||
|
},
|
||||||
|
"quota_resources": {
|
||||||
|
"workspace": "Workspace",
|
||||||
|
"app": "App",
|
||||||
|
"skill": "Skill",
|
||||||
|
"knowledge_capacity": "Knowledge capacity",
|
||||||
|
"memory_engine": "Memory engine",
|
||||||
|
"end_user": "End user",
|
||||||
|
"model": "Model",
|
||||||
|
"ontology_project": "Ontology project",
|
||||||
|
"api_ops_rate_limit": "API ops rate limit"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@
|
|||||||
},
|
},
|
||||||
"api": {
|
"api": {
|
||||||
"rate_limit_exceeded": "API调用频率超限",
|
"rate_limit_exceeded": "API调用频率超限",
|
||||||
"quota_exceeded": "API调用配额已用完",
|
"quota_exceeded": "{resource} 配额已超限",
|
||||||
"invalid_api_key": "无效的API密钥",
|
"invalid_api_key": "无效的API密钥",
|
||||||
"api_key_expired": "API密钥已过期",
|
"api_key_expired": "API密钥已过期",
|
||||||
"api_key_revoked": "API密钥已被撤销",
|
"api_key_revoked": "API密钥已被撤销",
|
||||||
@@ -114,7 +114,8 @@
|
|||||||
"method_not_allowed": "不支持的请求方法",
|
"method_not_allowed": "不支持的请求方法",
|
||||||
"invalid_request": "无效的请求",
|
"invalid_request": "无效的请求",
|
||||||
"missing_parameter": "缺少必需参数:{param}",
|
"missing_parameter": "缺少必需参数:{param}",
|
||||||
"invalid_parameter": "参数无效:{param}"
|
"invalid_parameter": "参数无效:{param}",
|
||||||
|
"api_key_rate_limit_exceeded": "API Key 的 QPS 限制({rate_limit})超过租户套餐上限({limit})"
|
||||||
},
|
},
|
||||||
"database": {
|
"database": {
|
||||||
"connection_failed": "数据库连接失败",
|
"connection_failed": "数据库连接失败",
|
||||||
@@ -134,5 +135,16 @@
|
|||||||
"invalid_format": "格式不正确:{field}",
|
"invalid_format": "格式不正确:{field}",
|
||||||
"invalid_value": "值无效:{field}",
|
"invalid_value": "值无效:{field}",
|
||||||
"out_of_range": "值超出范围:{field}"
|
"out_of_range": "值超出范围:{field}"
|
||||||
|
},
|
||||||
|
"quota_resources": {
|
||||||
|
"workspace": "工作空间",
|
||||||
|
"app": "应用",
|
||||||
|
"skill": "技能",
|
||||||
|
"knowledge_capacity": "知识库容量",
|
||||||
|
"memory_engine": "记忆引擎",
|
||||||
|
"end_user": "终端用户",
|
||||||
|
"model": "模型",
|
||||||
|
"ontology_project": "本体工程",
|
||||||
|
"api_ops_rate_limit": "API 操作速率"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from app.core.exceptions import (
|
|||||||
)
|
)
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.i18n.exceptions import I18nException
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -51,6 +52,22 @@ class ApiKeyService:
|
|||||||
if existing:
|
if existing:
|
||||||
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||||
|
|
||||||
|
# 校验 rate_limit 不能超过租户套餐的 api_ops_rate_limit
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
from app.core.quota_manager import get_api_ops_rate_limit
|
||||||
|
|
||||||
|
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
|
if workspace:
|
||||||
|
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
|
||||||
|
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
|
||||||
|
raise I18nException(
|
||||||
|
error_key="errors.api.api_key_rate_limit_exceeded",
|
||||||
|
status_code=400,
|
||||||
|
error_code="API_KEY_RATE_LIMIT_EXCEEDED",
|
||||||
|
rate_limit=data.rate_limit,
|
||||||
|
limit=tenant_api_ops_limit,
|
||||||
|
)
|
||||||
|
|
||||||
# 生成 API Key
|
# 生成 API Key
|
||||||
api_key = generate_api_key(data.type)
|
api_key = generate_api_key(data.type)
|
||||||
|
|
||||||
@@ -152,6 +169,23 @@ class ApiKeyService:
|
|||||||
if existing:
|
if existing:
|
||||||
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
|
||||||
|
|
||||||
|
# 校验 rate_limit 不能超过租户套餐的 api_ops_rate_limit
|
||||||
|
if data.rate_limit is not None:
|
||||||
|
from app.models.workspace_model import Workspace
|
||||||
|
from app.core.quota_manager import get_api_ops_rate_limit
|
||||||
|
|
||||||
|
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
|
||||||
|
if workspace:
|
||||||
|
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
|
||||||
|
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
|
||||||
|
raise I18nException(
|
||||||
|
error_key="errors.api.api_key_rate_limit_exceeded",
|
||||||
|
status_code=400,
|
||||||
|
error_code="API_KEY_RATE_LIMIT_EXCEEDED",
|
||||||
|
rate_limit=data.rate_limit,
|
||||||
|
limit=tenant_api_ops_limit,
|
||||||
|
)
|
||||||
|
|
||||||
update_data = data.model_dump(exclude_unset=True)
|
update_data = data.model_dump(exclude_unset=True)
|
||||||
ApiKeyRepository.update(db, api_key_id, update_data)
|
ApiKeyRepository.update(db, api_key_id, update_data)
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -248,42 +282,14 @@ class RateLimiterService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.redis = aio_redis
|
self.redis = aio_redis
|
||||||
|
|
||||||
async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
|
|
||||||
"""
|
|
||||||
按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit
|
|
||||||
"""
|
|
||||||
now = time.time()
|
|
||||||
window_start = now - 1 # 1 秒窗口
|
|
||||||
key = f"rate_limit:tenant_qps:{tenant_id}"
|
|
||||||
|
|
||||||
async with self.redis.pipeline() as pipe:
|
|
||||||
# 清理 1 秒前的旧记录
|
|
||||||
pipe.zremrangebyscore(key, 0, window_start)
|
|
||||||
# 加入当前请求(score=时间戳,member=时间戳+随机数保证唯一)
|
|
||||||
pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now})
|
|
||||||
# 统计窗口内请求数
|
|
||||||
pipe.zcard(key)
|
|
||||||
# 设置 key 过期(2 秒后自动清理)
|
|
||||||
pipe.expire(key, 2)
|
|
||||||
results = await pipe.execute()
|
|
||||||
|
|
||||||
current = results[2]
|
|
||||||
remaining = max(0, limit - current)
|
|
||||||
reset_time = int(now) + 1
|
|
||||||
|
|
||||||
return current <= limit, {
|
|
||||||
"limit": limit,
|
|
||||||
"remaining": remaining,
|
|
||||||
"reset": reset_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
|
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
|
||||||
"""
|
"""检查QPS限制
|
||||||
检查QPS限制
|
|
||||||
Returns:
|
Returns:
|
||||||
(is_allowed, rate_limit_info)
|
(is_allowed, rate_limit_info)
|
||||||
"""
|
"""
|
||||||
key = f"rate_limit:qps:{api_key_id}"
|
key = f"rate_limit:qps:{api_key_id}"
|
||||||
|
|
||||||
async with self.redis.pipeline() as pipe:
|
async with self.redis.pipeline() as pipe:
|
||||||
pipe.incr(key)
|
pipe.incr(key)
|
||||||
pipe.expire(key, 1, nx=True) # 1 秒过期
|
pipe.expire(key, 1, nx=True) # 1 秒过期
|
||||||
@@ -295,8 +301,9 @@ class RateLimiterService:
|
|||||||
|
|
||||||
return current <= limit, {
|
return current <= limit, {
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
|
"current": current,
|
||||||
"remaining": remaining,
|
"remaining": remaining,
|
||||||
"reset": reset_time
|
"reset": reset_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def check_daily_requests(
|
async def check_daily_requests(
|
||||||
@@ -304,7 +311,9 @@ class RateLimiterService:
|
|||||||
api_key_id: uuid.UUID,
|
api_key_id: uuid.UUID,
|
||||||
limit: int
|
limit: int
|
||||||
) -> Tuple[bool, dict]:
|
) -> Tuple[bool, dict]:
|
||||||
"""检查日调用量限制"""
|
"""检查日调用量限制。
|
||||||
|
使用原子 INCR,先写后判断,极低概率下允许轻微超限(并发场景下可接受)。
|
||||||
|
"""
|
||||||
today = datetime.now().strftime("%Y%m%d")
|
today = datetime.now().strftime("%Y%m%d")
|
||||||
key = f"rate_limit:daily:{api_key_id}:{today}"
|
key = f"rate_limit:daily:{api_key_id}:{today}"
|
||||||
|
|
||||||
@@ -313,6 +322,7 @@ class RateLimiterService:
|
|||||||
hour=0, minute=0, second=0, microsecond=0
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
)
|
)
|
||||||
expire_seconds = int((tomorrow_0 - now).total_seconds())
|
expire_seconds = int((tomorrow_0 - now).total_seconds())
|
||||||
|
reset_time = int(tomorrow_0.timestamp())
|
||||||
|
|
||||||
async with self.redis.pipeline() as pipe:
|
async with self.redis.pipeline() as pipe:
|
||||||
pipe.incr(key)
|
pipe.incr(key)
|
||||||
@@ -320,36 +330,74 @@ class RateLimiterService:
|
|||||||
results = await pipe.execute()
|
results = await pipe.execute()
|
||||||
|
|
||||||
current = results[0]
|
current = results[0]
|
||||||
remaining = max(0, limit - current)
|
|
||||||
reset_time = int(tomorrow_0.timestamp())
|
|
||||||
|
|
||||||
return current <= limit, {
|
if current > limit:
|
||||||
|
return False, {
|
||||||
|
"limit": limit,
|
||||||
|
"remaining": 0,
|
||||||
|
"reset": reset_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
return True, {
|
||||||
"limit": limit,
|
"limit": limit,
|
||||||
"remaining": remaining,
|
"remaining": max(0, limit - current),
|
||||||
"reset": reset_time
|
"reset": reset_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def check_all_limits(
|
async def check_all_limits(
|
||||||
self,
|
self,
|
||||||
api_key: ApiKey
|
api_key: ApiKey,
|
||||||
|
db: Optional[Session] = None,
|
||||||
) -> Tuple[bool, str, dict]:
|
) -> Tuple[bool, str, dict]:
|
||||||
"""
|
"""
|
||||||
检查所有限制
|
检查所有限制,按以下顺序:
|
||||||
Returns:
|
1. API Key QPS:取 api_key.rate_limit 与套餐 api_ops_rate_limit 的最小值作为限额
|
||||||
(is_allowed, error_message, rate_limit_headers)
|
2. API Key 日调用量
|
||||||
"""
|
"""
|
||||||
# Check QPS
|
# 1. 取套餐限额与 api_key 自身限额的最小值
|
||||||
qps_ok, qps_info = await self.check_qps(
|
effective_limit = api_key.rate_limit
|
||||||
api_key.id,
|
if db is not None:
|
||||||
api_key.rate_limit
|
try:
|
||||||
)
|
from app.models.workspace_model import Workspace
|
||||||
|
from app.core.quota_manager import get_api_ops_rate_limit
|
||||||
|
|
||||||
|
cache_key = f"tenant_api_ops_limit:{api_key.workspace_id}"
|
||||||
|
cached = await self.redis.get(cache_key)
|
||||||
|
if cached is not None:
|
||||||
|
try:
|
||||||
|
tenant_limit = int(cached) if cached != "0" else None
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
cached = None
|
||||||
|
tenant_limit = None
|
||||||
|
|
||||||
|
if cached is None:
|
||||||
|
workspace = db.query(Workspace).filter(Workspace.id == api_key.workspace_id).first()
|
||||||
|
if workspace:
|
||||||
|
tenant_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
|
||||||
|
await self.redis.set(cache_key, str(tenant_limit) if tenant_limit else "0", ex=60)
|
||||||
|
else:
|
||||||
|
tenant_limit = None
|
||||||
|
|
||||||
|
if tenant_limit:
|
||||||
|
effective_limit = min(api_key.rate_limit, tenant_limit)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"获取套餐限额失败,使用 api_key 自身限额: {e}")
|
||||||
|
|
||||||
|
# 用最终有效限额做 QPS 检查
|
||||||
|
qps_ok, qps_info = await self.check_qps(api_key.id, effective_limit)
|
||||||
if not qps_ok:
|
if not qps_ok:
|
||||||
return False, "QPS limit exceeded", {
|
# 判断是套餐限额触发还是 api_key 自身限额触发
|
||||||
|
if tenant_limit and effective_limit == tenant_limit and api_key.rate_limit > tenant_limit:
|
||||||
|
error_msg = "Tenant limit exceeded"
|
||||||
|
else:
|
||||||
|
error_msg = "QPS limit exceeded"
|
||||||
|
return False, error_msg, {
|
||||||
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
||||||
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
||||||
"X-RateLimit-Reset": str(qps_info["reset"])
|
"X-RateLimit-Reset": str(qps_info["reset"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# 2. 检查日调用量
|
||||||
daily_ok, daily_info = await self.check_daily_requests(
|
daily_ok, daily_info = await self.check_daily_requests(
|
||||||
api_key.id,
|
api_key.id,
|
||||||
api_key.daily_request_limit
|
api_key.daily_request_limit
|
||||||
@@ -361,14 +409,13 @@ class RateLimiterService:
|
|||||||
"X-RateLimit-Reset": str(daily_info["reset"])
|
"X-RateLimit-Reset": str(daily_info["reset"])
|
||||||
}
|
}
|
||||||
|
|
||||||
headers = {
|
return True, "", {
|
||||||
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
|
||||||
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
|
||||||
"X-RateLimit-Limit-Day": str(daily_info["limit"]),
|
"X-RateLimit-Limit-Day": str(daily_info["limit"]),
|
||||||
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
|
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
|
||||||
"X-RateLimit-Reset": str(daily_info["reset"])
|
"X-RateLimit-Reset": str(daily_info["reset"]),
|
||||||
}
|
}
|
||||||
return True, "", headers
|
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyAuthService:
|
class ApiKeyAuthService:
|
||||||
|
|||||||
Reference in New Issue
Block a user