[changes] Initial stage of community integration
This commit is contained in:
@@ -141,8 +141,18 @@ class LabelPropagationEngine:
|
||||
|
||||
# 将最终标签写入 Neo4j
|
||||
await self._flush_labels(labels, end_user_id)
|
||||
pre_merge_count = len(set(labels.values()))
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区,"
|
||||
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体,开始后处理合并"
|
||||
)
|
||||
|
||||
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
|
||||
all_community_ids = list(set(labels.values()))
|
||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
|
||||
@@ -221,30 +231,50 @@ class LabelPropagationEngine:
|
||||
|
||||
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||
|
||||
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||
"""
|
||||
MERGE_THRESHOLD = 0.75
|
||||
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||
|
||||
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||
community_sizes: Dict[str, int] = {}
|
||||
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
# 计算社区成员 embedding 的平均向量
|
||||
valid_embeddings = [
|
||||
m["name_embedding"]
|
||||
for m in members
|
||||
if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
avg = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
if len(community_ids) > BATCH_THRESHOLD:
|
||||
# 批量查询:一次拉取所有社区成员
|
||||
all_members = await self.repo.get_all_community_members_batch(
|
||||
community_ids, end_user_id
|
||||
)
|
||||
for cid in community_ids:
|
||||
members = all_members.get(cid, [])
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
community_embeddings[cid] = avg
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
else:
|
||||
# 增量场景:逐个查询
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
|
||||
# 找出应合并的社区对
|
||||
to_merge: List[tuple] = []
|
||||
@@ -258,14 +288,32 @@ class LabelPropagationEngine:
|
||||
if sim > MERGE_THRESHOLD:
|
||||
to_merge.append((cids[i], cids[j]))
|
||||
|
||||
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||
|
||||
# 执行合并:用 union-find 思路避免重复迁移已被合并的社区
|
||||
# 维护一个 canonical 映射,确保链式合并正确收敛
|
||||
canonical: Dict[str, str] = {cid: cid for cid in cids}
|
||||
|
||||
def find(x: str) -> str:
|
||||
while canonical[x] != x:
|
||||
canonical[x] = canonical[canonical[x]]
|
||||
x = canonical[x]
|
||||
return x
|
||||
|
||||
for c1, c2 in to_merge:
|
||||
keep = c1 if community_sizes.get(c1, 0) >= community_sizes.get(c2, 0) else c2
|
||||
dissolve = c2 if keep == c1 else c1
|
||||
root1, root2 = find(c1), find(c2)
|
||||
if root1 == root2:
|
||||
continue # 已经在同一社区,跳过
|
||||
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||
dissolve = root2 if keep == root1 else root1
|
||||
canonical[dissolve] = keep
|
||||
|
||||
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||
for m in members:
|
||||
await self.repo.assign_entity_to_community(
|
||||
m["id"], keep, end_user_id
|
||||
)
|
||||
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||
# 更新 sizes 以便后续合并决策准确
|
||||
community_sizes[keep] = community_sizes.get(keep, 0) + len(members)
|
||||
community_sizes[dissolve] = 0
|
||||
await self.repo.refresh_member_count(keep, end_user_id)
|
||||
logger.info(
|
||||
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||
|
||||
Reference in New Issue
Block a user