refactor(core): migrate memory write tasks to centralized scheduler

This commit is contained in:
Eternity
2026-04-22 16:50:06 +08:00
parent 6f323f2435
commit c5ae82c3c2
15 changed files with 358 additions and 54 deletions

View File

@@ -17,6 +17,7 @@ def _mask_url(url: str) -> str:
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
@@ -29,7 +30,7 @@ if platform.system() == 'Darwin':
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
_broker_url = os.getenv("CELERY_BROKER_URL") or \
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
@@ -66,11 +67,11 @@ celery_app.conf.update(
task_serializer='json',
accept_content=['json'],
result_serializer='json',
# # 时区
# timezone='Asia/Shanghai',
# enable_utc=False,
# 任务追踪
task_track_started=True,
task_ignore_result=False,

View File

@@ -0,0 +1,241 @@
import json
import threading
import time
import redis
from app.core.config import settings
from celery_app import celery_app
from core.logging_config import get_named_logger
logger = get_named_logger("task_scheduler")
STREAM_KEY = "celery_task_stream"
PENDING_HASH = "scheduler:pending_tasks"
TASK_TIMEOUT = 7800
def health_check_server():
import uvicorn
from fastapi import FastAPI
health_app = FastAPI()
@health_app.get("/")
def health():
return scheduler.health()
threading.Thread(
target=uvicorn.run,
kwargs={
"app": health_app,
"host": "0.0.0.0",
"port": 8001,
"log_config": None
},
daemon=True
).start()
logger.info(f"[Health] Server started at http://0.0.0.0:8001")
class RedisTaskScheduler:
def __init__(self):
self.redis = redis.Redis(
host=settings.REDIS_HOST,
port=settings.REDIS_PORT,
db=settings.REDIS_DB_CELERY_BACKEND,
password=settings.REDIS_PASSWORD,
decode_responses=True,
)
self.running = False
self.dispatched = 0
self.errors = 0
self._leader = False
def push_task(self, task_name, user_id, params):
try:
msg_id = self.redis.xadd(
STREAM_KEY,
fields={
"task_name": task_name,
"user_id": user_id,
"params": json.dumps(params),
}
)
return msg_id
except Exception as e:
logger.error("Push task exception %s", e, exc_info=True)
raise e
def _cleanup_finished(self):
pending = self.redis.hgetall(PENDING_HASH)
if not pending:
return
now = time.time()
task_ids = list(pending.keys())
pipe = self.redis.pipeline()
for task_id in task_ids:
pipe.get(f"celery-task-meta-{task_id}")
results = pipe.execute()
cleanup_pipe = self.redis.pipeline()
has_cleanup = False
for task_id, raw_result in zip(task_ids, results):
try:
meta = json.loads(pending[task_id])
lock_key = meta["lock_key"]
dispatched_at = meta.get("dispatched_at", 0)
age = now - dispatched_at
should_cleanup = False
if raw_result is not None:
result_data = json.loads(raw_result)
if result_data.get("status") in ("SUCCESS", "FAILURE", "REVOKED"):
should_cleanup = True
logger.info("Task finished: %s state=%s", task_id, result_data.get("status"))
elif age > TASK_TIMEOUT:
should_cleanup = True
logger.warning(
"Task expired or lost: %s age=%.0fs, force cleanup",
task_id, age,
)
if should_cleanup:
cleanup_pipe.delete(lock_key)
cleanup_pipe.hdel(PENDING_HASH, task_id)
has_cleanup = True
except Exception as e:
logger.error("Cleanup error for %s: %s", task_id, e, exc_info=True)
self.errors += 1
if has_cleanup:
cleanup_pipe.execute()
def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data['user_id']
task_name = msg_data['task_name']
params = json.loads(msg_data.get('params', "{}"))
lock_key = f"{task_name}:{user_id}"
try:
task = celery_app.send_task(task_name, kwargs=params)
pipe = self.redis.pipeline()
pipe.set(lock_key, task.id, ex=3600)
pipe.hset(PENDING_HASH, task.id, json.dumps({
"lock_key": lock_key,
"dispatched_at": time.time()
}))
pipe.xdel(STREAM_KEY, msg_id)
pipe.execute()
self.dispatched += 1
logger.info("Task dispatched: %s", task.id)
return True
except Exception as e:
self.errors += 1
logger.error("Task dispatch error for %s: %s", task_name, e, exc_info=True)
return False
def _leader_lock_extend(self, lock, interval=20):
while self._leader:
try:
lock.extend(60)
except redis.exceptions.LockNotOwnedError:
logger.warning("Lost leader lock during extend")
self._leader = False
except Exception as e:
logger.error("Lock extend error: %s", e)
for _ in range(interval):
if not self._leader:
break
time.sleep(1)
def schedule_loop(self):
self.running = True
self._cleanup_finished()
resp = self.redis.xread(
streams={STREAM_KEY: '0-0'},
count=500,
block=5000,
)
if not resp:
return
messages = []
for stream_key, msgs in resp:
messages.extend(msgs)
lock_keys = []
for msg_id, msg_data in messages:
lock_keys.append(f"{msg_data['task_name']}:{msg_data['user_id']}")
pipe = self.redis.pipeline()
for key in lock_keys:
pipe.exists(key)
lock_exists = pipe.execute()
deliver_keys = set()
for (msg_id, msg_data), locked in zip(messages, lock_exists):
user_id = msg_data['user_id']
lock_key = f"{msg_data['task_name']}:{user_id}"
if locked or lock_key in deliver_keys:
continue
key = self._dispatch(msg_id, msg_data)
if key:
deliver_keys.add(lock_key)
time.sleep(0.1)
def run_server(self):
health_check_server()
lock = self.redis.lock(
"scheduler:leader",
timeout=60,
blocking_timeout=10,
thread_local=False
)
while True:
try:
if lock.acquire(blocking=True):
self._leader = True
t = threading.Thread(
target=self._leader_lock_extend,
args=(lock, 20),
daemon=True
)
t.start()
try:
while self._leader:
self.schedule_loop()
finally:
self._leader = False
t.join(timeout=30)
try:
lock.release()
except redis.exceptions.LockNotOwnedError:
pass
self.running = False
else:
time.sleep(5)
except Exception as e:
logger.error("Scheduler exception %s", e, exc_info=True)
time.sleep(5)
def health(self) -> dict:
return {
"running": self.running,
"pending": self.redis.xlen(STREAM_KEY),
"dispatched": self.dispatched,
"errors": self.errors
}
scheduler: RedisTaskScheduler | None = None
if scheduler is None:
scheduler = RedisTaskScheduler()
if __name__ == '__main__':
scheduler.run_server()

View File

@@ -86,7 +86,7 @@ async def write_memory(
user_rag_memory_id=payload.user_rag_memory_id,
)
logger.info(f"Memory write task submitted: task_id={result['task_id']}, end_user_id: {payload.end_user_id}")
logger.info(f"Memory write task submitted: end_user_id: {payload.end_user_id}")
return success(data=MemoryWriteResponse(**result).model_dump(), msg="Memory write task submitted")

View File

@@ -12,9 +12,8 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
from celery_task_scheduler import scheduler
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
@@ -86,16 +85,28 @@ async def write(
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: User ID
structured_messages, # message: JSON string format message list
str(actual_config_id), # config_id: Configuration ID string
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
# write_id = write_message_task.delay(
# actual_end_user_id, # end_user_id: User ID
# structured_messages, # message: JSON string format message list
# str(actual_config_id), # config_id: Configuration ID string
# storage_type, # storage_type: "neo4j"
# user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
# )
scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id,
{
"end_user_id": end_user_id,
"message": structured_messages,
"config_id": str(actual_config_id),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id or ""
}
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
# logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
# write_status = get_task_memory_write_result(str(write_id))
# logger.info(f'[WRITE] Task result - user={actual_end_user_id}')
async def term_memory_save(end_user_id, strategy_type, scope):
@@ -164,13 +175,24 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
else:
config_id = memory_config
write_message_task.delay(
end_user_id, # end_user_id: User ID
redis_messages, # message: JSON string format message list
config_id, # config_id: Configuration ID string
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id,
{
"end_user_id": end_user_id,
"message": redis_messages,
"config_id": config_id,
"storage_type": AgentMemory_Long_Term.STORAGE_NEO4J,
"user_rag_memory_id": ""
}
)
# write_message_task.delay(
# end_user_id, # end_user_id: User ID
# redis_messages, # message: JSON string format message list
# config_id, # config_id: Configuration ID string
# AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
# "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
# )
count_store.update_sessions_count(end_user_id, 0, [])

View File

@@ -1,8 +1,8 @@
from app.core.memory.enums import SearchStrategy, StorageType
from app.core.memory.models.service_models import MemorySearchResult
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
from app.core.memory.read_services.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
from core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
from core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):

View File

@@ -8,7 +8,7 @@ from neo4j import Session
from app.core.memory.enums import Neo4jNodeType
from app.core.memory.memory_service import MemoryContext
from app.core.memory.models.service_models import Memory, MemorySearchResult
from app.core.memory.read_services.result_builder import data_builder_factory
from core.memory.read_services.search_engine.result_builder import data_builder_factory
from app.core.models import RedBearEmbeddings
from app.core.rag.nlp.search import knowledge_retrieval
from app.repositories import knowledge_repository

View File

@@ -11,7 +11,7 @@ from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read
from app.schemas import FileInput
from app.tasks import write_message_task
from celery_task_scheduler import scheduler
class MemoryReadNode(BaseNode):
@@ -126,12 +126,23 @@ class MemoryWriteNode(BaseNode):
"files": file_info
})
write_message_task.delay(
end_user_id=end_user_id,
message=messages,
config_id=str(self.typed_config.config_id),
storage_type=state["memory_storage_type"],
user_rag_memory_id=state["user_rag_memory_id"]
scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id,
{
"end_user_id": end_user_id,
"message": messages,
"config_id": str(self.typed_config.config_id),
"storage_type": state["memory_storage_type"],
"user_rag_memory_id": state["user_rag_memory_id"]
}
)
# write_message_task.delay(
# end_user_id=end_user_id,
# message=messages,
# config_id=str(self.typed_config.config_id),
# storage_type=state["memory_storage_type"],
# user_rag_memory_id=state["user_rag_memory_id"]
# )
return "success"

