127 lines
3.4 KiB
Python
127 lines
3.4 KiB
Python
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, Union, Optional
|
||
import uuid
|
||
|
||
from jose import jwt, JWTError
|
||
from passlib.context import CryptContext
|
||
|
||
from app.core.config import settings
|
||
|
||
pwd_context = CryptContext(schemes=["scrypt"], deprecated="auto")
|
||
|
||
|
||
def create_access_token(subject: Union[str, Any], expires_delta: timedelta = None) -> tuple[str, str]:
|
||
"""创建访问token
|
||
|
||
Args:
|
||
subject: token主体(通常是用户名)
|
||
expires_delta: 过期时间间隔
|
||
|
||
Returns:
|
||
tuple: (token字符串, token_id)
|
||
"""
|
||
now = datetime.now(timezone.utc)
|
||
if expires_delta:
|
||
expire = now + expires_delta
|
||
else:
|
||
expire = now + timedelta(
|
||
minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES
|
||
)
|
||
|
||
token_id = str(uuid.uuid4())
|
||
to_encode = {
|
||
"exp": expire,
|
||
"iat": now, # Issued at time
|
||
"sub": str(subject),
|
||
"type": "access",
|
||
"jti": token_id # JWT ID
|
||
}
|
||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||
return encoded_jwt, token_id
|
||
|
||
|
||
def create_refresh_token(subject: Union[str, Any], expires_delta: timedelta = None) -> tuple[str, str]:
|
||
"""创建刷新token
|
||
|
||
Args:
|
||
subject: token主体(通常是用户名)
|
||
expires_delta: 过期时间间隔
|
||
|
||
Returns:
|
||
tuple: (token字符串, token_id)
|
||
"""
|
||
now = datetime.now(timezone.utc)
|
||
if expires_delta:
|
||
expire = now + expires_delta
|
||
else:
|
||
expire = now + timedelta(
|
||
days=settings.REFRESH_TOKEN_EXPIRE_DAYS
|
||
)
|
||
|
||
token_id = str(uuid.uuid4())
|
||
to_encode = {
|
||
"exp": expire,
|
||
"iat": now, # Issued at time
|
||
"sub": str(subject),
|
||
"type": "refresh",
|
||
"jti": token_id # JWT ID
|
||
}
|
||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM)
|
||
return encoded_jwt, token_id
|
||
|
||
|
||
def verify_token(token: str, token_type: str = "access") -> Union[str, None]:
|
||
try:
|
||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||
userId: str = payload.get("sub")
|
||
token_type_in_payload: str = payload.get("type")
|
||
|
||
if userId is None or token_type_in_payload != token_type:
|
||
return None
|
||
return userId
|
||
except JWTError:
|
||
return None
|
||
|
||
|
||
def get_token_id(token: str) -> Optional[str]:
|
||
"""从token中提取token ID
|
||
|
||
Args:
|
||
token: JWT token字符串
|
||
|
||
Returns:
|
||
token ID或None
|
||
"""
|
||
try:
|
||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||
return payload.get("jti")
|
||
except JWTError:
|
||
return None
|
||
|
||
|
||
def get_token_expiry(token: str) -> Optional[datetime]:
|
||
"""从token中提取过期时间
|
||
|
||
Args:
|
||
token: JWT token字符串
|
||
|
||
Returns:
|
||
过期时间或None
|
||
"""
|
||
try:
|
||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
||
exp_timestamp = payload.get("exp")
|
||
if exp_timestamp:
|
||
return datetime.fromtimestamp(exp_timestamp, tz=timezone.utc)
|
||
return None
|
||
except JWTError:
|
||
return None
|
||
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
|
||
def get_password_hash(password: str) -> str:
|
||
return pwd_context.hash(password)
|