[changes] Community Clustering Retrieval Module

This commit is contained in:
lanceyq
2026-03-16 12:30:00 +08:00
parent b1a7b58f97
commit f9fb480cc3
11 changed files with 637 additions and 96 deletions

View File

@@ -13,12 +13,15 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_LEAVE_ALL_COMMUNITIES,
GET_ENTITY_NEIGHBORS,
GET_ALL_ENTITIES_FOR_USER,
GET_ENTITIES_PAGE,
GET_COMMUNITY_MEMBERS,
GET_ALL_COMMUNITY_MEMBERS_BATCH,
GET_ALL_ENTITY_NEIGHBORS_BATCH,
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA,
UPDATE_COMMUNITY_METADATA,
)
logger = logging.getLogger(__name__)
@@ -110,6 +113,41 @@ class CommunityRepository:
logger.error(f"get_all_entities failed: {e}")
return []
async def get_entities_page(
self, end_user_id: str, skip: int, limit: int
) -> List[Dict]:
"""分页拉取实体,用于全量聚类分批处理。"""
try:
return await self.connector.execute_query(
GET_ENTITIES_PAGE,
end_user_id=end_user_id,
skip=skip,
limit=limit,
)
except Exception as e:
logger.error(f"get_entities_page failed: {e}")
return []
async def get_entity_neighbors_for_ids(
self, entity_ids: List[str], end_user_id: str
) -> Dict[str, List[Dict]]:
"""批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。"""
try:
rows = await self.connector.execute_query(
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
entity_ids=entity_ids,
end_user_id=end_user_id,
)
result: Dict[str, List[Dict]] = {}
for row in rows:
eid = row["entity_id"]
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
result.setdefault(eid, []).append(neighbor)
return result
except Exception as e:
logger.error(f"get_entity_neighbors_for_ids failed: {e}")
return {}
async def get_community_members(
self, community_id: str, end_user_id: str
) -> List[Dict]:
@@ -177,8 +215,9 @@ class CommunityRepository:
name: str,
summary: str,
core_entities: List[str],
summary_embedding: Optional[List[float]] = None,
) -> bool:
"""更新社区的名称、摘要核心实体列表。"""
"""更新社区的名称、摘要核心实体列表和摘要向量"""
try:
result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA,
@@ -187,6 +226,7 @@ class CommunityRepository:
name=name,
summary=summary,
core_entities=core_entities,
summary_embedding=summary_embedding,
)
return bool(result)
except Exception as e: