From f9fb480cc3e519a733b24e2957f5d8dc8b46ef49 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 12:30:00 +0800 Subject: [PATCH] [changes] Community Clustering Retrieval Module --- .../memory/agent/services/search_service.py | 38 ++- .../core/memory/agent/utils/write_tools.py | 9 +- api/app/core/memory/src/search.py | 12 +- .../clustering_engine/label_propagation.py | 149 ++++++---- api/app/main.py | 12 +- .../neo4j/community_repository.py | 42 ++- api/app/repositories/neo4j/cypher_queries.py | 121 +++++++-- api/app/repositories/neo4j/graph_saver.py | 15 +- api/app/repositories/neo4j/graph_search.py | 79 ++++++ api/app/repositories/neo4j/index_manager.py | 254 ++++++++++++++++++ redbear-mem-benchmark | 2 +- 11 files changed, 637 insertions(+), 96 deletions(-) create mode 100644 api/app/repositories/neo4j/index_manager.py diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 4fc4256e..2be18c97 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -120,7 +120,7 @@ class SearchService: raw_results is None if return_raw_results=False """ if include is None: - include = ["statements", "chunks", "entities", "summaries"] + include = ["statements", "chunks", "entities", "summaries", "communities"] # Clean query cleaned_query = self.clean_query(question) @@ -146,8 +146,8 @@ class SearchService: if search_type == "hybrid": reranked_results = answer.get('reranked_results', {}) - # Priority order: summaries first (most contextual), then statements, chunks, entities - priority_order = ['summaries', 'statements', 'chunks', 'entities'] + # Priority order: summaries first (most contextual), then communities, statements, chunks, entities + priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] for category in priority_order: if category in include and category in reranked_results: @@ -157,13 +157,43 @@ class SearchService: else: # For keyword or embedding search, results are directly in answer dict # Apply same priority order - priority_order = ['summaries', 'statements', 'chunks', 'entities'] + priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] for category in priority_order: if category in include and category in answer: category_results = answer[category] if isinstance(category_results, list): answer_list.extend(category_results) + + # 对命中的 community 节点展开其成员 statements + if "communities" in include: + community_results = ( + answer.get('reranked_results', {}).get('communities', []) + if search_type == "hybrid" + else answer.get('communities', []) + ) + community_ids = [ + r.get("id") for r in community_results if r.get("id") + ] + if community_ids and end_user_id: + try: + from app.repositories.neo4j.graph_search import search_graph_community_expand + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + connector = Neo4jConnector() + expand_result = await search_graph_community_expand( + connector=connector, + community_ids=community_ids, + end_user_id=end_user_id, + limit=10, + ) + await connector.close() + expanded_stmts = expand_result.get("expanded_statements", []) + if expanded_stmts: + # 展开的 statements 插入 communities 之后、statements 之前 + answer_list.extend(expanded_stmts) + logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements") + except Exception as e: + logger.warning(f"社区展开检索失败,跳过: {e}") # Extract clean content from all results content_list = [ diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b3707083..02aa1b44 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -171,6 +171,13 @@ async def write( ) if success: logger.info("Successfully saved all data to Neo4j") + # 写入成功后,异步触发聚类(不阻塞写入响应) + schedule_clustering_after_write( + all_entity_nodes, + config_id=config_id, + llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, + embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + ) break else: logger.warning("Failed to save some data to Neo4j") diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 0e1d8424..3570d707 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -238,7 +238,7 @@ def rerank_with_activation( reranked: Dict[str, List[Dict[str, Any]]] = {} - for category in ["statements", "chunks", "entities", "summaries"]: + for category in ["statements", "chunks", "entities", "summaries", "communities"]: keyword_items = keyword_results.get(category, []) embedding_items = embedding_results.get(category, []) @@ -281,21 +281,23 @@ def rerank_with_activation( for item in items_list: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") if item_id and item_id in combined_items: - combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) + combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value") # 步骤 4: 计算基础分数和最终分数 for item_id, item in combined_items.items(): bm25_norm = float(item.get("bm25_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0) - act_norm = float(item.get("normalized_activation_value", 0) or 0) + # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义 + raw_act_norm = item.get("normalized_activation_value") + act_norm = float(raw_act_norm) if raw_act_norm is not None else None # 第一阶段:只考虑内容相关性(BM25 + Embedding) # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 content_score = alpha * bm25_norm + (1 - alpha) * emb_norm base_score = content_score # 第一阶段用内容分数 - # 存储激活度分数供第二阶段使用 - item["activation_score"] = act_norm + # 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序) + item["activation_score"] = act_norm # 可能为 None item["content_score"] = content_score item["base_score"] = base_score diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index cbc303b1..b4a16734 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -19,8 +19,9 @@ logger = logging.getLogger(__name__) # 全量迭代最大轮数,防止不收敛 MAX_ITERATIONS = 10 -# 社区摘要核心实体数量 -CORE_ENTITY_LIMIT = 5 + +# 社区核心实体取 top-N 数量 +CORE_ENTITY_LIMIT = 10 def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: @@ -69,11 +70,13 @@ class LabelPropagationEngine: connector: Neo4jConnector, config_id: Optional[str] = None, llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ): self.connector = connector self.repo = CommunityRepository(connector) self.config_id = config_id self.llm_model_id = llm_model_id + self.embedding_model_id = embedding_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -103,58 +106,85 @@ class LabelPropagationEngine: async def full_clustering(self, end_user_id: str) -> None: """ - 全量标签传播初始化。 + 全量标签传播初始化(分批处理,控制内存峰值)。 - 1. 拉取所有实体,初始化每个实体为独立社区 - 2. 迭代:每轮对所有实体做邻居投票,更新社区标签 - 3. 直到标签不再变化或达到 MAX_ITERATIONS - 4. 将最终标签写入 Neo4j + 策略: + - 每次只加载 BATCH_SIZE 个实体及其邻居进内存 + - labels 字典跨批次共享(只存 id→community_id,内存极小) + - 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息 + - 所有批次完成后统一 flush 和 merge """ - entities = await self.repo.get_all_entities(end_user_id) - if not entities: + BATCH_SIZE = 2000 # 每批实体数,可按需调整 + + # 先查总数,决定批次数 + total_entities = await self.repo.get_all_entities(end_user_id) + if not total_entities: logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") return - # 初始化:每个实体持有自己 id 作为社区标签 - labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} - embeddings: Dict[str, Optional[List[float]]] = { - e["id"]: e.get("name_embedding") for e in entities - } + total_count = len(total_entities) + logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体," + f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批") - # 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返 - logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...") - neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id) - logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}") + # labels 跨批次共享:先用全量数据初始化(只存 id,内存极小) + labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities} + # embeddings 也跨批次共享(每个向量 ~6KB,10万实体约 600MB,这是不可避免的) + # 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻 + # 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding, + # 所以只需要当前批次实体的 embedding,不需要全量 + del total_entities # 释放全量列表,后续按批次加载 - for iteration in range(MAX_ITERATIONS): - changed = 0 - # 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历) - for entity in entities: - eid = entity["id"] - # 直接从缓存取邻居,不再发起 Neo4j 查询 - neighbors = neighbors_cache.get(eid, []) - - # 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值) - enriched = [] - for nb in neighbors: - nb_copy = dict(nb) - nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) - enriched.append(nb_copy) - - new_label = _weighted_vote(enriched, embeddings.get(eid)) - if new_label and new_label != labels[eid]: - labels[eid] = new_label - changed += 1 - - logger.info( - f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}," - f"标签变化数: {changed}" + for batch_start in range(0, total_count, BATCH_SIZE): + batch_entities = await self.repo.get_entities_page( + end_user_id, skip=batch_start, limit=BATCH_SIZE ) - if changed == 0: - logger.info("[Clustering] 标签已收敛,提前结束迭代") + if not batch_entities: break - # 将最终标签写入 Neo4j + batch_ids = [e["id"] for e in batch_entities] + batch_embeddings: Dict[str, Optional[List[float]]] = { + e["id"]: e.get("name_embedding") for e in batch_entities + } + + logger.info( + f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:" + f"加载 {len(batch_entities)} 个实体的邻居图..." + ) + neighbors_cache = await self.repo.get_entity_neighbors_for_ids( + batch_ids, end_user_id + ) + logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}") + + for iteration in range(MAX_ITERATIONS): + changed = 0 + for entity in batch_entities: + eid = entity["id"] + neighbors = neighbors_cache.get(eid, []) + + # 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值) + enriched = [] + for nb in neighbors: + nb_copy = dict(nb) + nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) + enriched.append(nb_copy) + + new_label = _weighted_vote(enriched, batch_embeddings.get(eid)) + if new_label and new_label != labels[eid]: + labels[eid] = new_label + changed += 1 + + logger.info( + f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} " + f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}" + ) + if changed == 0: + logger.info("[Clustering] 标签已收敛,提前结束本批迭代") + break + + # 释放本批次的大对象 + del neighbors_cache, batch_embeddings, batch_entities + + # 所有批次完成,统一写入 Neo4j await self._flush_labels(labels, end_user_id) pre_merge_count = len(set(labels.values())) logger.info( @@ -162,17 +192,16 @@ class LabelPropagationEngine: f"{len(labels)} 个实体,开始后处理合并" ) - # 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度) all_community_ids = list(set(labels.values())) await self._evaluate_merge(all_community_ids, end_user_id) logger.info( + f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) - # 为所有社区生成元数据 - # 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区 - # 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID + + # 查询存活社区并生成元数据 surviving_communities = await self.repo.get_all_entities(end_user_id) surviving_community_ids = list({ e.get("community_id") for e in surviving_communities @@ -421,6 +450,7 @@ class LabelPropagationEngine: - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 + NOTE: core_entities按照激活值高低排序,会造成对边缘信息检索返回消息质量不高。 """ try: members = await self.repo.get_community_members(community_id, end_user_id) @@ -468,16 +498,33 @@ class LabelPropagationEngine: except Exception as e: logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") - await self.repo.update_community_metadata( + # 生成 summary_embedding + summary_embedding = None + if self.embedding_model_id and summary: + try: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + results = await embedder.response([summary]) + summary_embedding = results[0] if results else None + except Exception as e: + logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}") + + result = await self.repo.update_community_metadata( community_id=community_id, end_user_id=end_user_id, name=name, summary=summary, core_entities=core_entities, + summary_embedding=summary_embedding, ) - logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}") + if result: + logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...") + else: + logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False") except Exception as e: - logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}") + logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True) @staticmethod def _new_community_id() -> str: diff --git a/api/app/main.py b/api/app/main.py index c6256e3c..5314b8b6 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -18,6 +18,7 @@ from app.core.logging_config import LoggingConfig, get_logger from app.core.response_utils import fail from app.core.models.scripts.loader import load_models from app.db import get_db_context +from app.repositories.neo4j.index_manager import ensure_indexes # Initialize logging system LoggingConfig.setup_logging() @@ -61,9 +62,18 @@ async def lifespan(app: FastAPI): else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") + # 确保 Neo4j 索引存在(幂等,多环境安全) + try: + report = await ensure_indexes() + if report["errors"]: + logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}") + else: + logger.info(f"Neo4j 索引检查完成 [{report['uri']}]") + except Exception as e: + logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}") + logger.info("应用程序启动完成") yield - # 应用关闭事件 logger.info("应用程序正在关闭") diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index f2f11f76..bf7fde1d 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -13,12 +13,15 @@ from app.repositories.neo4j.cypher_queries import ( ENTITY_LEAVE_ALL_COMMUNITIES, GET_ENTITY_NEIGHBORS, GET_ALL_ENTITIES_FOR_USER, + GET_ENTITIES_PAGE, GET_COMMUNITY_MEMBERS, GET_ALL_COMMUNITY_MEMBERS_BATCH, GET_ALL_ENTITY_NEIGHBORS_BATCH, + GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_METADATA, + UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -110,6 +113,41 @@ class CommunityRepository: logger.error(f"get_all_entities failed: {e}") return [] + async def get_entities_page( + self, end_user_id: str, skip: int, limit: int + ) -> List[Dict]: + """分页拉取实体,用于全量聚类分批处理。""" + try: + return await self.connector.execute_query( + GET_ENTITIES_PAGE, + end_user_id=end_user_id, + skip=skip, + limit=limit, + ) + except Exception as e: + logger.error(f"get_entities_page failed: {e}") + return [] + + async def get_entity_neighbors_for_ids( + self, entity_ids: List[str], end_user_id: str + ) -> Dict[str, List[Dict]]: + """批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。""" + try: + rows = await self.connector.execute_query( + GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS, + entity_ids=entity_ids, + end_user_id=end_user_id, + ) + result: Dict[str, List[Dict]] = {} + for row in rows: + eid = row["entity_id"] + neighbor = {k: v for k, v in row.items() if k != "entity_id"} + result.setdefault(eid, []).append(neighbor) + return result + except Exception as e: + logger.error(f"get_entity_neighbors_for_ids failed: {e}") + return {} + async def get_community_members( self, community_id: str, end_user_id: str ) -> List[Dict]: @@ -177,8 +215,9 @@ class CommunityRepository: name: str, summary: str, core_entities: List[str], + summary_embedding: Optional[List[float]] = None, ) -> bool: - """更新社区的名称、摘要和核心实体列表。""" + """更新社区的名称、摘要、核心实体列表和摘要向量。""" try: result = await self.connector.execute_query( UPDATE_COMMUNITY_METADATA, @@ -187,6 +226,7 @@ class CommunityRepository: name=name, summary=summary, core_entities=core_entities, + summary_embedding=summary_embedding, ) return bool(result) except Exception as e: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 48a5ac87..339adb43 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1132,11 +1132,11 @@ ORDER BY coalesce(e.activation_value, 0) DESC GET_ALL_COMMUNITY_MEMBERS_BATCH = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) -WHERE c.community_id IN $community_ids RETURN c.community_id AS community_id, - e.id AS id, - e.name_embedding AS name_embedding, - e.activation_value AS activation_value + e.id AS id, e.name AS name, e.entity_type AS entity_type, + e.importance_score AS importance_score, e.activation_value AS activation_value, + e.name_embedding AS name_embedding +ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC """ CHECK_USER_HAS_COMMUNITIES = """ @@ -1153,13 +1153,47 @@ RETURN c.community_id AS community_id, cnt AS member_count UPDATE_COMMUNITY_METADATA = """ MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -SET c.name = $name, - c.summary = $summary, - c.core_entities = $core_entities, - c.updated_at = datetime() +SET c.name = $name, + c.summary = $summary, + c.core_entities = $core_entities, + c.summary_embedding = $summary_embedding, + c.updated_at = datetime() RETURN c.community_id AS community_id """ +GET_ENTITIES_PAGE = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN e.id AS id, + e.name AS name, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +ORDER BY e.id +SKIP $skip LIMIT $limit +""" + +GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """ +// 批量拉取指定实体列表的邻居(用于分批全量聚类) +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +WHERE e.id IN $entity_ids +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id +WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH e, nb WHERE nb IS NOT NULL +OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN DISTINCT + e.id AS entity_id, + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, + nb.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + GET_ALL_ENTITY_NEIGHBORS_BATCH = """ // 批量拉取某用户下所有实体的邻居(用于全量聚类预加载) MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) @@ -1185,20 +1219,59 @@ RETURN DISTINCT CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id """ -GET_COMMUNITY_GRAPH_DATA = """ -MATCH (c:Community {end_user_id: $end_user_id}) -MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c) -OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id}) -RETURN - elementId(c) AS c_id, - properties(c) AS c_props, - elementId(e) AS e_id, - properties(e) AS e_props, - elementId(b) AS b_id, - elementId(e2) AS e2_id, - properties(e2) AS e2_props, - elementId(r) AS r_id, - type(r) AS r_type, - properties(r) AS r_props, - startNode(r) = e AS r_from_e + +# Community keyword search: matches name or summary via fulltext index +SEARCH_COMMUNITIES_BY_KEYWORD = """ +CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score +WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) +RETURN c.community_id AS id, + c.name AS name, + c.summary AS content, + c.core_entities AS core_entities, + c.member_count AS member_count, + c.end_user_id AS end_user_id, + c.updated_at AS updated_at, + score +ORDER BY score DESC +LIMIT $limit +""" + +# Community 向量检索 ────────────────────────────────────────────────── +# Community embedding-based search: cosine similarity on Community.summary_embedding +COMMUNITY_EMBEDDING_SEARCH = """ +CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding) +YIELD node AS c, score +WHERE c.summary_embedding IS NOT NULL + AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id) +RETURN c.community_id AS id, + c.name AS name, + c.summary AS content, + c.core_entities AS core_entities, + c.member_count AS member_count, + c.end_user_id AS end_user_id, + c.updated_at AS updated_at, + score +ORDER BY score DESC +LIMIT $limit +""" + +# Community 展开检索 ────────────────────────────────────────────────── +# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索) +EXPAND_COMMUNITY_STATEMENTS = """ +MATCH (c:Community {community_id: $community_id}) +MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c) +MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +WHERE s.end_user_id = $end_user_id +RETURN s.statement AS statement, + s.id AS id, + s.end_user_id AS end_user_id, + s.created_at AS created_at, + s.valid_at AS valid_at, + s.invalid_at AS invalid_at, + COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value, + COALESCE(s.importance_score, 0.5) AS importance_score, + e.name AS source_entity, + c.name AS community_name +ORDER BY COALESCE(s.activation_value, 0) DESC +LIMIT $limit """ diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index cbd2b532..29e337f1 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,5 +1,4 @@ import asyncio -import os from typing import List, Optional # 使用新的仓储层 @@ -158,11 +157,12 @@ async def save_dialog_and_statements_to_neo4j( statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], connector: Neo4jConnector, - config_id: Optional[str] = None, - llm_model_id: Optional[str] = None, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. + 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 + schedule_clustering_after_write() 显式触发。 + Args: dialogue_nodes: List of DialogueNode objects to save chunk_nodes: List of ChunkNode objects to save @@ -293,9 +293,6 @@ async def save_dialog_and_statements_to_neo4j( logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) - # 写入成功后,异步触发聚类(不阻塞写入响应) - schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id) - return True except Exception as e: @@ -309,6 +306,7 @@ def schedule_clustering_after_write( entity_nodes: List, config_id: Optional[str] = None, llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 写入 Neo4j 成功后,调度后台聚类任务。 @@ -327,7 +325,7 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) async def _trigger_clustering( @@ -335,6 +333,7 @@ async def _trigger_clustering( end_user_id: str, config_id: Optional[str] = None, llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 @@ -344,7 +343,7 @@ async def _trigger_clustering( from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") connector = Neo4jConnector() - engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id) + engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") except Exception as e: diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index e8f52535..19e40a82 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional from app.repositories.neo4j.cypher_queries import ( CHUNK_EMBEDDING_SEARCH, + COMMUNITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH, + EXPAND_COMMUNITY_STATEMENTS, MEMORY_SUMMARY_EMBEDDING_SEARCH, SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNKS_BY_CONTENT, + SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_ENTITIES_BY_NAME, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, @@ -285,6 +288,15 @@ async def search_graph( limit=limit, )) task_keys.append("summaries") + + if "communities" in include: + tasks.append(connector.execute_query( + SEARCH_COMMUNITIES_BY_KEYWORD, + q=q, + end_user_id=end_user_id, + limit=limit, + )) + task_keys.append("communities") # Execute all queries in parallel task_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -396,6 +408,16 @@ async def search_graph_by_embedding( )) task_keys.append("summaries") + # Communities (向量语义匹配) + if "communities" in include: + tasks.append(connector.execute_query( + COMMUNITY_EMBEDDING_SEARCH, + embedding=embedding, + end_user_id=end_user_id, + limit=limit, + )) + task_keys.append("communities") + # Execute all queries in parallel query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -408,6 +430,7 @@ async def search_graph_by_embedding( "chunks": [], "entities": [], "summaries": [], + "communities": [], } for key, result in zip(task_keys, task_results): @@ -661,6 +684,62 @@ async def search_graph_by_chunk_id( return {"chunks": chunks} +async def search_graph_community_expand( + connector: Neo4jConnector, + community_ids: List[str], + end_user_id: str, + limit: int = 10, +) -> Dict[str, List[Dict[str, Any]]]: + """ + 三期:社区展开检索 —— 主题 → 细节两级检索。 + + 命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体, + 再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点, + 按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。 + + Args: + connector: Neo4j 连接器 + community_ids: 已命中的社区 ID 列表 + end_user_id: 用户 ID,用于数据隔离 + limit: 每个社区最多返回的 Statement 数量 + + Returns: + {"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]} + """ + if not community_ids or not end_user_id: + return {"expanded_statements": []} + + tasks = [ + connector.execute_query( + EXPAND_COMMUNITY_STATEMENTS, + community_id=cid, + end_user_id=end_user_id, + limit=limit, + ) + for cid in community_ids + ] + + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + expanded: List[Dict[str, Any]] = [] + for cid, result in zip(community_ids, task_results): + if isinstance(result, Exception): + logger.warning(f"社区展开检索失败 community_id={cid}: {result}") + else: + expanded.extend(result) + + # 按 activation_value 全局排序后去重 + from app.core.memory.src.search import _deduplicate_results + expanded.sort( + key=lambda x: float(x.get("activation_value") or 0), + reverse=True, + ) + expanded = _deduplicate_results(expanded) + + logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}") + return {"expanded_statements": expanded} + + async def search_graph_by_created_at( connector: Neo4jConnector, end_user_id: Optional[str] = None, diff --git a/api/app/repositories/neo4j/index_manager.py b/api/app/repositories/neo4j/index_manager.py new file mode 100644 index 00000000..a1ab6689 --- /dev/null +++ b/api/app/repositories/neo4j/index_manager.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- +"""Neo4j 索引管理模块 + +负责检查和创建 Neo4j 全文索引与向量索引。 +支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。 + +用法: + # 作为模块调用(应用启动时) + from app.repositories.neo4j.index_manager import ensure_indexes + await ensure_indexes() + + # 作为独立脚本执行(手动建索引) + python -m app.repositories.neo4j.index_manager +""" + +import asyncio +import logging +from dataclasses import dataclass +from typing import List + +from app.core.config import settings +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = logging.getLogger(__name__) + + +# ───────────────────────────────────────────────────────────── +# 索引定义表 +# ───────────────────────────────────────────────────────────── + +@dataclass +class FulltextIndexDef: + name: str + label: str + properties: List[str] + + +@dataclass +class VectorIndexDef: + name: str + label: str + property: str + dimensions: int + similarity: str = "cosine" + + +# 全文索引清单(现有 + 新增 communities) +FULLTEXT_INDEXES: List[FulltextIndexDef] = [ + FulltextIndexDef("statementsFulltext", "Statement", ["statement"]), + FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]), + FulltextIndexDef("chunksFulltext", "Chunk", ["content"]), + FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]), + FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源 +] + +# 向量索引清单(预留 community 二期) +VECTOR_INDEXES: List[VectorIndexDef] = [ + VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536), + VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536), + VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536), + VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536), + # 二期:社区向量索引 + VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536), +] + + +# ───────────────────────────────────────────────────────────── +# 核心检查 / 创建逻辑 +# ───────────────────────────────────────────────────────────── + +async def _get_existing_indexes(connector: Neo4jConnector) -> set: + """查询 Neo4j 中已存在的索引名称集合""" + rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name") + return {row["name"] for row in rows} + + +async def _ensure_fulltext_index( + connector: Neo4jConnector, + idx: FulltextIndexDef, + existing: set, +) -> str: + """检查并按需创建全文索引,返回操作状态描述""" + if idx.name in existing: + return f"[SKIP] 全文索引已存在: {idx.name}" + + props = ", ".join(f"n.{p}" for p in idx.properties) + cypher = ( + f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS ' + f'FOR (n:{idx.label}) ON EACH [{props}]' + ) + await connector.execute_query(cypher) + return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label} → {idx.properties})" + + +async def _ensure_vector_index( + connector: Neo4jConnector, + idx: VectorIndexDef, + existing: set, +) -> str: + """检查并按需创建向量索引,返回操作状态描述""" + if idx.name in existing: + return f"[SKIP] 向量索引已存在: {idx.name}" + + cypher = ( + f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS " + f"FOR (n:{idx.label}) ON n.{idx.property} " + f"OPTIONS {{indexConfig: {{" + f"`vector.dimensions`: {idx.dimensions}, " + f"`vector.similarity_function`: '{idx.similarity}'" + f"}}}}" + ) + await connector.execute_query(cypher) + return ( + f"[CREATE] 向量索引已创建: {idx.name} " + f"({idx.label}.{idx.property}, dim={idx.dimensions})" + ) + + +async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict: + """ + 检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。 + + Args: + connector: 可选,传入已有连接器;为 None 时自动创建。 + + Returns: + dict: { + "uri": 当前连接的 Neo4j URI, + "fulltext": [操作日志列表], + "vector": [操作日志列表], + "errors": [错误信息列表], + } + """ + own_connector = connector is None + if own_connector: + connector = Neo4jConnector() + + report = { + "uri": settings.NEO4J_URI, + "fulltext": [], + "vector": [], + "errors": [], + } + + try: + # 一次性拉取所有已有索引名 + existing = await _get_existing_indexes(connector) + logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}") + logger.info(f"[IndexManager] 已有索引数量: {len(existing)}") + + # 处理全文索引 + for idx in FULLTEXT_INDEXES: + try: + msg = await _ensure_fulltext_index(connector, idx, existing) + report["fulltext"].append(msg) + logger.info(f"[IndexManager] {msg}") + except Exception as e: + err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}" + report["errors"].append(err) + logger.error(f"[IndexManager] {err}") + + # 处理向量索引 + for idx in VECTOR_INDEXES: + try: + msg = await _ensure_vector_index(connector, idx, existing) + report["vector"].append(msg) + logger.info(f"[IndexManager] {msg}") + except Exception as e: + err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}" + report["errors"].append(err) + logger.error(f"[IndexManager] {err}") + + finally: + if own_connector: + await connector.close() + + return report + + +async def check_indexes(connector: Neo4jConnector | None = None) -> dict: + """ + 仅检查索引状态,不创建任何索引。 + + Returns: + dict: { + "uri": ..., + "present": [已存在的索引名], + "missing_fulltext": [缺失的全文索引名], + "missing_vector": [缺失的向量索引名], + } + """ + own_connector = connector is None + if own_connector: + connector = Neo4jConnector() + + try: + existing = await _get_existing_indexes(connector) + missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing] + missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing] + + return { + "uri": settings.NEO4J_URI, + "present": sorted(existing), + "missing_fulltext": missing_ft, + "missing_vector": missing_vec, + } + finally: + if own_connector: + await connector.close() + + +# ───────────────────────────────────────────────────────────── +# 独立脚本入口 +# ───────────────────────────────────────────────────────────── + +async def _main(): + import sys + + print(f"\n{'='*60}") + print(f"Neo4j 索引管理工具") + print(f"环境: {settings.NEO4J_URI}") + print(f"{'='*60}\n") + + # 先检查 + print(">>> 检查当前索引状态...\n") + status = await check_indexes() + print(f" 已存在索引数: {len(status['present'])}") + if status["missing_fulltext"]: + print(f" 缺失全文索引: {status['missing_fulltext']}") + if status["missing_vector"]: + print(f" 缺失向量索引: {status['missing_vector']}") + + if not status["missing_fulltext"] and not status["missing_vector"]: + print("\n 所有索引均已存在,无需操作。") + return + + # 再创建 + print("\n>>> 开始创建缺失索引...\n") + report = await ensure_indexes() + + for msg in report["fulltext"] + report["vector"]: + print(f" {msg}") + + if report["errors"]: + print("\n[!] 以下索引创建失败:") + for err in report["errors"]: + print(f" {err}") + sys.exit(1) + else: + print("\n 全部索引处理完成。") + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/redbear-mem-benchmark b/redbear-mem-benchmark index c3bbc693..89053e48 160000 --- a/redbear-mem-benchmark +++ b/redbear-mem-benchmark @@ -1 +1 @@ -Subproject commit c3bbc6931c570e6fac88c0b00658b4f08dc2ac77 +Subproject commit 89053e48e932332d2a0f17760034ee2bce75ea43