diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index 913874f1..1a2e3cbc 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -620,34 +620,52 @@ class AccessHistoryManager: new_version = current_version + 1 # 步骤2:使用乐观锁更新节点 - # 只有当版本号匹配时才更新 - update_query = f""" - MATCH (n:{node_label} {{id: $node_id}}) - """ + # 根据节点类型构建完整的查询语句 + content_field_map = { + 'Statement': 'n.statement as statement', + 'MemorySummary': 'n.content as content', + 'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤 + } + + # 显式检查节点类型,不支持的类型抛出错误 + if node_label not in content_field_map: + raise ValueError( + f"Unsupported node_label: {node_label}. " + f"Supported labels are: {list(content_field_map.keys())}" + ) + + content_field = content_field_map[node_label] + + # 构建 WHERE 子句 + where_conditions = [] if group_id: - update_query += " WHERE n.group_id = $group_id" + where_conditions.append("n.group_id = $group_id") # 添加版本检查 if current_version > 0: - update_query += " AND n.version = $current_version" + where_conditions.append("n.version = $current_version") else: - # 如果节点没有版本号,检查是否为首次更新 - update_query += " AND (n.version IS NULL OR n.version = 0)" + where_conditions.append("(n.version IS NULL OR n.version = 0)") - update_query += """ + where_clause = " AND ".join(where_conditions) if where_conditions else "true" + + # 构建完整的更新查询 + update_query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + WHERE {where_clause} SET n.activation_value = $activation_value, n.access_history = $access_history, n.last_access_time = $last_access_time, n.access_count = $access_count, n.version = $new_version RETURN n.id as id, - n.statement as statement, n.activation_value as activation_value, n.access_history as access_history, n.last_access_time as last_access_time, n.access_count as access_count, n.importance_score as importance_score, - n.version as version + n.version as version, + {content_field} """ update_params = { @@ -671,7 +689,11 @@ class AccessHistoryManager: f"Expected version {current_version}, but node was modified by another transaction." ) - return dict(updated_node) + # 转换为字典并移除占位符字段 + result_dict = dict(updated_node) + result_dict.pop('content_placeholder', None) + + return result_dict # 执行事务 try: diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 1549ef86..80756793 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -66,24 +66,38 @@ async def _update_activation_values_batch( max_retries=max_retries ) - # 提取节点ID列表 - node_ids = [node.get('id') for node in nodes if node.get('id')] + # 提取节点ID列表并去重(保持原始顺序) + seen_ids = set() + unique_node_ids = [] + for node in nodes: + node_id = node.get('id') + if node_id and node_id not in seen_ids: + seen_ids.add(node_id) + unique_node_ids.append(node_id) - if not node_ids: + if not unique_node_ids: logger.warning(f"批量更新激活值:没有有效的节点ID") return nodes + + # 记录去重信息(仅针对具有有效 ID 的节点) + id_nodes_count = sum(1 for n in nodes if n.get("id")) + if len(unique_node_ids) < id_nodes_count: + logger.info( + f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, " + f"去重后唯一ID数量={len(unique_node_ids)}" + ) # 批量记录访问 try: updated_nodes = await access_manager.record_batch_access( - node_ids=node_ids, + node_ids=unique_node_ids, node_label=node_label, group_id=group_id ) logger.info( f"批量更新激活值成功: {node_label}, " - f"更新数量={len(updated_nodes)}/{len(node_ids)}" + f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}" ) return updated_nodes @@ -153,19 +167,38 @@ async def _update_search_results_activation( original_nodes = results[key] updated_nodes = update_result - # 创建 ID 到原始节点的映射(用于快速查找 score) - original_map = {node.get('id'): node for node in original_nodes if node.get('id')} + # 创建 ID 到更新节点的映射(用于快速查找激活值数据) + updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')} - # 合并数据:激活值来自更新结果,score 来自原始结果 + # 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充 merged_nodes = [] - for updated_node in updated_nodes: - node_id = updated_node.get('id') - if node_id and node_id in original_map: - # 保留原始的 score 字段 - original_score = original_map[node_id].get('score') - if original_score is not None: - updated_node['score'] = original_score - merged_nodes.append(updated_node) + for original_node in original_nodes: + node_id = original_node.get('id') + if node_id and node_id in updated_map: + # 从原始节点开始,用更新后的激活值数据覆盖 + merged_node = original_node.copy() + + # 更新激活值相关字段 + activation_fields = { + 'activation_value', + 'access_history', + 'last_access_time', + 'access_count', + 'importance_score', + 'version', + 'statement', # Statement 节点的内容字段 + 'content' # MemorySummary 节点的内容字段 + } + + # 只更新激活值相关字段,保留原始节点的其他字段 + for field in activation_fields: + if field in updated_map[node_id]: + merged_node[field] = updated_map[node_id][field] + + merged_nodes.append(merged_node) + else: + # 如果没有更新数据,保留原始节点 + merged_nodes.append(original_node) updated_results[key] = merged_nodes else: