feat(memory): add perceptual memory retrieval service with BM25+embedding fusion

This commit is contained in:
Eternity
2026-04-01 17:19:03 +08:00
parent 75bb96d4e7
commit 9cbe9d5edc
13 changed files with 1042 additions and 409 deletions

View File

@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
PERCEPTUAL_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
批量更新节点的激活值
@@ -58,7 +60,7 @@ async def _update_activation_values_batch(
"""
if not nodes:
return []
# 延迟导入以避免循环依赖
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager,
@@ -66,7 +68,7 @@ async def _update_activation_values_batch(
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
)
# 创建计算器和管理器实例
actr_calculator = ACTRCalculator()
access_manager = AccessHistoryManager(
@@ -74,7 +76,7 @@ async def _update_activation_values_batch(
actr_calculator=actr_calculator,
max_retries=max_retries
)
# 提取节点ID列表并去重保持原始顺序
seen_ids = set()
unique_node_ids = []
@@ -83,7 +85,7 @@ async def _update_activation_values_batch(
if node_id and node_id not in seen_ids:
seen_ids.add(node_id)
unique_node_ids.append(node_id)
if not unique_node_ids:
logger.warning(f"批量更新激活值没有有效的节点ID")
return nodes
@@ -95,7 +97,7 @@ async def _update_activation_values_batch(
f"批量更新激活值检测到重复节点具有有效ID的节点数量={id_nodes_count}, "
f"去重后唯一ID数量={len(unique_node_ids)}"
)
# 批量记录访问
try:
updated_nodes = await access_manager.record_batch_access(
@@ -103,14 +105,14 @@ async def _update_activation_values_batch(
node_label=node_label,
end_user_id=end_user_id
)
logger.info(
f"批量更新激活值成功: {node_label}, "
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
)
return updated_nodes
except Exception as e:
logger.error(
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
@@ -144,11 +146,11 @@ async def _update_search_results_activation(
'entities': 'ExtractedEntity',
'summaries': 'MemorySummary'
}
# 并行更新所有类型的节点
update_tasks = []
update_keys = []
for key, label in knowledge_node_types.items():
if key in results and results[key]:
update_tasks.append(
@@ -160,13 +162,13 @@ async def _update_search_results_activation(
)
)
update_keys.append(key)
if not update_tasks:
return results
# 并行执行所有更新
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
# 更新结果字典,保留原始搜索分数
updated_results = results.copy()
for key, update_result in zip(update_keys, update_results):
@@ -175,10 +177,10 @@ async def _update_search_results_activation(
# 保留原始的 score 字段BM25/Embedding 分数)
original_nodes = results[key]
updated_nodes = update_result
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
merged_nodes = []
for original_node in original_nodes:
@@ -186,7 +188,7 @@ async def _update_search_results_activation(
if node_id and node_id in updated_map:
# 从原始节点开始,用更新后的激活值数据覆盖
merged_node = original_node.copy()
# 更新激活值相关字段
activation_fields = {
'activation_value',
@@ -196,35 +198,35 @@ async def _update_search_results_activation(
'importance_score',
'version',
'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段
'content' # MemorySummary 节点的内容字段
}
# 只更新激活值相关字段,保留原始节点的其他字段
for field in activation_fields:
if field in updated_map[node_id]:
merged_node[field] = updated_map[node_id][field]
merged_nodes.append(merged_node)
else:
# 如果没有更新数据,保留原始节点
merged_nodes.append(original_node)
updated_results[key] = merged_nodes
else:
# 更新失败,记录错误但保留原始结果
logger.warning(
f"更新 {key} 激活值失败: {str(update_result)}"
)
return updated_results
async def search_graph(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -249,41 +251,45 @@ async def search_graph(
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Prepare tasks for parallel execution
tasks = []
task_keys = []
if "statements" in include:
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
if "entities" in include:
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
if "chunks" in include:
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
if "summaries" in include:
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -293,15 +299,16 @@ async def search_graph(
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Build results dictionary
results = {}
for key, result in zip(task_keys, task_results):
@@ -310,14 +317,14 @@ async def search_graph(
results[key] = []
else:
results[key] = result
# Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results
for key in results:
if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization)
needs_activation_update = any(
@@ -331,17 +338,17 @@ async def search_graph(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities", "summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
"""
Embedding-based semantic search across Statements, Chunks, and Entities.
@@ -355,13 +362,13 @@ async def search_graph_by_embedding(
- Returns up to 'limit' per included type
"""
import time
# Get embedding for the query
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
logger.warning(
f"search_graph_by_embedding: embedding 生成失败或为空,"
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
if "statements" in include:
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
if "chunks" in include:
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
if "entities" in include:
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
if "summaries" in include:
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -428,8 +440,8 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
"statements": [],
@@ -438,7 +450,7 @@ async def search_graph_by_embedding(
"summaries": [],
"communities": [],
}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
) -> Dict[str, List[Any]]:
"""
Temporal keyword search across Statements.
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
logger.warning(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
logger.debug(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -643,15 +653,15 @@ async def search_graph_by_temporal(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
logger.warning(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
chunk_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
logger.warning(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
async def search_graph_by_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -767,16 +776,11 @@ async def search_graph_by_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -784,16 +788,16 @@ async def search_graph_by_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_by_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -807,16 +811,11 @@ async def search_graph_by_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -824,16 +823,16 @@ async def search_graph_by_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_g_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -847,16 +846,11 @@ async def search_graph_g_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -864,16 +858,16 @@ async def search_graph_g_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_g_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -904,16 +892,16 @@ async def search_graph_g_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_l_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -927,16 +915,11 @@ async def search_graph_l_created_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -944,16 +927,16 @@ async def search_graph_l_created_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_graph_l_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -967,16 +950,11 @@ async def search_graph_l_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -984,5 +962,89 @@ async def search_graph_l_valid_at(
results=results,
end_user_id=end_user_id
)
return results
async def search_perceptual(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
q: Query text
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
PERCEPTUAL_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}