[changes]
This commit is contained in:
@@ -121,12 +121,18 @@ class LabelPropagationEngine:
|
||||
e["id"]: e.get("name_embedding") for e in entities
|
||||
}
|
||||
|
||||
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
|
||||
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
|
||||
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
|
||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
changed = 0
|
||||
# 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
|
||||
for entity in entities:
|
||||
eid = entity["id"]
|
||||
neighbors = await self.repo.get_entity_neighbors(eid, end_user_id)
|
||||
# 直接从缓存取邻居,不再发起 Neo4j 查询
|
||||
neighbors = neighbors_cache.get(eid, [])
|
||||
|
||||
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
|
||||
enriched = []
|
||||
|
||||
@@ -15,6 +15,7 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
GET_ALL_ENTITIES_FOR_USER,
|
||||
GET_COMMUNITY_MEMBERS,
|
||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
@@ -78,6 +79,26 @@ class CommunityRepository:
|
||||
logger.error(f"get_entity_neighbors failed: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_entity_neighbors_batch(
|
||||
self, end_user_id: str
|
||||
) -> Dict[str, List[Dict]]:
|
||||
"""一次性批量拉取该用户下所有实体的邻居,返回 {entity_id: [neighbors]} 字典。
|
||||
用于全量聚类预加载,避免每个实体单独查询。"""
|
||||
try:
|
||||
rows = await self.connector.execute_query(
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
result: Dict[str, List[Dict]] = {}
|
||||
for row in rows:
|
||||
eid = row["entity_id"]
|
||||
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
|
||||
result.setdefault(eid, []).append(neighbor)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"get_all_entity_neighbors_batch failed: {e}")
|
||||
return {}
|
||||
|
||||
async def get_all_entities(self, end_user_id: str) -> List[Dict]:
|
||||
"""拉取某用户下所有实体及其当前社区归属。"""
|
||||
try:
|
||||
|
||||
@@ -1159,3 +1159,28 @@ SET c.name = $name,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
||||
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源一:直接关系邻居
|
||||
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||
|
||||
// 来源二:同 Statement 共现邻居
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||
WHERE nb2.id <> e.id
|
||||
|
||||
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||
UNWIND all_neighbors AS nb
|
||||
WITH e, nb WHERE nb IS NOT NULL
|
||||
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||
RETURN DISTINCT
|
||||
e.id AS entity_id,
|
||||
nb.id AS id,
|
||||
nb.name AS name,
|
||||
nb.name_embedding AS name_embedding,
|
||||
nb.activation_value AS activation_value,
|
||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user