[MODIFY] Code optimization
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import subprocess
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -32,7 +33,6 @@ from app.controllers import (
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
|
||||
|
||||
app = FastAPI(title="Data Config API", version="1.0.0")
|
||||
router = APIRouter(prefix="/memory", tags=["Memory"])
|
||||
|
||||
@@ -46,16 +46,16 @@ from app.controllers.service import service_router
|
||||
LoggingConfig.setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""使用 FastAPI lifespan 替代 on_event 处理启动/关闭事件"""
|
||||
# 应用启动事件
|
||||
|
||||
|
||||
# 检查是否需要自动升级数据库
|
||||
if settings.DB_AUTO_UPGRADE:
|
||||
logger.info("开始自动升级数据库...")
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(
|
||||
["alembic", "upgrade", "head"],
|
||||
capture_output=True,
|
||||
@@ -71,11 +71,12 @@ async def lifespan(app: FastAPI):
|
||||
raise
|
||||
else:
|
||||
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
||||
|
||||
|
||||
logger.info("应用程序启动完成")
|
||||
yield
|
||||
# 应用关闭事件
|
||||
logger.info("应用程序正在关闭")
|
||||
logger.info("应用程序正在关闭")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="redbera-mem",
|
||||
@@ -120,10 +121,8 @@ app.include_router(manager_router, prefix="/api")
|
||||
# 服务端 API (API Key 认证)
|
||||
app.include_router(service_router, prefix="/v1")
|
||||
|
||||
|
||||
logger.info("所有路由已注册完成")
|
||||
|
||||
|
||||
# Import additional exception types for specific handling
|
||||
from app.core.exceptions import (
|
||||
ValidationException,
|
||||
@@ -131,7 +130,8 @@ from app.core.exceptions import (
|
||||
PermissionDeniedException,
|
||||
AuthenticationException,
|
||||
AuthorizationException,
|
||||
FileUploadException
|
||||
FileUploadException,
|
||||
RateLimitException
|
||||
)
|
||||
from app.core.sensitive_filter import SensitiveDataFilter
|
||||
import traceback
|
||||
@@ -143,7 +143,7 @@ async def validation_exception_handler(request: Request, exc: ValidationExceptio
|
||||
"""处理验证异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"Validation error: {filtered_message}",
|
||||
extra={
|
||||
@@ -169,7 +169,7 @@ async def not_found_exception_handler(request: Request, exc: ResourceNotFoundExc
|
||||
"""处理资源不存在异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Resource not found: {filtered_message}",
|
||||
extra={
|
||||
@@ -194,7 +194,7 @@ async def permission_denied_handler(request: Request, exc: PermissionDeniedExcep
|
||||
"""处理权限拒绝异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"Permission denied: {filtered_message}",
|
||||
extra={
|
||||
@@ -220,7 +220,7 @@ async def authentication_exception_handler(request: Request, exc: Authentication
|
||||
"""处理认证异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"Authentication error: {filtered_message}",
|
||||
extra={
|
||||
@@ -245,7 +245,7 @@ async def authorization_exception_handler(request: Request, exc: AuthorizationEx
|
||||
"""处理授权异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"Authorization error: {filtered_message}",
|
||||
extra={
|
||||
@@ -270,7 +270,7 @@ async def file_upload_exception_handler(request: Request, exc: FileUploadExcepti
|
||||
"""处理文件上传异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
|
||||
logger.error(
|
||||
f"File upload error: {filtered_message}",
|
||||
extra={
|
||||
@@ -290,13 +290,48 @@ async def file_upload_exception_handler(request: Request, exc: FileUploadExcepti
|
||||
)
|
||||
|
||||
|
||||
# 处理限流异常
|
||||
@app.exception_handler(RateLimitException)
|
||||
async def rate_limit_exception_handler(request: Request, exc: RateLimitException):
|
||||
"""处理限流异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
logger.warning(
|
||||
f"Rate limit exceeded: {filtered_message}",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"method": request.method,
|
||||
"context": filtered_context,
|
||||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||||
"cause": str(exc.cause) if exc.cause else None
|
||||
}
|
||||
)
|
||||
|
||||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.RATE_LIMITED
|
||||
status_code = HTTP_MAPPING.get(biz_code, 429)
|
||||
|
||||
# 创建响应对象并添加限流头信息
|
||||
response = JSONResponse(
|
||||
status_code=status_code,
|
||||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||||
)
|
||||
|
||||
# 添加限流相关的响应头
|
||||
rate_headers = exc.context.get("rate_limit_headers", {}) if exc.context else {}
|
||||
for header_name, header_value in rate_headers.items():
|
||||
response.headers[header_name] = str(header_value)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# 业务异常统一处理(使用业务错误码)
|
||||
@app.exception_handler(BusinessException)
|
||||
async def business_exception_handler(request: Request, exc: BusinessException):
|
||||
"""处理通用业务异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||||
|
||||
|
||||
logger.error(
|
||||
f"Business error: {filtered_message}",
|
||||
extra={
|
||||
@@ -332,7 +367,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
"""处理HTTP异常"""
|
||||
# 过滤敏感信息
|
||||
filtered_detail = SensitiveDataFilter.filter_string(str(exc.detail))
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"HTTP exception: {filtered_detail}",
|
||||
extra={
|
||||
@@ -362,7 +397,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
|
||||
# 生产环境隐藏详细错误信息
|
||||
environment = os.getenv("ENVIRONMENT", "development")
|
||||
if environment == "production":
|
||||
@@ -370,7 +405,7 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
else:
|
||||
# 开发环境也要过滤敏感信息
|
||||
message = SensitiveDataFilter.filter_string(str(exc))
|
||||
|
||||
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content=fail(code=BizCode.INTERNAL_ERROR.value, msg=message, error=message)
|
||||
@@ -379,4 +414,5 @@ async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
Reference in New Issue
Block a user