Merge pull request #939 from wanxunyang/feature/add-quota-check-decorator

feat(quota): refactor quota management and rate limiting services
This commit is contained in:
山程漫悟
2026-04-20 18:36:33 +08:00
committed by GitHub
8 changed files with 385 additions and 173 deletions

View File

@@ -96,40 +96,8 @@ def require_api_key(
resource_id=api_key_obj.resource_id, resource_id=api_key_obj.resource_id,
) )
# ── Tenant 级别限速(来自套餐配额 api_ops_rate_limit──────────
try:
from app.models.workspace_model import Workspace
from premium.platform_admin.package_plan_service import TenantSubscriptionService
workspace = db.query(Workspace).filter(
Workspace.id == api_key_obj.workspace_id
).first()
if workspace:
quota = TenantSubscriptionService(db).get_effective_quota(workspace.tenant_id)
tenant_qps_limit = quota.get("api_ops_rate_limit") if quota else None
if tenant_qps_limit:
rate_limiter = RateLimiterService()
tenant_ok, tenant_info = await rate_limiter.check_tenant_rate_limit(
workspace.tenant_id, tenant_qps_limit
)
if not tenant_ok:
raise RateLimitException(
"租户 API 调用速率超限",
BizCode.API_KEY_QPS_LIMIT_EXCEEDED,
rate_headers={
"X-RateLimit-Tenant-Limit": str(tenant_info["limit"]),
"X-RateLimit-Tenant-Remaining": str(tenant_info["remaining"]),
"X-RateLimit-Tenant-Reset": str(tenant_info["reset"]),
}
)
except RateLimitException:
raise
except Exception as e:
logger.warning(f"Tenant 限速检查异常,跳过: {e}")
# ─────────────────────────────────────────────────────────────
rate_limiter = RateLimiterService() rate_limiter = RateLimiterService()
is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj) is_allowed, error_msg, rate_headers = await rate_limiter.check_all_limits(api_key_obj, db=db)
if not is_allowed: if not is_allowed:
logger.warning("API Key 限流触发", extra={ logger.warning("API Key 限流触发", extra={
"api_key_id": str(api_key_obj.id), "api_key_id": str(api_key_obj.id),
@@ -138,10 +106,12 @@ def require_api_key(
"error_msg": error_msg "error_msg": error_msg
}) })
# 根据错误消息判断限流类型 # 根据错误消息判断限流类型
if "QPS" in error_msg: if "Daily" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
elif "Daily" in error_msg:
code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED code = BizCode.API_KEY_DAILY_LIMIT_EXCEEDED
elif "Tenant" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED # 租户套餐速率超限,同属 QPS 类
elif "QPS" in error_msg:
code = BizCode.API_KEY_QPS_LIMIT_EXCEEDED
else: else:
code = BizCode.API_KEY_QUOTA_EXCEEDED code = BizCode.API_KEY_QUOTA_EXCEEDED

View File

@@ -31,6 +31,7 @@ class BizCode(IntEnum):
API_KEY_QPS_LIMIT_EXCEEDED = 3014 API_KEY_QPS_LIMIT_EXCEEDED = 3014
API_KEY_DAILY_LIMIT_EXCEEDED = 3015 API_KEY_DAILY_LIMIT_EXCEEDED = 3015
API_KEY_QUOTA_EXCEEDED = 3016 API_KEY_QUOTA_EXCEEDED = 3016
API_KEY_RATE_LIMIT_EXCEEDED = 3017
# 资源4xxx # 资源4xxx
NOT_FOUND = 4000 NOT_FOUND = 4000
USER_NOT_FOUND = 4001 USER_NOT_FOUND = 4001

View File

@@ -6,7 +6,6 @@
2. 降级到 default_free_plan.py 配置文件(社区版兜底) 2. 降级到 default_free_plan.py 配置文件(社区版兜底)
""" """
import asyncio import asyncio
import time
from functools import wraps from functools import wraps
from typing import Optional, Callable, Dict, Any from typing import Optional, Callable, Dict, Any
from uuid import UUID from uuid import UUID
@@ -15,10 +14,13 @@ from sqlalchemy import func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.logging_config import get_auth_logger from app.core.logging_config import get_auth_logger
from app.i18n.exceptions import QuotaExceededError from app.i18n.exceptions import QuotaExceededError, InternalServerError
logger = get_auth_logger() logger = get_auth_logger()
# Redis key 格式常量,与 RateLimiterService.check_qps 保持一致per api_key 独立计数)
API_KEY_QPS_REDIS_KEY = "rate_limit:qps:{api_key_id}"
def _get_user_from_kwargs(kwargs: dict): def _get_user_from_kwargs(kwargs: dict):
"""从 kwargs 中获取 user 对象""" """从 kwargs 中获取 user 对象"""
@@ -65,7 +67,9 @@ def _get_tenant_id_from_kwargs(db: Session, kwargs: dict):
if share_record: if share_record:
app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first() app = db.query(App).filter(App.id == share_record.app_id, App.is_active.is_(True)).first()
if app: if app:
return app.workspace.tenant_id workspace = db.query(Workspace).filter(Workspace.id == app.workspace_id).first()
if workspace:
return workspace.tenant_id
return None return None
@@ -78,26 +82,47 @@ def _get_quota_config(db: Session, tenant_id: UUID) -> Optional[Dict[str, Any]]:
1. premium 模块的 tenant_subscriptionsSaaS 版) 1. premium 模块的 tenant_subscriptionsSaaS 版)
2. default_free_plan.py 配置文件(社区版兜底) 2. default_free_plan.py 配置文件(社区版兜底)
""" """
# 尝试从 premium 模块获取 # 尝试从 premium 模块获取SaaS 版)
try: try:
from premium.platform_admin.package_plan_service import TenantSubscriptionService from premium.platform_admin.package_plan_service import TenantSubscriptionService
# premium 模块存在,运行时错误不应被静默降级,直接抛出
quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id) quota_config = TenantSubscriptionService(db).get_effective_quota(tenant_id)
if quota_config: if quota_config:
logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置") logger.debug(f"从 premium 模块获取租户 {tenant_id} 配额配置")
return quota_config return quota_config
except (ModuleNotFoundError, ImportError, Exception) as e: # premium 存在但该租户无订阅记录,降级到免费套餐
logger.debug(f"无法从 premium 模块获取配额配置: {e}") logger.debug(f"租户 {tenant_id} 无 premium 订阅,降级到免费套餐")
except (ModuleNotFoundError, ImportError):
# 社区版premium 包不存在,正常降级
logger.debug("premium 模块不存在,使用社区版免费套餐配额")
# 降级到配置文件 # 降级到社区版配置文件
try: try:
from app.config.default_free_plan import DEFAULT_FREE_PLAN from app.config.default_free_plan import DEFAULT_FREE_PLAN
logger.info(f"使用配置文件中的免费套餐配额: tenant={tenant_id}") logger.debug(f"使用社区版免费套餐配额: tenant={tenant_id}")
return DEFAULT_FREE_PLAN.get("quotas") return DEFAULT_FREE_PLAN.get("quotas")
except Exception as e: except Exception as e:
logger.error(f"无法从配置文件获取配额: {e}") logger.error(f"无法从配置文件获取配额: {e}")
return None return None
def get_api_ops_rate_limit(db: Session, tenant_id: UUID) -> Optional[int]:
"""
获取租户套餐的 API 操作速率限制QPS 上限)
该函数兼容社区版和 SaaS 版:
- SaaS 版:从 premium 模块的套餐配额读取
- 社区版:从 default_free_plan.py 配置文件读取
Returns:
int: api_ops_rate_limit 值,如果未配置则返回 None
"""
quota_config = _get_quota_config(db, tenant_id)
if quota_config:
return quota_config.get("api_ops_rate_limit")
return None
class QuotaUsageRepository: class QuotaUsageRepository:
"""配额使用量数据访问层""" """配额使用量数据访问层"""
@@ -247,41 +272,74 @@ def _check_quota(
def check_workspace_quota(func: Callable) -> Callable: def check_workspace_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "workspace_quota", "workspace")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "workspace_quota", "workspace") _check_quota(db, user.tenant_id, "workspace_quota", "workspace")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_skill_quota(func: Callable) -> Callable: def check_skill_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "skill_quota", "skill")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "skill_quota", "skill") _check_quota(db, user.tenant_id, "skill_quota", "skill")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_app_quota(func: Callable) -> Callable: def check_app_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "app_quota", "app")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "app_quota", "app") _check_quota(db, user.tenant_id, "app_quota", "app")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_knowledge_capacity_quota(func: Callable) -> Callable: def check_knowledge_capacity_quota(func: Callable) -> Callable:
@@ -289,12 +347,12 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
if not db: if not db:
logger.warning("配额检查失败:缺少 db 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs) tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id: if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id") logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
_check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity") _check_quota(db, tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
return await func(*args, **kwargs) return await func(*args, **kwargs)
@@ -303,8 +361,8 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity") _check_quota(db, user.tenant_id, "knowledge_capacity_quota", "knowledge_capacity")
return func(*args, **kwargs) return func(*args, **kwargs)
@@ -313,15 +371,26 @@ def check_knowledge_capacity_quota(func: Callable) -> Callable:
def check_memory_engine_quota(func: Callable) -> Callable: def check_memory_engine_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine") _check_quota(db, user.tenant_id, "memory_engine_quota", "memory_engine")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_end_user_quota(func: Callable) -> Callable: def check_end_user_quota(func: Callable) -> Callable:
@@ -329,12 +398,12 @@ def check_end_user_quota(func: Callable) -> Callable:
async def async_wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
if not db: if not db:
logger.warning("配额检查失败:缺少 db 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs) tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id: if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id") logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
return await func(*args, **kwargs) raise InternalServerError()
_check_quota(db, tenant_id, "end_user_quota", "end_user") _check_quota(db, tenant_id, "end_user_quota", "end_user")
return await func(*args, **kwargs) return await func(*args, **kwargs)
@@ -342,12 +411,12 @@ def check_end_user_quota(func: Callable) -> Callable:
def sync_wrapper(*args, **kwargs): def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
if not db: if not db:
logger.warning("配额检查失败:缺少 db 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
tenant_id = _get_tenant_id_from_kwargs(db, kwargs) tenant_id = _get_tenant_id_from_kwargs(db, kwargs)
if not tenant_id: if not tenant_id:
logger.warning("配额检查失败:无法获取 tenant_id") logger.error(f"配额检查失败:{func.__name__} 无法获取 tenant_id,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, tenant_id, "end_user_quota", "end_user") _check_quota(db, tenant_id, "end_user_quota", "end_user")
return func(*args, **kwargs) return func(*args, **kwargs)
@@ -356,39 +425,95 @@ def check_end_user_quota(func: Callable) -> Callable:
def check_ontology_project_quota(func: Callable) -> Callable: def check_ontology_project_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project") _check_quota(db, user.tenant_id, "ontology_project_quota", "ontology_project")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_model_quota(func: Callable) -> Callable: def check_model_quota(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, "model_quota", "model")
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, "model_quota", "model") _check_quota(db, user.tenant_id, "model_quota", "model")
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_model_activation_quota(func: Callable) -> Callable: def check_model_activation_quota(func: Callable) -> Callable:
"""模型激活时的配额检查装饰器""" """模型激活时的配额检查装饰器"""
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
model_data = kwargs.get("model_data")
if not model_id or not model_data:
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
return await func(*args, **kwargs)
if model_data.is_active:
try:
from app.services.model_service import ModelConfigService
existing_model = ModelConfigService.get_model_by_id(
db=db,
model_id=model_id,
tenant_id=user.tenant_id
)
if not existing_model.is_active:
logger.info(f"模型激活操作,检查配额: model_id={model_id}, tenant_id={user.tenant_id}")
_check_quota(db, user.tenant_id, "model_quota", "model")
except Exception as e:
logger.error(f"模型激活配额检查异常: model_id={model_id}, error={str(e)}")
raise
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None) model_id = kwargs.get("model_id") or (args[1] if len(args) > 1 else None)
model_data = kwargs.get("model_data") model_data = kwargs.get("model_data")
@@ -397,9 +522,8 @@ def check_model_activation_quota(func: Callable) -> Callable:
logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数") logger.warning("模型激活配额检查失败:缺少 model_id 或 model_data 参数")
return func(*args, **kwargs) return func(*args, **kwargs)
if model_data.is_active is True: if model_data.is_active:
try: try:
from app.models.models_model import ModelConfig
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
existing_model = ModelConfigService.get_model_by_id( existing_model = ModelConfigService.get_model_by_id(
@@ -416,28 +540,40 @@ def check_model_activation_quota(func: Callable) -> Callable:
raise raise
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None): def check_quota(quota_type: str, resource_name: str, usage_func: Optional[Callable] = None):
"""通用配额检查装饰器,支持自定义使用量获取函数""" """通用配额检查装饰器,支持自定义使用量获取函数"""
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): async def async_wrapper(*args, **kwargs):
db: Session = kwargs.get("db") db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs) user = _get_user_from_kwargs(kwargs)
if not db or not user: if not db or not user:
logger.warning("配额检查失败:缺少 db 或 user 参数") logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
return func(*args, **kwargs) raise InternalServerError()
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args, **kwargs):
db: Session = kwargs.get("db")
user = _get_user_from_kwargs(kwargs)
if not db or not user:
logger.error(f"配额检查失败:{func.__name__} 缺少 db 或 user 参数,拒绝请求")
raise InternalServerError()
_check_quota(db, user.tenant_id, quota_type, resource_name, usage_func) _check_quota(db, user.tenant_id, quota_type, resource_name, usage_func)
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapper
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
return decorator return decorator
# ─── 配额使用统计 ──────────────────────────────────────────────────────────── # ─── 配额使用统计 ────────────────────────────────────────────────────────────
def get_quota_usage(db: Session, tenant_id: UUID) -> dict: async def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
"""获取租户所有配额的使用情况""" """获取租户所有配额的使用情况"""
quota_config = _get_quota_config(db, tenant_id) quota_config = _get_quota_config(db, tenant_id)
if not quota_config: if not quota_config:
@@ -459,18 +595,25 @@ def get_quota_usage(db: Session, tenant_id: UUID) -> dict:
api_ops_current = 0 api_ops_current = 0
try: try:
from app.core.config import settings from app.aioRedis import aio_redis as _aio_redis
import redis from app.models.api_key_model import ApiKey
_now = time.time() from app.models.workspace_model import Workspace
_rk = f"rate_limit:tenant_qps:{tenant_id}" # api_ops_rate_limit 限的是每个 api_key 每秒最高限额
_r = redis.StrictRedis( # 展示当前最接近触发限流的 key 的 QPS取最大值
host=settings.REDIS_HOST, port=settings.REDIS_PORT, api_key_ids = db.query(ApiKey.id).join(
db=settings.REDIS_DB, password=settings.REDIS_PASSWORD, Workspace, ApiKey.workspace_id == Workspace.id
decode_responses=True ).filter(
) Workspace.tenant_id == tenant_id,
api_ops_current = int(_r.zcount(_rk, _now - 1, "+inf")) ApiKey.is_active.is_(True)
except Exception: ).all()
pass for (key_id,) in api_key_ids:
_rk = API_KEY_QPS_REDIS_KEY.format(api_key_id=key_id)
val = await _aio_redis.get(_rk)
count = int(val) if val else 0
if count > api_ops_current:
api_ops_current = count
except Exception as e:
logger.warning(f"获取 api_ops_current 失败,返回 0: {type(e).__name__}: {e}")
return { return {
"workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))}, "workspace": {"used": workspace_count, "limit": quota_config.get("workspace_quota"), "percentage": pct(workspace_count, quota_config.get("workspace_quota"))},

View File

@@ -18,6 +18,7 @@ from app.core.quota_manager import (
get_quota_usage, get_quota_usage,
_check_quota, _check_quota,
QuotaUsageRepository, QuotaUsageRepository,
API_KEY_QPS_REDIS_KEY,
) )
__all__ = [ __all__ = [
@@ -33,4 +34,5 @@ __all__ = [
"get_quota_usage", "get_quota_usage",
"_check_quota", "_check_quota",
"QuotaUsageRepository", "QuotaUsageRepository",
"API_KEY_QPS_REDIS_KEY",
] ]

View File

@@ -482,14 +482,39 @@ class RateLimitExceededError(I18nException):
) )
class QuotaExceededError(ForbiddenError): class QuotaExceededError(I18nException):
"""Quota exceeded error.""" """Quota exceeded error (402)."""
# resource key -> i18n display key
_RESOURCE_KEY_MAP = {
"workspace": "errors.quota_resources.workspace",
"app": "errors.quota_resources.app",
"skill": "errors.quota_resources.skill",
"knowledge_capacity": "errors.quota_resources.knowledge_capacity",
"memory_engine": "errors.quota_resources.memory_engine",
"end_user": "errors.quota_resources.end_user",
"model": "errors.quota_resources.model",
"ontology_project": "errors.quota_resources.ontology_project",
"api_ops_rate_limit": "errors.quota_resources.api_ops_rate_limit",
}
def __init__(self, resource: Optional[str] = None, **params): def __init__(self, resource: Optional[str] = None, **params):
# Translate resource key to a localized display name before calling super()
if resource: if resource:
params["resource"] = resource resource_i18n_key = self._RESOURCE_KEY_MAP.get(resource)
if resource_i18n_key:
try:
from app.i18n.service import get_translation_service
from app.core.config import settings
_locale = _current_locale.get() or settings.I18N_DEFAULT_LANGUAGE
params["resource"] = get_translation_service().translate(resource_i18n_key, _locale)
except Exception:
params["resource"] = resource
else:
params["resource"] = resource
super().__init__( super().__init__(
error_key="errors.api.quota_exceeded", error_key="errors.api.quota_exceeded",
status_code=402,
error_code="QUOTA_EXCEEDED", error_code="QUOTA_EXCEEDED",
**params **params
) )

