139 lines
4.5 KiB
Python
139 lines
4.5 KiB
Python
"""API Key Repository"""
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import select, func, and_
|
|
from typing import Optional, List, Tuple
|
|
import uuid
|
|
import datetime
|
|
|
|
from app.models.api_key_model import ApiKey, ApiKeyLog
|
|
from app.schemas import api_key_schema
|
|
|
|
|
|
class ApiKeyRepository:
|
|
"""API Key 数据访问层"""
|
|
|
|
@staticmethod
|
|
def create(db: Session, api_key_data: dict) -> ApiKey:
|
|
"""创建 API Key"""
|
|
api_key = ApiKey(**api_key_data)
|
|
db.add(api_key)
|
|
db.flush()
|
|
return api_key
|
|
|
|
@staticmethod
|
|
def get_by_id(db: Session, api_key_id: uuid.UUID) -> Optional[ApiKey]:
|
|
"""根据 ID 获取 API Key"""
|
|
return db.get(ApiKey, api_key_id)
|
|
|
|
@staticmethod
|
|
def get_by_hash(db: Session, key_hash: str) -> Optional[ApiKey]:
|
|
"""根据哈希值获取 API Key"""
|
|
stmt = select(ApiKey).where(ApiKey.key_hash == key_hash)
|
|
return db.scalars(stmt).first()
|
|
|
|
@staticmethod
|
|
def list_by_workspace(
|
|
db: Session,
|
|
workspace_id: uuid.UUID,
|
|
query: api_key_schema.ApiKeyQuery
|
|
) -> Tuple[List[ApiKey], int]:
|
|
"""列出工作空间的 API Keys"""
|
|
stmt = select(ApiKey).where(ApiKey.workspace_id == workspace_id)
|
|
|
|
# 过滤条件
|
|
if query.type:
|
|
stmt = stmt.where(ApiKey.type == query.type)
|
|
if query.is_active is not None:
|
|
stmt = stmt.where(ApiKey.is_active == query.is_active)
|
|
if query.resource_id:
|
|
stmt = stmt.where(ApiKey.resource_id == query.resource_id)
|
|
|
|
# 总数
|
|
count_stmt = select(func.count()).select_from(stmt.subquery())
|
|
total = db.execute(count_stmt).scalar()
|
|
|
|
# 分页
|
|
stmt = stmt.order_by(ApiKey.created_at.desc())
|
|
stmt = stmt.offset((query.page - 1) * query.pagesize).limit(query.pagesize)
|
|
|
|
items = db.scalars(stmt).all()
|
|
return list(items), total
|
|
|
|
@staticmethod
|
|
def update(db: Session, api_key_id: uuid.UUID, update_data: dict) -> ApiKey:
|
|
"""更新 API Key"""
|
|
api_key = db.get(ApiKey, api_key_id)
|
|
if api_key:
|
|
for key, value in update_data.items():
|
|
if value is not None:
|
|
setattr(api_key, key, value)
|
|
api_key.updated_at = datetime.datetime.now()
|
|
db.flush()
|
|
return api_key
|
|
|
|
@staticmethod
|
|
def delete(db: Session, api_key_id: uuid.UUID) -> bool:
|
|
"""删除 API Key"""
|
|
api_key = db.get(ApiKey, api_key_id)
|
|
if api_key:
|
|
db.delete(api_key)
|
|
db.flush()
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def update_usage(db: Session, api_key_id: uuid.UUID) -> bool:
|
|
"""更新使用统计"""
|
|
api_key = db.get(ApiKey, api_key_id)
|
|
if api_key:
|
|
api_key.usage_count += 1
|
|
api_key.quota_used += 1
|
|
api_key.last_used_at = datetime.datetime.now()
|
|
db.flush()
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def get_stats(db: Session, api_key_id: uuid.UUID) -> dict:
|
|
"""获取使用统计"""
|
|
api_key = db.get(ApiKey, api_key_id)
|
|
if not api_key:
|
|
return {}
|
|
|
|
# 今日请求数
|
|
today_start = datetime.datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
today_count_stmt = select(func.count()).select_from(ApiKeyLog).where(
|
|
and_(
|
|
ApiKeyLog.api_key_id == api_key_id,
|
|
ApiKeyLog.created_at >= today_start
|
|
)
|
|
)
|
|
requests_today = db.execute(today_count_stmt).scalar() or 0
|
|
|
|
# 平均响应时间
|
|
avg_time_stmt = select(func.avg(ApiKeyLog.response_time)).where(
|
|
ApiKeyLog.api_key_id == api_key_id
|
|
)
|
|
avg_response_time = db.execute(avg_time_stmt).scalar()
|
|
|
|
return {
|
|
"total_requests": api_key.usage_count,
|
|
"requests_today": requests_today,
|
|
"quota_used": api_key.quota_used,
|
|
"quota_limit": api_key.quota_limit,
|
|
"last_used_at": api_key.last_used_at,
|
|
"avg_response_time": float(avg_response_time) if avg_response_time else None
|
|
}
|
|
|
|
|
|
class ApiKeyLogRepository:
|
|
"""API Key 日志数据访问层"""
|
|
|
|
@staticmethod
|
|
def create(db: Session, log_data: dict) -> ApiKeyLog:
|
|
"""创建日志"""
|
|
log = ApiKeyLog(**log_data)
|
|
db.add(log)
|
|
db.flush()
|
|
return log
|