From 8f609ba29c636f1456b835118ba821fe23531831 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 1 Apr 2026 11:15:06 +0800 Subject: [PATCH] fix(redis_lock): refactor RedisFairLock to use ZSET for queue management and fix loop shutdown - Replace list-based queue with sorted set for better dead client cleanup - Add zombie cleanup buffer to handle expired queue entries - Fix potential None loop reference in graceful shutdown - Add task start time to write_message_task result - Update lock acquisition script to use ZSET operations - Remove unused queue cleanup scripts - Ensure proper lock release and renewal failure handling --- api/app/tasks.py | 6 ++- api/app/utils/redis_lock.py | 96 ++++++++++++++++++++----------------- 2 files changed, 58 insertions(+), 44 deletions(-) diff --git a/api/app/tasks.py b/api/app/tasks.py index 72421a5f..fa2fa55d 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1176,6 +1176,7 @@ def write_message_task( redis_client = get_sync_redis_client() lock = None + loop = None if redis_client is not None: lock = RedisFairLock( key=f"memory_write:{end_user_id}", @@ -1196,6 +1197,7 @@ def write_message_task( } try: + task_start_time = int(time.time()) loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1219,6 +1221,7 @@ def write_message_task( return { "status": "SUCCESS", "result": result, + "start_at": task_start_time, "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, @@ -1252,7 +1255,8 @@ def write_message_task( logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") # Gracefully shutdown the event loop to prevent # 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__ - _shutdown_loop_gracefully(loop) + if loop: + _shutdown_loop_gracefully(loop) # unused task diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index a86ba46e..f517cbb5 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,14 +1,15 @@ -import redis -import uuid -import time import threading +import time +import uuid + +import redis UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then return redis.call("del", KEYS[1]) -else - return 0 end + +return 0 """ RENEW_SCRIPT = """ @@ -19,38 +20,44 @@ else end """ -CLEANUP_DEAD_HEAD_SCRIPT = """ +ACQUIRE_SCRIPT = """ local queue_key = KEYS[1] local lock_key = KEYS[2] -local first = redis.call("lindex", queue_key, 0) -if not first then - return 0 +local client_id = ARGV[1] +local expire = tonumber(ARGV[2]) +local time_out = tonumber(ARGV[3]) + +local now = tonumber(redis.call("time")[1]) + +if redis.call("zscore", queue_key, client_id) == false then + redis.call("zadd", queue_key, now, client_id) end -if redis.call("exists", lock_key) == 1 then - return 0 +local expired = redis.call("zrangebyscore", queue_key, 0, now - time_out) + +for _, v in ipairs(expired) do + redis.call("zrem", queue_key, v) end -redis.call("lpop", queue_key) -return 1 -""" +local first = redis.call("zrange", queue_key, 0, 0)[1] +if first == client_id then -SAFE_RELEASE_QUEUE_SCRIPT = """ -local queue_key = KEYS[1] -local value = ARGV[1] + if redis.call("set", lock_key, client_id, "NX", "EX", expire) then + redis.call("zrem", queue_key, client_id) + return 1 + end -local first = redis.call("lindex", queue_key, 0) -if first == value then - redis.call("lpop", queue_key) - return 1 + if redis.call("get", lock_key) == client_id then + redis.call("expire", lock_key, expire) + return 1 + end end return 0 """ def _ensure_str(val): - """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" if val is None: return None if isinstance(val, bytes): @@ -59,18 +66,21 @@ def _ensure_str(val): class RedisFairLock: + # ZOMBIE CLEAN BUFFER + CLEANUP_BUFFER = 30 + def __init__( self, key: str, redis_client: redis.StrictRedis, expire: int = 30, - retry_interval: float = 0.05, + retry_interval: float = 1, timeout: float = 600, auto_renewal: bool = True ): self.key = key - self.queue_key = f"{key}:queue" - self.value = str(uuid.uuid4()) + self.queue_key = f"{key}:zset" + self.value = f"{uuid.uuid4().hex}:{int(time.time())}" self.expire = expire self.retry_interval = retry_interval self.timeout = timeout @@ -83,25 +93,25 @@ class RedisFairLock: def acquire(self): start = time.time() - self.redis.rpush(self.queue_key, self.value) - while True: - first = _ensure_str(self.redis.lindex(self.queue_key, 0)) + ok = self.redis.eval( + ACQUIRE_SCRIPT, + 2, + self.queue_key, + self.key, + self.value, + str(self.expire), + str(self.timeout + self.CLEANUP_BUFFER) + ) - if first == self.value: - ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) - if ok: - self._locked = True - - if self.auto_renewal: - self._start_renewal() - return True - - if first: - self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) + if ok == 1: + self._locked = True + if self.auto_renewal: + self._start_renewal() + return True if time.time() - start > self.timeout: - self.redis.lrem(self.queue_key, 0, self.value) + self.redis.zrem(self.queue_key, self.value) return False time.sleep(self.retry_interval) @@ -112,13 +122,15 @@ class RedisFairLock: if self._stop_renew.is_set(): break - self.redis.eval( + success = self.redis.eval( RENEW_SCRIPT, 1, self.key, self.value, str(self.expire) ) + if not success: + break def _start_renewal(self): self._stop_renew = threading.Event() @@ -139,8 +151,6 @@ class RedisFairLock: self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) - self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) - self._locked = False def __enter__(self):