import os import subprocess from app.repositories.neo4j.create_indexes import create_all_indexes from contextlib import asynccontextmanager from fastapi import FastAPI, APIRouter from fastapi import HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse # 管理端 API (JWT 认证) from app.controllers import manager_router # 服务端 API (API Key 认证) from app.controllers.service import service_router from app.core.config import settings from app.core.error_codes import BizCode, HTTP_MAPPING from app.core.exceptions import BusinessException from app.core.logging_config import LoggingConfig, get_logger from app.core.response_utils import fail from app.core.models.scripts.loader import load_models from app.db import get_db_context # Initialize logging system LoggingConfig.setup_logging() logger = get_logger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """使用 FastAPI lifespan 替代 on_event 处理启动/关闭事件""" # 应用启动事件 # 检查是否需要自动升级数据库 if settings.DB_AUTO_UPGRADE: logger.info("开始自动升级数据库...") try: result = subprocess.run( ["alembic", "upgrade", "head"], capture_output=True, text=True, check=True ) logger.info(f"数据库升级成功: {result.stdout}") except subprocess.CalledProcessError as e: logger.error(f"数据库升级失败: {e.stderr}") raise RuntimeError(f"数据库升级失败: {e.stderr}") except Exception as e: logger.error(f"运行数据库升级时出错: {str(e)}") raise else: logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)") # 加载预定义模型 if settings.LOAD_MODEL: logger.info("开始加载预定义模型...") try: with get_db_context() as db: result = load_models(db, silent=True) logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个") except Exception as e: logger.warning(f"加载预定义模型时出错: {str(e)}") else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") await create_all_indexes() logger.info("应用程序启动完成") yield # 应用关闭事件 logger.info("应用程序正在关闭") app = FastAPI( title="redbera-mem", description="redbera-mem", version="1.0.0", lifespan=lifespan, ) # Enable CORS for frontend access with environment-extendable origins default_origins = [ settings.WEB_URL ] allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o}) # 如果 CORS_ORIGINS 包含 "*",则允许所有来源 if "*" in settings.CORS_ORIGINS: allowed_origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, allow_credentials=True if "*" not in allowed_origins else False, # 允许所有来源时不能使用 credentials allow_methods=["*"], allow_headers=["*"], ) # Add i18n language detection middleware from app.i18n.middleware import LanguageMiddleware app.add_middleware(LanguageMiddleware) logger.info("FastAPI应用程序启动") @app.get("/", tags=["General"]) def read_root(): """ A simple health check endpoint. """ logger.debug("健康检查端点被访问") return {"message": "FastAPI is running"} # 生命周期事件由 lifespan 管理,无需 on_event # 注册路由 # 管理端 API (JWT 认证) app.include_router(manager_router, prefix="/api") # 服务端 API (API Key 认证) app.include_router(service_router, prefix="/v1") logger.info("所有路由已注册完成") # Import additional exception types for specific handling from app.core.exceptions import ( ValidationException, ResourceNotFoundException, PermissionDeniedException, AuthenticationException, AuthorizationException, FileUploadException, RateLimitException ) from app.core.sensitive_filter import SensitiveDataFilter import traceback # Import i18n exception support from app.i18n.exceptions import I18nException from app.i18n.service import get_translation_service from pydantic import ValidationError as PydanticValidationError # 处理验证异常 @app.exception_handler(ValidationException) async def validation_exception_handler(request: Request, exc: ValidationException): """处理验证异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Validation error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None }, exc_info=exc.cause is not None ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.VALIDATION_FAILED status_code = HTTP_MAPPING.get(biz_code, 400) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理 i18n 异常(国际化异常) @app.exception_handler(I18nException) async def i18n_exception_handler(request: Request, exc: I18nException): """ 处理国际化异常 I18nException 已经自动翻译了错误消息,直接返回即可 """ # 获取当前语言 language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) # 获取异常详情(已经包含翻译后的消息) detail = exc.detail # 过滤敏感信息 if isinstance(detail, dict): filtered_message = SensitiveDataFilter.filter_string(detail.get("message", "")) filtered_detail = { **detail, "message": filtered_message } else: filtered_detail = SensitiveDataFilter.filter_string(str(detail)) logger.warning( f"I18n exception: {exc.error_key}", extra={ "path": request.url.path, "method": request.method, "error_code": exc.error_code, "error_key": exc.error_key, "language": language, "status_code": exc.status_code, "params": exc.params } ) return JSONResponse( status_code=exc.status_code, content={ "success": False, **filtered_detail }, headers=exc.headers ) # 处理 Pydantic 验证错误(国际化支持) @app.exception_handler(PydanticValidationError) async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError): """ 处理 Pydantic 验证错误,支持国际化 """ # 获取当前语言 language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) # 获取翻译服务 translation_service = get_translation_service() # 翻译验证错误消息 errors = [] for error in exc.errors(): field = ".".join(str(loc) for loc in error["loc"]) error_type = error["type"] # 尝试翻译错误消息 if error_type == "value_error.missing": message = translation_service.translate( "errors.validation.missing_field", language, field=field ) elif error_type == "value_error.any_str.max_length": message = translation_service.translate( "errors.validation.field_too_long", language, field=field ) elif error_type == "value_error.any_str.min_length": message = translation_service.translate( "errors.validation.field_too_short", language, field=field ) else: # 使用通用验证错误消息 message = translation_service.translate( "errors.validation.invalid_field", language, field=field ) errors.append({ "field": field, "message": message, "type": error_type }) # 翻译主错误消息 main_message = translation_service.translate( "errors.common.validation_failed", language ) logger.warning( f"Pydantic validation error: {len(errors)} errors", extra={ "path": request.url.path, "method": request.method, "language": language, "errors": errors } ) return JSONResponse( status_code=422, content={ "success": False, "error_code": "VALIDATION_FAILED", "message": main_message, "errors": errors } ) # 处理资源不存在异常 @app.exception_handler(ResourceNotFoundException) async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException): """处理资源不存在异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.info( f"Resource not found: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_NOT_FOUND status_code = HTTP_MAPPING.get(biz_code, 404) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理权限拒绝异常 @app.exception_handler(PermissionDeniedException) async def permission_denied_handler(request: Request, exc: PermissionDeniedException): """处理权限拒绝异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Permission denied: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "user": getattr(request.state, "user_id", None), "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN status_code = HTTP_MAPPING.get(biz_code, 403) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理认证异常 @app.exception_handler(AuthenticationException) async def authentication_exception_handler(request: Request, exc: AuthenticationException): """处理认证异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Authentication error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.UNAUTHORIZED status_code = HTTP_MAPPING.get(biz_code, 401) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理授权异常 @app.exception_handler(AuthorizationException) async def authorization_exception_handler(request: Request, exc: AuthorizationException): """处理授权异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Authorization error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN status_code = HTTP_MAPPING.get(biz_code, 403) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理文件上传异常 @app.exception_handler(FileUploadException) async def file_upload_exception_handler(request: Request, exc: FileUploadException): """处理文件上传异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.error( f"File upload error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None }, exc_info=exc.cause is not None ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_READ_ERROR status_code = HTTP_MAPPING.get(biz_code, 500) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 处理限流异常 @app.exception_handler(RateLimitException) async def rate_limit_exception_handler(request: Request, exc: RateLimitException): """处理限流异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.warning( f"Rate limit exceeded: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None } ) biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.RATE_LIMITED status_code = HTTP_MAPPING.get(biz_code, 429) # 创建响应对象并添加限流头信息 response = JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 添加限流相关的响应头 rate_headers = exc.context.get("rate_limit_headers", {}) if exc.context else {} for header_name, header_value in rate_headers.items(): response.headers[header_name] = str(header_value) return response # 业务异常统一处理(使用业务错误码) @app.exception_handler(BusinessException) async def business_exception_handler(request: Request, exc: BusinessException): """处理通用业务异常""" # 过滤敏感信息 filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context) logger.error( f"Business error: {filtered_message}", extra={ "path": request.url.path, "method": request.method, "context": filtered_context, "error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code, "cause": str(exc.cause) if exc.cause else None }, exc_info=exc.cause is not None ) raw_code = exc.code if isinstance(raw_code, BizCode): biz_code = raw_code elif isinstance(raw_code, int): try: biz_code = BizCode(raw_code) except ValueError: biz_code = BizCode.BAD_REQUEST else: biz_code = BizCode.BAD_REQUEST status_code = HTTP_MAPPING.get(biz_code, 400) return JSONResponse( status_code=status_code, content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message) ) # 统一异常处理:将HTTPException转换为统一响应结构(支持国际化) @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): """处理HTTP异常,支持国际化""" # 获取当前语言 language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) # 获取翻译服务 translation_service = get_translation_service() # 尝试翻译标准HTTP错误 error_key_map = { 400: "errors.common.bad_request", 401: "errors.common.unauthorized", 403: "errors.common.forbidden", 404: "errors.common.not_found", 405: "errors.common.method_not_allowed", 409: "errors.common.conflict", 413: "errors.common.payload_too_large", 422: "errors.common.validation_failed", 429: "errors.common.too_many_requests", 500: "errors.common.internal_error", 502: "errors.common.bad_gateway", 503: "errors.common.service_unavailable", 504: "errors.common.gateway_timeout", } # 如果有对应的翻译键,使用翻译 if exc.status_code in error_key_map: translated_message = translation_service.translate( error_key_map[exc.status_code], language ) else: # 否则过滤原始消息 translated_message = SensitiveDataFilter.filter_string(str(exc.detail)) logger.warning( f"HTTP exception: {translated_message}", extra={ "path": request.url.path, "method": request.method, "status_code": exc.status_code, "language": language } ) return JSONResponse( status_code=exc.status_code, content=fail(code=exc.status_code, msg=translated_message, error=exc.detail) ) # 捕获未处理的异常,返回统一错误结构(支持国际化) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): """处理未捕获的异常,支持国际化""" # 获取当前语言 language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE) # 获取翻译服务 translation_service = get_translation_service() # 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息) logger.error( f"Unhandled exception: {exc}", extra={ "path": request.url.path, "method": request.method, "exception_type": type(exc).__name__, "language": language, "traceback": traceback.format_exc() }, exc_info=True ) # 生产环境隐藏详细错误信息 environment = os.getenv("ENVIRONMENT", "development") if environment == "production": # 使用翻译的通用错误消息 message = translation_service.translate( "errors.common.internal_error", language ) else: # 开发环境也要过滤敏感信息 message = SensitiveDataFilter.filter_string(str(exc)) return JSONResponse( status_code=500, content=fail(code=BizCode.INTERNAL_ERROR.value, msg=message, error=message) ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)