Merge #85 into develop from feature/actr-forget
[feature]actr-记忆遗忘需求开发
* feature/actr-forget: (12 commits squashed)
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget
- [fix]Eliminate the interference caused by redundant code
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget
Signed-off-by: 乐力齐 <accounts_690c7b0af9007d7e338af636@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/85
This commit is contained in:
@@ -106,7 +106,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
"emotion_intensity": statement.emotion_intensity,
|
||||
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
|
||||
"emotion_subject": statement.emotion_subject,
|
||||
"emotion_target": statement.emotion_target
|
||||
"emotion_target": statement.emotion_target,
|
||||
# 添加 ACT-R 记忆激活属性
|
||||
"importance_score": statement.importance_score,
|
||||
"activation_value": statement.activation_value,
|
||||
"access_history": statement.access_history if statement.access_history else [],
|
||||
"last_access_time": statement.last_access_time,
|
||||
"access_count": statement.access_count
|
||||
}
|
||||
flattened_statements.append(flattened_statement)
|
||||
|
||||
|
||||
@@ -38,7 +38,12 @@ SET s += {
|
||||
valid_at: statement.valid_at,
|
||||
invalid_at: statement.invalid_at,
|
||||
statement_embedding: statement.statement_embedding,
|
||||
relevence_info: statement.relevence_info
|
||||
relevence_info: statement.relevence_info,
|
||||
importance_score: statement.importance_score,
|
||||
activation_value: statement.activation_value,
|
||||
access_history: statement.access_history,
|
||||
last_access_time: statement.last_access_time,
|
||||
access_count: statement.access_count
|
||||
}
|
||||
RETURN s.id AS uuid
|
||||
"""
|
||||
@@ -111,7 +116,12 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength
|
||||
ELSE e.connect_strength
|
||||
END
|
||||
END
|
||||
END,
|
||||
e.importance_score = CASE WHEN entity.importance_score IS NOT NULL THEN entity.importance_score ELSE coalesce(e.importance_score, 0.5) END,
|
||||
e.activation_value = CASE WHEN entity.activation_value IS NOT NULL THEN entity.activation_value ELSE e.activation_value END,
|
||||
e.access_history = CASE WHEN entity.access_history IS NOT NULL THEN entity.access_history ELSE coalesce(e.access_history, []) END,
|
||||
e.last_access_time = CASE WHEN entity.last_access_time IS NOT NULL THEN entity.last_access_time ELSE e.last_access_time END,
|
||||
e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END
|
||||
RETURN e.id AS uuid
|
||||
"""
|
||||
|
||||
@@ -225,6 +235,10 @@ RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.entity_type AS entity_type,
|
||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||
e.last_access_time AS last_access_time,
|
||||
COALESCE(e.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -243,6 +257,10 @@ RETURN s.id AS id,
|
||||
s.expired_at AS expired_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -258,6 +276,9 @@ RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||
c.last_access_time AS last_access_time,
|
||||
COALESCE(c.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -278,6 +299,10 @@ RETURN s.id AS id,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -305,6 +330,10 @@ RETURN e.id AS id,
|
||||
e.connect_strength AS connect_strength,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT c.id) AS chunk_ids,
|
||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||
e.last_access_time AS last_access_time,
|
||||
COALESCE(e.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -322,6 +351,9 @@ RETURN c.id AS chunk_id,
|
||||
c.sequence_number AS sequence_number,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||
c.last_access_time AS last_access_time,
|
||||
COALESCE(c.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -419,7 +451,11 @@ RETURN s.id AS id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
collect(DISTINCT s.id) AS statement_ids
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count
|
||||
ORDER BY datetime(s.created_at) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
@@ -446,6 +482,10 @@ RETURN s.id AS id,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY s.created_at DESC, score DESC
|
||||
LIMIT $limit
|
||||
@@ -635,6 +675,10 @@ RETURN m.id AS id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(m.importance_score, 0.5) AS importance_score,
|
||||
m.last_access_time AS last_access_time,
|
||||
COALESCE(m.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -653,6 +697,10 @@ RETURN m.id AS id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(m.importance_score, 0.5) AS importance_score,
|
||||
m.last_access_time AS last_access_time,
|
||||
COALESCE(m.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
|
||||
@@ -55,6 +55,13 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
if 'aliases' not in n or n['aliases'] is None:
|
||||
n['aliases'] = []
|
||||
|
||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||
n['importance_score'] = n.get('importance_score', 0.5)
|
||||
n['activation_value'] = n.get('activation_value')
|
||||
n['access_history'] = n.get('access_history', [])
|
||||
n['last_access_time'] = n.get('last_access_time')
|
||||
n['access_count'] = n.get('access_count', 0)
|
||||
|
||||
return ExtractedEntityNode(**n)
|
||||
|
||||
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -24,6 +25,157 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量更新节点的激活值
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史和激活值。
|
||||
使用重试机制处理更新失败的情况。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
nodes: 节点列表,每个节点必须包含 'id' 字段
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选)
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 成功更新的节点列表
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
# 创建计算器和管理器实例
|
||||
actr_calculator = ACTRCalculator()
|
||||
access_manager = AccessHistoryManager(
|
||||
connector=connector,
|
||||
actr_calculator=actr_calculator,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# 提取节点ID列表
|
||||
node_ids = [node.get('id') for node in nodes if node.get('id')]
|
||||
|
||||
if not node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
node_ids=node_ids,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
|
||||
)
|
||||
|
||||
return updated_nodes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
|
||||
)
|
||||
# 失败时返回原始节点列表
|
||||
return nodes
|
||||
|
||||
|
||||
async def _update_search_results_activation(
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
group_id: Optional[str] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
更新搜索结果中所有知识节点的激活值
|
||||
|
||||
对 Statement、ExtractedEntity、MemorySummary 节点进行批量激活值更新。
|
||||
ChunkNode 和 DialogueNode 不参与激活值更新(数据层隔离)。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
results: 搜索结果字典,包含不同类型节点的列表
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
|
||||
"""
|
||||
# 定义需要更新激活值的节点类型
|
||||
knowledge_node_types = {
|
||||
'statements': 'Statement',
|
||||
'entities': 'ExtractedEntity',
|
||||
'summaries': 'MemorySummary'
|
||||
}
|
||||
|
||||
# 并行更新所有类型的节点
|
||||
update_tasks = []
|
||||
update_keys = []
|
||||
|
||||
for key, label in knowledge_node_types.items():
|
||||
if key in results and results[key]:
|
||||
update_tasks.append(
|
||||
_update_activation_values_batch(
|
||||
connector=connector,
|
||||
nodes=results[key],
|
||||
node_label=label,
|
||||
group_id=group_id
|
||||
)
|
||||
)
|
||||
update_keys.append(key)
|
||||
|
||||
if not update_tasks:
|
||||
return results
|
||||
|
||||
# 并行执行所有更新
|
||||
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
|
||||
|
||||
# 更新结果字典,保留原始搜索分数
|
||||
updated_results = results.copy()
|
||||
for key, update_result in zip(update_keys, update_results):
|
||||
if not isinstance(update_result, Exception):
|
||||
# 更新成功,合并原始搜索结果和更新后的激活值数据
|
||||
# 保留原始的 score 字段(BM25/Embedding 分数)
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
# 创建 ID 到原始节点的映射(用于快速查找 score)
|
||||
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
|
||||
|
||||
# 合并数据:激活值来自更新结果,score 来自原始结果
|
||||
merged_nodes = []
|
||||
for updated_node in updated_nodes:
|
||||
node_id = updated_node.get('id')
|
||||
if node_id and node_id in original_map:
|
||||
# 保留原始的 score 字段
|
||||
original_score = original_map[node_id].get('score')
|
||||
if original_score is not None:
|
||||
updated_node['score'] = original_score
|
||||
merged_nodes.append(updated_node)
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
# 更新失败,记录错误但保留原始结果
|
||||
logger.warning(
|
||||
f"更新 {key} 激活值失败: {str(update_result)}"
|
||||
)
|
||||
|
||||
return updated_results
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
@@ -36,6 +188,7 @@ async def search_graph(
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
INTEGRATED: Updates activation values for knowledge nodes before returning results
|
||||
|
||||
- Statements: matches s.statement CONTAINS q
|
||||
- Entities: matches e.name CONTAINS q
|
||||
@@ -50,7 +203,7 @@ async def search_graph(
|
||||
include: List of categories to search (default: all)
|
||||
|
||||
Returns:
|
||||
Dictionary with search results per category
|
||||
Dictionary with search results per category (with updated activation values)
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
@@ -106,6 +259,13 @@ async def search_graph(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -121,6 +281,7 @@ async def search_graph_by_embedding(
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
INTEGRATED: Updates activation values for knowledge nodes before returning results
|
||||
|
||||
- Computes query embedding with the provided embedder_client
|
||||
- Ranks by cosine similarity in Cypher
|
||||
@@ -203,6 +364,16 @@ async def search_graph_by_embedding(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
update_start = time.time()
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
update_time = time.time() - update_start
|
||||
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||
|
||||
return results
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
@@ -304,6 +475,8 @@ async def search_graph_by_keyword_temporal(
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements containing query_text created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -326,7 +499,15 @@ async def search_graph_by_keyword_temporal(
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
return {"statements": statements}
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
@@ -342,6 +523,8 @@ async def search_graph_by_temporal(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -362,7 +545,16 @@ async def search_graph_by_temporal(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
@@ -419,6 +611,8 @@ async def search_graph_by_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -436,7 +630,16 @@ async def search_graph_by_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -448,6 +651,8 @@ async def search_graph_by_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -465,7 +670,16 @@ async def search_graph_by_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -477,6 +691,8 @@ async def search_graph_g_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -494,7 +710,16 @@ async def search_graph_g_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -506,6 +731,8 @@ async def search_graph_g_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -523,7 +750,16 @@ async def search_graph_g_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -535,6 +771,8 @@ async def search_graph_l_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -552,7 +790,16 @@ async def search_graph_l_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -564,6 +811,8 @@ async def search_graph_l_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -581,4 +830,13 @@ async def search_graph_l_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -8,7 +8,6 @@ Classes:
|
||||
Neo4jConnector: Neo4j数据库连接器,提供异步查询接口
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from neo4j import AsyncGraphDatabase, basic_auth
|
||||
@@ -85,6 +84,63 @@ class Neo4jConnector:
|
||||
records, summary, keys = result
|
||||
return [record.data() for record in records]
|
||||
|
||||
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||
"""在写事务中执行操作
|
||||
|
||||
提供显式事务支持,确保操作的原子性。
|
||||
如果事务函数抛出异常,所有更改将自动回滚。
|
||||
|
||||
Args:
|
||||
transaction_func: 事务函数,接收 tx 参数并执行查询
|
||||
**kwargs: 传递给事务函数的额外参数
|
||||
|
||||
Returns:
|
||||
Any: 事务函数的返回值
|
||||
|
||||
Example:
|
||||
>>> async def create_node(tx, name):
|
||||
... result = await tx.run(
|
||||
... "CREATE (n:Person {name: $name}) RETURN n",
|
||||
... name=name
|
||||
... )
|
||||
... return await result.single()
|
||||
>>>
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> result = await connector.execute_write_transaction(
|
||||
... create_node, name="Alice"
|
||||
... )
|
||||
"""
|
||||
async with self.driver.session(database="neo4j") as session:
|
||||
return await session.execute_write(transaction_func, **kwargs)
|
||||
|
||||
async def execute_read_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||
"""在读事务中执行操作
|
||||
|
||||
提供显式事务支持用于读操作。
|
||||
|
||||
Args:
|
||||
transaction_func: 事务函数,接收 tx 参数并执行查询
|
||||
**kwargs: 传递给事务函数的额外参数
|
||||
|
||||
Returns:
|
||||
Any: 事务函数的返回值
|
||||
|
||||
Example:
|
||||
>>> async def get_node(tx, name):
|
||||
... result = await tx.run(
|
||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
||||
... name=name
|
||||
... )
|
||||
... return await result.single()
|
||||
>>>
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> result = await connector.execute_read_transaction(
|
||||
... get_node, name="Alice"
|
||||
... )
|
||||
"""
|
||||
async with self.driver.session(database="neo4j") as session:
|
||||
return await session.execute_read(transaction_func, **kwargs)
|
||||
|
||||
async def delete_group(self, group_id: str):
|
||||
"""删除指定组的所有数据
|
||||
|
||||
|
||||
@@ -75,6 +75,13 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
n['emotion_subject'] = n.get('emotion_subject')
|
||||
n['emotion_target'] = n.get('emotion_target')
|
||||
|
||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||
n['importance_score'] = n.get('importance_score', 0.5)
|
||||
n['activation_value'] = n.get('activation_value')
|
||||
n['access_history'] = n.get('access_history', [])
|
||||
n['last_access_time'] = n.get('last_access_time')
|
||||
n['access_count'] = n.get('access_count', 0)
|
||||
|
||||
return StatementNode(**n)
|
||||
|
||||
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:
|
||||
|
||||
Reference in New Issue
Block a user