From 117e29fbe3f525e3b00fc90be87fdafc470e768f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 3 Apr 2026 16:46:09 +0800 Subject: [PATCH 1/3] fix(memory): improve optimistic lock resilience in access history manager - Increase max_retries from 3 to 5 for concurrent conflict recovery - Add randomized exponential backoff between retries to reduce contention - Merge duplicate node accesses in batch operations to avoid self-conflicts - Support access_times parameter for merged batch access counting - Add Community node label support in atomic update content field map --- .../access_history_manager.py | 62 ++++++++++++++----- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index a71c0957..cc477330 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -57,7 +57,7 @@ class AccessHistoryManager: self, connector: Neo4jConnector, actr_calculator: ACTRCalculator, - max_retries: int = 3 + max_retries: int = 5 ): """ 初始化访问历史管理器 @@ -76,7 +76,8 @@ class AccessHistoryManager: node_id: str, node_label: str, end_user_id: Optional[str] = None, - current_time: Optional[datetime] = None + current_time: Optional[datetime] = None, + access_times: int = 1 ) -> Dict[str, Any]: """ 记录节点访问并原子性更新所有相关字段 @@ -93,6 +94,7 @@ class AccessHistoryManager: node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) end_user_id: 组ID(可选,用于过滤) current_time: 当前时间(可选,默认使用系统时间) + access_times: 本次访问次数(默认1,批量合并时可能大于1) Returns: Dict[str, Any]: 更新后的节点数据,包含: @@ -134,7 +136,8 @@ class AccessHistoryManager: update_data = await self._calculate_update( node_data=node_data, current_time=current_time, - current_time_iso=current_time_iso + current_time_iso=current_time_iso, + access_times=access_times ) # 步骤3:原子性更新节点(使用事务) @@ -149,15 +152,21 @@ class AccessHistoryManager: f"成功记录访问: {node_label}[{node_id}], " f"activation={update_data['activation_value']:.4f}, " f"access_count={update_data['access_count']}" + f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" ) return updated_node except Exception as e: if attempt < self.max_retries - 1: + # 随机退避:避免并发请求同时重试再次冲突 + import random + backoff = random.uniform(0.05, 0.2) * (attempt + 1) logger.warning( - f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}" + f"访问记录失败(尝试 {attempt + 1}/{self.max_retries})," + f"{backoff:.3f}s 后重试: {str(e)}" ) + await asyncio.sleep(backoff) continue else: logger.error( @@ -179,10 +188,11 @@ class AccessHistoryManager: 批量记录多个节点的访问 为提高性能,批量更新多个节点的访问历史。 - 每个节点独立更新,失败的节点不影响其他节点。 + 对同一个节点的多次访问会先在内存中合并,只发起一次更新, + 从而避免同节点并发写入导致的乐观锁冲突。 Args: - node_ids: 节点ID列表 + node_ids: 节点ID列表(可包含重复ID) node_label: 节点标签(所有节点必须是同一类型) end_user_id: 组ID(可选) current_time: 当前时间(可选) @@ -196,25 +206,40 @@ class AccessHistoryManager: if current_time is None: current_time = datetime.now() - # PERFORMANCE FIX: Process all nodes in parallel instead of sequentially - tasks = [] + # 合并同一节点的访问次数,避免对同一节点并发写入 + access_count_map: Dict[str, int] = {} for node_id in node_ids: + access_count_map[node_id] = access_count_map.get(node_id, 0) + 1 + + merged_count = len(node_ids) - len(access_count_map) + if merged_count > 0: + logger.info( + f"批量访问合并: 原始={len(node_ids)}, " + f"去重后={len(access_count_map)}, 合并={merged_count}" + ) + + # 对去重后的节点并行发起更新 + tasks = [] + for node_id, access_times in access_count_map.items(): task = self.record_access( node_id=node_id, node_label=node_label, end_user_id=end_user_id, - current_time=current_time + current_time=current_time, + access_times=access_times ) - tasks.append(task) + tasks.append((node_id, task)) # Execute all tasks in parallel - task_results = await asyncio.gather(*tasks, return_exceptions=True) + task_results = await asyncio.gather( + *[t for _, t in tasks], return_exceptions=True + ) # Collect successful results and count failures results = [] failed_count = 0 - for node_id, result in zip(node_ids, task_results): + for (node_id, _), result in zip(tasks, task_results): if isinstance(result, Exception): failed_count += 1 logger.warning( @@ -225,7 +250,7 @@ class AccessHistoryManager: batch_duration = time.time() - batch_start logger.info( - f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " + f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, " f"失败 {failed_count}, 耗时 {batch_duration:.4f}s" ) @@ -532,7 +557,8 @@ class AccessHistoryManager: self, node_data: Dict[str, Any], current_time: datetime, - current_time_iso: str + current_time_iso: str, + access_times: int = 1 ) -> Dict[str, Any]: """ 计算更新数据 @@ -541,6 +567,7 @@ class AccessHistoryManager: node_data: 当前节点数据 current_time: 当前时间(datetime对象) current_time_iso: 当前时间(ISO格式字符串) + access_times: 本次访问次数(合并后可能大于1) Returns: Dict[str, Any]: 更新数据,包含所有需要更新的字段 @@ -551,8 +578,8 @@ class AccessHistoryManager: if importance_score is None: importance_score = 0.5 - # 追加新的访问时间 - new_access_history = access_history + [current_time_iso] + # 追加新的访问时间(合并场景下追加多条相同时间戳) + new_access_history = access_history + [current_time_iso] * access_times # 修剪访问历史(如果过长) access_history_dt = [ @@ -642,7 +669,8 @@ class AccessHistoryManager: content_field_map = { 'Statement': 'n.statement as statement', 'MemorySummary': 'n.content as content', - 'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤 + 'ExtractedEntity': 'null as content_placeholder', # 占位符,后续会被过滤 + 'Community': 'n.summary as summary' } # 显式检查节点类型,不支持的类型抛出错误 From 00a8099857beaca14cbf5ffcb322e052eaa1b690 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 3 Apr 2026 16:55:53 +0800 Subject: [PATCH 2/3] changes:(api) Change the "jitter" to "tremble". --- .../forgetting_engine/access_history_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index cc477330..a5f48982 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -159,9 +159,9 @@ class AccessHistoryManager: except Exception as e: if attempt < self.max_retries - 1: - # 随机退避:避免并发请求同时重试再次冲突 + # 带抖动的指数退避:base * 2^attempt * random(0.5, 1.0) import random - backoff = random.uniform(0.05, 0.2) * (attempt + 1) + backoff = 0.05 * (2 ** attempt) * random.uniform(0.5, 1.0) logger.warning( f"访问记录失败(尝试 {attempt + 1}/{self.max_retries})," f"{backoff:.3f}s 后重试: {str(e)}" From 99862db7a05dc06e02af0179df5a6f2d48aff013 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 3 Apr 2026 18:40:03 +0800 Subject: [PATCH 3/3] refactor(forgetting-engine): replace optimistic locking with APOC atomic operations in access history manager - Replace version-based optimistic locking and retry loop with apoc.atomic.add/insert for concurrent safety - Merge duplicate accesses within a batch before updating (access_count_delta) - Simplify _calculate_update to only compute on new timestamps instead of full history rebuild - Remove max_retries instance variable (kept as param for backward compat) - Trim verbose docstrings and inline comments --- .../access_history_manager.py | 420 ++++++------------ 1 file changed, 126 insertions(+), 294 deletions(-) diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index a5f48982..e5254646 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -42,15 +42,14 @@ class AccessHistoryManager: - access_count: 访问次数 特性: - - 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚 - - 并发安全:使用乐观锁机制防止并发冲突 + - 原子性更新:使用 APOC 原子操作确保并发安全 + - 批次内合并:同一批次中对同一节点的多次访问合并为一次更新 - 一致性保证:提供一致性检查和自动修复功能 - 智能修剪:自动修剪过长的访问历史 Attributes: connector: Neo4j连接器实例 actr_calculator: ACT-R激活值计算器实例 - max_retries: 并发冲突时的最大重试次数 """ def __init__( @@ -65,12 +64,11 @@ class AccessHistoryManager: Args: connector: Neo4j连接器实例 actr_calculator: ACT-R激活值计算器实例 - max_retries: 并发冲突时的最大重试次数(默认3次) + max_retries: 已废弃,保留参数兼容性(APOC 原子操作无需重试) """ self.connector = connector self.actr_calculator = actr_calculator - self.max_retries = max_retries - + async def record_access( self, node_id: str, @@ -82,13 +80,6 @@ class AccessHistoryManager: """ 记录节点访问并原子性更新所有相关字段 - 这是核心方法,实现了: - 1. 首次访问:初始化access_history,计算初始激活值 - 2. 后续访问:追加访问历史,重新计算激活值 - 3. 历史修剪:当历史过长时自动修剪 - 4. 原子性:所有字段在单个事务中更新 - 5. 并发安全:使用乐观锁重试机制 - Args: node_id: 节点ID node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) @@ -97,17 +88,11 @@ class AccessHistoryManager: access_times: 本次访问次数(默认1,批量合并时可能大于1) Returns: - Dict[str, Any]: 更新后的节点数据,包含: - - id: 节点ID - - activation_value: 更新后的激活值 - - access_history: 更新后的访问历史 - - last_access_time: 最后访问时间 - - access_count: 访问次数 - - importance_score: 重要性分数 + Dict[str, Any]: 更新后的节点数据 Raises: ValueError: 如果节点不存在或节点标签无效 - RuntimeError: 如果重试次数耗尽仍然失败 + RuntimeError: 如果更新失败 """ if current_time is None: current_time = datetime.now() @@ -121,62 +106,48 @@ class AccessHistoryManager: f"Invalid node_label: {node_label}. Must be one of {valid_labels}" ) - # 使用乐观锁重试机制处理并发冲突 - for attempt in range(self.max_retries): - try: - # 步骤1:读取当前节点状态 - node_data = await self._fetch_node(node_id, node_label, end_user_id) - - if not node_data: - raise ValueError( - f"Node not found: {node_label} with id={node_id}" - ) - - # 步骤2:计算新的访问历史和激活值 - update_data = await self._calculate_update( - node_data=node_data, - current_time=current_time, - current_time_iso=current_time_iso, - access_times=access_times + try: + # 步骤1:读取当前节点状态 + node_data = await self._fetch_node(node_id, node_label, end_user_id) + + if not node_data: + raise ValueError( + f"Node not found: {node_label} with id={node_id}" ) - - # 步骤3:原子性更新节点(使用事务) - updated_node = await self._atomic_update( - node_id=node_id, - node_label=node_label, - update_data=update_data, - end_user_id=end_user_id - ) - - logger.info( - f"成功记录访问: {node_label}[{node_id}], " - f"activation={update_data['activation_value']:.4f}, " - f"access_count={update_data['access_count']}" - f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" - ) - - return updated_node - - except Exception as e: - if attempt < self.max_retries - 1: - # 带抖动的指数退避:base * 2^attempt * random(0.5, 1.0) - import random - backoff = 0.05 * (2 ** attempt) * random.uniform(0.5, 1.0) - logger.warning( - f"访问记录失败(尝试 {attempt + 1}/{self.max_retries})," - f"{backoff:.3f}s 后重试: {str(e)}" - ) - await asyncio.sleep(backoff) - continue - else: - logger.error( - f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], " - f"错误: {str(e)}" - ) - raise RuntimeError( - f"Failed to record access after {self.max_retries} attempts: {str(e)}" - ) - + + # 步骤2:计算新的访问历史和激活值 + update_data = await self._calculate_update( + node_data=node_data, + current_time=current_time, + current_time_iso=current_time_iso, + access_times=access_times + ) + + # 步骤3:使用 APOC 原子操作更新节点(无需重试) + updated_node = await self._atomic_update( + node_id=node_id, + node_label=node_label, + update_data=update_data, + end_user_id=end_user_id + ) + + logger.info( + f"成功记录访问: {node_label}[{node_id}], " + f"activation={update_data['activation_value']:.4f}, " + f"access_count={update_data['access_count']}" + f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}" + ) + + return updated_node + + except Exception as e: + logger.error( + f"访问记录失败: {node_label}[{node_id}], 错误: {str(e)}" + ) + raise RuntimeError( + f"Failed to record access: {str(e)}" + ) from e + async def record_batch_access( self, node_ids: List[str], @@ -187,9 +158,7 @@ class AccessHistoryManager: """ 批量记录多个节点的访问 - 为提高性能,批量更新多个节点的访问历史。 - 对同一个节点的多次访问会先在内存中合并,只发起一次更新, - 从而避免同节点并发写入导致的乐观锁冲突。 + 对同一个节点的多次访问会先在内存中合并,只发起一次更新。 Args: node_ids: 节点ID列表(可包含重复ID) @@ -230,12 +199,10 @@ class AccessHistoryManager: ) tasks.append((node_id, task)) - # Execute all tasks in parallel task_results = await asyncio.gather( *[t for _, t in tasks], return_exceptions=True ) - # Collect successful results and count failures results = [] failed_count = 0 @@ -255,7 +222,7 @@ class AccessHistoryManager: ) return results - + async def check_consistency( self, node_id: str, @@ -264,22 +231,6 @@ class AccessHistoryManager: ) -> Tuple[ConsistencyCheckResult, Optional[str]]: """ 检查节点数据的一致性 - - 验证以下一致性规则: - 1. access_history[-1] == last_access_time - 2. len(access_history) == access_count - 3. 如果有访问历史,必须有激活值 - 4. 激活值必须在有效范围内 [offset, 1.0] - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - Tuple[ConsistencyCheckResult, Optional[str]]: - - 一致性检查结果枚举 - - 错误描述(如果不一致) """ node_data = await self._fetch_node(node_id, node_label, end_user_id) @@ -291,7 +242,6 @@ class AccessHistoryManager: access_count = node_data.get('access_count', 0) activation_value = node_data.get('activation_value') - # 检查1:access_history[-1] == last_access_time if access_history and last_access_time: if access_history[-1] != last_access_time: return ( @@ -300,7 +250,6 @@ class AccessHistoryManager: f"last_access_time={last_access_time}" ) - # 检查2:len(access_history) == access_count if len(access_history) != access_count: return ( ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT, @@ -308,14 +257,12 @@ class AccessHistoryManager: f"access_count={access_count}" ) - # 检查3:有访问历史必须有激活值 if access_history and activation_value is None: return ( ConsistencyCheckResult.MISSING_ACTIVATION, "Node has access_history but activation_value is None" ) - # 检查4:激活值范围 if activation_value is not None: offset = self.actr_calculator.offset if not (offset <= activation_value <= 1.0): @@ -326,30 +273,14 @@ class AccessHistoryManager: ) return ConsistencyCheckResult.CONSISTENT, None - + async def check_batch_consistency( self, node_label: str, end_user_id: Optional[str] = None, limit: int = 1000 ) -> Dict[str, Any]: - """ - 批量检查多个节点的一致性 - - Args: - node_label: 节点标签 - end_user_id: 组ID(可选) - limit: 检查的最大节点数 - - Returns: - Dict[str, Any]: 一致性检查报告,包含: - - total_checked: 检查的节点总数 - - consistent_count: 一致的节点数 - - inconsistent_count: 不一致的节点数 - - inconsistencies: 不一致节点的详细信息列表 - - consistency_rate: 一致性率(0-1) - """ - # 查询所有相关节点 + """批量检查多个节点的一致性""" query = f""" MATCH (n:{node_label}) WHERE n.access_history IS NOT NULL @@ -368,7 +299,6 @@ class AccessHistoryManager: results = await self.connector.execute_query(query, **params) node_ids = [r['id'] for r in results] - # 检查每个节点 inconsistencies = [] consistent_count = 0 @@ -407,32 +337,15 @@ class AccessHistoryManager: ) return report - + async def repair_inconsistency( self, node_id: str, node_label: str, end_user_id: Optional[str] = None ) -> bool: - """ - 自动修复节点的数据不一致问题 - - 修复策略: - 1. 如果access_history[-1] != last_access_time:使用access_history[-1] - 2. 如果len(access_history) != access_count:使用len(access_history) - 3. 如果有历史但无激活值:重新计算激活值 - 4. 如果激活值超出范围:重新计算激活值 - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - bool: 修复成功返回True,否则返回False - """ + """自动修复节点的数据不一致问题""" try: - # 检查一致性 result, message = await self.check_consistency( node_id=node_id, node_label=node_label, @@ -443,7 +356,6 @@ class AccessHistoryManager: logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]") return True - # 获取节点数据 node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") @@ -452,17 +364,13 @@ class AccessHistoryManager: access_history = node_data.get('access_history') or [] importance_score = node_data.get('importance_score', 0.5) - # 准备修复数据 repair_data = {} - # 修复last_access_time if access_history: repair_data['last_access_time'] = access_history[-1] - # 修复access_count repair_data['access_count'] = len(access_history) - # 修复activation_value if access_history: current_time = datetime.now() last_access_dt = datetime.fromisoformat(access_history[-1]) @@ -478,7 +386,6 @@ class AccessHistoryManager: ) repair_data['activation_value'] = activation_value - # 执行修复 query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ @@ -509,26 +416,16 @@ class AccessHistoryManager: f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}" ) return False - + # ==================== 私有辅助方法 ==================== - + async def _fetch_node( self, node_id: str, node_label: str, end_user_id: Optional[str] = None ) -> Optional[Dict[str, Any]]: - """ - 获取节点数据 - - Args: - node_id: 节点ID - node_label: 节点标签 - end_user_id: 组ID(可选) - - Returns: - Optional[Dict[str, Any]]: 节点数据,如果不存在返回None - """ + """获取节点数据""" query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ @@ -552,7 +449,7 @@ class AccessHistoryManager: if results: return results[0] return None - + async def _calculate_update( self, node_data: Dict[str, Any], @@ -570,43 +467,37 @@ class AccessHistoryManager: access_times: 本次访问次数(合并后可能大于1) Returns: - Dict[str, Any]: 更新数据,包含所有需要更新的字段 + Dict[str, Any]: 更新数据 """ - access_history = node_data.get('access_history') or [] - # Handle None importance_score - default to 0.5 importance_score = node_data.get('importance_score') if importance_score is None: importance_score = 0.5 - # 追加新的访问时间(合并场景下追加多条相同时间戳) - new_access_history = access_history + [current_time_iso] * access_times + # 本次新增的时间戳 + new_timestamps = [current_time_iso] * access_times - # 修剪访问历史(如果过长) - access_history_dt = [ - datetime.fromisoformat(ts) for ts in new_access_history - ] + # 仅用本次新增的访问记录计算激活值 + new_history_dt = [current_time] * access_times trimmed_history_dt = self.actr_calculator.trim_access_history( - access_history=access_history_dt, + access_history=new_history_dt, current_time=current_time ) - trimmed_history = [ts.isoformat() for ts in trimmed_history_dt] - # 计算新的激活值 activation_value = self.actr_calculator.calculate_memory_activation( access_history=trimmed_history_dt, current_time=current_time, - last_access_time=current_time, # 最后访问时间就是当前时间 + last_access_time=current_time, importance_score=importance_score ) - # 返回所有需要更新的字段 return { 'activation_value': activation_value, - 'access_history': trimmed_history, + 'new_timestamps': new_timestamps, + 'access_count_delta': access_times, + 'access_count': len(trimmed_history_dt), 'last_access_time': current_time_iso, - 'access_count': len(trimmed_history) } - + async def _atomic_update( self, node_id: str, @@ -615,10 +506,10 @@ class AccessHistoryManager: end_user_id: Optional[str] = None ) -> Dict[str, Any]: """ - 原子性更新节点(使用乐观锁) + 原子性更新节点(使用 APOC 原子操作) - 使用Neo4j事务和版本号确保所有字段同时更新或回滚。 - 实现乐观锁机制防止并发冲突。 + 使用 apoc.atomic.add 和 apoc.atomic.insert 保证并发安全, + 无需 version 字段和乐观锁,数据库层面保证原子性。 Args: node_id: 节点ID @@ -630,127 +521,68 @@ class AccessHistoryManager: Dict[str, Any]: 更新后的节点数据 Raises: - RuntimeError: 如果更新失败或发生版本冲突 + RuntimeError: 如果更新失败 """ - # 定义事务函数 - async def update_transaction(tx, node_id, node_label, update_data, end_user_id): - # 步骤1:读取当前节点并获取版本号 - read_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - """ - if end_user_id: - read_query += " WHERE n.end_user_id = $end_user_id" - read_query += """ - RETURN n.id as id, - n.version as version, - n.activation_value as activation_value, - n.access_history as access_history, - n.last_access_time as last_access_time, - n.access_count as access_count, - n.importance_score as importance_score - """ + content_field_map = { + 'Statement': 'n.statement as statement', + 'MemorySummary': 'n.content as content', + 'ExtractedEntity': 'null as content_placeholder', + 'Community': 'n.summary as summary' + } + + if node_label not in content_field_map: + raise ValueError( + f"Unsupported node_label: {node_label}. " + f"Supported labels are: {list(content_field_map.keys())}" + ) + + content_field = content_field_map[node_label] + + where_clause = "" + if end_user_id: + where_clause = " AND n.end_user_id = $end_user_id" + + query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + WHERE true{where_clause} + CALL apoc.atomic.add(n, 'access_count', $access_count_delta, 5) YIELD oldValue AS old_count + WITH n + CALL (n) {{ + UNWIND $new_timestamps AS ts + CALL apoc.atomic.insert(n, 'access_history', size(n.access_history), ts, 5) YIELD oldValue + RETURN count(*) AS inserted + }} + SET n.activation_value = $activation_value, + n.last_access_time = $last_access_time + RETURN n.id as id, + n.activation_value as activation_value, + n.access_history as access_history, + n.last_access_time as last_access_time, + n.access_count as access_count, + n.importance_score as importance_score, + {content_field} + """ + + params = { + 'node_id': node_id, + 'access_count_delta': update_data['access_count_delta'], + 'new_timestamps': update_data['new_timestamps'], + 'activation_value': update_data['activation_value'], + 'last_access_time': update_data['last_access_time'], + } + if end_user_id: + params['end_user_id'] = end_user_id + + try: + results = await self.connector.execute_query(query, **params) - read_params = {'node_id': node_id} - if end_user_id: - read_params['end_user_id'] = end_user_id - - read_result = await tx.run(read_query, **read_params) - current_node = await read_result.single() - - if not current_node: + if not results: raise RuntimeError(f"Node not found: {node_label}[{node_id}]") - # 获取当前版本号(如果不存在则为0) - current_version = current_node.get('version', 0) or 0 - new_version = current_version + 1 - - # 步骤2:使用乐观锁更新节点 - # 根据节点类型构建完整的查询语句 - content_field_map = { - 'Statement': 'n.statement as statement', - 'MemorySummary': 'n.content as content', - 'ExtractedEntity': 'null as content_placeholder', # 占位符,后续会被过滤 - 'Community': 'n.summary as summary' - } - - # 显式检查节点类型,不支持的类型抛出错误 - if node_label not in content_field_map: - raise ValueError( - f"Unsupported node_label: {node_label}. " - f"Supported labels are: {list(content_field_map.keys())}" - ) - - content_field = content_field_map[node_label] - - # 构建 WHERE 子句 - where_conditions = [] - if end_user_id: - where_conditions.append("n.end_user_id = $end_user_id") - - # 添加版本检查 - if current_version > 0: - where_conditions.append("n.version = $current_version") - else: - where_conditions.append("(n.version IS NULL OR n.version = 0)") - - where_clause = " AND ".join(where_conditions) if where_conditions else "true" - - # 构建完整的更新查询 - update_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - WHERE {where_clause} - SET n.activation_value = $activation_value, - n.access_history = $access_history, - n.last_access_time = $last_access_time, - n.access_count = $access_count, - n.version = $new_version - RETURN n.id as id, - n.activation_value as activation_value, - n.access_history as access_history, - n.last_access_time as last_access_time, - n.access_count as access_count, - n.importance_score as importance_score, - n.version as version, - {content_field} - """ - - update_params = { - 'node_id': node_id, - 'current_version': current_version, - 'new_version': new_version, - 'activation_value': update_data['activation_value'], - 'access_history': update_data['access_history'], - 'last_access_time': update_data['last_access_time'], - 'access_count': update_data['access_count'] - } - if end_user_id: - update_params['end_user_id'] = end_user_id - - update_result = await tx.run(update_query, **update_params) - updated_node = await update_result.single() - - if not updated_node: - raise RuntimeError( - f"Version conflict detected for {node_label}[{node_id}]. " - f"Expected version {current_version}, but node was modified by another transaction." - ) - - # 转换为字典并移除占位符字段 - result_dict = dict(updated_node) + result_dict = dict(results[0]) result_dict.pop('content_placeholder', None) return result_dict - - # 执行事务 - try: - result = await self.connector.execute_write_transaction( - update_transaction, - node_id=node_id, - node_label=node_label, - update_data=update_data, - end_user_id=end_user_id - ) - return result except Exception as e: logger.error( f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"