View File

@@ -106,7 +106,7 @@
}, },
"api": { "api": {
"rate_limit_exceeded": "API rate limit exceeded", "rate_limit_exceeded": "API rate limit exceeded",
"quota_exceeded": "API quota exceeded", "quota_exceeded": "{resource} quota exceeded",
"invalid_api_key": "Invalid API key", "invalid_api_key": "Invalid API key",
"api_key_expired": "API key has expired", "api_key_expired": "API key has expired",
"api_key_revoked": "API key has been revoked", "api_key_revoked": "API key has been revoked",
@@ -114,7 +114,8 @@
"method_not_allowed": "Method not allowed", "method_not_allowed": "Method not allowed",
"invalid_request": "Invalid request", "invalid_request": "Invalid request",
"missing_parameter": "Missing required parameter: {param}", "missing_parameter": "Missing required parameter: {param}",
"invalid_parameter": "Invalid parameter: {param}" "invalid_parameter": "Invalid parameter: {param}",
"api_key_rate_limit_exceeded": "API Key rate limit ({rate_limit}) exceeds tenant plan limit ({limit})"
}, },
"database": { "database": {
"connection_failed": "Database connection failed", "connection_failed": "Database connection failed",
@@ -134,5 +135,16 @@
"invalid_format": "Invalid format: {field}", "invalid_format": "Invalid format: {field}",
"invalid_value": "Invalid value: {field}", "invalid_value": "Invalid value: {field}",
"out_of_range": "Value out of range: {field}" "out_of_range": "Value out of range: {field}"
},
"quota_resources": {
"workspace": "Workspace",
"app": "App",
"skill": "Skill",
"knowledge_capacity": "Knowledge capacity",
"memory_engine": "Memory engine",
"end_user": "End user",
"model": "Model",
"ontology_project": "Ontology project",
"api_ops_rate_limit": "API ops rate limit"
} }
} }

