Merge #85 into develop from feature/actr-forget
[feature]actr-记忆遗忘需求开发
* feature/actr-forget: (12 commits squashed)
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget
- [fix]Eliminate the interference caused by redundant code
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget
Signed-off-by: 乐力齐 <accounts_690c7b0af9007d7e338af636@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/85
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -24,6 +25,157 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量更新节点的激活值
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史和激活值。
|
||||
使用重试机制处理更新失败的情况。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
nodes: 节点列表,每个节点必须包含 'id' 字段
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选)
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 成功更新的节点列表
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
# 创建计算器和管理器实例
|
||||
actr_calculator = ACTRCalculator()
|
||||
access_manager = AccessHistoryManager(
|
||||
connector=connector,
|
||||
actr_calculator=actr_calculator,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# 提取节点ID列表
|
||||
node_ids = [node.get('id') for node in nodes if node.get('id')]
|
||||
|
||||
if not node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
node_ids=node_ids,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
|
||||
)
|
||||
|
||||
return updated_nodes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
|
||||
)
|
||||
# 失败时返回原始节点列表
|
||||
return nodes
|
||||
|
||||
|
||||
async def _update_search_results_activation(
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
group_id: Optional[str] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
更新搜索结果中所有知识节点的激活值
|
||||
|
||||
对 Statement、ExtractedEntity、MemorySummary 节点进行批量激活值更新。
|
||||
ChunkNode 和 DialogueNode 不参与激活值更新(数据层隔离)。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
results: 搜索结果字典,包含不同类型节点的列表
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
|
||||
"""
|
||||
# 定义需要更新激活值的节点类型
|
||||
knowledge_node_types = {
|
||||
'statements': 'Statement',
|
||||
'entities': 'ExtractedEntity',
|
||||
'summaries': 'MemorySummary'
|
||||
}
|
||||
|
||||
# 并行更新所有类型的节点
|
||||
update_tasks = []
|
||||
update_keys = []
|
||||
|
||||
for key, label in knowledge_node_types.items():
|
||||
if key in results and results[key]:
|
||||
update_tasks.append(
|
||||
_update_activation_values_batch(
|
||||
connector=connector,
|
||||
nodes=results[key],
|
||||
node_label=label,
|
||||
group_id=group_id
|
||||
)
|
||||
)
|
||||
update_keys.append(key)
|
||||
|
||||
if not update_tasks:
|
||||
return results
|
||||
|
||||
# 并行执行所有更新
|
||||
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
|
||||
|
||||
# 更新结果字典,保留原始搜索分数
|
||||
updated_results = results.copy()
|
||||
for key, update_result in zip(update_keys, update_results):
|
||||
if not isinstance(update_result, Exception):
|
||||
# 更新成功,合并原始搜索结果和更新后的激活值数据
|
||||
# 保留原始的 score 字段(BM25/Embedding 分数)
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
# 创建 ID 到原始节点的映射(用于快速查找 score)
|
||||
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
|
||||
|
||||
# 合并数据:激活值来自更新结果,score 来自原始结果
|
||||
merged_nodes = []
|
||||
for updated_node in updated_nodes:
|
||||
node_id = updated_node.get('id')
|
||||
if node_id and node_id in original_map:
|
||||
# 保留原始的 score 字段
|
||||
original_score = original_map[node_id].get('score')
|
||||
if original_score is not None:
|
||||
updated_node['score'] = original_score
|
||||
merged_nodes.append(updated_node)
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
# 更新失败,记录错误但保留原始结果
|
||||
logger.warning(
|
||||
f"更新 {key} 激活值失败: {str(update_result)}"
|
||||
)
|
||||
|
||||
return updated_results
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
@@ -36,6 +188,7 @@ async def search_graph(
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
INTEGRATED: Updates activation values for knowledge nodes before returning results
|
||||
|
||||
- Statements: matches s.statement CONTAINS q
|
||||
- Entities: matches e.name CONTAINS q
|
||||
@@ -50,7 +203,7 @@ async def search_graph(
|
||||
include: List of categories to search (default: all)
|
||||
|
||||
Returns:
|
||||
Dictionary with search results per category
|
||||
Dictionary with search results per category (with updated activation values)
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
@@ -106,6 +259,13 @@ async def search_graph(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -121,6 +281,7 @@ async def search_graph_by_embedding(
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
INTEGRATED: Updates activation values for knowledge nodes before returning results
|
||||
|
||||
- Computes query embedding with the provided embedder_client
|
||||
- Ranks by cosine similarity in Cypher
|
||||
@@ -203,6 +364,16 @@ async def search_graph_by_embedding(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
update_start = time.time()
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
update_time = time.time() - update_start
|
||||
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||
|
||||
return results
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
@@ -304,6 +475,8 @@ async def search_graph_by_keyword_temporal(
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements containing query_text created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -326,7 +499,15 @@ async def search_graph_by_keyword_temporal(
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
return {"statements": statements}
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
@@ -342,6 +523,8 @@ async def search_graph_by_temporal(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -362,7 +545,16 @@ async def search_graph_by_temporal(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
@@ -419,6 +611,8 @@ async def search_graph_by_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -436,7 +630,16 @@ async def search_graph_by_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -448,6 +651,8 @@ async def search_graph_by_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -465,7 +670,16 @@ async def search_graph_by_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -477,6 +691,8 @@ async def search_graph_g_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -494,7 +710,16 @@ async def search_graph_g_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -506,6 +731,8 @@ async def search_graph_g_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -523,7 +750,16 @@ async def search_graph_g_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -535,6 +771,8 @@ async def search_graph_l_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -552,7 +790,16 @@ async def search_graph_l_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -564,6 +811,8 @@ async def search_graph_l_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -581,4 +830,13 @@ async def search_graph_l_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
Reference in New Issue
Block a user