From 16cf6eee9bf205eac5ec594f7a2a2b542ea56c79 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Fri, 6 Feb 2026 17:37:03 +0800 Subject: [PATCH] Fix/develop memory bug (#350) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 遗漏的历史映射 * 遗漏的历史映射 * fix_timeline_memories * fix_timeline_memories * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * write_gragp/bug_fix * Multiple independent transactions - single transaction * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * memory_content ->memory_config_id * tasks/bug_fix/long * tasks_reflection/bug/fix * tasks_reflection/bug/fix * tasks_reflection/bug/fix * tasks_reflection/bug/fix --- api/app/services/memory_reflection_service.py | 7 + api/app/tasks.py | 344 ++++++++++-------- 2 files changed, 190 insertions(+), 161 deletions(-) diff --git a/api/app/services/memory_reflection_service.py b/api/app/services/memory_reflection_service.py index 0e542ff0..371a9e72 100644 --- a/api/app/services/memory_reflection_service.py +++ b/api/app/services/memory_reflection_service.py @@ -364,6 +364,13 @@ class MemoryReflectionService: reflexion_range_value = config_data.get("reflexion_range") if reflexion_range_value is None or reflexion_range_value == "": reflexion_range_value = "partial" + # Map legacy/invalid values to valid enum values + reflexion_range_mapping = { + "retrieval": "partial", # Map old 'retrieval' to 'partial' + "partial": "partial", + "all": "all" + } + reflexion_range_value = reflexion_range_mapping.get(reflexion_range_value, "partial") reflexion_range = ReflectionRange(reflexion_range_value) baseline_value = config_data.get("baseline") diff --git a/api/app/tasks.py b/api/app/tasks.py index e2c295ab..539a3700 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -405,7 +405,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): # 2. sync data match db_knowledge.type: - case "Web": # Crawl webpages in batches through a web crawler + case "Web": # Crawl webpages in batches through a web crawler entry_url = db_knowledge.parser_config.get("entry_url", "") max_pages = db_knowledge.parser_config.get("max_pages", 20) delay_seconds = db_knowledge.parser_config.get("delay_seconds", 1.0) @@ -428,19 +428,21 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db_file = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url == crawled_document.url).first() if db_file: - if db_file.file_size == crawled_document.content_length: # same + if db_file.file_size == crawled_document.content_length: # same continue - else: # --update + else: # --update if crawled_document.content_length: # 1. update file db_file.file_name = f"{crawled_document.title}.txt" - db_file.file_ext=".txt" - db_file.file_size=crawled_document.content_length + db_file.file_ext = ".txt" + db_file.file_size = crawled_document.content_length db.commit() db.refresh(db_file) # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} - save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) - Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, + exist_ok=True) # Ensure that the directory exists save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") # update file if os.path.exists(save_path): @@ -460,7 +462,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.refresh(db_document) # 3. Document parsing, vectorization, and storage parse_document(file_path=save_path, document_id=db_document.id) - else: # --add + else: # --add if crawled_document.content_length: # 1. upload file upload_file = file_schema.FileCreate( @@ -507,8 +509,9 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): db.commit() # 3. Document parsing, vectorization, and storage parse_document(file_path=save_path, document_id=db_document.id) - db_files = db.query(File).filter(File.kb_id == db_knowledge.id, File.file_url.notin_(file_urls)).all() - if db_files: # --delete + db_files = db.query(File).filter(File.kb_id == db_knowledge.id, + File.file_url.notin_(file_urls)).all() + if db_files: # --delete for db_file in db_files: db_document = db.query(Document).filter(Document.kb_id == db_knowledge.id, Document.file_id == db_file.id).first() @@ -535,7 +538,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): case "Third-party": # Integration of knowledge bases from three parties yuque_user_id = db_knowledge.parser_config.get("yuque_user_id", "") feishu_app_id = db_knowledge.parser_config.get("feishu_app_id", "") - if yuque_user_id: # Yuque Knowledge Base + if yuque_user_id: # Yuque Knowledge Base yuque_token = db_knowledge.parser_config.get("yuque_token", "") # Create yuqueAPIClient api_client = YuqueAPIClient( @@ -571,11 +574,14 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): else: # --update # 1. update file # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} - save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) - Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) + Path(save_dir).mkdir(parents=True, + exist_ok=True) # Ensure that the directory exists # download document from Feishu FileInfo - async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, + save_dir: str): async with api_client as client: file_path = await client.download_document(doc, save_dir) return file_path @@ -613,11 +619,13 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): else: # --add # 1. update file # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} - save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) + save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), + str(db_knowledge.parent_id)) Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists # download document from Feishu FileInfo - async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, save_dir: str): + async def async_download_document(api_client: YuqueAPIClient, doc: YuqueDocInfo, + save_dir: str): async with api_client as client: file_path = await client.download_document(doc, save_dir) return file_path @@ -697,7 +705,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): except Exception as e: print(f"\n\nError during fetch feishu: {e}") - if feishu_app_id: # Feishu Knowledge Base + if feishu_app_id: # Feishu Knowledge Base feishu_app_secret = db_knowledge.parser_config.get("feishu_app_secret", "") feishu_folder_token = db_knowledge.parser_config.get("feishu_folder_token", "") # Create feishuAPIClient @@ -708,11 +716,13 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): try: # 初始化存储获取飞书 URLs 的集合 file_urls = set() + # Get all files from folder async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str): async with api_client as client: files = await client.list_all_folder_files(feishu_folder_token, recursive=True) return files + files = asyncio.run(async_get_files(api_client, feishu_folder_token)) # Filter out folders, only sync documents documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]] @@ -728,12 +738,16 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): # Construct a save path:/files/{kb_id}/{parent_id}/{file.id}{file_extension} save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) - Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + Path(save_dir).mkdir(parents=True, + exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo - async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, + save_dir: str): async with api_client as client: file_path = await client.download_document(document=doc, save_dir=save_dir) return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) save_path = os.path.join(save_dir, f"{db_file.id}{db_file.file_ext}") @@ -770,11 +784,14 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): save_dir = os.path.join(settings.FILE_PATH, str(db_knowledge.id), str(db_knowledge.parent_id)) Path(save_dir).mkdir(parents=True, exist_ok=True) # Ensure that the directory exists + # download document from Feishu FileInfo - async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str): + async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, + save_dir: str): async with api_client as client: file_path = await client.download_document(document=doc, save_dir=save_dir) return file_path + file_path = asyncio.run(async_download_document(api_client, doc, save_dir)) # add db_file file_name = os.path.basename(file_path) @@ -788,7 +805,7 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): file_ext=file_extension.lower(), file_size=file_size, file_url=doc.url, - created_at = doc.modified_time + created_at=doc.modified_time ) db_file = File(**upload_file.model_dump()) db.add(db_file) @@ -853,7 +870,6 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): case _: # General print(f"General: No synchronization needed\n") - result = f"sync knowledge '{db_knowledge.name}' processed successfully." return result except Exception as e: @@ -866,8 +882,8 @@ def sync_knowledge_for_kb(kb_id: uuid.UUID): @celery_app.task(name="app.core.memory.agent.read_message", bind=True) -def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str, storage_type:str, user_rag_memory_id:str) -> Dict[str, Any]: - +def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, + config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]: """Celery task to process a read message via MemoryAgentService. Args: @@ -876,15 +892,15 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s history: Conversation history search_switch: Search switch parameter config_id: Configuration ID as string (will be converted to UUID) - + Returns: Dict containing the result and metadata - + Raises: Exception on failure """ start_time = time.time() - + # Convert config_id string to UUID actual_config_id = None if config_id: @@ -893,7 +909,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s except (ValueError, AttributeError): # If conversion fails, leave as None and try to resolve pass - + # Resolve config_id if None if actual_config_id is None: try: @@ -907,12 +923,13 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s except Exception: # Log but continue - will fail later with proper error pass - + async def _run() -> str: db = next(get_db()) try: service = MemoryAgentService() - return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) + return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, + storage_type, user_rag_memory_id) finally: db.close() @@ -923,7 +940,7 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s nest_asyncio.apply() except ImportError: pass - + # 尝试获取现有事件循环,如果不存在则创建新的 try: loop = asyncio.get_event_loop() @@ -933,10 +950,10 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time - + return { "status": "SUCCESS", "result": result, @@ -964,9 +981,10 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type:str, user_rag_memory_id:str, language: str = "zh") -> Dict[str, Any]: +def write_message_task(self, end_user_id: str, message: str, config_id: str, storage_type: str, user_rag_memory_id: str, + language: str = "zh") -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. - + Args: end_user_id: Group ID for the memory agent (also used as end_user_id) message: Message to write @@ -974,25 +992,27 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID language: 语言类型 ("zh" 中文, "en" 英文) - + Returns: Dict containing the result and metadata - + Raises: Exception on failure """ from app.core.logging_config import get_logger logger = get_logger(__name__) - - logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}, language={language}") + + logger.info( + f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}, language={language}") start_time = time.time() - + # Convert config_id string to UUID actual_config_id = None if config_id: try: actual_config_id = uuid.UUID(config_id) if isinstance(config_id, str) else config_id - logger.info(f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") + logger.info( + f"[CELERY WRITE] Converted config_id to UUID: {actual_config_id} (type: {type(actual_config_id).__name__})") except (ValueError, AttributeError) as e: logger.error(f"[CELERY WRITE] Invalid config_id format: {config_id}, error: {e}") return { @@ -1003,7 +1023,7 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto "elapsed_time": 0.0, "task_id": self.request.id } - + # Resolve config_id if None if actual_config_id is None: try: @@ -1021,9 +1041,11 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto async def _run() -> str: db = next(get_db()) try: - logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") + logger.info( + f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() - result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id, language) + result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, + user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result except Exception as e: @@ -1039,7 +1061,7 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto nest_asyncio.apply() except ImportError: pass - + # 尝试获取现有事件循环,如果不存在则创建新的 try: loop = asyncio.get_event_loop() @@ -1049,12 +1071,13 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time - - logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") - + + logger.info( + f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + return { "status": "SUCCESS", "result": result, @@ -1071,9 +1094,10 @@ def write_message_task(self, end_user_id: str, message: str, config_id: str, sto detailed_error = "; ".join(error_messages) else: detailed_error = str(e) - - logger.error(f"[CELERY WRITE] Task failed - elapsed_time={elapsed_time:.2f}s, error={detailed_error}", exc_info=True) - + + logger.error(f"[CELERY WRITE] Task failed - elapsed_time={elapsed_time:.2f}s, error={detailed_error}", + exc_info=True) + return { "status": "FAILURE", "error": detailed_error, @@ -1100,16 +1124,17 @@ def reflection_engine() -> None: @celery_app.task(name="app.core.memory.agent.reflection.timer") def reflection_timer_task() -> None: """Periodic Celery task that invokes reflection_engine. - + Raises an exception on failure. """ reflection_engine() + # unused task # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: # """Call read_service and write latest status to Redis. - + # Returns status data dict that gets written to Redis. # """ # client = redis.Redis( @@ -1157,31 +1182,31 @@ def reflection_timer_task() -> None: @celery_app.task(name="app.controllers.memory_storage_controller.search_all") def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: """定时任务:查询工作空间下所有宿主的记忆总量并写入数据库 - + Args: workspace_id: 工作空间ID - + Returns: 包含任务执行结果的字典 """ start_time = time.time() - + async def _run() -> Dict[str, Any]: from app.models.app_model import App from app.models.end_user_model import EndUser from app.repositories.memory_increment_repository import write_memory_increment from app.services.memory_storage_service import search_all - + with get_db_context() as db: try: workspace_uuid = uuid.UUID(workspace_id) - + # 1. 查询当前workspace下的所有app(仅未删除的) apps = db.query(App).filter( App.workspace_id == workspace_uuid, App.is_active.is_(True) ).all() - + if not apps: # 如果没有app,总量为0 memory_increment = write_memory_increment( @@ -1197,17 +1222,17 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: "memory_increment_id": str(memory_increment.id), "created_at": memory_increment.created_at.isoformat(), } - + # 2. 查询所有app下的end_user_id(去重) app_ids = [app.id for app in apps] end_users = db.query(EndUser.id).filter( EndUser.app_id.in_(app_ids) ).distinct().all() - + # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 total_num = 0 end_user_details = [] - + for (end_user_id,) in end_users: try: # 调用 search_all 接口查询该宿主的总量 @@ -1225,14 +1250,14 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: "total": 0, "error": str(e) }) - + # 4. 写入数据库 memory_increment = write_memory_increment( db=db, workspace_id=workspace_uuid, total_num=total_num ) - + return { "status": "SUCCESS", "workspace_id": workspace_id, @@ -1244,7 +1269,7 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: } except Exception as e: raise e - + try: result = asyncio.run(_run()) elapsed_time = time.time() - start_time @@ -1263,18 +1288,18 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: @celery_app.task( name="app.tasks.regenerate_memory_cache", bind=True, - ignore_result=True, - max_retries=0, - acks_late=False, - time_limit=3600, - soft_time_limit=3300, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, ) def regenerate_memory_cache(self) -> Dict[str, Any]: """定时任务:为所有用户重新生成记忆洞察和用户摘要缓存 - + 遍历所有活动工作空间的所有终端用户,为每个用户重新生成记忆洞察和用户摘要。 实现错误隔离,单个用户失败不影响其他用户的处理。 - + Returns: 包含任务执行结果的字典,包括: - status: 任务状态 (SUCCESS/FAILURE) @@ -1288,57 +1313,57 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: - task_id: 任务ID """ start_time = time.time() - + async def _run() -> Dict[str, Any]: from app.core.logging_config import get_logger from app.repositories.end_user_repository import EndUserRepository from app.services.user_memory_service import UserMemoryService - + logger = get_logger(__name__) logger.info("开始执行记忆缓存重新生成定时任务") - + service = UserMemoryService() - + total_users = 0 successful = 0 failed = 0 workspace_results = [] - + with get_db_context() as db: try: # 获取所有活动工作空间 repo = EndUserRepository(db) workspaces = repo.get_all_active_workspaces() logger.info(f"找到 {len(workspaces)} 个活动工作空间") - + # 遍历每个工作空间 for workspace_id in workspaces: logger.info(f"开始处理工作空间: {workspace_id}") workspace_start_time = time.time() - + try: # 获取工作空间的所有终端用户 end_users = repo.get_all_by_workspace(workspace_id) workspace_user_count = len(end_users) total_users += workspace_user_count - + logger.info(f"工作空间 {workspace_id} 有 {workspace_user_count} 个终端用户") - + workspace_successful = 0 workspace_failed = 0 workspace_errors = [] - + # 遍历每个用户并生成缓存 for end_user in end_users: end_user_id = str(end_user.id) - + try: # 生成记忆洞察 insight_result = await service.generate_and_cache_insight(db, end_user_id) - + # 生成用户摘要 summary_result = await service.generate_and_cache_summary(db, end_user_id) - + # 检查是否都成功 if insight_result["success"] and summary_result["success"]: workspace_successful += 1 @@ -1354,7 +1379,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } workspace_errors.append(error_info) logger.warning(f"终端用户 {end_user_id} 的缓存重新生成部分失败: {error_info}") - + except Exception as e: # 单个用户失败不影响其他用户(错误隔离) workspace_failed += 1 @@ -1365,9 +1390,9 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } workspace_errors.append(error_info) logger.error(f"为终端用户 {end_user_id} 重新生成缓存时出错: {str(e)}") - + workspace_elapsed = time.time() - workspace_start_time - + # 记录工作空间处理结果 workspace_result = { "workspace_id": str(workspace_id), @@ -1378,13 +1403,13 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: "elapsed_time": workspace_elapsed } workspace_results.append(workspace_result) - + logger.info( f"工作空间 {workspace_id} 处理完成: " f"总数={workspace_user_count}, 成功={workspace_successful}, " f"失败={workspace_failed}, 耗时={workspace_elapsed:.2f}秒" ) - + except Exception as e: # 工作空间处理失败,记录错误并继续处理下一个 logger.error(f"处理工作空间 {workspace_id} 时出错: {str(e)}") @@ -1396,14 +1421,14 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: "failed": 0, "errors": [] }) - + # 记录总体统计信息 logger.info( f"记忆缓存重新生成定时任务完成: " f"工作空间数={len(workspaces)}, 总用户数={total_users}, " f"成功={successful}, 失败={failed}" ) - + return { "status": "SUCCESS", "message": f"成功处理 {len(workspaces)} 个工作空间,总共 {successful}/{total_users} 个用户缓存重新生成成功", @@ -1413,7 +1438,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: "failed": failed, "workspace_results": workspace_results } - + except Exception as e: logger.error(f"记忆缓存重新生成定时任务执行失败: {str(e)}") return { @@ -1425,7 +1450,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: "failed": failed, "workspace_results": workspace_results } - + try: # 使用 nest_asyncio 来避免事件循环冲突 try: @@ -1433,7 +1458,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: nest_asyncio.apply() except ImportError: pass - + # 尝试获取现有事件循环,如果不存在则创建新的 try: loop = asyncio.get_event_loop() @@ -1443,12 +1468,12 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id - + return result except Exception as e: elapsed_time = time.time() - start_time @@ -1460,15 +1485,14 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } - @celery_app.task( name="app.tasks.workspace_reflection_task", bind=True, - ignore_result=True, - max_retries=0, - acks_late=False, - time_limit=300, - soft_time_limit=240, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=300, + soft_time_limit=240, ) def workspace_reflection_task(self) -> Dict[str, Any]: """定时任务:每30秒运行工作空间反思功能 @@ -1487,7 +1511,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: ) api_logger = get_api_logger() - + with get_db_context() as db: try: # 获取所有工作空间 @@ -1518,15 +1542,16 @@ def workspace_reflection_task(self) -> Dict[str, Any]: workspace_reflection_results = [] for data in result['apps_detailed_info']: - if data['data_configs'] == []: + if data['memory_configs'] == []: continue releases = data['releases'] - data_configs = data['data_configs'] + memory_configs = data['memory_configs'] end_users = data['end_users'] - for base, config, user in zip(releases, data_configs, end_users): - if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str(user['app_id']): + for base, config, user in zip(releases, memory_configs, end_users): + if str(base['config']) == str(config['config_id']) and str(base['app_id']) == str( + user['app_id']): # 调用反思服务 api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}") @@ -1614,75 +1639,73 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } - - @celery_app.task( name="app.tasks.run_forgetting_cycle_task", bind=True, - ignore_result=True, - max_retries=0, - acks_late=False, - time_limit=7200, - soft_time_limit=7000, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=7200, + soft_time_limit=7000, ) def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Dict[str, Any]: """定时任务:运行遗忘周期 - + 定期执行遗忘周期,识别并融合低激活值的知识节点。 - + Args: config_id: 配置ID(可选,如果为None则使用默认配置) - + Returns: 包含任务执行结果的字典 """ start_time = time.time() - + async def _run() -> Dict[str, Any]: from app.core.logging_config import get_api_logger from app.services.memory_forget_service import MemoryForgetService - + api_logger = get_api_logger() - + with get_db_context() as db: try: api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") - + forget_service = MemoryForgetService() - + # 运行遗忘周期 report = await forget_service.trigger_forgetting( db=db, end_user_id=None, # 处理所有组 config_id=config_id ) - + duration = time.time() - start_time - + api_logger.info( f"遗忘周期定时任务完成: " f"融合 {report['merged_count']} 对节点, " f"失败 {report['failed_count']} 对, " f"耗时 {duration:.2f} 秒" ) - + return { "status": "SUCCESS", "message": "遗忘周期执行成功", "report": report, "duration_seconds": duration } - + except Exception as e: duration = time.time() - start_time api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) - + return { "status": "FAILED", "message": f"遗忘周期执行失败: {str(e)}", "duration_seconds": duration } - + # 运行异步函数 loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -1692,7 +1715,6 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di finally: loop.close() - # ============================================================================= # Long-term Memory Storage Tasks (Batched Write Strategies) # ============================================================================= @@ -1705,27 +1727,27 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # time_window: int = 5 # ) -> Dict[str, Any]: # """Celery task for time-based long-term memory storage. - + # Retrieves recent sessions from Redis within time window and writes to Neo4j. - + # Args: # end_user_id: End user identifier # config_id: Memory configuration ID # time_window: Time window in minutes for retrieving recent sessions - + # Returns: # Dict containing task status and metadata # """ # from app.core.logging_config import get_logger # logger = get_logger(__name__) - + # logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}") # start_time = time.time() - + # async def _run() -> Dict[str, Any]: # from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage # from app.services.memory_config_service import MemoryConfigService - + # db = next(get_db()) # try: # # Load memory config @@ -1734,20 +1756,20 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # config_id=config_id, # service_name="LongTermStorageTask" # ) - + # # Execute time-based storage # await memory_long_term_storage(end_user_id, memory_config, time_window) - + # return {"status": "SUCCESS", "strategy": "time", "time_window": time_window} # finally: # db.close() - + # try: # import nest_asyncio # nest_asyncio.apply() # except ImportError: # pass - + # try: # loop = asyncio.get_event_loop() # if loop.is_closed(): @@ -1756,13 +1778,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # except RuntimeError: # loop = asyncio.new_event_loop() # asyncio.set_event_loop(loop) - + # try: # result = loop.run_until_complete(_run()) # elapsed_time = time.time() - start_time - + # logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s") - + # return { # **result, # "end_user_id": end_user_id, @@ -1773,7 +1795,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # except Exception as e: # elapsed_time = time.time() - start_time # logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True) - + # return { # "status": "FAILURE", # "strategy": "time", @@ -1793,45 +1815,45 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # config_id: str # ) -> Dict[str, Any]: # """Celery task for aggregate-based long-term memory storage. - + # Uses LLM to determine if new messages describe the same event as history. # Only writes to Neo4j if messages represent new information (not duplicates). - + # Args: # end_user_id: End user identifier # langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] # config_id: Memory configuration ID - + # Returns: # Dict containing task status, is_same_event flag, and metadata # """ # from app.core.logging_config import get_logger # logger = get_logger(__name__) - + # logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}") # start_time = time.time() - + # async def _run() -> Dict[str, Any]: # from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment # from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format # from app.core.memory.agent.utils.redis_tool import write_store # from app.services.memory_config_service import MemoryConfigService - + # db = next(get_db()) # try: # # Save to Redis buffer first # write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) - + # # Load memory config # config_service = MemoryConfigService(db) # memory_config = config_service.load_memory_config( # config_id=config_id, # service_name="LongTermStorageTask" # ) - + # # Execute aggregate judgment # result = await aggregate_judgment(end_user_id, langchain_messages, memory_config) - + # return { # "status": "SUCCESS", # "strategy": "aggregate", @@ -1840,13 +1862,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # } # finally: # db.close() - + # try: # import nest_asyncio # nest_asyncio.apply() # except ImportError: # pass - + # try: # loop = asyncio.get_event_loop() # if loop.is_closed(): @@ -1855,13 +1877,13 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # except RuntimeError: # loop = asyncio.new_event_loop() # asyncio.set_event_loop(loop) - + # try: # result = loop.run_until_complete(_run()) # elapsed_time = time.time() - start_time - + # logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s") - + # return { # **result, # "end_user_id": end_user_id, @@ -1872,7 +1894,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # except Exception as e: # elapsed_time = time.time() - start_time # logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True) - + # return { # "status": "FAILURE", # "strategy": "aggregate", @@ -1881,4 +1903,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di # "config_id": config_id, # "elapsed_time": elapsed_time, # "task_id": self.request.id -# } +# } \ No newline at end of file