Merge pull request #687 from SuanmoSuanyangTechnology/fix/forget-celery

[fix] Fix the forgotten periodic tasks
This commit is contained in:
Ke Sun
2026-03-26 13:48:32 +08:00
committed by GitHub
2 changed files with 69 additions and 52 deletions

View File

@@ -315,6 +315,12 @@ class MemoryForgetService:
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
# 如果参数为 None使用配置中的默认值
if max_merge_batch_size is None:
max_merge_batch_size = config.get('max_merge_batch_size', 100)
if min_days_since_access is None:
min_days_since_access = config.get('min_days_since_access', 30)
# 记录执行开始时间
execution_time = datetime.now()

View File

@@ -36,9 +36,11 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
)
from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema
from app.schemas.model_schema import ModelInfo
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
from app.services.memory_forget_service import MemoryForgetService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisLock
@@ -1860,7 +1862,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
@celery_app.task(
name="app.tasks.run_forgetting_cycle_task",
bind=True,
ignore_result=True,
ignore_result=False, # 改为 False 以便在 Flower 中查看结果
max_retries=0,
acks_late=False,
time_limit=7200,
@@ -1868,68 +1870,77 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
)
def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]:
"""定时任务:运行遗忘周期
定期执行遗忘周期,识别并融合低激活值的知识节点
Args:
config_id: 配置ID可选如果为None则使用默认配置
Returns:
包含任务执行结果的字典
遍历所有终端用户,执行遗忘周期
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.services.memory_forget_service import MemoryForgetService
async def _process_users() -> Dict[str, Any]:
with get_db_context() as db:
try:
logger.info(f"开始执行遗忘周期定时任务config_id: {config_id}")
end_users = db.query(EndUser).all()
if not end_users:
logger.info("没有终端用户,跳过遗忘周期")
return {"status": "SUCCESS", "message": "没有终端用户",
"report": {"merged_count": 0, "failed_count": 0, "processed_users": 0},
"duration_seconds": time.time() - start_time}
forget_service = MemoryForgetService()
logger.info(f"开始处理 {len(end_users)} 个终端用户的遗忘周期")
forget_service = MemoryForgetService()
total_merged = total_failed = processed_users = 0
failed_users = []
# 运行遗忘周期
# FIXME: MemeoryForgetService
report = await forget_service.trigger_forgetting(
db=db,
end_user_id=None, # 处理所有组
config_id=config_id
)
for end_user in end_users:
try:
# 获取用户配置(自动回退到工作空间默认配置)
connected_config = get_end_user_connected_config(str(end_user.id), db)
user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db)
if not user_config_id:
failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"})
continue
duration = time.time() - start_time
# 执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db, end_user_id=str(end_user.id), config_id=user_config_id
)
total_merged += report.get('merged_count', 0)
total_failed += report.get('failed_count', 0)
processed_users += 1
logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点")
except Exception as e:
logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True)
failed_users.append({"end_user_id": str(end_user.id), "error": str(e)})
logger.info(
f"遗忘周期定时任务完成: "
f"融合 {report['merged_count']} 对节点, "
f"失败 {report['failed_count']} 对, "
f"耗时 {duration:.2f}"
)
duration = time.time() - start_time
logger.info(f"遗忘周期完成: {processed_users}/{len(end_users)} 用户, "
f"融合 {total_merged} 对, 耗时 {duration:.2f}s")
return {
"status": "SUCCESS",
"message": "遗忘周期执行成功",
"report": report,
"duration_seconds": duration
}
except Exception as e:
duration = time.time() - start_time
logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
return {
"status": "FAILED",
"message": f"遗忘周期执行失败: {str(e)}",
"duration_seconds": duration
}
return {
"status": "SUCCESS",
"message": f"处理 {processed_users} 个用户",
"report": {
"merged_count": total_merged,
"failed_count": total_failed,
"processed_users": processed_users,
"total_users": len(end_users),
"failed_users": failed_users
},
"duration_seconds": duration
}
# 运行异步函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()
return asyncio.run(_process_users())
except Exception as e:
logger.error(f"遗忘周期任务失败: {e}", exc_info=True)
return {
"status": "FAILED",
"message": f"任务失败: {str(e)}",
"duration_seconds": time.time() - start_time
}
# =============================================================================