Files
MemoryBear/api/app/main.py
2025-12-24 20:35:04 +08:00

393 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import subprocess
from contextlib import asynccontextmanager
from fastapi import FastAPI, APIRouter
from fastapi import HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
# 管理端 API (JWT 认证)
from app.controllers import manager_router
# 服务端 API (API Key 认证)
from app.controllers.service import service_router
from app.core.config import settings
from app.core.error_codes import BizCode, HTTP_MAPPING
from app.core.exceptions import BusinessException
from app.core.logging_config import LoggingConfig, get_logger
from app.core.response_utils import fail
# Initialize logging system
LoggingConfig.setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""使用 FastAPI lifespan 替代 on_event 处理启动/关闭事件"""
# 应用启动事件
# 检查是否需要自动升级数据库
if settings.DB_AUTO_UPGRADE:
logger.info("开始自动升级数据库...")
try:
result = subprocess.run(
["alembic", "upgrade", "head"],
capture_output=True,
text=True,
check=True
)
logger.info(f"数据库升级成功: {result.stdout}")
except subprocess.CalledProcessError as e:
logger.error(f"数据库升级失败: {e.stderr}")
raise RuntimeError(f"数据库升级失败: {e.stderr}")
except Exception as e:
logger.error(f"运行数据库升级时出错: {str(e)}")
raise
else:
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
logger.info("应用程序启动完成")
yield
# 应用关闭事件
logger.info("应用程序正在关闭")
app = FastAPI(
title="redbera-mem",
description="redbera-mem",
version="1.0.0",
lifespan=lifespan,
)
# Enable CORS for frontend access with environment-extendable origins
default_origins = [
settings.WEB_URL
]
allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o})
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger.info("FastAPI应用程序启动")
@app.get("/", tags=["General"])
def read_root():
"""
A simple health check endpoint.
"""
logger.debug("健康检查端点被访问")
return {"message": "FastAPI is running"}
# 生命周期事件由 lifespan 管理,无需 on_event
# 注册路由
# 管理端 API (JWT 认证)
app.include_router(manager_router, prefix="/api")
# 服务端 API (API Key 认证)
app.include_router(service_router, prefix="/v1")
logger.info("所有路由已注册完成")
# Import additional exception types for specific handling
from app.core.exceptions import (
ValidationException,
ResourceNotFoundException,
PermissionDeniedException,
AuthenticationException,
AuthorizationException,
FileUploadException,
RateLimitException
)
from app.core.sensitive_filter import SensitiveDataFilter
import traceback
# 处理验证异常
@app.exception_handler(ValidationException)
async def validation_exception_handler(request: Request, exc: ValidationException):
"""处理验证异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Validation error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
},
exc_info=exc.cause is not None
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.VALIDATION_FAILED
status_code = HTTP_MAPPING.get(biz_code, 400)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理资源不存在异常
@app.exception_handler(ResourceNotFoundException)
async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException):
"""处理资源不存在异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.info(
f"Resource not found: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_NOT_FOUND
status_code = HTTP_MAPPING.get(biz_code, 404)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理权限拒绝异常
@app.exception_handler(PermissionDeniedException)
async def permission_denied_handler(request: Request, exc: PermissionDeniedException):
"""处理权限拒绝异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Permission denied: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"user": getattr(request.state, "user_id", None),
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN
status_code = HTTP_MAPPING.get(biz_code, 403)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理认证异常
@app.exception_handler(AuthenticationException)
async def authentication_exception_handler(request: Request, exc: AuthenticationException):
"""处理认证异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Authentication error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.UNAUTHORIZED
status_code = HTTP_MAPPING.get(biz_code, 401)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理授权异常
@app.exception_handler(AuthorizationException)
async def authorization_exception_handler(request: Request, exc: AuthorizationException):
"""处理授权异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Authorization error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN
status_code = HTTP_MAPPING.get(biz_code, 403)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理文件上传异常
@app.exception_handler(FileUploadException)
async def file_upload_exception_handler(request: Request, exc: FileUploadException):
"""处理文件上传异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.error(
f"File upload error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
},
exc_info=exc.cause is not None
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_READ_ERROR
status_code = HTTP_MAPPING.get(biz_code, 500)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理限流异常
@app.exception_handler(RateLimitException)
async def rate_limit_exception_handler(request: Request, exc: RateLimitException):
"""处理限流异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Rate limit exceeded: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.RATE_LIMITED
status_code = HTTP_MAPPING.get(biz_code, 429)
# 创建响应对象并添加限流头信息
response = JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 添加限流相关的响应头
rate_headers = exc.context.get("rate_limit_headers", {}) if exc.context else {}
for header_name, header_value in rate_headers.items():
response.headers[header_name] = str(header_value)
return response
# 业务异常统一处理(使用业务错误码)
@app.exception_handler(BusinessException)
async def business_exception_handler(request: Request, exc: BusinessException):
"""处理通用业务异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.error(
f"Business error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
},
exc_info=exc.cause is not None
)
raw_code = exc.code
if isinstance(raw_code, BizCode):
biz_code = raw_code
elif isinstance(raw_code, int):
try:
biz_code = BizCode(raw_code)
except ValueError:
biz_code = BizCode.BAD_REQUEST
else:
biz_code = BizCode.BAD_REQUEST
status_code = HTTP_MAPPING.get(biz_code, 400)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 统一异常处理将HTTPException转换为统一响应结构
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理HTTP异常"""
# 过滤敏感信息
filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail))
logger.warning(
f"HTTP exception: {filtered_detail}",
extra={
"path": request.url.path,
"method": request.method,
"status_code": exc.status_code
}
)
return JSONResponse(
status_code=exc.status_code,
content=fail(code=exc.status_code, msg=filtered_detail, error=filtered_detail)
)
# 捕获未处理的异常,返回统一错误结构
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception):
"""处理未捕获的异常"""
# 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息)
logger.error(
f"Unhandled exception: {exc}",
extra={
"path": request.url.path,
"method": request.method,
"exception_type": type(exc).__name__,
"traceback": traceback.format_exc()
},
exc_info=True
)
# 生产环境隐藏详细错误信息
environment = os.getenv("ENVIRONMENT", "development")
if environment == "production":
message = "服务器内部错误,请稍后重试"
else:
# 开发环境也要过滤敏感信息
message = SensitiveDataFilter.filter_string(str(exc))
return JSONResponse(
status_code=500,
content=fail(code=BizCode.INTERNAL_ERROR.value, msg=message, error=message)
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)