refactor(core): migrate task scheduler to per-user queue with dynamic sharding

This commit is contained in:
Eternity
2026-04-24 13:56:32 +08:00
parent f93ec8d609
commit be10bab763
3 changed files with 326 additions and 109 deletions

View File

@@ -1,21 +1,70 @@
import hashlib
import json
import os
import socket
import threading
import time
import uuid
import redis
from app.core.config import settings
from celery_app import celery_app
from app.core.logging_config import get_named_logger
from app.celery_app import celery_app
logger = get_named_logger("task_scheduler")
STREAM_KEY = "celery_task_stream"
# per-user queue scheduler:uq:{user_id}
USER_QUEUE_PREFIX = "scheduler:uq:"
# User Collection of Pending Messages
ACTIVE_USERS = "scheduler:active_users"
# Set of users that can dispatch (ready signal)
READY_SET = "scheduler:ready_users"
# Metadata of tasks that have been dispatched and are pending completion
PENDING_HASH = "scheduler:pending_tasks"
TASK_TIMEOUT = 7800
# Dynamic Sharding: Instance Registry
REGISTRY_KEY = "scheduler:instances"
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
INSTANCE_TTL = 30 # Instance timeout (seconds)
LUA_ATOMIC_LOCK = """
local dispatch_lock = KEYS[1]
local lock_key = KEYS[2]
local instance_id = ARGV[1]
local dispatch_ttl = tonumber(ARGV[2])
local lock_ttl = tonumber(ARGV[3])
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
return 0
end
if redis.call('EXISTS', lock_key) == 1 then
redis.call('DEL', dispatch_lock)
return -1
end
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
return 1
"""
LUA_SAFE_DELETE = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
"""
def health_check_server():
def stable_hash(value: str) -> int:
return int.from_bytes(
hashlib.md5(value.encode("utf-8")).digest(),
"big"
)
def health_check_server(scheduler_ref):
import uvicorn
from fastapi import FastAPI
@@ -23,19 +72,20 @@ def health_check_server():
@health_app.get("/")
def health():
return scheduler.health()
return scheduler_ref.health()
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
threading.Thread(
target=uvicorn.run,
kwargs={
"app": health_app,
"host": "0.0.0.0",
"port": 8001,
"log_config": None
"port": port,
"log_config": None,
},
daemon=True
daemon=True,
).start()
logger.info(f"[Health] Server started at http://0.0.0.0:8001")
logger.info("[Health] Server started at http://0.0.0.0:%s", port)
class RedisTaskScheduler:
@@ -50,27 +100,43 @@ class RedisTaskScheduler:
self.running = False
self.dispatched = 0
self.errors = 0
self._leader = False
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
self._shard_index = 0
self._shard_count = 1
self._last_heartbeat = 0.0
def push_task(self, task_name, user_id, params):
try:
msg_id = self.redis.xadd(
STREAM_KEY,
fields={
"task_name": task_name,
"user_id": user_id,
"params": json.dumps(params),
}
)
self.redis.set(
msg_id = str(uuid.uuid4())
msg = json.dumps({
"msg_id": msg_id,
"task_name": task_name,
"user_id": user_id,
"params": json.dumps(params),
})
lock_key = f"{task_name}:{user_id}"
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
pipe = self.redis.pipeline()
pipe.rpush(queue_key, msg)
pipe.sadd(ACTIVE_USERS, user_id)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "QUEUED", "task_id": None}),
ex=86400
ex=86400,
)
pipe.execute()
if not self.redis.exists(lock_key):
self.redis.sadd(READY_SET, user_id)
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
return msg_id
except Exception as e:
logger.error("Push task exception %s", e, exc_info=True)
raise e
raise
def get_task_status(self, msg_id: str) -> dict:
raw = self.redis.get(f"task_tracker:{msg_id}")
@@ -81,6 +147,7 @@ class RedisTaskScheduler:
status = tracker["status"]
task_id = tracker.get("task_id")
result_content = tracker.get("result") or {}
if status == "DISPATCHED" and task_id:
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
if result_raw:
@@ -105,6 +172,7 @@ class RedisTaskScheduler:
cleanup_pipe = self.redis.pipeline()
has_cleanup = False
ready_user_ids = set()
for task_id, raw_result in zip(task_ids, results):
try:
@@ -114,12 +182,16 @@ class RedisTaskScheduler:
age = now - dispatched_at
should_cleanup = False
result_data = None
result_data = {}
if raw_result is not None:
result_data = json.loads(raw_result)
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
should_cleanup = True
logger.info("Task finished: %s state=%s", task_id, result_data.get("status"))
logger.info(
"Task finished: %s state=%s", task_id,
result_data.get("status"),
)
elif age > TASK_TIMEOUT:
should_cleanup = True
logger.warning(
@@ -128,9 +200,14 @@ class RedisTaskScheduler:
)
if should_cleanup:
final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
cleanup_pipe.delete(lock_key)
final_status = (
result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
)
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
cleanup_pipe.hdel(PENDING_HASH, task_id)
tracker_msg_id = meta.get("msg_id")
if tracker_msg_id:
cleanup_pipe.set(
@@ -138,146 +215,286 @@ class RedisTaskScheduler:
json.dumps({
"status": final_status,
"task_id": task_id,
"result": result_data.get("result") or {}
"result": result_data.get("result") or {},
}),
ex=86400,
)
has_cleanup = True
parts = lock_key.split(":", 1)
if len(parts) == 2:
ready_user_ids.add(parts[1])
except Exception as e:
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
self.errors += 1
if has_cleanup:
cleanup_pipe.execute()
if ready_user_ids:
self.redis.sadd(READY_SET, *ready_user_ids)
def _heartbeat(self):
now = time.time()
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
return
self._last_heartbeat = now
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
all_instances = self.redis.hgetall(REGISTRY_KEY)
alive = []
dead = []
for iid, ts in all_instances.items():
if now - float(ts) < INSTANCE_TTL:
alive.append(iid)
else:
dead.append(iid)
if dead:
pipe = self.redis.pipeline()
for iid in dead:
pipe.hdel(REGISTRY_KEY, iid)
pipe.execute()
logger.info("Cleaned dead instances: %s", dead)
alive.sort()
self._shard_count = max(len(alive), 1)
self._shard_index = (
alive.index(self.instance_id) if self.instance_id in alive else 0
)
logger.debug(
"Shard: %s/%s (instance=%s, alive=%d)",
self._shard_index, self._shard_count,
self.instance_id, len(alive),
)
def _is_mine(self, user_id: str) -> bool:
if self._shard_count <= 1:
return True
return stable_hash(user_id) % self._shard_count == self._shard_index
def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data['user_id']
task_name = msg_data['task_name']
params = json.loads(msg_data.get('params', "{}"))
user_id = msg_data["user_id"]
task_name = msg_data["task_name"]
params = json.loads(msg_data.get("params", "{}"))
lock_key = f"{task_name}:{user_id}"
dispatch_lock = f"dispatch:{msg_id}"
result = self.redis.eval(
LUA_ATOMIC_LOCK, 2,
dispatch_lock, lock_key,
self.instance_id, str(300), str(3600),
)
if result == 0:
return False
if result == -1:
return False
try:
task = celery_app.send_task(task_name, kwargs=params)
except Exception as e:
pipe = self.redis.pipeline()
pipe.delete(dispatch_lock)
pipe.delete(lock_key)
pipe.execute()
self.errors += 1
logger.error(
"send_task failed for %s:%s msg=%s: %s",
task_name, user_id, msg_id, e, exc_info=True,
)
return False
try:
pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time(),
"msg_id": msg_id
"msg_id": msg_id,
}))
pipe.xdel(STREAM_KEY, msg_id)
pipe.delete(dispatch_lock)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400,
)
pipe.execute()
self.dispatched += 1
logger.info("Task dispatched: %s", task.id)
return True
except Exception as e:
logger.error(
"Post-dispatch state update failed for %s: %s",
task.id, e, exc_info=True,
)
self.errors += 1
logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True)
return False
def _leader_lock_extend(self, lock, interval=20):
while self._leader:
try:
lock.extend(60)
except redis.exceptions.LockNotOwnedError:
logger.warning("Lost leader lock during extend")
self._leader = False
except Exception as e:
logger.error("Lock extend error: %s", e)
for _ in range(interval):
if not self._leader:
break
time.sleep(1)
self.dispatched += 1
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
return True
def schedule_loop(self):
self.running = True
self._cleanup_finished()
resp = self.redis.xread(
streams={STREAM_KEY: '0-0'},
count=500,
block=5000,
)
if not resp:
def _process_batch(self, user_ids):
if not user_ids:
return
messages = []
for stream_key, msgs in resp:
messages.extend(msgs)
pipe = self.redis.pipeline()
for uid in user_ids:
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
lock_keys = []
for msg_id, msg_data in messages:
lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}")
candidates = [] # (user_id, msg_dict)
empty_users = []
for uid, head in zip(user_ids, heads):
if head is None:
empty_users.append(uid)
else:
try:
candidates.append((uid, json.loads(head)))
except (json.JSONDecodeError, TypeError) as e:
logger.error("Bad message in queue for user %s: %s", uid, e)
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
if empty_users:
pipe = self.redis.pipeline()
for uid in empty_users:
pipe.srem(ACTIVE_USERS, uid)
pipe.execute()
if not candidates:
return
for uid, msg in candidates:
if self._dispatch(msg["msg_id"], msg):
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
def schedule_loop(self):
self._heartbeat()
self._cleanup_finished()
pipe = self.redis.pipeline()
for key in lock_keys:
pipe.exists(key)
lock_exists = pipe.execute()
pipe.smembers(READY_SET)
pipe.delete(READY_SET)
results = pipe.execute()
ready_users = results[0] or set()
deliver_keys = set()
for (msg_id, msg_data), locked in zip(messages, lock_exists):
user_id = msg_data['user_id']
lock_key = f"{msg_data['task_name']}:{user_id}"
my_users = [uid for uid in ready_users if self._is_mine(uid)]
if locked or lock_key in deliver_keys:
continue
if not my_users:
time.sleep(0.5)
return
dispatched_successfully = self._dispatch(msg_id, msg_data)
if dispatched_successfully:
deliver_keys.add(lock_key)
self._process_batch(my_users)
time.sleep(0.1)
def run_server(self):
health_check_server()
def _full_scan(self):
cursor = 0
ready_batch = []
while True:
cursor, user_ids = self.redis.sscan(
ACTIVE_USERS, cursor=cursor, count=1000,
)
if user_ids:
my_users = [uid for uid in user_ids if self._is_mine(uid)]
if my_users:
pipe = self.redis.pipeline()
for uid in my_users:
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
lock = self.redis.lock(
"scheduler:leader",
timeout=60,
blocking_timeout=10,
thread_local=False
for uid, head in zip(my_users, heads):
if head is None:
continue
try:
msg = json.loads(head)
lock_key = f"{msg['task_name']}:{uid}"
ready_batch.append((uid, lock_key))
except (json.JSONDecodeError, TypeError):
continue
if cursor == 0:
break
if not ready_batch:
return
pipe = self.redis.pipeline()
for _, lock_key in ready_batch:
pipe.exists(lock_key)
lock_exists = pipe.execute()
ready_uids = [
uid for (uid, _), locked in zip(ready_batch, lock_exists)
if not locked
]
if ready_uids:
self.redis.sadd(READY_SET, *ready_uids)
logger.info("Full scan found %d ready users", len(ready_uids))
def run_server(self):
health_check_server(self)
self.running = True
last_full_scan = 0.0
full_scan_interval = 30.0
logger.info(
"Scheduler started: instance=%s", self.instance_id,
)
while True:
try:
if lock.acquire(blocking=True):
self._leader = True
t = threading.Thread(
target=self._leader_lock_extend,
args=(lock, 20),
daemon=True
)
t.start()
try:
while self._leader:
self.schedule_loop()
finally:
self._leader = False
t.join(timeout=30)
try:
lock.release()
except redis.exceptions.LockNotOwnedError:
pass
self.running = False
else:
time.sleep(5)
self.schedule_loop()
now = time.time()
if now - last_full_scan > full_scan_interval:
self._full_scan()
last_full_scan = now
except Exception as e:
logger.error("Scheduler exception %s", e, exc_info=True)
self.errors += 1
time.sleep(5)
def health(self) -> dict:
return {
"running": self.running,
"pending": self.redis.xlen(STREAM_KEY),
"active_users": self.redis.scard(ACTIVE_USERS),
"ready_users": self.redis.scard(READY_SET),
"pending_tasks": self.redis.hlen(PENDING_HASH),
"dispatched": self.dispatched,
"errors": self.errors
"errors": self.errors,
"shard": f"{self._shard_index}/{self._shard_count}",
"instance": self.instance_id,
}
def shutdown(self):
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
self.running = False
try:
self.redis.hdel(REGISTRY_KEY, self.instance_id)
except Exception as e:
logger.error("Shutdown cleanup error: %s", e)
scheduler: RedisTaskScheduler | None = None
if scheduler is None:
scheduler = RedisTaskScheduler()
if __name__ == '__main__':
if __name__ == "__main__":
import signal
import sys
def _signal_handler(signum, frame):
scheduler.shutdown()
sys.exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
scheduler.run_server()