import logging from typing import List, Optional from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ MEMORY_SUMMARY_NODE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector logger = logging.getLogger(__name__) async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") logger.warning(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add dialogue nodes to Neo4j database. Args: dialogues: List of DialogueNode objects to save connector: Neo4j connector instance Returns: List of created node UUIDs or None if failed """ if not dialogues: logger.info("No dialogues to save") return [] try: # Flatten DialogueNode objects to match Cypher expected fields flattened_dialogues = [] for dialogue in dialogues: flattened_dialogues.append({ "id": dialogue.id, "end_user_id": dialogue.end_user_id, "run_id": dialogue.run_id, "ref_id": dialogue.ref_id, "name": dialogue.name, "created_at": dialogue.created_at.isoformat() if dialogue.created_at else None, "expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None, "content": dialogue.content, "dialog_embedding": dialogue.dialog_embedding }) result = await connector.execute_query( DIALOGUE_NODE_SAVE, dialogues=flattened_dialogues ) created_uuids = [record["uuid"] for record in result] logger.info(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") return created_uuids except Exception as e: logger.error(f"Error creating dialogue nodes: {e}") return None async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add statement nodes to Neo4j database. Args: statements: List of StatementNode objects to save connector: Neo4j connector instance Returns: List of created node UUIDs or None if failed """ if not statements: logger.info("No statements to save") return [] try: # Flatten StatementNode objects to only include primitive types flattened_statements = [] for statement in statements: flattened_statement = { "id": statement.id, "name": statement.name, "end_user_id": statement.end_user_id, "run_id": statement.run_id, "chunk_id": statement.chunk_id, # "created_at": statement.created_at.isoformat(), "created_at": statement.created_at.isoformat() if statement.created_at else None, "expired_at": statement.expired_at.isoformat() if statement.expired_at else None, "stmt_type": statement.stmt_type, "temporal_info": statement.temporal_info.value, "statement": statement.statement, "connect_strength": statement.connect_strength, "chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None, # "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None, # "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None, "valid_at": statement.valid_at.isoformat() if statement.valid_at else None, "invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None, # "triplet_extraction_info": json.dumps({ # "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [], # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), "statement_embedding": statement.statement_embedding if statement.statement_embedding else None, # 添加 speaker 字段(用于基于角色的情绪提取) "speaker": statement.speaker if hasattr(statement, 'speaker') else None, # 添加情绪字段处理 "emotion_type": statement.emotion_type, "emotion_intensity": statement.emotion_intensity, "emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [], "emotion_subject": statement.emotion_subject, "emotion_target": statement.emotion_target, # 添加 ACT-R 记忆激活属性 "importance_score": statement.importance_score, "activation_value": statement.activation_value, "access_history": statement.access_history if statement.access_history else [], "last_access_time": statement.last_access_time, "access_count": statement.access_count } flattened_statements.append(flattened_statement) result = await connector.execute_query( STATEMENT_NODE_SAVE, statements=flattened_statements ) created_uuids = [record["uuid"] for record in result] logger.info(f"Successfully created {len(created_uuids)} statement nodes") return created_uuids except Exception as e: logger.error(f"Error creating statement nodes: {e}") return None async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add chunk nodes to Neo4j in batch. Args: chunks: List of ChunkNode objects to add connector: Neo4j connector instance Returns: List of created chunk UUIDs or None if failed """ if not chunks: logger.info("No chunk nodes to add") return [] try: # Convert chunk nodes to dictionaries for the query flattened_chunks = [] for chunk in chunks: # Flatten metadata properties to avoid Neo4j Map type issues metadata = chunk.metadata if chunk.metadata else {} flattened_chunk = { "id": chunk.id, "name": chunk.name, "end_user_id": chunk.end_user_id, "run_id": chunk.run_id, "created_at": chunk.created_at.isoformat() if chunk.created_at else None, "expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None, "dialog_id": chunk.dialog_id, "content": chunk.content, "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, "sequence_number": chunk.sequence_number, "start_index": metadata.get("start_index"), "end_index": metadata.get("end_index"), # 添加 speaker 字段(用于基于角色的情绪提取) "speaker": chunk.speaker if hasattr(chunk, 'speaker') else None } flattened_chunks.append(flattened_chunk) result = await connector.execute_query( CHUNK_NODE_SAVE, chunks=flattened_chunks ) created_uuids = [record["uuid"] for record in result] logger.info(f"Successfully created {len(created_uuids)} chunk nodes") return created_uuids except Exception as e: logger.error(f"Error creating chunk nodes: {e}") return None async def add_memory_summary_nodes( summaries: List[MemorySummaryNode], connector: Neo4jConnector ) -> Optional[List[str]]: """Add memory summary nodes to Neo4j in batch. Args: summaries: List of MemorySummaryNode objects to add connector: Neo4j connector instance Returns: List of created summary node ids or None if failed """ if not summaries: logger.info("No memory summary nodes to add") return [] try: flattened = [] for s in summaries: flattened.append({ "id": s.id, "name": s.name, "end_user_id": s.end_user_id, "run_id": s.run_id, "created_at": s.created_at.isoformat() if s.created_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None, "dialog_id": s.dialog_id, "chunk_ids": s.chunk_ids, "content": s.content, "memory_type": s.memory_type, # 添加 memory_type 字段 "summary_embedding": s.summary_embedding if s.summary_embedding else None, "config_id": s.config_id, # 添加 config_id }) result = await connector.execute_query( MEMORY_SUMMARY_NODE_SAVE, summaries=flattened ) created_ids = [record.get("uuid") for record in result] logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None