Files
MemoryBear/api/app/repositories/neo4j/graph_saver.py
2026-04-10 10:43:43 +08:00

439 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import asyncio
import os
from typing import List, Optional
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.add_nodes import add_dialogue_nodes, add_statement_nodes, add_chunk_nodes
from app.repositories.neo4j.cypher_queries import (
STATEMENT_ENTITY_EDGE_SAVE,
ENTITY_RELATIONSHIP_SAVE,
EXTRACTED_ENTITY_NODE_SAVE,
CHUNK_STATEMENT_EDGE_SAVE,
STATEMENT_ENTITY_EDGE_SAVE,
ENTITY_RELATIONSHIP_SAVE,
EXTRACTED_ENTITY_NODE_SAVE,
)
from app.core.memory.models.graph_models import (
DialogueNode,
ChunkNode,
StatementChunkEdge,
StatementEntityEdge,
StatementNode,
ExtractedEntityNode,
EntityEntityEdge,
PerceptualNode,
PerceptualEdge,
)
import logging
logger = logging.getLogger(__name__)
async def save_entities_and_relationships(
entity_nodes: List[ExtractedEntityNode],
entity_entity_edges: List[EntityEntityEdge],
connector: Neo4jConnector
):
"""Save entities and their relationships using graph models"""
all_entities = [entity.model_dump() for entity in entity_nodes]
all_relationships = []
for edge in entity_entity_edges:
relationship = {
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
}
all_relationships.append(relationship)
# Save entities
if all_entities:
entity_uuids = await connector.execute_query(EXTRACTED_ENTITY_NODE_SAVE, entities=all_entities)
if entity_uuids:
print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
else:
print("Failed to save entity nodes to Neo4j")
else:
print("No entity nodes to save")
# Create relationships
if all_relationships:
relationship_uuids = await connector.execute_query(ENTITY_RELATIONSHIP_SAVE, relationships=all_relationships)
if relationship_uuids:
print(f"Successfully saved {len(relationship_uuids)} entity relationships (edges) to Neo4j")
else:
print("Failed to save entity relationships to Neo4j")
else:
print("No entity relationships to save")
async def save_chunk_nodes(
chunk_nodes: List[ChunkNode],
connector: Neo4jConnector
):
"""Save chunk nodes using graph models"""
if not chunk_nodes:
print("No chunk nodes to save")
return
chunk_uuids = await add_chunk_nodes(chunk_nodes, connector)
if chunk_uuids:
print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
else:
print("Failed to save chunk nodes to Neo4j")
async def save_statement_chunk_edges(
statement_chunk_edges: List[StatementChunkEdge],
connector: Neo4jConnector
):
"""Save statement-chunk edges using graph models"""
if not statement_chunk_edges:
return
all_sc_edges = []
for edge in statement_chunk_edges:
all_sc_edges.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
})
try:
await connector.execute_query(
CHUNK_STATEMENT_EDGE_SAVE,
chunk_statement_edges=all_sc_edges
)
except Exception:
pass
async def save_statement_entity_edges(
statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector
):
"""Save statement-entity edges using graph models"""
if not statement_entity_edges:
print("No statement-entity edges to save")
return
all_se_edges = []
for edge in statement_entity_edges:
edge_data = {
"source": edge.source,
"target": edge.target,
"end_user_id": edge.end_user_id,
"run_id": edge.run_id,
"connect_strength": edge.connect_strength,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
}
all_se_edges.append(edge_data)
if all_se_edges:
try:
await connector.execute_query(
STATEMENT_ENTITY_EDGE_SAVE,
relationships=all_se_edges
)
except Exception:
pass
async def save_dialog_and_statements_to_neo4j(
dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode],
perceptual_nodes: List[PerceptualNode],
entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge],
perceptual_edges: List[PerceptualEdge],
connector: Neo4jConnector,
) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
_trigger_clustering_sync() 显式触发。
Args:
dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save
statement_nodes: List of StatementNode objects to save
entity_nodes: List of ExtractedEntityNode objects to save
perceptual_nodes: List of PerceptualNode objects to save
entity_edges: List of EntityEntityEdge objects to save
statement_chunk_edges: List of StatementChunkEdge objects to save
statement_entity_edges: List of StatementEntityEdge objects to save
perceptual_edges: List of PerceptualEdge objects to save
connector: Neo4j connector instance
Returns:
bool: True if successful, False otherwise
"""
# TODO 需要在去重消歧节阶段,做以下逻辑的处理
# 预处理:对特殊实体("用户"、"AI助手")复用 Neo4j 中已有节点的 ID
# 确保同一个 end_user_id 下只有一个"用户"节点和一个"AI助手"节点。
if entity_nodes:
_SPECIAL_NAMES = {"用户", "", "user", "i", "ai助手", "助手", "ai assistant", "assistant"}
end_user_id = entity_nodes[0].end_user_id if entity_nodes else None
if end_user_id:
try:
# 查询已有的特殊实体
cypher = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $names
RETURN e.id AS id, e.name AS name
"""
existing = await connector.execute_query(
cypher,
end_user_id=end_user_id,
names=list(_SPECIAL_NAMES),
)
# 建立 name(lower) → existing_id 映射
existing_id_map = {}
for record in (existing or []):
name_lower = (record.get("name") or "").strip().lower()
if name_lower and record.get("id"):
existing_id_map[name_lower] = record["id"]
if existing_id_map:
# 替换新实体的 ID 为已有 ID同时更新所有引用该 ID 的边
for ent in entity_nodes:
name_lower = (ent.name or "").strip().lower()
if name_lower in existing_id_map:
old_id = ent.id
new_id = existing_id_map[name_lower]
if old_id != new_id:
ent.id = new_id
# 更新 statement_entity_edges 中的引用
for edge in statement_entity_edges:
if edge.target == old_id:
edge.target = new_id
if edge.source == old_id:
edge.source = new_id
# 更新 entity_edges 中的引用
for edge in entity_edges:
if edge.source == old_id:
edge.source = new_id
if edge.target == old_id:
edge.target = new_id
logger.info(
f"特殊实体 '{ent.name}' ID 复用: {old_id[:8]}... → {new_id[:8]}..."
)
except Exception as e:
logger.warning(f"特殊实体 ID 复用查询失败(不影响写入): {e}")
# 定义事务函数,将所有写操作放在一个事务中
async def _save_all_in_transaction(tx):
"""在单个事务中执行所有保存操作,避免死锁"""
results = {}
# 1. Save all dialogue nodes in batch
if dialogue_nodes:
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE
dialogue_data = [node.model_dump() for node in dialogue_nodes]
result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data)
dialogue_uuids = [record["uuid"] async for record in result]
results['dialogues'] = dialogue_uuids
logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}")
# 2. Save all chunk nodes in batch
if chunk_nodes:
from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE
chunk_data = [node.model_dump() for node in chunk_nodes]
result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data)
chunk_uuids = [record["uuid"] async for record in result]
results['chunks'] = chunk_uuids
logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j")
if perceptual_nodes:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE
perceptual_data = [node.model_dump() for node in perceptual_nodes]
result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data)
perceptual_uuids = [record["uuid"] async for record in result]
results["perceptuals"] = perceptual_uuids
logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j")
# 3. Save all statement nodes in batch
if statement_nodes:
from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE
statement_data = [node.model_dump() for node in statement_nodes]
result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data)
statement_uuids = [record["uuid"] async for record in result]
results['statements'] = statement_uuids
logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j")
# 4. Save entities
if entity_nodes:
from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE
entity_data = [entity.model_dump() for entity in entity_nodes]
result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data)
entity_uuids = [record["uuid"] async for record in result]
results['entities'] = entity_uuids
logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j")
# 5. Create entity relationships
if entity_edges:
from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE
relationship_data = []
for edge in entity_edges:
relationship_data.append({
'source_id': edge.source,
'target_id': edge.target,
'predicate': edge.relation_type,
'statement_id': edge.source_statement_id,
'value': edge.relation_value,
'statement': edge.statement,
'valid_at': edge.valid_at.isoformat() if edge.valid_at else None,
'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None,
'created_at': edge.created_at.isoformat() if edge.created_at else None,
'expired_at': edge.expired_at.isoformat() if edge.expired_at else None,
'run_id': edge.run_id,
'end_user_id': edge.end_user_id,
})
result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data)
rel_uuids = [record["uuid"] async for record in result]
results['entity_relationships'] = rel_uuids
logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j")
# 6. Save statement-chunk edges
if statement_chunk_edges:
from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE
sc_edge_data = []
for edge in statement_chunk_edges:
sc_edge_data.append({
"id": edge.id,
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
})
result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data)
sc_uuids = [record["uuid"] async for record in result]
results['statement_chunk_edges'] = sc_uuids
logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j")
# 7. Save statement-entity edges
if statement_entity_edges:
from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE
se_edge_data = []
for edge in statement_entity_edges:
se_edge_data.append({
"source": edge.source,
"target": edge.target,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
"expired_at": edge.expired_at.isoformat() if edge.expired_at else None,
"run_id": edge.run_id,
"end_user_id": edge.end_user_id,
"connect_strength": getattr(edge, "connect_strength", "strong"),
})
result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data)
se_uuids = [record["uuid"] async for record in result]
results['statement_entity_edges'] = se_uuids
logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j")
if perceptual_edges:
from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE
perceptual_edge_data = []
for edge in perceptual_edges:
print(edge.source, edge.target)
perceptual_edge_data.append({
"perceptual_id": edge.source,
"chunk_id": edge.target,
"end_user_id": edge.end_user_id,
"created_at": edge.created_at.isoformat() if edge.created_at else None,
})
result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data)
perceptual_edges_uuids = [record["uuid"] async for record in result]
results['perceptual_chunk_edges'] = perceptual_edges_uuids
logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j")
return results
try:
# 使用显式写事务执行所有操作,避免死锁
results = await connector.execute_write_transaction(_save_all_in_transaction)
summary = {
key: len(value)
for key, value in results.items()
if isinstance(value, (list, tuple, set))
}
logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results)
return True
except Exception as e:
logger.error(f"Neo4j integration error: {e}", exc_info=True)
print(f"Neo4j integration error: {e}")
print("Continuing without database storage...")
return False
async def _trigger_clustering_sync(
entity_nodes: List,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
"""
同步等待聚类完成,避免与其他 LLM 任务并发冲突。
"""
if not entity_nodes:
return
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
if not clustering_enabled:
logger.info("[Clustering] 聚类已禁用CLUSTERING_ENABLED=false跳过聚类触发")
return
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id)
async def _trigger_clustering(
new_entity_ids: List[str],
end_user_id: str,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
"""
聚类触发函数,自动判断全量初始化还是增量更新。
"""
connector = None
try:
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
connector = Neo4jConnector()
engine = LabelPropagationEngine(connector, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}")
except Exception as e:
logger.error(f"[Clustering] 聚类触发失败: {e}", exc_info=True)
finally:
if connector:
try:
await connector.close()
except Exception:
pass