Files
MemoryBear/api/app/utils/redis_lock.py

155 lines
3.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import redis
import uuid
import time
import threading
UNLOCK_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("del", KEYS[1])
else
return 0
end
"""
RENEW_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
"""
CLEANUP_DEAD_HEAD_SCRIPT = """
local queue_key = KEYS[1]
local lock_key = KEYS[2]
local first = redis.call("lindex", queue_key, 0)
if not first then
return 0
end
if redis.call("exists", lock_key) == 1 then
return 0
end
redis.call("lpop", queue_key)
return 1
"""
SAFE_RELEASE_QUEUE_SCRIPT = """
local queue_key = KEYS[1]
local value = ARGV[1]
local first = redis.call("lindex", queue_key, 0)
if first == value then
redis.call("lpop", queue_key)
return 1
end
return 0
"""
def _ensure_str(val):
"""统一将 Redis 返回值转为 str兼容 decode_responses=True/False"""
if val is None:
return None
if isinstance(val, bytes):
return val.decode("utf-8")
return str(val)
class RedisFairLock:
def __init__(
self,
key: str,
redis_client: redis.StrictRedis,
expire: int = 30,
retry_interval: float = 0.05,
timeout: float = 600,
auto_renewal: bool = True
):
self.key = key
self.queue_key = f"{key}:queue"
self.value = str(uuid.uuid4())
self.expire = expire
self.retry_interval = retry_interval
self.timeout = timeout
self.redis = redis_client
self._locked = False
self.auto_renewal = auto_renewal
self._renew_thread = None
self._stop_renew = threading.Event()
def acquire(self):
start = time.time()
self.redis.rpush(self.queue_key, self.value)
while True:
first = _ensure_str(self.redis.lindex(self.queue_key, 0))
if first == self.value:
ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire)
if ok:
self._locked = True
if self.auto_renewal:
self._start_renewal()
return True
if first:
self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key)
if time.time() - start > self.timeout:
self.redis.lrem(self.queue_key, 0, self.value)
return False
time.sleep(self.retry_interval)
def _renewal_loop(self):
while not self._stop_renew.is_set():
time.sleep(self.expire / 3)
if self._stop_renew.is_set():
break
self.redis.eval(
RENEW_SCRIPT,
1,
self.key,
self.value,
str(self.expire)
)
def _start_renewal(self):
self._stop_renew = threading.Event()
self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True)
self._renew_thread.start()
def _stop_renewal(self):
self._stop_renew.set()
if self._renew_thread:
self._renew_thread.join(timeout=1)
def release(self):
if not self._locked:
return
if self.auto_renewal:
self._stop_renewal()
self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value)
self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value)
self._locked = False
def __enter__(self):
ok = self.acquire()
if not ok:
raise RuntimeError(f"Get redis lock timeout: {self.key}")
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.release()