[changes] Initial stage of community integration

This commit is contained in:
lanceyq
2026-03-11 18:04:04 +08:00
parent 5b431400be
commit fc58ac0408
3 changed files with 115 additions and 47 deletions

View File

@@ -141,8 +141,18 @@ class LabelPropagationEngine:
# 将最终标签写入 Neo4j
await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values()))
logger.info(
f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区,"
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体,开始后处理合并"
)
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id)
logger.info(
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体"
)
@@ -221,30 +231,50 @@ class LabelPropagationEngine:
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
合并时保留成员数最多的社区,其余成员迁移过来。
全量场景(社区数 > 20使用批量查询避免 N 次数据库往返。
"""
MERGE_THRESHOLD = 0.75
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
community_embeddings: Dict[str, Optional[List[float]]] = {}
community_sizes: Dict[str, int] = {}
for cid in community_ids:
members = await self.repo.get_community_members(cid, end_user_id)
community_sizes[cid] = len(members)
# 计算社区成员 embedding 的平均向量
valid_embeddings = [
m["name_embedding"]
for m in members
if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
avg = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
if len(community_ids) > BATCH_THRESHOLD:
# 批量查询:一次拉取所有社区成员
all_members = await self.repo.get_all_community_members_batch(
community_ids, end_user_id
)
for cid in community_ids:
members = all_members.get(cid, [])
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
community_embeddings[cid] = avg
else:
community_embeddings[cid] = None
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
else:
# 增量场景:逐个查询
for cid in community_ids:
members = await self.repo.get_community_members(cid, end_user_id)
community_sizes[cid] = len(members)
valid_embeddings = [
m["name_embedding"] for m in members if m.get("name_embedding")
]
if valid_embeddings:
dim = len(valid_embeddings[0])
community_embeddings[cid] = [
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
for i in range(dim)
]
else:
community_embeddings[cid] = None
# 找出应合并的社区对
to_merge: List[tuple] = []
@@ -258,14 +288,32 @@ class LabelPropagationEngine:
if sim > MERGE_THRESHOLD:
to_merge.append((cids[i], cids[j]))
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
# 执行合并:用 union-find 思路避免重复迁移已被合并的社区
# 维护一个 canonical 映射,确保链式合并正确收敛
canonical: Dict[str, str] = {cid: cid for cid in cids}
def find(x: str) -> str:
while canonical[x] != x:
canonical[x] = canonical[canonical[x]]
x = canonical[x]
return x
for c1, c2 in to_merge:
keep = c1 if community_sizes.get(c1, 0) >= community_sizes.get(c2, 0) else c2
dissolve = c2 if keep == c1 else c1
root1, root2 = find(c1), find(c2)
if root1 == root2:
continue # 已经在同一社区,跳过
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
dissolve = root2 if keep == root1 else root1
canonical[dissolve] = keep
members = await self.repo.get_community_members(dissolve, end_user_id)
for m in members:
await self.repo.assign_entity_to_community(
m["id"], keep, end_user_id
)
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
# 更新 sizes 以便后续合并决策准确
community_sizes[keep] = community_sizes.get(keep, 0) + len(members)
community_sizes[dissolve] = 0
await self.repo.refresh_member_count(keep, end_user_id)
logger.info(
f"[Clustering] 社区合并: {dissolve}{keep}"

View File

@@ -14,6 +14,7 @@ from app.repositories.neo4j.cypher_queries import (
GET_ENTITY_NEIGHBORS,
GET_ALL_ENTITIES_FOR_USER,
GET_COMMUNITY_MEMBERS,
GET_ALL_COMMUNITY_MEMBERS_BATCH,
CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT,
)
@@ -101,6 +102,25 @@ class CommunityRepository:
logger.error(f"get_community_members failed: {e}")
return []
async def get_all_community_members_batch(
self, community_ids: List[str], end_user_id: str
) -> Dict[str, List[Dict]]:
"""批量查询多个社区的成员,返回 {community_id: [members]} 字典。"""
try:
rows = await self.connector.execute_query(
GET_ALL_COMMUNITY_MEMBERS_BATCH,
community_ids=community_ids,
end_user_id=end_user_id,
)
result: Dict[str, List[Dict]] = {}
for row in rows:
cid = row["community_id"]
result.setdefault(cid, []).append(row)
return result
except Exception as e:
logger.error(f"get_all_community_members_batch failed: {e}")
return {}
async def has_communities(self, end_user_id: str) -> bool:
"""检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。"""
try:

View File

@@ -1065,26 +1065,6 @@ Graph_Node_query = """
# Community 节点 & BELONGS_TO_COMMUNITY 边
# ============================================================
COMMUNITY_NODE_SAVE = """
MERGE (c:Community {community_id: $community_id})
SET c.end_user_id = $end_user_id,
c.formed_at = $formed_at,
c.updated_at = datetime(),
c.status = $status,
c.member_count = $member_count
RETURN c.community_id AS community_id
"""
COMMUNITY_ADD_MEMBER = """
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c)
SET c.updated_at = datetime(),
c.member_count = $member_count
"""
# ─── Community 聚类相关 Cypher 模板 ───────────────────────────────────────────
COMMUNITY_NODE_UPSERT = """
@@ -1111,12 +1091,23 @@ DELETE r
GET_ENTITY_NEIGHBORS = """
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb:ExtractedEntity {end_user_id: $end_user_id})
// 来源一直接关系邻居EXTRACTED_RELATIONSHIP 边)
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
// 来源二:同 Statement 共现邻居REFERENCES_ENTITY 边)
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
WHERE nb2.id <> e.id
WITH collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
UNWIND all_neighbors AS nb
WITH nb WHERE nb IS NOT NULL
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN DISTINCT
nb.id AS id,
nb.name AS name,
nb.name_embedding AS name_embedding,
nb.id AS id,
nb.name AS name,
nb.name_embedding AS name_embedding,
nb.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
"""
@@ -1139,6 +1130,15 @@ RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
ORDER BY coalesce(e.activation_value, 0) DESC
"""
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
WHERE c.community_id IN $community_ids
RETURN c.community_id AS community_id,
e.id AS id,
e.name_embedding AS name_embedding,
e.activation_value AS activation_value
"""
CHECK_USER_HAS_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id})
RETURN count(c) AS community_count