diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 80e238fd..cb6e5804 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -141,8 +141,18 @@ class LabelPropagationEngine: # 将最终标签写入 Neo4j await self._flush_labels(labels, end_user_id) + pre_merge_count = len(set(labels.values())) logger.info( - f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区," + f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区," + f"{len(labels)} 个实体,开始后处理合并" + ) + + # 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度) + all_community_ids = list(set(labels.values())) + await self._evaluate_merge(all_community_ids, end_user_id) + + logger.info( + f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"{len(labels)} 个实体" ) @@ -221,30 +231,50 @@ class LabelPropagationEngine: 策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。 合并时保留成员数最多的社区,其余成员迁移过来。 + + 全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。 """ MERGE_THRESHOLD = 0.75 + BATCH_THRESHOLD = 20 # 超过此数量走批量查询 community_embeddings: Dict[str, Optional[List[float]]] = {} community_sizes: Dict[str, int] = {} - for cid in community_ids: - members = await self.repo.get_community_members(cid, end_user_id) - community_sizes[cid] = len(members) - # 计算社区成员 embedding 的平均向量 - valid_embeddings = [ - m["name_embedding"] - for m in members - if m.get("name_embedding") - ] - if valid_embeddings: - dim = len(valid_embeddings[0]) - avg = [ - sum(e[i] for e in valid_embeddings) / len(valid_embeddings) - for i in range(dim) + if len(community_ids) > BATCH_THRESHOLD: + # 批量查询:一次拉取所有社区成员 + all_members = await self.repo.get_all_community_members_batch( + community_ids, end_user_id + ) + for cid in community_ids: + members = all_members.get(cid, []) + community_sizes[cid] = len(members) + valid_embeddings = [ + m["name_embedding"] for m in members if m.get("name_embedding") ] - community_embeddings[cid] = avg - else: - community_embeddings[cid] = None + if valid_embeddings: + dim = len(valid_embeddings[0]) + community_embeddings[cid] = [ + sum(e[i] for e in valid_embeddings) / len(valid_embeddings) + for i in range(dim) + ] + else: + community_embeddings[cid] = None + else: + # 增量场景:逐个查询 + for cid in community_ids: + members = await self.repo.get_community_members(cid, end_user_id) + community_sizes[cid] = len(members) + valid_embeddings = [ + m["name_embedding"] for m in members if m.get("name_embedding") + ] + if valid_embeddings: + dim = len(valid_embeddings[0]) + community_embeddings[cid] = [ + sum(e[i] for e in valid_embeddings) / len(valid_embeddings) + for i in range(dim) + ] + else: + community_embeddings[cid] = None # 找出应合并的社区对 to_merge: List[tuple] = [] @@ -258,14 +288,32 @@ class LabelPropagationEngine: if sim > MERGE_THRESHOLD: to_merge.append((cids[i], cids[j])) + logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区") + + # 执行合并:用 union-find 思路避免重复迁移已被合并的社区 + # 维护一个 canonical 映射,确保链式合并正确收敛 + canonical: Dict[str, str] = {cid: cid for cid in cids} + + def find(x: str) -> str: + while canonical[x] != x: + canonical[x] = canonical[canonical[x]] + x = canonical[x] + return x + for c1, c2 in to_merge: - keep = c1 if community_sizes.get(c1, 0) >= community_sizes.get(c2, 0) else c2 - dissolve = c2 if keep == c1 else c1 + root1, root2 = find(c1), find(c2) + if root1 == root2: + continue # 已经在同一社区,跳过 + keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2 + dissolve = root2 if keep == root1 else root1 + canonical[dissolve] = keep + members = await self.repo.get_community_members(dissolve, end_user_id) for m in members: - await self.repo.assign_entity_to_community( - m["id"], keep, end_user_id - ) + await self.repo.assign_entity_to_community(m["id"], keep, end_user_id) + # 更新 sizes 以便后续合并决策准确 + community_sizes[keep] = community_sizes.get(keep, 0) + len(members) + community_sizes[dissolve] = 0 await self.repo.refresh_member_count(keep, end_user_id) logger.info( f"[Clustering] 社区合并: {dissolve} → {keep}," diff --git a/api/app/repositories/neo4j/community_repository.py b/api/app/repositories/neo4j/community_repository.py index 16e30a10..2a1f4f2b 100644 --- a/api/app/repositories/neo4j/community_repository.py +++ b/api/app/repositories/neo4j/community_repository.py @@ -14,6 +14,7 @@ from app.repositories.neo4j.cypher_queries import ( GET_ENTITY_NEIGHBORS, GET_ALL_ENTITIES_FOR_USER, GET_COMMUNITY_MEMBERS, + GET_ALL_COMMUNITY_MEMBERS_BATCH, CHECK_USER_HAS_COMMUNITIES, UPDATE_COMMUNITY_MEMBER_COUNT, ) @@ -101,6 +102,25 @@ class CommunityRepository: logger.error(f"get_community_members failed: {e}") return [] + async def get_all_community_members_batch( + self, community_ids: List[str], end_user_id: str + ) -> Dict[str, List[Dict]]: + """批量查询多个社区的成员,返回 {community_id: [members]} 字典。""" + try: + rows = await self.connector.execute_query( + GET_ALL_COMMUNITY_MEMBERS_BATCH, + community_ids=community_ids, + end_user_id=end_user_id, + ) + result: Dict[str, List[Dict]] = {} + for row in rows: + cid = row["community_id"] + result.setdefault(cid, []).append(row) + return result + except Exception as e: + logger.error(f"get_all_community_members_batch failed: {e}") + return {} + async def has_communities(self, end_user_id: str) -> bool: """检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。""" try: diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 947097a2..84889d65 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1065,26 +1065,6 @@ Graph_Node_query = """ # Community 节点 & BELONGS_TO_COMMUNITY 边 # ============================================================ -COMMUNITY_NODE_SAVE = """ -MERGE (c:Community {community_id: $community_id}) -SET c.end_user_id = $end_user_id, - c.formed_at = $formed_at, - c.updated_at = datetime(), - c.status = $status, - c.member_count = $member_count -RETURN c.community_id AS community_id -""" - -COMMUNITY_ADD_MEMBER = """ -MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) -MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) -MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c) -SET c.updated_at = datetime(), - c.member_count = $member_count -""" - - - # ─── Community 聚类相关 Cypher 模板 ─────────────────────────────────────────── COMMUNITY_NODE_UPSERT = """ @@ -1111,12 +1091,23 @@ DELETE r GET_ENTITY_NEIGHBORS = """ MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id}) -OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源一:直接关系邻居(EXTRACTED_RELATIONSHIP 边) +OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id}) + +// 来源二:同 Statement 共现邻居(REFERENCES_ENTITY 边) +OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) +OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id}) +WHERE nb2.id <> e.id + +WITH collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors +UNWIND all_neighbors AS nb +WITH nb WHERE nb IS NOT NULL OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community) RETURN DISTINCT - nb.id AS id, - nb.name AS name, - nb.name_embedding AS name_embedding, + nb.id AS id, + nb.name AS name, + nb.name_embedding AS name_embedding, nb.activation_value AS activation_value, CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id """ @@ -1139,6 +1130,15 @@ RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, ORDER BY coalesce(e.activation_value, 0) DESC """ +GET_ALL_COMMUNITY_MEMBERS_BATCH = """ +MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) +WHERE c.community_id IN $community_ids +RETURN c.community_id AS community_id, + e.id AS id, + e.name_embedding AS name_embedding, + e.activation_value AS activation_value +""" + CHECK_USER_HAS_COMMUNITIES = """ MATCH (c:Community {end_user_id: $end_user_id}) RETURN count(c) AS community_count