View File

@@ -10,6 +10,7 @@ from typing import Any, Dict, Optional
from sqlalchemy.orm import Session
from app.celery_task_scheduler import scheduler
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_logger
@@ -166,19 +167,30 @@ class MemoryAPIService:
# Convert to message list format expected by write_message_task
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
from app.tasks import write_message_task
task = write_message_task.delay(
# from app.tasks import write_message_task
# task = write_message_task.delay(
# end_user_id,
# messages,
# config_id,
# storage_type,
# user_rag_memory_id or "",
# )
scheduler.push_task(
"app.core.memory.agent.write_message",
end_user_id,
messages,
config_id,
storage_type,
user_rag_memory_id or "",
{
"end_user_id": end_user_id,
"message": messages,
"config_id": config_id,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id or ""
}
)
logger.info(f"Memory write task submitted: task_id={task.id}, end_user_id={end_user_id}")
logger.info(f"Memory write task submitted, end_user_id={end_user_id}")
return {
"task_id": task.id,
# "task_id": task.id,
"status": "PENDING",
"end_user_id": end_user_id,
}

View File

@@ -34,7 +34,7 @@ from app.core.rag.prompts.generator import question_proposal
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
ElasticSearchVectorFactory,
)
from app.db import get_db, get_db_context
from app.db import get_db_context
from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema
@@ -1993,7 +1993,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
end_users = db.query(EndUser).all()
if not end_users:
logger.info("没有终端用户,跳过遗忘周期")
return {"status": "SUCCESS", "message": "没有终端用户",
return {"status": "SUCCESS", "message": "没有终端用户",
"report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
"duration_seconds": time.time() - start_time}
@@ -2007,7 +2007,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
# 获取用户配置(自动回退到工作空间默认配置)
connected_config = get_end_user_connected_config(str(end_user.id), db)
user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
if not user_config_id:
failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
continue
@@ -2016,13 +2016,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
report = await forget_service.trigger_forgetting_cycle(
db=db, end_user_id=str(end_user.id), config_id=user_config_id
)
total_merged += report.get('merged_count', 0)
total_failed += report.get('failed_count', 0)
processed_users += 1
logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
except Exception as e:
logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
@@ -2769,18 +2769,18 @@ def run_incremental_clustering(
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
logger = get_logger(__name__)
logger.info(
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
)
connector = Neo4jConnector()
try:
engine = LabelPropagationEngine(
@@ -2788,12 +2788,12 @@ def run_incremental_clustering(
llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id,
)
# 执行增量聚类
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
return {
"status": "SUCCESS",
"end_user_id": end_user_id,
@@ -2804,18 +2804,18 @@ def run_incremental_clustering(
raise
finally:
await connector.close()
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
logger.info(
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
f"elapsed_time={result['elapsed_time']:.2f}s"
)
return result
except Exception as e:
elapsed_time = time.time() - start_time

View File

@@ -63,6 +63,23 @@ services:
networks:
- celery
celery-task-scheduler:
image: redbear-mem-open:latest
container_name: celery-task-scheduler
env_file:
- .env
volumes:
- /etc/localtime:/etc/localtime:ro
command: python app/celery_task_scheduler.py
restart: unless-stopped
healthcheck:
test: CMD curl -f 127.0.0.1:8001 || exit 1
interval: 30s
timeout: 5s
retries: 3
networks:
- celery
# Celery Beat - scheduler
beat:
image: redbear-mem-open:latest