fix(memory): improve optimistic lock resilience in access history manager

- Increase max_retries from 3 to 5 for concurrent conflict recovery
- Add randomized exponential backoff between retries to reduce contention
- Merge duplicate node accesses in batch operations to avoid self-conflicts
- Support access_times parameter for merged batch access counting
- Add Community node label support in atomic update content field map
This commit is contained in:
lanceyq
2026-04-03 16:46:09 +08:00
parent a711635694
commit 117e29fbe3

View File

@@ -57,7 +57,7 @@ class AccessHistoryManager:
self, self,
connector: Neo4jConnector, connector: Neo4jConnector,
actr_calculator: ACTRCalculator, actr_calculator: ACTRCalculator,
max_retries: int = 3 max_retries: int = 5
): ):
""" """
初始化访问历史管理器 初始化访问历史管理器
@@ -76,7 +76,8 @@ class AccessHistoryManager:
node_id: str, node_id: str,
node_label: str, node_label: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None current_time: Optional[datetime] = None,
access_times: int = 1
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
记录节点访问并原子性更新所有相关字段 记录节点访问并原子性更新所有相关字段
@@ -93,6 +94,7 @@ class AccessHistoryManager:
node_label: 节点标签Statement, ExtractedEntity, MemorySummary node_label: 节点标签Statement, ExtractedEntity, MemorySummary
end_user_id: 组ID可选用于过滤 end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间) current_time: 当前时间(可选,默认使用系统时间)
access_times: 本次访问次数默认1批量合并时可能大于1
Returns: Returns:
Dict[str, Any]: 更新后的节点数据,包含: Dict[str, Any]: 更新后的节点数据,包含:
@@ -134,7 +136,8 @@ class AccessHistoryManager:
update_data = await self._calculate_update( update_data = await self._calculate_update(
node_data=node_data, node_data=node_data,
current_time=current_time, current_time=current_time,
current_time_iso=current_time_iso current_time_iso=current_time_iso,
access_times=access_times
) )
# 步骤3原子性更新节点使用事务 # 步骤3原子性更新节点使用事务
@@ -149,15 +152,21 @@ class AccessHistoryManager:
f"成功记录访问: {node_label}[{node_id}], " f"成功记录访问: {node_label}[{node_id}], "
f"activation={update_data['activation_value']:.4f}, " f"activation={update_data['activation_value']:.4f}, "
f"access_count={update_data['access_count']}" f"access_count={update_data['access_count']}"
f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}"
) )
return updated_node return updated_node
except Exception as e: except Exception as e:
if attempt < self.max_retries - 1: if attempt < self.max_retries - 1:
# 随机退避:避免并发请求同时重试再次冲突
import random
backoff = random.uniform(0.05, 0.2) * (attempt + 1)
logger.warning( logger.warning(
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}: {str(e)}" f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}"
f"{backoff:.3f}s 后重试: {str(e)}"
) )
await asyncio.sleep(backoff)
continue continue
else: else:
logger.error( logger.error(
@@ -179,10 +188,11 @@ class AccessHistoryManager:
批量记录多个节点的访问 批量记录多个节点的访问
为提高性能,批量更新多个节点的访问历史。 为提高性能,批量更新多个节点的访问历史。
每个节点独立更新,失败的节点不影响其他节点。 对同一个节点的多次访问会先在内存中合并,只发起一次更新,
从而避免同节点并发写入导致的乐观锁冲突。
Args: Args:
node_ids: 节点ID列表 node_ids: 节点ID列表可包含重复ID
node_label: 节点标签(所有节点必须是同一类型) node_label: 节点标签(所有节点必须是同一类型)
end_user_id: 组ID可选 end_user_id: 组ID可选
current_time: 当前时间(可选) current_time: 当前时间(可选)
@@ -196,25 +206,40 @@ class AccessHistoryManager:
if current_time is None: if current_time is None:
current_time = datetime.now() current_time = datetime.now()
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially # 合并同一节点的访问次数,避免对同一节点并发写入
tasks = [] access_count_map: Dict[str, int] = {}
for node_id in node_ids: for node_id in node_ids:
access_count_map[node_id] = access_count_map.get(node_id, 0) + 1
merged_count = len(node_ids) - len(access_count_map)
if merged_count > 0:
logger.info(
f"批量访问合并: 原始={len(node_ids)}, "
f"去重后={len(access_count_map)}, 合并={merged_count}"
)
# 对去重后的节点并行发起更新
tasks = []
for node_id, access_times in access_count_map.items():
task = self.record_access( task = self.record_access(
node_id=node_id, node_id=node_id,
node_label=node_label, node_label=node_label,
end_user_id=end_user_id, end_user_id=end_user_id,
current_time=current_time current_time=current_time,
access_times=access_times
) )
tasks.append(task) tasks.append((node_id, task))
# Execute all tasks in parallel # Execute all tasks in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(
*[t for _, t in tasks], return_exceptions=True
)
# Collect successful results and count failures # Collect successful results and count failures
results = [] results = []
failed_count = 0 failed_count = 0
for node_id, result in zip(node_ids, task_results): for (node_id, _), result in zip(tasks, task_results):
if isinstance(result, Exception): if isinstance(result, Exception):
failed_count += 1 failed_count += 1
logger.warning( logger.warning(
@@ -225,7 +250,7 @@ class AccessHistoryManager:
batch_duration = time.time() - batch_start batch_duration = time.time() - batch_start
logger.info( logger.info(
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, "
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s" f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
) )
@@ -532,7 +557,8 @@ class AccessHistoryManager:
self, self,
node_data: Dict[str, Any], node_data: Dict[str, Any],
current_time: datetime, current_time: datetime,
current_time_iso: str current_time_iso: str,
access_times: int = 1
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
计算更新数据 计算更新数据
@@ -541,6 +567,7 @@ class AccessHistoryManager:
node_data: 当前节点数据 node_data: 当前节点数据
current_time: 当前时间datetime对象 current_time: 当前时间datetime对象
current_time_iso: 当前时间ISO格式字符串 current_time_iso: 当前时间ISO格式字符串
access_times: 本次访问次数合并后可能大于1
Returns: Returns:
Dict[str, Any]: 更新数据,包含所有需要更新的字段 Dict[str, Any]: 更新数据,包含所有需要更新的字段
@@ -551,8 +578,8 @@ class AccessHistoryManager:
if importance_score is None: if importance_score is None:
importance_score = 0.5 importance_score = 0.5
# 追加新的访问时间 # 追加新的访问时间(合并场景下追加多条相同时间戳)
new_access_history = access_history + [current_time_iso] new_access_history = access_history + [current_time_iso] * access_times
# 修剪访问历史(如果过长) # 修剪访问历史(如果过长)
access_history_dt = [ access_history_dt = [
@@ -642,7 +669,8 @@ class AccessHistoryManager:
content_field_map = { content_field_map = {
'Statement': 'n.statement as statement', 'Statement': 'n.statement as statement',
'MemorySummary': 'n.content as content', 'MemorySummary': 'n.content as content',
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤 'ExtractedEntity': 'null as content_placeholder', # 占位符,后续会被过滤
'Community': 'n.summary as summary'
} }
# 显式检查节点类型,不支持的类型抛出错误 # 显式检查节点类型,不支持的类型抛出错误