Merge branch 'feature/node-aggregation' of github.com:SuanmoSuanyangTechnology/MemoryBear into feature/node-aggregation

This commit is contained in:
lanceyq
2026-03-16 13:11:12 +08:00
7 changed files with 342 additions and 49 deletions

View File

@@ -196,6 +196,7 @@ class LabelPropagationEngine:
await self._evaluate_merge(all_community_ids, end_user_id)
logger.info(
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体"
)
@@ -265,6 +266,7 @@ class LabelPropagationEngine:
logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
)
await self._generate_community_metadata(new_cid, end_user_id)
else:
# 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -276,6 +278,7 @@ class LabelPropagationEngine:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
await self._generate_community_metadata(target_cid, end_user_id)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
@@ -285,30 +288,50 @@ class LabelPropagationEngine:
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
合并时保留成员数最多的社区,其余成员迁移过来。
全量场景(社区数 > 20使用批量查询避免 N 次数据库往返。
"""
MERGE_THRESHOLD = 0.75
MERGE_THRESHOLD = 0.85
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
community_embeddings: Dict[str, Optional[List[float]]] = {}
community_sizes: Dict[str, int] = {}
for cid in community_ids:
members = await self.repo.get_community_members(cid, end_user_id)
community_sizes[cid] = len(members)
# 计算社区成员 embedding 的平均向量
valid_embeddings = [
m["name_embedding"]
for m in members
if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
avg = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
if len(community_ids) > BATCH_THRESHOLD:
# 批量查询:一次拉取所有社区成员
all_members = await self.repo.get_all_community_members_batch(
community_ids, end_user_id
)
for cid in community_ids:
members = all_members.get(cid, [])
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
community_embeddings[cid] = avg
else:
community_embeddings[cid] = None
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
else:
# 增量场景:逐个查询
for cid in community_ids:
members = await self.repo.get_community_members(cid, end_user_id)
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
# 找出应合并的社区对
to_merge: List[tuple] = []
@@ -322,18 +345,67 @@ class LabelPropagationEngine:
if sim > MERGE_THRESHOLD:
to_merge.append((cids[i], cids[j]))
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
# A≈B、B≈C 不代表 A≈C不能因传递性把 A/B/C 全部合并)
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
def get_root(x: str) -> str:
"""路径压缩,找到 x 当前所属的根社区。"""
while x in merged_into:
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
x = merged_into[x]
return x
for c1, c2 in to_merge:
keep = c1 if community_sizes.get(c1, 0) >= community_sizes.get(c2, 0) else c2
dissolve = c2 if keep == c1 else c1
root1, root2 = get_root(c1), get_root(c2)
if root1 == root2:
continue
# 用合并后的最新平均向量重新验证相似度
# 防止链式传递A≈B 合并后 B 的向量已更新C 必须和新 B 相似才能合并
current_sim = _cosine_similarity(
community_embeddings.get(root1),
community_embeddings.get(root2),
)
if current_sim <= MERGE_THRESHOLD:
# 合并后向量已漂移,不再满足阈值,跳过
logger.debug(
f"[Clustering] 跳过合并 {root1}{root2}"
f"当前相似度 {current_sim:.3f}{MERGE_THRESHOLD}"
)
continue
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
dissolve = root2 if keep == root1 else root1
merged_into[dissolve] = keep
members = await self.repo.get_community_members(dissolve, end_user_id)
for m in members:
await self.repo.assign_entity_to_community(
m["id"], keep, end_user_id
)
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
# 合并后重新计算 keep 的平均向量(加权平均)
keep_emb = community_embeddings.get(keep)
dissolve_emb = community_embeddings.get(dissolve)
keep_size = community_sizes.get(keep, 0)
dissolve_size = community_sizes.get(dissolve, 0)
total_size = keep_size + dissolve_size
if keep_emb and dissolve_emb and total_size > 0:
dim = len(keep_emb)
community_embeddings[keep] = [
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
for i in range(dim)
]
community_embeddings[dissolve] = None
community_sizes[keep] = total_size
community_sizes[dissolve] = 0
await self.repo.refresh_member_count(keep, end_user_id)
logger.info(
f"[Clustering] 社区合并: {dissolve}{keep}"
f"迁移 {len(members)} 个成员"
f"相似度={current_sim:.3f}迁移 {len(members)} 个成员"
)
async def _flush_labels(