diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 60c22855..e77ed683 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -116,6 +116,7 @@ celery_app.conf.update( 'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'}, 'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'}, 'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'}, + 'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'}, }, ) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 3bbb5cf7..ce32a519 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -193,7 +193,19 @@ async def get_workspace_end_users( await aio_redis_set(cache_key, json.dumps(result), expire=30) except Exception as e: api_logger.warning(f"Redis 缓存写入失败: {str(e)}") - + + # 触发社区聚类补全任务(异步,不阻塞接口响应) + # 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类 + try: + from app.tasks import init_community_clustering_for_users + init_community_clustering_for_users.apply_async( + kwargs={"end_user_ids": end_user_ids}, + queue="periodic_tasks", + ) + api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}") + except Exception as e: + api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") + api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") return success(data=result, msg="宿主列表获取成功") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 251d4fea..4491b416 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -165,8 +165,15 @@ class LabelPropagationEngine: f"{len(labels)} 个实体" ) # 为所有社区生成元数据 - unique_communities = list(set(labels.values())) - for cid in unique_communities: + # 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区 + # 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID + surviving_communities = await self.repo.get_all_entities(end_user_id) + surviving_community_ids = list({ + e.get("community_id") for e in surviving_communities + if e.get("community_id") + }) + logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") + for cid in surviving_community_ids: await self._generate_community_metadata(cid, end_user_id) async def incremental_update( @@ -249,7 +256,7 @@ class LabelPropagationEngine: 全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。 """ - MERGE_THRESHOLD = 0.75 + MERGE_THRESHOLD = 0.85 BATCH_THRESHOLD = 20 # 超过此数量走批量查询 community_embeddings: Dict[str, Optional[List[float]]] = {} @@ -305,34 +312,65 @@ class LabelPropagationEngine: logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区") - # 执行合并:用 union-find 思路避免重复迁移已被合并的社区 - # 维护一个 canonical 映射,确保链式合并正确收敛 - canonical: Dict[str, str] = {cid: cid for cid in cids} + # 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量 + # 避免 union-find 链式传递导致语义不相关的社区被间接合并 + # (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并) + merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射 - def find(x: str) -> str: - while canonical[x] != x: - canonical[x] = canonical[canonical[x]] - x = canonical[x] + def get_root(x: str) -> str: + """路径压缩,找到 x 当前所属的根社区。""" + while x in merged_into: + merged_into[x] = merged_into.get(merged_into[x], merged_into[x]) + x = merged_into[x] return x for c1, c2 in to_merge: - root1, root2 = find(c1), find(c2) + root1, root2 = get_root(c1), get_root(c2) if root1 == root2: - continue # 已经在同一社区,跳过 + continue + + # 用合并后的最新平均向量重新验证相似度 + # 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并 + current_sim = _cosine_similarity( + community_embeddings.get(root1), + community_embeddings.get(root2), + ) + if current_sim <= MERGE_THRESHOLD: + # 合并后向量已漂移,不再满足阈值,跳过 + logger.debug( + f"[Clustering] 跳过合并 {root1} ↔ {root2}," + f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}" + ) + continue + keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2 dissolve = root2 if keep == root1 else root1 - canonical[dissolve] = keep + merged_into[dissolve] = keep members = await self.repo.get_community_members(dissolve, end_user_id) for m in members: await self.repo.assign_entity_to_community(m["id"], keep, end_user_id) - # 更新 sizes 以便后续合并决策准确 - community_sizes[keep] = community_sizes.get(keep, 0) + len(members) + + # 合并后重新计算 keep 的平均向量(加权平均) + keep_emb = community_embeddings.get(keep) + dissolve_emb = community_embeddings.get(dissolve) + keep_size = community_sizes.get(keep, 0) + dissolve_size = community_sizes.get(dissolve, 0) + total_size = keep_size + dissolve_size + if keep_emb and dissolve_emb and total_size > 0: + dim = len(keep_emb) + community_embeddings[keep] = [ + (keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size + for i in range(dim) + ] + community_embeddings[dissolve] = None + + community_sizes[keep] = total_size community_sizes[dissolve] = 0 await self.repo.refresh_member_count(keep, end_user_id) logger.info( f"[Clustering] 社区合并: {dissolve} → {keep}," - f"迁移 {len(members)} 个成员" + f"相似度={current_sim:.3f},迁移 {len(members)} 个成员" ) async def _flush_labels( diff --git a/api/app/tasks.py b/api/app/tasks.py index 5e1550bd..f737bb48 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2662,3 +2662,141 @@ def write_perceptual_memory( file_url, file_message, )) + + +# ============================================================================= +# 社区聚类补全任务(触发型) +# ============================================================================= + +@celery_app.task( + name="app.tasks.init_community_clustering_for_users", + bind=True, + ignore_result=False, + max_retries=0, + acks_late=False, + time_limit=7200, # 2小时硬超时 + soft_time_limit=6900, +) +def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]: + """触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。 + + 由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。 + + Args: + end_user_ids: 需要检查的用户 ID 列表 + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.neo4j.community_repository import CommunityRepository + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + + logger = get_logger(__name__) + logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}") + + initialized = 0 + skipped = 0 + failed = 0 + + connector = Neo4jConnector() + try: + repo = CommunityRepository(connector) + + # 获取 llm_model_id(从第一个用户的配置中读取,作为全局兜底) + llm_model_id = None + try: + with get_db_context() as db: + from app.services.memory_agent_service import get_end_user_connected_config + from app.services.memory_config_service import MemoryConfigService + for uid in end_user_ids: + try: + connected = get_end_user_connected_config(uid, db) + config_id = connected.get("memory_config_id") + workspace_id = connected.get("workspace_id") + if config_id or workspace_id: + cfg = MemoryConfigService(db).load_memory_config( + config_id=config_id, workspace_id=workspace_id + ) + llm_model_id = str(cfg.llm_model_id) + break + except Exception: + continue + except Exception as e: + logger.warning(f"[CommunityCluster] 获取 LLM 配置失败,将使用兜底值: {e}") + + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + ) + + for end_user_id in end_user_ids: + try: + # 已有社区节点则跳过 + has_communities = await repo.has_communities(end_user_id) + if has_communities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过") + continue + + # 检查是否有 ExtractedEntity 节点 + entities = await repo.get_all_entities(end_user_id) + if not entities: + skipped += 1 + logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过") + continue + + logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类") + await engine.full_clustering(end_user_id) + initialized += 1 + logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") + + except Exception as e: + failed += 1 + logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}") + + finally: + await connector.close() + + logger.info( + f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}" + ) + return { + "status": "SUCCESS", + "initialized": initialized, + "skipped": skipped, + "failed": failed, + } + + try: + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + return result + + except Exception as e: + return { + "status": "FAILURE", + "error": str(e), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + }