diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 2ef9bafc..cbd2b532 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -293,15 +293,8 @@ async def save_dialog_and_statements_to_neo4j( logger.info("Transaction completed. Summary: %s", summary) logger.debug("Full transaction results: %r", results) - # 写入成功后,触发聚类(可通过环境变量 CLUSTERING_ENABLED=false 禁用,用于基准测试对比) - clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false" - if entity_nodes and clustering_enabled: - end_user_id = entity_nodes[0].end_user_id - new_entity_ids = [e.id for e in entity_nodes] - logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") - asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) - elif entity_nodes and not clustering_enabled: - logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发") + # 写入成功后,异步触发聚类(不阻塞写入响应) + schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id) return True @@ -312,6 +305,31 @@ async def save_dialog_and_statements_to_neo4j( return False +def schedule_clustering_after_write( + entity_nodes: List, + config_id: Optional[str] = None, + llm_model_id: Optional[str] = None, +) -> None: + """ + 写入 Neo4j 成功后,调度后台聚类任务。 + + 可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。 + 使用 asyncio.create_task 异步触发,不阻塞写入响应。 + """ + if not entity_nodes: + return + + clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false" + if not clustering_enabled: + logger.info("[Clustering] 聚类已禁用(CLUSTERING_ENABLED=false),跳过聚类触发") + return + + end_user_id = entity_nodes[0].end_user_id + new_entity_ids = [e.id for e in entity_nodes] + logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") + asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) + + async def _trigger_clustering( new_entity_ids: List[str], end_user_id: str, diff --git a/api/app/tasks.py b/api/app/tasks.py index f737bb48..8ad2c467 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -2707,32 +2707,26 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s try: repo = CommunityRepository(connector) - # 获取 llm_model_id(从第一个用户的配置中读取,作为全局兜底) - llm_model_id = None + # 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置) + user_llm_map: Dict[str, Optional[str]] = {} try: with get_db_context() as db: - from app.services.memory_agent_service import get_end_user_connected_config + from app.services.memory_agent_service import get_end_users_connected_configs_batch from app.services.memory_config_service import MemoryConfigService - for uid in end_user_ids: - try: - connected = get_end_user_connected_config(uid, db) - config_id = connected.get("memory_config_id") - workspace_id = connected.get("workspace_id") - if config_id or workspace_id: - cfg = MemoryConfigService(db).load_memory_config( - config_id=config_id, workspace_id=workspace_id - ) - llm_model_id = str(cfg.llm_model_id) - break - except Exception: - continue + batch_configs = get_end_users_connected_configs_batch(end_user_ids, db) + for uid, cfg_info in batch_configs.items(): + config_id = cfg_info.get("memory_config_id") + if config_id: + try: + cfg = MemoryConfigService(db).load_memory_config(config_id=config_id) + user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None + except Exception as e: + logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}") + user_llm_map[uid] = None + else: + user_llm_map[uid] = None except Exception as e: - logger.warning(f"[CommunityCluster] 获取 LLM 配置失败,将使用兜底值: {e}") - - engine = LabelPropagationEngine( - connector=connector, - llm_model_id=llm_model_id, - ) + logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}") for end_user_id in end_user_ids: try: @@ -2750,7 +2744,14 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过") continue - logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类") + # 每个用户使用自己的 llm_model_id + llm_model_id = user_llm_map.get(end_user_id) + engine = LabelPropagationEngine( + connector=connector, + llm_model_id=llm_model_id, + ) + + logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}") await engine.full_clustering(end_user_id) initialized += 1 logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成") @@ -2779,15 +2780,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s except ImportError: pass - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - + loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) result["elapsed_time"] = time.time() - start_time result["task_id"] = self.request.id