Merge pull request #555 from SuanmoSuanyangTechnology/feature/cluster
Feature/cluster
This commit is contained in:
@@ -116,6 +116,7 @@ celery_app.conf.update(
|
|||||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||||
|
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -194,6 +194,15 @@ async def get_workspace_end_users(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
|
|
||||||
|
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||||
|
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
|
||||||
|
try:
|
||||||
|
from app.tasks import init_community_clustering_for_users
|
||||||
|
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
|
||||||
|
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||||
return success(data=result, msg="宿主列表获取成功")
|
return success(data=result, msg="宿主列表获取成功")
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from app.services.user_memory_service import (
|
|||||||
UserMemoryService,
|
UserMemoryService,
|
||||||
analytics_memory_types,
|
analytics_memory_types,
|
||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
|
analytics_community_graph_data,
|
||||||
)
|
)
|
||||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
@@ -295,6 +296,42 @@ async def get_graph_data_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||||
|
async def get_community_graph_data_api(
|
||||||
|
end_user_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
|
f"workspace={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||||
|
|
||||||
|
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||||
|
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||||
|
return success(data=result, msg=result.get("message", "查询成功"))
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||||
|
f"nodes={result['statistics']['total_nodes']}, "
|
||||||
|
f"edges={result['statistics']['total_edges']}"
|
||||||
|
)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||||
async def get_end_user_profile(
|
async def get_end_user_profile(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
|||||||
@@ -165,7 +165,9 @@ async def write(
|
|||||||
statement_chunk_edges=all_statement_chunk_edges,
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
connector=neo4j_connector
|
connector=neo4j_connector,
|
||||||
|
config_id=config_id,
|
||||||
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||||
|
|
||||||
|
__all__ = ["LabelPropagationEngine"]
|
||||||
@@ -0,0 +1,484 @@
|
|||||||
|
"""标签传播聚类引擎
|
||||||
|
|
||||||
|
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
|
||||||
|
|
||||||
|
支持两种模式:
|
||||||
|
- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代
|
||||||
|
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from math import sqrt
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 全量迭代最大轮数,防止不收敛
|
||||||
|
MAX_ITERATIONS = 10
|
||||||
|
# 社区摘要核心实体数量
|
||||||
|
CORE_ENTITY_LIMIT = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||||
|
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
|
||||||
|
if not v1 or not v2 or len(v1) != len(v2):
|
||||||
|
return 0.0
|
||||||
|
dot = sum(a * b for a, b in zip(v1, v2))
|
||||||
|
norm1 = sqrt(sum(a * a for a in v1))
|
||||||
|
norm2 = sqrt(sum(b * b for b in v2))
|
||||||
|
if norm1 == 0 or norm2 == 0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
|
def _weighted_vote(
|
||||||
|
neighbors: List[Dict],
|
||||||
|
self_embedding: Optional[List[float]],
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
加权多数投票,选出得票最高的社区。
|
||||||
|
|
||||||
|
权重 = 语义相似度(name_embedding 余弦)* activation_value 加成
|
||||||
|
没有 community_id 的邻居不参与投票。
|
||||||
|
"""
|
||||||
|
votes: Dict[str, float] = {}
|
||||||
|
for nb in neighbors:
|
||||||
|
cid = nb.get("community_id")
|
||||||
|
if not cid:
|
||||||
|
continue
|
||||||
|
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
|
||||||
|
act = nb.get("activation_value") or 0.5
|
||||||
|
# 语义相似度权重 0.6,激活值权重 0.4
|
||||||
|
weight = 0.6 * sem + 0.4 * act
|
||||||
|
votes[cid] = votes.get(cid, 0.0) + weight
|
||||||
|
|
||||||
|
if not votes:
|
||||||
|
return None
|
||||||
|
return max(votes, key=votes.__getitem__)
|
||||||
|
|
||||||
|
|
||||||
|
class LabelPropagationEngine:
|
||||||
|
"""标签传播聚类引擎"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.connector = connector
|
||||||
|
self.repo = CommunityRepository(connector)
|
||||||
|
self.config_id = config_id
|
||||||
|
self.llm_model_id = llm_model_id
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# 公开接口
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
new_entity_ids: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
统一入口:自动判断全量还是增量。
|
||||||
|
|
||||||
|
- 若该用户尚无 Community 节点 → 全量初始化
|
||||||
|
- 否则 → 增量更新(仅处理 new_entity_ids)
|
||||||
|
"""
|
||||||
|
has_communities = await self.repo.has_communities(end_user_id)
|
||||||
|
if not has_communities:
|
||||||
|
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
|
||||||
|
await self.full_clustering(end_user_id)
|
||||||
|
else:
|
||||||
|
if new_entity_ids:
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
|
||||||
|
)
|
||||||
|
await self.incremental_update(new_entity_ids, end_user_id)
|
||||||
|
|
||||||
|
async def full_clustering(self, end_user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
全量标签传播初始化。
|
||||||
|
|
||||||
|
1. 拉取所有实体,初始化每个实体为独立社区
|
||||||
|
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
|
||||||
|
3. 直到标签不再变化或达到 MAX_ITERATIONS
|
||||||
|
4. 将最终标签写入 Neo4j
|
||||||
|
"""
|
||||||
|
entities = await self.repo.get_all_entities(end_user_id)
|
||||||
|
if not entities:
|
||||||
|
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 初始化:每个实体持有自己 id 作为社区标签
|
||||||
|
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
|
||||||
|
embeddings: Dict[str, Optional[List[float]]] = {
|
||||||
|
e["id"]: e.get("name_embedding") for e in entities
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
|
||||||
|
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
|
||||||
|
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
|
||||||
|
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||||
|
|
||||||
|
for iteration in range(MAX_ITERATIONS):
|
||||||
|
changed = 0
|
||||||
|
# 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
|
||||||
|
for entity in entities:
|
||||||
|
eid = entity["id"]
|
||||||
|
# 直接从缓存取邻居,不再发起 Neo4j 查询
|
||||||
|
neighbors = neighbors_cache.get(eid, [])
|
||||||
|
|
||||||
|
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
|
||||||
|
enriched = []
|
||||||
|
for nb in neighbors:
|
||||||
|
nb_copy = dict(nb)
|
||||||
|
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||||
|
enriched.append(nb_copy)
|
||||||
|
|
||||||
|
new_label = _weighted_vote(enriched, embeddings.get(eid))
|
||||||
|
if new_label and new_label != labels[eid]:
|
||||||
|
labels[eid] = new_label
|
||||||
|
changed += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS},"
|
||||||
|
f"标签变化数: {changed}"
|
||||||
|
)
|
||||||
|
if changed == 0:
|
||||||
|
logger.info("[Clustering] 标签已收敛,提前结束迭代")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 将最终标签写入 Neo4j
|
||||||
|
await self._flush_labels(labels, end_user_id)
|
||||||
|
pre_merge_count = len(set(labels.values()))
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||||
|
f"{len(labels)} 个实体,开始后处理合并"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
|
||||||
|
all_community_ids = list(set(labels.values()))
|
||||||
|
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||||
|
f"{len(labels)} 个实体"
|
||||||
|
)
|
||||||
|
# 为所有社区生成元数据
|
||||||
|
# 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
|
||||||
|
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
|
||||||
|
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||||
|
surviving_community_ids = list({
|
||||||
|
e.get("community_id") for e in surviving_communities
|
||||||
|
if e.get("community_id")
|
||||||
|
})
|
||||||
|
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||||
|
for cid in surviving_community_ids:
|
||||||
|
await self._generate_community_metadata(cid, end_user_id)
|
||||||
|
|
||||||
|
async def incremental_update(
|
||||||
|
self, new_entity_ids: List[str], end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
增量更新:只处理新实体及其邻居,不重跑全图。
|
||||||
|
|
||||||
|
1. 对每个新实体查询邻居
|
||||||
|
2. 加权多数投票决定社区归属
|
||||||
|
3. 若邻居无社区 → 创建新社区
|
||||||
|
4. 若邻居分属多个社区 → 评估是否合并
|
||||||
|
"""
|
||||||
|
for entity_id in new_entity_ids:
|
||||||
|
await self._process_single_entity(entity_id, end_user_id)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# 内部方法
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _process_single_entity(
|
||||||
|
self, entity_id: str, end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""处理单个新实体的社区分配。"""
|
||||||
|
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||||
|
|
||||||
|
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||||
|
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
|
||||||
|
|
||||||
|
if not neighbors:
|
||||||
|
# 孤立实体:创建单成员社区
|
||||||
|
new_cid = self._new_community_id()
|
||||||
|
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||||
|
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||||
|
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 统计邻居社区分布
|
||||||
|
community_ids_in_neighbors = set(
|
||||||
|
nb["community_id"] for nb in neighbors if nb.get("community_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
target_cid = _weighted_vote(neighbors, self_embedding)
|
||||||
|
|
||||||
|
if target_cid is None:
|
||||||
|
# 邻居都没有社区,连同新实体一起创建新社区
|
||||||
|
new_cid = self._new_community_id()
|
||||||
|
await self.repo.upsert_community(new_cid, end_user_id)
|
||||||
|
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||||
|
for nb in neighbors:
|
||||||
|
await self.repo.assign_entity_to_community(
|
||||||
|
nb["id"], new_cid, end_user_id
|
||||||
|
)
|
||||||
|
await self.repo.refresh_member_count(new_cid, end_user_id)
|
||||||
|
logger.debug(
|
||||||
|
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||||
|
)
|
||||||
|
await self._generate_community_metadata(new_cid, end_user_id)
|
||||||
|
else:
|
||||||
|
# 加入得票最多的社区
|
||||||
|
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||||
|
await self.repo.refresh_member_count(target_cid, end_user_id)
|
||||||
|
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
|
||||||
|
|
||||||
|
# 若邻居分属多个社区,评估合并
|
||||||
|
if len(community_ids_in_neighbors) > 1:
|
||||||
|
await self._evaluate_merge(
|
||||||
|
list(community_ids_in_neighbors), end_user_id
|
||||||
|
)
|
||||||
|
await self._generate_community_metadata(target_cid, end_user_id)
|
||||||
|
|
||||||
|
async def _evaluate_merge(
|
||||||
|
self, community_ids: List[str], end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
评估多个社区是否应合并。
|
||||||
|
|
||||||
|
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||||
|
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||||
|
|
||||||
|
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||||
|
"""
|
||||||
|
MERGE_THRESHOLD = 0.85
|
||||||
|
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||||
|
|
||||||
|
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||||
|
community_sizes: Dict[str, int] = {}
|
||||||
|
|
||||||
|
if len(community_ids) > BATCH_THRESHOLD:
|
||||||
|
# 批量查询:一次拉取所有社区成员
|
||||||
|
all_members = await self.repo.get_all_community_members_batch(
|
||||||
|
community_ids, end_user_id
|
||||||
|
)
|
||||||
|
for cid in community_ids:
|
||||||
|
members = all_members.get(cid, [])
|
||||||
|
community_sizes[cid] = len(members)
|
||||||
|
valid_embeddings = [
|
||||||
|
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||||
|
]
|
||||||
|
if valid_embeddings:
|
||||||
|
dim = len(valid_embeddings[0])
|
||||||
|
community_embeddings[cid] = [
|
||||||
|
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||||
|
for i in range(dim)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
community_embeddings[cid] = None
|
||||||
|
else:
|
||||||
|
# 增量场景:逐个查询
|
||||||
|
for cid in community_ids:
|
||||||
|
members = await self.repo.get_community_members(cid, end_user_id)
|
||||||
|
community_sizes[cid] = len(members)
|
||||||
|
valid_embeddings = [
|
||||||
|
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||||
|
]
|
||||||
|
if valid_embeddings:
|
||||||
|
dim = len(valid_embeddings[0])
|
||||||
|
community_embeddings[cid] = [
|
||||||
|
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||||
|
for i in range(dim)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
community_embeddings[cid] = None
|
||||||
|
|
||||||
|
# 找出应合并的社区对
|
||||||
|
to_merge: List[tuple] = []
|
||||||
|
cids = list(community_ids)
|
||||||
|
for i in range(len(cids)):
|
||||||
|
for j in range(i + 1, len(cids)):
|
||||||
|
sim = _cosine_similarity(
|
||||||
|
community_embeddings[cids[i]],
|
||||||
|
community_embeddings[cids[j]],
|
||||||
|
)
|
||||||
|
if sim > MERGE_THRESHOLD:
|
||||||
|
to_merge.append((cids[i], cids[j]))
|
||||||
|
|
||||||
|
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||||
|
|
||||||
|
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
|
||||||
|
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
|
||||||
|
# (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并)
|
||||||
|
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
|
||||||
|
|
||||||
|
def get_root(x: str) -> str:
|
||||||
|
"""路径压缩,找到 x 当前所属的根社区。"""
|
||||||
|
while x in merged_into:
|
||||||
|
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
|
||||||
|
x = merged_into[x]
|
||||||
|
return x
|
||||||
|
|
||||||
|
for c1, c2 in to_merge:
|
||||||
|
root1, root2 = get_root(c1), get_root(c2)
|
||||||
|
if root1 == root2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 用合并后的最新平均向量重新验证相似度
|
||||||
|
# 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并
|
||||||
|
current_sim = _cosine_similarity(
|
||||||
|
community_embeddings.get(root1),
|
||||||
|
community_embeddings.get(root2),
|
||||||
|
)
|
||||||
|
if current_sim <= MERGE_THRESHOLD:
|
||||||
|
# 合并后向量已漂移,不再满足阈值,跳过
|
||||||
|
logger.debug(
|
||||||
|
f"[Clustering] 跳过合并 {root1} ↔ {root2},"
|
||||||
|
f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||||
|
dissolve = root2 if keep == root1 else root1
|
||||||
|
merged_into[dissolve] = keep
|
||||||
|
|
||||||
|
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||||
|
for m in members:
|
||||||
|
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||||
|
|
||||||
|
# 合并后重新计算 keep 的平均向量(加权平均)
|
||||||
|
keep_emb = community_embeddings.get(keep)
|
||||||
|
dissolve_emb = community_embeddings.get(dissolve)
|
||||||
|
keep_size = community_sizes.get(keep, 0)
|
||||||
|
dissolve_size = community_sizes.get(dissolve, 0)
|
||||||
|
total_size = keep_size + dissolve_size
|
||||||
|
if keep_emb and dissolve_emb and total_size > 0:
|
||||||
|
dim = len(keep_emb)
|
||||||
|
community_embeddings[keep] = [
|
||||||
|
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
|
||||||
|
for i in range(dim)
|
||||||
|
]
|
||||||
|
community_embeddings[dissolve] = None
|
||||||
|
|
||||||
|
community_sizes[keep] = total_size
|
||||||
|
community_sizes[dissolve] = 0
|
||||||
|
await self.repo.refresh_member_count(keep, end_user_id)
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||||
|
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _flush_labels(
|
||||||
|
self, labels: Dict[str, str], end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""将内存中的标签批量写入 Neo4j。"""
|
||||||
|
# 先创建所有唯一社区节点
|
||||||
|
unique_communities = set(labels.values())
|
||||||
|
for cid in unique_communities:
|
||||||
|
await self.repo.upsert_community(cid, end_user_id)
|
||||||
|
|
||||||
|
# 再批量分配实体
|
||||||
|
for entity_id, community_id in labels.items():
|
||||||
|
await self.repo.assign_entity_to_community(
|
||||||
|
entity_id, community_id, end_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 刷新成员数
|
||||||
|
for cid in unique_communities:
|
||||||
|
await self.repo.refresh_member_count(cid, end_user_id)
|
||||||
|
|
||||||
|
async def _get_entity_embedding(
|
||||||
|
self, entity_id: str, end_user_id: str
|
||||||
|
) -> Optional[List[float]]:
|
||||||
|
"""查询单个实体的 name_embedding。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
|
||||||
|
"RETURN e.name_embedding AS name_embedding",
|
||||||
|
eid=entity_id,
|
||||||
|
uid=end_user_id,
|
||||||
|
)
|
||||||
|
return result[0]["name_embedding"] if result else None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _generate_community_metadata(
|
||||||
|
self, community_id: str, end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
为社区生成并写入元数据:名称、摘要、核心实体。
|
||||||
|
|
||||||
|
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
||||||
|
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
members = await self.repo.get_community_members(community_id, end_user_id)
|
||||||
|
if not members:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 核心实体:按 activation_value 降序取 top-N
|
||||||
|
sorted_members = sorted(
|
||||||
|
members,
|
||||||
|
key=lambda m: m.get("activation_value") or 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
|
all_names = [m["name"] for m in members if m.get("name")]
|
||||||
|
|
||||||
|
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
||||||
|
summary = f"包含实体:{', '.join(all_names)}"
|
||||||
|
|
||||||
|
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
||||||
|
if self.llm_model_id:
|
||||||
|
try:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
|
entity_list_str = "、".join(all_names)
|
||||||
|
prompt = (
|
||||||
|
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
||||||
|
f"请为这组实体所代表的主题:\n"
|
||||||
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
|
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||||
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
|
)
|
||||||
|
with get_db_context() as db:
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||||
|
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
|
for line in text.strip().splitlines():
|
||||||
|
if line.startswith("名称:"):
|
||||||
|
name = line[3:].strip()
|
||||||
|
elif line.startswith("摘要:"):
|
||||||
|
summary = line[3:].strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||||
|
|
||||||
|
await self.repo.update_community_metadata(
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
name=name,
|
||||||
|
summary=summary,
|
||||||
|
core_entities=core_entities,
|
||||||
|
)
|
||||||
|
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _new_community_id() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
194
api/app/repositories/neo4j/community_repository.py
Normal file
194
api/app/repositories/neo4j/community_repository.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""Community 节点仓库
|
||||||
|
|
||||||
|
管理 Neo4j 中 Community 节点及 BELONGS_TO_COMMUNITY 边的 CRUD 操作。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.repositories.neo4j.cypher_queries import (
|
||||||
|
COMMUNITY_NODE_UPSERT,
|
||||||
|
ENTITY_JOIN_COMMUNITY,
|
||||||
|
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||||
|
GET_ENTITY_NEIGHBORS,
|
||||||
|
GET_ALL_ENTITIES_FOR_USER,
|
||||||
|
GET_COMMUNITY_MEMBERS,
|
||||||
|
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||||
|
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||||
|
CHECK_USER_HAS_COMMUNITIES,
|
||||||
|
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||||
|
UPDATE_COMMUNITY_METADATA,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CommunityRepository:
|
||||||
|
def __init__(self, connector: Neo4jConnector):
|
||||||
|
self.connector = connector
|
||||||
|
|
||||||
|
async def upsert_community(
|
||||||
|
self, community_id: str, end_user_id: str, member_count: int = 0
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""创建或更新 Community 节点,返回 community_id。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
COMMUNITY_NODE_UPSERT,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
member_count=member_count,
|
||||||
|
)
|
||||||
|
return result[0]["community_id"] if result else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"upsert_community failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def assign_entity_to_community(
|
||||||
|
self, entity_id: str, community_id: str, end_user_id: str
|
||||||
|
) -> bool:
|
||||||
|
"""将实体关联到社区(先解除旧关联,再建立新关联)。"""
|
||||||
|
try:
|
||||||
|
await self.connector.execute_query(
|
||||||
|
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||||
|
entity_id=entity_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
ENTITY_JOIN_COMMUNITY,
|
||||||
|
entity_id=entity_id,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
return bool(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"assign_entity_to_community failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def get_entity_neighbors(
|
||||||
|
self, entity_id: str, end_user_id: str
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""查询实体的直接邻居及其社区归属。"""
|
||||||
|
try:
|
||||||
|
return await self.connector.execute_query(
|
||||||
|
GET_ENTITY_NEIGHBORS,
|
||||||
|
entity_id=entity_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_entity_neighbors failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_all_entity_neighbors_batch(
|
||||||
|
self, end_user_id: str
|
||||||
|
) -> Dict[str, List[Dict]]:
|
||||||
|
"""一次性批量拉取该用户下所有实体的邻居,返回 {entity_id: [neighbors]} 字典。
|
||||||
|
用于全量聚类预加载,避免每个实体单独查询。"""
|
||||||
|
try:
|
||||||
|
rows = await self.connector.execute_query(
|
||||||
|
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
result: Dict[str, List[Dict]] = {}
|
||||||
|
for row in rows:
|
||||||
|
eid = row["entity_id"]
|
||||||
|
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
|
||||||
|
result.setdefault(eid, []).append(neighbor)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_all_entity_neighbors_batch failed: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def get_all_entities(self, end_user_id: str) -> List[Dict]:
|
||||||
|
"""拉取某用户下所有实体及其当前社区归属。"""
|
||||||
|
try:
|
||||||
|
return await self.connector.execute_query(
|
||||||
|
GET_ALL_ENTITIES_FOR_USER,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_all_entities failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_community_members(
|
||||||
|
self, community_id: str, end_user_id: str
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""查询社区成员列表。"""
|
||||||
|
try:
|
||||||
|
return await self.connector.execute_query(
|
||||||
|
GET_COMMUNITY_MEMBERS,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_community_members failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_all_community_members_batch(
|
||||||
|
self, community_ids: List[str], end_user_id: str
|
||||||
|
) -> Dict[str, List[Dict]]:
|
||||||
|
"""批量查询多个社区的成员,返回 {community_id: [members]} 字典。"""
|
||||||
|
try:
|
||||||
|
rows = await self.connector.execute_query(
|
||||||
|
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||||
|
community_ids=community_ids,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
result: Dict[str, List[Dict]] = {}
|
||||||
|
for row in rows:
|
||||||
|
cid = row["community_id"]
|
||||||
|
result.setdefault(cid, []).append(row)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_all_community_members_batch failed: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def has_communities(self, end_user_id: str) -> bool:
|
||||||
|
"""检查该用户是否已有 Community 节点(用于判断全量 vs 增量)。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
CHECK_USER_HAS_COMMUNITIES,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
return result[0]["community_count"] > 0 if result else False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"has_communities failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def refresh_member_count(
|
||||||
|
self, community_id: str, end_user_id: str
|
||||||
|
) -> int:
|
||||||
|
"""重新统计并更新社区成员数,返回最新数量。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
return result[0]["member_count"] if result else 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"refresh_member_count failed: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def update_community_metadata(
|
||||||
|
self,
|
||||||
|
community_id: str,
|
||||||
|
end_user_id: str,
|
||||||
|
name: str,
|
||||||
|
summary: str,
|
||||||
|
core_entities: List[str],
|
||||||
|
) -> bool:
|
||||||
|
"""更新社区的名称、摘要和核心实体列表。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
UPDATE_COMMUNITY_METADATA,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
name=name,
|
||||||
|
summary=summary,
|
||||||
|
core_entities=core_entities,
|
||||||
|
)
|
||||||
|
return bool(result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"update_community_metadata failed: {e}")
|
||||||
|
return False
|
||||||
@@ -1059,3 +1059,146 @@ Graph_Node_query = """
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Community 节点 & BELONGS_TO_COMMUNITY 边
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# ─── Community 聚类相关 Cypher 模板 ───────────────────────────────────────────
|
||||||
|
|
||||||
|
COMMUNITY_NODE_UPSERT = """
|
||||||
|
MERGE (c:Community {community_id: $community_id})
|
||||||
|
SET c.end_user_id = $end_user_id,
|
||||||
|
c.member_count = $member_count,
|
||||||
|
c.updated_at = datetime()
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
ENTITY_JOIN_COMMUNITY = """
|
||||||
|
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||||
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
|
MERGE (e)-[:BELONGS_TO_COMMUNITY]->(c)
|
||||||
|
SET c.updated_at = datetime()
|
||||||
|
RETURN e.id AS entity_id, c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
ENTITY_LEAVE_ALL_COMMUNITIES = """
|
||||||
|
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||||
|
MATCH (e)-[r:BELONGS_TO_COMMUNITY]->(:Community)
|
||||||
|
DELETE r
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ENTITY_NEIGHBORS = """
|
||||||
|
MATCH (e:ExtractedEntity {id: $entity_id, end_user_id: $end_user_id})
|
||||||
|
|
||||||
|
// 来源一:直接关系邻居(EXTRACTED_RELATIONSHIP 边)
|
||||||
|
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
|
||||||
|
// 来源二:同 Statement 共现邻居(REFERENCES_ENTITY 边)
|
||||||
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
WHERE nb2.id <> e.id
|
||||||
|
|
||||||
|
WITH collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||||
|
UNWIND all_neighbors AS nb
|
||||||
|
WITH nb WHERE nb IS NOT NULL
|
||||||
|
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
RETURN DISTINCT
|
||||||
|
nb.id AS id,
|
||||||
|
nb.name AS name,
|
||||||
|
nb.name_embedding AS name_embedding,
|
||||||
|
nb.activation_value AS activation_value,
|
||||||
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ALL_ENTITIES_FOR_USER = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.name_embedding AS name_embedding,
|
||||||
|
e.activation_value AS activation_value,
|
||||||
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_COMMUNITY_MEMBERS = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||||
|
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||||
|
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||||
|
e.name_embedding AS name_embedding
|
||||||
|
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
WHERE c.community_id IN $community_ids
|
||||||
|
RETURN c.community_id AS community_id,
|
||||||
|
e.id AS id,
|
||||||
|
e.name_embedding AS name_embedding,
|
||||||
|
e.activation_value AS activation_value
|
||||||
|
"""
|
||||||
|
|
||||||
|
CHECK_USER_HAS_COMMUNITIES = """
|
||||||
|
MATCH (c:Community {end_user_id: $end_user_id})
|
||||||
|
RETURN count(c) AS community_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
UPDATE_COMMUNITY_MEMBER_COUNT = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||||
|
WITH c, count(e) AS cnt
|
||||||
|
SET c.member_count = cnt
|
||||||
|
RETURN c.community_id AS community_id, cnt AS member_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
UPDATE_COMMUNITY_METADATA = """
|
||||||
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
|
SET c.name = $name,
|
||||||
|
c.summary = $summary,
|
||||||
|
c.core_entities = $core_entities,
|
||||||
|
c.updated_at = datetime()
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
||||||
|
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
|
||||||
|
// 来源一:直接关系邻居
|
||||||
|
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
|
||||||
|
// 来源二:同 Statement 共现邻居
|
||||||
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
WHERE nb2.id <> e.id
|
||||||
|
|
||||||
|
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||||
|
UNWIND all_neighbors AS nb
|
||||||
|
WITH e, nb WHERE nb IS NOT NULL
|
||||||
|
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
RETURN DISTINCT
|
||||||
|
e.id AS entity_id,
|
||||||
|
nb.id AS id,
|
||||||
|
nb.name AS name,
|
||||||
|
nb.name_embedding AS name_embedding,
|
||||||
|
nb.activation_value AS activation_value,
|
||||||
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_COMMUNITY_GRAPH_DATA = """
|
||||||
|
MATCH (c:Community {end_user_id: $end_user_id})
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c)
|
||||||
|
OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
RETURN
|
||||||
|
elementId(c) AS c_id,
|
||||||
|
properties(c) AS c_props,
|
||||||
|
elementId(e) AS e_id,
|
||||||
|
properties(e) AS e_props,
|
||||||
|
elementId(b) AS b_id,
|
||||||
|
elementId(e2) AS e2_id,
|
||||||
|
properties(e2) AS e2_props,
|
||||||
|
elementId(r) AS r_id,
|
||||||
|
type(r) AS r_type,
|
||||||
|
properties(r) AS r_props,
|
||||||
|
startNode(r) = e AS r_from_e
|
||||||
|
"""
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
from typing import List
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
@@ -155,7 +157,9 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
entity_edges: List[EntityEntityEdge],
|
entity_edges: List[EntityEntityEdge],
|
||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector
|
connector: Neo4jConnector,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
@@ -288,6 +292,10 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
}
|
}
|
||||||
logger.info("Transaction completed. Summary: %s", summary)
|
logger.info("Transaction completed. Summary: %s", summary)
|
||||||
logger.debug("Full transaction results: %r", results)
|
logger.debug("Full transaction results: %r", results)
|
||||||
|
|
||||||
|
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||||
|
schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -295,3 +303,55 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
print(f"Neo4j integration error: {e}")
|
print(f"Neo4j integration error: {e}")
|
||||||
print("Continuing without database storage...")
|
print("Continuing without database storage...")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def schedule_clustering_after_write(
|
||||||
|
entity_nodes: List,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
写入 Neo4j 成功后,调度后台聚类任务。
|
||||||
|
|
||||||
|
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
|
||||||
|
使用 asyncio.create_task 异步触发,不阻塞写入响应。
|
||||||
|
"""
|
||||||
|
if not entity_nodes:
|
||||||
|
return
|
||||||
|
|
||||||
|
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
|
||||||
|
if not clustering_enabled:
|
||||||
|
logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发")
|
||||||
|
return
|
||||||
|
|
||||||
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
|
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
|
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
|
||||||
|
|
||||||
|
|
||||||
|
async def _trigger_clustering(
|
||||||
|
new_entity_ids: List[str],
|
||||||
|
end_user_id: str,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||||
|
"""
|
||||||
|
connector = None
|
||||||
|
try:
|
||||||
|
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||||
|
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
|
||||||
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
|
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] 聚类触发失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if connector:
|
||||||
|
try:
|
||||||
|
await connector.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|||||||
@@ -1727,6 +1727,150 @@ async def analytics_graph_data(
|
|||||||
|
|
||||||
# 辅助函数
|
# 辅助函数
|
||||||
|
|
||||||
|
async def analytics_community_graph_data(
|
||||||
|
db: Session,
|
||||||
|
end_user_id: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取社区图谱数据,包含 Community 节点、ExtractedEntity 节点及其关系。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 nodes、edges、statistics 的字典,格式与 analytics_graph_data 一致
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user_uuid = uuid.UUID(end_user_id)
|
||||||
|
repo = EndUserRepository(db)
|
||||||
|
end_user = repo.get_by_id(user_uuid)
|
||||||
|
if not end_user:
|
||||||
|
return {
|
||||||
|
"nodes": [], "edges": [],
|
||||||
|
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
|
||||||
|
"message": "用户不存在"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 查询社区节点、实体节点、BELONGS_TO_COMMUNITY 边、实体间关系
|
||||||
|
from app.repositories.neo4j.cypher_queries import GET_COMMUNITY_GRAPH_DATA
|
||||||
|
rows = await _neo4j_connector.execute_query(GET_COMMUNITY_GRAPH_DATA, end_user_id=end_user_id)
|
||||||
|
|
||||||
|
nodes_map: Dict[str, dict] = {}
|
||||||
|
edges_map: Dict[str, dict] = {}
|
||||||
|
# 记录每个 Community 对应的实体 id 列表
|
||||||
|
community_members: Dict[str, list] = {}
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
# Community 节点
|
||||||
|
c_id = row["c_id"]
|
||||||
|
if c_id and c_id not in nodes_map:
|
||||||
|
raw = row["c_props"] or {}
|
||||||
|
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||||
|
"community_id", "end_user_id", "member_count", "updated_at",
|
||||||
|
"name", "summary", "core_entities",
|
||||||
|
) if k in raw}
|
||||||
|
nodes_map[c_id] = {
|
||||||
|
"id": c_id,
|
||||||
|
"label": "Community",
|
||||||
|
"properties": props,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ExtractedEntity 节点 (e)
|
||||||
|
e_id = row["e_id"]
|
||||||
|
if e_id and e_id not in nodes_map:
|
||||||
|
raw = row["e_props"] or {}
|
||||||
|
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||||
|
"name", "end_user_id", "description", "created_at", "entity_type",
|
||||||
|
) if k in raw}
|
||||||
|
# 注入所属社区名称(c 是 e 直接归属的社区)
|
||||||
|
c_raw = row["c_props"] or {}
|
||||||
|
props["community_name"] = _clean_neo4j_value(c_raw.get("name")) or ""
|
||||||
|
nodes_map[e_id] = {
|
||||||
|
"id": e_id,
|
||||||
|
"label": "ExtractedEntity",
|
||||||
|
"properties": props,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ExtractedEntity 节点 (e2,可选)
|
||||||
|
e2_id = row.get("e2_id")
|
||||||
|
if e2_id and e2_id not in nodes_map:
|
||||||
|
raw = row["e2_props"] or {}
|
||||||
|
props = {k: _clean_neo4j_value(raw.get(k)) for k in (
|
||||||
|
"name", "end_user_id", "description", "created_at", "entity_type",
|
||||||
|
) if k in raw}
|
||||||
|
# e2 的社区归属在后处理阶段通过 community_members 补充
|
||||||
|
props["community_name"] = ""
|
||||||
|
nodes_map[e2_id] = {
|
||||||
|
"id": e2_id,
|
||||||
|
"label": "ExtractedEntity",
|
||||||
|
"properties": props,
|
||||||
|
}
|
||||||
|
|
||||||
|
# BELONGS_TO_COMMUNITY 边
|
||||||
|
b_id = row["b_id"]
|
||||||
|
if b_id and b_id not in edges_map:
|
||||||
|
edges_map[b_id] = {
|
||||||
|
"id": b_id,
|
||||||
|
"source": e_id,
|
||||||
|
"target": c_id,
|
||||||
|
}
|
||||||
|
# 收集社区成员 id
|
||||||
|
if c_id and e_id:
|
||||||
|
community_members.setdefault(c_id, [])
|
||||||
|
if e_id not in community_members[c_id]:
|
||||||
|
community_members[c_id].append(e_id)
|
||||||
|
|
||||||
|
# EXTRACTED_RELATIONSHIP 边(可选)
|
||||||
|
r_id = row.get("r_id")
|
||||||
|
if r_id and r_id not in edges_map and e2_id:
|
||||||
|
r_props = {k: _clean_neo4j_value(v) for k, v in (row["r_props"] or {}).items()}
|
||||||
|
source = e_id if row.get("r_from_e") else e2_id
|
||||||
|
target = e2_id if row.get("r_from_e") else e_id
|
||||||
|
edges_map[r_id] = {
|
||||||
|
"id": r_id,
|
||||||
|
"source": source,
|
||||||
|
"target": target,
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes = list(nodes_map.values())
|
||||||
|
edges = list(edges_map.values())
|
||||||
|
|
||||||
|
# 为每个 Community 节点注入 member_entity_ids,同时补全 e2 节点的 community_name
|
||||||
|
for c_id, member_ids in community_members.items():
|
||||||
|
c_node = nodes_map.get(c_id)
|
||||||
|
if c_node:
|
||||||
|
c_node["properties"]["member_entity_ids"] = member_ids
|
||||||
|
c_name = c_node["properties"].get("name") or ""
|
||||||
|
# 补全属于该社区但 community_name 为空的实体(即 e2 节点)
|
||||||
|
for eid in member_ids:
|
||||||
|
e_node = nodes_map.get(eid)
|
||||||
|
if e_node and e_node["label"] == "ExtractedEntity":
|
||||||
|
if not e_node["properties"].get("community_name"):
|
||||||
|
e_node["properties"]["community_name"] = c_name
|
||||||
|
|
||||||
|
node_type_counts: Dict[str, int] = {}
|
||||||
|
for n in nodes:
|
||||||
|
node_type_counts[n["label"]] = node_type_counts.get(n["label"], 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"nodes": nodes,
|
||||||
|
"edges": edges,
|
||||||
|
"statistics": {
|
||||||
|
"total_nodes": len(nodes),
|
||||||
|
"total_edges": len(edges),
|
||||||
|
"node_types": node_type_counts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||||
|
return {
|
||||||
|
"nodes": [], "edges": [],
|
||||||
|
"statistics": {"total_nodes": 0, "total_edges": 0, "node_types": {}, "edge_types": {}},
|
||||||
|
"message": "无效的用户ID格式"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取社区图谱数据失败: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]:
|
async def _extract_node_properties(label: str, properties: Dict[str, Any],node_id: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
根据节点类型提取需要的属性字段
|
根据节点类型提取需要的属性字段
|
||||||
|
|||||||
131
api/app/tasks.py
131
api/app/tasks.py
@@ -2662,3 +2662,134 @@ def write_perceptual_memory(
|
|||||||
file_url,
|
file_url,
|
||||||
file_message,
|
file_message,
|
||||||
))
|
))
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# 社区聚类补全任务(触发型)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
name="app.tasks.init_community_clustering_for_users",
|
||||||
|
bind=True,
|
||||||
|
ignore_result=False,
|
||||||
|
max_retries=0,
|
||||||
|
acks_late=False,
|
||||||
|
time_limit=7200, # 2小时硬超时
|
||||||
|
soft_time_limit=6900,
|
||||||
|
)
|
||||||
|
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||||
|
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
||||||
|
|
||||||
|
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_ids: 需要检查的用户 ID 列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含任务执行结果的字典
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
async def _run() -> Dict[str, Any]:
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||||
|
|
||||||
|
initialized = 0
|
||||||
|
skipped = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
repo = CommunityRepository(connector)
|
||||||
|
|
||||||
|
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
||||||
|
user_llm_map: Dict[str, Optional[str]] = {}
|
||||||
|
try:
|
||||||
|
with get_db_context() as db:
|
||||||
|
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
batch_configs = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||||
|
for uid, cfg_info in batch_configs.items():
|
||||||
|
config_id = cfg_info.get("memory_config_id")
|
||||||
|
if config_id:
|
||||||
|
try:
|
||||||
|
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
||||||
|
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
|
||||||
|
user_llm_map[uid] = None
|
||||||
|
else:
|
||||||
|
user_llm_map[uid] = None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
|
||||||
|
|
||||||
|
for end_user_id in end_user_ids:
|
||||||
|
try:
|
||||||
|
# 已有社区节点则跳过
|
||||||
|
has_communities = await repo.has_communities(end_user_id)
|
||||||
|
if has_communities:
|
||||||
|
skipped += 1
|
||||||
|
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查是否有 ExtractedEntity 节点
|
||||||
|
entities = await repo.get_all_entities(end_user_id)
|
||||||
|
if not entities:
|
||||||
|
skipped += 1
|
||||||
|
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 每个用户使用自己的 llm_model_id
|
||||||
|
llm_model_id = user_llm_map.get(end_user_id)
|
||||||
|
engine = LabelPropagationEngine(
|
||||||
|
connector=connector,
|
||||||
|
llm_model_id=llm_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||||
|
await engine.full_clustering(end_user_id)
|
||||||
|
initialized += 1
|
||||||
|
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed += 1
|
||||||
|
logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}")
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"initialized": initialized,
|
||||||
|
"skipped": skipped,
|
||||||
|
"failed": failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
try:
|
||||||
|
import nest_asyncio
|
||||||
|
nest_asyncio.apply()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
loop = set_asyncio_event_loop()
|
||||||
|
result = loop.run_until_complete(_run())
|
||||||
|
result["elapsed_time"] = time.time() - start_time
|
||||||
|
result["task_id"] = self.request.id
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"status": "FAILURE",
|
||||||
|
"error": str(e),
|
||||||
|
"elapsed_time": time.time() - start_time,
|
||||||
|
"task_id": self.request.id,
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user