feat(memory): add perceptual memory retrieval service with BM25+embedding fusion

This commit is contained in:
Eternity
2026-04-01 17:19:03 +08:00
parent 75bb96d4e7
commit 9cbe9d5edc
13 changed files with 1042 additions and 409 deletions

View File

@@ -1,17 +1,17 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector()
try:
# 创建 Statements 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# # 创建 Dialogues 索引
# await connector.execute_query("""
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
@@ -21,27 +21,35 @@ async def create_fulltext_indexes():
await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# 创建 Chunks 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# 创建 MemorySummary 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
""")
# 创建 Community 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
# 创建 Perceptual 感知记忆索引
await connector.execute_query("""
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
finally:
await connector.close()
async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search.
@@ -50,8 +58,7 @@ async def create_vector_indexes():
"""
connector = Neo4jConnector()
try:
# Statement embedding index
await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
@@ -62,8 +69,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Chunk embedding index
await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
@@ -75,7 +81,6 @@ async def create_vector_indexes():
}}
""")
# Entity name embedding index
await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
@@ -86,8 +91,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Memory summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
@@ -98,7 +102,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Community summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
@@ -108,8 +112,8 @@ async def create_vector_indexes():
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
""")
# Dialogue embedding index (optional)
await connector.execute_query("""
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
@@ -120,15 +124,27 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine'
}}
""")
# Perceptual summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
FOR (p:Perceptual)
ON p.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
finally:
await connector.close()
async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates.
"""
connector = Neo4jConnector()
try:
try:
# Dialogue.id unique
await connector.execute_query(
"""
@@ -136,7 +152,7 @@ async def create_unique_constraints():
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
"""
)
# Statement.id unique
await connector.execute_query(
"""
@@ -144,7 +160,7 @@ async def create_unique_constraints():
FOR (s:Statement) REQUIRE s.id IS UNIQUE
"""
)
# Chunk.id unique
await connector.execute_query(
"""
@@ -152,13 +168,13 @@ async def create_unique_constraints():
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
"""
)
finally:
await connector.close()
async def create_all_indexes():
"""Create all indexes and constraints in one go."""
await create_fulltext_indexes()
await create_vector_indexes()
await create_unique_constraints()
print("✓ All indexes and constraints created successfully!")

View File

