diff --git a/api/app/celery_app.py b/api/app/celery_app.py index b0894eb8..717709da 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -17,6 +17,7 @@ def _mask_url(url: str) -> str: """隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议""" return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url) + # macOS fork() safety - must be set before any Celery initialization if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') @@ -29,7 +30,7 @@ if platform.system() == 'Darwin': # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md _broker_url = os.getenv("CELERY_BROKER_URL") or \ - f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" + f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url @@ -66,11 +67,11 @@ celery_app.conf.update( task_serializer='json', accept_content=['json'], result_serializer='json', - + # # 时区 # timezone='Asia/Shanghai', # enable_utc=False, - + # 任务追踪 task_track_started=True, task_ignore_result=False, diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py new file mode 100644 index 00000000..e7f946b6 --- /dev/null +++ b/api/app/celery_task_scheduler.py @@ -0,0 +1,500 @@ +import hashlib +import json +import os +import socket +import threading +import time +import uuid + +import redis + +from app.core.config import settings +from app.core.logging_config import get_named_logger +from app.celery_app import celery_app + +logger = get_named_logger("task_scheduler") + +# per-user queue scheduler:uq:{user_id} +USER_QUEUE_PREFIX = "scheduler:uq:" +# User Collection of Pending Messages +ACTIVE_USERS = "scheduler:active_users" +# Set of users that can dispatch (ready signal) +READY_SET = "scheduler:ready_users" +# Metadata of tasks that have been dispatched and are pending completion +PENDING_HASH = "scheduler:pending_tasks" +# Dynamic Sharding: Instance Registry +REGISTRY_KEY = "scheduler:instances" + +TASK_TIMEOUT = 7800 # Task timeout (seconds), considered lost if exceeded +HEARTBEAT_INTERVAL = 10 # Heartbeat interval (seconds) +INSTANCE_TTL = 30 # Instance timeout (seconds) + +LUA_ATOMIC_LOCK = """ +local dispatch_lock = KEYS[1] +local lock_key = KEYS[2] +local instance_id = ARGV[1] +local dispatch_ttl = tonumber(ARGV[2]) +local lock_ttl = tonumber(ARGV[3]) + +if redis.call('SET', dispatch_lock, instance_id, 'NX', 'EX', dispatch_ttl) == false then + return 0 +end + +if redis.call('EXISTS', lock_key) == 1 then + redis.call('DEL', dispatch_lock) + return -1 +end + +redis.call('SET', lock_key, 'dispatching', 'EX', lock_ttl) +return 1 +""" + +LUA_SAFE_DELETE = """ +if redis.call('GET', KEYS[1]) == ARGV[1] then + return redis.call('DEL', KEYS[1]) +end +return 0 +""" + + +def stable_hash(value: str) -> int: + return int.from_bytes( + hashlib.md5(value.encode("utf-8")).digest(), + "big" + ) + + +def health_check_server(scheduler_ref): + import uvicorn + from fastapi import FastAPI + + health_app = FastAPI() + + @health_app.get("/") + def health(): + return scheduler_ref.health() + + port = int(os.environ.get("SCHEDULER_HEALTH_PORT", "8001")) + threading.Thread( + target=uvicorn.run, + kwargs={ + "app": health_app, + "host": "0.0.0.0", + "port": port, + "log_config": None, + }, + daemon=True, + ).start() + logger.info("[Health] Server started at http://0.0.0.0:%s", port) + + +class RedisTaskScheduler: + def __init__(self): + self.redis = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB_CELERY_BACKEND, + password=settings.REDIS_PASSWORD, + decode_responses=True, + ) + self.running = False + self.dispatched = 0 + self.errors = 0 + + self.instance_id = f"{socket.gethostname()}-{os.getpid()}" + self._shard_index = 0 + self._shard_count = 1 + self._last_heartbeat = 0.0 + + def push_task(self, task_name, user_id, params): + try: + msg_id = str(uuid.uuid4()) + msg = json.dumps({ + "msg_id": msg_id, + "task_name": task_name, + "user_id": user_id, + "params": json.dumps(params), + }) + + lock_key = f"{task_name}:{user_id}" + queue_key = f"{USER_QUEUE_PREFIX}{user_id}" + + pipe = self.redis.pipeline() + pipe.rpush(queue_key, msg) + pipe.sadd(ACTIVE_USERS, user_id) + pipe.set( + f"task_tracker:{msg_id}", + json.dumps({"status": "QUEUED", "task_id": None}), + ex=86400, + ) + pipe.execute() + + if not self.redis.exists(lock_key): + self.redis.sadd(READY_SET, user_id) + + logger.info("Task pushed: msg_id=%s task=%s user=%s", msg_id, task_name, user_id) + return msg_id + except Exception as e: + logger.error("Push task exception %s", e, exc_info=True) + raise + + def get_task_status(self, msg_id: str) -> dict: + raw = self.redis.get(f"task_tracker:{msg_id}") + if raw is None: + return {"status": "NOT_FOUND"} + + tracker = json.loads(raw) + status = tracker["status"] + task_id = tracker.get("task_id") + result_content = tracker.get("result") or {} + + if status == "DISPATCHED" and task_id: + result_raw = self.redis.get(f"celery-task-meta-{task_id}") + if result_raw: + result_data = json.loads(result_raw) + status = result_data.get("status", status) + result_content = result_data.get("result") + + return {"status": status, "task_id": task_id, "result": result_content} + + def _cleanup_finished(self): + pending = self.redis.hgetall(PENDING_HASH) + if not pending: + return + + now = time.time() + task_ids = list(pending.keys()) + + pipe = self.redis.pipeline() + for task_id in task_ids: + pipe.get(f"celery-task-meta-{task_id}") + results = pipe.execute() + + cleanup_pipe = self.redis.pipeline() + has_cleanup = False + ready_user_ids = set() + + for task_id, raw_result in zip(task_ids, results): + try: + meta = json.loads(pending[task_id]) + lock_key = meta["lock_key"] + dispatched_at = meta.get("dispatched_at", 0) + age = now - dispatched_at + + should_cleanup = False + result_data = {} + + if raw_result is not None: + result_data = json.loads(raw_result) + if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"): + should_cleanup = True + logger.info( + "Task finished: %s state=%s", task_id, + result_data.get("status"), + ) + elif age > TASK_TIMEOUT: + should_cleanup = True + logger.warning( + "Task expired or lost: %s age=%.0fs, force cleanup", + task_id, age, + ) + + if should_cleanup: + final_status = ( + result_data.get("status", "UNKNOWN") if result_data else "EXPIRED" + ) + + self.redis.eval(LUA_SAFE_DELETE, 1, lock_key, task_id) + + cleanup_pipe.hdel(PENDING_HASH, task_id) + + tracker_msg_id = meta.get("msg_id") + if tracker_msg_id: + cleanup_pipe.set( + f"task_tracker:{tracker_msg_id}", + json.dumps({ + "status": final_status, + "task_id": task_id, + "result": result_data.get("result") or {}, + }), + ex=86400, + ) + has_cleanup = True + + parts = lock_key.split(":", 1) + if len(parts) == 2: + ready_user_ids.add(parts[1]) + + except Exception as e: + logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True) + self.errors += 1 + + if has_cleanup: + cleanup_pipe.execute() + + if ready_user_ids: + self.redis.sadd(READY_SET, *ready_user_ids) + + def _heartbeat(self): + now = time.time() + if now - self._last_heartbeat < HEARTBEAT_INTERVAL: + return + self._last_heartbeat = now + + self.redis.hset(REGISTRY_KEY, self.instance_id, str(now)) + + all_instances = self.redis.hgetall(REGISTRY_KEY) + + alive = [] + dead = [] + for iid, ts in all_instances.items(): + if now - float(ts) < INSTANCE_TTL: + alive.append(iid) + else: + dead.append(iid) + + if dead: + pipe = self.redis.pipeline() + for iid in dead: + pipe.hdel(REGISTRY_KEY, iid) + pipe.execute() + logger.info("Cleaned dead instances: %s", dead) + + alive.sort() + self._shard_count = max(len(alive), 1) + self._shard_index = ( + alive.index(self.instance_id) if self.instance_id in alive else 0 + ) + logger.debug( + "Shard: %s/%s (instance=%s, alive=%d)", + self._shard_index, self._shard_count, + self.instance_id, len(alive), + ) + + def _is_mine(self, user_id: str) -> bool: + if self._shard_count <= 1: + return True + return stable_hash(user_id) % self._shard_count == self._shard_index + + def _dispatch(self, msg_id, msg_data) -> bool: + user_id = msg_data["user_id"] + task_name = msg_data["task_name"] + params = json.loads(msg_data.get("params", "{}")) + + lock_key = f"{task_name}:{user_id}" + dispatch_lock = f"dispatch:{msg_id}" + + result = self.redis.eval( + LUA_ATOMIC_LOCK, 2, + dispatch_lock, lock_key, + self.instance_id, str(300), str(3600), + ) + + if result == 0: + return False + if result == -1: + return False + + try: + task = celery_app.send_task(task_name, kwargs=params) + except Exception as e: + pipe = self.redis.pipeline() + pipe.delete(dispatch_lock) + pipe.delete(lock_key) + pipe.execute() + self.errors += 1 + logger.error( + "send_task failed for %s:%s msg=%s: %s", + task_name, user_id, msg_id, e, exc_info=True, + ) + return False + + try: + pipe = self.redis.pipeline() + pipe.set(lock_key, task.id, ex=3600) + pipe.hset(PENDING_HASH, task.id, json.dumps({ + "lock_key": lock_key, + "dispatched_at": time.time(), + "msg_id": msg_id, + })) + pipe.delete(dispatch_lock) + pipe.set( + f"task_tracker:{msg_id}", + json.dumps({"status": "DISPATCHED", "task_id": task.id}), + ex=86400, + ) + pipe.execute() + except Exception as e: + logger.error( + "Post-dispatch state update failed for %s: %s", + task.id, e, exc_info=True, + ) + self.errors += 1 + + self.dispatched += 1 + logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id) + return True + + def _process_batch(self, user_ids): + if not user_ids: + return + + pipe = self.redis.pipeline() + for uid in user_ids: + pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0) + heads = pipe.execute() + + candidates = [] # (user_id, msg_dict) + empty_users = [] + + for uid, head in zip(user_ids, heads): + if head is None: + empty_users.append(uid) + else: + try: + candidates.append((uid, json.loads(head))) + except (json.JSONDecodeError, TypeError) as e: + logger.error("Bad message in queue for user %s: %s", uid, e) + self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}") + + if empty_users: + pipe = self.redis.pipeline() + for uid in empty_users: + pipe.srem(ACTIVE_USERS, uid) + pipe.execute() + + if not candidates: + return + + for uid, msg in candidates: + if self._dispatch(msg["msg_id"], msg): + self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}") + + def schedule_loop(self): + self._heartbeat() + self._cleanup_finished() + + pipe = self.redis.pipeline() + pipe.smembers(READY_SET) + pipe.delete(READY_SET) + results = pipe.execute() + ready_users = results[0] or set() + + my_users = [uid for uid in ready_users if self._is_mine(uid)] + + if not my_users: + time.sleep(0.5) + return + + self._process_batch(my_users) + time.sleep(0.1) + + def _full_scan(self): + cursor = 0 + ready_batch = [] + while True: + cursor, user_ids = self.redis.sscan( + ACTIVE_USERS, cursor=cursor, count=1000, + ) + if user_ids: + my_users = [uid for uid in user_ids if self._is_mine(uid)] + if my_users: + pipe = self.redis.pipeline() + for uid in my_users: + pipe.lindex(f"{USER_QUEUE_PREFIX}{uid}", 0) + heads = pipe.execute() + + for uid, head in zip(my_users, heads): + if head is None: + continue + try: + msg = json.loads(head) + lock_key = f"{msg['task_name']}:{uid}" + ready_batch.append((uid, lock_key)) + except (json.JSONDecodeError, TypeError): + continue + + if cursor == 0: + break + + if not ready_batch: + return + + pipe = self.redis.pipeline() + for _, lock_key in ready_batch: + pipe.exists(lock_key) + lock_exists = pipe.execute() + + ready_uids = [ + uid for (uid, _), locked in zip(ready_batch, lock_exists) + if not locked + ] + + if ready_uids: + self.redis.sadd(READY_SET, *ready_uids) + logger.info("Full scan found %d ready users", len(ready_uids)) + + def run_server(self): + health_check_server(self) + self.running = True + + last_full_scan = 0.0 + full_scan_interval = 30.0 + + logger.info( + "Scheduler started: instance=%s", self.instance_id, + ) + + while True: + try: + self.schedule_loop() + + now = time.time() + if now - last_full_scan > full_scan_interval: + self._full_scan() + last_full_scan = now + + except Exception as e: + logger.error("Scheduler exception %s", e, exc_info=True) + self.errors += 1 + time.sleep(5) + + def health(self) -> dict: + return { + "running": self.running, + "active_users": self.redis.scard(ACTIVE_USERS), + "ready_users": self.redis.scard(READY_SET), + "pending_tasks": self.redis.hlen(PENDING_HASH), + "dispatched": self.dispatched, + "errors": self.errors, + "shard": f"{self._shard_index}/{self._shard_count}", + "instance": self.instance_id, + } + + def shutdown(self): + logger.info("Scheduler shutting down: instance=%s", self.instance_id) + self.running = False + try: + self.redis.hdel(REGISTRY_KEY, self.instance_id) + except Exception as e: + logger.error("Shutdown cleanup error: %s", e) + + +scheduler: RedisTaskScheduler | None = None +if scheduler is None: + scheduler = RedisTaskScheduler() + +if __name__ == "__main__": + import signal + import sys + + + def _signal_handler(signum, frame): + scheduler.shutdown() + sys.exit(0) + + + signal.signal(signal.SIGTERM, _signal_handler) + signal.signal(signal.SIGINT, _signal_handler) + + scheduler.run_server() diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 313781d2..57664e4e 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -18,6 +18,7 @@ from app.schemas.memory_api_schema import ( MemoryWriteSyncResponse, ) from app.services.memory_api_service import MemoryAPIService +from celery_task_scheduler import scheduler router = APIRouter(prefix="/memory", tags=["V1 - Memory API"]) logger = get_business_logger() @@ -86,7 +87,7 @@ async def write_memory( user_rag_memory_id=payload.user_rag_memory_id, ) - logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}") + logger.info(f"Memory write task submitted: task_id: {result['task_id']} end_user_id: {payload.end_user_id}") return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted") @@ -105,8 +106,7 @@ async def get_write_task_status( """ logger.info(f"Write task status check - task_id: {task_id}") - from app.services.task_service import get_task_memory_write_result - result = get_task_memory_write_result(task_id) + result = scheduler.get_task_status(task_id) return success(data=_sanitize_task_result(result), msg="Task status retrieved") diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 74fb6bae..7ef4ed12 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -12,9 +12,8 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id +from celery_task_scheduler import scheduler logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') @@ -86,16 +85,28 @@ async def write( logger.info( f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: User ID - structured_messages, # message: JSON string format message list - str(actual_config_id), # config_id: Configuration ID string - storage_type, # storage_type: "neo4j" - user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # write_id = write_message_task.delay( + # actual_end_user_id, # end_user_id: User ID + # structured_messages, # message: JSON string format message list + # str(actual_config_id), # config_id: Configuration ID string + # storage_type, # storage_type: "neo4j" + # user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # ) + scheduler.push_task( + "app.core.memory.agent.write_message", + actual_end_user_id, + { + "end_user_id": actual_end_user_id, + "message": structured_messages, + "config_id": str(actual_config_id), + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id or "" + } ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + + # logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + # write_status = get_task_memory_write_result(str(write_id)) + # logger.info(f'[WRITE] Task result - user={actual_end_user_id}') async def term_memory_save(end_user_id, strategy_type, scope): @@ -164,13 +175,24 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) else: config_id = memory_config - write_message_task.delay( - end_user_id, # end_user_id: User ID - redis_messages, # message: JSON string format message list - config_id, # config_id: Configuration ID string - AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" - "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": redis_messages, + "config_id": config_id, + "storage_type": AgentMemory_Long_Term.STORAGE_NEO4J, + "user_rag_memory_id": "" + } ) + # write_message_task.delay( + # end_user_id, # end_user_id: User ID + # redis_messages, # message: JSON string format message list + # config_id, # config_id: Configuration ID string + # AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" + # "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # ) count_store.update_sessions_count(end_user_id, 0, []) diff --git a/api/app/core/memory/pipelines/memory_read.py b/api/app/core/memory/pipelines/memory_read.py index 96ff929a..35ace00d 100644 --- a/api/app/core/memory/pipelines/memory_read.py +++ b/api/app/core/memory/pipelines/memory_read.py @@ -1,8 +1,8 @@ from app.core.memory.enums import SearchStrategy, StorageType from app.core.memory.models.service_models import MemorySearchResult from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline -from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService -from app.core.memory.read_services.query_preprocessor import QueryPreprocessor +from core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService +from core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): diff --git a/api/app/core/memory/read_services/generate_engine/__init__.py b/api/app/core/memory/read_services/generate_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/memory/read_services/query_preprocessor.py b/api/app/core/memory/read_services/generate_engine/query_preprocessor.py similarity index 100% rename from api/app/core/memory/read_services/query_preprocessor.py rename to api/app/core/memory/read_services/generate_engine/query_preprocessor.py diff --git a/api/app/core/memory/read_services/retrieval_summary.py b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py similarity index 94% rename from api/app/core/memory/read_services/retrieval_summary.py rename to api/app/core/memory/read_services/generate_engine/retrieval_summary.py index 6b166cf2..c46e93f0 100644 --- a/api/app/core/memory/read_services/retrieval_summary.py +++ b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py @@ -8,4 +8,4 @@ class RetrievalSummaryProcessor: @staticmethod def verify(content: str, llm_client: RedBearLLM): - return \ No newline at end of file + return diff --git a/api/app/core/memory/read_services/search_engine/__init__.py b/api/app/core/memory/read_services/search_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/memory/read_services/content_search.py b/api/app/core/memory/read_services/search_engine/content_search.py similarity index 99% rename from api/app/core/memory/read_services/content_search.py rename to api/app/core/memory/read_services/search_engine/content_search.py index ef4e90f1..16c23f91 100644 --- a/api/app/core/memory/read_services/content_search.py +++ b/api/app/core/memory/read_services/search_engine/content_search.py @@ -8,7 +8,7 @@ from neo4j import Session from app.core.memory.enums import Neo4jNodeType from app.core.memory.memory_service import MemoryContext from app.core.memory.models.service_models import Memory, MemorySearchResult -from app.core.memory.read_services.result_builder import data_builder_factory +from core.memory.read_services.search_engine.result_builder import data_builder_factory from app.core.models import RedBearEmbeddings from app.core.rag.nlp.search import knowledge_retrieval from app.repositories import knowledge_repository diff --git a/api/app/core/memory/read_services/result_builder.py b/api/app/core/memory/read_services/search_engine/result_builder.py similarity index 100% rename from api/app/core/memory/read_services/result_builder.py rename to api/app/core/memory/read_services/search_engine/result_builder.py diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index bcdc80c7..b74af6b9 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -11,7 +11,7 @@ from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.db import get_db_read from app.schemas import FileInput -from app.tasks import write_message_task +from celery_task_scheduler import scheduler class MemoryReadNode(BaseNode): @@ -126,12 +126,23 @@ class MemoryWriteNode(BaseNode): "files": file_info }) - write_message_task.delay( - end_user_id=end_user_id, - message=messages, - config_id=str(self.typed_config.config_id), - storage_type=state["memory_storage_type"], - user_rag_memory_id=state["user_rag_memory_id"] + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": messages, + "config_id": str(self.typed_config.config_id), + "storage_type": state["memory_storage_type"], + "user_rag_memory_id": state["user_rag_memory_id"] + } ) + # write_message_task.delay( + # end_user_id=end_user_id, + # message=messages, + # config_id=str(self.typed_config.config_id), + # storage_type=state["memory_storage_type"], + # user_rag_memory_id=state["user_rag_memory_id"] + # ) return "success" diff --git a/api/app/schemas/memory_api_schema.py b/api/app/schemas/memory_api_schema.py index 4cc548f3..7e4ca74a 100644 --- a/api/app/schemas/memory_api_schema.py +++ b/api/app/schemas/memory_api_schema.py @@ -112,12 +112,12 @@ class MemoryWriteResponse(BaseModel): """Response schema for memory write operation. Attributes: - task_id: Celery task ID for status polling - status: Initial task status (PENDING) + task_id: task ID for status polling + status: Initial task status (QUEUED) end_user_id: End user ID the write was submitted for """ - task_id: str = Field(..., description="Celery task ID for polling") - status: str = Field(..., description="Task status: PENDING") + task_id: str = Field(..., description="task ID for polling") + status: str = Field(..., description="Task status: QUEUED") end_user_id: str = Field(..., description="End user ID") diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 330b84ad..82d1c463 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Optional from sqlalchemy.orm import Session +from app.celery_task_scheduler import scheduler from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.logging_config import get_logger @@ -166,20 +167,31 @@ class MemoryAPIService: # Convert to message list format expected by write_message_task messages = message if isinstance(message, list) else [{"role": "user", "content": message}] - from app.tasks import write_message_task - task = write_message_task.delay( + # from app.tasks import write_message_task + # task = write_message_task.delay( + # end_user_id, + # messages, + # config_id, + # storage_type, + # user_rag_memory_id or "", + # ) + task_id = scheduler.push_task( + "app.core.memory.agent.write_message", end_user_id, - messages, - config_id, - storage_type, - user_rag_memory_id or "", + { + "end_user_id": end_user_id, + "message": messages, + "config_id": config_id, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id or "" + } ) - logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}") + logger.info(f"Memory write task submitted, task_id={task_id} end_user_id={end_user_id}") return { - "task_id": task.id, - "status": "PENDING", + "task_id": task_id, + "status": "QUEUED", "end_user_id": end_user_id, } diff --git a/api/app/tasks.py b/api/app/tasks.py index 92843175..fdc717f5 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -34,7 +34,7 @@ from app.core.rag.prompts.generator import question_proposal from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ElasticSearchVectorFactory, ) -from app.db import get_db, get_db_context +from app.db import get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema @@ -2025,7 +2025,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di end_users = db.query(EndUser).all() if not end_users: logger.info("没有终端用户,跳过遗忘周期") - return {"status": "SUCCESS", "message": "没有终端用户", + return {"status": "SUCCESS", "message": "没有终端用户", "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0}, "duration_seconds": time.time() - start_time} @@ -2039,7 +2039,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # 获取用户配置(自动回退到工作空间默认配置) connected_config = get_end_user_connected_config(str(end_user.id), db) user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db) - + if not user_config_id: failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"}) continue @@ -2048,13 +2048,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di report = await forget_service.trigger_forgetting_cycle( db=db, end_user_id=str(end_user.id), config_id=user_config_id ) - + total_merged += report.get('merged_count', 0) total_failed += report.get('failed_count', 0) processed_users += 1 - + logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点") - + except Exception as e: logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True) failed_users.append({"end_user_id": str(end_user.id), "error": str(e)}) @@ -2801,18 +2801,18 @@ def run_incremental_clustering( 包含任务执行结果的字典 """ start_time = time.time() - + async def _run() -> Dict[str, Any]: from app.core.logging_config import get_logger from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine - + logger = get_logger(__name__) logger.info( f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, " f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}" ) - + connector = Neo4jConnector() try: engine = LabelPropagationEngine( @@ -2820,12 +2820,12 @@ def run_incremental_clustering( llm_model_id=llm_model_id, embedding_model_id=embedding_model_id, ) - + # 执行增量聚类 await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) - + logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}") - + return { "status": "SUCCESS", "end_user_id": end_user_id, @@ -2836,18 +2836,18 @@ def run_incremental_clustering( raise finally: await connector.close() - + try: loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time result["task_id"] = self.request.id - + logger.info( f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, " f"elapsed_time={result['elapsed_time']:.2f}s" ) - + return result except Exception as e: elapsed_time = time.time() - start_time diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 5d358f2c..a0fd4791 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -63,6 +63,23 @@ services: networks: - celery + celery-task-scheduler: + image: redbear-mem-open:latest + container_name: celery-task-scheduler + env_file: + - .env + volumes: + - /etc/localtime:/etc/localtime:ro + command: python app/celery_task_scheduler.py + restart: unless-stopped + healthcheck: + test: CMD curl -f 127.0.0.1:8001 || exit 1 + interval: 30s + timeout: 5s + retries: 3 + networks: + - celery + # Celery Beat - scheduler beat: image: redbear-mem-open:latest