1. add list operator node for filtering, sorting, limiting, and extracting list items; 2. Increase the session variable to the "file" type
204 lines
8.2 KiB
Python
204 lines
8.2 KiB
Python
from datetime import datetime, timedelta, timezone
|
||
from typing import Callable
|
||
from fastapi import APIRouter, Depends
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.response_utils import success
|
||
from app.db import get_db
|
||
from app.schemas.response_schema import ApiResponse
|
||
from app.schemas.token_schema import Token, RefreshTokenRequest, TokenRequest
|
||
from app.schemas.workspace_schema import InviteAcceptRequest
|
||
from app.services import auth_service, user_service, workspace_service
|
||
from app.core import security
|
||
from app.core.config import settings
|
||
from app.services.session_service import SessionService
|
||
from app.core.logging_config import get_auth_logger, get_security_logger
|
||
from app.core.exceptions import BusinessException
|
||
from app.core.error_codes import BizCode
|
||
from app.dependencies import get_current_user, oauth2_scheme
|
||
from app.models.user_model import User
|
||
from app.i18n.dependencies import get_translator
|
||
|
||
# 获取专用日志器
|
||
auth_logger = get_auth_logger()
|
||
security_logger = get_security_logger()
|
||
|
||
router = APIRouter(tags=["Authentication"])
|
||
|
||
@router.post("/token", response_model=ApiResponse)
|
||
async def login_for_access_token(
|
||
form_data: TokenRequest,
|
||
db: Session = Depends(get_db),
|
||
t: Callable = Depends(get_translator)
|
||
):
|
||
"""用户登录获取token"""
|
||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||
|
||
# 验证邀请码(如果提供)
|
||
invite_info = None
|
||
# 验证用户凭据或注册新用户
|
||
user = None
|
||
if form_data.invite:
|
||
auth_logger.info(f"检测到邀请码: {form_data.invite[:8]}...")
|
||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||
|
||
if not invite_info.is_valid:
|
||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||
|
||
if invite_info.email != form_data.email:
|
||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||
try:
|
||
# 尝试认证用户
|
||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||
if form_data.invite:
|
||
auth_service.bind_workspace_with_invite(
|
||
db=db,
|
||
user=user,
|
||
invite_token=form_data.invite,
|
||
workspace_id=invite_info.workspace_id
|
||
)
|
||
except BusinessException as e:
|
||
# 用户不存在且有邀请码,尝试注册
|
||
if e.code == BizCode.USER_NOT_FOUND:
|
||
auth_logger.info(f"用户不存在,使用邀请码注册: {form_data.email}")
|
||
user = auth_service.register_user_with_invite(
|
||
db=db,
|
||
email=form_data.email,
|
||
username=form_data.username,
|
||
password=form_data.password,
|
||
invite_token=form_data.invite,
|
||
workspace_id=invite_info.workspace_id
|
||
)
|
||
elif e.code == BizCode.PASSWORD_ERROR:
|
||
# 用户存在但密码错误
|
||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||
else:
|
||
# 其他认证失败情况,直接抛出
|
||
raise
|
||
else:
|
||
try:
|
||
# 尝试认证用户
|
||
user = auth_service.authenticate_user_or_raise(db, form_data.email, form_data.password)
|
||
auth_logger.info(f"用户认证成功: {user.email} (ID: {user.id})")
|
||
|
||
except BusinessException as e:
|
||
|
||
# 其他认证失败情况,直接抛出
|
||
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||
|
||
# 创建 tokens
|
||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||
refresh_token, refresh_token_id = security.create_refresh_token(subject=user.id)
|
||
|
||
# 计算过期时间
|
||
access_expires_at = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
refresh_expires_at = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||
|
||
# 单点登录会话管理
|
||
if settings.ENABLE_SINGLE_SESSION:
|
||
await SessionService.invalidate_old_session(user.id, access_token_id)
|
||
await SessionService.set_user_active_session(user.id, access_token_id, access_expires_at)
|
||
|
||
# 更新最后登录时间
|
||
user_service.update_last_login_time(db, user.id)
|
||
|
||
auth_logger.info(f"用户 {user.username} 登录成功")
|
||
|
||
return success(
|
||
data=Token(
|
||
access_token=access_token,
|
||
refresh_token=refresh_token,
|
||
token_type="bearer",
|
||
expires_at=access_expires_at,
|
||
refresh_expires_at=refresh_expires_at
|
||
),
|
||
msg=t("auth.login.success")
|
||
)
|
||
|
||
|
||
@router.post("/refresh", response_model=ApiResponse)
|
||
async def refresh_token(
|
||
refresh_request: RefreshTokenRequest,
|
||
db: Session = Depends(get_db),
|
||
t: Callable = Depends(get_translator)
|
||
):
|
||
"""刷新token"""
|
||
auth_logger.info("收到token刷新请求")
|
||
|
||
# 验证 refresh token
|
||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||
if not userId:
|
||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||
|
||
# 检查用户是否存在
|
||
user = auth_service.get_user_by_id(db, userId)
|
||
if not user:
|
||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||
|
||
# 检查 refresh token 黑名单
|
||
if settings.ENABLE_SINGLE_SESSION:
|
||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||
|
||
# 生成新 tokens
|
||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||
new_refresh_token, new_refresh_token_id = security.create_refresh_token(subject=user.id)
|
||
|
||
# 计算过期时间
|
||
access_expires_at = datetime.now() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
refresh_expires_at = datetime.now() + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||
|
||
# 单点登录会话管理
|
||
if settings.ENABLE_SINGLE_SESSION:
|
||
# 将旧 refresh token 加入黑名单
|
||
old_refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||
if old_refresh_token_id:
|
||
await SessionService.blacklist_token(old_refresh_token_id)
|
||
|
||
# 更新会话
|
||
await SessionService.invalidate_old_session(user.id, new_access_token_id)
|
||
await SessionService.set_user_active_session(user.id, new_access_token_id, access_expires_at)
|
||
|
||
auth_logger.info(f"用户 {user.id} token刷新成功")
|
||
|
||
return success(
|
||
data=Token(
|
||
access_token=new_access_token,
|
||
refresh_token=new_refresh_token,
|
||
token_type="bearer",
|
||
expires_at=access_expires_at,
|
||
refresh_expires_at=refresh_expires_at
|
||
),
|
||
msg=t("auth.token.refresh_success")
|
||
)
|
||
|
||
|
||
@router.post("/logout", response_model=ApiResponse)
|
||
async def logout(
|
||
token: str = Depends(oauth2_scheme),
|
||
current_user: User = Depends(get_current_user),
|
||
db: Session = Depends(get_db),
|
||
t: Callable = Depends(get_translator)
|
||
):
|
||
"""登出当前用户:加入token黑名单并清理会话"""
|
||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||
|
||
token_id = security.get_token_id(token)
|
||
if not token_id:
|
||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||
|
||
# 加入黑名单
|
||
await SessionService.blacklist_token(token_id)
|
||
|
||
# 清理会话
|
||
if settings.ENABLE_SINGLE_SESSION:
|
||
await SessionService.clear_user_session(current_user.username)
|
||
|
||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||
return success(msg=t("auth.logout.success"))
|
||
|