fix(memory): improve optimistic lock resilience in access history manager

- Increase max_retries from 3 to 5 for concurrent conflict recovery
- Add randomized exponential backoff between retries to reduce contention
- Merge duplicate node accesses in batch operations to avoid self-conflicts
- Support access_times parameter for merged batch access counting
- Add Community node label support in atomic update content field map
This commit is contained in:
lanceyq
2026-04-03 16:46:09 +08:00
parent a711635694
commit 117e29fbe3

View File

@@ -57,7 +57,7 @@ class AccessHistoryManager:
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
max_retries: int = 3
max_retries: int = 5
):
"""
初始化访问历史管理器
@@ -76,7 +76,8 @@ class AccessHistoryManager:
node_id: str,
node_label: str,
end_user_id: Optional[str] = None,
current_time: Optional[datetime] = None
current_time: Optional[datetime] = None,
access_times: int = 1
) -> Dict[str, Any]:
"""
记录节点访问并原子性更新所有相关字段
@@ -93,6 +94,7 @@ class AccessHistoryManager:
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
end_user_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
access_times: 本次访问次数默认1批量合并时可能大于1
Returns:
Dict[str, Any]: 更新后的节点数据,包含:
@@ -134,7 +136,8 @@ class AccessHistoryManager:
update_data = await self._calculate_update(
node_data=node_data,
current_time=current_time,
current_time_iso=current_time_iso
current_time_iso=current_time_iso,
access_times=access_times
)
# 步骤3原子性更新节点使用事务
@@ -149,15 +152,21 @@ class AccessHistoryManager:
f"成功记录访问: {node_label}[{node_id}], "
f"activation={update_data['activation_value']:.4f}, "
f"access_count={update_data['access_count']}"
f"{f', 合并访问次数={access_times}' if access_times > 1 else ''}"
)
return updated_node
except Exception as e:
if attempt < self.max_retries - 1:
# 随机退避:避免并发请求同时重试再次冲突
import random
backoff = random.uniform(0.05, 0.2) * (attempt + 1)
logger.warning(
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}: {str(e)}"
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}"
f"{backoff:.3f}s 后重试: {str(e)}"
)
await asyncio.sleep(backoff)
continue
else:
logger.error(
@@ -179,10 +188,11 @@ class AccessHistoryManager:
批量记录多个节点的访问
为提高性能,批量更新多个节点的访问历史。
每个节点独立更新,失败的节点不影响其他节点。
对同一个节点的多次访问会先在内存中合并,只发起一次更新,
从而避免同节点并发写入导致的乐观锁冲突。
Args:
node_ids: 节点ID列表
node_ids: 节点ID列表可包含重复ID
node_label: 节点标签(所有节点必须是同一类型)
end_user_id: 组ID可选
current_time: 当前时间(可选)
@@ -196,25 +206,40 @@ class AccessHistoryManager:
if current_time is None:
current_time = datetime.now()
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
tasks = []
# 合并同一节点的访问次数,避免对同一节点并发写入
access_count_map: Dict[str, int] = {}
for node_id in node_ids:
access_count_map[node_id] = access_count_map.get(node_id, 0) + 1
merged_count = len(node_ids) - len(access_count_map)
if merged_count > 0:
logger.info(
f"批量访问合并: 原始={len(node_ids)}, "
f"去重后={len(access_count_map)}, 合并={merged_count}"
)
# 对去重后的节点并行发起更新
tasks = []
for node_id, access_times in access_count_map.items():
task = self.record_access(
node_id=node_id,
node_label=node_label,
end_user_id=end_user_id,
current_time=current_time
current_time=current_time,
access_times=access_times
)
tasks.append(task)
tasks.append((node_id, task))
# Execute all tasks in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
task_results = await asyncio.gather(
*[t for _, t in tasks], return_exceptions=True
)
# Collect successful results and count failures
results = []
failed_count = 0
for node_id, result in zip(node_ids, task_results):
for (node_id, _), result in zip(tasks, task_results):
if isinstance(result, Exception):
failed_count += 1
logger.warning(
@@ -225,7 +250,7 @@ class AccessHistoryManager:
batch_duration = time.time() - batch_start
logger.info(
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(access_count_map)}, "
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
)
@@ -532,7 +557,8 @@ class AccessHistoryManager:
self,
node_data: Dict[str, Any],
current_time: datetime,
current_time_iso: str
current_time_iso: str,
access_times: int = 1
) -> Dict[str, Any]:
"""
计算更新数据
@@ -541,6 +567,7 @@ class AccessHistoryManager:
node_data: 当前节点数据
current_time: 当前时间datetime对象
current_time_iso: 当前时间ISO格式字符串
access_times: 本次访问次数合并后可能大于1
Returns:
Dict[str, Any]: 更新数据,包含所有需要更新的字段
@@ -551,8 +578,8 @@ class AccessHistoryManager:
if importance_score is None:
importance_score = 0.5
# 追加新的访问时间
new_access_history = access_history + [current_time_iso]
# 追加新的访问时间(合并场景下追加多条相同时间戳)
new_access_history = access_history + [current_time_iso] * access_times
# 修剪访问历史(如果过长)
access_history_dt = [
@@ -642,7 +669,8 @@ class AccessHistoryManager:
content_field_map = {
'Statement': 'n.statement as statement',
'MemorySummary': 'n.content as content',
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
'ExtractedEntity': 'null as content_placeholder', # 占位符,后续会被过滤
'Community': 'n.summary as summary'
}
# 显式检查节点类型,不支持的类型抛出错误