596 lines
20 KiB
Python
596 lines
20 KiB
Python
import os
|
||
import subprocess
|
||
|
||
# 必须在导入任何使用 DashScope SDK 的模块之前应用补丁
|
||
import app.plugins.dashscope_patch # noqa: F401
|
||
|
||
from app.repositories.neo4j.create_indexes import create_all_indexes
|
||
from contextlib import asynccontextmanager
|
||
|
||
from fastapi import FastAPI, APIRouter
|
||
from fastapi import HTTPException, Request
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import JSONResponse
|
||
|
||
# 管理端 API (JWT 认证)
|
||
from app.controllers import manager_router
|
||
# 服务端 API (API Key 认证)
|
||
from app.controllers.service import service_router
|
||
from app.core.config import settings
|
||
from app.core.error_codes import BizCode, HTTP_MAPPING
|
||
from app.core.exceptions import BusinessException
|
||
from app.core.logging_config import LoggingConfig, get_logger
|
||
from app.core.response_utils import fail
|
||
from app.core.models.scripts.loader import load_models
|
||
from app.db import get_db_context
|
||
|
||
# Initialize logging system
|
||
LoggingConfig.setup_logging()
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
"""使用 FastAPI lifespan 替代 on_event 处理启动/关闭事件"""
|
||
# 应用启动事件
|
||
|
||
# 检查是否需要自动升级数据库
|
||
if settings.DB_AUTO_UPGRADE:
|
||
logger.info("开始自动升级数据库...")
|
||
try:
|
||
result = subprocess.run(
|
||
["alembic", "upgrade", "head"],
|
||
capture_output=True,
|
||
text=True,
|
||
check=True
|
||
)
|
||
logger.info(f"数据库升级成功: {result.stdout}")
|
||
except subprocess.CalledProcessError as e:
|
||
logger.error(f"数据库升级失败: {e.stderr}")
|
||
raise RuntimeError(f"数据库升级失败: {e.stderr}")
|
||
except Exception as e:
|
||
logger.error(f"运行数据库升级时出错: {str(e)}")
|
||
raise
|
||
else:
|
||
logger.info("自动数据库升级已禁用 (DB_AUTO_UPGRADE=false)")
|
||
|
||
# 加载预定义模型
|
||
if settings.LOAD_MODEL:
|
||
logger.info("开始加载预定义模型...")
|
||
try:
|
||
with get_db_context() as db:
|
||
result = load_models(db, silent=True)
|
||
logger.info(f"预定义模型加载完成: 成功{result['success']}个, 跳过{result['skipped']}个, 失败{result['failed']}个")
|
||
except Exception as e:
|
||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||
else:
|
||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||
await create_all_indexes()
|
||
logger.info("All neo4j indexes and constraints created successfully!")
|
||
logger.info("应用程序启动完成")
|
||
|
||
|
||
yield
|
||
# 应用关闭事件
|
||
logger.info("应用程序正在关闭")
|
||
|
||
|
||
app = FastAPI(
|
||
title="redbera-mem",
|
||
description="redbera-mem",
|
||
version="1.0.0",
|
||
lifespan=lifespan,
|
||
)
|
||
|
||
# Enable CORS for frontend access with environment-extendable origins
|
||
default_origins = [
|
||
settings.WEB_URL
|
||
]
|
||
allowed_origins = list({o for o in (default_origins + settings.CORS_ORIGINS) if o})
|
||
|
||
# 如果 CORS_ORIGINS 包含 "*",则允许所有来源
|
||
if "*" in settings.CORS_ORIGINS:
|
||
allowed_origins = ["*"]
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=allowed_origins,
|
||
allow_credentials=True if "*" not in allowed_origins else False, # 允许所有来源时不能使用 credentials
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# Add i18n language detection middleware
|
||
from app.i18n.middleware import LanguageMiddleware
|
||
app.add_middleware(LanguageMiddleware)
|
||
|
||
logger.info("FastAPI应用程序启动")
|
||
|
||
|
||
@app.get("/", tags=["General"])
|
||
def read_root():
|
||
"""
|
||
A simple health check endpoint.
|
||
"""
|
||
logger.debug("健康检查端点被访问")
|
||
return {"message": "FastAPI is running"}
|
||
|
||
|
||
# 生命周期事件由 lifespan 管理,无需 on_event
|
||
|
||
|
||
# 注册路由
|
||
# 管理端 API (JWT 认证)
|
||
app.include_router(manager_router, prefix="/api")
|
||
|
||
# 服务端 API (API Key 认证)
|
||
app.include_router(service_router, prefix="/v1")
|
||
|
||
logger.info("所有路由已注册完成")
|
||
|
||
# Import additional exception types for specific handling
|
||
from app.core.exceptions import (
|
||
ValidationException,
|
||
ResourceNotFoundException,
|
||
PermissionDeniedException,
|
||
AuthenticationException,
|
||
AuthorizationException,
|
||
FileUploadException,
|
||
RateLimitException
|
||
)
|
||
from app.core.sensitive_filter import SensitiveDataFilter
|
||
import traceback
|
||
|
||
# Import i18n exception support
|
||
from app.i18n.exceptions import I18nException
|
||
from app.i18n.service import get_translation_service
|
||
from pydantic import ValidationError as PydanticValidationError
|
||
|
||
|
||
# 处理验证异常
|
||
@app.exception_handler(ValidationException)
|
||
async def validation_exception_handler(request: Request, exc: ValidationException):
|
||
"""处理验证异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.warning(
|
||
f"Validation error: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
},
|
||
exc_info=exc.cause is not None
|
||
)
|
||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.VALIDATION_FAILED
|
||
status_code = HTTP_MAPPING.get(biz_code, 400)
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
|
||
# 处理 i18n 异常(国际化异常)
|
||
@app.exception_handler(I18nException)
|
||
async def i18n_exception_handler(request: Request, exc: I18nException):
|
||
"""
|
||
处理国际化异常
|
||
|
||
I18nException 已经自动翻译了错误消息,直接返回即可
|
||
"""
|
||
# 获取当前语言
|
||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||
|
||
# 获取异常详情(已经包含翻译后的消息)
|
||
detail = exc.detail
|
||
|
||
# 过滤敏感信息
|
||
if isinstance(detail, dict):
|
||
filtered_message = SensitiveDataFilter.filter_string(detail.get("message", ""))
|
||
filtered_detail = {
|
||
**detail,
|
||
"message": filtered_message
|
||
}
|
||
else:
|
||
filtered_detail = SensitiveDataFilter.filter_string(str(detail))
|
||
|
||
logger.warning(
|
||
f"I18n exception: {exc.error_key}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"error_code": exc.error_code,
|
||
"error_key": exc.error_key,
|
||
"language": language,
|
||
"status_code": exc.status_code,
|
||
"params": exc.params
|
||
}
|
||
)
|
||
|
||
return JSONResponse(
|
||
status_code=exc.status_code,
|
||
content={
|
||
"success": False,
|
||
**filtered_detail
|
||
},
|
||
headers=exc.headers
|
||
)
|
||
|
||
|
||
# 处理 Pydantic 验证错误(国际化支持)
|
||
@app.exception_handler(PydanticValidationError)
|
||
async def pydantic_validation_exception_handler(request: Request, exc: PydanticValidationError):
|
||
"""
|
||
处理 Pydantic 验证错误,支持国际化
|
||
"""
|
||
# 获取当前语言
|
||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||
|
||
# 获取翻译服务
|
||
translation_service = get_translation_service()
|
||
|
||
# 翻译验证错误消息
|
||
errors = []
|
||
for error in exc.errors():
|
||
field = ".".join(str(loc) for loc in error["loc"])
|
||
error_type = error["type"]
|
||
|
||
# 尝试翻译错误消息
|
||
if error_type == "value_error.missing":
|
||
message = translation_service.translate(
|
||
"errors.validation.missing_field",
|
||
language,
|
||
field=field
|
||
)
|
||
elif error_type == "value_error.any_str.max_length":
|
||
message = translation_service.translate(
|
||
"errors.validation.field_too_long",
|
||
language,
|
||
field=field
|
||
)
|
||
elif error_type == "value_error.any_str.min_length":
|
||
message = translation_service.translate(
|
||
"errors.validation.field_too_short",
|
||
language,
|
||
field=field
|
||
)
|
||
else:
|
||
# 使用通用验证错误消息
|
||
message = translation_service.translate(
|
||
"errors.validation.invalid_field",
|
||
language,
|
||
field=field
|
||
)
|
||
|
||
errors.append({
|
||
"field": field,
|
||
"message": message,
|
||
"type": error_type
|
||
})
|
||
|
||
# 翻译主错误消息
|
||
main_message = translation_service.translate(
|
||
"errors.common.validation_failed",
|
||
language
|
||
)
|
||
|
||
logger.warning(
|
||
f"Pydantic validation error: {len(errors)} errors",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"language": language,
|
||
"errors": errors
|
||
}
|
||
)
|
||
|
||
return JSONResponse(
|
||
status_code=422,
|
||
content={
|
||
"success": False,
|
||
"error_code": "VALIDATION_FAILED",
|
||
"message": main_message,
|
||
"errors": errors
|
||
}
|
||
)
|
||
|
||
|
||
# 处理资源不存在异常
|
||
@app.exception_handler(ResourceNotFoundException)
|
||
async def not_found_exception_handler(request: Request, exc: ResourceNotFoundException):
|
||
"""处理资源不存在异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.info(
|
||
f"Resource not found: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
}
|
||
)
|
||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_NOT_FOUND
|
||
status_code = HTTP_MAPPING.get(biz_code, 404)
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
|
||
# 处理权限拒绝异常
|
||
@app.exception_handler(PermissionDeniedException)
|
||
async def permission_denied_handler(request: Request, exc: PermissionDeniedException):
|
||
"""处理权限拒绝异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.warning(
|
||
f"Permission denied: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"user": getattr(request.state, "user_id", None),
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
}
|
||
)
|
||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN
|
||
status_code = HTTP_MAPPING.get(biz_code, 403)
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
|
||
# 处理认证异常
|
||
@app.exception_handler(AuthenticationException)
|
||
async def authentication_exception_handler(request: Request, exc: AuthenticationException):
|
||
"""处理认证异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.warning(
|
||
f"Authentication error: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
}
|
||
)
|
||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.UNAUTHORIZED
|
||
status_code = HTTP_MAPPING.get(biz_code, 401)
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
|
||
# 处理授权异常
|
||
@app.exception_handler(AuthorizationException)
|
||
async def authorization_exception_handler(request: Request, exc: AuthorizationException):
|
||
"""处理授权异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.warning(
|
||
f"Authorization error: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
}
|
||
)
|
||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FORBIDDEN
|
||
status_code = HTTP_MAPPING.get(biz_code, 403)
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
|
||
# 处理文件上传异常
|
||
@app.exception_handler(FileUploadException)
|
||
async def file_upload_exception_handler(request: Request, exc: FileUploadException):
|
||
"""处理文件上传异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.error(
|
||
f"File upload error: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
},
|
||
exc_info=exc.cause is not None
|
||
)
|
||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.FILE_READ_ERROR
|
||
status_code = HTTP_MAPPING.get(biz_code, 500)
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
|
||
# 处理限流异常
|
||
@app.exception_handler(RateLimitException)
|
||
async def rate_limit_exception_handler(request: Request, exc: RateLimitException):
|
||
"""处理限流异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.warning(
|
||
f"Rate limit exceeded: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
}
|
||
)
|
||
|
||
biz_code = exc.code if isinstance(exc.code, BizCode) else BizCode.RATE_LIMITED
|
||
status_code = HTTP_MAPPING.get(biz_code, 429)
|
||
|
||
# 创建响应对象并添加限流头信息
|
||
response = JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
# 添加限流相关的响应头
|
||
rate_headers = exc.context.get("rate_limit_headers", {}) if exc.context else {}
|
||
for header_name, header_value in rate_headers.items():
|
||
response.headers[header_name] = str(header_value)
|
||
|
||
return response
|
||
|
||
|
||
# 业务异常统一处理(使用业务错误码)
|
||
@app.exception_handler(BusinessException)
|
||
async def business_exception_handler(request: Request, exc: BusinessException):
|
||
"""处理通用业务异常"""
|
||
# 过滤敏感信息
|
||
filtered_message, filtered_context = SensitiveDataFilter.filter_message(exc.message, exc.context)
|
||
|
||
logger.error(
|
||
f"Business error: {filtered_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"context": filtered_context,
|
||
"error_code": exc.code.value if isinstance(exc.code, BizCode) else exc.code,
|
||
"cause": str(exc.cause) if exc.cause else None
|
||
},
|
||
exc_info=exc.cause is not None
|
||
)
|
||
raw_code = exc.code
|
||
if isinstance(raw_code, BizCode):
|
||
biz_code = raw_code
|
||
elif isinstance(raw_code, int):
|
||
try:
|
||
biz_code = BizCode(raw_code)
|
||
except ValueError:
|
||
biz_code = BizCode.BAD_REQUEST
|
||
else:
|
||
biz_code = BizCode.BAD_REQUEST
|
||
|
||
status_code = HTTP_MAPPING.get(biz_code, 400)
|
||
return JSONResponse(
|
||
status_code=status_code,
|
||
content=fail(code=biz_code.value, msg=filtered_message, error=filtered_message)
|
||
)
|
||
|
||
|
||
# 统一异常处理:将HTTPException转换为统一响应结构(支持国际化)
|
||
@app.exception_handler(HTTPException)
|
||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||
"""处理HTTP异常,支持国际化"""
|
||
# 获取当前语言
|
||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||
|
||
# 获取翻译服务
|
||
translation_service = get_translation_service()
|
||
|
||
# 尝试翻译标准HTTP错误
|
||
error_key_map = {
|
||
400: "errors.common.bad_request",
|
||
401: "errors.common.unauthorized",
|
||
403: "errors.common.forbidden",
|
||
404: "errors.common.not_found",
|
||
405: "errors.common.method_not_allowed",
|
||
409: "errors.common.conflict",
|
||
413: "errors.common.payload_too_large",
|
||
422: "errors.common.validation_failed",
|
||
429: "errors.common.too_many_requests",
|
||
500: "errors.common.internal_error",
|
||
502: "errors.common.bad_gateway",
|
||
503: "errors.common.service_unavailable",
|
||
504: "errors.common.gateway_timeout",
|
||
}
|
||
|
||
# 如果有对应的翻译键,使用翻译
|
||
if exc.status_code in error_key_map:
|
||
translated_message = translation_service.translate(
|
||
error_key_map[exc.status_code],
|
||
language
|
||
)
|
||
else:
|
||
# 否则过滤原始消息
|
||
translated_message = SensitiveDataFilter.filter_string(str(exc.detail))
|
||
|
||
logger.warning(
|
||
f"HTTP exception: {translated_message}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"status_code": exc.status_code,
|
||
"language": language
|
||
}
|
||
)
|
||
|
||
return JSONResponse(
|
||
status_code=exc.status_code,
|
||
content=fail(code=exc.status_code, msg=translated_message, error=exc.detail)
|
||
)
|
||
|
||
|
||
# 捕获未处理的异常,返回统一错误结构(支持国际化)
|
||
@app.exception_handler(Exception)
|
||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||
"""处理未捕获的异常,支持国际化"""
|
||
# 获取当前语言
|
||
language = getattr(request.state, "language", settings.I18N_DEFAULT_LANGUAGE)
|
||
|
||
# 获取翻译服务
|
||
translation_service = get_translation_service()
|
||
|
||
# 记录完整的堆栈跟踪(日志过滤器会自动过滤敏感信息)
|
||
logger.error(
|
||
f"Unhandled exception: {exc}",
|
||
extra={
|
||
"path": request.url.path,
|
||
"method": request.method,
|
||
"exception_type": type(exc).__name__,
|
||
"language": language,
|
||
"traceback": traceback.format_exc()
|
||
},
|
||
exc_info=True
|
||
)
|
||
|
||
# 生产环境隐藏详细错误信息
|
||
environment = os.getenv("ENVIRONMENT", "development")
|
||
if environment == "production":
|
||
# 使用翻译的通用错误消息
|
||
message = translation_service.translate(
|
||
"errors.common.internal_error",
|
||
language
|
||
)
|
||
else:
|
||
# 开发环境也要过滤敏感信息
|
||
message = SensitiveDataFilter.filter_string(str(exc))
|
||
|
||
return JSONResponse(
|
||
status_code=500,
|
||
content=fail(code=BizCode.INTERNAL_ERROR.value, msg=message, error=message)
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|