[add] Create trigger events for the purpose of completing the existing data

This commit is contained in:
lanceyq
2026-03-13 14:43:29 +08:00
parent f6d929ab7a
commit 6a0ee22d81
4 changed files with 206 additions and 17 deletions

View File

@@ -165,8 +165,15 @@ class LabelPropagationEngine:
f"{len(labels)} 个实体"
)
# 为所有社区生成元数据
unique_communities = list(set(labels.values()))
for cid in unique_communities:
# 注意_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({
e.get("community_id") for e in surviving_communities
if e.get("community_id")
})
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
for cid in surviving_community_ids:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update(
@@ -249,7 +256,7 @@ class LabelPropagationEngine:
全量场景(社区数 > 20使用批量查询避免 N 次数据库往返。
"""
MERGE_THRESHOLD = 0.75
MERGE_THRESHOLD = 0.85
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
community_embeddings: Dict[str, Optional[List[float]]] = {}
@@ -305,34 +312,65 @@ class LabelPropagationEngine:
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
# 执行合并:用 union-find 思路避免重复迁移已被合并社区
# 维护一个 canonical 映射,确保链式合并正确收敛
canonical: Dict[str, str] = {cid: cid for cid in cids}
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
# A≈B、B≈C 不代表 A≈C不能因传递性把 A/B/C 全部合并)
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
def find(x: str) -> str:
while canonical[x] != x:
canonical[x] = canonical[canonical[x]]
x = canonical[x]
def get_root(x: str) -> str:
"""路径压缩,找到 x 当前所属的根社区。"""
while x in merged_into:
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
x = merged_into[x]
return x
for c1, c2 in to_merge:
root1, root2 = find(c1), find(c2)
root1, root2 = get_root(c1), get_root(c2)
if root1 == root2:
continue # 已经在同一社区,跳过
continue
# 用合并后的最新平均向量重新验证相似度
# 防止链式传递A≈B 合并后 B 的向量已更新C 必须和新 B 相似才能合并
current_sim = _cosine_similarity(
community_embeddings.get(root1),
community_embeddings.get(root2),
)
if current_sim <= MERGE_THRESHOLD:
# 合并后向量已漂移,不再满足阈值,跳过
logger.debug(
f"[Clustering] 跳过合并 {root1}{root2}"
f"当前相似度 {current_sim:.3f}{MERGE_THRESHOLD}"
)
continue
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
dissolve = root2 if keep == root1 else root1
canonical[dissolve] = keep
merged_into[dissolve] = keep
members = await self.repo.get_community_members(dissolve, end_user_id)
for m in members:
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
# 更新 sizes 以便后续合并决策准确
community_sizes[keep] = community_sizes.get(keep, 0) + len(members)
# 合并后重新计算 keep 的平均向量(加权平均)
keep_emb = community_embeddings.get(keep)
dissolve_emb = community_embeddings.get(dissolve)
keep_size = community_sizes.get(keep, 0)
dissolve_size = community_sizes.get(dissolve, 0)
total_size = keep_size + dissolve_size
if keep_emb and dissolve_emb and total_size > 0:
dim = len(keep_emb)
community_embeddings[keep] = [
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
for i in range(dim)
]
community_embeddings[dissolve] = None
community_sizes[keep] = total_size
community_sizes[dissolve] = 0
await self.repo.refresh_member_count(keep, end_user_id)
logger.info(
f"[Clustering] 社区合并: {dissolve}{keep}"
f"迁移 {len(members)} 个成员"
f"相似度={current_sim:.3f}迁移 {len(members)} 个成员"
)
async def _flush_labels(