feat(memory): add perceptual memory retrieval service with BM25+embedding fusion
This commit is contained in:
@@ -1,17 +1,17 @@
|
||||
import asyncio
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def create_fulltext_indexes():
|
||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
|
||||
|
||||
# 创建 Statements 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# # 创建 Dialogues 索引
|
||||
# await connector.execute_query("""
|
||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||
@@ -21,27 +21,35 @@ async def create_fulltext_indexes():
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# 创建 Chunks 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# 创建 MemorySummary 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
""")
|
||||
# 创建 Community 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
|
||||
# 创建 Perceptual 感知记忆索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_vector_indexes():
|
||||
"""Create vector indexes for fast embedding similarity search.
|
||||
|
||||
@@ -50,8 +58,7 @@ async def create_vector_indexes():
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
|
||||
|
||||
|
||||
# Statement embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
|
||||
@@ -62,8 +69,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
|
||||
# Chunk embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
|
||||
@@ -75,7 +81,6 @@ async def create_vector_indexes():
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Entity name embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
|
||||
@@ -86,8 +91,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
|
||||
# Memory summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
|
||||
@@ -98,7 +102,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
@@ -108,8 +112,8 @@ async def create_vector_indexes():
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
""")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
||||
@@ -120,15 +124,27 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
|
||||
|
||||
# Perceptual summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
|
||||
FOR (p:Perceptual)
|
||||
ON p.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_unique_constraints():
|
||||
"""Create uniqueness constraints for core node identifiers.
|
||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
try:
|
||||
# Dialogue.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -136,7 +152,7 @@ async def create_unique_constraints():
|
||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Statement.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -144,7 +160,7 @@ async def create_unique_constraints():
|
||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
# Chunk.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -152,13 +168,13 @@ async def create_unique_constraints():
|
||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_all_indexes():
|
||||
"""Create all indexes and constraints in one go."""
|
||||
await create_fulltext_indexes()
|
||||
await create_vector_indexes()
|
||||
await create_unique_constraints()
|
||||
print("✓ All indexes and constraints created successfully!")
|
||||
|
||||
|
||||
@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||
r.created_at = edge.created_at
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
|
||||
WHERE p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
PERCEPTUAL_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS p, score
|
||||
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
EXPAND_COMMUNITY_STATEMENTS,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量更新节点的激活值
|
||||
@@ -58,7 +60,7 @@ async def _update_activation_values_batch(
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
|
||||
AccessHistoryManager,
|
||||
@@ -66,7 +68,7 @@ async def _update_activation_values_batch(
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
)
|
||||
|
||||
|
||||
# 创建计算器和管理器实例
|
||||
actr_calculator = ACTRCalculator()
|
||||
access_manager = AccessHistoryManager(
|
||||
@@ -74,7 +76,7 @@ async def _update_activation_values_batch(
|
||||
actr_calculator=actr_calculator,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
|
||||
# 提取节点ID列表并去重(保持原始顺序)
|
||||
seen_ids = set()
|
||||
unique_node_ids = []
|
||||
@@ -83,7 +85,7 @@ async def _update_activation_values_batch(
|
||||
if node_id and node_id not in seen_ids:
|
||||
seen_ids.add(node_id)
|
||||
unique_node_ids.append(node_id)
|
||||
|
||||
|
||||
if not unique_node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
@@ -95,7 +97,7 @@ async def _update_activation_values_batch(
|
||||
f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, "
|
||||
f"去重后唯一ID数量={len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
@@ -103,14 +105,14 @@ async def _update_activation_values_batch(
|
||||
node_label=node_label,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
|
||||
)
|
||||
|
||||
|
||||
return updated_nodes
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
|
||||
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
|
||||
|
||||
|
||||
async def _update_search_results_activation(
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
end_user_id: Optional[str] = None
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
更新搜索结果中所有知识节点的激活值
|
||||
@@ -144,11 +146,11 @@ async def _update_search_results_activation(
|
||||
'entities': 'ExtractedEntity',
|
||||
'summaries': 'MemorySummary'
|
||||
}
|
||||
|
||||
|
||||
# 并行更新所有类型的节点
|
||||
update_tasks = []
|
||||
update_keys = []
|
||||
|
||||
|
||||
for key, label in knowledge_node_types.items():
|
||||
if key in results and results[key]:
|
||||
update_tasks.append(
|
||||
@@ -160,13 +162,13 @@ async def _update_search_results_activation(
|
||||
)
|
||||
)
|
||||
update_keys.append(key)
|
||||
|
||||
|
||||
if not update_tasks:
|
||||
return results
|
||||
|
||||
|
||||
# 并行执行所有更新
|
||||
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 更新结果字典,保留原始搜索分数
|
||||
updated_results = results.copy()
|
||||
for key, update_result in zip(update_keys, update_results):
|
||||
@@ -175,10 +177,10 @@ async def _update_search_results_activation(
|
||||
# 保留原始的 score 字段(BM25/Embedding 分数)
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
|
||||
# 创建 ID 到更新节点的映射(用于快速查找激活值数据)
|
||||
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
|
||||
|
||||
|
||||
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
|
||||
merged_nodes = []
|
||||
for original_node in original_nodes:
|
||||
@@ -186,7 +188,7 @@ async def _update_search_results_activation(
|
||||
if node_id and node_id in updated_map:
|
||||
# 从原始节点开始,用更新后的激活值数据覆盖
|
||||
merged_node = original_node.copy()
|
||||
|
||||
|
||||
# 更新激活值相关字段
|
||||
activation_fields = {
|
||||
'activation_value',
|
||||
@@ -196,35 +198,35 @@ async def _update_search_results_activation(
|
||||
'importance_score',
|
||||
'version',
|
||||
'statement', # Statement 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
'content' # MemorySummary 节点的内容字段
|
||||
}
|
||||
|
||||
|
||||
# 只更新激活值相关字段,保留原始节点的其他字段
|
||||
for field in activation_fields:
|
||||
if field in updated_map[node_id]:
|
||||
merged_node[field] = updated_map[node_id][field]
|
||||
|
||||
|
||||
merged_nodes.append(merged_node)
|
||||
else:
|
||||
# 如果没有更新数据,保留原始节点
|
||||
merged_nodes.append(original_node)
|
||||
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
# 更新失败,记录错误但保留原始结果
|
||||
logger.warning(
|
||||
f"更新 {key} 激活值失败: {str(update_result)}"
|
||||
)
|
||||
|
||||
|
||||
return updated_results
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
@@ -249,41 +251,45 @@ async def search_graph(
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
tasks = []
|
||||
task_keys = []
|
||||
|
||||
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("statements")
|
||||
|
||||
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("entities")
|
||||
|
||||
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("chunks")
|
||||
|
||||
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -293,15 +299,16 @@ async def search_graph(
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
json_format=True,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
))
|
||||
task_keys.append("communities")
|
||||
|
||||
|
||||
# Execute all queries in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# Build results dictionary
|
||||
results = {}
|
||||
for key, result in zip(task_keys, task_results):
|
||||
@@ -310,14 +317,14 @@ async def search_graph(
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
|
||||
# Deduplicate results before updating activation values
|
||||
# This prevents duplicates from propagating through the pipeline
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
for key in results:
|
||||
if isinstance(results[key], list):
|
||||
results[key] = _deduplicate_results(results[key])
|
||||
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
# Skip activation updates if only searching summaries (optimization)
|
||||
needs_activation_update = any(
|
||||
@@ -331,17 +338,17 @@ async def search_graph(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities", "summaries"],
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
@@ -355,13 +362,13 @@ async def search_graph_by_embedding(
|
||||
- Returns up to 'limit' per included type
|
||||
"""
|
||||
import time
|
||||
|
||||
|
||||
# Get embedding for the query
|
||||
embed_start = time.time()
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
embed_time = time.time() - embed_start
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
logger.warning(
|
||||
f"search_graph_by_embedding: embedding 生成失败或为空,"
|
||||
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
|
||||
if "statements" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
|
||||
if "chunks" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
|
||||
if "entities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
|
||||
if "summaries" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
|
||||
if "communities" in include:
|
||||
tasks.append(connector.execute_query(
|
||||
COMMUNITY_EMBEDDING_SEARCH,
|
||||
json_format=True,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -428,8 +440,8 @@ async def search_graph_by_embedding(
|
||||
query_start = time.time()
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
query_time = time.time() - query_start
|
||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||
|
||||
# Build results dictionary
|
||||
results: Dict[str, List[Dict[str, Any]]] = {
|
||||
"statements": [],
|
||||
@@ -438,7 +450,7 @@ async def search_graph_by_embedding(
|
||||
"summaries": [],
|
||||
"communities": [],
|
||||
}
|
||||
|
||||
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
|
||||
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
|
||||
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
entities: List[Dict[str, Any]],
|
||||
use_contains_fallback: bool = True,
|
||||
batch_size: int = 500,
|
||||
max_concurrency: int = 5,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
||||
|
||||
|
||||
async def search_graph_by_keyword_temporal(
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
connector: Neo4jConnector,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
|
||||
- Returns up to 'limit' statements
|
||||
"""
|
||||
if not query_text:
|
||||
print(f"query_text不能为空")
|
||||
logger.warning(f"query_text不能为空")
|
||||
return {"statements": []}
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
|
||||
invalid_date=invalid_date,
|
||||
limit=limit,
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
logger.debug(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None,
|
||||
valid_date: Optional[str] = None,
|
||||
invalid_date: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -643,15 +653,15 @@ async def search_graph_by_temporal(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
dialog_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Dialogues.
|
||||
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
|
||||
- Returns up to 'limit' dialogues
|
||||
"""
|
||||
if not dialog_id:
|
||||
print(f"dialog_id不能为空")
|
||||
logger.warning(f"dialog_id不能为空")
|
||||
return {"dialogues": []}
|
||||
|
||||
dialogues = await connector.execute_query(
|
||||
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
|
||||
|
||||
|
||||
async def search_graph_by_chunk_id(
|
||||
connector: Neo4jConnector,
|
||||
chunk_id : str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
chunk_id: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
if not chunk_id:
|
||||
print(f"chunk_id不能为空")
|
||||
logger.warning(f"chunk_id不能为空")
|
||||
return {"chunks": []}
|
||||
chunks = await connector.execute_query(
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
|
||||
|
||||
|
||||
async def search_graph_community_expand(
|
||||
connector: Neo4jConnector,
|
||||
community_ids: List[str],
|
||||
end_user_id: str,
|
||||
limit: int = 10,
|
||||
connector: Neo4jConnector,
|
||||
community_ids: List[str],
|
||||
end_user_id: str,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
三期:社区展开检索 —— 主题 → 细节两级检索。
|
||||
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
|
||||
|
||||
|
||||
async def search_graph_by_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -767,16 +776,11 @@ async def search_graph_by_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -784,16 +788,16 @@ async def search_graph_by_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -807,16 +811,11 @@ async def search_graph_by_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -824,16 +823,16 @@ async def search_graph_by_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -847,16 +846,11 @@ async def search_graph_g_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -864,16 +858,16 @@ async def search_graph_g_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_G_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -904,16 +892,16 @@ async def search_graph_g_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
created_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -927,16 +915,11 @@ async def search_graph_l_created_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
created_at=created_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -944,16 +927,16 @@ async def search_graph_l_created_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: Optional[str] = None,
|
||||
|
||||
valid_at: Optional[str] = None,
|
||||
limit: int = 1,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
@@ -967,16 +950,11 @@ async def search_graph_l_valid_at(
|
||||
statements = await connector.execute_query(
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
end_user_id=end_user_id,
|
||||
|
||||
|
||||
|
||||
valid_at=valid_at,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
@@ -984,5 +962,89 @@ async def search_graph_l_valid_at(
|
||||
results=results,
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_perceptual(
|
||||
connector: Neo4jConnector,
|
||||
q: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search Perceptual memory nodes using fulltext keyword search.
|
||||
|
||||
Matches against summary, topic, and domain fields via the perceptualFulltext index.
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
q: Query text
|
||||
end_user_id: Optional user filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||
"""
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"search_perceptual: keyword search failed: {e}")
|
||||
perceptuals = []
|
||||
|
||||
# Deduplicate
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
perceptuals = _deduplicate_results(perceptuals)
|
||||
|
||||
return {"perceptuals": perceptuals}
|
||||
|
||||
|
||||
async def search_perceptual_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Search Perceptual memory nodes using embedding-based semantic search.
|
||||
|
||||
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
|
||||
|
||||
Args:
|
||||
connector: Neo4j connector
|
||||
embedder_client: Embedding client with async response() method
|
||||
query_text: Query text to embed
|
||||
end_user_id: Optional user filter
|
||||
limit: Max results
|
||||
|
||||
Returns:
|
||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||
"""
|
||||
embeddings = await embedder_client.response([query_text])
|
||||
if not embeddings or not embeddings[0]:
|
||||
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
||||
return {"perceptuals": []}
|
||||
|
||||
embedding = embeddings[0]
|
||||
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
|
||||
perceptuals = []
|
||||
|
||||
from app.core.memory.src.search import _deduplicate_results
|
||||
perceptuals = _deduplicate_results(perceptuals)
|
||||
|
||||
return {"perceptuals": perceptuals}
|
||||
|
||||
@@ -11,10 +11,28 @@ Classes:
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from neo4j import AsyncGraphDatabase, basic_auth
|
||||
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def _convert_neo4j_types(value: Any) -> Any:
|
||||
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
|
||||
if isinstance(value, Neo4jDateTime):
|
||||
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
|
||||
if isinstance(value, Neo4jDate):
|
||||
return value.iso_format()
|
||||
if isinstance(value, Neo4jTime):
|
||||
return value.iso_format()
|
||||
if isinstance(value, Neo4jDuration):
|
||||
return str(value)
|
||||
if isinstance(value, dict):
|
||||
return {k: _convert_neo4j_types(v) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_convert_neo4j_types(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
class Neo4jConnector:
|
||||
"""Neo4j数据库连接器
|
||||
|
||||
@@ -59,11 +77,12 @@ class Neo4jConnector:
|
||||
"""
|
||||
await self.driver.close()
|
||||
|
||||
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||
"""执行Cypher查询
|
||||
|
||||
Args:
|
||||
query: Cypher查询语句
|
||||
json_format: json格式化
|
||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||
|
||||
Returns:
|
||||
@@ -78,7 +97,10 @@ class Neo4jConnector:
|
||||
**kwargs
|
||||
)
|
||||
records, summary, keys = result
|
||||
return [record.data() for record in records]
|
||||
if json_format:
|
||||
return [_convert_neo4j_types(record.data()) for record in records]
|
||||
else:
|
||||
return [record.data() for record in records]
|
||||
|
||||
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||
"""在写事务中执行操作
|
||||
|
||||
Reference in New Issue
Block a user