Files
MemoryBear/api/app/celery_task_scheduler.py
Eternity f93ec8d609 fix(core): fix end_user_id reference and add task status tracking
- Fix write_router to use actual_end_user_id instead of end_user_id
- Add task status tracking via Redis in scheduler
- Expose task_id in memory write response
- Fix logging import path in scheduler
2026-04-22 18:06:14 +08:00

284 lines
9.1 KiB
Python

import json
import threading
import time
import redis
from app.core.config import settings
from celery_app import celery_app
from app.core.logging_config import get_named_logger
logger = get_named_logger("task_scheduler")
STREAM_KEY = "celery_task_stream"
PENDING_HASH = "scheduler:pending_tasks"
TASK_TIMEOUT = 7800
def health_check_server():
import uvicorn
from fastapi import FastAPI
health_app = FastAPI()
@health_app.get("/")
def health():
return scheduler.health()
threading.Thread(
target=uvicorn.run,
kwargs={
"app": health_app,
"host": "0.0.0.0",
"port": 8001,
"log_config": None
},
daemon=True
).start()
logger.info(f"[Health] Server started at http://0.0.0.0:8001")
class RedisTaskScheduler:
def __init__(self):
self.redis = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB_CELERY_BACKEND,
password=settings.REDIS_PASSWORD,
decode_responses=True,
)
self.running = False
self.dispatched = 0
self.errors = 0
self._leader = False
def push_task(self, task_name, user_id, params):
try:
msg_id = self.redis.xadd(
STREAM_KEY,
fields={
"task_name": task_name,
"user_id": user_id,
"params": json.dumps(params),
}
)
self.redis.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "QUEUED", "task_id": None}),
ex=86400
)
return msg_id
except Exception as e:
logger.error("Push task exception %s", e, exc_info=True)
raise e
def get_task_status(self, msg_id: str) -> dict:
raw = self.redis.get(f"task_tracker:{msg_id}")
if raw is None:
return {"status": "NOT_FOUND"}
tracker = json.loads(raw)
status = tracker["status"]
task_id = tracker.get("task_id")
result_content = tracker.get("result") or {}
if status == "DISPATCHED" and task_id:
result_raw = self.redis.get(f"celery-task-meta-{task_id}")
if result_raw:
result_data = json.loads(result_raw)
status = result_data.get("status", status)
result_content = result_data.get("result")
return {"status": status, "task_id": task_id, "result": result_content}
def _cleanup_finished(self):
pending = self.redis.hgetall(PENDING_HASH)
if not pending:
return
now = time.time()
task_ids = list(pending.keys())
pipe = self.redis.pipeline()
for task_id in task_ids:
pipe.get(f"celery-task-meta-{task_id}")
results = pipe.execute()
cleanup_pipe = self.redis.pipeline()
has_cleanup = False
for task_id, raw_result in zip(task_ids, results):
try:
meta = json.loads(pending[task_id])
lock_key = meta["lock_key"]
dispatched_at = meta.get("dispatched_at", 0)
age = now - dispatched_at
should_cleanup = False
result_data = None
if raw_result is not None:
result_data = json.loads(raw_result)
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
should_cleanup = True
logger.info("Task finished: %s state=%s", task_id, result_data.get("status"))
elif age > TASK_TIMEOUT:
should_cleanup = True
logger.warning(
"Task expired or lost: %s age=%.0fs, force cleanup",
task_id, age,
)
if should_cleanup:
final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED"
cleanup_pipe.delete(lock_key)
cleanup_pipe.hdel(PENDING_HASH, task_id)
tracker_msg_id = meta.get("msg_id")
if tracker_msg_id:
cleanup_pipe.set(
f"task_tracker:{tracker_msg_id}",
json.dumps({
"status": final_status,
"task_id": task_id,
"result": result_data.get("result") or {}
}),
ex=86400,
)
has_cleanup = True
except Exception as e:
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
self.errors += 1
if has_cleanup:
cleanup_pipe.execute()
def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data['user_id']
task_name = msg_data['task_name']
params = json.loads(msg_data.get('params', "{}"))
lock_key = f"{task_name}:{user_id}"
try:
task = celery_app.send_task(task_name, kwargs=params)
pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time(),
"msg_id": msg_id
}))
pipe.xdel(STREAM_KEY, msg_id)
pipe.set(
f"task_tracker:{msg_id}",
json.dumps({"status": "DISPATCHED", "task_id": task.id}),
ex=86400,
)
pipe.execute()
self.dispatched += 1
logger.info("Task dispatched: %s", task.id)
return True
except Exception as e:
self.errors += 1
logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True)
return False
def _leader_lock_extend(self, lock, interval=20):
while self._leader:
try:
lock.extend(60)
except redis.exceptions.LockNotOwnedError:
logger.warning("Lost leader lock during extend")
self._leader = False
except Exception as e:
logger.error("Lock extend error: %s", e)
for _ in range(interval):
if not self._leader:
break
time.sleep(1)
def schedule_loop(self):
self.running = True
self._cleanup_finished()
resp = self.redis.xread(
streams={STREAM_KEY: '0-0'},
count=500,
block=5000,
)
if not resp:
return
messages = []
for stream_key, msgs in resp:
messages.extend(msgs)
lock_keys = []
for msg_id, msg_data in messages:
lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}")
pipe = self.redis.pipeline()
for key in lock_keys:
pipe.exists(key)
lock_exists = pipe.execute()
deliver_keys = set()
for (msg_id, msg_data), locked in zip(messages, lock_exists):
user_id = msg_data['user_id']
lock_key = f"{msg_data['task_name']}:{user_id}"
if locked or lock_key in deliver_keys:
continue
dispatched_successfully = self._dispatch(msg_id, msg_data)
if dispatched_successfully:
deliver_keys.add(lock_key)
time.sleep(0.1)
def run_server(self):
health_check_server()
lock = self.redis.lock(
"scheduler:leader",
timeout=60,
blocking_timeout=10,
thread_local=False
)
while True:
try:
if lock.acquire(blocking=True):
self._leader = True
t = threading.Thread(
target=self._leader_lock_extend,
args=(lock, 20),
daemon=True
)
t.start()
try:
while self._leader:
self.schedule_loop()
finally:
self._leader = False
t.join(timeout=30)
try:
lock.release()
except redis.exceptions.LockNotOwnedError:
pass
self.running = False
else:
time.sleep(5)
except Exception as e:
logger.error("Scheduler exception %s", e, exc_info=True)
time.sleep(5)
def health(self) -> dict:
return {
"running": self.running,
"pending": self.redis.xlen(STREAM_KEY),
"dispatched": self.dispatched,
"errors": self.errors
}
scheduler: RedisTaskScheduler | None = None
if scheduler is None:
scheduler = RedisTaskScheduler()
if __name__ == '__main__':
scheduler.run_server()