From 5b431400be9959442dca774ddd95dc95719fd343 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 10 Mar 2026 17:06:43 +0800 Subject: [PATCH 01/14] [add] Create community nodes --- .../clustering_engine/__init__.py | 3 + .../clustering_engine/label_propagation.py | 311 ++++++++++++++++++ .../neo4j/community_repository.py | 129 ++++++++ api/app/repositories/neo4j/cypher_queries.py | 93 +++++- api/app/repositories/neo4j/graph_saver.py | 34 ++ 5 files changed, 569 insertions(+), 1 deletion(-) create mode 100644 api/app/core/memory/storage_services/clustering_engine/__init__.py create mode 100644 api/app/core/memory/storage_services/clustering_engine/label_propagation.py create mode 100644 api/app/repositories/neo4j/community_repository.py diff --git a/api/app/core/memory/storage_services/clustering_engine/__init__.py b/api/app/core/memory/storage_services/clustering_engine/__init__.py new file mode 100644 index 00000000..992d8bff --- /dev/null +++ b/api/app/core/memory/storage_services/clustering_engine/__init__.py @@ -0,0 +1,3 @@ +from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + +__all__ = ["LabelPropagationEngine"] diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py new file mode 100644 index 00000000..80e238fd --- /dev/null +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -0,0 +1,311 @@ +"""标签传播聚类引擎 + +基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。 + +支持两种模式: +- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代 +- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居 +""" + +import logging +import uuid +from math import sqrt +from typing import Dict, List, Optional + +from app.repositories.neo4j.community_repository import CommunityRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = logging.getLogger(__name__) + +# 全量迭代最大轮数,防止不收敛 +MAX_ITERATIONS = 10 + + +def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: + """计算两个向量的余弦相似度,任一为空则返回 0。""" + if not v1 or not v2 or len(v1) != len(v2): + return 0.0 + dot = sum(a * b for a, b in zip(v1, v2)) + norm1 = sqrt(sum(a * a for a in v1)) + norm2 = sqrt(sum(b * b for b in v2)) + if norm1 == 0 or norm2 == 0: + return 0.0 + return dot / (norm1 * norm2) + + +def _weighted_vote( + neighbors: List[Dict], + self_embedding: Optional[List[float]], +) -> Optional[str]: + """ + 加权多数投票,选出得票最高的社区。 + + 权重 = 语义相似度(name_embedding 余弦)* activation_value 加成 + 没有 community_id 的邻居不参与投票。 + """ + votes: Dict[str, float] = {} + for nb in neighbors: + cid = nb.get("community_id") + if not cid: + continue + sem = _cosine_similarity(self_embedding, nb.get("name_embedding")) + act = nb.get("activation_value") or 0.5 + # 语义相似度权重 0.6,激活值权重 0.4 + weight = 0.6 * sem + 0.4 * act + votes[cid] = votes.get(cid, 0.0) + weight + + if not votes: + return None + return max(votes, key=votes.__getitem__) + + +class LabelPropagationEngine: + """标签传播聚类引擎""" + + def __init__(self, connector: Neo4jConnector): + self.connector = connector + self.repo = CommunityRepository(connector) + + # ────────────────────────────────────────────────────────────────────────── + # 公开接口 + # ────────────────────────────────────────────────────────────────────────── + + async def run( + self, + end_user_id: str, + new_entity_ids: Optional[List[str]] = None, + ) -> None: + """ + 统一入口:自动判断全量还是增量。 + + - 若该用户尚无 Community 节点 → 全量初始化 + - 否则 → 增量更新(仅处理 new_entity_ids) + """ + has_communities = await self.repo.has_communities(end_user_id) + if not has_communities: + logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化") + await self.full_clustering(end_user_id) + else: + if new_entity_ids: + logger.info( + f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}" + ) + await self.incremental_update(new_entity_ids, end_user_id) + + async def full_clustering(self, end_user_id: str) -> None: + """ + 全量标签传播初始化。 + + 1. 拉取所有实体,初始化每个实体为独立社区 + 2. 迭代:每轮对所有实体做邻居投票,更新社区标签 + 3. 直到标签不再变化或达到 MAX_ITERATIONS + 4. 将最终标签写入 Neo4j + """ + entities = await self.repo.get_all_entities(end_user_id) + if not entities: + logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") + return + + # 初始化:每个实体持有自己 id 作为社区标签 + labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} + embeddings: Dict[str, Optional[List[float]]] = { + e["id"]: e.get("name_embedding") for e in entities + } + + for iteration in range(MAX_ITERATIONS): + changed = 0 + # 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历) + for entity in entities: + eid = entity["id"] + neighbors = await self.repo.get_entity_neighbors(eid, end_user_id) + + # 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值) + enriched = [] + for nb in neighbors: + nb_copy = dict(nb) + nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) + enriched.append(nb_copy) + + new_label = _weighted_vote(enriched, embeddings.get(eid)) + if new_label and new_label != labels[eid]: + labels[eid] = new_label + changed += 1 + + logger.info( + f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}," + f"标签变化数: {changed}" + ) + if changed == 0: + logger.info("[Clustering] 标签已收敛,提前结束迭代") + break + + # 将最终标签写入 Neo4j + await self._flush_labels(labels, end_user_id) + logger.info( + f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区," + f"{len(labels)} 个实体" + ) + + async def incremental_update( + self, new_entity_ids: List[str], end_user_id: str + ) -> None: + """ + 增量更新:只处理新实体及其邻居,不重跑全图。 + + 1. 对每个新实体查询邻居 + 2. 加权多数投票决定社区归属 + 3. 若邻居无社区 → 创建新社区 + 4. 若邻居分属多个社区 → 评估是否合并 + """ + for entity_id in new_entity_ids: + await self._process_single_entity(entity_id, end_user_id) + + # ────────────────────────────────────────────────────────────────────────── + # 内部方法 + # ────────────────────────────────────────────────────────────────────────── + + async def _process_single_entity( + self, entity_id: str, end_user_id: str + ) -> None: + """处理单个新实体的社区分配。""" + neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id) + + # 查询自身 embedding(从邻居查询结果中无法获取,需单独查) + self_embedding = await self._get_entity_embedding(entity_id, end_user_id) + + if not neighbors: + # 孤立实体:创建单成员社区 + new_cid = self._new_community_id() + await self.repo.upsert_community(new_cid, end_user_id, member_count=1) + await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) + logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") + return + + # 统计邻居社区分布 + community_ids_in_neighbors = set( + nb["community_id"] for nb in neighbors if nb.get("community_id") + ) + + target_cid = _weighted_vote(neighbors, self_embedding) + + if target_cid is None: + # 邻居都没有社区,连同新实体一起创建新社区 + new_cid = self._new_community_id() + await self.repo.upsert_community(new_cid, end_user_id) + await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) + for nb in neighbors: + await self.repo.assign_entity_to_community( + nb["id"], new_cid, end_user_id + ) + await self.repo.refresh_member_count(new_cid, end_user_id) + logger.debug( + f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" + ) + else: + # 加入得票最多的社区 + await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) + await self.repo.refresh_member_count(target_cid, end_user_id) + logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}") + + # 若邻居分属多个社区,评估合并 + if len(community_ids_in_neighbors) > 1: + await self._evaluate_merge( + list(community_ids_in_neighbors), end_user_id + ) + + async def _evaluate_merge( + self, community_ids: List[str], end_user_id: str + ) -> None: + """ + 评估多个社区是否应合并。 + + 策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。 + 合并时保留成员数最多的社区,其余成员迁移过来。 + """ + MERGE_THRESHOLD = 0.75 + + community_embeddings: Dict[str, Optional[List[float]]] = {} + community_sizes: Dict[str, int] = {} + + for cid in community_ids: + members = await self.repo.get_community_members(cid, end_user_id) + community_sizes[cid] = len(members) + # 计算社区成员 embedding 的平均向量 + valid_embeddings = [ + m["name_embedding"] + for m in members + if m.get("name_embedding") + ] + if valid_embeddings: + dim = len(valid_embeddings[0]) + avg = [ + sum(e[i] for e in valid_embeddings) / len(valid_embeddings) + for i in range(dim) + ] + community_embeddings[cid] = avg + else: + community_embeddings[cid] = None + + # 找出应合并的社区对 + to_merge: List[tuple] = [] + cids = list(community_ids) + for i in range(len(cids)): + for j in range(i + 1, len(cids)): + sim = _cosine_similarity( + community_embeddings[cids[i]], + community_embeddings[cids[j]], + ) + if sim > MERGE_THRESHOLD: + to_merge.append((cids[i], cids[j])) + + for c1, c2 in to_merge: + keep = c1 if community_sizes.get(c1, 0) >= community_sizes.get(c2, 0) else c2 + dissolve = c2 if keep == c1 else c1 + members = await self.repo.get_community_members(dissolve, end_user_id) + for m in members: + await self.repo.assign_entity_to_community( + m["id"], keep, end_user_id + ) + await self.repo.refresh_member_count(keep, end_user_id) + logger.info( + f"[Clustering] 社区合并: {dissolve} → {keep}," + f"迁移 {len(members)} 个成员" + ) + + async def _flush_labels( + self, labels: Dict[str, str], end_user_id: str + ) -> None: + """将内存中的标签批量写入 Neo4j。""" + # 先创建所有唯一社区节点 + unique_communities = set(labels.values()) + for cid in unique_communities: + await self.repo.upsert_community(cid, end_user_id) + + # 再批量分配实体 + for entity_id, community_id in labels.items(): + await self.repo.assign_entity_to_community( + entity_id, community_id, end_user_id + ) + + # 刷新成员数 + for cid in unique_communities: + await self.repo.refresh_member_count(cid, end_user_id) + + async def _get_entity_embedding( + self, entity_id: str, end_user_id: str + ) -> Optional[List[float]]: + """查询单个实体的 name_embedding。""" + try: + result = await self.connector.execute_query( + "MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) " + "RETURN e.name_embedding AS name_embedding", + eid=entity_id, + uid=end_user_id, + ) + return result[0]["name_embedding"] if result else None + except Exception: + return None + + @staticmethod + def _new_community_id() -> str: + return str(uuid.uuid4()) diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py new file mode 100644 index 00000000..16e30a10 --- /dev/null +++ b/api/app/repositories/neo4j/community_repository.py @@ -0,0 +1,129 @@ +"""Community 节点仓库 + +管理 Neo4j 中 Community 节点及 BELONGS_TO_COMMUNITY 边的 CRUD 操作。 +""" + +import logging +from typing import Dict, List, Optional + +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.repositories.neo4j.cypher_queries import ( + COMMUNITY_NODE_UPSERT, + ENTITY_JOIN_COMMUNITY, + ENTITY_LEAVE_ALL_COMMUNITIES, + GET_ENTITY_NEIGHBORS, + GET_ALL_ENTITIES_FOR_USER, + GET_COMMUNITY_MEMBERS, + CHECK_USER_HAS_COMMUNITIES, + UPDATE_COMMUNITY_MEMBER_COUNT, +) + +logger = logging.getLogger(__name__) + + +class CommunityRepository: + def __init__(self, connector: Neo4jConnector): + self.connector = connector + + async def upsert_community( + self, community_id: str, end_user_id: str, member_count: int = 0 + ) -> Optional[str]: + """创建或更新 Community 节点,返回 community_id。""" + try: + result = await self.connector.execute_query( + COMMUNITY_NODE_UPSERT, + community_id=community_id, + end_user_id=end_user_id, + member_count=member_count, + ) + return result[0]["community_id"] if result else None + except Exception as e: + logger.error(f"upsert_community failed: {e}") + return None + + async def assign_entity_to_community( + self, entity_id: str, community_id: str, end_user_id: str + ) -> bool: + """将实体关联到社区(先解除旧关联,再建立新关联)。""" + try: + await self.connector.execute_query( + ENTITY_LEAVE_ALL_COMMUNITIES, + entity_id=entity_id, + end_user_id=end_user_id, + ) + result = await self.connector.execute_query( + ENTITY_JOIN_COMMUNITY, + entity_id=entity_id, + community_id=community_id, + end_user_id=end_user_id, + ) + return bool(result) + except Exception as e: + logger.error(f"assign_entity_to_community failed: {e}") + return False + + async def get_entity_neighbors( + self, entity_id: str, end_user_id: str + ) -> List[Dict]: + """查询实体的直接邻居及其社区归属。""" + try: + return await self.connector.execute_query( + GET_ENTITY_NEIGHBORS, + entity_id=entity_id, + end_user_id=end_user_id, + ) + except Exception as e: + logger.error(f"get_entity_neighbors failed: {e}") + return [] + + async def get_all_entities(self, end_user_id: str) -> List[Dict]: + """拉取某用户下所有实体及其当前社区归属。""" + try: + return await self.connector.execute_query( + GET_ALL_ENTITIES_FOR_USER, + end_user_id=end_user_id, + ) + except Exception as e: + logger.error(f"get_all_entities failed: {e}") + return [] + + async def get_community_members( + self, community_id: str, end_user_id: str + ) -> List[Dict]: + """查询社区成员列表。""" + try: + return await self.connector.execute_query( + GET_COMMUNITY_MEMBERS, + community_id=community_id, + end_user_id=end_user_id, + ) + except Exception as e: + logger.error(f"get_community_members failed: {e}") + return [] + + async def has_communities(self, end_user_id: str) -> bool: + """检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。""" + try: + result = await self.connector.execute_query( + CHECK_USER_HAS_COMMUNITIES, + end_user_id=end_user_id, + ) + return result[0]["community_count"] > 0 if result else False + except Exception as e: + logger.error(f"has_communities failed: {e}") + return False + + async def refresh_member_count( + self, community_id: str, end_user_id: str + ) -> int: + """重新统计并更新社区成员数,返回最新数量。""" + try: + result = await self.connector.execute_query( + UPDATE_COMMUNITY_MEMBER_COUNT, + community_id=community_id, + end_user_id=end_user_id, + ) + return result[0]["member_count"] if result else 0 + except Exception as e: + logger.error(f"refresh_member_count failed: {e}") + return 0 diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 651c513f..947097a2 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1058,4 +1058,95 @@ Graph_Node_query = """ 3 AS priority LIMIT $limit - """ \ No newline at end of file + """ + + +# ============================================================ +# Community 节点 & BELONGS_TO_COMMUNITY 边 +# ============================================================ + +COMMUNITY_NODE_SAVE = """ +MERGE (c:Community {community_id: $community_id}) +SET c.end_user_id = $end_user_id, + c.formed_at = $formed_at, + c.updated_at = datetime(), + c.status = $status, + c.member_count = $member_count +RETURN c.community_id AS community_id +""" + +COMMUNITY_ADD_MEMBER = """ +MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c) +SET c.updated_at = datetime(), + c.member_count = $member_count +""" + + + +# ─── Community 聚类相关 Cypher 模板 ─────────────────────────────────────────── + +COMMUNITY_NODE_UPSERT = """ +MERGE (c:Community {community_id: $community_id}) +SET c.end_user_id = $end_user_id, + c.member_count = $member_count, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" + +ENTITY_JOIN_COMMUNITY = """ +MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c) +SET c.updated_at = datetime() +RETURN e.id AS entity_id, c.community_id AS community_id +""" + +ENTITY_LEAVE_ALL_COMMUNITIES = """ +MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) +MATCH (e)-[r:BELONGS_TO_COMMUNITY]->(:Community) +DELETE r +""" + +GET_ENTITY_NEIGHBORS = """ +MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN DISTINCT + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, + nb.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + +GET_ALL_ENTITIES_FOR_USER = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN e.id AS id, + e.name AS name, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + +GET_COMMUNITY_MEMBERS = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) +RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, + e.importance_score AS importance_score, e.activation_value AS activation_value, + e.name_embedding AS name_embedding +ORDER BY coalesce(e.activation_value, 0) DESC +""" + +CHECK_USER_HAS_COMMUNITIES = """ +MATCH (c:Community {end_user_id: $end_user_id}) +RETURN count(c) AS community_count +""" + +UPDATE_COMMUNITY_MEMBER_COUNT = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) +WITH c, count(e) AS cnt +SET c.member_count = cnt +RETURN c.community_id AS community_id, cnt AS member_count +""" diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 526d16ec..a94bc23b 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,3 +1,4 @@ +import asyncio from typing import List # 使用新的仓储层 @@ -288,6 +289,14 @@ async def save_dialog_and_statements_to_neo4j( } logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) + + # 写入成功后,触发聚类 + if entity_nodes: + end_user_id = entity_nodes[0].end_user_id + new_entity_ids = [e.id for e in entity_nodes] + logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") + await _trigger_clustering(new_entity_ids, end_user_id) + return True except Exception as e: @@ -295,3 +304,28 @@ async def save_dialog_and_statements_to_neo4j( print(f"Neo4j integration error: {e}") print("Continuing without database storage...") return False + + +async def _trigger_clustering( + new_entity_ids: List[str], + end_user_id: str, +) -> None: + """ + 聚类触发函数,自动判断全量初始化还是增量更新。 + """ + connector = None + try: + from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine + logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") + connector = Neo4jConnector() + engine = LabelPropagationEngine(connector) + await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) + logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") + except Exception as e: + logger.error(f"[Clustering] 聚类触发失败: {e}", exc_info=True) + finally: + if connector: + try: + await connector.close() + except Exception: + pass From fc58ac0408c6110fcbd852ce85e17c9e95353ff0 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Wed, 11 Mar 2026 18:04:04 +0800 Subject: [PATCH 02/14] [changes] Initial stage of community integration --- .../clustering_engine/label_propagation.py | 94 ++++++++++++++----- .../neo4j/community_repository.py | 20 ++++ api/app/repositories/neo4j/cypher_queries.py | 48 +++++----- 3 files changed, 115 insertions(+), 47 deletions(-) diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 80e238fd..cb6e5804 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -141,8 +141,18 @@ class LabelPropagationEngine: # 将最终标签写入 Neo4j await self._flush_labels(labels, end_user_id) + pre_merge_count = len(set(labels.values())) logger.info( - f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区," + f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区," + f"{len(labels)} 个实体,开始后处理合并" + ) + + # 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度) + all_community_ids = list(set(labels.values())) + await self._evaluate_merge(all_community_ids, end_user_id) + + logger.info( + f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) @@ -221,30 +231,50 @@ class LabelPropagationEngine: 策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。 合并时保留成员数最多的社区,其余成员迁移过来。 + + 全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。 """ MERGE_THRESHOLD = 0.75 + BATCH_THRESHOLD = 20 # 超过此数量走批量查询 community_embeddings: Dict[str, Optional[List[float]]] = {} community_sizes: Dict[str, int] = {} - for cid in community_ids: - members = await self.repo.get_community_members(cid, end_user_id) - community_sizes[cid] = len(members) - # 计算社区成员 embedding 的平均向量 - valid_embeddings = [ - m["name_embedding"] - for m in members - if m.get("name_embedding") - ] - if valid_embeddings: - dim = len(valid_embeddings[0]) - avg = [ - sum(e[i] for e in valid_embeddings) / len(valid_embeddings) - for i in range(dim) + if len(community_ids) > BATCH_THRESHOLD: + # 批量查询:一次拉取所有社区成员 + all_members = await self.repo.get_all_community_members_batch( + community_ids, end_user_id + ) + for cid in community_ids: + members = all_members.get(cid, []) + community_sizes[cid] = len(members) + valid_embeddings = [ + m["name_embedding"] for m in members if m.get("name_embedding") ] - community_embeddings[cid] = avg - else: - community_embeddings[cid] = None + if valid_embeddings: + dim = len(valid_embeddings[0]) + community_embeddings[cid] = [ + sum(e[i] for e in valid_embeddings) / len(valid_embeddings) + for i in range(dim) + ] + else: + community_embeddings[cid] = None + else: + # 增量场景:逐个查询 + for cid in community_ids: + members = await self.repo.get_community_members(cid, end_user_id) + community_sizes[cid] = len(members) + valid_embeddings = [ + m["name_embedding"] for m in members if m.get("name_embedding") + ] + if valid_embeddings: + dim = len(valid_embeddings[0]) + community_embeddings[cid] = [ + sum(e[i] for e in valid_embeddings) / len(valid_embeddings) + for i in range(dim) + ] + else: + community_embeddings[cid] = None # 找出应合并的社区对 to_merge: List[tuple] = [] @@ -258,14 +288,32 @@ class LabelPropagationEngine: if sim > MERGE_THRESHOLD: to_merge.append((cids[i], cids[j])) + logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区") + + # 执行合并:用 union-find 思路避免重复迁移已被合并的社区 + # 维护一个 canonical 映射,确保链式合并正确收敛 + canonical: Dict[str, str] = {cid: cid for cid in cids} + + def find(x: str) -> str: + while canonical[x] != x: + canonical[x] = canonical[canonical[x]] + x = canonical[x] + return x + for c1, c2 in to_merge: - keep = c1 if community_sizes.get(c1, 0) >= community_sizes.get(c2, 0) else c2 - dissolve = c2 if keep == c1 else c1 + root1, root2 = find(c1), find(c2) + if root1 == root2: + continue # 已经在同一社区,跳过 + keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2 + dissolve = root2 if keep == root1 else root1 + canonical[dissolve] = keep + members = await self.repo.get_community_members(dissolve, end_user_id) for m in members: - await self.repo.assign_entity_to_community( - m["id"], keep, end_user_id - ) + await self.repo.assign_entity_to_community(m["id"], keep, end_user_id) + # 更新 sizes 以便后续合并决策准确 + community_sizes[keep] = community_sizes.get(keep, 0) + len(members) + community_sizes[dissolve] = 0 await self.repo.refresh_member_count(keep, end_user_id) logger.info( f"[Clustering] 社区合并: {dissolve} → {keep}," diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 16e30a10..2a1f4f2b 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -14,6 +14,7 @@ from app.repositories.neo4j.cypher_queries import ( GET_ENTITY_NEIGHBORS, GET_ALL_ENTITIES_FOR_USER, GET_COMMUNITY_MEMBERS, + GET_ALL_COMMUNITY_MEMBERS_BATCH, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, ) @@ -101,6 +102,25 @@ class CommunityRepository: logger.error(f"get_community_members failed: {e}") return [] + async def get_all_community_members_batch( + self, community_ids: List[str], end_user_id: str + ) -> Dict[str, List[Dict]]: + """批量查询多个社区的成员,返回 {community_id: [members]} 字典。""" + try: + rows = await self.connector.execute_query( + GET_ALL_COMMUNITY_MEMBERS_BATCH, + community_ids=community_ids, + end_user_id=end_user_id, + ) + result: Dict[str, List[Dict]] = {} + for row in rows: + cid = row["community_id"] + result.setdefault(cid, []).append(row) + return result + except Exception as e: + logger.error(f"get_all_community_members_batch failed: {e}") + return {} + async def has_communities(self, end_user_id: str) -> bool: """检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。""" try: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 947097a2..84889d65 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1065,26 +1065,6 @@ Graph_Node_query = """ # Community 节点 & BELONGS_TO_COMMUNITY 边 # ============================================================ -COMMUNITY_NODE_SAVE = """ -MERGE (c:Community {community_id: $community_id}) -SET c.end_user_id = $end_user_id, - c.formed_at = $formed_at, - c.updated_at = datetime(), - c.status = $status, - c.member_count = $member_count -RETURN c.community_id AS community_id -""" - -COMMUNITY_ADD_MEMBER = """ -MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c) -SET c.updated_at = datetime(), - c.member_count = $member_count -""" - - - # ─── Community 聚类相关 Cypher 模板 ─────────────────────────────────────────── COMMUNITY_NODE_UPSERT = """ @@ -1111,12 +1091,23 @@ DELETE r GET_ENTITY_NEIGHBORS = """ MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) -OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源一:直接关系邻居(EXTRACTED_RELATIONSHIP 边) +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源二:同 Statement 共现邻居(REFERENCES_ENTITY 边) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id + +WITH collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH nb WHERE nb IS NOT NULL OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) RETURN DISTINCT - nb.id AS id, - nb.name AS name, - nb.name_embedding AS name_embedding, + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, nb.activation_value AS activation_value, CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id """ @@ -1139,6 +1130,15 @@ RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, ORDER BY coalesce(e.activation_value, 0) DESC """ +GET_ALL_COMMUNITY_MEMBERS_BATCH = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) +WHERE c.community_id IN $community_ids +RETURN c.community_id AS community_id, + e.id AS id, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value +""" + CHECK_USER_HAS_COMMUNITIES = """ MATCH (c:Community {end_user_id: $end_user_id}) RETURN count(c) AS community_count From 7b8f101824055122631dd590f8f9b873211ec136 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 12 Mar 2026 20:27:50 +0800 Subject: [PATCH 03/14] [add] Create the attribute values of the community nodes --- .../core/memory/agent/utils/write_tools.py | 4 +- .../clustering_engine/label_propagation.py | 83 ++++++++++++++++++- .../neo4j/community_repository.py | 24 ++++++ api/app/repositories/neo4j/cypher_queries.py | 9 ++ api/app/repositories/neo4j/graph_saver.py | 20 +++-- 5 files changed, 132 insertions(+), 8 deletions(-) diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 22030278..b3707083 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -165,7 +165,9 @@ async def write( statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, - connector=neo4j_connector + connector=neo4j_connector, + config_id=config_id, + llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, ) if success: logger.info("Successfully saved all data to Neo4j") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index cb6e5804..251d4fea 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -19,6 +19,8 @@ logger = logging.getLogger(__name__) # 全量迭代最大轮数,防止不收敛 MAX_ITERATIONS = 10 +# 社区摘要核心实体数量 +CORE_ENTITY_LIMIT = 5 def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: @@ -62,9 +64,16 @@ def _weighted_vote( class LabelPropagationEngine: """标签传播聚类引擎""" - def __init__(self, connector: Neo4jConnector): + def __init__( + self, + connector: Neo4jConnector, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, + ): self.connector = connector self.repo = CommunityRepository(connector) + self.config_id = config_id + self.llm_model_id = llm_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -155,6 +164,10 @@ class LabelPropagationEngine: f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) + # 为所有社区生成元数据 + unique_communities = list(set(labels.values())) + for cid in unique_communities: + await self._generate_community_metadata(cid, end_user_id) async def incremental_update( self, new_entity_ids: List[str], end_user_id: str @@ -211,6 +224,7 @@ class LabelPropagationEngine: logger.debug( f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" ) + await self._generate_community_metadata(new_cid, end_user_id) else: # 加入得票最多的社区 await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) @@ -222,6 +236,7 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) + await self._generate_community_metadata(target_cid, end_user_id) async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -354,6 +369,72 @@ class LabelPropagationEngine: except Exception: return None + async def _generate_community_metadata( + self, community_id: str, end_user_id: str + ) -> None: + """ + 为社区生成并写入元数据:名称、摘要、核心实体。 + + - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) + - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 + """ + try: + members = await self.repo.get_community_members(community_id, end_user_id) + if not members: + return + + # 核心实体:按 activation_value 降序取 top-N + sorted_members = sorted( + members, + key=lambda m: m.get("activation_value") or 0, + reverse=True, + ) + core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] + all_names = [m["name"] for m in members if m.get("name")] + + name = "、".join(core_entities[:3]) if core_entities else community_id[:8] + summary = f"包含实体:{', '.join(all_names)}" + + # 若有 LLM 配置,调用 LLM 生成更好的名称和摘要 + if self.llm_model_id: + try: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + entity_list_str = "、".join(all_names) + prompt = ( + f"以下是一组语义相关的实体:{entity_list_str}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过50个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) + + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() + except Exception as e: + logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") + + await self.repo.update_community_metadata( + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + ) + logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}") + except Exception as e: + logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}") + @staticmethod def _new_community_id() -> str: return str(uuid.uuid4()) diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 2a1f4f2b..6c5c7618 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -17,6 +17,7 @@ from app.repositories.neo4j.cypher_queries import ( GET_ALL_COMMUNITY_MEMBERS_BATCH, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, + UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -147,3 +148,26 @@ class CommunityRepository: except Exception as e: logger.error(f"refresh_member_count failed: {e}") return 0 + + async def update_community_metadata( + self, + community_id: str, + end_user_id: str, + name: str, + summary: str, + core_entities: List[str], + ) -> bool: + """更新社区的名称、摘要和核心实体列表。""" + try: + result = await self.connector.execute_query( + UPDATE_COMMUNITY_METADATA, + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + ) + return bool(result) + except Exception as e: + logger.error(f"update_community_metadata failed: {e}") + return False diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 84889d65..b270ed64 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1150,3 +1150,12 @@ WITH c, count(e) AS cnt SET c.member_count = cnt RETURN c.community_id AS community_id, cnt AS member_count """ + +UPDATE_COMMUNITY_METADATA = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +SET c.name = $name, + c.summary = $summary, + c.core_entities = $core_entities, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index a94bc23b..2ef9bafc 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,5 +1,6 @@ import asyncio -from typing import List +import os +from typing import List, Optional # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -156,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j( entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + connector: Neo4jConnector, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. @@ -290,12 +293,15 @@ async def save_dialog_and_statements_to_neo4j( logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) - # 写入成功后,触发聚类 - if entity_nodes: + # 写入成功后,触发聚类(可通过环境变量 CLUSTERING_ENABLED=false 禁用,用于基准测试对比) + clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false" + if entity_nodes and clustering_enabled: end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - await _trigger_clustering(new_entity_ids, end_user_id) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) + elif entity_nodes and not clustering_enabled: + logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发") return True @@ -309,6 +315,8 @@ async def save_dialog_and_statements_to_neo4j( async def _trigger_clustering( new_entity_ids: List[str], end_user_id: str, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 @@ -318,7 +326,7 @@ async def _trigger_clustering( from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") connector = Neo4jConnector() - engine = LabelPropagationEngine(connector) + engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") except Exception as e: From f6d929ab7a48e9a5417f9f45f376de02a796058a Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 13 Mar 2026 12:59:36 +0800 Subject: [PATCH 04/14] [add] Community node interface development --- .../controllers/user_memory_controllers.py | 37 ++++ api/app/services/user_memory_service.py | 160 ++++++++++++++++++ 2 files changed, 197 insertions(+) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index d3fe7d83..be796ff9 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -17,6 +17,7 @@ from app.services.user_memory_service import ( UserMemoryService, analytics_memory_types, analytics_graph_data, + analytics_community_graph_data, ) from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction from app.schemas.response_schema import ApiResponse @@ -295,6 +296,42 @@ async def get_graph_data_api( return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e)) +@router.get("/analytics/community_graph", response_model=ApiResponse) +async def get_community_graph_data_api( + end_user_id: str, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), +) -> dict: + workspace_id = current_user.current_workspace_id + + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, " + f"workspace={workspace_id}" + ) + + try: + result = await analytics_community_graph_data(db=db, end_user_id=end_user_id) + + if "message" in result and result["statistics"]["total_nodes"] == 0: + api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}") + return success(data=result, msg=result.get("message", "查询成功")) + + api_logger.info( + f"成功获取社区图谱: end_user_id={end_user_id}, " + f"nodes={result['statistics']['total_nodes']}, " + f"edges={result['statistics']['total_edges']}" + ) + return success(data=result, msg="查询成功") + + except Exception as e: + api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}") + return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e)) + + @router.get("/read_end_user/profile", response_model=ApiResponse) async def get_end_user_profile( end_user_id: str, diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 8bacc112..d21df064 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -1727,6 +1727,166 @@ async def analytics_graph_data( # 辅助函数 +async def analytics_community_graph_data( + db: Session, + end_user_id: str, +) -> Dict[str, Any]: + """ + 获取社区图谱数据,包含 Community 节点、ExtractedEntity 节点及其关系。 + + Returns: + 包含 nodes、edges、statistics 的字典,格式与 analytics_graph_data 一致 + """ + try: + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + if not end_user: + return { + "nodes": [], "edges": [], + "statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}}, + "message": "用户不存在" + } + + # 查询社区节点、实体节点、BELONGS_TO_COMMUNITY 边、实体间关系 + cypher = """ + MATCH (c:Community {end_user_id: $end_user_id}) + MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c) + OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id}) + RETURN + elementId(c) AS c_id, + properties(c) AS c_props, + elementId(e) AS e_id, + properties(e) AS e_props, + elementId(b) AS b_id, + elementId(e2) AS e2_id, + properties(e2) AS e2_props, + elementId(r) AS r_id, + type(r) AS r_type, + properties(r) AS r_props, + startNode(r) = e AS r_from_e + """ + rows = await _neo4j_connector.execute_query(cypher, end_user_id=end_user_id) + + nodes_map: Dict[str, dict] = {} + edges_map: Dict[str, dict] = {} + # 记录每个 Community 对应的实体 id 列表 + community_members: Dict[str, list] = {} + + for row in rows: + # Community 节点 + c_id = row["c_id"] + if c_id and c_id not in nodes_map: + raw = row["c_props"] or {} + props = {k: _clean_neo4j_value(raw.get(k)) for k in ( + "community_id", "end_user_id", "member_count", "updated_at", + "name", "summary", "core_entities", + ) if k in raw} + nodes_map[c_id] = { + "id": c_id, + "label": "Community", + "properties": props, + } + + # ExtractedEntity 节点 (e) + e_id = row["e_id"] + if e_id and e_id not in nodes_map: + raw = row["e_props"] or {} + props = {k: _clean_neo4j_value(raw.get(k)) for k in ( + "name", "end_user_id", "description", "created_at", "entity_type", + ) if k in raw} + # 注入所属社区名称(c 是 e 直接归属的社区) + c_raw = row["c_props"] or {} + props["community_name"] = _clean_neo4j_value(c_raw.get("name")) or "" + nodes_map[e_id] = { + "id": e_id, + "label": "ExtractedEntity", + "properties": props, + } + + # ExtractedEntity 节点 (e2,可选) + e2_id = row.get("e2_id") + if e2_id and e2_id not in nodes_map: + raw = row["e2_props"] or {} + props = {k: _clean_neo4j_value(raw.get(k)) for k in ( + "name", "end_user_id", "description", "created_at", "entity_type", + ) if k in raw} + # e2 的社区归属在后处理阶段通过 community_members 补充 + props["community_name"] = "" + nodes_map[e2_id] = { + "id": e2_id, + "label": "ExtractedEntity", + "properties": props, + } + + # BELONGS_TO_COMMUNITY 边 + b_id = row["b_id"] + if b_id and b_id not in edges_map: + edges_map[b_id] = { + "id": b_id, + "source": e_id, + "target": c_id, + } + # 收集社区成员 id + if c_id and e_id: + community_members.setdefault(c_id, []) + if e_id not in community_members[c_id]: + community_members[c_id].append(e_id) + + # EXTRACTED_RELATIONSHIP 边(可选) + r_id = row.get("r_id") + if r_id and r_id not in edges_map and e2_id: + r_props = {k: _clean_neo4j_value(v) for k, v in (row["r_props"] or {}).items()} + source = e_id if row.get("r_from_e") else e2_id + target = e2_id if row.get("r_from_e") else e_id + edges_map[r_id] = { + "id": r_id, + "source": source, + "target": target, + } + + nodes = list(nodes_map.values()) + edges = list(edges_map.values()) + + # 为每个 Community 节点注入 member_entity_ids,同时补全 e2 节点的 community_name + for c_id, member_ids in community_members.items(): + c_node = nodes_map.get(c_id) + if c_node: + c_node["properties"]["member_entity_ids"] = member_ids + c_name = c_node["properties"].get("name") or "" + # 补全属于该社区但 community_name 为空的实体(即 e2 节点) + for eid in member_ids: + e_node = nodes_map.get(eid) + if e_node and e_node["label"] == "ExtractedEntity": + if not e_node["properties"].get("community_name"): + e_node["properties"]["community_name"] = c_name + + node_type_counts: Dict[str, int] = {} + for n in nodes: + node_type_counts[n["label"]] = node_type_counts.get(n["label"], 0) + 1 + + return { + "nodes": nodes, + "edges": edges, + "statistics": { + "total_nodes": len(nodes), + "total_edges": len(edges), + "node_types": node_type_counts, + } + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "nodes": [], "edges": [], + "statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}}, + "message": "无效的用户ID格式" + } + except Exception as e: + logger.error(f"获取社区图谱数据失败: {str(e)}", exc_info=True) + raise + + async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]: """ 根据节点类型提取需要的属性字段 From 6a0ee22d8145d17654b010656e52f7ead22bc92c Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 13 Mar 2026 14:43:29 +0800 Subject: [PATCH 05/14] [add] Create trigger events for the purpose of completing the existing data --- api/app/celery_app.py | 1 + .../memory_dashboard_controller.py | 14 +- .../clustering_engine/label_propagation.py | 70 +++++++-- api/app/tasks.py | 138 ++++++++++++++++++ 4 files changed, 206 insertions(+), 17 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 0319e079..dbfa9d51 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -113,6 +113,7 @@ celery_app.conf.update( 'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'}, 'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'}, 'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'}, + 'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'}, }, ) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 1b5b45fb..f01445d3 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -177,7 +177,19 @@ async def get_workspace_end_users( await aio_redis_set(cache_key, json.dumps(result), expire=30) except Exception as e: api_logger.warning(f"Redis 缓存写入失败: {str(e)}") - + + # 触发社区聚类补全任务(异步,不阻塞接口响应) + # 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类 + try: + from app.tasks import init_community_clustering_for_users + init_community_clustering_for_users.apply_async( + kwargs={"end_user_ids": end_user_ids}, + queue="periodic_tasks", + ) + api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}") + except Exception as e: + api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") + api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") return success(data=result, msg="宿主列表获取成功") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 251d4fea..4491b416 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -165,8 +165,15 @@ class LabelPropagationEngine: f"{len(labels)} 个实体" ) # 为所有社区生成元数据 - unique_communities = list(set(labels.values())) - for cid in unique_communities: + # 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区 + # 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID + surviving_communities = await self.repo.get_all_entities(end_user_id) + surviving_community_ids = list({ + e.get("community_id") for e in surviving_communities + if e.get("community_id") + }) + logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") + for cid in surviving_community_ids: await self._generate_community_metadata(cid, end_user_id) async def incremental_update( @@ -249,7 +256,7 @@ class LabelPropagationEngine: 全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。 """ - MERGE_THRESHOLD = 0.75 + MERGE_THRESHOLD = 0.85 BATCH_THRESHOLD = 20 # 超过此数量走批量查询 community_embeddings: Dict[str, Optional[List[float]]] = {} @@ -305,34 +312,65 @@ class LabelPropagationEngine: logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区") - # 执行合并:用 union-find 思路避免重复迁移已被合并的社区 - # 维护一个 canonical 映射,确保链式合并正确收敛 - canonical: Dict[str, str] = {cid: cid for cid in cids} + # 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量 + # 避免 union-find 链式传递导致语义不相关的社区被间接合并 + # (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并) + merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射 - def find(x: str) -> str: - while canonical[x] != x: - canonical[x] = canonical[canonical[x]] - x = canonical[x] + def get_root(x: str) -> str: + """路径压缩,找到 x 当前所属的根社区。""" + while x in merged_into: + merged_into[x] = merged_into.get(merged_into[x], merged_into[x]) + x = merged_into[x] return x for c1, c2 in to_merge: - root1, root2 = find(c1), find(c2) + root1, root2 = get_root(c1), get_root(c2) if root1 == root2: - continue # 已经在同一社区,跳过 + continue + + # 用合并后的最新平均向量重新验证相似度 + # 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并 + current_sim = _cosine_similarity( + community_embeddings.get(root1), + community_embeddings.get(root2), + ) + if current_sim <= MERGE_THRESHOLD: + # 合并后向量已漂移,不再满足阈值,跳过 + logger.debug( + f"[Clustering] 跳过合并 {root1} ↔ {root2}," + f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}" + ) + continue + keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2 dissolve = root2 if keep == root1 else root1 - canonical[dissolve] = keep + merged_into[dissolve] = keep members = await self.repo.get_community_members(dissolve, end_user_id) for m in members: await self.repo.assign_entity_to_community(m["id"], keep, end_user_id) - # 更新 sizes 以便后续合并决策准确 - community_sizes[keep] = community_sizes.get(keep, 0) + len(members) + + # 合并后重新计算 keep 的平均向量(加权平均) + keep_emb = community_embeddings.get(keep) + dissolve_emb = community_embeddings.get(dissolve) + keep_size = community_sizes.get(keep, 0) + dissolve_size = community_sizes.get(dissolve, 0) + total_size = keep_size + dissolve_size + if keep_emb and dissolve_emb and total_size > 0: + dim = len(keep_emb) + community_embeddings[keep] = [ + (keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size + for i in range(dim) + ] + community_embeddings[dissolve] = None + + community_sizes[keep] = total_size community_sizes[dissolve] = 0 await self.repo.refresh_member_count(keep, end_user_id) logger.info( f"[Clustering] 社区合并: {dissolve} → {keep}," - f"迁移 {len(members)} 个成员" + f"相似度={current_sim:.3f},迁移 {len(members)} 个成员" ) async def _flush_labels( diff --git a/api/app/tasks.py b/api/app/tasks.py index a6ebbb8e..134d4744 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2416,3 +2416,141 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: "elapsed_time": elapsed_time, "task_id": self.request.id } + + +# ============================================================================= +# 社区聚类补全任务(触发型) +# ============================================================================= + +@celery_app.task( + name="app.tasks.init_community_clustering_for_users", + bind=True, + ignore_result=False, + max_retries=0, + acks_late=False, + time_limit=7200, # 2小时硬超时 + soft_time_limit=6900, +) +def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: + """触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。 + + 由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。 + + Args: + end_user_ids: 需要检查的用户 ID 列表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.neo4j.community_repository import CommunityRepository + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + + logger = get_logger(__name__) + logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}") + + initialized = 0 + skipped = 0 + failed = 0 + + connector = Neo4jConnector() + try: + repo = CommunityRepository(connector) + + # 获取 llm_model_id(从第一个用户的配置中读取,作为全局兜底) + llm_model_id = None + try: + with get_db_context() as db: + from app.services.memory_agent_service import get_end_user_connected_config + from app.services.memory_config_service import MemoryConfigService + for uid in end_user_ids: + try: + connected = get_end_user_connected_config(uid, db) + config_id = connected.get("memory_config_id") + workspace_id = connected.get("workspace_id") + if config_id or workspace_id: + cfg = MemoryConfigService(db).load_memory_config( + config_id=config_id, workspace_id=workspace_id + ) + llm_model_id = str(cfg.llm_model_id) + break + except Exception: + continue + except Exception as e: + logger.warning(f"[CommunityCluster] 获取 LLM 配置失败,将使用兜底值: {e}") + + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + ) + + for end_user_id in end_user_ids: + try: + # 已有社区节点则跳过 + has_communities = await repo.has_communities(end_user_id) + if has_communities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过") + continue + + # 检查是否有 ExtractedEntity 节点 + entities = await repo.get_all_entities(end_user_id) + if not entities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过") + continue + + logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类") + await engine.full_clustering(end_user_id) + initialized += 1 + logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") + + except Exception as e: + failed += 1 + logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}") + + finally: + await connector.close() + + logger.info( + f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}" + ) + return { + "status": "SUCCESS", + "initialized": initialized, + "skipped": skipped, + "failed": failed, + } + + try: + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + return result + + except Exception as e: + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } From 01a1e8eab1ee97dbb04b6c79d1d3fd085b491b0f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Fri, 13 Mar 2026 14:50:21 +0800 Subject: [PATCH 06/14] [changes] Update the pointers in the main repository to point to the submodules --- redbear-mem-benchmark | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redbear-mem-benchmark b/redbear-mem-benchmark index 8494e824..b4ddbe9b 160000 --- a/redbear-mem-benchmark +++ b/redbear-mem-benchmark @@ -1 +1 @@ -Subproject commit 8494e82498cb99c70ac67a64a544ff872432363a +Subproject commit b4ddbe9b19014bb8d2d20f1b41eb656d03a5e5ed From c244e9834f20a121789d8b7bdccb3704b270ec1f Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 12:30:00 +0800 Subject: [PATCH 07/14] [changes] Community Clustering Retrieval Module --- .../memory/agent/services/search_service.py | 38 +- .../core/memory/agent/utils/write_tools.py | 11 +- api/app/core/memory/src/search.py | 12 +- .../clustering_engine/label_propagation.py | 224 ++++++++-- api/app/main.py | 12 +- .../neo4j/community_repository.py | 65 +++ api/app/repositories/neo4j/cypher_queries.py | 134 ++++++ api/app/repositories/neo4j/graph_saver.py | 45 +- api/app/repositories/neo4j/graph_search.py | 79 ++++ api/app/repositories/neo4j/index_manager.py | 254 ++++++++++++ api/app/tasks.py | 388 ++++++++++++++++++ redbear-mem-benchmark | 2 +- 12 files changed, 1203 insertions(+), 61 deletions(-) create mode 100644 api/app/repositories/neo4j/index_manager.py diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 4fc4256e..2be18c97 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -120,7 +120,7 @@ class SearchService: raw_results is None if return_raw_results=False """ if include is None: - include = ["statements", "chunks", "entities", "summaries"] + include = ["statements", "chunks", "entities", "summaries", "communities"] # Clean query cleaned_query = self.clean_query(question) @@ -146,8 +146,8 @@ class SearchService: if search_type == "hybrid": reranked_results = answer.get('reranked_results', {}) - # Priority order: summaries first (most contextual), then statements, chunks, entities - priority_order = ['summaries', 'statements', 'chunks', 'entities'] + # Priority order: summaries first (most contextual), then communities, statements, chunks, entities + priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] for category in priority_order: if category in include and category in reranked_results: @@ -157,13 +157,43 @@ class SearchService: else: # For keyword or embedding search, results are directly in answer dict # Apply same priority order - priority_order = ['summaries', 'statements', 'chunks', 'entities'] + priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] for category in priority_order: if category in include and category in answer: category_results = answer[category] if isinstance(category_results, list): answer_list.extend(category_results) + + # 对命中的 community 节点展开其成员 statements + if "communities" in include: + community_results = ( + answer.get('reranked_results', {}).get('communities', []) + if search_type == "hybrid" + else answer.get('communities', []) + ) + community_ids = [ + r.get("id") for r in community_results if r.get("id") + ] + if community_ids and end_user_id: + try: + from app.repositories.neo4j.graph_search import search_graph_community_expand + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + connector = Neo4jConnector() + expand_result = await search_graph_community_expand( + connector=connector, + community_ids=community_ids, + end_user_id=end_user_id, + limit=10, + ) + await connector.close() + expanded_stmts = expand_result.get("expanded_statements", []) + if expanded_stmts: + # 展开的 statements 插入 communities 之后、statements 之前 + answer_list.extend(expanded_stmts) + logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements") + except Exception as e: + logger.warning(f"社区展开检索失败,跳过: {e}") # Extract clean content from all results content_list = [ diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 22030278..4e71f2c5 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -165,10 +165,17 @@ async def write( statement_chunk_edges=all_statement_chunk_edges, statement_entity_edges=all_statement_entity_edges, entity_edges=all_entity_entity_edges, - connector=neo4j_connector + connector=neo4j_connector, ) if success: logger.info("Successfully saved all data to Neo4j") + # 写入成功后,异步触发聚类(不阻塞写入响应) + schedule_clustering_after_write( + all_entity_nodes, + config_id=config_id, + llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, + embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + ) break else: logger.warning("Failed to save some data to Neo4j") diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 0e1d8424..3570d707 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -238,7 +238,7 @@ def rerank_with_activation( reranked: Dict[str, List[Dict[str, Any]]] = {} - for category in ["statements", "chunks", "entities", "summaries"]: + for category in ["statements", "chunks", "entities", "summaries", "communities"]: keyword_items = keyword_results.get(category, []) embedding_items = embedding_results.get(category, []) @@ -281,21 +281,23 @@ def rerank_with_activation( for item in items_list: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") if item_id and item_id in combined_items: - combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) + combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value") # 步骤 4: 计算基础分数和最终分数 for item_id, item in combined_items.items(): bm25_norm = float(item.get("bm25_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0) - act_norm = float(item.get("normalized_activation_value", 0) or 0) + # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义 + raw_act_norm = item.get("normalized_activation_value") + act_norm = float(raw_act_norm) if raw_act_norm is not None else None # 第一阶段:只考虑内容相关性(BM25 + Embedding) # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 content_score = alpha * bm25_norm + (1 - alpha) * emb_norm base_score = content_score # 第一阶段用内容分数 - # 存储激活度分数供第二阶段使用 - item["activation_score"] = act_norm + # 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序) + item["activation_score"] = act_norm # 可能为 None item["content_score"] = content_score item["base_score"] = base_score diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 80e238fd..a116ba3b 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -20,6 +20,9 @@ logger = logging.getLogger(__name__) # 全量迭代最大轮数,防止不收敛 MAX_ITERATIONS = 10 +# 社区核心实体取 top-N 数量 +CORE_ENTITY_LIMIT = 10 + def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: """计算两个向量的余弦相似度,任一为空则返回 0。""" @@ -62,9 +65,18 @@ def _weighted_vote( class LabelPropagationEngine: """标签传播聚类引擎""" - def __init__(self, connector: Neo4jConnector): + def __init__( + self, + connector: Neo4jConnector, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, + ): self.connector = connector self.repo = CommunityRepository(connector) + self.config_id = config_id + self.llm_model_id = llm_model_id + self.embedding_model_id = embedding_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -94,58 +106,110 @@ class LabelPropagationEngine: async def full_clustering(self, end_user_id: str) -> None: """ - 全量标签传播初始化。 + 全量标签传播初始化(分批处理,控制内存峰值)。 - 1. 拉取所有实体,初始化每个实体为独立社区 - 2. 迭代:每轮对所有实体做邻居投票,更新社区标签 - 3. 直到标签不再变化或达到 MAX_ITERATIONS - 4. 将最终标签写入 Neo4j + 策略: + - 每次只加载 BATCH_SIZE 个实体及其邻居进内存 + - labels 字典跨批次共享(只存 id→community_id,内存极小) + - 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息 + - 所有批次完成后统一 flush 和 merge """ - entities = await self.repo.get_all_entities(end_user_id) - if not entities: + BATCH_SIZE = 2000 # 每批实体数,可按需调整 + + # 先查总数,决定批次数 + total_entities = await self.repo.get_all_entities(end_user_id) + if not total_entities: logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") return - # 初始化:每个实体持有自己 id 作为社区标签 - labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} - embeddings: Dict[str, Optional[List[float]]] = { - e["id"]: e.get("name_embedding") for e in entities - } + total_count = len(total_entities) + logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体," + f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批") - for iteration in range(MAX_ITERATIONS): - changed = 0 - # 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历) - for entity in entities: - eid = entity["id"] - neighbors = await self.repo.get_entity_neighbors(eid, end_user_id) + # labels 跨批次共享:先用全量数据初始化(只存 id,内存极小) + labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities} + # embeddings 也跨批次共享(每个向量 ~6KB,10万实体约 600MB,这是不可避免的) + # 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻 + # 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding, + # 所以只需要当前批次实体的 embedding,不需要全量 + del total_entities # 释放全量列表,后续按批次加载 - # 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值) - enriched = [] - for nb in neighbors: - nb_copy = dict(nb) - nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) - enriched.append(nb_copy) - - new_label = _weighted_vote(enriched, embeddings.get(eid)) - if new_label and new_label != labels[eid]: - labels[eid] = new_label - changed += 1 - - logger.info( - f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}," - f"标签变化数: {changed}" + for batch_start in range(0, total_count, BATCH_SIZE): + batch_entities = await self.repo.get_entities_page( + end_user_id, skip=batch_start, limit=BATCH_SIZE ) - if changed == 0: - logger.info("[Clustering] 标签已收敛,提前结束迭代") + if not batch_entities: break - # 将最终标签写入 Neo4j + batch_ids = [e["id"] for e in batch_entities] + batch_embeddings: Dict[str, Optional[List[float]]] = { + e["id"]: e.get("name_embedding") for e in batch_entities + } + + logger.info( + f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:" + f"加载 {len(batch_entities)} 个实体的邻居图..." + ) + neighbors_cache = await self.repo.get_entity_neighbors_for_ids( + batch_ids, end_user_id + ) + logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}") + + for iteration in range(MAX_ITERATIONS): + changed = 0 + for entity in batch_entities: + eid = entity["id"] + neighbors = neighbors_cache.get(eid, []) + + # 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值) + enriched = [] + for nb in neighbors: + nb_copy = dict(nb) + nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) + enriched.append(nb_copy) + + new_label = _weighted_vote(enriched, batch_embeddings.get(eid)) + if new_label and new_label != labels[eid]: + labels[eid] = new_label + changed += 1 + + logger.info( + f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} " + f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}" + ) + if changed == 0: + logger.info("[Clustering] 标签已收敛,提前结束本批迭代") + break + + # 释放本批次的大对象 + del neighbors_cache, batch_embeddings, batch_entities + + # 所有批次完成,统一写入 Neo4j await self._flush_labels(labels, end_user_id) + pre_merge_count = len(set(labels.values())) logger.info( - f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区," + f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区," + f"{len(labels)} 个实体,开始后处理合并" + ) + + all_community_ids = list(set(labels.values())) + await self._evaluate_merge(all_community_ids, end_user_id) + + logger.info( + f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) + # 查询存活社区并生成元数据 + surviving_communities = await self.repo.get_all_entities(end_user_id) + surviving_community_ids = list({ + e.get("community_id") for e in surviving_communities + if e.get("community_id") + }) + logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") + for cid in surviving_community_ids: + await self._generate_community_metadata(cid, end_user_id) + async def incremental_update( self, new_entity_ids: List[str], end_user_id: str ) -> None: @@ -306,6 +370,90 @@ class LabelPropagationEngine: except Exception: return None + async def _generate_community_metadata( + self, community_id: str, end_user_id: str + ) -> None: + """ + 为社区生成并写入元数据:名称、摘要、核心实体。 + + - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) + - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 + NOTE: core_entities按照激活值高低排序,会造成对边缘信息检索返回消息质量不高。 + """ + try: + members = await self.repo.get_community_members(community_id, end_user_id) + if not members: + return + + # 核心实体:按 activation_value 降序取 top-N + sorted_members = sorted( + members, + key=lambda m: m.get("activation_value") or 0, + reverse=True, + ) + core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] + all_names = [m["name"] for m in members if m.get("name")] + + name = "、".join(core_entities[:3]) if core_entities else community_id[:8] + summary = f"包含实体:{', '.join(all_names)}" + + # 若有 LLM 配置,调用 LLM 生成更好的名称和摘要 + if self.llm_model_id: + try: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + entity_list_str = "、".join(all_names) + prompt = ( + f"以下是一组语义相关的实体:{entity_list_str}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过50个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) + + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() + except Exception as e: + logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") + + # 生成 summary_embedding + summary_embedding = None + if self.embedding_model_id and summary: + try: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + results = await embedder.response([summary]) + summary_embedding = results[0] if results else None + except Exception as e: + logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}") + + result = await self.repo.update_community_metadata( + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + summary_embedding=summary_embedding, + ) + if result: + logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...") + else: + logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False") + except Exception as e: + logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True) + @staticmethod def _new_community_id() -> str: return str(uuid.uuid4()) diff --git a/api/app/main.py b/api/app/main.py index af5ed796..21f56766 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -18,6 +18,7 @@ from app.core.logging_config import LoggingConfig, get_logger from app.core.response_utils import fail from app.core.models.scripts.loader import load_models from app.db import get_db_context +from app.repositories.neo4j.index_manager import ensure_indexes # Initialize logging system LoggingConfig.setup_logging() @@ -61,9 +62,18 @@ async def lifespan(app: FastAPI): else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") + # 确保 Neo4j 索引存在(幂等,多环境安全) + try: + report = await ensure_indexes() + if report["errors"]: + logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}") + else: + logger.info(f"Neo4j 索引检查完成 [{report['uri']}]") + except Exception as e: + logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}") + logger.info("应用程序启动完成") yield - # 应用关闭事件 logger.info("应用程序正在关闭") diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 16e30a10..78ecf6f6 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -13,9 +13,14 @@ from app.repositories.neo4j.cypher_queries import ( ENTITY_LEAVE_ALL_COMMUNITIES, GET_ENTITY_NEIGHBORS, GET_ALL_ENTITIES_FOR_USER, + GET_ENTITIES_PAGE, GET_COMMUNITY_MEMBERS, + GET_ALL_COMMUNITY_MEMBERS_BATCH, + GET_ALL_ENTITY_NEIGHBORS_BATCH, + GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, + UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -87,6 +92,41 @@ class CommunityRepository: logger.error(f"get_all_entities failed: {e}") return [] + async def get_entities_page( + self, end_user_id: str, skip: int, limit: int + ) -> List[Dict]: + """分页拉取实体,用于全量聚类分批处理。""" + try: + return await self.connector.execute_query( + GET_ENTITIES_PAGE, + end_user_id=end_user_id, + skip=skip, + limit=limit, + ) + except Exception as e: + logger.error(f"get_entities_page failed: {e}") + return [] + + async def get_entity_neighbors_for_ids( + self, entity_ids: List[str], end_user_id: str + ) -> Dict[str, List[Dict]]: + """批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。""" + try: + rows = await self.connector.execute_query( + GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS, + entity_ids=entity_ids, + end_user_id=end_user_id, + ) + result: Dict[str, List[Dict]] = {} + for row in rows: + eid = row["entity_id"] + neighbor = {k: v for k, v in row.items() if k != "entity_id"} + result.setdefault(eid, []).append(neighbor) + return result + except Exception as e: + logger.error(f"get_entity_neighbors_for_ids failed: {e}") + return {} + async def get_community_members( self, community_id: str, end_user_id: str ) -> List[Dict]: @@ -127,3 +167,28 @@ class CommunityRepository: except Exception as e: logger.error(f"refresh_member_count failed: {e}") return 0 + + async def update_community_metadata( + self, + community_id: str, + end_user_id: str, + name: str, + summary: str, + core_entities: List[str], + summary_embedding: Optional[List[float]] = None, + ) -> bool: + """更新社区的名称、摘要、核心实体列表和摘要向量。""" + try: + result = await self.connector.execute_query( + UPDATE_COMMUNITY_METADATA, + community_id=community_id, + end_user_id=end_user_id, + name=name, + summary=summary, + core_entities=core_entities, + summary_embedding=summary_embedding, + ) + return bool(result) + except Exception as e: + logger.error(f"update_community_metadata failed: {e}") + return False diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 947097a2..b42351b0 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1139,6 +1139,15 @@ RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, ORDER BY coalesce(e.activation_value, 0) DESC """ +GET_ALL_COMMUNITY_MEMBERS_BATCH = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN c.community_id AS community_id, + e.id AS id, e.name AS name, e.entity_type AS entity_type, + e.importance_score AS importance_score, e.activation_value AS activation_value, + e.name_embedding AS name_embedding +ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC +""" + CHECK_USER_HAS_COMMUNITIES = """ MATCH (c:Community {end_user_id: $end_user_id}) RETURN count(c) AS community_count @@ -1150,3 +1159,128 @@ WITH c, count(e) AS cnt SET c.member_count = cnt RETURN c.community_id AS community_id, cnt AS member_count """ + +UPDATE_COMMUNITY_METADATA = """ +MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) +SET c.name = $name, + c.summary = $summary, + c.core_entities = $core_entities, + c.summary_embedding = $summary_embedding, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" + +GET_ENTITIES_PAGE = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN e.id AS id, + e.name AS name, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +ORDER BY e.id +SKIP $skip LIMIT $limit +""" + +GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """ +// 批量拉取指定实体列表的邻居(用于分批全量聚类) +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +WHERE e.id IN $entity_ids +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id +WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH e, nb WHERE nb IS NOT NULL +OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN DISTINCT + e.id AS entity_id, + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, + nb.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + +GET_ALL_ENTITY_NEIGHBORS_BATCH = """ +// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载) +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源一:直接关系邻居 +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源二:同 Statement 共现邻居 +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id + +WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH e, nb WHERE nb IS NOT NULL +OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN DISTINCT + e.id AS entity_id, + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, + nb.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + + +# Community keyword search: matches name or summary via fulltext index +SEARCH_COMMUNITIES_BY_KEYWORD = """ +CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score +WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) +RETURN c.community_id AS id, + c.name AS name, + c.summary AS content, + c.core_entities AS core_entities, + c.member_count AS member_count, + c.end_user_id AS end_user_id, + c.updated_at AS updated_at, + score +ORDER BY score DESC +LIMIT $limit +""" + +# Community 向量检索 ────────────────────────────────────────────────── +# Community embedding-based search: cosine similarity on Community.summary_embedding +COMMUNITY_EMBEDDING_SEARCH = """ +CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding) +YIELD node AS c, score +WHERE c.summary_embedding IS NOT NULL + AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id) +RETURN c.community_id AS id, + c.name AS name, + c.summary AS content, + c.core_entities AS core_entities, + c.member_count AS member_count, + c.end_user_id AS end_user_id, + c.updated_at AS updated_at, + score +ORDER BY score DESC +LIMIT $limit +""" + +# Community 展开检索 ────────────────────────────────────────────────── +# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索) +EXPAND_COMMUNITY_STATEMENTS = """ +MATCH (c:Community {community_id: $community_id}) +MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c) +MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +WHERE s.end_user_id = $end_user_id +RETURN s.statement AS statement, + s.id AS id, + s.end_user_id AS end_user_id, + s.created_at AS created_at, + s.valid_at AS valid_at, + s.invalid_at AS invalid_at, + COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value, + COALESCE(s.importance_score, 0.5) AS importance_score, + e.name AS source_entity, + c.name AS community_name +ORDER BY COALESCE(s.activation_value, 0) DESC +LIMIT $limit +""" diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index a94bc23b..29e337f1 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,5 +1,5 @@ import asyncio -from typing import List +from typing import List, Optional # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -156,10 +156,13 @@ async def save_dialog_and_statements_to_neo4j( entity_edges: List[EntityEntityEdge], statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], - connector: Neo4jConnector + connector: Neo4jConnector, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. + 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 + schedule_clustering_after_write() 显式触发。 + Args: dialogue_nodes: List of DialogueNode objects to save chunk_nodes: List of ChunkNode objects to save @@ -290,13 +293,6 @@ async def save_dialog_and_statements_to_neo4j( logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) - # 写入成功后,触发聚类 - if entity_nodes: - end_user_id = entity_nodes[0].end_user_id - new_entity_ids = [e.id for e in entity_nodes] - logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - await _trigger_clustering(new_entity_ids, end_user_id) - return True except Exception as e: @@ -306,9 +302,38 @@ async def save_dialog_and_statements_to_neo4j( return False +def schedule_clustering_after_write( + entity_nodes: List, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, +) -> None: + """ + 写入 Neo4j 成功后,调度后台聚类任务。 + + 可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。 + 使用 asyncio.create_task 异步触发,不阻塞写入响应。 + """ + if not entity_nodes: + return + + clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false" + if not clustering_enabled: + logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发") + return + + end_user_id = entity_nodes[0].end_user_id + new_entity_ids = [e.id for e in entity_nodes] + logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) + + async def _trigger_clustering( new_entity_ids: List[str], end_user_id: str, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 @@ -318,7 +343,7 @@ async def _trigger_clustering( from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") connector = Neo4jConnector() - engine = LabelPropagationEngine(connector) + engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") except Exception as e: diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index e8f52535..19e40a82 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional from app.repositories.neo4j.cypher_queries import ( CHUNK_EMBEDDING_SEARCH, + COMMUNITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH, + EXPAND_COMMUNITY_STATEMENTS, MEMORY_SUMMARY_EMBEDDING_SEARCH, SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNKS_BY_CONTENT, + SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_ENTITIES_BY_NAME, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, @@ -285,6 +288,15 @@ async def search_graph( limit=limit, )) task_keys.append("summaries") + + if "communities" in include: + tasks.append(connector.execute_query( + SEARCH_COMMUNITIES_BY_KEYWORD, + q=q, + end_user_id=end_user_id, + limit=limit, + )) + task_keys.append("communities") # Execute all queries in parallel task_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -396,6 +408,16 @@ async def search_graph_by_embedding( )) task_keys.append("summaries") + # Communities (向量语义匹配) + if "communities" in include: + tasks.append(connector.execute_query( + COMMUNITY_EMBEDDING_SEARCH, + embedding=embedding, + end_user_id=end_user_id, + limit=limit, + )) + task_keys.append("communities") + # Execute all queries in parallel query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -408,6 +430,7 @@ async def search_graph_by_embedding( "chunks": [], "entities": [], "summaries": [], + "communities": [], } for key, result in zip(task_keys, task_results): @@ -661,6 +684,62 @@ async def search_graph_by_chunk_id( return {"chunks": chunks} +async def search_graph_community_expand( + connector: Neo4jConnector, + community_ids: List[str], + end_user_id: str, + limit: int = 10, +) -> Dict[str, List[Dict[str, Any]]]: + """ + 三期:社区展开检索 —— 主题 → 细节两级检索。 + + 命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体, + 再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点, + 按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。 + + Args: + connector: Neo4j 连接器 + community_ids: 已命中的社区 ID 列表 + end_user_id: 用户 ID,用于数据隔离 + limit: 每个社区最多返回的 Statement 数量 + + Returns: + {"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]} + """ + if not community_ids or not end_user_id: + return {"expanded_statements": []} + + tasks = [ + connector.execute_query( + EXPAND_COMMUNITY_STATEMENTS, + community_id=cid, + end_user_id=end_user_id, + limit=limit, + ) + for cid in community_ids + ] + + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + expanded: List[Dict[str, Any]] = [] + for cid, result in zip(community_ids, task_results): + if isinstance(result, Exception): + logger.warning(f"社区展开检索失败 community_id={cid}: {result}") + else: + expanded.extend(result) + + # 按 activation_value 全局排序后去重 + from app.core.memory.src.search import _deduplicate_results + expanded.sort( + key=lambda x: float(x.get("activation_value") or 0), + reverse=True, + ) + expanded = _deduplicate_results(expanded) + + logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}") + return {"expanded_statements": expanded} + + async def search_graph_by_created_at( connector: Neo4jConnector, end_user_id: Optional[str] = None, diff --git a/api/app/repositories/neo4j/index_manager.py b/api/app/repositories/neo4j/index_manager.py new file mode 100644 index 00000000..a1ab6689 --- /dev/null +++ b/api/app/repositories/neo4j/index_manager.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- +"""Neo4j 索引管理模块 + +负责检查和创建 Neo4j 全文索引与向量索引。 +支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。 + +用法: + # 作为模块调用(应用启动时) + from app.repositories.neo4j.index_manager import ensure_indexes + await ensure_indexes() + + # 作为独立脚本执行(手动建索引) + python -m app.repositories.neo4j.index_manager +""" + +import asyncio +import logging +from dataclasses import dataclass +from typing import List + +from app.core.config import settings +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = logging.getLogger(__name__) + + +# ───────────────────────────────────────────────────────────── +# 索引定义表 +# ───────────────────────────────────────────────────────────── + +@dataclass +class FulltextIndexDef: + name: str + label: str + properties: List[str] + + +@dataclass +class VectorIndexDef: + name: str + label: str + property: str + dimensions: int + similarity: str = "cosine" + + +# 全文索引清单(现有 + 新增 communities) +FULLTEXT_INDEXES: List[FulltextIndexDef] = [ + FulltextIndexDef("statementsFulltext", "Statement", ["statement"]), + FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]), + FulltextIndexDef("chunksFulltext", "Chunk", ["content"]), + FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]), + FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源 +] + +# 向量索引清单(预留 community 二期) +VECTOR_INDEXES: List[VectorIndexDef] = [ + VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536), + VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536), + VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536), + VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536), + # 二期:社区向量索引 + VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536), +] + + +# ───────────────────────────────────────────────────────────── +# 核心检查 / 创建逻辑 +# ───────────────────────────────────────────────────────────── + +async def _get_existing_indexes(connector: Neo4jConnector) -> set: + """查询 Neo4j 中已存在的索引名称集合""" + rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name") + return {row["name"] for row in rows} + + +async def _ensure_fulltext_index( + connector: Neo4jConnector, + idx: FulltextIndexDef, + existing: set, +) -> str: + """检查并按需创建全文索引,返回操作状态描述""" + if idx.name in existing: + return f"[SKIP] 全文索引已存在: {idx.name}" + + props = ", ".join(f"n.{p}" for p in idx.properties) + cypher = ( + f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS ' + f'FOR (n:{idx.label}) ON EACH [{props}]' + ) + await connector.execute_query(cypher) + return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label} → {idx.properties})" + + +async def _ensure_vector_index( + connector: Neo4jConnector, + idx: VectorIndexDef, + existing: set, +) -> str: + """检查并按需创建向量索引,返回操作状态描述""" + if idx.name in existing: + return f"[SKIP] 向量索引已存在: {idx.name}" + + cypher = ( + f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS " + f"FOR (n:{idx.label}) ON n.{idx.property} " + f"OPTIONS {{indexConfig: {{" + f"`vector.dimensions`: {idx.dimensions}, " + f"`vector.similarity_function`: '{idx.similarity}'" + f"}}}}" + ) + await connector.execute_query(cypher) + return ( + f"[CREATE] 向量索引已创建: {idx.name} " + f"({idx.label}.{idx.property}, dim={idx.dimensions})" + ) + + +async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict: + """ + 检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。 + + Args: + connector: 可选,传入已有连接器;为 None 时自动创建。 + + Returns: + dict: { + "uri": 当前连接的 Neo4j URI, + "fulltext": [操作日志列表], + "vector": [操作日志列表], + "errors": [错误信息列表], + } + """ + own_connector = connector is None + if own_connector: + connector = Neo4jConnector() + + report = { + "uri": settings.NEO4J_URI, + "fulltext": [], + "vector": [], + "errors": [], + } + + try: + # 一次性拉取所有已有索引名 + existing = await _get_existing_indexes(connector) + logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}") + logger.info(f"[IndexManager] 已有索引数量: {len(existing)}") + + # 处理全文索引 + for idx in FULLTEXT_INDEXES: + try: + msg = await _ensure_fulltext_index(connector, idx, existing) + report["fulltext"].append(msg) + logger.info(f"[IndexManager] {msg}") + except Exception as e: + err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}" + report["errors"].append(err) + logger.error(f"[IndexManager] {err}") + + # 处理向量索引 + for idx in VECTOR_INDEXES: + try: + msg = await _ensure_vector_index(connector, idx, existing) + report["vector"].append(msg) + logger.info(f"[IndexManager] {msg}") + except Exception as e: + err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}" + report["errors"].append(err) + logger.error(f"[IndexManager] {err}") + + finally: + if own_connector: + await connector.close() + + return report + + +async def check_indexes(connector: Neo4jConnector | None = None) -> dict: + """ + 仅检查索引状态,不创建任何索引。 + + Returns: + dict: { + "uri": ..., + "present": [已存在的索引名], + "missing_fulltext": [缺失的全文索引名], + "missing_vector": [缺失的向量索引名], + } + """ + own_connector = connector is None + if own_connector: + connector = Neo4jConnector() + + try: + existing = await _get_existing_indexes(connector) + missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing] + missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing] + + return { + "uri": settings.NEO4J_URI, + "present": sorted(existing), + "missing_fulltext": missing_ft, + "missing_vector": missing_vec, + } + finally: + if own_connector: + await connector.close() + + +# ───────────────────────────────────────────────────────────── +# 独立脚本入口 +# ───────────────────────────────────────────────────────────── + +async def _main(): + import sys + + print(f"\n{'='*60}") + print(f"Neo4j 索引管理工具") + print(f"环境: {settings.NEO4J_URI}") + print(f"{'='*60}\n") + + # 先检查 + print(">>> 检查当前索引状态...\n") + status = await check_indexes() + print(f" 已存在索引数: {len(status['present'])}") + if status["missing_fulltext"]: + print(f" 缺失全文索引: {status['missing_fulltext']}") + if status["missing_vector"]: + print(f" 缺失向量索引: {status['missing_vector']}") + + if not status["missing_fulltext"] and not status["missing_vector"]: + print("\n 所有索引均已存在,无需操作。") + return + + # 再创建 + print("\n>>> 开始创建缺失索引...\n") + report = await ensure_indexes() + + for msg in report["fulltext"] + report["vector"]: + print(f" {msg}") + + if report["errors"]: + print("\n[!] 以下索引创建失败:") + for err in report["errors"]: + print(f" {err}") + sys.exit(1) + else: + print("\n 全部索引处理完成。") + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/api/app/tasks.py b/api/app/tasks.py index a6ebbb8e..defa1aa0 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2416,3 +2416,391 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]: "elapsed_time": elapsed_time, "task_id": self.request.id } + + +# ============================================================================= + +@celery_app.task( + name="app.tasks.init_implicit_emotions_for_users", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, + # 触发型任务标识,区别于 periodic_tasks 队列中的定时任务 + triggered=True, +) +def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: + """事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。 + + 由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。 + 存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。 + + Args: + end_user_ids: 需要检查的用户ID列表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.repositories.implicit_emotions_storage_repository import ( + ImplicitEmotionsStorageRepository, + ) + from app.services.emotion_analytics_service import EmotionAnalyticsService + from app.services.implicit_memory_service import ImplicitMemoryService + + logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}") + + initialized = 0 + failed = 0 + skipped = 0 + + with get_db_context() as db: + repo = ImplicitEmotionsStorageRepository(db) + + for end_user_id in end_user_ids: + existing = repo.get_by_end_user_id(end_user_id) + if existing is not None: + skipped += 1 + continue + + logger.info(f"用户 {end_user_id} 无记录,开始初始化") + implicit_ok = False + emotion_ok = False + try: + try: + implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id) + profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id) + await implicit_service.save_profile_cache( + end_user_id=end_user_id, profile_data=profile_data, db=db + ) + implicit_ok = True + except Exception as e: + logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}") + + try: + emotion_service = EmotionAnalyticsService() + suggestions_data = await emotion_service.generate_emotion_suggestions( + end_user_id=end_user_id, db=db, language="zh" + ) + await emotion_service.save_suggestions_cache( + end_user_id=end_user_id, suggestions_data=suggestions_data, db=db + ) + emotion_ok = True + except Exception as e: + logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}") + + if implicit_ok or emotion_ok: + initialized += 1 + else: + failed += 1 + except Exception as e: + failed += 1 + logger.error(f"用户 {end_user_id} 初始化异常: {e}") + + logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}") + return { + "status": "SUCCESS", + "initialized": initialized, + "skipped": skipped, + "failed": failed, + } + + try: + loop = set_asyncio_event_loop() + + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + return result + except Exception as e: + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } + + +# ============================================================================= + +@celery_app.task( + name="app.tasks.init_interest_distribution_for_users", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, +) +def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: + """事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。 + + 由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。 + 默认生成中文(zh)兴趣分布数据。 + + Args: + self: task object + end_user_ids: 需要检查的用户ID列表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE + from app.services.memory_agent_service import MemoryAgentService + + logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}") + + initialized = 0 + failed = 0 + skipped = 0 + language = "zh" + + service = MemoryAgentService() + + with get_db_context() as db: + for end_user_id in end_user_ids: + # 存在性检查:缓存有数据则跳过 + cached = await InterestMemoryCache.get_interest_distribution( + end_user_id=end_user_id, + language=language, + ) + if cached is not None: + skipped += 1 + continue + + logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成") + try: + result = await service.get_interest_distribution_by_user( + end_user_id=end_user_id, + limit=5, + language=language, + ) + await InterestMemoryCache.set_interest_distribution( + end_user_id=end_user_id, + language=language, + data=result, + expire=INTEREST_CACHE_EXPIRE, + ) + initialized += 1 + logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功") + except Exception as e: + failed += 1 + logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}") + + logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}") + return { + "status": "SUCCESS", + "initialized": initialized, + "skipped": skipped, + "failed": failed, + } + + try: + loop = set_asyncio_event_loop() + + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + return result + except Exception as e: + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } + + +@celery_app.task( + name="app.tasks.write_perceptual_memory", + bind=True, + ignore_result=True, + max_retries=0, + acks_late=False, + time_limit=3600, + soft_time_limit=3300, +) +def write_perceptual_memory( + self, + end_user_id: str, + model_api_config: dict, + file_type: str, + file_url: str, + file_message: dict +): + """ + Write perceptual memory for a user into PostgreSQL and Neo4j. + + This task generates or updates the user's perceptual memory + in the backend databases. It is intended to be executed asynchronously + via Celery. + + Args: + end_user_id (uuid.UUID): The unique identifier of the end user. + model_api_config (ModelInfo): API configuration for the model + used to generate perceptual memory. + file_type (str): The file type + file_url (url): The url of file + file_message (dict): The file message containing details about the file + to be processed. + + Returns: + None + """ + file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest() + set_asyncio_event_loop() + with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()): + model_info = ModelInfo(**model_api_config) + with get_db_context() as db: + memory_perceptual_service = MemoryPerceptualService(db) + return asyncio.run(memory_perceptual_service.generate_perceptual_memory( + end_user_id, + model_info, + file_type, + file_url, + file_message, + )) + + +# ============================================================================= +# 社区聚类补全任务(触发型) +# ============================================================================= + +@celery_app.task( + name="app.tasks.init_community_clustering_for_users", + bind=True, + ignore_result=False, + max_retries=0, + acks_late=False, + time_limit=7200, # 2小时硬超时 + soft_time_limit=6900, +) +def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: + """触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。 + + 由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。 + + Args: + end_user_ids: 需要检查的用户 ID 列表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.neo4j.community_repository import CommunityRepository + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + + logger = get_logger(__name__) + logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}") + + initialized = 0 + skipped = 0 + failed = 0 + + connector = Neo4jConnector() + try: + repo = CommunityRepository(connector) + + # 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置) + user_llm_map: Dict[str, Optional[str]] = {} + user_embedding_map: Dict[str, Optional[str]] = {} + try: + with get_db_context() as db: + from app.services.memory_agent_service import get_end_users_connected_configs_batch + from app.services.memory_config_service import MemoryConfigService + batch_configs = get_end_users_connected_configs_batch(end_user_ids, db) + for uid, cfg_info in batch_configs.items(): + config_id = cfg_info.get("memory_config_id") + if config_id: + try: + cfg = MemoryConfigService(db).load_memory_config(config_id=config_id) + user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None + user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None + except Exception as e: + logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}") + user_llm_map[uid] = None + user_embedding_map[uid] = None + else: + user_llm_map[uid] = None + user_embedding_map[uid] = None + except Exception as e: + logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}") + + for end_user_id in end_user_ids: + try: + # 已有社区节点则跳过 + has_communities = await repo.has_communities(end_user_id) + if has_communities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过") + continue + + # 检查是否有 ExtractedEntity 节点 + entities = await repo.get_all_entities(end_user_id) + if not entities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过") + continue + + # 每个用户使用自己的 llm_model_id + llm_model_id = user_llm_map.get(end_user_id) + embedding_model_id = user_embedding_map.get(end_user_id) + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id, + ) + + logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") + await engine.full_clustering(end_user_id) + initialized += 1 + logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") + + except Exception as e: + failed += 1 + logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}") + + finally: + await connector.close() + + logger.info( + f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}" + ) + return { + "status": "SUCCESS", + "initialized": initialized, + "skipped": skipped, + "failed": failed, + } + + try: + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + loop = set_asyncio_event_loop() + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + return result + + except Exception as e: + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } diff --git a/redbear-mem-benchmark b/redbear-mem-benchmark index 8494e824..89053e48 160000 --- a/redbear-mem-benchmark +++ b/redbear-mem-benchmark @@ -1 +1 @@ -Subproject commit 8494e82498cb99c70ac67a64a544ff872432363a +Subproject commit 89053e48e932332d2a0f17760034ee2bce75ea43 From f9fb480cc3e519a733b24e2957f5d8dc8b46ef49 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 12:30:00 +0800 Subject: [PATCH 08/14] [changes] Community Clustering Retrieval Module --- .../memory/agent/services/search_service.py | 38 ++- .../core/memory/agent/utils/write_tools.py | 9 +- api/app/core/memory/src/search.py | 12 +- .../clustering_engine/label_propagation.py | 149 ++++++---- api/app/main.py | 12 +- .../neo4j/community_repository.py | 42 ++- api/app/repositories/neo4j/cypher_queries.py | 121 +++++++-- api/app/repositories/neo4j/graph_saver.py | 15 +- api/app/repositories/neo4j/graph_search.py | 79 ++++++ api/app/repositories/neo4j/index_manager.py | 254 ++++++++++++++++++ redbear-mem-benchmark | 2 +- 11 files changed, 637 insertions(+), 96 deletions(-) create mode 100644 api/app/repositories/neo4j/index_manager.py diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 4fc4256e..2be18c97 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -120,7 +120,7 @@ class SearchService: raw_results is None if return_raw_results=False """ if include is None: - include = ["statements", "chunks", "entities", "summaries"] + include = ["statements", "chunks", "entities", "summaries", "communities"] # Clean query cleaned_query = self.clean_query(question) @@ -146,8 +146,8 @@ class SearchService: if search_type == "hybrid": reranked_results = answer.get('reranked_results', {}) - # Priority order: summaries first (most contextual), then statements, chunks, entities - priority_order = ['summaries', 'statements', 'chunks', 'entities'] + # Priority order: summaries first (most contextual), then communities, statements, chunks, entities + priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] for category in priority_order: if category in include and category in reranked_results: @@ -157,13 +157,43 @@ class SearchService: else: # For keyword or embedding search, results are directly in answer dict # Apply same priority order - priority_order = ['summaries', 'statements', 'chunks', 'entities'] + priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] for category in priority_order: if category in include and category in answer: category_results = answer[category] if isinstance(category_results, list): answer_list.extend(category_results) + + # 对命中的 community 节点展开其成员 statements + if "communities" in include: + community_results = ( + answer.get('reranked_results', {}).get('communities', []) + if search_type == "hybrid" + else answer.get('communities', []) + ) + community_ids = [ + r.get("id") for r in community_results if r.get("id") + ] + if community_ids and end_user_id: + try: + from app.repositories.neo4j.graph_search import search_graph_community_expand + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + connector = Neo4jConnector() + expand_result = await search_graph_community_expand( + connector=connector, + community_ids=community_ids, + end_user_id=end_user_id, + limit=10, + ) + await connector.close() + expanded_stmts = expand_result.get("expanded_statements", []) + if expanded_stmts: + # 展开的 statements 插入 communities 之后、statements 之前 + answer_list.extend(expanded_stmts) + logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements") + except Exception as e: + logger.warning(f"社区展开检索失败,跳过: {e}") # Extract clean content from all results content_list = [ diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index b3707083..02aa1b44 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -171,6 +171,13 @@ async def write( ) if success: logger.info("Successfully saved all data to Neo4j") + # 写入成功后,异步触发聚类(不阻塞写入响应) + schedule_clustering_after_write( + all_entity_nodes, + config_id=config_id, + llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, + embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + ) break else: logger.warning("Failed to save some data to Neo4j") diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 0e1d8424..3570d707 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -238,7 +238,7 @@ def rerank_with_activation( reranked: Dict[str, List[Dict[str, Any]]] = {} - for category in ["statements", "chunks", "entities", "summaries"]: + for category in ["statements", "chunks", "entities", "summaries", "communities"]: keyword_items = keyword_results.get(category, []) embedding_items = embedding_results.get(category, []) @@ -281,21 +281,23 @@ def rerank_with_activation( for item in items_list: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") if item_id and item_id in combined_items: - combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) + combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value") # 步骤 4: 计算基础分数和最终分数 for item_id, item in combined_items.items(): bm25_norm = float(item.get("bm25_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0) - act_norm = float(item.get("normalized_activation_value", 0) or 0) + # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义 + raw_act_norm = item.get("normalized_activation_value") + act_norm = float(raw_act_norm) if raw_act_norm is not None else None # 第一阶段:只考虑内容相关性(BM25 + Embedding) # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 content_score = alpha * bm25_norm + (1 - alpha) * emb_norm base_score = content_score # 第一阶段用内容分数 - # 存储激活度分数供第二阶段使用 - item["activation_score"] = act_norm + # 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序) + item["activation_score"] = act_norm # 可能为 None item["content_score"] = content_score item["base_score"] = base_score diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index cbc303b1..b4a16734 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -19,8 +19,9 @@ logger = logging.getLogger(__name__) # 全量迭代最大轮数,防止不收敛 MAX_ITERATIONS = 10 -# 社区摘要核心实体数量 -CORE_ENTITY_LIMIT = 5 + +# 社区核心实体取 top-N 数量 +CORE_ENTITY_LIMIT = 10 def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: @@ -69,11 +70,13 @@ class LabelPropagationEngine: connector: Neo4jConnector, config_id: Optional[str] = None, llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ): self.connector = connector self.repo = CommunityRepository(connector) self.config_id = config_id self.llm_model_id = llm_model_id + self.embedding_model_id = embedding_model_id # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -103,58 +106,85 @@ class LabelPropagationEngine: async def full_clustering(self, end_user_id: str) -> None: """ - 全量标签传播初始化。 + 全量标签传播初始化(分批处理,控制内存峰值)。 - 1. 拉取所有实体,初始化每个实体为独立社区 - 2. 迭代:每轮对所有实体做邻居投票,更新社区标签 - 3. 直到标签不再变化或达到 MAX_ITERATIONS - 4. 将最终标签写入 Neo4j + 策略: + - 每次只加载 BATCH_SIZE 个实体及其邻居进内存 + - labels 字典跨批次共享(只存 id→community_id,内存极小) + - 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息 + - 所有批次完成后统一 flush 和 merge """ - entities = await self.repo.get_all_entities(end_user_id) - if not entities: + BATCH_SIZE = 2000 # 每批实体数,可按需调整 + + # 先查总数,决定批次数 + total_entities = await self.repo.get_all_entities(end_user_id) + if not total_entities: logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") return - # 初始化:每个实体持有自己 id 作为社区标签 - labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} - embeddings: Dict[str, Optional[List[float]]] = { - e["id"]: e.get("name_embedding") for e in entities - } + total_count = len(total_entities) + logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体," + f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批") - # 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返 - logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...") - neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id) - logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}") + # labels 跨批次共享:先用全量数据初始化(只存 id,内存极小) + labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities} + # embeddings 也跨批次共享(每个向量 ~6KB,10万实体约 600MB,这是不可避免的) + # 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻 + # 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding, + # 所以只需要当前批次实体的 embedding,不需要全量 + del total_entities # 释放全量列表,后续按批次加载 - for iteration in range(MAX_ITERATIONS): - changed = 0 - # 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历) - for entity in entities: - eid = entity["id"] - # 直接从缓存取邻居,不再发起 Neo4j 查询 - neighbors = neighbors_cache.get(eid, []) - - # 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值) - enriched = [] - for nb in neighbors: - nb_copy = dict(nb) - nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) - enriched.append(nb_copy) - - new_label = _weighted_vote(enriched, embeddings.get(eid)) - if new_label and new_label != labels[eid]: - labels[eid] = new_label - changed += 1 - - logger.info( - f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}," - f"标签变化数: {changed}" + for batch_start in range(0, total_count, BATCH_SIZE): + batch_entities = await self.repo.get_entities_page( + end_user_id, skip=batch_start, limit=BATCH_SIZE ) - if changed == 0: - logger.info("[Clustering] 标签已收敛,提前结束迭代") + if not batch_entities: break - # 将最终标签写入 Neo4j + batch_ids = [e["id"] for e in batch_entities] + batch_embeddings: Dict[str, Optional[List[float]]] = { + e["id"]: e.get("name_embedding") for e in batch_entities + } + + logger.info( + f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:" + f"加载 {len(batch_entities)} 个实体的邻居图..." + ) + neighbors_cache = await self.repo.get_entity_neighbors_for_ids( + batch_ids, end_user_id + ) + logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}") + + for iteration in range(MAX_ITERATIONS): + changed = 0 + for entity in batch_entities: + eid = entity["id"] + neighbors = neighbors_cache.get(eid, []) + + # 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值) + enriched = [] + for nb in neighbors: + nb_copy = dict(nb) + nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id")) + enriched.append(nb_copy) + + new_label = _weighted_vote(enriched, batch_embeddings.get(eid)) + if new_label and new_label != labels[eid]: + labels[eid] = new_label + changed += 1 + + logger.info( + f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} " + f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}" + ) + if changed == 0: + logger.info("[Clustering] 标签已收敛,提前结束本批迭代") + break + + # 释放本批次的大对象 + del neighbors_cache, batch_embeddings, batch_entities + + # 所有批次完成,统一写入 Neo4j await self._flush_labels(labels, end_user_id) pre_merge_count = len(set(labels.values())) logger.info( @@ -162,17 +192,16 @@ class LabelPropagationEngine: f"{len(labels)} 个实体,开始后处理合并" ) - # 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度) all_community_ids = list(set(labels.values())) await self._evaluate_merge(all_community_ids, end_user_id) logger.info( + f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) - # 为所有社区生成元数据 - # 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区 - # 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID + + # 查询存活社区并生成元数据 surviving_communities = await self.repo.get_all_entities(end_user_id) surviving_community_ids = list({ e.get("community_id") for e in surviving_communities @@ -421,6 +450,7 @@ class LabelPropagationEngine: - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 + NOTE: core_entities按照激活值高低排序,会造成对边缘信息检索返回消息质量不高。 """ try: members = await self.repo.get_community_members(community_id, end_user_id) @@ -468,16 +498,33 @@ class LabelPropagationEngine: except Exception as e: logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") - await self.repo.update_community_metadata( + # 生成 summary_embedding + summary_embedding = None + if self.embedding_model_id and summary: + try: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + results = await embedder.response([summary]) + summary_embedding = results[0] if results else None + except Exception as e: + logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}") + + result = await self.repo.update_community_metadata( community_id=community_id, end_user_id=end_user_id, name=name, summary=summary, core_entities=core_entities, + summary_embedding=summary_embedding, ) - logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}") + if result: + logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...") + else: + logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False") except Exception as e: - logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}") + logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True) @staticmethod def _new_community_id() -> str: diff --git a/api/app/main.py b/api/app/main.py index c6256e3c..5314b8b6 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -18,6 +18,7 @@ from app.core.logging_config import LoggingConfig, get_logger from app.core.response_utils import fail from app.core.models.scripts.loader import load_models from app.db import get_db_context +from app.repositories.neo4j.index_manager import ensure_indexes # Initialize logging system LoggingConfig.setup_logging() @@ -61,9 +62,18 @@ async def lifespan(app: FastAPI): else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") + # 确保 Neo4j 索引存在(幂等,多环境安全) + try: + report = await ensure_indexes() + if report["errors"]: + logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}") + else: + logger.info(f"Neo4j 索引检查完成 [{report['uri']}]") + except Exception as e: + logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}") + logger.info("应用程序启动完成") yield - # 应用关闭事件 logger.info("应用程序正在关闭") diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index f2f11f76..bf7fde1d 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -13,12 +13,15 @@ from app.repositories.neo4j.cypher_queries import ( ENTITY_LEAVE_ALL_COMMUNITIES, GET_ENTITY_NEIGHBORS, GET_ALL_ENTITIES_FOR_USER, + GET_ENTITIES_PAGE, GET_COMMUNITY_MEMBERS, GET_ALL_COMMUNITY_MEMBERS_BATCH, GET_ALL_ENTITY_NEIGHBORS_BATCH, + GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_METADATA, + UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -110,6 +113,41 @@ class CommunityRepository: logger.error(f"get_all_entities failed: {e}") return [] + async def get_entities_page( + self, end_user_id: str, skip: int, limit: int + ) -> List[Dict]: + """分页拉取实体,用于全量聚类分批处理。""" + try: + return await self.connector.execute_query( + GET_ENTITIES_PAGE, + end_user_id=end_user_id, + skip=skip, + limit=limit, + ) + except Exception as e: + logger.error(f"get_entities_page failed: {e}") + return [] + + async def get_entity_neighbors_for_ids( + self, entity_ids: List[str], end_user_id: str + ) -> Dict[str, List[Dict]]: + """批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。""" + try: + rows = await self.connector.execute_query( + GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS, + entity_ids=entity_ids, + end_user_id=end_user_id, + ) + result: Dict[str, List[Dict]] = {} + for row in rows: + eid = row["entity_id"] + neighbor = {k: v for k, v in row.items() if k != "entity_id"} + result.setdefault(eid, []).append(neighbor) + return result + except Exception as e: + logger.error(f"get_entity_neighbors_for_ids failed: {e}") + return {} + async def get_community_members( self, community_id: str, end_user_id: str ) -> List[Dict]: @@ -177,8 +215,9 @@ class CommunityRepository: name: str, summary: str, core_entities: List[str], + summary_embedding: Optional[List[float]] = None, ) -> bool: - """更新社区的名称、摘要和核心实体列表。""" + """更新社区的名称、摘要、核心实体列表和摘要向量。""" try: result = await self.connector.execute_query( UPDATE_COMMUNITY_METADATA, @@ -187,6 +226,7 @@ class CommunityRepository: name=name, summary=summary, core_entities=core_entities, + summary_embedding=summary_embedding, ) return bool(result) except Exception as e: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 48a5ac87..339adb43 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1132,11 +1132,11 @@ ORDER BY coalesce(e.activation_value, 0) DESC GET_ALL_COMMUNITY_MEMBERS_BATCH = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) -WHERE c.community_id IN $community_ids RETURN c.community_id AS community_id, - e.id AS id, - e.name_embedding AS name_embedding, - e.activation_value AS activation_value + e.id AS id, e.name AS name, e.entity_type AS entity_type, + e.importance_score AS importance_score, e.activation_value AS activation_value, + e.name_embedding AS name_embedding +ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC """ CHECK_USER_HAS_COMMUNITIES = """ @@ -1153,13 +1153,47 @@ RETURN c.community_id AS community_id, cnt AS member_count UPDATE_COMMUNITY_METADATA = """ MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -SET c.name = $name, - c.summary = $summary, - c.core_entities = $core_entities, - c.updated_at = datetime() +SET c.name = $name, + c.summary = $summary, + c.core_entities = $core_entities, + c.summary_embedding = $summary_embedding, + c.updated_at = datetime() RETURN c.community_id AS community_id """ +GET_ENTITIES_PAGE = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN e.id AS id, + e.name AS name, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +ORDER BY e.id +SKIP $skip LIMIT $limit +""" + +GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """ +// 批量拉取指定实体列表的邻居(用于分批全量聚类) +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +WHERE e.id IN $entity_ids +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id +WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH e, nb WHERE nb IS NOT NULL +OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) +RETURN DISTINCT + e.id AS entity_id, + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, + nb.activation_value AS activation_value, + CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id +""" + GET_ALL_ENTITY_NEIGHBORS_BATCH = """ // 批量拉取某用户下所有实体的邻居(用于全量聚类预加载) MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) @@ -1185,20 +1219,59 @@ RETURN DISTINCT CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id """ -GET_COMMUNITY_GRAPH_DATA = """ -MATCH (c:Community {end_user_id: $end_user_id}) -MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c) -OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id}) -RETURN - elementId(c) AS c_id, - properties(c) AS c_props, - elementId(e) AS e_id, - properties(e) AS e_props, - elementId(b) AS b_id, - elementId(e2) AS e2_id, - properties(e2) AS e2_props, - elementId(r) AS r_id, - type(r) AS r_type, - properties(r) AS r_props, - startNode(r) = e AS r_from_e + +# Community keyword search: matches name or summary via fulltext index +SEARCH_COMMUNITIES_BY_KEYWORD = """ +CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score +WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) +RETURN c.community_id AS id, + c.name AS name, + c.summary AS content, + c.core_entities AS core_entities, + c.member_count AS member_count, + c.end_user_id AS end_user_id, + c.updated_at AS updated_at, + score +ORDER BY score DESC +LIMIT $limit +""" + +# Community 向量检索 ────────────────────────────────────────────────── +# Community embedding-based search: cosine similarity on Community.summary_embedding +COMMUNITY_EMBEDDING_SEARCH = """ +CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding) +YIELD node AS c, score +WHERE c.summary_embedding IS NOT NULL + AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id) +RETURN c.community_id AS id, + c.name AS name, + c.summary AS content, + c.core_entities AS core_entities, + c.member_count AS member_count, + c.end_user_id AS end_user_id, + c.updated_at AS updated_at, + score +ORDER BY score DESC +LIMIT $limit +""" + +# Community 展开检索 ────────────────────────────────────────────────── +# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索) +EXPAND_COMMUNITY_STATEMENTS = """ +MATCH (c:Community {community_id: $community_id}) +MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c) +MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +WHERE s.end_user_id = $end_user_id +RETURN s.statement AS statement, + s.id AS id, + s.end_user_id AS end_user_id, + s.created_at AS created_at, + s.valid_at AS valid_at, + s.invalid_at AS invalid_at, + COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value, + COALESCE(s.importance_score, 0.5) AS importance_score, + e.name AS source_entity, + c.name AS community_name +ORDER BY COALESCE(s.activation_value, 0) DESC +LIMIT $limit """ diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index cbd2b532..29e337f1 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -1,5 +1,4 @@ import asyncio -import os from typing import List, Optional # 使用新的仓储层 @@ -158,11 +157,12 @@ async def save_dialog_and_statements_to_neo4j( statement_chunk_edges: List[StatementChunkEdge], statement_entity_edges: List[StatementEntityEdge], connector: Neo4jConnector, - config_id: Optional[str] = None, - llm_model_id: Optional[str] = None, ) -> bool: """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. + 只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过 + schedule_clustering_after_write() 显式触发。 + Args: dialogue_nodes: List of DialogueNode objects to save chunk_nodes: List of ChunkNode objects to save @@ -293,9 +293,6 @@ async def save_dialog_and_statements_to_neo4j( logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) - # 写入成功后,异步触发聚类(不阻塞写入响应) - schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id) - return True except Exception as e: @@ -309,6 +306,7 @@ def schedule_clustering_after_write( entity_nodes: List, config_id: Optional[str] = None, llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 写入 Neo4j 成功后,调度后台聚类任务。 @@ -327,7 +325,7 @@ def schedule_clustering_after_write( end_user_id = entity_nodes[0].end_user_id new_entity_ids = [e.id for e in entity_nodes] logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)) async def _trigger_clustering( @@ -335,6 +333,7 @@ async def _trigger_clustering( end_user_id: str, config_id: Optional[str] = None, llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, ) -> None: """ 聚类触发函数,自动判断全量初始化还是增量更新。 @@ -344,7 +343,7 @@ async def _trigger_clustering( from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") connector = Neo4jConnector() - engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id) + engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}") except Exception as e: diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index e8f52535..19e40a82 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional from app.repositories.neo4j.cypher_queries import ( CHUNK_EMBEDDING_SEARCH, + COMMUNITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH, + EXPAND_COMMUNITY_STATEMENTS, MEMORY_SUMMARY_EMBEDDING_SEARCH, SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNKS_BY_CONTENT, + SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_ENTITIES_BY_NAME, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, @@ -285,6 +288,15 @@ async def search_graph( limit=limit, )) task_keys.append("summaries") + + if "communities" in include: + tasks.append(connector.execute_query( + SEARCH_COMMUNITIES_BY_KEYWORD, + q=q, + end_user_id=end_user_id, + limit=limit, + )) + task_keys.append("communities") # Execute all queries in parallel task_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -396,6 +408,16 @@ async def search_graph_by_embedding( )) task_keys.append("summaries") + # Communities (向量语义匹配) + if "communities" in include: + tasks.append(connector.execute_query( + COMMUNITY_EMBEDDING_SEARCH, + embedding=embedding, + end_user_id=end_user_id, + limit=limit, + )) + task_keys.append("communities") + # Execute all queries in parallel query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -408,6 +430,7 @@ async def search_graph_by_embedding( "chunks": [], "entities": [], "summaries": [], + "communities": [], } for key, result in zip(task_keys, task_results): @@ -661,6 +684,62 @@ async def search_graph_by_chunk_id( return {"chunks": chunks} +async def search_graph_community_expand( + connector: Neo4jConnector, + community_ids: List[str], + end_user_id: str, + limit: int = 10, +) -> Dict[str, List[Dict[str, Any]]]: + """ + 三期:社区展开检索 —— 主题 → 细节两级检索。 + + 命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体, + 再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点, + 按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。 + + Args: + connector: Neo4j 连接器 + community_ids: 已命中的社区 ID 列表 + end_user_id: 用户 ID,用于数据隔离 + limit: 每个社区最多返回的 Statement 数量 + + Returns: + {"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]} + """ + if not community_ids or not end_user_id: + return {"expanded_statements": []} + + tasks = [ + connector.execute_query( + EXPAND_COMMUNITY_STATEMENTS, + community_id=cid, + end_user_id=end_user_id, + limit=limit, + ) + for cid in community_ids + ] + + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + expanded: List[Dict[str, Any]] = [] + for cid, result in zip(community_ids, task_results): + if isinstance(result, Exception): + logger.warning(f"社区展开检索失败 community_id={cid}: {result}") + else: + expanded.extend(result) + + # 按 activation_value 全局排序后去重 + from app.core.memory.src.search import _deduplicate_results + expanded.sort( + key=lambda x: float(x.get("activation_value") or 0), + reverse=True, + ) + expanded = _deduplicate_results(expanded) + + logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}") + return {"expanded_statements": expanded} + + async def search_graph_by_created_at( connector: Neo4jConnector, end_user_id: Optional[str] = None, diff --git a/api/app/repositories/neo4j/index_manager.py b/api/app/repositories/neo4j/index_manager.py new file mode 100644 index 00000000..a1ab6689 --- /dev/null +++ b/api/app/repositories/neo4j/index_manager.py @@ -0,0 +1,254 @@ +# -*- coding: utf-8 -*- +"""Neo4j 索引管理模块 + +负责检查和创建 Neo4j 全文索引与向量索引。 +支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。 + +用法: + # 作为模块调用(应用启动时) + from app.repositories.neo4j.index_manager import ensure_indexes + await ensure_indexes() + + # 作为独立脚本执行(手动建索引) + python -m app.repositories.neo4j.index_manager +""" + +import asyncio +import logging +from dataclasses import dataclass +from typing import List + +from app.core.config import settings +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = logging.getLogger(__name__) + + +# ───────────────────────────────────────────────────────────── +# 索引定义表 +# ───────────────────────────────────────────────────────────── + +@dataclass +class FulltextIndexDef: + name: str + label: str + properties: List[str] + + +@dataclass +class VectorIndexDef: + name: str + label: str + property: str + dimensions: int + similarity: str = "cosine" + + +# 全文索引清单(现有 + 新增 communities) +FULLTEXT_INDEXES: List[FulltextIndexDef] = [ + FulltextIndexDef("statementsFulltext", "Statement", ["statement"]), + FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]), + FulltextIndexDef("chunksFulltext", "Chunk", ["content"]), + FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]), + FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源 +] + +# 向量索引清单(预留 community 二期) +VECTOR_INDEXES: List[VectorIndexDef] = [ + VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536), + VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536), + VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536), + VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536), + # 二期:社区向量索引 + VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536), +] + + +# ───────────────────────────────────────────────────────────── +# 核心检查 / 创建逻辑 +# ───────────────────────────────────────────────────────────── + +async def _get_existing_indexes(connector: Neo4jConnector) -> set: + """查询 Neo4j 中已存在的索引名称集合""" + rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name") + return {row["name"] for row in rows} + + +async def _ensure_fulltext_index( + connector: Neo4jConnector, + idx: FulltextIndexDef, + existing: set, +) -> str: + """检查并按需创建全文索引,返回操作状态描述""" + if idx.name in existing: + return f"[SKIP] 全文索引已存在: {idx.name}" + + props = ", ".join(f"n.{p}" for p in idx.properties) + cypher = ( + f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS ' + f'FOR (n:{idx.label}) ON EACH [{props}]' + ) + await connector.execute_query(cypher) + return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label} → {idx.properties})" + + +async def _ensure_vector_index( + connector: Neo4jConnector, + idx: VectorIndexDef, + existing: set, +) -> str: + """检查并按需创建向量索引,返回操作状态描述""" + if idx.name in existing: + return f"[SKIP] 向量索引已存在: {idx.name}" + + cypher = ( + f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS " + f"FOR (n:{idx.label}) ON n.{idx.property} " + f"OPTIONS {{indexConfig: {{" + f"`vector.dimensions`: {idx.dimensions}, " + f"`vector.similarity_function`: '{idx.similarity}'" + f"}}}}" + ) + await connector.execute_query(cypher) + return ( + f"[CREATE] 向量索引已创建: {idx.name} " + f"({idx.label}.{idx.property}, dim={idx.dimensions})" + ) + + +async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict: + """ + 检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。 + + Args: + connector: 可选,传入已有连接器;为 None 时自动创建。 + + Returns: + dict: { + "uri": 当前连接的 Neo4j URI, + "fulltext": [操作日志列表], + "vector": [操作日志列表], + "errors": [错误信息列表], + } + """ + own_connector = connector is None + if own_connector: + connector = Neo4jConnector() + + report = { + "uri": settings.NEO4J_URI, + "fulltext": [], + "vector": [], + "errors": [], + } + + try: + # 一次性拉取所有已有索引名 + existing = await _get_existing_indexes(connector) + logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}") + logger.info(f"[IndexManager] 已有索引数量: {len(existing)}") + + # 处理全文索引 + for idx in FULLTEXT_INDEXES: + try: + msg = await _ensure_fulltext_index(connector, idx, existing) + report["fulltext"].append(msg) + logger.info(f"[IndexManager] {msg}") + except Exception as e: + err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}" + report["errors"].append(err) + logger.error(f"[IndexManager] {err}") + + # 处理向量索引 + for idx in VECTOR_INDEXES: + try: + msg = await _ensure_vector_index(connector, idx, existing) + report["vector"].append(msg) + logger.info(f"[IndexManager] {msg}") + except Exception as e: + err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}" + report["errors"].append(err) + logger.error(f"[IndexManager] {err}") + + finally: + if own_connector: + await connector.close() + + return report + + +async def check_indexes(connector: Neo4jConnector | None = None) -> dict: + """ + 仅检查索引状态,不创建任何索引。 + + Returns: + dict: { + "uri": ..., + "present": [已存在的索引名], + "missing_fulltext": [缺失的全文索引名], + "missing_vector": [缺失的向量索引名], + } + """ + own_connector = connector is None + if own_connector: + connector = Neo4jConnector() + + try: + existing = await _get_existing_indexes(connector) + missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing] + missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing] + + return { + "uri": settings.NEO4J_URI, + "present": sorted(existing), + "missing_fulltext": missing_ft, + "missing_vector": missing_vec, + } + finally: + if own_connector: + await connector.close() + + +# ───────────────────────────────────────────────────────────── +# 独立脚本入口 +# ───────────────────────────────────────────────────────────── + +async def _main(): + import sys + + print(f"\n{'='*60}") + print(f"Neo4j 索引管理工具") + print(f"环境: {settings.NEO4J_URI}") + print(f"{'='*60}\n") + + # 先检查 + print(">>> 检查当前索引状态...\n") + status = await check_indexes() + print(f" 已存在索引数: {len(status['present'])}") + if status["missing_fulltext"]: + print(f" 缺失全文索引: {status['missing_fulltext']}") + if status["missing_vector"]: + print(f" 缺失向量索引: {status['missing_vector']}") + + if not status["missing_fulltext"] and not status["missing_vector"]: + print("\n 所有索引均已存在,无需操作。") + return + + # 再创建 + print("\n>>> 开始创建缺失索引...\n") + report = await ensure_indexes() + + for msg in report["fulltext"] + report["vector"]: + print(f" {msg}") + + if report["errors"]: + print("\n[!] 以下索引创建失败:") + for err in report["errors"]: + print(f" {err}") + sys.exit(1) + else: + print("\n 全部索引处理完成。") + + +if __name__ == "__main__": + asyncio.run(_main()) diff --git a/redbear-mem-benchmark b/redbear-mem-benchmark index c3bbc693..89053e48 160000 --- a/redbear-mem-benchmark +++ b/redbear-mem-benchmark @@ -1 +1 @@ -Subproject commit c3bbc6931c570e6fac88c0b00658b4f08dc2ac77 +Subproject commit 89053e48e932332d2a0f17760034ee2bce75ea43 From f32d92b9d09eca24bfc755ed084d8bc838d18428 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 14:05:12 +0800 Subject: [PATCH 09/14] [Changes] --- .../memory/agent/services/search_service.py | 12 ++++----- .../core/memory/agent/utils/write_tools.py | 7 ----- .../clustering_engine/label_propagation.py | 18 +++++-------- .../neo4j/community_repository.py | 27 ++++++++++++++++++- api/app/repositories/neo4j/cypher_queries.py | 10 +++++++ 5 files changed, 49 insertions(+), 25 deletions(-) diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 2be18c97..c9346c16 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -176,24 +176,24 @@ class SearchService: r.get("id") for r in community_results if r.get("id") ] if community_ids and end_user_id: + from app.repositories.neo4j.graph_search import search_graph_community_expand + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + expand_connector = Neo4jConnector() try: - from app.repositories.neo4j.graph_search import search_graph_community_expand - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - connector = Neo4jConnector() expand_result = await search_graph_community_expand( - connector=connector, + connector=expand_connector, community_ids=community_ids, end_user_id=end_user_id, limit=10, ) - await connector.close() expanded_stmts = expand_result.get("expanded_statements", []) if expanded_stmts: - # 展开的 statements 插入 communities 之后、statements 之前 answer_list.extend(expanded_stmts) logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements") except Exception as e: logger.warning(f"社区展开检索失败,跳过: {e}") + finally: + await expand_connector.close() # Extract clean content from all results content_list = [ diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index df8752d6..02aa1b44 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -178,13 +178,6 @@ async def write( llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, ) - # 写入成功后,异步触发聚类(不阻塞写入响应) - schedule_clustering_after_write( - all_entity_nodes, - config_id=config_id, - llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, - ) break else: logger.warning("Failed to save some data to Neo4j") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index b4a16734..46a7b8f3 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -116,23 +116,19 @@ class LabelPropagationEngine: """ BATCH_SIZE = 2000 # 每批实体数,可按需调整 - # 先查总数,决定批次数 - total_entities = await self.repo.get_all_entities(end_user_id) - if not total_entities: + # 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段 + total_count = await self.repo.get_entity_count(end_user_id) + if not total_count: logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") return - total_count = len(total_entities) + all_entity_ids = await self.repo.get_all_entity_ids(end_user_id) logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体," f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批") - # labels 跨批次共享:先用全量数据初始化(只存 id,内存极小) - labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities} - # embeddings 也跨批次共享(每个向量 ~6KB,10万实体约 600MB,这是不可避免的) - # 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻 - # 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding, - # 所以只需要当前批次实体的 embedding,不需要全量 - del total_entities # 释放全量列表,后续按批次加载 + # labels 跨批次共享:只存 id→community_id,内存极小 + labels: Dict[str, str] = {eid: eid for eid in all_entity_ids} + del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据 for batch_start in range(0, total_count, BATCH_SIZE): batch_entities = await self.repo.get_entities_page( diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index bf7fde1d..267ced4f 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -13,6 +13,8 @@ from app.repositories.neo4j.cypher_queries import ( ENTITY_LEAVE_ALL_COMMUNITIES, GET_ENTITY_NEIGHBORS, GET_ALL_ENTITIES_FOR_USER, + GET_ENTITY_COUNT_FOR_USER, + GET_ALL_ENTITY_IDS_FOR_USER, GET_ENTITIES_PAGE, GET_COMMUNITY_MEMBERS, GET_ALL_COMMUNITY_MEMBERS_BATCH, @@ -21,7 +23,6 @@ from app.repositories.neo4j.cypher_queries import ( CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_METADATA, - UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -113,6 +114,30 @@ class CommunityRepository: logger.error(f"get_all_entities failed: {e}") return [] + async def get_entity_count(self, end_user_id: str) -> int: + """仅返回用户实体总数,不加载实体数据。""" + try: + result = await self.connector.execute_query( + GET_ENTITY_COUNT_FOR_USER, + end_user_id=end_user_id, + ) + return result[0]["entity_count"] if result else 0 + except Exception as e: + logger.error(f"get_entity_count failed: {e}") + return 0 + + async def get_all_entity_ids(self, end_user_id: str) -> List[str]: + """仅返回用户所有实体 ID 列表,不加载 embedding 等大字段。""" + try: + result = await self.connector.execute_query( + GET_ALL_ENTITY_IDS_FOR_USER, + end_user_id=end_user_id, + ) + return [r["id"] for r in result] + except Exception as e: + logger.error(f"get_all_entity_ids failed: {e}") + return [] + async def get_entities_page( self, end_user_id: str, skip: int, limit: int ) -> List[Dict]: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 339adb43..01f2fa6a 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1122,6 +1122,16 @@ RETURN e.id AS id, CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id """ +GET_ENTITY_COUNT_FOR_USER = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +RETURN count(e) AS entity_count +""" + +GET_ALL_ENTITY_IDS_FOR_USER = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) +RETURN e.id AS id +""" + GET_COMMUNITY_MEMBERS = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, From b8e85bed61f4f7f6c18818033dab4170f19e1922 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 14:47:57 +0800 Subject: [PATCH 10/14] [changes] Remove FileType and break the import loop --- api/app/models/memory_perceptual_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py index 9fed7c5d..7610b79f 100644 --- a/api/app/models/memory_perceptual_model.py +++ b/api/app/models/memory_perceptual_model.py @@ -7,7 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import JSONB from app.db import Base -from app.schemas import FileType +from app.schemas.app_schema import FileType class PerceptualType(IntEnum): From 19d149c129b43bc16c6b8ff1ea834ca797784f23 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 14:55:25 +0800 Subject: [PATCH 11/14] [add] Remove redundant logs --- .../storage_services/clustering_engine/label_propagation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 46a7b8f3..37376093 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -192,7 +192,6 @@ class LabelPropagationEngine: await self._evaluate_merge(all_community_ids, end_user_id) logger.info( - f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) From 8e6288bca8f038bfd5b44e713a5595afd9e63ee9 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 18:38:59 +0800 Subject: [PATCH 12/14] [changes] Change the same reference --- api/app/models/memory_perceptual_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/api/app/models/memory_perceptual_model.py b/api/app/models/memory_perceptual_model.py index 7610b79f..ae8cc1bd 100644 --- a/api/app/models/memory_perceptual_model.py +++ b/api/app/models/memory_perceptual_model.py @@ -7,8 +7,7 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import JSONB from app.db import Base -from app.schemas.app_schema import FileType - +from app.schemas import FileType class PerceptualType(IntEnum): VISION = 1 From 56adca9f22880cf614d7c0424a0835c276b9e0ed Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 23:06:41 +0800 Subject: [PATCH 13/14] [changes] Batch mode for metadata creation and unified management of indexes --- .../clustering_engine/label_propagation.py | 165 +++++++----- api/app/main.py | 11 - .../neo4j/community_repository.py | 23 ++ api/app/repositories/neo4j/create_indexes.py | 19 ++ api/app/repositories/neo4j/cypher_queries.py | 17 +- api/app/repositories/neo4j/index_manager.py | 254 ------------------ 6 files changed, 156 insertions(+), 333 deletions(-) delete mode 100644 api/app/repositories/neo4j/index_manager.py diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 37376093..21257f2e 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -7,6 +7,7 @@ - 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居 """ +import asyncio import logging import uuid from math import sqrt @@ -114,7 +115,7 @@ class LabelPropagationEngine: - 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息 - 所有批次完成后统一 flush 和 merge """ - BATCH_SIZE = 2000 # 每批实体数,可按需调整 + BATCH_SIZE = 888 # 每批实体数,可按需调整 # 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段 total_count = await self.repo.get_entity_count(end_user_id) @@ -203,8 +204,7 @@ class LabelPropagationEngine: if e.get("community_id") }) logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") - for cid in surviving_community_ids: - await self._generate_community_metadata(cid, end_user_id) + await self._generate_community_metadata(surviving_community_ids, end_user_id) async def incremental_update( self, new_entity_ids: List[str], end_user_id: str @@ -261,7 +261,7 @@ class LabelPropagationEngine: logger.debug( f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" ) - await self._generate_community_metadata(new_cid, end_user_id) + await self._generate_community_metadata([new_cid], end_user_id) else: # 加入得票最多的社区 await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) @@ -273,7 +273,7 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) - await self._generate_community_metadata(target_cid, end_user_id) + await self._generate_community_metadata([target_cid], end_user_id) async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -437,89 +437,122 @@ class LabelPropagationEngine: except Exception: return None + @staticmethod + def _build_entity_lines(members: List[Dict]) -> List[str]: + """将实体列表格式化为 prompt 行,包含 name、aliases、description。""" + lines = [] + for m in members: + m_name = m.get("name", "") + aliases = m.get("aliases") or [] + description = m.get("description") or "" + aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else "" + desc_str = f":{description}" if description else "" + lines.append(f"- {m_name}{aliases_str}{desc_str}") + return lines + async def _generate_community_metadata( - self, community_id: str, end_user_id: str + self, community_ids: List[str], end_user_id: str ) -> None: """ - 为社区生成并写入元数据:名称、摘要、核心实体。 + 为一个或多个社区生成并写入元数据。 - - core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM) - - name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 - NOTE: core_entities按照激活值高低排序,会造成对边缘信息检索返回消息质量不高。 + 流程: + 1. 逐个社区调 LLM 生成 name / summary(串行) + 2. 收集所有 summary,一次性批量 embed + 3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata """ - try: - members = await self.repo.get_community_members(community_id, end_user_id) - if not members: - return + if not community_ids: + return + + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + # --- 阶段1:并发调 LLM 生成每个社区的 name / summary --- + async def _build_one(cid: str): + members = await self.repo.get_community_members(cid, end_user_id) + if not members: + return None - # 核心实体:按 activation_value 降序取 top-N sorted_members = sorted( members, key=lambda m: m.get("activation_value") or 0, reverse=True, ) core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] - all_names = [m["name"] for m in members if m.get("name")] - name = "、".join(core_entities[:3]) if core_entities else community_id[:8] - summary = f"包含实体:{', '.join(all_names)}" + entity_list_str = "\n".join(self._build_entity_lines(members)) + prompt = ( + f"以下是一组语义相关的实体:\n{entity_list_str}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过50个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) + with get_db_context() as db: + llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) + response = await llm_client.chat([{"role": "user", "content": prompt}]) + text = response.content if hasattr(response, "content") else str(response) - # 若有 LLM 配置,调用 LLM 生成更好的名称和摘要 - if self.llm_model_id: - try: - from app.db import get_db_context - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + name, summary = "", "" + for line in text.strip().splitlines(): + if line.startswith("名称:"): + name = line[3:].strip() + elif line.startswith("摘要:"): + summary = line[3:].strip() - entity_list_str = "、".join(all_names) - prompt = ( - f"以下是一组语义相关的实体:{entity_list_str}\n\n" - f"请为这组实体所代表的主题:\n" - f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过50个字)\n\n" - f"严格按以下格式输出,不要有其他内容:\n" - f"名称:<名称>\n摘要:<摘要>" - ) - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(self.llm_model_id) - response = await llm_client.chat([{"role": "user", "content": prompt}]) - text = response.content if hasattr(response, "content") else str(response) + return { + "community_id": cid, + "end_user_id": end_user_id, + "name": name, + "summary": summary, + "core_entities": core_entities, + "summary_embedding": None, + } - for line in text.strip().splitlines(): - if line.startswith("名称:"): - name = line[3:].strip() - elif line.startswith("摘要:"): - summary = line[3:].strip() - except Exception as e: - logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") + results = await asyncio.gather( + *[_build_one(cid) for cid in community_ids], + return_exceptions=True, + ) + metadata_list = [] + for cid, res in zip(community_ids, results): + if isinstance(res, Exception): + logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res) + elif res is not None: + metadata_list.append(res) - # 生成 summary_embedding - summary_embedding = None - if self.embedding_model_id and summary: - try: - from app.db import get_db_context - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - with get_db_context() as db: - embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) - results = await embedder.response([summary]) - summary_embedding = results[0] if results else None - except Exception as e: - logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}") + if not metadata_list: + return + # --- 阶段2:批量生成 summary_embedding --- + summaries = [m["summary"] for m in metadata_list] + with get_db_context() as db: + embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + embeddings = await embedder.response(summaries) + for i, meta in enumerate(metadata_list): + meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + + # --- 阶段3:写入(单个 or 批量)--- + if len(metadata_list) == 1: + m = metadata_list[0] result = await self.repo.update_community_metadata( - community_id=community_id, - end_user_id=end_user_id, - name=name, - summary=summary, - core_entities=core_entities, - summary_embedding=summary_embedding, + community_id=m["community_id"], + end_user_id=m["end_user_id"], + name=m["name"], + summary=m["summary"], + core_entities=m["core_entities"], + summary_embedding=m["summary_embedding"], ) if result: - logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...") + logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...") else: - logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False") - except Exception as e: - logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True) + logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False") + else: + ok = await self.repo.batch_update_community_metadata(metadata_list) + if ok: + logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") + else: + logger.warning(f"[Clustering] 批量写入社区元数据失败") @staticmethod def _new_community_id() -> str: diff --git a/api/app/main.py b/api/app/main.py index 5314b8b6..c794f48a 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -18,7 +18,6 @@ from app.core.logging_config import LoggingConfig, get_logger from app.core.response_utils import fail from app.core.models.scripts.loader import load_models from app.db import get_db_context -from app.repositories.neo4j.index_manager import ensure_indexes # Initialize logging system LoggingConfig.setup_logging() @@ -62,16 +61,6 @@ async def lifespan(app: FastAPI): else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") - # 确保 Neo4j 索引存在(幂等,多环境安全) - try: - report = await ensure_indexes() - if report["errors"]: - logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}") - else: - logger.info(f"Neo4j 索引检查完成 [{report['uri']}]") - except Exception as e: - logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}") - logger.info("应用程序启动完成") yield logger.info("应用程序正在关闭") diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 267ced4f..f9c4bd92 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -23,6 +23,7 @@ from app.repositories.neo4j.cypher_queries import ( CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_METADATA, + BATCH_UPDATE_COMMUNITY_METADATA, ) logger = logging.getLogger(__name__) @@ -257,3 +258,25 @@ class CommunityRepository: except Exception as e: logger.error(f"update_community_metadata failed: {e}") return False + + async def batch_update_community_metadata( + self, + communities: List[Dict], + ) -> bool: + """批量更新多个社区的元数据。 + + Args: + communities: 每项包含 community_id, end_user_id, name, summary, + core_entities, summary_embedding + """ + if not communities: + return True + try: + await self.connector.execute_query( + BATCH_UPDATE_COMMUNITY_METADATA, + communities=communities, + ) + return True + except Exception as e: + logger.error(f"batch_update_community_metadata failed: {e}") + return False diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index 55dead1b..29f60fdd 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -42,6 +42,13 @@ async def create_fulltext_indexes(): OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } """) print("✓ Created: summariesFulltext") + + # 创建 Community 索引 + await connector.execute_query(""" + CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] + OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } + """) + print("✓ Created: communitiesFulltext") print("\nFull-text indexes created successfully with BM25 support.") except Exception as e: @@ -124,6 +131,18 @@ async def create_vector_indexes(): }} """) print("✓ Created: dialogue_embedding_index") + + # Community summary embedding index + await connector.execute_query(""" + CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS + FOR (c:Community) + ON c.summary_embedding + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """) + print("✓ Created: community_summary_embedding_index") print("\nVector indexes created successfully!") print("\nExpected performance improvement:") diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 01f2fa6a..16a26b3b 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1136,7 +1136,8 @@ GET_COMMUNITY_MEMBERS = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, e.importance_score AS importance_score, e.activation_value AS activation_value, - e.name_embedding AS name_embedding + e.name_embedding AS name_embedding, + e.aliases AS aliases, e.description AS description ORDER BY coalesce(e.activation_value, 0) DESC """ @@ -1145,7 +1146,8 @@ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->( RETURN c.community_id AS community_id, e.id AS id, e.name AS name, e.entity_type AS entity_type, e.importance_score AS importance_score, e.activation_value AS activation_value, - e.name_embedding AS name_embedding + e.name_embedding AS name_embedding, + e.aliases AS aliases, e.description AS description ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC """ @@ -1171,6 +1173,17 @@ SET c.name = $name, RETURN c.community_id AS community_id """ +BATCH_UPDATE_COMMUNITY_METADATA = """ +UNWIND $communities AS row +MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id}) +SET c.name = row.name, + c.summary = row.summary, + c.core_entities = row.core_entities, + c.summary_embedding = row.summary_embedding, + c.updated_at = datetime() +RETURN c.community_id AS community_id +""" + GET_ENTITIES_PAGE = """ MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community) diff --git a/api/app/repositories/neo4j/index_manager.py b/api/app/repositories/neo4j/index_manager.py deleted file mode 100644 index a1ab6689..00000000 --- a/api/app/repositories/neo4j/index_manager.py +++ /dev/null @@ -1,254 +0,0 @@ -# -*- coding: utf-8 -*- -"""Neo4j 索引管理模块 - -负责检查和创建 Neo4j 全文索引与向量索引。 -支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。 - -用法: - # 作为模块调用(应用启动时) - from app.repositories.neo4j.index_manager import ensure_indexes - await ensure_indexes() - - # 作为独立脚本执行(手动建索引) - python -m app.repositories.neo4j.index_manager -""" - -import asyncio -import logging -from dataclasses import dataclass -from typing import List - -from app.core.config import settings -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -logger = logging.getLogger(__name__) - - -# ───────────────────────────────────────────────────────────── -# 索引定义表 -# ───────────────────────────────────────────────────────────── - -@dataclass -class FulltextIndexDef: - name: str - label: str - properties: List[str] - - -@dataclass -class VectorIndexDef: - name: str - label: str - property: str - dimensions: int - similarity: str = "cosine" - - -# 全文索引清单(现有 + 新增 communities) -FULLTEXT_INDEXES: List[FulltextIndexDef] = [ - FulltextIndexDef("statementsFulltext", "Statement", ["statement"]), - FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]), - FulltextIndexDef("chunksFulltext", "Chunk", ["content"]), - FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]), - FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源 -] - -# 向量索引清单(预留 community 二期) -VECTOR_INDEXES: List[VectorIndexDef] = [ - VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536), - VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536), - VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536), - VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536), - # 二期:社区向量索引 - VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536), -] - - -# ───────────────────────────────────────────────────────────── -# 核心检查 / 创建逻辑 -# ───────────────────────────────────────────────────────────── - -async def _get_existing_indexes(connector: Neo4jConnector) -> set: - """查询 Neo4j 中已存在的索引名称集合""" - rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name") - return {row["name"] for row in rows} - - -async def _ensure_fulltext_index( - connector: Neo4jConnector, - idx: FulltextIndexDef, - existing: set, -) -> str: - """检查并按需创建全文索引,返回操作状态描述""" - if idx.name in existing: - return f"[SKIP] 全文索引已存在: {idx.name}" - - props = ", ".join(f"n.{p}" for p in idx.properties) - cypher = ( - f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS ' - f'FOR (n:{idx.label}) ON EACH [{props}]' - ) - await connector.execute_query(cypher) - return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label} → {idx.properties})" - - -async def _ensure_vector_index( - connector: Neo4jConnector, - idx: VectorIndexDef, - existing: set, -) -> str: - """检查并按需创建向量索引,返回操作状态描述""" - if idx.name in existing: - return f"[SKIP] 向量索引已存在: {idx.name}" - - cypher = ( - f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS " - f"FOR (n:{idx.label}) ON n.{idx.property} " - f"OPTIONS {{indexConfig: {{" - f"`vector.dimensions`: {idx.dimensions}, " - f"`vector.similarity_function`: '{idx.similarity}'" - f"}}}}" - ) - await connector.execute_query(cypher) - return ( - f"[CREATE] 向量索引已创建: {idx.name} " - f"({idx.label}.{idx.property}, dim={idx.dimensions})" - ) - - -async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict: - """ - 检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。 - - Args: - connector: 可选,传入已有连接器;为 None 时自动创建。 - - Returns: - dict: { - "uri": 当前连接的 Neo4j URI, - "fulltext": [操作日志列表], - "vector": [操作日志列表], - "errors": [错误信息列表], - } - """ - own_connector = connector is None - if own_connector: - connector = Neo4jConnector() - - report = { - "uri": settings.NEO4J_URI, - "fulltext": [], - "vector": [], - "errors": [], - } - - try: - # 一次性拉取所有已有索引名 - existing = await _get_existing_indexes(connector) - logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}") - logger.info(f"[IndexManager] 已有索引数量: {len(existing)}") - - # 处理全文索引 - for idx in FULLTEXT_INDEXES: - try: - msg = await _ensure_fulltext_index(connector, idx, existing) - report["fulltext"].append(msg) - logger.info(f"[IndexManager] {msg}") - except Exception as e: - err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}" - report["errors"].append(err) - logger.error(f"[IndexManager] {err}") - - # 处理向量索引 - for idx in VECTOR_INDEXES: - try: - msg = await _ensure_vector_index(connector, idx, existing) - report["vector"].append(msg) - logger.info(f"[IndexManager] {msg}") - except Exception as e: - err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}" - report["errors"].append(err) - logger.error(f"[IndexManager] {err}") - - finally: - if own_connector: - await connector.close() - - return report - - -async def check_indexes(connector: Neo4jConnector | None = None) -> dict: - """ - 仅检查索引状态,不创建任何索引。 - - Returns: - dict: { - "uri": ..., - "present": [已存在的索引名], - "missing_fulltext": [缺失的全文索引名], - "missing_vector": [缺失的向量索引名], - } - """ - own_connector = connector is None - if own_connector: - connector = Neo4jConnector() - - try: - existing = await _get_existing_indexes(connector) - missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing] - missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing] - - return { - "uri": settings.NEO4J_URI, - "present": sorted(existing), - "missing_fulltext": missing_ft, - "missing_vector": missing_vec, - } - finally: - if own_connector: - await connector.close() - - -# ───────────────────────────────────────────────────────────── -# 独立脚本入口 -# ───────────────────────────────────────────────────────────── - -async def _main(): - import sys - - print(f"\n{'='*60}") - print(f"Neo4j 索引管理工具") - print(f"环境: {settings.NEO4J_URI}") - print(f"{'='*60}\n") - - # 先检查 - print(">>> 检查当前索引状态...\n") - status = await check_indexes() - print(f" 已存在索引数: {len(status['present'])}") - if status["missing_fulltext"]: - print(f" 缺失全文索引: {status['missing_fulltext']}") - if status["missing_vector"]: - print(f" 缺失向量索引: {status['missing_vector']}") - - if not status["missing_fulltext"] and not status["missing_vector"]: - print("\n 所有索引均已存在,无需操作。") - return - - # 再创建 - print("\n>>> 开始创建缺失索引...\n") - report = await ensure_indexes() - - for msg in report["fulltext"] + report["vector"]: - print(f" {msg}") - - if report["errors"]: - print("\n[!] 以下索引创建失败:") - for err in report["errors"]: - print(f" {err}") - sys.exit(1) - else: - print("\n 全部索引处理完成。") - - -if __name__ == "__main__": - asyncio.run(_main()) From 5df339b56d292403e84279e9ff8fb28a792377a9 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Mon, 16 Mar 2026 23:09:09 +0800 Subject: [PATCH 14/14] [changes] recovery log --- api/app/main.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/app/main.py b/api/app/main.py index c794f48a..c6256e3c 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -63,6 +63,7 @@ async def lifespan(app: FastAPI): logger.info("应用程序启动完成") yield + # 应用关闭事件 logger.info("应用程序正在关闭")