import asyncio import os from typing import List, Optional # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.add_nodes import add_dialogue_nodes, add_statement_nodes, add_chunk_nodes from app.repositories.neo4j.cypher_queries import ( STATEMENT_ENTITY_EDGE_SAVE, ENTITY_RELATIONSHIP_SAVE, EXTRACTED_ENTITY_NODE_SAVE, CHUNK_STATEMENT_EDGE_SAVE, STATEMENT_ENTITY_EDGE_SAVE, ENTITY_RELATIONSHIP_SAVE, EXTRACTED_ENTITY_NODE_SAVE, ) from app.core.memory.models.graph_models import ( DialogueNode, ChunkNode, StatementChunkEdge, StatementEntityEdge, StatementNode, ExtractedEntityNode, EntityEntityEdge, PerceptualNode, PerceptualEdge, ) import logging logger = logging.getLogger(__name__) async def save_entities_and_relationships( entity_nodes: List[ExtractedEntityNode], entity_entity_edges: List[EntityEntityEdge], connector: Neo4jConnector ): """Save entities and their relationships using graph models""" all_entities = [entity.model_dump() for entity in entity_nodes] all_relationships = [] for edge in entity_entity_edges: relationship = { 'source_id': edge.source, 'target_id': edge.target, 'predicate': edge.relation_type, 'statement_id': edge.source_statement_id, 'value': edge.relation_value, 'statement': edge.statement, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, 'created_at': edge.created_at.isoformat() if edge.created_at else None, 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, 'run_id': edge.run_id, 'end_user_id': edge.end_user_id, } all_relationships.append(relationship) # Save entities if all_entities: entity_uuids = await connector.execute_query(EXTRACTED_ENTITY_NODE_SAVE, entities=all_entities) if entity_uuids: print(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") else: print("Failed to save entity nodes to Neo4j") else: print("No entity nodes to save") # Create relationships if all_relationships: relationship_uuids = await connector.execute_query(ENTITY_RELATIONSHIP_SAVE, relationships=all_relationships) if relationship_uuids: print(f"Successfully saved {len(relationship_uuids)} entity relationships (edges) to Neo4j") else: print("Failed to save entity relationships to Neo4j") else: print("No entity relationships to save") async def save_chunk_nodes( chunk_nodes: List[ChunkNode], connector: Neo4jConnector ): """Save chunk nodes using graph models""" if not chunk_nodes: print("No chunk nodes to save") return chunk_uuids = await add_chunk_nodes(chunk_nodes, connector) if chunk_uuids: print(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") else: print("Failed to save chunk nodes to Neo4j") async def save_statement_chunk_edges( statement_chunk_edges: List[StatementChunkEdge], connector: Neo4jConnector ): """Save statement-chunk edges using graph models""" if not statement_chunk_edges: return all_sc_edges = [] for edge in statement_chunk_edges: all_sc_edges.append({ "id": edge.id, "source": edge.source, "target": edge.target, "end_user_id": edge.end_user_id, "run_id": edge.run_id, "created_at": edge.created_at.isoformat() if edge.created_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, }) try: await connector.execute_query( CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=all_sc_edges ) except Exception: pass async def save_statement_entity_edges( statement_entity_edges: List[StatementEntityEdge], connector: Neo4jConnector ): """Save statement-entity edges using graph models""" if not statement_entity_edges: print("No statement-entity edges to save") return all_se_edges = [] for edge in statement_entity_edges: edge_data = { "source": edge.source, "target": edge.target, "end_user_id": edge.end_user_id, "run_id": edge.run_id, "connect_strength": edge.connect_strength, "created_at": edge.created_at.isoformat() if edge.created_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, } all_se_edges.append(edge_data) if all_se_edges: try: await connector.execute_query( STATEMENT_ENTITY_EDGE_SAVE, relationships=all_se_edges ) except Exception: pass async def save_dialog_and_statements_to_neo4j( dialogue_nodes: List[DialogueNode], chunk_nodes: List[ChunkNode], statement_nodes: List[StatementNode], entity_nodes: List[ExtractedEntityNode], perceptual_nodes: List[PerceptualNode], entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], perceptual_edges: List[PerceptualEdge], connector: Neo4jConnector, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 _trigger_clustering_sync() 显式触发。 Args: dialogue_nodes: List of DialogueNode objects to save chunk_nodes: List of ChunkNode objects to save statement_nodes: List of StatementNode objects to save entity_nodes: List of ExtractedEntityNode objects to save perceptual_nodes: List of PerceptualNode objects to save entity_edges: List of EntityEntityEdge objects to save statement_chunk_edges: List of StatementChunkEdge objects to save statement_entity_edges: List of StatementEntityEdge objects to save perceptual_edges: List of PerceptualEdge objects to save connector: Neo4j connector instance Returns: bool: True if successful, False otherwise """ # TODO 需要在去重消歧节阶段,做以下逻辑的处理 # 预处理:对特殊实体("用户"、"AI助手")复用 Neo4j 中已有节点的 ID, # 确保同一个 end_user_id 下只有一个"用户"节点和一个"AI助手"节点。 if entity_nodes: _SPECIAL_NAMES = {"用户", "我", "user", "i", "ai助手", "助手", "ai assistant", "assistant"} end_user_id = entity_nodes[0].end_user_id if entity_nodes else None if end_user_id: try: # 查询已有的特殊实体 cypher = """ MATCH (e:ExtractedEntity) WHERE e.end_user_id = $end_user_id AND toLower(e.name) IN $names RETURN e.id AS id, e.name AS name """ existing = await connector.execute_query( cypher, end_user_id=end_user_id, names=list(_SPECIAL_NAMES), ) # 建立 name(lower) → existing_id 映射 existing_id_map = {} for record in (existing or []): name_lower = (record.get("name") or "").strip().lower() if name_lower and record.get("id"): existing_id_map[name_lower] = record["id"] if existing_id_map: # 替换新实体的 ID 为已有 ID,同时更新所有引用该 ID 的边 for ent in entity_nodes: name_lower = (ent.name or "").strip().lower() if name_lower in existing_id_map: old_id = ent.id new_id = existing_id_map[name_lower] if old_id != new_id: ent.id = new_id # 更新 statement_entity_edges 中的引用 for edge in statement_entity_edges: if edge.target == old_id: edge.target = new_id if edge.source == old_id: edge.source = new_id # 更新 entity_edges 中的引用 for edge in entity_edges: if edge.source == old_id: edge.source = new_id if edge.target == old_id: edge.target = new_id logger.info( f"特殊实体 '{ent.name}' ID 复用: {old_id[:8]}... → {new_id[:8]}..." ) except Exception as e: logger.warning(f"特殊实体 ID 复用查询失败(不影响写入): {e}") # 定义事务函数,将所有写操作放在一个事务中 async def _save_all_in_transaction(tx): """在单个事务中执行所有保存操作,避免死锁""" results = {} # 1. Save all dialogue nodes in batch if dialogue_nodes: from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE dialogue_data = [node.model_dump() for node in dialogue_nodes] result = await tx.run(DIALOGUE_NODE_SAVE, dialogues=dialogue_data) dialogue_uuids = [record["uuid"] async for record in result] results['dialogues'] = dialogue_uuids logger.info(f"Dialogues saved to Neo4j with UUIDs: {dialogue_uuids}") # 2. Save all chunk nodes in batch if chunk_nodes: from app.repositories.neo4j.cypher_queries import CHUNK_NODE_SAVE chunk_data = [node.model_dump() for node in chunk_nodes] result = await tx.run(CHUNK_NODE_SAVE, chunks=chunk_data) chunk_uuids = [record["uuid"] async for record in result] results['chunks'] = chunk_uuids logger.info(f"Successfully saved {len(chunk_uuids)} chunk nodes to Neo4j") if perceptual_nodes: from app.repositories.neo4j.cypher_queries import PERCEPTUAL_NODE_SAVE perceptual_data = [node.model_dump() for node in perceptual_nodes] result = await tx.run(PERCEPTUAL_NODE_SAVE, perceptuals=perceptual_data) perceptual_uuids = [record["uuid"] async for record in result] results["perceptuals"] = perceptual_uuids logger.info(f"Successfully saved {len(perceptual_uuids)} perceptual nodes to Neo4j") # 3. Save all statement nodes in batch if statement_nodes: from app.repositories.neo4j.cypher_queries import STATEMENT_NODE_SAVE statement_data = [node.model_dump() for node in statement_nodes] result = await tx.run(STATEMENT_NODE_SAVE, statements=statement_data) statement_uuids = [record["uuid"] async for record in result] results['statements'] = statement_uuids logger.info(f"Successfully saved {len(statement_uuids)} statement nodes to Neo4j") # 4. Save entities if entity_nodes: from app.repositories.neo4j.cypher_queries import EXTRACTED_ENTITY_NODE_SAVE entity_data = [entity.model_dump() for entity in entity_nodes] result = await tx.run(EXTRACTED_ENTITY_NODE_SAVE, entities=entity_data) entity_uuids = [record["uuid"] async for record in result] results['entities'] = entity_uuids logger.info(f"Successfully saved {len(entity_uuids)} entity nodes to Neo4j") # 5. Create entity relationships if entity_edges: from app.repositories.neo4j.cypher_queries import ENTITY_RELATIONSHIP_SAVE relationship_data = [] for edge in entity_edges: relationship_data.append({ 'source_id': edge.source, 'target_id': edge.target, 'predicate': edge.relation_type, 'statement_id': edge.source_statement_id, 'value': edge.relation_value, 'statement': edge.statement, 'valid_at': edge.valid_at.isoformat() if edge.valid_at else None, 'invalid_at': edge.invalid_at.isoformat() if edge.invalid_at else None, 'created_at': edge.created_at.isoformat() if edge.created_at else None, 'expired_at': edge.expired_at.isoformat() if edge.expired_at else None, 'run_id': edge.run_id, 'end_user_id': edge.end_user_id, }) result = await tx.run(ENTITY_RELATIONSHIP_SAVE, relationships=relationship_data) rel_uuids = [record["uuid"] async for record in result] results['entity_relationships'] = rel_uuids logger.info(f"Successfully saved {len(rel_uuids)} entity relationships to Neo4j") # 6. Save statement-chunk edges if statement_chunk_edges: from app.repositories.neo4j.cypher_queries import CHUNK_STATEMENT_EDGE_SAVE sc_edge_data = [] for edge in statement_chunk_edges: sc_edge_data.append({ "id": edge.id, "source": edge.source, "target": edge.target, "created_at": edge.created_at.isoformat() if edge.created_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, "run_id": edge.run_id, "end_user_id": edge.end_user_id, }) result = await tx.run(CHUNK_STATEMENT_EDGE_SAVE, chunk_statement_edges=sc_edge_data) sc_uuids = [record["uuid"] async for record in result] results['statement_chunk_edges'] = sc_uuids logger.info(f"Successfully saved {len(sc_uuids)} statement-chunk edges to Neo4j") # 7. Save statement-entity edges if statement_entity_edges: from app.repositories.neo4j.cypher_queries import STATEMENT_ENTITY_EDGE_SAVE se_edge_data = [] for edge in statement_entity_edges: se_edge_data.append({ "source": edge.source, "target": edge.target, "created_at": edge.created_at.isoformat() if edge.created_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, "run_id": edge.run_id, "end_user_id": edge.end_user_id, "connect_strength": getattr(edge, "connect_strength", "strong"), }) result = await tx.run(STATEMENT_ENTITY_EDGE_SAVE, relationships=se_edge_data) se_uuids = [record["uuid"] async for record in result] results['statement_entity_edges'] = se_uuids logger.info(f"Successfully saved {len(se_uuids)} statement-entity edges to Neo4j") if perceptual_edges: from app.repositories.neo4j.cypher_queries import PERCEPTUAL_CHUNK_EDGE_SAVE perceptual_edge_data = [] for edge in perceptual_edges: print(edge.source, edge.target) perceptual_edge_data.append({ "perceptual_id": edge.source, "chunk_id": edge.target, "end_user_id": edge.end_user_id, "created_at": edge.created_at.isoformat() if edge.created_at else None, }) result = await tx.run(PERCEPTUAL_CHUNK_EDGE_SAVE, edges=perceptual_edge_data) perceptual_edges_uuids = [record["uuid"] async for record in result] results['perceptual_chunk_edges'] = perceptual_edges_uuids logger.info(f"Successfully saved {len(perceptual_edges_uuids)} perceptual-chunk edges to Neo4j") return results try: # 使用显式写事务执行所有操作,避免死锁 results = await connector.execute_write_transaction(_save_all_in_transaction) summary = { key: len(value) for key, value in results.items() if isinstance(value, (list, tuple, set)) } logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) return True except Exception as e: logger.error(f"Neo4j integration error: {e}", exc_info=True) print(f"Neo4j integration error: {e}") print("Continuing without database storage...") return False async def _trigger_clustering_sync( entity_nodes: List, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, ) -> None: """ 同步等待聚类完成,避免与其他 LLM 任务并发冲突。 """ if not entity_nodes: return clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false" if not clustering_enabled: logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发") return end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) async def _trigger_clustering( new_entity_ids: List[str], end_user_id: str, llm_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 """ connector = None try: from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") connector = Neo4jConnector() engine = LabelPropagationEngine(connector, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") except Exception as e: logger.error(f"[Clustering] 聚类触发失败: {e}", exc_info=True) finally: if connector: try: await connector.close() except Exception: pass