@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at
RETURN elementId(r) AS uuid
"""
SEARCH_PERCEPTUAL_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
WHERE p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""
PERCEPTUAL_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS p, score
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
PERCEPTUAL_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
批量更新节点的激活值
@@ -58,7 +60,7 @@ async def _update_activation_values_batch(
"""
if not nodes:
return []
# 延迟导入以避免循环依赖
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager,
@@ -66,7 +68,7 @@ async def _update_activation_values_batch(
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
)
# 创建计算器和管理器实例
actr_calculator = ACTRCalculator()
access_manager = AccessHistoryManager(
@@ -74,7 +76,7 @@ async def _update_activation_values_batch(
actr_calculator=actr_calculator,
max_retries=max_retries
)
# 提取节点ID列表并去重保持原始顺序
seen_ids = set()
unique_node_ids = []
@@ -83,7 +85,7 @@ async def _update_activation_values_batch(
if node_id and node_id not in seen_ids:
seen_ids.add(node_id)
unique_node_ids.append(node_id)
if not unique_node_ids:
logger.warning(f"批量更新激活值没有有效的节点ID")
return nodes
@@ -95,7 +97,7 @@ async def _update_activation_values_batch(
f"批量更新激活值检测到重复节点具有有效ID的节点数量={id_nodes_count}, "
f"去重后唯一ID数量={len(unique_node_ids)}"
)
# 批量记录访问
try:
updated_nodes = await access_manager.record_batch_access(
@@ -103,14 +105,14 @@ async def _update_activation_values_batch(
node_label=node_label,
end_user_id=end_user_id
)
logger.info(
f"批量更新激活值成功: {node_label}, "
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
)
return updated_nodes
except Exception as e:
logger.error(
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
@@ -144,11 +146,11 @@ async def _update_search_results_activation(
'entities': 'ExtractedEntity',
'summaries': 'MemorySummary'
}
# 并行更新所有类型的节点
update_tasks = []
update_keys = []
for key, label in knowledge_node_types.items():
if key in results and results[key]:
update_tasks.append(
@@ -160,13 +162,13 @@ async def _update_search_results_activation(
)
)
update_keys.append(key)
if not update_tasks:
return results
# 并行执行所有更新
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
# 更新结果字典,保留原始搜索分数
updated_results = results.copy()
for key, update_result in zip(update_keys, update_results):
@@ -175,10 +177,10 @@ async def _update_search_results_activation(
# 保留原始的 score 字段BM25/Embedding 分数)
original_nodes = results[key]
updated_nodes = update_result
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
merged_nodes = []
for original_node in original_nodes:
@@ -186,7 +188,7 @@ async def _update_search_results_activation(
if node_id and node_id in updated_map:
# 从原始节点开始,用更新后的激活值数据覆盖
merged_node = original_node.copy()
# 更新激活值相关字段
activation_fields = {
'activation_value',
@@ -196,35 +198,35 @@ async def _update_search_results_activation(
'importance_score',
'version',
'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段
'content' # MemorySummary 节点的内容字段
}
# 只更新激活值相关字段,保留原始节点的其他字段
for field in activation_fields:
if field in updated_map[node_id]:
merged_node[field] = updated_map[node_id][field]
merged_nodes.append(merged_node)
else:
# 如果没有更新数据,保留原始节点
merged_nodes.append(original_node)
updated_results[key] = merged_nodes
else:
# 更新失败,记录错误但保留原始结果
logger.warning(
f"更新 {key} 激活值失败: {str(update_result)}"
)
return updated_results
async def search_graph(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -249,41 +251,45 @@ async def search_graph(
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Prepare tasks for parallel execution
tasks = []
task_keys = []
if "statements" in include:
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
if "entities" in include:
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
if "chunks" in include:
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
if "summaries" in include:
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -293,15 +299,16 @@ async def search_graph(
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Build results dictionary
results = {}
for key, result in zip(task_keys, task_results):
@@ -310,14 +317,14 @@ async def search_graph(
results[key] = []
else:
results[key] = result
# Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results
for key in results:
if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization)
needs_activation_update = any(
@@ -331,17 +338,17 @@ async def search_graph(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities", "summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
"""
Embedding-based semantic search across Statements, Chunks, and Entities.
@@ -355,13 +362,13 @@ async def search_graph_by_embedding(
- Returns up to 'limit' per included type
"""
import time
# Get embedding for the query
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
logger.warning(
f"search_graph_by_embedding: embedding 生成失败或为空,"
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
if "statements" in include:
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
if "chunks" in include:
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
if "entities" in include:
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
if "summaries" in include:
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -428,8 +440,8 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
"statements": [],
@@ -438,7 +450,7 @@ async def search_graph_by_embedding(
"summaries": [],
"communities": [],
}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
) -> Dict[str, List[Any]]:
"""
Temporal keyword search across Statements.
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
logger.warning(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
logger.debug(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -643,15 +653,15 @@ async def search_graph_by_temporal(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
logger.warning(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
chunk_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
logger.warning(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
async def search_graph_by_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -767,16 +776,11 @@ async def search_graph_by_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -784,16 +788,16 @@ async def search_graph_by_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -807,16 +811,11 @@ async def search_graph_by_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -824,16 +823,16 @@ async def search_graph_by_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_g_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -847,16 +846,11 @@ async def search_graph_g_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -864,16 +858,16 @@ async def search_graph_g_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_g_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -904,16 +892,16 @@ async def search_graph_g_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_l_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -927,16 +915,11 @@ async def search_graph_l_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -944,16 +927,16 @@ async def search_graph_l_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_l_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -967,16 +950,11 @@ async def search_graph_l_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -984,5 +962,89 @@ async def search_graph_l_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_perceptual(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
q: Query text
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
PERCEPTUAL_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}

View File

@@ -11,10 +11,28 @@ Classes:
from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
from app.core.config import settings
def _convert_neo4j_types(value: Any) -> Any:
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
if isinstance(value, Neo4jDateTime):
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
if isinstance(value, Neo4jDate):
return value.iso_format()
if isinstance(value, Neo4jTime):
return value.iso_format()
if isinstance(value, Neo4jDuration):
return str(value)
if isinstance(value, dict):
return {k: _convert_neo4j_types(v) for k, v in value.items()}
if isinstance(value, list):
return [_convert_neo4j_types(item) for item in value]
return value
class Neo4jConnector:
"""Neo4j数据库连接器
@@ -59,11 +77,12 @@ class Neo4jConnector:
"""
await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询
Args:
query: Cypher查询语句
json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询
Returns:
@@ -78,7 +97,10 @@ class Neo4jConnector:
**kwargs
)
records, summary, keys = result
return [record.data() for record in records]
if json_format:
return [_convert_neo4j_types(record.data()) for record in records]
else:
return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作