Fix/content attribute (#105)
* [fix]Fix the return of the "content" attribute * [changes]Improve the code based on AI review * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * [fix]Fix the return of the "content" attribute * [changes]Improve the code based on AI review * Apply suggestion from @sourcery-ai[bot] Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com> * [changes]Improve the code based on AI review --------- Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
This commit is contained in:
@@ -620,34 +620,52 @@ class AccessHistoryManager:
|
|||||||
new_version = current_version + 1
|
new_version = current_version + 1
|
||||||
|
|
||||||
# 步骤2:使用乐观锁更新节点
|
# 步骤2:使用乐观锁更新节点
|
||||||
# 只有当版本号匹配时才更新
|
# 根据节点类型构建完整的查询语句
|
||||||
update_query = f"""
|
content_field_map = {
|
||||||
MATCH (n:{node_label} {{id: $node_id}})
|
'Statement': 'n.statement as statement',
|
||||||
"""
|
'MemorySummary': 'n.content as content',
|
||||||
|
'ExtractedEntity': 'null as content_placeholder' # 占位符,后续会被过滤
|
||||||
|
}
|
||||||
|
|
||||||
|
# 显式检查节点类型,不支持的类型抛出错误
|
||||||
|
if node_label not in content_field_map:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported node_label: {node_label}. "
|
||||||
|
f"Supported labels are: {list(content_field_map.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
content_field = content_field_map[node_label]
|
||||||
|
|
||||||
|
# 构建 WHERE 子句
|
||||||
|
where_conditions = []
|
||||||
if group_id:
|
if group_id:
|
||||||
update_query += " WHERE n.group_id = $group_id"
|
where_conditions.append("n.group_id = $group_id")
|
||||||
|
|
||||||
# 添加版本检查
|
# 添加版本检查
|
||||||
if current_version > 0:
|
if current_version > 0:
|
||||||
update_query += " AND n.version = $current_version"
|
where_conditions.append("n.version = $current_version")
|
||||||
else:
|
else:
|
||||||
# 如果节点没有版本号,检查是否为首次更新
|
where_conditions.append("(n.version IS NULL OR n.version = 0)")
|
||||||
update_query += " AND (n.version IS NULL OR n.version = 0)"
|
|
||||||
|
|
||||||
update_query += """
|
where_clause = " AND ".join(where_conditions) if where_conditions else "true"
|
||||||
|
|
||||||
|
# 构建完整的更新查询
|
||||||
|
update_query = f"""
|
||||||
|
MATCH (n:{node_label} {{id: $node_id}})
|
||||||
|
WHERE {where_clause}
|
||||||
SET n.activation_value = $activation_value,
|
SET n.activation_value = $activation_value,
|
||||||
n.access_history = $access_history,
|
n.access_history = $access_history,
|
||||||
n.last_access_time = $last_access_time,
|
n.last_access_time = $last_access_time,
|
||||||
n.access_count = $access_count,
|
n.access_count = $access_count,
|
||||||
n.version = $new_version
|
n.version = $new_version
|
||||||
RETURN n.id as id,
|
RETURN n.id as id,
|
||||||
n.statement as statement,
|
|
||||||
n.activation_value as activation_value,
|
n.activation_value as activation_value,
|
||||||
n.access_history as access_history,
|
n.access_history as access_history,
|
||||||
n.last_access_time as last_access_time,
|
n.last_access_time as last_access_time,
|
||||||
n.access_count as access_count,
|
n.access_count as access_count,
|
||||||
n.importance_score as importance_score,
|
n.importance_score as importance_score,
|
||||||
n.version as version
|
n.version as version,
|
||||||
|
{content_field}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
update_params = {
|
update_params = {
|
||||||
@@ -671,7 +689,11 @@ class AccessHistoryManager:
|
|||||||
f"Expected version {current_version}, but node was modified by another transaction."
|
f"Expected version {current_version}, but node was modified by another transaction."
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(updated_node)
|
# 转换为字典并移除占位符字段
|
||||||
|
result_dict = dict(updated_node)
|
||||||
|
result_dict.pop('content_placeholder', None)
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
|
||||||
# 执行事务
|
# 执行事务
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -66,24 +66,38 @@ async def _update_activation_values_batch(
|
|||||||
max_retries=max_retries
|
max_retries=max_retries
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提取节点ID列表
|
# 提取节点ID列表并去重(保持原始顺序)
|
||||||
node_ids = [node.get('id') for node in nodes if node.get('id')]
|
seen_ids = set()
|
||||||
|
unique_node_ids = []
|
||||||
|
for node in nodes:
|
||||||
|
node_id = node.get('id')
|
||||||
|
if node_id and node_id not in seen_ids:
|
||||||
|
seen_ids.add(node_id)
|
||||||
|
unique_node_ids.append(node_id)
|
||||||
|
|
||||||
if not node_ids:
|
if not unique_node_ids:
|
||||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
|
# 记录去重信息(仅针对具有有效 ID 的节点)
|
||||||
|
id_nodes_count = sum(1 for n in nodes if n.get("id"))
|
||||||
|
if len(unique_node_ids) < id_nodes_count:
|
||||||
|
logger.info(
|
||||||
|
f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, "
|
||||||
|
f"去重后唯一ID数量={len(unique_node_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
# 批量记录访问
|
# 批量记录访问
|
||||||
try:
|
try:
|
||||||
updated_nodes = await access_manager.record_batch_access(
|
updated_nodes = await access_manager.record_batch_access(
|
||||||
node_ids=node_ids,
|
node_ids=unique_node_ids,
|
||||||
node_label=node_label,
|
node_label=node_label,
|
||||||
group_id=group_id
|
group_id=group_id
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"批量更新激活值成功: {node_label}, "
|
f"批量更新激活值成功: {node_label}, "
|
||||||
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
|
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return updated_nodes
|
return updated_nodes
|
||||||
@@ -153,19 +167,38 @@ async def _update_search_results_activation(
|
|||||||
original_nodes = results[key]
|
original_nodes = results[key]
|
||||||
updated_nodes = update_result
|
updated_nodes = update_result
|
||||||
|
|
||||||
# 创建 ID 到原始节点的映射(用于快速查找 score)
|
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
|
||||||
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
|
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
|
||||||
|
|
||||||
# 合并数据:激活值来自更新结果,score 来自原始结果
|
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
|
||||||
merged_nodes = []
|
merged_nodes = []
|
||||||
for updated_node in updated_nodes:
|
for original_node in original_nodes:
|
||||||
node_id = updated_node.get('id')
|
node_id = original_node.get('id')
|
||||||
if node_id and node_id in original_map:
|
if node_id and node_id in updated_map:
|
||||||
# 保留原始的 score 字段
|
# 从原始节点开始,用更新后的激活值数据覆盖
|
||||||
original_score = original_map[node_id].get('score')
|
merged_node = original_node.copy()
|
||||||
if original_score is not None:
|
|
||||||
updated_node['score'] = original_score
|
# 更新激活值相关字段
|
||||||
merged_nodes.append(updated_node)
|
activation_fields = {
|
||||||
|
'activation_value',
|
||||||
|
'access_history',
|
||||||
|
'last_access_time',
|
||||||
|
'access_count',
|
||||||
|
'importance_score',
|
||||||
|
'version',
|
||||||
|
'statement', # Statement 节点的内容字段
|
||||||
|
'content' # MemorySummary 节点的内容字段
|
||||||
|
}
|
||||||
|
|
||||||
|
# 只更新激活值相关字段,保留原始节点的其他字段
|
||||||
|
for field in activation_fields:
|
||||||
|
if field in updated_map[node_id]:
|
||||||
|
merged_node[field] = updated_map[node_id][field]
|
||||||
|
|
||||||
|
merged_nodes.append(merged_node)
|
||||||
|
else:
|
||||||
|
# 如果没有更新数据,保留原始节点
|
||||||
|
merged_nodes.append(original_node)
|
||||||
|
|
||||||
updated_results[key] = merged_nodes
|
updated_results[key] = merged_nodes
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user