from typing import List, Optional from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE, \ MEMORY_SUMMARY_NODE_SAVE, PERCEPTUAL_NODE_SAVE, PERCEPTUAL_DIALOGUE_EDGE_SAVE # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") print(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add dialogue nodes to Neo4j database. Args: dialogues: List of DialogueNode objects to save connector: Neo4j connector instance Returns: List of created node UUIDs or None if failed """ if not dialogues: print("No dialogues to save") return [] try: # Flatten DialogueNode objects to match Cypher expected fields flattened_dialogues = [] for dialogue in dialogues: flattened_dialogues.append({ "id": dialogue.id, "end_user_id": dialogue.end_user_id, "run_id": dialogue.run_id, "ref_id": dialogue.ref_id, "name": dialogue.name, "created_at": dialogue.created_at.isoformat() if dialogue.created_at else None, "expired_at": dialogue.expired_at.isoformat() if dialogue.expired_at else None, "content": dialogue.content, "dialog_embedding": dialogue.dialog_embedding }) result = await connector.execute_query( DIALOGUE_NODE_SAVE, dialogues=flattened_dialogues ) created_uuids = [record["uuid"] for record in result] print(f"Successfully created {len(created_uuids)} dialogue nodes: {created_uuids}") return created_uuids except Exception as e: print(f"Error creating dialogue nodes: {e}") return None async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add statement nodes to Neo4j database. Args: statements: List of StatementNode objects to save connector: Neo4j connector instance Returns: List of created node UUIDs or None if failed """ if not statements: print("No statements to save") return [] try: # Flatten StatementNode objects to only include primitive types flattened_statements = [] for statement in statements: flattened_statement = { "id": statement.id, "name": statement.name, "end_user_id": statement.end_user_id, "run_id": statement.run_id, "chunk_id": statement.chunk_id, # "created_at": statement.created_at.isoformat(), "created_at": statement.created_at.isoformat() if statement.created_at else None, "expired_at": statement.expired_at.isoformat() if statement.expired_at else None, "stmt_type": statement.stmt_type, "temporal_info": statement.temporal_info.value, "statement": statement.statement, "connect_strength": statement.connect_strength, "chunk_embedding": statement.chunk_embedding if statement.chunk_embedding else None, # "temporal_validity_valid_at": statement.temporal_validity_valid_at.isoformat() if statement.temporal_validity_valid_at else None, # "temporal_validity_invalid_at": statement.temporal_validity_invalid_at.isoformat() if statement.temporal_validity_invalid_at else None, "valid_at": statement.valid_at.isoformat() if statement.valid_at else None, "invalid_at": statement.invalid_at.isoformat() if statement.invalid_at else None, # "triplet_extraction_info": json.dumps({ # "triplets": [triplet.model_dump() for triplet in statement.triplet_extraction_info.triplets] if statement.triplet_extraction_info else [], # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), "statement_embedding": statement.statement_embedding if statement.statement_embedding else None, # 添加 speaker 字段(用于基于角色的情绪提取) "speaker": statement.speaker if hasattr(statement, 'speaker') else None, # 添加情绪字段处理 "emotion_type": statement.emotion_type, "emotion_intensity": statement.emotion_intensity, "emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [], "emotion_subject": statement.emotion_subject, "emotion_target": statement.emotion_target, # 添加 ACT-R 记忆激活属性 "importance_score": statement.importance_score, "activation_value": statement.activation_value, "access_history": statement.access_history if statement.access_history else [], "last_access_time": statement.last_access_time, "access_count": statement.access_count } flattened_statements.append(flattened_statement) result = await connector.execute_query( STATEMENT_NODE_SAVE, statements=flattened_statements ) created_uuids = [record["uuid"] for record in result] print(f"Successfully created {len(created_uuids)} statement nodes") return created_uuids except Exception as e: print(f"Error creating statement nodes: {e}") return None async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> Optional[List[str]]: """Add chunk nodes to Neo4j in batch. Args: chunks: List of ChunkNode objects to add connector: Neo4j connector instance Returns: List of created chunk UUIDs or None if failed """ if not chunks: print("No chunk nodes to add") return [] try: # Convert chunk nodes to dictionaries for the query flattened_chunks = [] for chunk in chunks: # Flatten metadata properties to avoid Neo4j Map type issues metadata = chunk.metadata if chunk.metadata else {} flattened_chunk = { "id": chunk.id, "name": chunk.name, "end_user_id": chunk.end_user_id, "run_id": chunk.run_id, "created_at": chunk.created_at.isoformat() if chunk.created_at else None, "expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None, "dialog_id": chunk.dialog_id, "content": chunk.content, "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, "sequence_number": chunk.sequence_number, "start_index": metadata.get("start_index"), "end_index": metadata.get("end_index"), # 添加 speaker 字段(用于基于角色的情绪提取) "speaker": chunk.speaker if hasattr(chunk, 'speaker') else None } flattened_chunks.append(flattened_chunk) result = await connector.execute_query( CHUNK_NODE_SAVE, chunks=flattened_chunks ) created_uuids = [record["uuid"] for record in result] print(f"Successfully created {len(created_uuids)} chunk nodes") return created_uuids except Exception as e: print(f"Error creating chunk nodes: {e}") return None async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector: Neo4jConnector) -> Optional[ List[str]]: """Add memory summary nodes to Neo4j in batch. Args: summaries: List of MemorySummaryNode objects to add connector: Neo4j connector instance Returns: List of created summary node ids or None if failed """ if not summaries: print("No memory summary nodes to add") return [] try: flattened = [] for s in summaries: flattened.append({ "id": s.id, "name": s.name, "end_user_id": s.end_user_id, "run_id": s.run_id, "created_at": s.created_at.isoformat() if s.created_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None, "dialog_id": s.dialog_id, "chunk_ids": s.chunk_ids, "content": s.content, "memory_type": s.memory_type, # 添加 memory_type 字段 "summary_embedding": s.summary_embedding if s.summary_embedding else None, "config_id": s.config_id, # 添加 config_id }) result = await connector.execute_query( MEMORY_SUMMARY_NODE_SAVE, summaries=flattened ) created_ids = [record.get("uuid") for record in result] print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j") return created_ids except Exception as e: print(f"Failed to save MemorySummary nodes to Neo4j: {e}") return None async def add_perceptual_nodes( perceptuals: list, connector: Neo4jConnector, embedder_client=None, ) -> Optional[List[str]]: """Add perceptual memory nodes to Neo4j in batch. Args: perceptuals: List of MemoryPerceptualModel objects from PostgreSQL connector: Neo4j connector instance embedder_client: Optional embedder client for generating summary embeddings Returns: List of created node UUIDs or None if failed """ if not perceptuals: print("No perceptual nodes to add") return [] try: flattened = [] for p in perceptuals: meta = p.meta_data or {} content_meta = meta.get("content", {}) # 生成 summary embedding(如果有 embedder_client) summary_embedding = None if embedder_client and p.summary: try: summary_embedding = (await embedder_client.response([p.summary]))[0] except Exception as emb_err: print(f"Failed to embed perceptual summary: {emb_err}") flattened.append({ "id": str(p.id), "end_user_id": str(p.end_user_id), "perceptual_type": p.perceptual_type, "file_path": p.file_path or "", "file_name": p.file_name or "", "file_ext": p.file_ext or "", "summary": p.summary or "", "keywords": content_meta.get("keywords", []), "topic": content_meta.get("topic", ""), "domain": content_meta.get("domain", ""), "created_at": p.created_time.isoformat() if p.created_time else None, "summary_embedding": summary_embedding, }) result = await connector.execute_query( PERCEPTUAL_NODE_SAVE, perceptuals=flattened, ) created_uuids = [record.get("uuid") for record in result] print(f"Successfully saved {len(created_uuids)} Perceptual nodes to Neo4j") return created_uuids except Exception as e: print(f"Failed to save Perceptual nodes to Neo4j: {e}") return None async def add_perceptual_dialogue_edges( perceptuals: list, dialog_id: str, connector: Neo4jConnector, ) -> Optional[List[str]]: """Add edges between Perceptual nodes and Dialogue nodes. Args: perceptuals: List of MemoryPerceptualModel objects dialog_id: The dialogue ID (or ref_id) to link to connector: Neo4j connector instance Returns: List of created edge element IDs or None if failed """ if not perceptuals or not dialog_id: return [] try: edges = [] for p in perceptuals: edges.append({ "perceptual_id": str(p.id), "dialog_id": dialog_id, "end_user_id": str(p.end_user_id), "created_at": p.created_time.isoformat() if p.created_time else None, }) result = await connector.execute_query( PERCEPTUAL_DIALOGUE_EDGE_SAVE, edges=edges, ) created_ids = [record.get("uuid") for record in result] print(f"Successfully saved {len(created_ids)} Perceptual-Dialogue edges to Neo4j") return created_ids except Exception as e: print(f"Failed to save Perceptual-Dialogue edges: {e}") return None