Merge branch 'feature/node-aggregation' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/node-aggregation
This commit is contained in:
@@ -196,6 +196,7 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
@@ -265,6 +266,7 @@ class LabelPropagationEngine:
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
await self._generate_community_metadata(new_cid, end_user_id)
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
@@ -276,6 +278,7 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata(target_cid, end_user_id)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
@@ -285,30 +288,50 @@ class LabelPropagationEngine:
|
||||
|
||||
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||
|
||||
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||
"""
|
||||
MERGE_THRESHOLD = 0.75
|
||||
MERGE_THRESHOLD = 0.85
|
||||
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||
|
||||
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||
community_sizes: Dict[str, int] = {}
|
||||
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
# 计算社区成员 embedding 的平均向量
|
||||
valid_embeddings = [
|
||||
m["name_embedding"]
|
||||
for m in members
|
||||
if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
avg = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
if len(community_ids) > BATCH_THRESHOLD:
|
||||
# 批量查询:一次拉取所有社区成员
|
||||
all_members = await self.repo.get_all_community_members_batch(
|
||||
community_ids, end_user_id
|
||||
)
|
||||
for cid in community_ids:
|
||||
members = all_members.get(cid, [])
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
community_embeddings[cid] = avg
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
else:
|
||||
# 增量场景:逐个查询
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
|
||||
# 找出应合并的社区对
|
||||
to_merge: List[tuple] = []
|
||||
@@ -322,18 +345,67 @@ class LabelPropagationEngine:
|
||||
if sim > MERGE_THRESHOLD:
|
||||
to_merge.append((cids[i], cids[j]))
|
||||
|
||||
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||
|
||||
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
|
||||
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
|
||||
# (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并)
|
||||
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
|
||||
|
||||
def get_root(x: str) -> str:
|
||||
"""路径压缩,找到 x 当前所属的根社区。"""
|
||||
while x in merged_into:
|
||||
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
|
||||
x = merged_into[x]
|
||||
return x
|
||||
|
||||
for c1, c2 in to_merge:
|
||||
keep = c1 if community_sizes.get(c1, 0) >= community_sizes.get(c2, 0) else c2
|
||||
dissolve = c2 if keep == c1 else c1
|
||||
root1, root2 = get_root(c1), get_root(c2)
|
||||
if root1 == root2:
|
||||
continue
|
||||
|
||||
# 用合并后的最新平均向量重新验证相似度
|
||||
# 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并
|
||||
current_sim = _cosine_similarity(
|
||||
community_embeddings.get(root1),
|
||||
community_embeddings.get(root2),
|
||||
)
|
||||
if current_sim <= MERGE_THRESHOLD:
|
||||
# 合并后向量已漂移,不再满足阈值,跳过
|
||||
logger.debug(
|
||||
f"[Clustering] 跳过合并 {root1} ↔ {root2},"
|
||||
f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}"
|
||||
)
|
||||
continue
|
||||
|
||||
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||
dissolve = root2 if keep == root1 else root1
|
||||
merged_into[dissolve] = keep
|
||||
|
||||
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||
for m in members:
|
||||
await self.repo.assign_entity_to_community(
|
||||
m["id"], keep, end_user_id
|
||||
)
|
||||
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||
|
||||
# 合并后重新计算 keep 的平均向量(加权平均)
|
||||
keep_emb = community_embeddings.get(keep)
|
||||
dissolve_emb = community_embeddings.get(dissolve)
|
||||
keep_size = community_sizes.get(keep, 0)
|
||||
dissolve_size = community_sizes.get(dissolve, 0)
|
||||
total_size = keep_size + dissolve_size
|
||||
if keep_emb and dissolve_emb and total_size > 0:
|
||||
dim = len(keep_emb)
|
||||
community_embeddings[keep] = [
|
||||
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
|
||||
for i in range(dim)
|
||||
]
|
||||
community_embeddings[dissolve] = None
|
||||
|
||||
community_sizes[keep] = total_size
|
||||
community_sizes[dissolve] = 0
|
||||
await self.repo.refresh_member_count(keep, end_user_id)
|
||||
logger.info(
|
||||
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||
f"迁移 {len(members)} 个成员"
|
||||
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
|
||||
)
|
||||
|
||||
async def _flush_labels(
|
||||
|
||||
Reference in New Issue
Block a user