Files
MemoryBear/api/app/repositories/neo4j/add_nodes.py
Eternity 89d188fbf3 Merge branch 'develop' into feature/multimodel_memory
# Conflicts:
#	api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py
#	api/app/repositories/neo4j/add_nodes.py
#	api/app/repositories/neo4j/cypher_queries.py
#	api/app/repositories/neo4j/graph_saver.py
#	api/app/services/memory_agent_service.py
#	api/app/services/multimodal_service.py
2026-03-24 14:15:18 +08:00

233 lines
9.8 KiB
Python

import logging
from typing import List, Optional
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \
MEMORY_SUMMARY_NODE_SAVE
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n")
logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully")
return result
async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add dialogue nodes to Neo4j database.
Args:
dialogues: List of DialogueNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not dialogues:
logger.info("No dialogues to save")
return []
try:
# Flatten DialogueNode objects to match Cypher expected fields
flattened_dialogues = []
for dialogue in dialogues:
flattened_dialogues.append({
"id": dialogue.id,
"end_user_id": dialogue.end_user_id,
"run_id": dialogue.run_id,
"ref_id": dialogue.ref_id,
"name": dialogue.name,
"created_at": dialogue.created_at.isoformat() if dialogue.created_at else None,
"expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None,
"content": dialogue.content,
"dialog_embedding": dialogue.dialog_embedding
})
result = await connector.execute_query(
DIALOGUE_NODE_SAVE,
dialogues=flattened_dialogues
)
created_uuids = [record["uuid"] for record in result]
logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}")
return created_uuids
except Exception as e:
logger.error(f"Error creating dialogue nodes: {e}")
return None
async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add statement nodes to Neo4j database.
Args:
statements: List of StatementNode objects to save
connector: Neo4j connector instance
Returns:
List of created node UUIDs or None if failed
"""
if not statements:
logger.info("No statements to save")
return []
try:
# Flatten StatementNode objects to only include primitive types
flattened_statements = []
for statement in statements:
flattened_statement = {
"id": statement.id,
"name": statement.name,
"end_user_id": statement.end_user_id,
"run_id": statement.run_id,
"chunk_id": statement.chunk_id,
# "created_at": statement.created_at.isoformat(),
"created_at": statement.created_at.isoformat() if statement.created_at else None,
"expired_at": statement.expired_at.isoformat() if statement.expired_at else None,
"stmt_type": statement.stmt_type,
"temporal_info": statement.temporal_info.value,
"statement": statement.statement,
"connect_strength": statement.connect_strength,
"chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None,
# "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None,
# "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None,
"valid_at": statement.valid_at.isoformat() if statement.valid_at else None,
"invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None,
# "triplet_extraction_info": json.dumps({
# "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [],
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
# 添加情绪字段处理
"emotion_type": statement.emotion_type,
"emotion_intensity": statement.emotion_intensity,
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
"emotion_subject": statement.emotion_subject,
"emotion_target": statement.emotion_target,
# 添加 ACT-R 记忆激活属性
"importance_score": statement.importance_score,
"activation_value": statement.activation_value,
"access_history": statement.access_history if statement.access_history else [],
"last_access_time": statement.last_access_time,
"access_count": statement.access_count
}
flattened_statements.append(flattened_statement)
result = await connector.execute_query(
STATEMENT_NODE_SAVE,
statements=flattened_statements
)
created_uuids = [record["uuid"] for record in result]
logger.info(f"Successfully created {len(created_uuids)} statement nodes")
return created_uuids
except Exception as e:
logger.error(f"Error creating statement nodes: {e}")
return None
async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]:
"""Add chunk nodes to Neo4j in batch.
Args:
chunks: List of ChunkNode objects to add
connector: Neo4j connector instance
Returns:
List of created chunk UUIDs or None if failed
"""
if not chunks:
logger.info("No chunk nodes to add")
return []
try:
# Convert chunk nodes to dictionaries for the query
flattened_chunks = []
for chunk in chunks:
# Flatten metadata properties to avoid Neo4j Map type issues
metadata = chunk.metadata if chunk.metadata else {}
flattened_chunk = {
"id": chunk.id,
"name": chunk.name,
"end_user_id": chunk.end_user_id,
"run_id": chunk.run_id,
"created_at": chunk.created_at.isoformat() if chunk.created_at else None,
"expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None,
"dialog_id": chunk.dialog_id,
"content": chunk.content,
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index"),
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
}
flattened_chunks.append(flattened_chunk)
result = await connector.execute_query(
CHUNK_NODE_SAVE,
chunks=flattened_chunks
)
created_uuids = [record["uuid"] for record in result]
logger.info(f"Successfully created {len(created_uuids)} chunk nodes")
return created_uuids
except Exception as e:
logger.error(f"Error creating chunk nodes: {e}")
return None
async def add_memory_summary_nodes(
summaries: List[MemorySummaryNode],
connector: Neo4jConnector
) -> Optional[List[str]]:
"""Add memory summary nodes to Neo4j in batch.
Args:
summaries: List of MemorySummaryNode objects to add
connector: Neo4j connector instance
Returns:
List of created summary node ids or None if failed
"""
if not summaries:
logger.info("No memory summary nodes to add")
return []
try:
flattened = []
for s in summaries:
flattened.append({
"id": s.id,
"name": s.name,
"end_user_id": s.end_user_id,
"run_id": s.run_id,
"created_at": s.created_at.isoformat() if s.created_at else None,
"expired_at": s.expired_at.isoformat() if s.expired_at else None,
"dialog_id": s.dialog_id,
"chunk_ids": s.chunk_ids,
"content": s.content,
"memory_type": s.memory_type, # 添加 memory_type 字段
"summary_embedding": s.summary_embedding if s.summary_embedding else None,
"config_id": s.config_id, # 添加 config_id
})
result = await connector.execute_query(
MEMORY_SUMMARY_NODE_SAVE,
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids
except Exception as e:
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None