refactor(core): migrate task scheduler to per-user queue with dynamic sharding

This commit is contained in:
Eternity
2026-04-24 13:56:32 +08:00
parent f93ec8d609
commit be10bab763
3 changed files with 326 additions and 109 deletions

View File

@@ -1,21 +1,70 @@
import hashlib
import json import json
import os
import socket
import threading import threading
import time import time
import uuid
import redis import redis
from app.core.config import settings from app.core.config import settings
from celery_app import celery_app
from app.core.logging_config import get_named_logger from app.core.logging_config import get_named_logger
from app.celery_app import celery_app
logger = get_named_logger("task_scheduler") logger = get_named_logger("task_scheduler")
STREAM_KEY = "celery_task_stream" # per-user queue scheduler:uq:{user_id}
USER_QUEUE_PREFIX = "scheduler:uq:"
# User Collection of Pending Messages
ACTIVE_USERS = "scheduler:active_users"
# Set of users that can dispatch (ready signal)
READY_SET = "scheduler:ready_users"
# Metadata of tasks that have been dispatched and are pending completion
PENDING_HASH = "scheduler:pending_tasks" PENDING_HASH = "scheduler:pending_tasks"
TASK_TIMEOUT = 7800 # Dynamic Sharding: Instance Registry
REGISTRY_KEY = "scheduler:instances"
TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded
HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds)
INSTANCE_TTL = 30 # Instance timeout (seconds)
LUA_ATOMIC_LOCK = """
local dispatch_lock = KEYS[1]
local lock_key = KEYS[2]
local instance_id = ARGV[1]
local dispatch_ttl = tonumber(ARGV[2])
local lock_ttl = tonumber(ARGV[3])
if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then
return 0
end
if redis.call('EXISTS', lock_key) == 1 then
redis.call('DEL', dispatch_lock)
return -1
end
redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl)
return 1
"""
LUA_SAFE_DELETE = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
"""
def health_check_server(): def stable_hash(value: str) -> int:
return int.from_bytes(
hashlib.md5(value.encode("utf-8")).digest(),
"big"
)
def health_check_server(scheduler_ref):
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
@@ -23,19 +72,20 @@ def health_check_server():
@health_app.get("/") @health_app.get("/")
def health(): def health():
return scheduler.health() return scheduler_ref.health()
port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001"))
threading.Thread( threading.Thread(
target=uvicorn.run, target=uvicorn.run,
kwargs={ kwargs={
"app": health_app, "app": health_app,
"host": "0.0.0.0", "host": "0.0.0.0",
"port": 8001, "port": port,
"log_config": None "log_config": None,
}, },
daemon=True daemon=True,
).start() ).start()
logger.info(f"[Health] Server started at http://0.0.0.0:8001") logger.info("[Health] Server started at http://0.0.0.0:%s", port)
class RedisTaskScheduler: class RedisTaskScheduler:
@@ -50,27 +100,43 @@ class RedisTaskScheduler:
self.running = False self.running = False
self.dispatched = 0 self.dispatched = 0
self.errors = 0 self.errors = 0
self._leader = False
self.instance_id = f"{socket.gethostname()}-{os.getpid()}"
self._shard_index = 0
self._shard_count = 1
self._last_heartbeat = 0.0
def push_task(self, task_name, user_id, params): def push_task(self, task_name, user_id, params):
try: try:
msg_id = self.redis.xadd( msg_id = str(uuid.uuid4())
STREAM_KEY, msg = json.dumps({
fields={ "msg_id": msg_id,
"task_name": task_name, "task_name": task_name,
"user_id": user_id, "user_id": user_id,
"params": json.dumps(params), "params": json.dumps(params),
} })
)
self.redis.set( lock_key = f"{task_name}:{user_id}"
queue_key = f"{USER_QUEUE_PREFIX}{user_id}"
pipe = self.redis.pipeline()
pipe.rpush(queue_key, msg)
pipe.sadd(ACTIVE_USERS, user_id)
pipe.set(
f"task_tracker:{msg_id}", f"task_tracker:{msg_id}",
json.dumps({"status": "QUEUED", "task_id": None}), json.dumps({"status": "QUEUED", "task_id": None}),
ex=86400 ex=86400,
) )
pipe.execute()
if not self.redis.exists(lock_key):
self.redis.sadd(READY_SET, user_id)
logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id)
return msg_id return msg_id
except Exception as e: except Exception as e:
logger.error("Push task exception %s", e, exc_info=True) logger.error("Push task exception %s", e, exc_info=True)
raise e raise
def get_task_status(self, msg_id: str) -> dict: def get_task_status(self, msg_id: str) -> dict:
raw = self.redis.get(f"task_tracker:{msg_id}") raw = self.redis.get(f"task_tracker:{msg_id}")
@@ -81,6 +147,7 @@ class RedisTaskScheduler:
status = tracker["status"] status = tracker["status"]
task_id = tracker.get("task_id") task_id = tracker.get("task_id")
result_content = tracker.get("result") or {} result_content = tracker.get("result") or {}
if status == "DISPATCHED" and task_id: if status == "DISPATCHED" and task_id:
result_raw = self.redis.get(f"celery-task-meta-{task_id}") result_raw = self.redis.get(f"celery-task-meta-{task_id}")
if result_raw: if result_raw:
@@ -105,6 +172,7 @@ class RedisTaskScheduler:
cleanup_pipe = self.redis.pipeline() cleanup_pipe = self.redis.pipeline()
has_cleanup = False has_cleanup = False
ready_user_ids = set()
for task_id, raw_result in zip(task_ids, results): for task_id, raw_result in zip(task_ids, results):
try: try:
@@ -114,12 +182,16 @@ class RedisTaskScheduler:
age = now - dispatched_at age = now - dispatched_at
should_cleanup = False should_cleanup = False
result_data = None result_data = {}
if raw_result is not None: if raw_result is not None:
result_data = json.loads(raw_result) result_data = json.loads(raw_result)
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"): if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
should_cleanup = True should_cleanup = True
logger.info("Task finished: %s state=%s", task_id, result_data.get("status")) logger.info(
"Task finished: %s state=%s", task_id,
result_data.get("status"),
)
elif age > TASK_TIMEOUT: elif age > TASK_TIMEOUT:
should_cleanup = True should_cleanup = True
logger.warning( logger.warning(
@@ -128,9 +200,14 @@ class RedisTaskScheduler:
) )
if should_cleanup: if should_cleanup:
final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED" final_status = (
cleanup_pipe.delete(lock_key) result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
)
self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id)
cleanup_pipe.hdel(PENDING_HASH, task_id) cleanup_pipe.hdel(PENDING_HASH, task_id)
tracker_msg_id = meta.get("msg_id") tracker_msg_id = meta.get("msg_id")
if tracker_msg_id: if tracker_msg_id:
cleanup_pipe.set( cleanup_pipe.set(
@@ -138,146 +215,286 @@ class RedisTaskScheduler:
json.dumps({ json.dumps({
"status": final_status, "status": final_status,
"task_id": task_id, "task_id": task_id,
"result": result_data.get("result") or {} "result": result_data.get("result") or {},
}), }),
ex=86400, ex=86400,
) )
has_cleanup = True has_cleanup = True
parts = lock_key.split(":", 1)
if len(parts) == 2:
ready_user_ids.add(parts[1])
except Exception as e: except Exception as e:
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True) logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
self.errors += 1 self.errors += 1
if has_cleanup: if has_cleanup:
cleanup_pipe.execute() cleanup_pipe.execute()
if ready_user_ids:
self.redis.sadd(READY_SET, *ready_user_ids)
def _heartbeat(self):
now = time.time()
if now - self._last_heartbeat < HEARTBEAT_INTERVAL:
return
self._last_heartbeat = now
self.redis.hset(REGISTRY_KEY, self.instance_id, str(now))
all_instances = self.redis.hgetall(REGISTRY_KEY)
alive = []
dead = []
for iid, ts in all_instances.items():
if now - float(ts) < INSTANCE_TTL:
alive.append(iid)
else:
dead.append(iid)
if dead:
pipe = self.redis.pipeline()
for iid in dead:
pipe.hdel(REGISTRY_KEY, iid)
pipe.execute()
logger.info("Cleaned dead instances: %s", dead)
alive.sort()
self._shard_count = max(len(alive), 1)
self._shard_index = (
alive.index(self.instance_id) if self.instance_id in alive else 0
)
logger.debug(
"Shard: %s/%s (instance=%s, alive=%d)",
self._shard_index, self._shard_count,
self.instance_id, len(alive),
)
def _is_mine(self, user_id: str) -> bool:
if self._shard_count <= 1:
return True
return stable_hash(user_id) % self._shard_count == self._shard_index
def _dispatch(self, msg_id, msg_data) -> bool: def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data['user_id'] user_id = msg_data["user_id"]
task_name = msg_data['task_name'] task_name = msg_data["task_name"]
params = json.loads(msg_data.get('params', "{}")) params = json.loads(msg_data.get("params", "{}"))
lock_key = f"{task_name}:{user_id}" lock_key = f"{task_name}:{user_id}"
dispatch_lock = f"dispatch:{msg_id}"
result = self.redis.eval(
LUA_ATOMIC_LOCK, 2,
dispatch_lock, lock_key,
self.instance_id, str(300), str(3600),
)
if result == 0:
return False
if result == -1:
return False
try: try:
task = celery_app.send_task(task_name, kwargs=params) task = celery_app.send_task(task_name, kwargs=params)
except Exception as e:
pipe = self.redis.pipeline()
pipe.delete(dispatch_lock)
pipe.delete(lock_key)
pipe.execute()
self.errors += 1
logger.error(
"send_task failed for %s:%s msg=%s: %s",
task_name, user_id, msg_id, e, exc_info=True,
)
return False
try:
pipe = self.redis.pipeline() pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600) pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({ pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key, "lock_key": lock_key,
"dispatched_at": time.time(), "dispatched_at": time.time(),
"msg_id": msg_id "msg_id": msg_id,
})) }))
pipe.xdel(STREAM_KEY, msg_id) pipe.delete(dispatch_lock)
pipe.set( pipe.set(
f"task_tracker:{msg_id}", f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}), json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400, ex=86400,
) )
pipe.execute() pipe.execute()
self.dispatched += 1
logger.info("Task dispatched: %s", task.id)
return True
except Exception as e: except Exception as e:
logger.error(
"Post-dispatch state update failed for %s: %s",
task.id, e, exc_info=True,
)
self.errors += 1 self.errors += 1
logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True)
return False
def _leader_lock_extend(self, lock, interval=20): self.dispatched += 1
while self._leader: logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
try: return True
lock.extend(60)
except redis.exceptions.LockNotOwnedError:
logger.warning("Lost leader lock during extend")
self._leader = False
except Exception as e:
logger.error("Lock extend error: %s", e)
for _ in range(interval):
if not self._leader:
break
time.sleep(1)
def schedule_loop(self): def _process_batch(self, user_ids):
self.running = True if not user_ids:
self._cleanup_finished()
resp = self.redis.xread(
streams={STREAM_KEY: '0-0'},
count=500,
block=5000,
)
if not resp:
return return
messages = [] pipe = self.redis.pipeline()
for stream_key, msgs in resp: for uid in user_ids:
messages.extend(msgs) pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
lock_keys = [] candidates = [] # (user_id, msg_dict)
for msg_id, msg_data in messages: empty_users = []
lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}")
for uid, head in zip(user_ids, heads):
if head is None:
empty_users.append(uid)
else:
try:
candidates.append((uid, json.loads(head)))
except (json.JSONDecodeError, TypeError) as e:
logger.error("Bad message in queue for user %s: %s", uid, e)
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
if empty_users:
pipe = self.redis.pipeline()
for uid in empty_users:
pipe.srem(ACTIVE_USERS, uid)
pipe.execute()
if not candidates:
return
for uid, msg in candidates:
if self._dispatch(msg["msg_id"], msg):
self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
def schedule_loop(self):
self._heartbeat()
self._cleanup_finished()
pipe = self.redis.pipeline() pipe = self.redis.pipeline()
for key in lock_keys: pipe.smembers(READY_SET)
pipe.exists(key) pipe.delete(READY_SET)
lock_exists = pipe.execute() results = pipe.execute()
ready_users = results[0] or set()
deliver_keys = set() my_users = [uid for uid in ready_users if self._is_mine(uid)]
for (msg_id, msg_data), locked in zip(messages, lock_exists):
user_id = msg_data['user_id']
lock_key = f"{msg_data['task_name']}:{user_id}"
if locked or lock_key in deliver_keys: if not my_users:
continue time.sleep(0.5)
return
dispatched_successfully = self._dispatch(msg_id, msg_data) self._process_batch(my_users)
if dispatched_successfully:
deliver_keys.add(lock_key)
time.sleep(0.1) time.sleep(0.1)
def run_server(self): def _full_scan(self):
health_check_server() cursor = 0
ready_batch = []
while True:
cursor, user_ids = self.redis.sscan(
ACTIVE_USERS, cursor=cursor, count=1000,
)
if user_ids:
my_users = [uid for uid in user_ids if self._is_mine(uid)]
if my_users:
pipe = self.redis.pipeline()
for uid in my_users:
pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0)
heads = pipe.execute()
lock = self.redis.lock( for uid, head in zip(my_users, heads):
"scheduler:leader", if head is None:
timeout=60, continue
blocking_timeout=10, try:
thread_local=False msg = json.loads(head)
lock_key = f"{msg['task_name']}:{uid}"
ready_batch.append((uid, lock_key))
except (json.JSONDecodeError, TypeError):
continue
if cursor == 0:
break
if not ready_batch:
return
pipe = self.redis.pipeline()
for _, lock_key in ready_batch:
pipe.exists(lock_key)
lock_exists = pipe.execute()
ready_uids = [
uid for (uid, _), locked in zip(ready_batch, lock_exists)
if not locked
]
if ready_uids:
self.redis.sadd(READY_SET, *ready_uids)
logger.info("Full scan found %d ready users", len(ready_uids))
def run_server(self):
health_check_server(self)
self.running = True
last_full_scan = 0.0
full_scan_interval = 30.0
logger.info(
"Scheduler started: instance=%s", self.instance_id,
) )
while True: while True:
try: try:
if lock.acquire(blocking=True): self.schedule_loop()
self._leader = True
t = threading.Thread( now = time.time()
target=self._leader_lock_extend, if now - last_full_scan > full_scan_interval:
args=(lock, 20), self._full_scan()
daemon=True last_full_scan = now
)
t.start()
try:
while self._leader:
self.schedule_loop()
finally:
self._leader = False
t.join(timeout=30)
try:
lock.release()
except redis.exceptions.LockNotOwnedError:
pass
self.running = False
else:
time.sleep(5)
except Exception as e: except Exception as e:
logger.error("Scheduler exception %s", e, exc_info=True) logger.error("Scheduler exception %s", e, exc_info=True)
self.errors += 1
time.sleep(5) time.sleep(5)
def health(self) -> dict: def health(self) -> dict:
return { return {
"running": self.running, "running": self.running,
"pending": self.redis.xlen(STREAM_KEY), "active_users": self.redis.scard(ACTIVE_USERS),
"ready_users": self.redis.scard(READY_SET),
"pending_tasks": self.redis.hlen(PENDING_HASH),
"dispatched": self.dispatched, "dispatched": self.dispatched,
"errors": self.errors "errors": self.errors,
"shard": f"{self._shard_index}/{self._shard_count}",
"instance": self.instance_id,
} }
def shutdown(self):
logger.info("Scheduler shutting down: instance=%s", self.instance_id)
self.running = False
try:
self.redis.hdel(REGISTRY_KEY, self.instance_id)
except Exception as e:
logger.error("Shutdown cleanup error: %s", e)
scheduler: RedisTaskScheduler | None = None scheduler: RedisTaskScheduler | None = None
if scheduler is None: if scheduler is None:
scheduler = RedisTaskScheduler() scheduler = RedisTaskScheduler()
if __name__ == '__main__': if __name__ == "__main__":
import signal
import sys
def _signal_handler(signum, frame):
scheduler.shutdown()
sys.exit(0)
signal.signal(signal.SIGTERM, _signal_handler)
signal.signal(signal.SIGINT, _signal_handler)
scheduler.run_server() scheduler.run_server()

View File

@@ -8,4 +8,4 @@ class RetrievalSummaryProcessor:
@staticmethod @staticmethod
def verify(content: str, llm_client: RedBearLLM): def verify(content: str, llm_client: RedBearLLM):
return return

View File

@@ -187,7 +187,7 @@ class MemoryAPIService:
} }
) )
logger.info(f"Memory write task submitted, end_user_id={end_user_id}") logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}")
return { return {
"task_id": task_id, "task_id": task_id,