From 2f0bb793d8ea48cabea42e478c04b4ea71140e02 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Thu, 2 Apr 2026 14:49:46 +0800 Subject: [PATCH] feat(memory): Add task result sanitization for JSON serialization - Remove unused TaskStatusResponse import from memory_api_schema - Add _sanitize_task_result() helper function to convert non-serializable types (UUID, datetime) to strings - Update get_write_task_status endpoint to use sanitization instead of TaskStatusResponse validation - Update get_read_task_status endpoint to use sanitization instead of TaskStatusResponse validation - Ensures Celery task results are properly JSON-serializable before returning to clients --- .../service/memory_api_controller.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 34135fec..9acd865f 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -15,7 +15,6 @@ from app.schemas.memory_api_schema import ( MemoryWriteRequest, MemoryWriteResponse, MemoryWriteSyncResponse, - TaskStatusResponse, ) from app.services.memory_api_service import MemoryAPIService @@ -23,6 +22,34 @@ router = APIRouter(prefix="/memory", tags=["V1 - Memory API"]) logger = get_business_logger() +def _sanitize_task_result(result: dict) -> dict: + """Make Celery task result JSON-serializable. + + Converts UUID and other non-serializable values to strings. + + Args: + result: Raw task result dict from task_service + + Returns: + JSON-safe dict + """ + import uuid as _uuid + from datetime import datetime + + def _convert(obj): + if isinstance(obj, dict): + return {k: _convert(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_convert(i) for i in obj] + if isinstance(obj, _uuid.UUID): + return str(obj) + if isinstance(obj, datetime): + return obj.isoformat() + return obj + + return _convert(result) + + @router.get("") async def get_memory_info(): """获取记忆服务信息(占位)""" @@ -80,7 +107,7 @@ async def get_write_task_status( from app.services.task_service import get_task_memory_write_result result = get_task_memory_write_result(task_id) - return success(data=TaskStatusResponse(**result).model_dump(), msg="Task status retrieved") + return success(data=_sanitize_task_result(result), msg="Task status retrieved") @router.post("/read") @@ -135,7 +162,7 @@ async def get_read_task_status( from app.services.task_service import get_task_memory_read_result result = get_task_memory_read_result(task_id) - return success(data=TaskStatusResponse(**result).model_dump(), msg="Task status retrieved") + return success(data=_sanitize_task_result(result), msg="Task status retrieved") @router.post("/write/sync")