diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py index 8cdce658..e7f946b6 100644 --- a/api/app/celery_task_scheduler.py +++ b/api/app/celery_task_scheduler.py @@ -1,21 +1,70 @@ +import hashlib import json +import os +import socket import threading import time +import uuid import redis from app.core.config import settings -from celery_app import celery_app from app.core.logging_config import get_named_logger +from app.celery_app import celery_app logger = get_named_logger("task_scheduler") -STREAM_KEY = "celery_task_stream" +# per-user queue scheduler:uq:{user_id} +USER_QUEUE_PREFIX = "scheduler:uq:" +# User Collection of Pending Messages +ACTIVE_USERS = "scheduler:active_users" +# Set of users that can dispatch (ready signal) +READY_SET = "scheduler:ready_users" +# Metadata of tasks that have been dispatched and are pending completion PENDING_HASH = "scheduler:pending_tasks" -TASK_TIMEOUT = 7800 +# Dynamic Sharding: Instance Registry +REGISTRY_KEY = "scheduler:instances" + +TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded +HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds) +INSTANCE_TTL = 30 # Instance timeout (seconds) + +LUA_ATOMIC_LOCK = """ +local dispatch_lock = KEYS[1] +local lock_key = KEYS[2] +local instance_id = ARGV[1] +local dispatch_ttl = tonumber(ARGV[2]) +local lock_ttl = tonumber(ARGV[3]) + +if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then + return 0 +end + +if redis.call('EXISTS', lock_key) == 1 then + redis.call('DEL', dispatch_lock) + return -1 +end + +redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl) +return 1 +""" + +LUA_SAFE_DELETE = """ +if redis.call('GET', KEYS[1]) == ARGV[1] then + return redis.call('DEL', KEYS[1]) +end +return 0 +""" -def health_check_server(): +def stable_hash(value: str) -> int: + return int.from_bytes( + hashlib.md5(value.encode("utf-8")).digest(), + "big" + ) + + +def health_check_server(scheduler_ref): import uvicorn from fastapi import FastAPI @@ -23,19 +72,20 @@ def health_check_server(): @health_app.get("/") def health(): - return scheduler.health() + return scheduler_ref.health() + port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001")) threading.Thread( target=uvicorn.run, kwargs={ "app": health_app, "host": "0.0.0.0", - "port": 8001, - "log_config": None + "port": port, + "log_config": None, }, - daemon=True + daemon=True, ).start() - logger.info(f"[Health] Server started at http://0.0.0.0:8001") + logger.info("[Health] Server started at http://0.0.0.0:%s", port) class RedisTaskScheduler: @@ -50,27 +100,43 @@ class RedisTaskScheduler: self.running = False self.dispatched = 0 self.errors = 0 - self._leader = False + + self.instance_id = f"{socket.gethostname()}-{os.getpid()}" + self._shard_index = 0 + self._shard_count = 1 + self._last_heartbeat = 0.0 def push_task(self, task_name, user_id, params): try: - msg_id = self.redis.xadd( - STREAM_KEY, - fields={ - "task_name": task_name, - "user_id": user_id, - "params": json.dumps(params), - } - ) - self.redis.set( + msg_id = str(uuid.uuid4()) + msg = json.dumps({ + "msg_id": msg_id, + "task_name": task_name, + "user_id": user_id, + "params": json.dumps(params), + }) + + lock_key = f"{task_name}:{user_id}" + queue_key = f"{USER_QUEUE_PREFIX}{user_id}" + + pipe = self.redis.pipeline() + pipe.rpush(queue_key, msg) + pipe.sadd(ACTIVE_USERS, user_id) + pipe.set( f"task_tracker:{msg_id}", json.dumps({"status": "QUEUED", "task_id": None}), - ex=86400 + ex=86400, ) + pipe.execute() + + if not self.redis.exists(lock_key): + self.redis.sadd(READY_SET, user_id) + + logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id) return msg_id except Exception as e: logger.error("Push task exception %s", e, exc_info=True) - raise e + raise def get_task_status(self, msg_id: str) -> dict: raw = self.redis.get(f"task_tracker:{msg_id}") @@ -81,6 +147,7 @@ class RedisTaskScheduler: status = tracker["status"] task_id = tracker.get("task_id") result_content = tracker.get("result") or {} + if status == "DISPATCHED" and task_id: result_raw = self.redis.get(f"celery-task-meta-{task_id}") if result_raw: @@ -105,6 +172,7 @@ class RedisTaskScheduler: cleanup_pipe = self.redis.pipeline() has_cleanup = False + ready_user_ids = set() for task_id, raw_result in zip(task_ids, results): try: @@ -114,12 +182,16 @@ class RedisTaskScheduler: age = now - dispatched_at should_cleanup = False - result_data = None + result_data = {} + if raw_result is not None: result_data = json.loads(raw_result) if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"): should_cleanup = True - logger.info("Task finished: %s state=%s", task_id, result_data.get("status")) + logger.info( + "Task finished: %s state=%s", task_id, + result_data.get("status"), + ) elif age > TASK_TIMEOUT: should_cleanup = True logger.warning( @@ -128,9 +200,14 @@ class RedisTaskScheduler: ) if should_cleanup: - final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED" - cleanup_pipe.delete(lock_key) + final_status = ( + result_data.get("status", "UNKNOWN") if result_data else "EXPIRED" + ) + + self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id) + cleanup_pipe.hdel(PENDING_HASH, task_id) + tracker_msg_id = meta.get("msg_id") if tracker_msg_id: cleanup_pipe.set( @@ -138,146 +215,286 @@ class RedisTaskScheduler: json.dumps({ "status": final_status, "task_id": task_id, - "result": result_data.get("result") or {} + "result": result_data.get("result") or {}, }), ex=86400, ) has_cleanup = True + + parts = lock_key.split(":", 1) + if len(parts) == 2: + ready_user_ids.add(parts[1]) + except Exception as e: logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True) self.errors += 1 + if has_cleanup: cleanup_pipe.execute() + if ready_user_ids: + self.redis.sadd(READY_SET, *ready_user_ids) + + def _heartbeat(self): + now = time.time() + if now - self._last_heartbeat < HEARTBEAT_INTERVAL: + return + self._last_heartbeat = now + + self.redis.hset(REGISTRY_KEY, self.instance_id, str(now)) + + all_instances = self.redis.hgetall(REGISTRY_KEY) + + alive = [] + dead = [] + for iid, ts in all_instances.items(): + if now - float(ts) < INSTANCE_TTL: + alive.append(iid) + else: + dead.append(iid) + + if dead: + pipe = self.redis.pipeline() + for iid in dead: + pipe.hdel(REGISTRY_KEY, iid) + pipe.execute() + logger.info("Cleaned dead instances: %s", dead) + + alive.sort() + self._shard_count = max(len(alive), 1) + self._shard_index = ( + alive.index(self.instance_id) if self.instance_id in alive else 0 + ) + logger.debug( + "Shard: %s/%s (instance=%s, alive=%d)", + self._shard_index, self._shard_count, + self.instance_id, len(alive), + ) + + def _is_mine(self, user_id: str) -> bool: + if self._shard_count <= 1: + return True + return stable_hash(user_id) % self._shard_count == self._shard_index + def _dispatch(self, msg_id, msg_data) -> bool: - user_id = msg_data['user_id'] - task_name = msg_data['task_name'] - params = json.loads(msg_data.get('params', "{}")) + user_id = msg_data["user_id"] + task_name = msg_data["task_name"] + params = json.loads(msg_data.get("params", "{}")) lock_key = f"{task_name}:{user_id}" + dispatch_lock = f"dispatch:{msg_id}" + + result = self.redis.eval( + LUA_ATOMIC_LOCK, 2, + dispatch_lock, lock_key, + self.instance_id, str(300), str(3600), + ) + + if result == 0: + return False + if result == -1: + return False + try: task = celery_app.send_task(task_name, kwargs=params) + except Exception as e: + pipe = self.redis.pipeline() + pipe.delete(dispatch_lock) + pipe.delete(lock_key) + pipe.execute() + self.errors += 1 + logger.error( + "send_task failed for %s:%s msg=%s: %s", + task_name, user_id, msg_id, e, exc_info=True, + ) + return False + + try: pipe = self.redis.pipeline() pipe.set(lock_key, task.id, ex=3600) pipe.hset(PENDING_HASH, task.id, json.dumps({ "lock_key": lock_key, "dispatched_at": time.time(), - "msg_id": msg_id + "msg_id": msg_id, })) - pipe.xdel(STREAM_KEY, msg_id) + pipe.delete(dispatch_lock) pipe.set( f"task_tracker:{msg_id}", json.dumps({"status": "DISPATCHED", "task_id": task.id}), ex=86400, ) pipe.execute() - self.dispatched += 1 - logger.info("Task dispatched: %s", task.id) - return True except Exception as e: + logger.error( + "Post-dispatch state update failed for %s: %s", + task.id, e, exc_info=True, + ) self.errors += 1 - logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True) - return False - def _leader_lock_extend(self, lock, interval=20): - while self._leader: - try: - lock.extend(60) - except redis.exceptions.LockNotOwnedError: - logger.warning("Lost leader lock during extend") - self._leader = False - except Exception as e: - logger.error("Lock extend error: %s", e) - for _ in range(interval): - if not self._leader: - break - time.sleep(1) + self.dispatched += 1 + logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id) + return True - def schedule_loop(self): - self.running = True - self._cleanup_finished() - resp = self.redis.xread( - streams={STREAM_KEY: '0-0'}, - count=500, - block=5000, - ) - if not resp: + def _process_batch(self, user_ids): + if not user_ids: return - messages = [] - for stream_key, msgs in resp: - messages.extend(msgs) + pipe = self.redis.pipeline() + for uid in user_ids: + pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0) + heads = pipe.execute() - lock_keys = [] - for msg_id, msg_data in messages: - lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}") + candidates = [] # (user_id, msg_dict) + empty_users = [] + + for uid, head in zip(user_ids, heads): + if head is None: + empty_users.append(uid) + else: + try: + candidates.append((uid, json.loads(head))) + except (json.JSONDecodeError, TypeError) as e: + logger.error("Bad message in queue for user %s: %s", uid, e) + self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}") + + if empty_users: + pipe = self.redis.pipeline() + for uid in empty_users: + pipe.srem(ACTIVE_USERS, uid) + pipe.execute() + + if not candidates: + return + + for uid, msg in candidates: + if self._dispatch(msg["msg_id"], msg): + self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}") + + def schedule_loop(self): + self._heartbeat() + self._cleanup_finished() pipe = self.redis.pipeline() - for key in lock_keys: - pipe.exists(key) - lock_exists = pipe.execute() + pipe.smembers(READY_SET) + pipe.delete(READY_SET) + results = pipe.execute() + ready_users = results[0] or set() - deliver_keys = set() - for (msg_id, msg_data), locked in zip(messages, lock_exists): - user_id = msg_data['user_id'] - lock_key = f"{msg_data['task_name']}:{user_id}" + my_users = [uid for uid in ready_users if self._is_mine(uid)] - if locked or lock_key in deliver_keys: - continue + if not my_users: + time.sleep(0.5) + return - dispatched_successfully = self._dispatch(msg_id, msg_data) - if dispatched_successfully: - deliver_keys.add(lock_key) + self._process_batch(my_users) time.sleep(0.1) - def run_server(self): - health_check_server() + def _full_scan(self): + cursor = 0 + ready_batch = [] + while True: + cursor, user_ids = self.redis.sscan( + ACTIVE_USERS, cursor=cursor, count=1000, + ) + if user_ids: + my_users = [uid for uid in user_ids if self._is_mine(uid)] + if my_users: + pipe = self.redis.pipeline() + for uid in my_users: + pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0) + heads = pipe.execute() - lock = self.redis.lock( - "scheduler:leader", - timeout=60, - blocking_timeout=10, - thread_local=False + for uid, head in zip(my_users, heads): + if head is None: + continue + try: + msg = json.loads(head) + lock_key = f"{msg['task_name']}:{uid}" + ready_batch.append((uid, lock_key)) + except (json.JSONDecodeError, TypeError): + continue + + if cursor == 0: + break + + if not ready_batch: + return + + pipe = self.redis.pipeline() + for _, lock_key in ready_batch: + pipe.exists(lock_key) + lock_exists = pipe.execute() + + ready_uids = [ + uid for (uid, _), locked in zip(ready_batch, lock_exists) + if not locked + ] + + if ready_uids: + self.redis.sadd(READY_SET, *ready_uids) + logger.info("Full scan found %d ready users", len(ready_uids)) + + def run_server(self): + health_check_server(self) + self.running = True + + last_full_scan = 0.0 + full_scan_interval = 30.0 + + logger.info( + "Scheduler started: instance=%s", self.instance_id, ) + while True: try: - if lock.acquire(blocking=True): - self._leader = True - t = threading.Thread( - target=self._leader_lock_extend, - args=(lock, 20), - daemon=True - ) - t.start() - try: - while self._leader: - self.schedule_loop() - finally: - self._leader = False - t.join(timeout=30) - try: - lock.release() - except redis.exceptions.LockNotOwnedError: - pass - self.running = False - else: - time.sleep(5) + self.schedule_loop() + + now = time.time() + if now - last_full_scan > full_scan_interval: + self._full_scan() + last_full_scan = now + except Exception as e: logger.error("Scheduler exception %s", e, exc_info=True) + self.errors += 1 time.sleep(5) def health(self) -> dict: return { "running": self.running, - "pending": self.redis.xlen(STREAM_KEY), + "active_users": self.redis.scard(ACTIVE_USERS), + "ready_users": self.redis.scard(READY_SET), + "pending_tasks": self.redis.hlen(PENDING_HASH), "dispatched": self.dispatched, - "errors": self.errors + "errors": self.errors, + "shard": f"{self._shard_index}/{self._shard_count}", + "instance": self.instance_id, } + def shutdown(self): + logger.info("Scheduler shutting down: instance=%s", self.instance_id) + self.running = False + try: + self.redis.hdel(REGISTRY_KEY, self.instance_id) + except Exception as e: + logger.error("Shutdown cleanup error: %s", e) + scheduler: RedisTaskScheduler | None = None if scheduler is None: scheduler = RedisTaskScheduler() -if __name__ == '__main__': +if __name__ == "__main__": + import signal + import sys + + + def _signal_handler(signum, frame): + scheduler.shutdown() + sys.exit(0) + + + signal.signal(signal.SIGTERM, _signal_handler) + signal.signal(signal.SIGINT, _signal_handler) + scheduler.run_server() diff --git a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py index 6b166cf2..c46e93f0 100644 --- a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py +++ b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py @@ -8,4 +8,4 @@ class RetrievalSummaryProcessor: @staticmethod def verify(content: str, llm_client: RedBearLLM): - return \ No newline at end of file + return diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index a1ceef86..82d1c463 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -187,7 +187,7 @@ class MemoryAPIService: } ) - logger.info(f"Memory write task submitted, end_user_id={end_user_id}") + logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}") return { "task_id": task_id,