import os import subprocess from contextlib import asynccontextmanager from fastapi import FastAPI, APIRouter from fastapi import HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse # 管理端 API (JWT 认证) from app.controllers import manager_router # 服务端 API (API Key 认证) from app.controllers.service import service_router from app.core.config import settings from app.core.error_codes import BizCode, HTTP_MAPPING from app.core.exceptions import BusinessException from app.core.logging_config import LoggingConfig, get_logger from app.core.response_utils import fail # Initialize logging system LoggingConfig.setup_logging() logger = get_logger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """使用 FastAPI lifespan 替代 on_event 处理启动/关闭事件""" # 应用启动事件 # 检查是否需要自动升级数据库 if settings.DB_AUTO_UPGRADE: logger.info("开始自动升级数据库...") try: result = subprocess.run( ["alembic", "upgrade", "head"], capture_output=True, text=True, check=True ) logger.info(f"数据库升级成功: {result.stdout}") except subprocess.CalledProcessError as e: logger.error(f"数据库升级失败: {e.stderr}") raise RuntimeError(f"数据库升级失败: {e.stderr}") except Exception as e: logger.error(f"运行数据库升级时出错: {str(e)}") raise else: logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)") logger.info("应用程序启动完成") yield # 应用关闭事件 logger.info("应用程序正在关闭") app = FastAPI( title="redbera-mem", description="redbera-mem", version="1.0.0", lifespan=lifespan, ) # Enable CORS for frontend access with environment-extendable origins default_origins = [ settings.WEB_URL ] allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o}) app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) logger.info("FastAPI应用程序启动") @app.get("/", tags=["General"]) def read_root(): """ A simple health check endpoint. """ logger.debug("健康检查端点被访问") return {"message": "FastAPI is running"} # 生命周期事件由 lifespan 管理,无需 on_event # 注册路由 # 管理端 API (JWT 认证) app.include_router(manager_router, prefix="/api") # 服务端 API (API Key 认证) app.include_router(service_router, prefix="/v1") logger.info("所有路由已注册完成") # Import additional exception types for specific handling from app.core.exceptions import ( ValidationException, ResourceNotFoundException, PermissionDeniedException, AuthenticationException, AuthorizationException, FileUploadException, RateLimitException ) from app.core.sensitive_filter import SensitiveDataFilter import traceback # 处理验证异常 @app.exception_handler(ValidationException) async def validation_exception_handler(request: Request, exc: ValidationException): """处理验证异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Validation error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None }, exc_info=exc.cause is not None ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.VALIDATION_FAILED status_code = HTTP_MAPPING.get(biz_code, 400) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理资源不存在异常 @app.exception_handler(ResourceNotFoundException) async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException): """处理资源不存在异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.info( f"Resource not found: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_NOT_FOUND status_code = HTTP_MAPPING.get(biz_code, 404) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理权限拒绝异常 @app.exception_handler(PermissionDeniedException) async def permission_denied_handler(request: Request, exc: PermissionDeniedException): """处理权限拒绝异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Permission denied: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "user": getattr(request.state, "user_id", None), "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN status_code = HTTP_MAPPING.get(biz_code, 403) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理认证异常 @app.exception_handler(AuthenticationException) async def authentication_exception_handler(request: Request, exc: AuthenticationException): """处理认证异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Authentication error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.UNAUTHORIZED status_code = HTTP_MAPPING.get(biz_code, 401) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理授权异常 @app.exception_handler(AuthorizationException) async def authorization_exception_handler(request: Request, exc: AuthorizationException): """处理授权异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Authorization error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN status_code = HTTP_MAPPING.get(biz_code, 403) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理文件上传异常 @app.exception_handler(FileUploadException) async def file_upload_exception_handler(request: Request, exc: FileUploadException): """处理文件上传异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.error( f"File upload error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None }, exc_info=exc.cause is not None ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_READ_ERROR status_code = HTTP_MAPPING.get(biz_code, 500) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理限流异常 @app.exception_handler(RateLimitException) async def rate_limit_exception_handler(request: Request, exc: RateLimitException): """处理限流异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Rate limit exceeded: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.RATE_LIMITED status_code = HTTP_MAPPING.get(biz_code, 429) # 创建响应对象并添加限流头信息 response = JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 添加限流相关的响应头 rate_headers = exc.context.get("rate_limit_headers", {}) if exc.context else {} for header_name, header_value in rate_headers.items(): response.headers[header_name] = str(header_value) return response # 业务异常统一处理(使用业务错误码) @app.exception_handler(BusinessException) async def business_exception_handler(request: Request, exc: BusinessException): """处理通用业务异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.error( f"Business error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None }, exc_info=exc.cause is not None ) raw_code = exc.code if isinstance(raw_code, BizCode): biz_code = raw_code elif isinstance(raw_code, int): try: biz_code = BizCode(raw_code) except ValueError: biz_code = BizCode.BAD_REQUEST else: biz_code = BizCode.BAD_REQUEST status_code = HTTP_MAPPING.get(biz_code, 400) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 统一异常处理:将HTTPException转换为统一响应结构 @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """处理HTTP异常""" # 过滤敏感信息 filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail)) logger.warning( f"HTTP exception: {filtered_detail}", extra={ "path": request.url.path, "method": request.method, "status_code": exc.status_code } ) return JSONResponse( status_code=exc.status_code, content=fail(code=exc.status_code, msg=filtered_detail, error=filtered_detail) ) # 捕获未处理的异常,返回统一错误结构 @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): """处理未捕获的异常""" # 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息) logger.error( f"Unhandled exception: {exc}", extra={ "path": request.url.path, "method": request.method, "exception_type": type(exc).__name__, "traceback": traceback.format_exc() }, exc_info=True ) # 生产环境隐藏详细错误信息 environment = os.getenv("ENVIRONMENT", "development") if environment == "production": message = "服务器内部错误,请稍后重试" else: # 开发环境也要过滤敏感信息 message = SensitiveDataFilter.filter_string(str(exc)) return JSONResponse( status_code=500, content=fail(code=BizCode.INTERNAL_ERROR.value, msg=message, error=message) ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)