From 2319432182a41f2b7bbfc350581bce88fc8720f6 Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 26 Mar 2026 17:19:37 +0800 Subject: [PATCH] [changes] Set up Celery tasks to perform clustering --- api/app/celery_app.py | 3 + .../core/memory/agent/utils/write_tools.py | 50 +++-- .../clustering_engine/label_propagation.py | 199 ++++++++++++------ api/app/tasks.py | 94 +++++++++ 4 files changed, 263 insertions(+), 83 deletions(-) diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 58c89f8f..23fd82ed 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -108,6 +108,9 @@ celery_app.conf.update( 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, + # Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题) + 'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 5829a5c9..55bcb8ba 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -8,6 +8,7 @@ import asyncio import time import uuid from datetime import datetime +from typing import List, Optional from dotenv import load_dotenv @@ -21,7 +22,7 @@ from app.core.memory.utils.log.logging_utils import log_time from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync +from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import MemoryConfig @@ -177,28 +178,33 @@ async def write( if success: logger.info("Successfully saved all data to Neo4j") - # 同步用户别名到 PostgreSQL - try: - # 创建一个临时的 orchestrator 实例来调用同步方法 - temp_orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=neo4j_connector, - embedding_id=embedding_model_id - ) - await temp_orchestrator._update_end_user_other_name(all_entity_nodes, chunked_dialogs) - logger.info("Successfully synced user aliases to PostgreSQL") - except Exception as sync_error: - logger.error(f"Failed to sync user aliases to PostgreSQL: {sync_error}", exc_info=True) - # 不影响主流程 + # 使用 Celery 异步任务触发聚类(不阻塞主流程) + if all_entity_nodes: + try: + from app.tasks import run_incremental_clustering + + end_user_id = all_entity_nodes[0].end_user_id + new_entity_ids = [e.id for e in all_entity_nodes] + + # 异步提交 Celery 任务 + task = run_incremental_clustering.apply_async( + kwargs={ + "end_user_id": end_user_id, + "new_entity_ids": new_entity_ids, + "llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None, + "embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None, + }, + # 设置任务优先级(低优先级,不影响主业务) + priority=3, + ) + logger.info( + f"[Clustering] 增量聚类任务已提交到 Celery - " + f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}" + ) + except Exception as e: + # 聚类任务提交失败不影响主流程 + logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True) - # 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突) - await _trigger_clustering_sync( - all_entity_nodes, - llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None, - embedding_model_id=str( - memory_config.embedding_model_id) if memory_config.embedding_model_id else None, - ) break else: logger.warning("Failed to save some data to Neo4j") diff --git a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py index 0fa6a833..d0b121d7 100644 --- a/api/app/core/memory/storage_services/clustering_engine/label_propagation.py +++ b/api/app/core/memory/storage_services/clustering_engine/label_propagation.py @@ -76,6 +76,9 @@ class LabelPropagationEngine: self.repo = CommunityRepository(connector) self.llm_model_id = llm_model_id self.embedding_model_id = embedding_model_id + # 缓存客户端实例,避免重复初始化 + self._llm_client = None + self._embedder_client = None # ────────────────────────────────────────────────────────────────────────── # 公开接口 @@ -215,8 +218,17 @@ class LabelPropagationEngine: 3. 若邻居无社区 → 创建新社区 4. 若邻居分属多个社区 → 评估是否合并 """ + # 收集所有需要生成元数据的社区ID + communities_to_update = set() + for entity_id in new_entity_ids: - await self._process_single_entity(entity_id, end_user_id) + cid = await self._process_single_entity(entity_id, end_user_id) + if cid: + communities_to_update.add(cid) + + # 批量生成所有社区的元数据 + if communities_to_update: + await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True) # ────────────────────────────────────────────────────────────────────────── # 内部方法 @@ -224,8 +236,13 @@ class LabelPropagationEngine: async def _process_single_entity( self, entity_id: str, end_user_id: str - ) -> None: - """处理单个新实体的社区分配。""" + ) -> Optional[str]: + """ + 处理单个新实体的社区分配。 + + Returns: + str: 分配到的社区ID(如果有) + """ neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id) # 查询自身 embedding(从邻居查询结果中无法获取,需单独查) @@ -237,8 +254,7 @@ class LabelPropagationEngine: await self.repo.upsert_community(new_cid, end_user_id, member_count=1) await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") - await self._generate_community_metadata([new_cid], end_user_id) - return + return new_cid # 统计邻居社区分布 community_ids_in_neighbors = set( @@ -260,7 +276,7 @@ class LabelPropagationEngine: logger.debug( f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" ) - await self._generate_community_metadata([new_cid], end_user_id) + return new_cid else: # 加入得票最多的社区 await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) @@ -272,8 +288,8 @@ class LabelPropagationEngine: await self._evaluate_merge( list(community_ids_in_neighbors), end_user_id ) - # 新实体加入后成员变化,强制重新生成元数据 - await self._generate_community_metadata([target_cid], end_user_id, force=True) + # 返回目标社区ID,稍后批量生成元数据 + return target_cid async def _evaluate_merge( self, community_ids: List[str], end_user_id: str @@ -456,20 +472,19 @@ class LabelPropagationEngine: self, community_ids: List[str], end_user_id: str, force: bool = False ) -> None: """ - 为一个或多个社区生成并写入元数据。 + 为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)。 流程: - 1. 逐个社区调 LLM 生成 name / summary(串行) - 2. 收集所有 summary,一次性批量 embed - 3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata + 1. 批量准备所有社区的 prompt + 2. 并发调用 LLM 生成所有社区的 name / summary + 3. 批量 embed 所有 summary + 4. 批量写入数据库 Args: force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后) """ - from app.db import get_db_context - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - - async def _build_one(cid: str) -> Optional[Dict]: + async def _prepare_one(cid: str) -> Optional[Dict]: + """准备单个社区的数据和 prompt""" try: if not force: check_embedding = bool(self.embedding_model_id) @@ -489,42 +504,32 @@ class LabelPropagationEngine: core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] all_names = [m["name"] for m in members if m.get("name")] + # 默认值 name = "、".join(core_entities[:3]) if core_entities else cid[:8] summary = f"包含实体:{', '.join(all_names)}" + # 准备 LLM prompt(如果配置了 LLM) + prompt = None if self.llm_model_id: - try: - entity_list_str = "\n".join(self._build_entity_lines(members)) - relationships = await self.repo.get_community_relationships(cid, end_user_id) - rel_lines = [ - f"- {r['subject']} → {r['predicate']} → {r['object']}" - for r in relationships - if r.get("subject") and r.get("predicate") and r.get("object") - ] - rel_section = ( - f"\n实体间关系:\n" + "\n".join(rel_lines) - if rel_lines else "" - ) - prompt = ( - f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" - f"请为这组实体所代表的主题:\n" - f"1. 起一个简洁的中文名称(不超过10个字)\n" - f"2. 写一句话摘要(不超过80个字)\n\n" - f"严格按以下格式输出,不要有其他内容:\n" - f"名称:<名称>\n摘要:<摘要>" - ) - with get_db_context() as db: - llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) - response = await llm_client.chat([{"role": "user", "content": prompt}]) - text = response.content if hasattr(response, "content") else str(response) - - for line in text.strip().splitlines(): - if line.startswith("名称:"): - name = line[3:].strip() - elif line.startswith("摘要:"): - summary = line[3:].strip() - except Exception as e: - logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}") + entity_list_str = "\n".join(self._build_entity_lines(members)) + relationships = await self.repo.get_community_relationships(cid, end_user_id) + rel_lines = [ + f"- {r['subject']} → {r['predicate']} → {r['object']}" + for r in relationships + if r.get("subject") and r.get("predicate") and r.get("object") + ] + rel_section = ( + f"\n实体间关系:\n" + "\n".join(rel_lines) + if rel_lines else "" + ) + prompt = ( + f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" + f"请为这组实体所代表的主题:\n" + f"1. 起一个简洁的中文名称(不超过10个字)\n" + f"2. 写一句话摘要(不超过80个字)\n\n" + f"严格按以下格式输出,不要有其他内容:\n" + f"名称:<名称>\n摘要:<摘要>" + ) return { "community_id": cid, @@ -532,14 +537,16 @@ class LabelPropagationEngine: "name": name, "summary": summary, "core_entities": core_entities, + "prompt": prompt, "summary_embedding": None, } except Exception as e: logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True) return None + # --- 阶段1:并发准备所有社区数据 --- results = await asyncio.gather( - *[_build_one(cid) for cid in community_ids], + *[_prepare_one(cid) for cid in community_ids], return_exceptions=True, ) metadata_list = [] @@ -553,19 +560,67 @@ class LabelPropagationEngine: logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}") return - # --- 阶段2:批量生成 summary_embedding --- - if self.embedding_model_id: - try: - summaries = [m["summary"] for m in metadata_list] - with get_db_context() as db: - embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) - embeddings = await embedder.response(summaries) - for i, meta in enumerate(metadata_list): - meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None - except Exception as e: - logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) + # --- 阶段2:批量调用 LLM 生成 name 和 summary --- + if self.llm_model_id: + llm_client = self._get_llm_client() + if llm_client: + prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")] + + if prompts_to_process: + logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据") + + async def _call_llm(idx: int, meta: Dict) -> tuple: + """单个 LLM 调用""" + try: + response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}]) + text = response.content if hasattr(response, "content") else str(response) + return (idx, text, None) + except Exception as e: + logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}") + return (idx, None, e) + + # 并发调用所有 LLM 请求 + llm_results = await asyncio.gather( + *[_call_llm(idx, meta) for idx, meta in prompts_to_process], + return_exceptions=True + ) + + # 解析 LLM 响应 + for result in llm_results: + if isinstance(result, Exception): + continue + idx, text, error = result + if error or not text: + continue + + meta = metadata_list[idx] + for line in text.strip().splitlines(): + if line.startswith("名称:"): + meta["name"] = line[3:].strip() + elif line.startswith("摘要:"): + meta["summary"] = line[3:].strip() + + logger.info(f"[Clustering] LLM 批量生成完成") - # --- 阶段3:写入(单个 or 批量)--- + # --- 阶段3:批量生成 summary_embedding --- + if self.embedding_model_id: + embedder = self._get_embedder_client() + if embedder: + try: + summaries = [m["summary"] for m in metadata_list] + logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding") + embeddings = await embedder.response(summaries) + for i, meta in enumerate(metadata_list): + meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None + logger.info(f"[Clustering] Embedding 批量生成完成") + except Exception as e: + logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) + + # --- 阶段4:批量写入数据库 --- + # 移除 prompt 字段(不需要存储) + for m in metadata_list: + m.pop("prompt", None) + if len(metadata_list) == 1: m = metadata_list[0] result = await self.repo.update_community_metadata( @@ -582,6 +637,28 @@ class LabelPropagationEngine: ok = await self.repo.batch_update_community_metadata(metadata_list) if not ok: logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败") + else: + logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功") + + def _get_llm_client(self): + """获取或创建 LLM 客户端(单例模式)""" + if self._llm_client is None and self.llm_model_id: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + with get_db_context() as db: + self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id) + logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}") + return self._llm_client + + def _get_embedder_client(self): + """获取或创建 Embedder 客户端(单例模式)""" + if self._embedder_client is None and self.embedding_model_id: + from app.db import get_db_context + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + with get_db_context() as db: + self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id) + logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}") + return self._embedder_client @staticmethod def _new_community_id() -> str: diff --git a/api/app/tasks.py b/api/app/tasks.py index 61736275..d5f09a29 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2627,6 +2627,100 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[ # 社区聚类补全任务(触发型) # ============================================================================= +@celery_app.task( + name="app.tasks.run_incremental_clustering", + bind=True, + ignore_result=False, + max_retries=2, + acks_late=True, + time_limit=1800, # 30分钟硬超时 + soft_time_limit=1700, +) +def run_incremental_clustering( + self, + end_user_id: str, + new_entity_ids: List[str], + llm_model_id: Optional[str] = None, + embedding_model_id: Optional[str] = None, +) -> Dict[str, Any]: + """增量聚类任务:处理新增实体的社区分配和元数据生成。 + + 此任务在后台异步执行,不阻塞 write_message 主流程。 + + Args: + end_user_id: 用户 ID + new_entity_ids: 新增实体 ID 列表 + llm_model_id: LLM 模型 ID(可选) + embedding_model_id: Embedding 模型 ID(可选) + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_logger + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine + + logger = get_logger(__name__) + logger.info( + f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, " + f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}" + ) + + connector = Neo4jConnector() + try: + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + embedding_model_id=embedding_model_id, + ) + + # 执行增量聚类 + await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) + + logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}") + + return { + "status": "SUCCESS", + "end_user_id": end_user_id, + "entity_count": len(new_entity_ids), + } + except Exception as e: + logger.error(f"[IncrementalClustering] 增量聚类失败: {e}", exc_info=True) + raise + finally: + await connector.close() + + try: + loop = set_asyncio_event_loop() + result = loop.run_until_complete(_run()) + result["elapsed_time"] = time.time() - start_time + result["task_id"] = self.request.id + + logger.info( + f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, " + f"elapsed_time={result['elapsed_time']:.2f}s" + ) + + return result + except Exception as e: + elapsed_time = time.time() - start_time + logger.error( + f"[IncrementalClustering] 任务失败 - task_id={self.request.id}, " + f"elapsed_time={elapsed_time:.2f}s, error={str(e)}", + exc_info=True + ) + return { + "status": "FAILURE", + "error": str(e), + "end_user_id": end_user_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id, + } + + @celery_app.task( name="app.tasks.init_community_clustering_for_users", bind=True,