[MODIFY] Code optimization
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.orm import Session
|
||||
from jose import jwt, JWTError
|
||||
import uuid
|
||||
from functools import wraps
|
||||
|
||||
from app.db import get_db, SessionLocal
|
||||
from app.schemas import token_schema
|
||||
@@ -25,9 +26,10 @@ security_logger = get_security_logger()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
) -> User:
|
||||
"""
|
||||
获取当前认证用户
|
||||
@@ -37,81 +39,82 @@ async def get_current_user(
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
auth_logger.debug("开始解析JWT token")
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
|
||||
|
||||
if user_id is None:
|
||||
auth_logger.warning("JWT token中缺少用户ID")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
token_data = token_schema.TokenData(userId=user_id)
|
||||
auth_logger.debug(f"JWT解析成功,用户ID: {user_id}")
|
||||
|
||||
|
||||
except JWTError as e:
|
||||
auth_logger.warning(f"JWT解析失败: {str(e)}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
# 检查单点登录黑名单和用户token失效
|
||||
try:
|
||||
auth_logger.debug("检查单点登录黑名单")
|
||||
token_id = get_token_id(token)
|
||||
session_service = SessionService()
|
||||
|
||||
|
||||
if await session_service.is_token_blacklisted(token_id):
|
||||
auth_logger.warning(f"Token已被列入黑名单: {token_id}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
# 检查用户是否重置了密码(所有旧token失效)
|
||||
invalidation_time_str = await session_service.get_user_token_invalidation_time(user_id)
|
||||
if invalidation_time_str:
|
||||
from datetime import datetime, timezone
|
||||
invalidation_time = datetime.fromisoformat(invalidation_time_str)
|
||||
token_issued_at = datetime.fromtimestamp(payload.get("iat", 0), tz=timezone.utc) if payload.get("iat") else None
|
||||
|
||||
token_issued_at = datetime.fromtimestamp(payload.get("iat", 0), tz=timezone.utc) if payload.get(
|
||||
"iat") else None
|
||||
|
||||
if token_issued_at and token_issued_at < invalidation_time:
|
||||
auth_logger.warning(f"Token在密码重置前签发,已失效: user_id={user_id}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
auth_logger.debug("单点登录检查通过")
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
auth_logger.error(f"检查token有效性时发生错误: {str(e)}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
try:
|
||||
auth_logger.debug(f"查询用户信息: {token_data.userId}")
|
||||
user = user_repository.get_user_by_id(db, user_id=token_data.userId)
|
||||
|
||||
|
||||
if user is None:
|
||||
auth_logger.warning(f"用户不存在: {token_data.userId}")
|
||||
raise credentials_exception
|
||||
if not user.is_active:
|
||||
auth_logger.warning(f"用户已被停用: {user.username} (ID: {user.id})")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
auth_logger.info(f"用户认证成功: {user.username} (ID: {user.id})")
|
||||
return user
|
||||
|
||||
|
||||
except Exception as e:
|
||||
auth_logger.error(f"查询用户信息时发生错误: {str(e)}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
async def get_current_tenant(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> Tenants:
|
||||
"""
|
||||
获取当前用户的租户
|
||||
由于每个用户只属于一个租户,直接返回用户的租户
|
||||
"""
|
||||
auth_logger.debug(f"获取用户 {current_user.username} 的租户信息")
|
||||
|
||||
|
||||
try:
|
||||
# 直接从用户模型获取租户
|
||||
if current_user.tenant:
|
||||
@@ -123,7 +126,7 @@ async def get_current_tenant(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户没有关联的租户"
|
||||
)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -135,15 +138,15 @@ async def get_current_tenant(
|
||||
|
||||
|
||||
async def get_user_tenants(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
) -> list[Tenants]:
|
||||
"""
|
||||
获取当前用户所属的所有租户
|
||||
由于每个用户只属于一个租户,返回包含该租户的列表
|
||||
"""
|
||||
auth_logger.debug(f"获取用户 {current_user.username} 的所有租户")
|
||||
|
||||
|
||||
try:
|
||||
if current_user.tenant:
|
||||
tenants = [current_user.tenant]
|
||||
@@ -152,7 +155,7 @@ async def get_user_tenants(
|
||||
else:
|
||||
auth_logger.info(f"用户 {current_user.username} 没有关联的租户")
|
||||
return []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
auth_logger.error(f"获取用户租户列表时发生错误: {str(e)}")
|
||||
raise HTTPException(
|
||||
@@ -162,20 +165,20 @@ async def get_user_tenants(
|
||||
|
||||
|
||||
async def get_current_superuser(
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user)
|
||||
) -> User:
|
||||
"""
|
||||
检查当前用户是否为超级管理员
|
||||
"""
|
||||
auth_logger.debug(f"检查用户 {current_user.username} 是否为超级管理员")
|
||||
|
||||
|
||||
if not current_user.is_superuser:
|
||||
auth_logger.warning(f"用户 {current_user.username} 尝试访问超管功能但不是超级管理员")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="只有超级管理员才能执行此操作"
|
||||
)
|
||||
|
||||
|
||||
auth_logger.info(f"超级管理员 {current_user.username} 访问超管功能")
|
||||
return current_user
|
||||
|
||||
@@ -246,13 +249,13 @@ async def get_current_superuser(
|
||||
def _check_workspace_access_sync(db: Session, user: User, workspace_id: uuid.UUID) -> Workspace:
|
||||
"""同步校验版本,供装饰器在同步端点中调用 - 使用权限服务"""
|
||||
auth_logger.debug(f"同步校验工作空间访问权限: workspace_id={workspace_id}, user={user.id}")
|
||||
|
||||
|
||||
# 1) 工作空间存在性
|
||||
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
|
||||
if not workspace:
|
||||
auth_logger.warning(f"工作空间不存在: {workspace_id}")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workspace not found")
|
||||
|
||||
|
||||
# 2) 超级用户跳过成员检查,直接验证租户
|
||||
if user.is_superuser:
|
||||
if user.tenant_id == workspace.tenant_id:
|
||||
@@ -261,26 +264,26 @@ def _check_workspace_access_sync(db: Session, user: User, workspace_id: uuid.UUI
|
||||
else:
|
||||
auth_logger.warning(f"超级用户尝试访问其他租户工作空间: workspace_id={workspace_id}, user={user.id}")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
|
||||
# 3) 普通用户使用权限服务检查访问权限
|
||||
from app.core.permissions import permission_service, Subject, Resource, Action
|
||||
from app.core.permissions.policies import WorkspaceMemberPolicy, SameTenantSuperuserPolicy
|
||||
|
||||
|
||||
# Check if user is a member
|
||||
member = workspace_repository.get_member_in_workspace(
|
||||
db=db, user_id=user.id, workspace_id=workspace_id
|
||||
)
|
||||
workspace_memberships = {workspace_id} if member else set()
|
||||
|
||||
|
||||
subject = Subject.from_user(user, workspace_memberships=workspace_memberships)
|
||||
resource = Resource.from_workspace(workspace)
|
||||
|
||||
|
||||
# Add workspace member policy
|
||||
temp_service = permission_service
|
||||
if member:
|
||||
temp_service.add_policy(WorkspaceMemberPolicy(allowed_actions={Action.READ, Action.UPDATE, Action.MANAGE}))
|
||||
temp_service.add_policy(SameTenantSuperuserPolicy())
|
||||
|
||||
|
||||
try:
|
||||
permission_service.require_permission(
|
||||
subject,
|
||||
@@ -317,7 +320,8 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
||||
if get_workspace_id_from_body:
|
||||
payload = kwargs.get("payload")
|
||||
if not payload or not hasattr(payload, "workspace_id"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id missing in body")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="workspace_id missing in body")
|
||||
workspace_id = payload.workspace_id
|
||||
else:
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
@@ -326,6 +330,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
||||
|
||||
_check_workspace_access_sync(db, user, workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return _async_wrapper
|
||||
else:
|
||||
@wraps(func)
|
||||
@@ -336,7 +341,8 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
||||
if get_workspace_id_from_body:
|
||||
payload = kwargs.get("payload")
|
||||
if not payload or not hasattr(payload, "workspace_id"):
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id missing in body")
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="workspace_id missing in body")
|
||||
workspace_id = payload.workspace_id
|
||||
else:
|
||||
workspace_id = kwargs.get("workspace_id")
|
||||
@@ -345,6 +351,7 @@ def workspace_access_guard(get_workspace_id_from_body: bool = False):
|
||||
|
||||
_check_workspace_access_sync(db, user, workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _sync_wrapper
|
||||
|
||||
return _decorator
|
||||
@@ -384,6 +391,7 @@ def cur_workspace_access_guard():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id is required")
|
||||
_check_workspace_access_sync(db, user, workspace_id)
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return _async_wrapper
|
||||
else:
|
||||
@wraps(func)
|
||||
@@ -395,20 +403,23 @@ def cur_workspace_access_guard():
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="workspace_id is required")
|
||||
_check_workspace_access_sync(db, user, workspace_id)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return _sync_wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class ShareTokenData:
|
||||
"""分享 token 数据"""
|
||||
|
||||
def __init__(self, user_id: str, share_token: str):
|
||||
self.user_id = user_id
|
||||
self.share_token = share_token
|
||||
|
||||
|
||||
async def get_share_user_id(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: Session = Depends(get_db)
|
||||
) -> ShareTokenData:
|
||||
"""
|
||||
从分享访问 token 中获取用户 ID 和 share_token
|
||||
@@ -422,38 +433,40 @@ async def get_share_user_id(
|
||||
from app.services.auth_service import decode_access_token
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
auth_logger.debug("开始解析分享访问 token")
|
||||
|
||||
|
||||
# 解码 token 获取 user_id 和 share_token
|
||||
payload = decode_access_token(token)
|
||||
user_id = payload["user_id"]
|
||||
share_token = payload["share_token"]
|
||||
|
||||
|
||||
auth_logger.debug(f"Token 解析成功,用户ID: {user_id}, share_token: {share_token}")
|
||||
|
||||
|
||||
# 验证 share_token 是否有效
|
||||
service = ReleaseShareService(db)
|
||||
share_info = service.get_shared_release_info(share_token=share_token)
|
||||
|
||||
|
||||
if not share_info:
|
||||
auth_logger.warning(f"分享 token 无效: {share_token}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
auth_logger.info(f"分享访问验证成功: user_id={user_id}, share_token={share_token}")
|
||||
return ShareTokenData(user_id=user_id, share_token=share_token)
|
||||
|
||||
|
||||
except BusinessException as e:
|
||||
auth_logger.warning(f"分享访问验证失败: {str(e)}")
|
||||
raise credentials_exception
|
||||
except Exception as e:
|
||||
auth_logger.error(f"验证分享访问 token 时发生错误: {str(e)}")
|
||||
raise credentials_exception
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user