Files
MemoryBear/api/app/main.py
2026-03-30 18:14:31 +08:00

591 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import subprocess
from app.repositories.neo4j.create_indexes import create_all_indexes
from contextlib import asynccontextmanager
from fastapi import FastAPI, APIRouter
from fastapi import HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
# 管理端 API (JWT 认证)
from app.controllers import manager_router
# 服务端 API (API Key 认证)
from app.controllers.service import service_router
from app.core.config import settings
from app.core.error_codes import BizCode, HTTP_MAPPING
from app.core.exceptions import BusinessException
from app.core.logging_config import LoggingConfig, get_logger
from app.core.response_utils import fail
from app.core.models.scripts.loader import load_models
from app.db import get_db_context
# Initialize logging system
LoggingConfig.setup_logging()
logger = get_logger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""使用 FastAPI lifespan 替代 on_event 处理启动/关闭事件"""
# 应用启动事件
# 检查是否需要自动升级数据库
if settings.DB_AUTO_UPGRADE:
logger.info("开始自动升级数据库...")
try:
result = subprocess.run(
["alembic", "upgrade", "head"],
capture_output=True,
text=True,
check=True
)
logger.info(f"数据库升级成功: {result.stdout}")
except subprocess.CalledProcessError as e:
logger.error(f"数据库升级失败: {e.stderr}")
raise RuntimeError(f"数据库升级失败: {e.stderr}")
except Exception as e:
logger.error(f"运行数据库升级时出错: {str(e)}")
raise
else:
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
# 加载预定义模型
if settings.LOAD_MODEL:
logger.info("开始加载预定义模型...")
try:
with get_db_context() as db:
result = load_models(db, silent=True)
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}")
except Exception as e:
logger.warning(f"加载预定义模型时出错: {str(e)}")
else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes()
logger.info("应用程序启动完成")
yield
# 应用关闭事件
logger.info("应用程序正在关闭")
app = FastAPI(
title="redbera-mem",
description="redbera-mem",
version="1.0.0",
lifespan=lifespan,
)
# Enable CORS for frontend access with environment-extendable origins
default_origins = [
settings.WEB_URL
]
allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o})
# 如果 CORS_ORIGINS 包含 "*",则允许所有来源
if "*" in settings.CORS_ORIGINS:
allowed_origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True if "*" not in allowed_origins else False, # 允许所有来源时不能使用 credentials
allow_methods=["*"],
allow_headers=["*"],
)
# Add i18n language detection middleware
from app.i18n.middleware import LanguageMiddleware
app.add_middleware(LanguageMiddleware)
logger.info("FastAPI应用程序启动")
@app.get("/", tags=["General"])
def read_root():
"""
A simple health check endpoint.
"""
logger.debug("健康检查端点被访问")
return {"message": "FastAPI is running"}
# 生命周期事件由 lifespan 管理,无需 on_event
# 注册路由
# 管理端 API (JWT 认证)
app.include_router(manager_router, prefix="/api")
# 服务端 API (API Key 认证)
app.include_router(service_router, prefix="/v1")
logger.info("所有路由已注册完成")
# Import additional exception types for specific handling
from app.core.exceptions import (
ValidationException,
ResourceNotFoundException,
PermissionDeniedException,
AuthenticationException,
AuthorizationException,
FileUploadException,
RateLimitException
)
from app.core.sensitive_filter import SensitiveDataFilter
import traceback
# Import i18n exception support
from app.i18n.exceptions import I18nException
from app.i18n.service import get_translation_service
from pydantic import ValidationError as PydanticValidationError
# 处理验证异常
@app.exception_handler(ValidationException)
async def validation_exception_handler(request: Request, exc: ValidationException):
"""处理验证异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Validation error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
},
exc_info=exc.cause is not None
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.VALIDATION_FAILED
status_code = HTTP_MAPPING.get(biz_code, 400)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理 i18n 异常(国际化异常)
@app.exception_handler(I18nException)
async def i18n_exception_handler(request: Request, exc: I18nException):
"""
处理国际化异常
I18nException 已经自动翻译了错误消息,直接返回即可
"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取异常详情(已经包含翻译后的消息)
detail = exc.detail
# 过滤敏感信息
if isinstance(detail, dict):
filtered_message = SensitiveDataFilter.filter_string(detail.get("message", ""))
filtered_detail = {
**detail,
"message": filtered_message
}
else:
filtered_detail = SensitiveDataFilter.filter_string(str(detail))
logger.warning(
f"I18n exception: {exc.error_key}",
extra={
"path": request.url.path,
"method": request.method,
"error_code": exc.error_code,
"error_key": exc.error_key,
"language": language,
"status_code": exc.status_code,
"params": exc.params
}
)
return JSONResponse(
status_code=exc.status_code,
content={
"success": False,
**filtered_detail
},
headers=exc.headers
)
# 处理 Pydantic 验证错误(国际化支持)
@app.exception_handler(PydanticValidationError)
async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError):
"""
处理 Pydantic 验证错误,支持国际化
"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取翻译服务
translation_service = get_translation_service()
# 翻译验证错误消息
errors = []
for error in exc.errors():
field = ".".join(str(loc) for loc in error["loc"])
error_type = error["type"]
# 尝试翻译错误消息
if error_type == "value_error.missing":
message = translation_service.translate(
"errors.validation.missing_field",
language,
field=field
)
elif error_type == "value_error.any_str.max_length":
message = translation_service.translate(
"errors.validation.field_too_long",
language,
field=field
)
elif error_type == "value_error.any_str.min_length":
message = translation_service.translate(
"errors.validation.field_too_short",
language,
field=field
)
else:
# 使用通用验证错误消息
message = translation_service.translate(
"errors.validation.invalid_field",
language,
field=field
)
errors.append({
"field": field,
"message": message,
"type": error_type
})
# 翻译主错误消息
main_message = translation_service.translate(
"errors.common.validation_failed",
language
)
logger.warning(
f"Pydantic validation error: {len(errors)} errors",
extra={
"path": request.url.path,
"method": request.method,
"language": language,
"errors": errors
}
)
return JSONResponse(
status_code=422,
content={
"success": False,
"error_code": "VALIDATION_FAILED",
"message": main_message,
"errors": errors
}
)
# 处理资源不存在异常
@app.exception_handler(ResourceNotFoundException)
async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException):
"""处理资源不存在异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.info(
f"Resource not found: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_NOT_FOUND
status_code = HTTP_MAPPING.get(biz_code, 404)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理权限拒绝异常
@app.exception_handler(PermissionDeniedException)
async def permission_denied_handler(request: Request, exc: PermissionDeniedException):
"""处理权限拒绝异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Permission denied: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"user": getattr(request.state, "user_id", None),
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN
status_code = HTTP_MAPPING.get(biz_code, 403)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理认证异常
@app.exception_handler(AuthenticationException)
async def authentication_exception_handler(request: Request, exc: AuthenticationException):
"""处理认证异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Authentication error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.UNAUTHORIZED
status_code = HTTP_MAPPING.get(biz_code, 401)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理授权异常
@app.exception_handler(AuthorizationException)
async def authorization_exception_handler(request: Request, exc: AuthorizationException):
"""处理授权异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Authorization error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN
status_code = HTTP_MAPPING.get(biz_code, 403)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理文件上传异常
@app.exception_handler(FileUploadException)
async def file_upload_exception_handler(request: Request, exc: FileUploadException):
"""处理文件上传异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.error(
f"File upload error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
},
exc_info=exc.cause is not None
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_READ_ERROR
status_code = HTTP_MAPPING.get(biz_code, 500)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 处理限流异常
@app.exception_handler(RateLimitException)
async def rate_limit_exception_handler(request: Request, exc: RateLimitException):
"""处理限流异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.warning(
f"Rate limit exceeded: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
}
)
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.RATE_LIMITED
status_code = HTTP_MAPPING.get(biz_code, 429)
# 创建响应对象并添加限流头信息
response = JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 添加限流相关的响应头
rate_headers = exc.context.get("rate_limit_headers", {}) if exc.context else {}
for header_name, header_value in rate_headers.items():
response.headers[header_name] = str(header_value)
return response
# 业务异常统一处理(使用业务错误码)
@app.exception_handler(BusinessException)
async def business_exception_handler(request: Request, exc: BusinessException):
"""处理通用业务异常"""
# 过滤敏感信息
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
logger.error(
f"Business error: {filtered_message}",
extra={
"path": request.url.path,
"method": request.method,
"context": filtered_context,
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
"cause": str(exc.cause) if exc.cause else None
},
exc_info=exc.cause is not None
)
raw_code = exc.code
if isinstance(raw_code, BizCode):
biz_code = raw_code
elif isinstance(raw_code, int):
try:
biz_code = BizCode(raw_code)
except ValueError:
biz_code = BizCode.BAD_REQUEST
else:
biz_code = BizCode.BAD_REQUEST
status_code = HTTP_MAPPING.get(biz_code, 400)
return JSONResponse(
status_code=status_code,
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
)
# 统一异常处理将HTTPException转换为统一响应结构支持国际化
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
"""处理HTTP异常支持国际化"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取翻译服务
translation_service = get_translation_service()
# 尝试翻译标准HTTP错误
error_key_map = {
400: "errors.common.bad_request",
401: "errors.common.unauthorized",
403: "errors.common.forbidden",
404: "errors.common.not_found",
405: "errors.common.method_not_allowed",
409: "errors.common.conflict",
413: "errors.common.payload_too_large",
422: "errors.common.validation_failed",
429: "errors.common.too_many_requests",
500: "errors.common.internal_error",
502: "errors.common.bad_gateway",
503: "errors.common.service_unavailable",
504: "errors.common.gateway_timeout",
}
# 如果有对应的翻译键,使用翻译
if exc.status_code in error_key_map:
translated_message = translation_service.translate(
error_key_map[exc.status_code],
language
)
else:
# 否则过滤原始消息
translated_message = SensitiveDataFilter.filter_string(str(exc.detail))
logger.warning(
f"HTTP exception: {translated_message}",
extra={
"path": request.url.path,
"method": request.method,
"status_code": exc.status_code,
"language": language
}
)
return JSONResponse(
status_code=exc.status_code,
content=fail(code=exc.status_code, msg=translated_message, error=exc.detail)
)
# 捕获未处理的异常,返回统一错误结构(支持国际化)
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception):
"""处理未捕获的异常,支持国际化"""
# 获取当前语言
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
# 获取翻译服务
translation_service = get_translation_service()
# 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息)
logger.error(
f"Unhandled exception: {exc}",
extra={
"path": request.url.path,
"method": request.method,
"exception_type": type(exc).__name__,
"language": language,
"traceback": traceback.format_exc()
},
exc_info=True
)
# 生产环境隐藏详细错误信息
environment = os.getenv("ENVIRONMENT", "development")
if environment == "production":
# 使用翻译的通用错误消息
message = translation_service.translate(
"errors.common.internal_error",
language
)
else:
# 开发环境也要过滤敏感信息
message = SensitiveDataFilter.filter_string(str(exc))
return JSONResponse(
status_code=500,
content=fail(code=BizCode.INTERNAL_ERROR.value, msg=message, error=message)
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)