View File

@@ -106,7 +106,7 @@
}, },
"api": { "api": {
"rate_limit_exceeded": "API调用频率超限", "rate_limit_exceeded": "API调用频率超限",
"quota_exceeded": "API调用配额已用完", "quota_exceeded": "{resource} 配额已超限",
"invalid_api_key": "无效的API密钥", "invalid_api_key": "无效的API密钥",
"api_key_expired": "API密钥已过期", "api_key_expired": "API密钥已过期",
"api_key_revoked": "API密钥已被撤销", "api_key_revoked": "API密钥已被撤销",
@@ -114,7 +114,8 @@
"method_not_allowed": "不支持的请求方法", "method_not_allowed": "不支持的请求方法",
"invalid_request": "无效的请求", "invalid_request": "无效的请求",
"missing_parameter": "缺少必需参数:{param}", "missing_parameter": "缺少必需参数:{param}",
"invalid_parameter": "参数无效:{param}" "invalid_parameter": "参数无效:{param}",
"api_key_rate_limit_exceeded": "API Key 的 QPS 限制({rate_limit})超过租户套餐上限({limit}"
}, },
"database": { "database": {
"connection_failed": "数据库连接失败", "connection_failed": "数据库连接失败",
@@ -134,5 +135,16 @@
"invalid_format": "格式不正确:{field}", "invalid_format": "格式不正确:{field}",
"invalid_value": "值无效:{field}", "invalid_value": "值无效:{field}",
"out_of_range": "值超出范围:{field}" "out_of_range": "值超出范围:{field}"
},
"quota_resources": {
"workspace": "工作空间",
"app": "应用",
"skill": "技能",
"knowledge_capacity": "知识库容量",
"memory_engine": "记忆引擎",
"end_user": "终端用户",
"model": "模型",
"ontology_project": "本体工程",
"api_ops_rate_limit": "API 操作速率"
} }
} }

