From 117e29fbe3f525e3b00fc90be87fdafc470e768f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 3 Apr 2026 16:46:09 +0800 Subject: [PATCH] fix(memory): improve optimistic lock resilience in access history manager - Increase max_retries from 3 to 5 for concurrent conflict recovery - Add randomized exponential backoff between retries to reduce contention - Merge duplicate node accesses in batch operations to avoid self-conflicts - Support access_times parameter for merged batch access counting - Add Community node label support in atomic update content field map --- .../access_history_manager.py | 62 ++++++++++++++----- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index a71c0957..cc477330 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -57,7 +57,7 @@ class AccessHistoryManager: self, connector: Neo4jConnector, actr_calculator: ACTRCalculator, - max_retries: int = 3 + max_retries: int = 5 ): """ 初始化访问历史管理器 @@ -76,7 +76,8 @@ class AccessHistoryManager: node_id: str, node_label: str, end_user_id: Optional[str] = None, - current_time: Optional[datetime] = None + current_time: Optional[datetime] = None, + access_times: int = 1 ) -> Dict[str, Any]: """ 记录节点访问并原子性更新所有相关字段 @@ -93,6 +94,7 @@ class AccessHistoryManager: node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) end_user_id: 组ID(可选,用于过滤) current_time: 当前时间(可选,默认使用系统时间) + access_times: 本次访问次数(默认1,批量合并时可能大于1) Returns: Dict[str, Any]: 更新后的节点数据,包含: @@ -134,7 +136,8 @@ class AccessHistoryManager: update_data = await self._calculate_update( node_data=node_data, current_time=current_time, - current_time_iso=current_time_iso + current_time_iso=current_time_iso, + access_times=access_times ) # 步骤3:原子性更新节点(使用事务) @@ -149,15 +152,21 @@ class AccessHistoryManager: f"成功记录访问: {node_label}[{node_id}], " f"activation={update_data['activation_value']:.4f}, " f"access_count={update_data['access_count']}" + f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" ) return updated_node except Exception as e: if attempt < self.max_retries - 1: + # 随机退避:避免并发请求同时重试再次冲突 + import random + backoff = random.uniform(0.05, 0.2) * (attempt + 1) logger.warning( - f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}" + f"访问记录失败(尝试 {attempt + 1}/{self.max_retries})," + f"{backoff:.3f}s 后重试: {str(e)}" ) + await asyncio.sleep(backoff) continue else: logger.error( @@ -179,10 +188,11 @@ class AccessHistoryManager: 批量记录多个节点的访问 为提高性能,批量更新多个节点的访问历史。 - 每个节点独立更新,失败的节点不影响其他节点。 + 对同一个节点的多次访问会先在内存中合并,只发起一次更新, + 从而避免同节点并发写入导致的乐观锁冲突。 Args: - node_ids: 节点ID列表 + node_ids: 节点ID列表(可包含重复ID) node_label: 节点标签(所有节点必须是同一类型) end_user_id: 组ID(可选) current_time: 当前时间(可选) @@ -196,25 +206,40 @@ class AccessHistoryManager: if current_time is None: current_time = datetime.now() - # PERFORMANCE FIX: Process all nodes in parallel instead of sequentially - tasks = [] + # 合并同一节点的访问次数,避免对同一节点并发写入 + access_count_map: Dict[str, int] = {} for node_id in node_ids: + access_count_map[node_id] = access_count_map.get(node_id, 0) + 1 + + merged_count = len(node_ids) - len(access_count_map) + if merged_count > 0: + logger.info( + f"批量访问合并: 原始={len(node_ids)}, " + f"去重后={len(access_count_map)}, 合并={merged_count}" + ) + + # 对去重后的节点并行发起更新 + tasks = [] + for node_id, access_times in access_count_map.items(): task = self.record_access( node_id=node_id, node_label=node_label, end_user_id=end_user_id, - current_time=current_time + current_time=current_time, + access_times=access_times ) - tasks.append(task) + tasks.append((node_id, task)) # Execute all tasks in parallel - task_results = await asyncio.gather(*tasks, return_exceptions=True) + task_results = await asyncio.gather( + *[t for _, t in tasks], return_exceptions=True + ) # Collect successful results and count failures results = [] failed_count = 0 - for node_id, result in zip(node_ids, task_results): + for (node_id, _), result in zip(tasks, task_results): if isinstance(result, Exception): failed_count += 1 logger.warning( @@ -225,7 +250,7 @@ class AccessHistoryManager: batch_duration = time.time() - batch_start logger.info( - f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " + f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, " f"失败 {failed_count}, 耗时 {batch_duration:.4f}s" ) @@ -532,7 +557,8 @@ class AccessHistoryManager: self, node_data: Dict[str, Any], current_time: datetime, - current_time_iso: str + current_time_iso: str, + access_times: int = 1 ) -> Dict[str, Any]: """ 计算更新数据 @@ -541,6 +567,7 @@ class AccessHistoryManager: node_data: 当前节点数据 current_time: 当前时间(datetime对象) current_time_iso: 当前时间(ISO格式字符串) + access_times: 本次访问次数(合并后可能大于1) Returns: Dict[str, Any]: 更新数据,包含所有需要更新的字段 @@ -551,8 +578,8 @@ class AccessHistoryManager: if importance_score is None: importance_score = 0.5 - # 追加新的访问时间 - new_access_history = access_history + [current_time_iso] + # 追加新的访问时间(合并场景下追加多条相同时间戳) + new_access_history = access_history + [current_time_iso] * access_times # 修剪访问历史(如果过长) access_history_dt = [ @@ -642,7 +669,8 @@ class AccessHistoryManager: content_field_map = { 'Statement': 'n.statement as statement', 'MemorySummary': 'n.content as content', - 'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤 + 'ExtractedEntity': 'null as content_placeholder', # 占位符,后续会被过滤 + 'Community': 'n.summary as summary' } # 显式检查节点类型,不支持的类型抛出错误