diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index a0bcc1a1..11118571 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -315,6 +315,12 @@ class MemoryForgetService: # 获取遗忘引擎组件 _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + # 如果参数为 None,使用配置中的默认值 + if max_merge_batch_size is None: + max_merge_batch_size = config.get('max_merge_batch_size', 100) + if min_days_since_access is None: + min_days_since_access = config.get('min_days_since_access', 30) + # 记录执行开始时间 execution_time = datetime.now() diff --git a/api/app/tasks.py b/api/app/tasks.py index 3b81ced3..61736275 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -36,9 +36,11 @@ from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ( ) from app.db import get_db, get_db_context from app.models import Document, File, Knowledge +from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema from app.schemas.model_schema import ModelInfo -from app.services.memory_agent_service import MemoryAgentService +from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config +from app.services.memory_forget_service import MemoryForgetService from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id from app.utils.redis_lock import RedisLock @@ -1860,7 +1862,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: @celery_app.task( name="app.tasks.run_forgetting_cycle_task", bind=True, - ignore_result=True, + ignore_result=False, # 改为 False 以便在 Flower 中查看结果 max_retries=0, acks_late=False, time_limit=7200, @@ -1868,68 +1870,77 @@ def workspace_reflection_task(self) -> Dict[str, Any]: ) def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: """定时任务:运行遗忘周期 - - 定期执行遗忘周期,识别并融合低激活值的知识节点。 - - Args: - config_id: 配置ID(可选,如果为None则使用默认配置) - - Returns: - 包含任务执行结果的字典 + + 遍历所有终端用户,执行遗忘周期。 """ start_time = time.time() - async def _run() -> Dict[str, Any]: - from app.services.memory_forget_service import MemoryForgetService - + async def _process_users() -> Dict[str, Any]: with get_db_context() as db: - try: - logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") + end_users = db.query(EndUser).all() + if not end_users: + logger.info("没有终端用户,跳过遗忘周期") + return {"status": "SUCCESS", "message": "没有终端用户", + "report": {"merged_count": 0, "failed_count": 0, "processed_users": 0}, + "duration_seconds": time.time() - start_time} - forget_service = MemoryForgetService() + logger.info(f"开始处理 {len(end_users)} 个终端用户的遗忘周期") + forget_service = MemoryForgetService() + total_merged = total_failed = processed_users = 0 + failed_users = [] - # 运行遗忘周期 - # FIXME: MemeoryForgetService - report = await forget_service.trigger_forgetting( - db=db, - end_user_id=None, # 处理所有组 - config_id=config_id - ) + for end_user in end_users: + try: + # 获取用户配置(自动回退到工作空间默认配置) + connected_config = get_end_user_connected_config(str(end_user.id), db) + user_config_id = resolve_config_id(connected_config.get("memory_config_id"), db) + + if not user_config_id: + failed_users.append({"end_user_id": str(end_user.id), "error": "无法获取配置"}) + continue - duration = time.time() - start_time + # 执行遗忘周期 + report = await forget_service.trigger_forgetting_cycle( + db=db, end_user_id=str(end_user.id), config_id=user_config_id + ) + + total_merged += report.get('merged_count', 0) + total_failed += report.get('failed_count', 0) + processed_users += 1 + + logger.info(f"用户 {end_user.id}: 融合 {report.get('merged_count', 0)} 对节点") + + except Exception as e: + logger.error(f"处理用户 {end_user.id} 失败: {e}", exc_info=True) + failed_users.append({"end_user_id": str(end_user.id), "error": str(e)}) - logger.info( - f"遗忘周期定时任务完成: " - f"融合 {report['merged_count']} 对节点, " - f"失败 {report['failed_count']} 对, " - f"耗时 {duration:.2f} 秒" - ) + duration = time.time() - start_time + logger.info(f"遗忘周期完成: {processed_users}/{len(end_users)} 用户, " + f"融合 {total_merged} 对, 耗时 {duration:.2f}s") - return { - "status": "SUCCESS", - "message": "遗忘周期执行成功", - "report": report, - "duration_seconds": duration - } - - except Exception as e: - duration = time.time() - start_time - logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) - - return { - "status": "FAILED", - "message": f"遗忘周期执行失败: {str(e)}", - "duration_seconds": duration - } + return { + "status": "SUCCESS", + "message": f"处理 {processed_users} 个用户", + "report": { + "merged_count": total_merged, + "failed_count": total_failed, + "processed_users": processed_users, + "total_users": len(end_users), + "failed_users": failed_users + }, + "duration_seconds": duration + } # 运行异步函数 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) try: - result = loop.run_until_complete(_run()) - return result - finally: - loop.close() + return asyncio.run(_process_users()) + except Exception as e: + logger.error(f"遗忘周期任务失败: {e}", exc_info=True) + return { + "status": "FAILED", + "message": f"任务失败: {str(e)}", + "duration_seconds": time.time() - start_time + } # =============================================================================