View File

@@ -19,6 +19,7 @@ from app.core.exceptions import (
) )
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.i18n.exceptions import I18nException
logger = get_business_logger() logger = get_business_logger()
@@ -51,6 +52,22 @@ class ApiKeyService:
if existing: if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 校验 rate_limit 不能超过租户套餐的 api_ops_rate_limit
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
raise I18nException(
error_key="errors.api.api_key_rate_limit_exceeded",
status_code=400,
error_code="API_KEY_RATE_LIMIT_EXCEEDED",
rate_limit=data.rate_limit,
limit=tenant_api_ops_limit,
)
# 生成 API Key # 生成 API Key
api_key = generate_api_key(data.type) api_key = generate_api_key(data.type)
@@ -152,6 +169,23 @@ class ApiKeyService:
if existing: if existing:
raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME) raise BusinessException(f"API Key 名称 {data.name} 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 校验 rate_limit 不能超过租户套餐的 api_ops_rate_limit
if data.rate_limit is not None:
from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
workspace = db.query(Workspace).filter(Workspace.id == workspace_id).first()
if workspace:
tenant_api_ops_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
if tenant_api_ops_limit and data.rate_limit > tenant_api_ops_limit:
raise I18nException(
error_key="errors.api.api_key_rate_limit_exceeded",
status_code=400,
error_code="API_KEY_RATE_LIMIT_EXCEEDED",
rate_limit=data.rate_limit,
limit=tenant_api_ops_limit,
)
update_data = data.model_dump(exclude_unset=True) update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data) ApiKeyRepository.update(db, api_key_id, update_data)
db.commit() db.commit()
@@ -248,42 +282,14 @@ class RateLimiterService:
def __init__(self): def __init__(self):
self.redis = aio_redis self.redis = aio_redis
async def check_tenant_rate_limit(self, tenant_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
"""
按 tenant_id 做 1 秒滑动窗口限速,限制值来自套餐配额 api_ops_rate_limit
"""
now = time.time()
window_start = now - 1 # 1 秒窗口
key = f"rate_limit:tenant_qps:{tenant_id}"
async with self.redis.pipeline() as pipe:
# 清理 1 秒前的旧记录
pipe.zremrangebyscore(key, 0, window_start)
# 加入当前请求score=时间戳member=时间戳+随机数保证唯一)
pipe.zadd(key, {f"{now}:{uuid.uuid4().hex}": now})
# 统计窗口内请求数
pipe.zcard(key)
# 设置 key 过期2 秒后自动清理)
pipe.expire(key, 2)
results = await pipe.execute()
current = results[2]
remaining = max(0, limit - current)
reset_time = int(now) + 1
return current <= limit, {
"limit": limit,
"remaining": remaining,
"reset": reset_time,
}
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]: async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
""" """检查QPS限制
检查QPS限制
Returns: Returns:
(is_allowed, rate_limit_info) (is_allowed, rate_limit_info)
""" """
key = f"rate_limit:qps:{api_key_id}" key = f"rate_limit:qps:{api_key_id}"
async with self.redis.pipeline() as pipe: async with self.redis.pipeline() as pipe:
pipe.incr(key) pipe.incr(key)
pipe.expire(key, 1, nx=True) # 1 秒过期 pipe.expire(key, 1, nx=True) # 1 秒过期
@@ -295,8 +301,9 @@ class RateLimiterService:
return current <= limit, { return current <= limit, {
"limit": limit, "limit": limit,
"current": current,
"remaining": remaining, "remaining": remaining,
"reset": reset_time "reset": reset_time,
} }
async def check_daily_requests( async def check_daily_requests(
@@ -304,7 +311,9 @@ class RateLimiterService:
api_key_id: uuid.UUID, api_key_id: uuid.UUID,
limit: int limit: int
) -> Tuple[bool, dict]: ) -> Tuple[bool, dict]:
"""检查日调用量限制""" """检查日调用量限制
使用原子 INCR先写后判断极低概率下允许轻微超限并发场景下可接受
"""
today = datetime.now().strftime("%Y%m%d") today = datetime.now().strftime("%Y%m%d")
key = f"rate_limit:daily:{api_key_id}:{today}" key = f"rate_limit:daily:{api_key_id}:{today}"
@@ -313,6 +322,7 @@ class RateLimiterService:
hour=0, minute=0, second=0, microsecond=0 hour=0, minute=0, second=0, microsecond=0
) )
expire_seconds = int((tomorrow_0 - now).total_seconds()) expire_seconds = int((tomorrow_0 - now).total_seconds())
reset_time = int(tomorrow_0.timestamp())
async with self.redis.pipeline() as pipe: async with self.redis.pipeline() as pipe:
pipe.incr(key) pipe.incr(key)
@@ -320,36 +330,74 @@ class RateLimiterService:
results = await pipe.execute() results = await pipe.execute()
current = results[0] current = results[0]
remaining = max(0, limit - current)
reset_time = int(tomorrow_0.timestamp())
return current <= limit, { if current > limit:
return False, {
"limit": limit,
"remaining": 0,
"reset": reset_time,
}
return True, {
"limit": limit, "limit": limit,
"remaining": remaining, "remaining": max(0, limit - current),
"reset": reset_time "reset": reset_time,
} }
async def check_all_limits( async def check_all_limits(
self, self,
api_key: ApiKey api_key: ApiKey,
db: Optional[Session] = None,
) -> Tuple[bool, str, dict]: ) -> Tuple[bool, str, dict]:
""" """
检查所有限制 检查所有限制,按以下顺序:
Returns: 1. API Key QPS取 api_key.rate_limit 与套餐 api_ops_rate_limit 的最小值作为限额
(is_allowed, error_message, rate_limit_headers) 2. API Key 日调用量
""" """
# Check QPS # 1. 取套餐限额与 api_key 自身限额的最小值
qps_ok, qps_info = await self.check_qps( effective_limit = api_key.rate_limit
api_key.id, if db is not None:
api_key.rate_limit try:
) from app.models.workspace_model import Workspace
from app.core.quota_manager import get_api_ops_rate_limit
cache_key = f"tenant_api_ops_limit:{api_key.workspace_id}"
cached = await self.redis.get(cache_key)
if cached is not None:
try:
tenant_limit = int(cached) if cached != "0" else None
except (ValueError, TypeError):
cached = None
tenant_limit = None
if cached is None:
workspace = db.query(Workspace).filter(Workspace.id == api_key.workspace_id).first()
if workspace:
tenant_limit = get_api_ops_rate_limit(db, workspace.tenant_id)
await self.redis.set(cache_key, str(tenant_limit) if tenant_limit else "0", ex=60)
else:
tenant_limit = None
if tenant_limit:
effective_limit = min(api_key.rate_limit, tenant_limit)
except Exception as e:
logger.warning(f"获取套餐限额失败,使用 api_key 自身限额: {e}")
# 用最终有效限额做 QPS 检查
qps_ok, qps_info = await self.check_qps(api_key.id, effective_limit)
if not qps_ok: if not qps_ok:
return False, "QPS limit exceeded", { # 判断是套餐限额触发还是 api_key 自身限额触发
if tenant_limit and effective_limit == tenant_limit and api_key.rate_limit > tenant_limit:
error_msg = "Tenant limit exceeded"
else:
error_msg = "QPS limit exceeded"
return False, error_msg, {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]), "X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]), "X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Reset": str(qps_info["reset"]) "X-RateLimit-Reset": str(qps_info["reset"])
} }
# 2. 检查日调用量
daily_ok, daily_info = await self.check_daily_requests( daily_ok, daily_info = await self.check_daily_requests(
api_key.id, api_key.id,
api_key.daily_request_limit api_key.daily_request_limit
@@ -361,14 +409,13 @@ class RateLimiterService:
"X-RateLimit-Reset": str(daily_info["reset"]) "X-RateLimit-Reset": str(daily_info["reset"])
} }
headers = { return True, "", {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]), "X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]), "X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Limit-Day": str(daily_info["limit"]), "X-RateLimit-Limit-Day": str(daily_info["limit"]),
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]), "X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
"X-RateLimit-Reset": str(daily_info["reset"]) "X-RateLimit-Reset": str(daily_info["reset"]),
} }
return True, "", headers
class ApiKeyAuthService: class ApiKeyAuthService: