diff --git a/api/app/celery_app.py b/api/app/celery_app.py index b0894eb8..717709da 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -17,6 +17,7 @@ def _mask_url(url: str) -> str: """隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议""" return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url) + # macOS fork() safety - must be set before any Celery initialization if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') @@ -29,7 +30,7 @@ if platform.system() == 'Darwin': # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md _broker_url = os.getenv("CELERY_BROKER_URL") or \ - f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" + f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url @@ -66,11 +67,11 @@ celery_app.conf.update( task_serializer='json', accept_content=['json'], result_serializer='json', - + # # 时区 # timezone='Asia/Shanghai', # enable_utc=False, - + # 任务追踪 task_track_started=True, task_ignore_result=False, diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py new file mode 100644 index 00000000..88afe341 --- /dev/null +++ b/api/app/celery_task_scheduler.py @@ -0,0 +1,241 @@ +import json +import threading +import time + +import redis + +from app.core.config import settings +from celery_app import celery_app +from core.logging_config import get_named_logger + +logger = get_named_logger("task_scheduler") + +STREAM_KEY = "celery_task_stream" +PENDING_HASH = "scheduler:pending_tasks" +TASK_TIMEOUT = 7800 + + +def health_check_server(): + import uvicorn + from fastapi import FastAPI + + health_app = FastAPI() + + @health_app.get("/") + def health(): + return scheduler.health() + + threading.Thread( + target=uvicorn.run, + kwargs={ + "app": health_app, + "host": "0.0.0.0", + "port": 8001, + "log_config": None + }, + daemon=True + ).start() + logger.info(f"[Health] Server started at http://0.0.0.0:8001") + + +class RedisTaskScheduler: + def __init__(self): + self.redis = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB_CELERY_BACKEND, + password=settings.REDIS_PASSWORD, + decode_responses=True, + ) + self.running = False + self.dispatched = 0 + self.errors = 0 + self._leader = False + + def push_task(self, task_name, user_id, params): + try: + msg_id = self.redis.xadd( + STREAM_KEY, + fields={ + "task_name": task_name, + "user_id": user_id, + "params": json.dumps(params), + } + ) + return msg_id + except Exception as e: + logger.error("Push task exception %s", e, exc_info=True) + raise e + + def _cleanup_finished(self): + pending = self.redis.hgetall(PENDING_HASH) + if not pending: + return + + now = time.time() + task_ids = list(pending.keys()) + + pipe = self.redis.pipeline() + for task_id in task_ids: + pipe.get(f"celery-task-meta-{task_id}") + results = pipe.execute() + + cleanup_pipe = self.redis.pipeline() + has_cleanup = False + + for task_id, raw_result in zip(task_ids, results): + try: + meta = json.loads(pending[task_id]) + lock_key = meta["lock_key"] + dispatched_at = meta.get("dispatched_at", 0) + age = now - dispatched_at + + should_cleanup = False + if raw_result is not None: + result_data = json.loads(raw_result) + if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"): + should_cleanup = True + logger.info("Task finished: %s state=%s", task_id, result_data.get("status")) + elif age > TASK_TIMEOUT: + should_cleanup = True + logger.warning( + "Task expired or lost: %s age=%.0fs, force cleanup", + task_id, age, + ) + + if should_cleanup: + cleanup_pipe.delete(lock_key) + cleanup_pipe.hdel(PENDING_HASH, task_id) + has_cleanup = True + except Exception as e: + logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True) + self.errors += 1 + if has_cleanup: + cleanup_pipe.execute() + + def _dispatch(self, msg_id, msg_data) -> bool: + user_id = msg_data['user_id'] + task_name = msg_data['task_name'] + params = json.loads(msg_data.get('params', "{}")) + + lock_key = f"{task_name}:{user_id}" + try: + task = celery_app.send_task(task_name, kwargs=params) + pipe = self.redis.pipeline() + pipe.set(lock_key, task.id, ex=3600) + pipe.hset(PENDING_HASH, task.id, json.dumps({ + "lock_key": lock_key, + "dispatched_at": time.time() + })) + pipe.xdel(STREAM_KEY, msg_id) + pipe.execute() + self.dispatched += 1 + logger.info("Task dispatched: %s", task.id) + return True + except Exception as e: + self.errors += 1 + logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True) + return False + + def _leader_lock_extend(self, lock, interval=20): + while self._leader: + try: + lock.extend(60) + except redis.exceptions.LockNotOwnedError: + logger.warning("Lost leader lock during extend") + self._leader = False + except Exception as e: + logger.error("Lock extend error: %s", e) + for _ in range(interval): + if not self._leader: + break + time.sleep(1) + + def schedule_loop(self): + self.running = True + self._cleanup_finished() + resp = self.redis.xread( + streams={STREAM_KEY: '0-0'}, + count=500, + block=5000, + ) + if not resp: + return + + messages = [] + for stream_key, msgs in resp: + messages.extend(msgs) + + lock_keys = [] + for msg_id, msg_data in messages: + lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}") + + pipe = self.redis.pipeline() + for key in lock_keys: + pipe.exists(key) + lock_exists = pipe.execute() + + deliver_keys = set() + for (msg_id, msg_data), locked in zip(messages, lock_exists): + user_id = msg_data['user_id'] + lock_key = f"{msg_data['task_name']}:{user_id}" + + if locked or lock_key in deliver_keys: + continue + + key = self._dispatch(msg_id, msg_data) + if key: + deliver_keys.add(lock_key) + time.sleep(0.1) + + def run_server(self): + health_check_server() + + lock = self.redis.lock( + "scheduler:leader", + timeout=60, + blocking_timeout=10, + thread_local=False + ) + while True: + try: + if lock.acquire(blocking=True): + self._leader = True + t = threading.Thread( + target=self._leader_lock_extend, + args=(lock, 20), + daemon=True + ) + t.start() + try: + while self._leader: + self.schedule_loop() + finally: + self._leader = False + t.join(timeout=30) + try: + lock.release() + except redis.exceptions.LockNotOwnedError: + pass + self.running = False + else: + time.sleep(5) + except Exception as e: + logger.error("Scheduler exception %s", e, exc_info=True) + time.sleep(5) + + def health(self) -> dict: + return { + "running": self.running, + "pending": self.redis.xlen(STREAM_KEY), + "dispatched": self.dispatched, + "errors": self.errors + } + + +scheduler: RedisTaskScheduler | None = None +if scheduler is None: + scheduler = RedisTaskScheduler() + +if __name__ == '__main__': + scheduler.run_server() diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 313781d2..de56f56e 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -86,7 +86,7 @@ async def write_memory( user_rag_memory_id=payload.user_rag_memory_id, ) - logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}") + logger.info(f"Memory write task submitted: end_user_id: {payload.end_user_id}") return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted") diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 74fb6bae..50f7ddb9 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -12,9 +12,8 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id +from celery_task_scheduler import scheduler logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') @@ -86,16 +85,28 @@ async def write( logger.info( f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: User ID - structured_messages, # message: JSON string format message list - str(actual_config_id), # config_id: Configuration ID string - storage_type, # storage_type: "neo4j" - user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # write_id = write_message_task.delay( + # actual_end_user_id, # end_user_id: User ID + # structured_messages, # message: JSON string format message list + # str(actual_config_id), # config_id: Configuration ID string + # storage_type, # storage_type: "neo4j" + # user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # ) + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": structured_messages, + "config_id": str(actual_config_id), + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id or "" + } ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + + # logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + # write_status = get_task_memory_write_result(str(write_id)) + # logger.info(f'[WRITE] Task result - user={actual_end_user_id}') async def term_memory_save(end_user_id, strategy_type, scope): @@ -164,13 +175,24 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) else: config_id = memory_config - write_message_task.delay( - end_user_id, # end_user_id: User ID - redis_messages, # message: JSON string format message list - config_id, # config_id: Configuration ID string - AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" - "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": redis_messages, + "config_id": config_id, + "storage_type": AgentMemory_Long_Term.STORAGE_NEO4J, + "user_rag_memory_id": "" + } ) + # write_message_task.delay( + # end_user_id, # end_user_id: User ID + # redis_messages, # message: JSON string format message list + # config_id, # config_id: Configuration ID string + # AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" + # "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) + # ) count_store.update_sessions_count(end_user_id, 0, []) diff --git a/api/app/core/memory/pipelines/memory_read.py b/api/app/core/memory/pipelines/memory_read.py index 96ff929a..35ace00d 100644 --- a/api/app/core/memory/pipelines/memory_read.py +++ b/api/app/core/memory/pipelines/memory_read.py @@ -1,8 +1,8 @@ from app.core.memory.enums import SearchStrategy, StorageType from app.core.memory.models.service_models import MemorySearchResult from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline -from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService -from app.core.memory.read_services.query_preprocessor import QueryPreprocessor +from core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService +from core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): diff --git a/api/app/core/memory/read_services/generate_engine/__init__.py b/api/app/core/memory/read_services/generate_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/memory/read_services/query_preprocessor.py b/api/app/core/memory/read_services/generate_engine/query_preprocessor.py similarity index 100% rename from api/app/core/memory/read_services/query_preprocessor.py rename to api/app/core/memory/read_services/generate_engine/query_preprocessor.py diff --git a/api/app/core/memory/read_services/retrieval_summary.py b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py similarity index 100% rename from api/app/core/memory/read_services/retrieval_summary.py rename to api/app/core/memory/read_services/generate_engine/retrieval_summary.py diff --git a/api/app/core/memory/read_services/search_engine/__init__.py b/api/app/core/memory/read_services/search_engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/core/memory/read_services/content_search.py b/api/app/core/memory/read_services/search_engine/content_search.py similarity index 99% rename from api/app/core/memory/read_services/content_search.py rename to api/app/core/memory/read_services/search_engine/content_search.py index ef4e90f1..16c23f91 100644 --- a/api/app/core/memory/read_services/content_search.py +++ b/api/app/core/memory/read_services/search_engine/content_search.py @@ -8,7 +8,7 @@ from neo4j import Session from app.core.memory.enums import Neo4jNodeType from app.core.memory.memory_service import MemoryContext from app.core.memory.models.service_models import Memory, MemorySearchResult -from app.core.memory.read_services.result_builder import data_builder_factory +from core.memory.read_services.search_engine.result_builder import data_builder_factory from app.core.models import RedBearEmbeddings from app.core.rag.nlp.search import knowledge_retrieval from app.repositories import knowledge_repository diff --git a/api/app/core/memory/read_services/result_builder.py b/api/app/core/memory/read_services/search_engine/result_builder.py similarity index 100% rename from api/app/core/memory/read_services/result_builder.py rename to api/app/core/memory/read_services/search_engine/result_builder.py diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index bcdc80c7..b74af6b9 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -11,7 +11,7 @@ from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.db import get_db_read from app.schemas import FileInput -from app.tasks import write_message_task +from celery_task_scheduler import scheduler class MemoryReadNode(BaseNode): @@ -126,12 +126,23 @@ class MemoryWriteNode(BaseNode): "files": file_info }) - write_message_task.delay( - end_user_id=end_user_id, - message=messages, - config_id=str(self.typed_config.config_id), - storage_type=state["memory_storage_type"], - user_rag_memory_id=state["user_rag_memory_id"] + scheduler.push_task( + "app.core.memory.agent.write_message", + end_user_id, + { + "end_user_id": end_user_id, + "message": messages, + "config_id": str(self.typed_config.config_id), + "storage_type": state["memory_storage_type"], + "user_rag_memory_id": state["user_rag_memory_id"] + } ) + # write_message_task.delay( + # end_user_id=end_user_id, + # message=messages, + # config_id=str(self.typed_config.config_id), + # storage_type=state["memory_storage_type"], + # user_rag_memory_id=state["user_rag_memory_id"] + # ) return "success" diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 330b84ad..221ca4cf 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -10,6 +10,7 @@ from typing import Any, Dict, Optional from sqlalchemy.orm import Session +from app.celery_task_scheduler import scheduler from app.core.error_codes import BizCode from app.core.exceptions import BusinessException, ResourceNotFoundException from app.core.logging_config import get_logger @@ -166,19 +167,30 @@ class MemoryAPIService: # Convert to message list format expected by write_message_task messages = message if isinstance(message, list) else [{"role": "user", "content": message}] - from app.tasks import write_message_task - task = write_message_task.delay( + # from app.tasks import write_message_task + # task = write_message_task.delay( + # end_user_id, + # messages, + # config_id, + # storage_type, + # user_rag_memory_id or "", + # ) + scheduler.push_task( + "app.core.memory.agent.write_message", end_user_id, - messages, - config_id, - storage_type, - user_rag_memory_id or "", + { + "end_user_id": end_user_id, + "message": messages, + "config_id": config_id, + "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id or "" + } ) - logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}") + logger.info(f"Memory write task submitted, end_user_id={end_user_id}") return { - "task_id": task.id, + # "task_id": task.id, "status": "PENDING", "end_user_id": end_user_id, } diff --git a/api/app/tasks.py b/api/app/tasks.py index 8bbbdc6e..5978af4d 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -34,7 +34,7 @@ from app.core.rag.prompts.generator import question_proposal from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ElasticSearchVectorFactory, ) -from app.db import get_db, get_db_context +from app.db import get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema @@ -1993,7 +1993,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di end_users = db.query(EndUser).all() if not end_users: logger.info("没有终端用户,跳过遗忘周期") - return {"status": "SUCCESS", "message": "没有终端用户", + return {"status": "SUCCESS", "message": "没有终端用户", "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0}, "duration_seconds": time.time() - start_time} @@ -2007,7 +2007,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # 获取用户配置(自动回退到工作空间默认配置) connected_config = get_end_user_connected_config(str(end_user.id), db) user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db) - + if not user_config_id: failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"}) continue @@ -2016,13 +2016,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di report = await forget_service.trigger_forgetting_cycle( db=db, end_user_id=str(end_user.id), config_id=user_config_id ) - + total_merged += report.get('merged_count', 0) total_failed += report.get('failed_count', 0) processed_users += 1 - + logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点") - + except Exception as e: logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True) failed_users.append({"end_user_id": str(end_user.id), "error": str(e)}) @@ -2769,18 +2769,18 @@ def run_incremental_clustering( 包含任务执行结果的字典 """ start_time = time.time() - + async def _run() -> Dict[str, Any]: from app.core.logging_config import get_logger from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine - + logger = get_logger(__name__) logger.info( f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, " f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}" ) - + connector = Neo4jConnector() try: engine = LabelPropagationEngine( @@ -2788,12 +2788,12 @@ def run_incremental_clustering( llm_model_id=llm_model_id, embedding_model_id=embedding_model_id, ) - + # 执行增量聚类 await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) - + logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}") - + return { "status": "SUCCESS", "end_user_id": end_user_id, @@ -2804,18 +2804,18 @@ def run_incremental_clustering( raise finally: await connector.close() - + try: loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time result["task_id"] = self.request.id - + logger.info( f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, " f"elapsed_time={result['elapsed_time']:.2f}s" ) - + return result except Exception as e: elapsed_time = time.time() - start_time diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 5d358f2c..a0fd4791 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -63,6 +63,23 @@ services: networks: - celery + celery-task-scheduler: + image: redbear-mem-open:latest + container_name: celery-task-scheduler + env_file: + - .env + volumes: + - /etc/localtime:/etc/localtime:ro + command: python app/celery_task_scheduler.py + restart: unless-stopped + healthcheck: + test: CMD curl -f 127.0.0.1:8001 || exit 1 + interval: 30s + timeout: 5s + retries: 3 + networks: + - celery + # Celery Beat - scheduler beat: image: redbear-mem-open:latest