Files
MemoryBear/api/app/repositories/neo4j/add_nodes.py

330 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import List, Optional
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
print(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database.
Args:
dialogues: List of DialogueNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not dialogues:
print("No dialogues to save")
return []
try:
# Flatten DialogueNode objects to match Cypher expected fields
flattened_dialogues = []
for dialogue in dialogues:
flattened_dialogues.append({
"id": dialogue.id,
"end_user_id": dialogue.end_user_id,
"run_id": dialogue.run_id,
"ref_id": dialogue.ref_id,
"name": dialogue.name,
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
"content": dialogue.content,
"dialog_embedding": dialogue.dialog_embedding
})
result = await connector.execute_query(
DIALOGUE_NODE_SAVE,
dialogues=flattened_dialogues
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
return created_uuids
except Exception as e:
print(f"Error creating dialogue nodes: {e}")
return None
async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add statement nodes to Neo4j database.
Args:
statements: List of StatementNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not statements:
print("No statements to save")
return []
try:
# Flatten StatementNode objects to only include primitive types
flattened_statements = []
for statement in statements:
flattened_statement = {
"id": statement.id,
"name": statement.name,
"end_user_id": statement.end_user_id,
"run_id": statement.run_id,
"chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(),
"created_at": statement.created_at.isoformat() if statement.created_at else None,
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
"stmt_type": statement.stmt_type,
"temporal_info": statement.temporal_info.value,
"statement": statement.statement,
"connect_strength": statement.connect_strength,
"chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None,
# "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None,
# "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None,
"valid_at": statement.valid_at.isoformat() if statement.valid_at else None,
"invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None,
# "triplet_extraction_info": json.dumps({
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
# 添加情绪字段处理
"emotion_type": statement.emotion_type,
"emotion_intensity": statement.emotion_intensity,
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
"emotion_subject": statement.emotion_subject,
"emotion_target": statement.emotion_target,
# 添加 ACT-R 记忆激活属性
"importance_score": statement.importance_score,
"activation_value": statement.activation_value,
"access_history": statement.access_history if statement.access_history else [],
"last_access_time": statement.last_access_time,
"access_count": statement.access_count
}
flattened_statements.append(flattened_statement)
result = await connector.execute_query(
STATEMENT_NODE_SAVE,
statements=flattened_statements
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} statement nodes")
return created_uuids
except Exception as e:
print(f"Error creating statement nodes: {e}")
return None
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch.
Args:
chunks: List of ChunkNode objects to add
connector: Neo4j connector instance
Returns:
List of created chunk UUIDs or None if failed
"""
if not chunks:
print("No chunk nodes to add")
return []
try:
# Convert chunk nodes to dictionaries for the query
flattened_chunks = []
for chunk in chunks:
# Flatten metadata properties to avoid Neo4j Map type issues
metadata = chunk.metadata if chunk.metadata else {}
flattened_chunk = {
"id": chunk.id,
"name": chunk.name,
"end_user_id": chunk.end_user_id,
"run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
"dialog_id": chunk.dialog_id,
"content": chunk.content,
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index"),
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
}
flattened_chunks.append(flattened_chunk)
result = await connector.execute_query(
CHUNK_NODE_SAVE,
chunks=flattened_chunks
)
created_uuids = [record["uuid"] for record in result]
print(f"Successfully created {len(created_uuids)} chunk nodes")
return created_uuids
except Exception as e:
print(f"Error creating chunk nodes: {e}")
return None
async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[
List[str]]:
"""Add memory summary nodes to Neo4j in batch.
Args:
summaries: List of MemorySummaryNode objects to add
connector: Neo4j connector instance
Returns:
List of created summary node ids or None if failed
"""
if not summaries:
print("No memory summary nodes to add")
return []
try:
flattened = []
for s in summaries:
flattened.append({
"id": s.id,
"name": s.name,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
"dialog_id": s.dialog_id,
"chunk_ids": s.chunk_ids,
"content": s.content,
"memory_type": s.memory_type, # 添加 memory_type 字段
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
"config_id": s.config_id, # 添加 config_id
})
result = await connector.execute_query(
MEMORY_SUMMARY_NODE_SAVE,
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids
except Exception as e:
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None
async def add_perceptual_nodes(
perceptuals: list,
connector: Neo4jConnector,
embedder_client=None,
) -> Optional[List[str]]:
"""Add perceptual memory nodes to Neo4j in batch.
Args:
perceptuals: List of MemoryPerceptualModel objects from PostgreSQL
connector: Neo4j connector instance
embedder_client: Optional embedder client for generating summary embeddings
Returns:
List of created node UUIDs or None if failed
"""
if not perceptuals:
print("No perceptual nodes to add")
return []
try:
flattened = []
for p in perceptuals:
meta = p.meta_data or {}
content_meta = meta.get("content", {})
# 生成 summary embedding如果有 embedder_client
summary_embedding = None
if embedder_client and p.summary:
try:
summary_embedding = (await embedder_client.response([p.summary]))[0]
except Exception as emb_err:
print(f"Failed to embed perceptual summary: {emb_err}")
flattened.append({
"id": str(p.id),
"end_user_id": str(p.end_user_id),
"perceptual_type": p.perceptual_type,
"file_path": p.file_path or "",
"file_name": p.file_name or "",
"file_ext": p.file_ext or "",
"summary": p.summary or "",
"keywords": content_meta.get("keywords", []),
"topic": content_meta.get("topic", ""),
"domain": content_meta.get("domain", ""),
"created_at": p.created_time.isoformat() if p.created_time else None,
"summary_embedding": summary_embedding,
})
result = await connector.execute_query(
PERCEPTUAL_NODE_SAVE,
perceptuals=flattened,
)
created_uuids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j")
return created_uuids
except Exception as e:
print(f"Failed to save Perceptual nodes to Neo4j: {e}")
return None
async def add_perceptual_dialogue_edges(
perceptuals: list,
dialog_id: str,
connector: Neo4jConnector,
) -> Optional[List[str]]:
"""Add edges between Perceptual nodes and Dialogue nodes.
Args:
perceptuals: List of MemoryPerceptualModel objects
dialog_id: The dialogue ID (or ref_id) to link to
connector: Neo4j connector instance
Returns:
List of created edge element IDs or None if failed
"""
if not perceptuals or not dialog_id:
return []
try:
edges = []
for p in perceptuals:
edges.append({
"perceptual_id": str(p.id),
"dialog_id": dialog_id,
"end_user_id": str(p.end_user_id),
"created_at": p.created_time.isoformat() if p.created_time else None,
})
result = await connector.execute_query(
PERCEPTUAL_DIALOGUE_EDGE_SAVE,
edges=edges,
)
created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j")
return created_ids
except Exception as e:
print(f"Failed to save Perceptual-Dialogue edges: {e}")
return None