From c5ae82c3c2736b67a3ecec2a7c587820d62c3d2c Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 22 Apr 2026 16:50:06 +0800 Subject: [PATCH 1/3] refactor(core): migrate memory write tasks to centralized scheduler --- api/app/celery_app.py | 7 +- api/app/celery_task_scheduler.py | 241 ++++++++++++++++++ .../service/memory_api_controller.py | 2 +- .../langgraph_graph/routing/write_router.py | 56 ++-- api/app/core/memory/pipelines/memory_read.py | 4 +- .../read_services/generate_engine/__init__.py | 0 .../query_preprocessor.py | 0 .../retrieval_summary.py | 0 .../read_services/search_engine/__init__.py | 0 .../{ => search_engine}/content_search.py | 2 +- .../{ => search_engine}/result_builder.py | 0 api/app/core/workflow/nodes/memory/node.py | 25 +- api/app/services/memory_api_service.py | 28 +- api/app/tasks.py | 30 +-- api/docker-compose.yml | 17 ++ 15 files changed, 358 insertions(+), 54 deletions(-) create mode 100644 api/app/celery_task_scheduler.py create mode 100644 api/app/core/memory/read_services/generate_engine/__init__.py rename api/app/core/memory/read_services/{ => generate_engine}/query_preprocessor.py (100%) rename api/app/core/memory/read_services/{ => generate_engine}/retrieval_summary.py (100%) create mode 100644 api/app/core/memory/read_services/search_engine/__init__.py rename api/app/core/memory/read_services/{ => search_engine}/content_search.py (99%) rename api/app/core/memory/read_services/{ => search_engine}/result_builder.py (100%) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index b0894eb8..717709da 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -17,6 +17,7 @@ def _mask_url(url: str) -> str: """隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议""" return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url) + # macOS fork() safety - must be set before any Celery initialization if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') @@ -29,7 +30,7 @@ if platform.system() == 'Darwin': # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md _broker_url = os.getenv("CELERY_BROKER_URL") or \ - f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" + f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url @@ -66,11 +67,11 @@ celery_app.conf.update( task_serializer='json', accept_content=['json'], result_serializer='json', - + # # 时区 # timezone='Asia/Shanghai', # enable_utc=False, - + # 任务追踪 task_track_started=True, task_ignore_result=False, diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py new file mode 100644 index 00000000..88afe341 --- /dev/null +++ b/api/app/celery_task_scheduler.py @@ -0,0 +1,241 @@ +import json +import threading +import time + +import redis + +from app.core.config import settings +from celery_app import celery_app +from core.logging_config import get_named_logger + +logger = get_named_logger("task_scheduler") + +STREAM_KEY = "celery_task_stream" +PENDING_HASH = "scheduler:pending_tasks" +TASK_TIMEOUT = 7800 + + +def health_check_server(): + import uvicorn + from fastapi import FastAPI + + health_app = FastAPI() + + @health_app.get("/") + def health(): + return scheduler.health() + + threading.Thread( + target=uvicorn.run, + kwargs={ + "app": health_app, + "host": "0.0.0.0", + "port": 8001, + "log_config": None + }, + daemon=True + ).start() + logger.info(f"[Health] Server started at http://0.0.0.0:8001") + + +class RedisTaskScheduler: + def __init__(self): + self.redis = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB_CELERY_BACKEND, + password=settings.REDIS_PASSWORD, + decode_responses=True, + ) + self.running = False + self.dispatched = 0 + self.errors = 0 + self._leader = False + + def push_task(self, task_name, user_id, params): + try: + msg_id = self.redis.xadd( + STREAM_KEY, + fields={ + "task_name": task_name, + "user_id": user_id, + "params": json.dumps(params), + } + ) + return msg_id + except Exception as e: + logger.error("Push task exception %s", e, exc_info=True) + raise e + + def _cleanup_finished(self): + pending = self.redis.hgetall(PENDING_HASH) + if not pending: + return + + now = time.time() + task_ids = list(pending.keys()) + + pipe = self.redis.pipeline() + for task_id in task_ids: + pipe.get(f"celery-task-meta-{task_id}") + results = pipe.execute() + + cleanup_pipe = self.redis.pipeline() + has_cleanup = False + + for task_id, raw_result in zip(task_ids, results): + try: + meta = json.loads(pending[task_id]) + lock_key = meta["lock_key"] + dispatched_at = meta.get("dispatched_at", 0) + age = now - dispatched_at + + should_cleanup = False + if raw_result is not None: + result_data = json.loads(raw_result) + if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"): + should_cleanup = True + logger.info("Task finished: %s state=%s", task_id, result_data.get("status")) + elif age > TASK_TIMEOUT: + should_cleanup = True + logger.warning( + "Task expired or lost: %s age=%.0fs, force cleanup", + task_id, age, + ) + + if should_cleanup: + cleanup_pipe.delete(lock_key) + cleanup_pipe.hdel(PENDING_HASH, task_id) + has_cleanup = True + except Exception as e: + logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True) + self.errors += 1 + if has_cleanup: + cleanup_pipe.execute() + + def _dispatch(self, msg_id, msg_data) -> bool: + user_id = msg_data['user_id'] + task_name = msg_data['task_name'] + params = json.loads(msg_data.get('params', "{}")) + + lock_key = f"{task_name}:{user_id}" + try: + task = celery_app.send_task(task_name, kwargs=params) + pipe = self.redis.pipeline() + pipe.set(lock_key, task.id, ex=3600) + pipe.hset(PENDING_HASH, task.id, json.dumps({ + "lock_key": lock_key, + "dispatched_at": time.time() + })) + pipe.xdel(STREAM_KEY, msg_id) + pipe.execute() + self.dispatched += 1 + logger.info("Task dispatched: %s", task.id) + return True + except Exception as e: + self.errors += 1 + logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True) + return False + + def _leader_lock_extend(self, lock, interval=20): + while self._leader: + try: + lock.extend(60) + except redis.exceptions.LockNotOwnedError: + logger.warning("Lost leader lock during extend") + self._leader = False + except Exception as e: + logger.error("Lock extend error: %s", e) + for _ in range(interval): + if not self._leader: + break + time.sleep(1) + + def schedule_loop(self): + self.running = True + self._cleanup_finished() + resp = self.redis.xread( + streams={STREAM_KEY: '0-0'}, + count=500, + block=5000, + ) + if not resp: + return + + messages = [] + for stream_key, msgs in resp: + messages.extend(msgs) + + lock_keys = [] + for msg_id, msg_data in messages: + lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}") + + pipe = self.redis.pipeline() + for key in lock_keys: + pipe.exists(key) + lock_exists = pipe.execute() + + deliver_keys = set() + for (msg_id, msg_data), locked in zip(messages, lock_exists): + user_id = msg_data['user_id'] + lock_key = f"{msg_data['task_name']}:{user_id}" + + if locked or lock_key in deliver_keys: + continue + + key = self._dispatch(msg_id, msg_data) + if key: + deliver_keys.add(lock_key) + time.sleep(0.1) + + def run_server(self): + health_check_server() + + lock = self.redis.lock( + "scheduler:leader", + timeout=60, + blocking_timeout=10, + thread_local=False + ) + while True: + try: + if lock.acquire(blocking=True): + self._leader = True + t = threading.Thread( + target=self._leader_lock_extend, + args=(lock, 20), + daemon=True + ) + t.start() + try: + while self._leader: + self.schedule_loop() + finally: + self._leader = False + t.join(timeout=30) + try: + lock.release() + except redis.exceptions.LockNotOwnedError: + pass + self.running = False + else: + time.sleep(5) + except Exception as e: + logger.error("Scheduler exception %s", e, exc_info=True) + time.sleep(5) + + def health(self) -> dict: + return { + "running": self.running, + "pending": self.redis.xlen(STREAM_KEY), + "dispatched": self.dispatched, + "errors": self.errors + } + + +scheduler: RedisTaskScheduler | None = None +if scheduler is None: + scheduler = RedisTaskScheduler() + +if __name__ == '__main__': + scheduler.run_server() diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 313781d2..de56f56e 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -86,7 +86,7 @@ async def write_memory( user_rag_memory_id=payload.user_rag_memory_id, ) - logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}") + logger.info(f"Memory write task submitted: end_user_id: {payload.end_user_id}") return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted") diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 74fb6bae..50f7ddb9 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -12,9 +12,8 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id +from celery_task_scheduler import scheduler logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') @@ -86,16 +85,28 @@ async def write( logger.info( f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: User ID - structured_messages, # message: JSON string format message list - str(actual_config_id), # config_id: Configuration ID string - storage_type, # storage_type: "neo4j" - user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # write_id = write_message_task.delay( + # actual_end_user_id, # end_user_id: User ID + # structured_messages, # message: JSON string format message list + # str(actual_config_id), # config_id: Configuration ID string + # storage_type, # storage_type: "neo4j" + # user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # ) + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": structured_messages, + "config_id": str(actual_config_id), + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id or "" + } ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + + # logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + # write_status = get_task_memory_write_result(str(write_id)) + # logger.info(f'[WRITE] Task result - user={actual_end_user_id}') async def term_memory_save(end_user_id, strategy_type, scope): @@ -164,13 +175,24 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) else: config_id = memory_config - write_message_task.delay( - end_user_id, # end_user_id: User ID - redis_messages, # message: JSON string format message list - config_id, # config_id: Configuration ID string - AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" - "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": redis_messages, + "config_id": config_id, + "storage_type": AgentMemory_Long_Term.STORAGE_NEO4J, + "user_rag_memory_id": "" + } ) + # write_message_task.delay( + # end_user_id, # end_user_id: User ID + # redis_messages, # message: JSON string format message list + # config_id, # config_id: Configuration ID string + # AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" + # "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # ) count_store.update_sessions_count(end_user_id, 0, []) diff --git a/api/app/core/memory/pipelines/memory_read.py b/api/app/core/memory/pipelines/memory_read.py index 96ff929a..35ace00d 100644 --- a/api/app/core/memory/pipelines/memory_read.py +++ b/api/app/core/memory/pipelines/memory_read.py @@ -1,8 +1,8 @@ from app.core.memory.enums import SearchStrategy, StorageType from app.core.memory.models.service_models import MemorySearchResult from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline -from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService -from app.core.memory.read_services.query_preprocessor import QueryPreprocessor +from core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService +from core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): diff --git a/api/app/core/memory/read_services/generate_engine/__init__.py b/api/app/core/memory/read_services/generate_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/memory/read_services/query_preprocessor.py b/api/app/core/memory/read_services/generate_engine/query_preprocessor.py similarity index 100% rename from api/app/core/memory/read_services/query_preprocessor.py rename to api/app/core/memory/read_services/generate_engine/query_preprocessor.py diff --git a/api/app/core/memory/read_services/retrieval_summary.py b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py similarity index 100% rename from api/app/core/memory/read_services/retrieval_summary.py rename to api/app/core/memory/read_services/generate_engine/retrieval_summary.py diff --git a/api/app/core/memory/read_services/search_engine/__init__.py b/api/app/core/memory/read_services/search_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/memory/read_services/content_search.py b/api/app/core/memory/read_services/search_engine/content_search.py similarity index 99% rename from api/app/core/memory/read_services/content_search.py rename to api/app/core/memory/read_services/search_engine/content_search.py index ef4e90f1..16c23f91 100644 --- a/api/app/core/memory/read_services/content_search.py +++ b/api/app/core/memory/read_services/search_engine/content_search.py @@ -8,7 +8,7 @@ from neo4j import Session from app.core.memory.enums import Neo4jNodeType from app.core.memory.memory_service import MemoryContext from app.core.memory.models.service_models import Memory, MemorySearchResult -from app.core.memory.read_services.result_builder import data_builder_factory +from core.memory.read_services.search_engine.result_builder import data_builder_factory from app.core.models import RedBearEmbeddings from app.core.rag.nlp.search import knowledge_retrieval from app.repositories import knowledge_repository diff --git a/api/app/core/memory/read_services/result_builder.py b/api/app/core/memory/read_services/search_engine/result_builder.py similarity index 100% rename from api/app/core/memory/read_services/result_builder.py rename to api/app/core/memory/read_services/search_engine/result_builder.py diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index bcdc80c7..b74af6b9 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -11,7 +11,7 @@ from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.db import get_db_read from app.schemas import FileInput -from app.tasks import write_message_task +from celery_task_scheduler import scheduler class MemoryReadNode(BaseNode): @@ -126,12 +126,23 @@ class MemoryWriteNode(BaseNode): "files": file_info }) - write_message_task.delay( - end_user_id=end_user_id, - message=messages, - config_id=str(self.typed_config.config_id), - storage_type=state["memory_storage_type"], - user_rag_memory_id=state["user_rag_memory_id"] + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": messages, + "config_id": str(self.typed_config.config_id), + "storage_type": state["memory_storage_type"], + "user_rag_memory_id": state["user_rag_memory_id"] + } ) + # write_message_task.delay( + # end_user_id=end_user_id, + # message=messages, + # config_id=str(self.typed_config.config_id), + # storage_type=state["memory_storage_type"], + # user_rag_memory_id=state["user_rag_memory_id"] + # ) return "success" diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 330b84ad..221ca4cf 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Optional from sqlalchemy.orm import Session +from app.celery_task_scheduler import scheduler from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.logging_config import get_logger @@ -166,19 +167,30 @@ class MemoryAPIService: # Convert to message list format expected by write_message_task messages = message if isinstance(message, list) else [{"role": "user", "content": message}] - from app.tasks import write_message_task - task = write_message_task.delay( + # from app.tasks import write_message_task + # task = write_message_task.delay( + # end_user_id, + # messages, + # config_id, + # storage_type, + # user_rag_memory_id or "", + # ) + scheduler.push_task( + "app.core.memory.agent.write_message", end_user_id, - messages, - config_id, - storage_type, - user_rag_memory_id or "", + { + "end_user_id": end_user_id, + "message": messages, + "config_id": config_id, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id or "" + } ) - logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}") + logger.info(f"Memory write task submitted, end_user_id={end_user_id}") return { - "task_id": task.id, + # "task_id": task.id, "status": "PENDING", "end_user_id": end_user_id, } diff --git a/api/app/tasks.py b/api/app/tasks.py index 8bbbdc6e..5978af4d 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -34,7 +34,7 @@ from app.core.rag.prompts.generator import question_proposal from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ElasticSearchVectorFactory, ) -from app.db import get_db, get_db_context +from app.db import get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema @@ -1993,7 +1993,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di end_users = db.query(EndUser).all() if not end_users: logger.info("没有终端用户,跳过遗忘周期") - return {"status": "SUCCESS", "message": "没有终端用户", + return {"status": "SUCCESS", "message": "没有终端用户", "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0}, "duration_seconds": time.time() - start_time} @@ -2007,7 +2007,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # 获取用户配置(自动回退到工作空间默认配置) connected_config = get_end_user_connected_config(str(end_user.id), db) user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db) - + if not user_config_id: failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"}) continue @@ -2016,13 +2016,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di report = await forget_service.trigger_forgetting_cycle( db=db, end_user_id=str(end_user.id), config_id=user_config_id ) - + total_merged += report.get('merged_count', 0) total_failed += report.get('failed_count', 0) processed_users += 1 - + logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点") - + except Exception as e: logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True) failed_users.append({"end_user_id": str(end_user.id), "error": str(e)}) @@ -2769,18 +2769,18 @@ def run_incremental_clustering( 包含任务执行结果的字典 """ start_time = time.time() - + async def _run() -> Dict[str, Any]: from app.core.logging_config import get_logger from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine - + logger = get_logger(__name__) logger.info( f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, " f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}" ) - + connector = Neo4jConnector() try: engine = LabelPropagationEngine( @@ -2788,12 +2788,12 @@ def run_incremental_clustering( llm_model_id=llm_model_id, embedding_model_id=embedding_model_id, ) - + # 执行增量聚类 await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) - + logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}") - + return { "status": "SUCCESS", "end_user_id": end_user_id, @@ -2804,18 +2804,18 @@ def run_incremental_clustering( raise finally: await connector.close() - + try: loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time result["task_id"] = self.request.id - + logger.info( f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, " f"elapsed_time={result['elapsed_time']:.2f}s" ) - + return result except Exception as e: elapsed_time = time.time() - start_time diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 5d358f2c..a0fd4791 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -63,6 +63,23 @@ services: networks: - celery + celery-task-scheduler: + image: redbear-mem-open:latest + container_name: celery-task-scheduler + env_file: + - .env + volumes: + - /etc/localtime:/etc/localtime:ro + command: python app/celery_task_scheduler.py + restart: unless-stopped + healthcheck: + test: CMD curl -f 127.0.0.1:8001 || exit 1 + interval: 30s + timeout: 5s + retries: 3 + networks: + - celery + # Celery Beat - scheduler beat: image: redbear-mem-open:latest From f93ec8d6096af6d792fe13877c6e55b37242aade Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 22 Apr 2026 17:46:38 +0800 Subject: [PATCH 2/3] fix(core): fix end_user_id reference and add task status tracking - Fix write_router to use actual_end_user_id instead of end_user_id - Add task status tracking via Redis in scheduler - Expose task_id in memory write response - Fix logging import path in scheduler --- api/app/celery_task_scheduler.py | 50 +++++++++++++++++-- .../service/memory_api_controller.py | 6 +-- .../langgraph_graph/routing/write_router.py | 4 +- api/app/schemas/memory_api_schema.py | 8 +-- api/app/services/memory_api_service.py | 6 +-- 5 files changed, 58 insertions(+), 16 deletions(-) diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py index 88afe341..8cdce658 100644 --- a/api/app/celery_task_scheduler.py +++ b/api/app/celery_task_scheduler.py @@ -6,7 +6,7 @@ import redis from app.core.config import settings from celery_app import celery_app -from core.logging_config import get_named_logger +from app.core.logging_config import get_named_logger logger = get_named_logger("task_scheduler") @@ -62,11 +62,34 @@ class RedisTaskScheduler: "params": json.dumps(params), } ) + self.redis.set( + f"task_tracker:{msg_id}", + json.dumps({"status": "QUEUED", "task_id": None}), + ex=86400 + ) return msg_id except Exception as e: logger.error("Push task exception %s", e, exc_info=True) raise e + def get_task_status(self, msg_id: str) -> dict: + raw = self.redis.get(f"task_tracker:{msg_id}") + if raw is None: + return {"status": "NOT_FOUND"} + + tracker = json.loads(raw) + status = tracker["status"] + task_id = tracker.get("task_id") + result_content = tracker.get("result") or {} + if status == "DISPATCHED" and task_id: + result_raw = self.redis.get(f"celery-task-meta-{task_id}") + if result_raw: + result_data = json.loads(result_raw) + status = result_data.get("status", status) + result_content = result_data.get("result") + + return {"status": status, "task_id": task_id, "result": result_content} + def _cleanup_finished(self): pending = self.redis.hgetall(PENDING_HASH) if not pending: @@ -91,6 +114,7 @@ class RedisTaskScheduler: age = now - dispatched_at should_cleanup = False + result_data = None if raw_result is not None: result_data = json.loads(raw_result) if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"): @@ -104,8 +128,20 @@ class RedisTaskScheduler: ) if should_cleanup: + final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED" cleanup_pipe.delete(lock_key) cleanup_pipe.hdel(PENDING_HASH, task_id) + tracker_msg_id = meta.get("msg_id") + if tracker_msg_id: + cleanup_pipe.set( + f"task_tracker:{tracker_msg_id}", + json.dumps({ + "status": final_status, + "task_id": task_id, + "result": result_data.get("result") or {} + }), + ex=86400, + ) has_cleanup = True except Exception as e: logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True) @@ -125,9 +161,15 @@ class RedisTaskScheduler: pipe.set(lock_key, task.id, ex=3600) pipe.hset(PENDING_HASH, task.id, json.dumps({ "lock_key": lock_key, - "dispatched_at": time.time() + "dispatched_at": time.time(), + "msg_id": msg_id })) pipe.xdel(STREAM_KEY, msg_id) + pipe.set( + f"task_tracker:{msg_id}", + json.dumps({"status": "DISPATCHED", "task_id": task.id}), + ex=86400, + ) pipe.execute() self.dispatched += 1 logger.info("Task dispatched: %s", task.id) @@ -183,8 +225,8 @@ class RedisTaskScheduler: if locked or lock_key in deliver_keys: continue - key = self._dispatch(msg_id, msg_data) - if key: + dispatched_successfully = self._dispatch(msg_id, msg_data) + if dispatched_successfully: deliver_keys.add(lock_key) time.sleep(0.1) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index de56f56e..57664e4e 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -18,6 +18,7 @@ from app.schemas.memory_api_schema import ( MemoryWriteSyncResponse, ) from app.services.memory_api_service import MemoryAPIService +from celery_task_scheduler import scheduler router = APIRouter(prefix="/memory", tags=["V1 - Memory API"]) logger = get_business_logger() @@ -86,7 +87,7 @@ async def write_memory( user_rag_memory_id=payload.user_rag_memory_id, ) - logger.info(f"Memory write task submitted: end_user_id: {payload.end_user_id}") + logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}") return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted") @@ -105,8 +106,7 @@ async def get_write_task_status( """ logger.info(f"Write task status check - task_id: {task_id}") - from app.services.task_service import get_task_memory_write_result - result = get_task_memory_write_result(task_id) + result = scheduler.get_task_status(task_id) return success(data=_sanitize_task_result(result), msg="Task status retrieved") diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 50f7ddb9..7ef4ed12 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -94,9 +94,9 @@ async def write( # ) scheduler.push_task( "app.core.memory.agent.write_message", - end_user_id, + actual_end_user_id, { - "end_user_id": end_user_id, + "end_user_id": actual_end_user_id, "message": structured_messages, "config_id": str(actual_config_id), "storage_type": storage_type, diff --git a/api/app/schemas/memory_api_schema.py b/api/app/schemas/memory_api_schema.py index 4cc548f3..7e4ca74a 100644 --- a/api/app/schemas/memory_api_schema.py +++ b/api/app/schemas/memory_api_schema.py @@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel): """Response schema for memory write operation. Attributes: - task_id: Celery task ID for status polling - status: Initial task status (PENDING) + task_id: task ID for status polling + status: Initial task status (QUEUED) end_user_id: End user ID the write was submitted for """ - task_id: str = Field(..., description="Celery task ID for polling") - status: str = Field(..., description="Task status: PENDING") + task_id: str = Field(..., description="task ID for polling") + status: str = Field(..., description="Task status: QUEUED") end_user_id: str = Field(..., description="End user ID") diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 221ca4cf..a1ceef86 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -175,7 +175,7 @@ class MemoryAPIService: # storage_type, # user_rag_memory_id or "", # ) - scheduler.push_task( + task_id = scheduler.push_task( "app.core.memory.agent.write_message", end_user_id, { @@ -190,8 +190,8 @@ class MemoryAPIService: logger.info(f"Memory write task submitted, end_user_id={end_user_id}") return { - # "task_id": task.id, - "status": "PENDING", + "task_id": task_id, + "status": "QUEUED", "end_user_id": end_user_id, } From be10bab763e55b492a77afa9a6124bf8df1878fb Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 24 Apr 2026 13:56:32 +0800 Subject: [PATCH 3/3] refactor(core): migrate task scheduler to per-user queue with dynamic sharding --- api/app/celery_task_scheduler.py | 431 +++++++++++++----- .../generate_engine/retrieval_summary.py | 2 +- api/app/services/memory_api_service.py | 2 +- 3 files changed, 326 insertions(+), 109 deletions(-) diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py index 8cdce658..e7f946b6 100644 --- a/api/app/celery_task_scheduler.py +++ b/api/app/celery_task_scheduler.py @@ -1,21 +1,70 @@ +import hashlib import json +import os +import socket import threading import time +import uuid import redis from app.core.config import settings -from celery_app import celery_app from app.core.logging_config import get_named_logger +from app.celery_app import celery_app logger = get_named_logger("task_scheduler") -STREAM_KEY = "celery_task_stream" +# per-user queue scheduler:uq:{user_id} +USER_QUEUE_PREFIX = "scheduler:uq:" +# User Collection of Pending Messages +ACTIVE_USERS = "scheduler:active_users" +# Set of users that can dispatch (ready signal) +READY_SET = "scheduler:ready_users" +# Metadata of tasks that have been dispatched and are pending completion PENDING_HASH = "scheduler:pending_tasks" -TASK_TIMEOUT = 7800 +# Dynamic Sharding: Instance Registry +REGISTRY_KEY = "scheduler:instances" + +TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded +HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds) +INSTANCE_TTL = 30 # Instance timeout (seconds) + +LUA_ATOMIC_LOCK = """ +local dispatch_lock = KEYS[1] +local lock_key = KEYS[2] +local instance_id = ARGV[1] +local dispatch_ttl = tonumber(ARGV[2]) +local lock_ttl = tonumber(ARGV[3]) + +if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then + return 0 +end + +if redis.call('EXISTS', lock_key) == 1 then + redis.call('DEL', dispatch_lock) + return -1 +end + +redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl) +return 1 +""" + +LUA_SAFE_DELETE = """ +if redis.call('GET', KEYS[1]) == ARGV[1] then + return redis.call('DEL', KEYS[1]) +end +return 0 +""" -def health_check_server(): +def stable_hash(value: str) -> int: + return int.from_bytes( + hashlib.md5(value.encode("utf-8")).digest(), + "big" + ) + + +def health_check_server(scheduler_ref): import uvicorn from fastapi import FastAPI @@ -23,19 +72,20 @@ def health_check_server(): @health_app.get("/") def health(): - return scheduler.health() + return scheduler_ref.health() + port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001")) threading.Thread( target=uvicorn.run, kwargs={ "app": health_app, "host": "0.0.0.0", - "port": 8001, - "log_config": None + "port": port, + "log_config": None, }, - daemon=True + daemon=True, ).start() - logger.info(f"[Health] Server started at http://0.0.0.0:8001") + logger.info("[Health] Server started at http://0.0.0.0:%s", port) class RedisTaskScheduler: @@ -50,27 +100,43 @@ class RedisTaskScheduler: self.running = False self.dispatched = 0 self.errors = 0 - self._leader = False + + self.instance_id = f"{socket.gethostname()}-{os.getpid()}" + self._shard_index = 0 + self._shard_count = 1 + self._last_heartbeat = 0.0 def push_task(self, task_name, user_id, params): try: - msg_id = self.redis.xadd( - STREAM_KEY, - fields={ - "task_name": task_name, - "user_id": user_id, - "params": json.dumps(params), - } - ) - self.redis.set( + msg_id = str(uuid.uuid4()) + msg = json.dumps({ + "msg_id": msg_id, + "task_name": task_name, + "user_id": user_id, + "params": json.dumps(params), + }) + + lock_key = f"{task_name}:{user_id}" + queue_key = f"{USER_QUEUE_PREFIX}{user_id}" + + pipe = self.redis.pipeline() + pipe.rpush(queue_key, msg) + pipe.sadd(ACTIVE_USERS, user_id) + pipe.set( f"task_tracker:{msg_id}", json.dumps({"status": "QUEUED", "task_id": None}), - ex=86400 + ex=86400, ) + pipe.execute() + + if not self.redis.exists(lock_key): + self.redis.sadd(READY_SET, user_id) + + logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id) return msg_id except Exception as e: logger.error("Push task exception %s", e, exc_info=True) - raise e + raise def get_task_status(self, msg_id: str) -> dict: raw = self.redis.get(f"task_tracker:{msg_id}") @@ -81,6 +147,7 @@ class RedisTaskScheduler: status = tracker["status"] task_id = tracker.get("task_id") result_content = tracker.get("result") or {} + if status == "DISPATCHED" and task_id: result_raw = self.redis.get(f"celery-task-meta-{task_id}") if result_raw: @@ -105,6 +172,7 @@ class RedisTaskScheduler: cleanup_pipe = self.redis.pipeline() has_cleanup = False + ready_user_ids = set() for task_id, raw_result in zip(task_ids, results): try: @@ -114,12 +182,16 @@ class RedisTaskScheduler: age = now - dispatched_at should_cleanup = False - result_data = None + result_data = {} + if raw_result is not None: result_data = json.loads(raw_result) if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"): should_cleanup = True - logger.info("Task finished: %s state=%s", task_id, result_data.get("status")) + logger.info( + "Task finished: %s state=%s", task_id, + result_data.get("status"), + ) elif age > TASK_TIMEOUT: should_cleanup = True logger.warning( @@ -128,9 +200,14 @@ class RedisTaskScheduler: ) if should_cleanup: - final_status = result_data.get("status", "UNKNOWN") if result_data else "EXPIRED" - cleanup_pipe.delete(lock_key) + final_status = ( + result_data.get("status", "UNKNOWN") if result_data else "EXPIRED" + ) + + self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id) + cleanup_pipe.hdel(PENDING_HASH, task_id) + tracker_msg_id = meta.get("msg_id") if tracker_msg_id: cleanup_pipe.set( @@ -138,146 +215,286 @@ class RedisTaskScheduler: json.dumps({ "status": final_status, "task_id": task_id, - "result": result_data.get("result") or {} + "result": result_data.get("result") or {}, }), ex=86400, ) has_cleanup = True + + parts = lock_key.split(":", 1) + if len(parts) == 2: + ready_user_ids.add(parts[1]) + except Exception as e: logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True) self.errors += 1 + if has_cleanup: cleanup_pipe.execute() + if ready_user_ids: + self.redis.sadd(READY_SET, *ready_user_ids) + + def _heartbeat(self): + now = time.time() + if now - self._last_heartbeat < HEARTBEAT_INTERVAL: + return + self._last_heartbeat = now + + self.redis.hset(REGISTRY_KEY, self.instance_id, str(now)) + + all_instances = self.redis.hgetall(REGISTRY_KEY) + + alive = [] + dead = [] + for iid, ts in all_instances.items(): + if now - float(ts) < INSTANCE_TTL: + alive.append(iid) + else: + dead.append(iid) + + if dead: + pipe = self.redis.pipeline() + for iid in dead: + pipe.hdel(REGISTRY_KEY, iid) + pipe.execute() + logger.info("Cleaned dead instances: %s", dead) + + alive.sort() + self._shard_count = max(len(alive), 1) + self._shard_index = ( + alive.index(self.instance_id) if self.instance_id in alive else 0 + ) + logger.debug( + "Shard: %s/%s (instance=%s, alive=%d)", + self._shard_index, self._shard_count, + self.instance_id, len(alive), + ) + + def _is_mine(self, user_id: str) -> bool: + if self._shard_count <= 1: + return True + return stable_hash(user_id) % self._shard_count == self._shard_index + def _dispatch(self, msg_id, msg_data) -> bool: - user_id = msg_data['user_id'] - task_name = msg_data['task_name'] - params = json.loads(msg_data.get('params', "{}")) + user_id = msg_data["user_id"] + task_name = msg_data["task_name"] + params = json.loads(msg_data.get("params", "{}")) lock_key = f"{task_name}:{user_id}" + dispatch_lock = f"dispatch:{msg_id}" + + result = self.redis.eval( + LUA_ATOMIC_LOCK, 2, + dispatch_lock, lock_key, + self.instance_id, str(300), str(3600), + ) + + if result == 0: + return False + if result == -1: + return False + try: task = celery_app.send_task(task_name, kwargs=params) + except Exception as e: + pipe = self.redis.pipeline() + pipe.delete(dispatch_lock) + pipe.delete(lock_key) + pipe.execute() + self.errors += 1 + logger.error( + "send_task failed for %s:%s msg=%s: %s", + task_name, user_id, msg_id, e, exc_info=True, + ) + return False + + try: pipe = self.redis.pipeline() pipe.set(lock_key, task.id, ex=3600) pipe.hset(PENDING_HASH, task.id, json.dumps({ "lock_key": lock_key, "dispatched_at": time.time(), - "msg_id": msg_id + "msg_id": msg_id, })) - pipe.xdel(STREAM_KEY, msg_id) + pipe.delete(dispatch_lock) pipe.set( f"task_tracker:{msg_id}", json.dumps({"status": "DISPATCHED", "task_id": task.id}), ex=86400, ) pipe.execute() - self.dispatched += 1 - logger.info("Task dispatched: %s", task.id) - return True except Exception as e: + logger.error( + "Post-dispatch state update failed for %s: %s", + task.id, e, exc_info=True, + ) self.errors += 1 - logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True) - return False - def _leader_lock_extend(self, lock, interval=20): - while self._leader: - try: - lock.extend(60) - except redis.exceptions.LockNotOwnedError: - logger.warning("Lost leader lock during extend") - self._leader = False - except Exception as e: - logger.error("Lock extend error: %s", e) - for _ in range(interval): - if not self._leader: - break - time.sleep(1) + self.dispatched += 1 + logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id) + return True - def schedule_loop(self): - self.running = True - self._cleanup_finished() - resp = self.redis.xread( - streams={STREAM_KEY: '0-0'}, - count=500, - block=5000, - ) - if not resp: + def _process_batch(self, user_ids): + if not user_ids: return - messages = [] - for stream_key, msgs in resp: - messages.extend(msgs) + pipe = self.redis.pipeline() + for uid in user_ids: + pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0) + heads = pipe.execute() - lock_keys = [] - for msg_id, msg_data in messages: - lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}") + candidates = [] # (user_id, msg_dict) + empty_users = [] + + for uid, head in zip(user_ids, heads): + if head is None: + empty_users.append(uid) + else: + try: + candidates.append((uid, json.loads(head))) + except (json.JSONDecodeError, TypeError) as e: + logger.error("Bad message in queue for user %s: %s", uid, e) + self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}") + + if empty_users: + pipe = self.redis.pipeline() + for uid in empty_users: + pipe.srem(ACTIVE_USERS, uid) + pipe.execute() + + if not candidates: + return + + for uid, msg in candidates: + if self._dispatch(msg["msg_id"], msg): + self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}") + + def schedule_loop(self): + self._heartbeat() + self._cleanup_finished() pipe = self.redis.pipeline() - for key in lock_keys: - pipe.exists(key) - lock_exists = pipe.execute() + pipe.smembers(READY_SET) + pipe.delete(READY_SET) + results = pipe.execute() + ready_users = results[0] or set() - deliver_keys = set() - for (msg_id, msg_data), locked in zip(messages, lock_exists): - user_id = msg_data['user_id'] - lock_key = f"{msg_data['task_name']}:{user_id}" + my_users = [uid for uid in ready_users if self._is_mine(uid)] - if locked or lock_key in deliver_keys: - continue + if not my_users: + time.sleep(0.5) + return - dispatched_successfully = self._dispatch(msg_id, msg_data) - if dispatched_successfully: - deliver_keys.add(lock_key) + self._process_batch(my_users) time.sleep(0.1) - def run_server(self): - health_check_server() + def _full_scan(self): + cursor = 0 + ready_batch = [] + while True: + cursor, user_ids = self.redis.sscan( + ACTIVE_USERS, cursor=cursor, count=1000, + ) + if user_ids: + my_users = [uid for uid in user_ids if self._is_mine(uid)] + if my_users: + pipe = self.redis.pipeline() + for uid in my_users: + pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0) + heads = pipe.execute() - lock = self.redis.lock( - "scheduler:leader", - timeout=60, - blocking_timeout=10, - thread_local=False + for uid, head in zip(my_users, heads): + if head is None: + continue + try: + msg = json.loads(head) + lock_key = f"{msg['task_name']}:{uid}" + ready_batch.append((uid, lock_key)) + except (json.JSONDecodeError, TypeError): + continue + + if cursor == 0: + break + + if not ready_batch: + return + + pipe = self.redis.pipeline() + for _, lock_key in ready_batch: + pipe.exists(lock_key) + lock_exists = pipe.execute() + + ready_uids = [ + uid for (uid, _), locked in zip(ready_batch, lock_exists) + if not locked + ] + + if ready_uids: + self.redis.sadd(READY_SET, *ready_uids) + logger.info("Full scan found %d ready users", len(ready_uids)) + + def run_server(self): + health_check_server(self) + self.running = True + + last_full_scan = 0.0 + full_scan_interval = 30.0 + + logger.info( + "Scheduler started: instance=%s", self.instance_id, ) + while True: try: - if lock.acquire(blocking=True): - self._leader = True - t = threading.Thread( - target=self._leader_lock_extend, - args=(lock, 20), - daemon=True - ) - t.start() - try: - while self._leader: - self.schedule_loop() - finally: - self._leader = False - t.join(timeout=30) - try: - lock.release() - except redis.exceptions.LockNotOwnedError: - pass - self.running = False - else: - time.sleep(5) + self.schedule_loop() + + now = time.time() + if now - last_full_scan > full_scan_interval: + self._full_scan() + last_full_scan = now + except Exception as e: logger.error("Scheduler exception %s", e, exc_info=True) + self.errors += 1 time.sleep(5) def health(self) -> dict: return { "running": self.running, - "pending": self.redis.xlen(STREAM_KEY), + "active_users": self.redis.scard(ACTIVE_USERS), + "ready_users": self.redis.scard(READY_SET), + "pending_tasks": self.redis.hlen(PENDING_HASH), "dispatched": self.dispatched, - "errors": self.errors + "errors": self.errors, + "shard": f"{self._shard_index}/{self._shard_count}", + "instance": self.instance_id, } + def shutdown(self): + logger.info("Scheduler shutting down: instance=%s", self.instance_id) + self.running = False + try: + self.redis.hdel(REGISTRY_KEY, self.instance_id) + except Exception as e: + logger.error("Shutdown cleanup error: %s", e) + scheduler: RedisTaskScheduler | None = None if scheduler is None: scheduler = RedisTaskScheduler() -if __name__ == '__main__': +if __name__ == "__main__": + import signal + import sys + + + def _signal_handler(signum, frame): + scheduler.shutdown() + sys.exit(0) + + + signal.signal(signal.SIGTERM, _signal_handler) + signal.signal(signal.SIGINT, _signal_handler) + scheduler.run_server() diff --git a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py index 6b166cf2..c46e93f0 100644 --- a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py +++ b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py @@ -8,4 +8,4 @@ class RetrievalSummaryProcessor: @staticmethod def verify(content: str, llm_client: RedBearLLM): - return \ No newline at end of file + return diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index a1ceef86..82d1c463 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -187,7 +187,7 @@ class MemoryAPIService: } ) - logger.info(f"Memory write task submitted, end_user_id={end_user_id}") + logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}") return { "task_id": task_id,