Merge pull request #698 from SuanmoSuanyangTechnology/fix/batch-clustering

[changes] Set up Celery tasks to perform clustering
This commit is contained in:
Ke Sun
2026-03-26 18:36:23 +08:00
committed by GitHub
4 changed files with 281 additions and 83 deletions

View File

@@ -108,6 +108,9 @@ celery_app.conf.update(
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Clustering tasks → memory_tasks queue (使用相同的 worker避免 macOS fork 问题)
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},

View File

@@ -8,6 +8,7 @@ import asyncio
import time
import uuid
from datetime import datetime
from typing import List, Optional
from dotenv import load_dotenv
@@ -21,7 +22,7 @@ from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
@@ -177,28 +178,33 @@ async def write(
if success:
logger.info("Successfully saved all data to Neo4j")
# 同步用户别名到 PostgreSQL
try:
# 创建一个临时的 orchestrator 实例来调用同步方法
temp_orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=neo4j_connector,
embedding_id=embedding_model_id
)
await temp_orchestrator._update_end_user_other_name(all_entity_nodes, chunked_dialogs)
logger.info("Successfully synced user aliases to PostgreSQL")
except Exception as sync_error:
logger.error(f"Failed to sync user aliases to PostgreSQL: {sync_error}", exc_info=True)
# 不影响主流程
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
if all_entity_nodes:
try:
from app.tasks import run_incremental_clustering
end_user_id = all_entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in all_entity_nodes]
# 异步提交 Celery 任务
task = run_incremental_clustering.apply_async(
kwargs={
"end_user_id": end_user_id,
"new_entity_ids": new_entity_ids,
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
},
# 设置任务优先级(低优先级,不影响主业务)
priority=3,
)
logger.info(
f"[Clustering] 增量聚类任务已提交到 Celery - "
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
)
except Exception as e:
# 聚类任务提交失败不影响主流程
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
await _trigger_clustering_sync(
all_entity_nodes,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break
else:
logger.warning("Failed to save some data to Neo4j")

View File

@@ -76,6 +76,9 @@ class LabelPropagationEngine:
self.repo = CommunityRepository(connector)
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
# 缓存客户端实例,避免重复初始化
self._llm_client = None
self._embedder_client = None
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -215,8 +218,17 @@ class LabelPropagationEngine:
3. 若邻居无社区 → 创建新社区
4. 若邻居分属多个社区 → 评估是否合并
"""
# 收集所有需要生成元数据的社区ID
communities_to_update = set()
for entity_id in new_entity_ids:
await self._process_single_entity(entity_id, end_user_id)
cid = await self._process_single_entity(entity_id, end_user_id)
if cid:
communities_to_update.add(cid)
# 批量生成所有社区的元数据
if communities_to_update:
await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True)
# ──────────────────────────────────────────────────────────────────────────
# 内部方法
@@ -224,8 +236,21 @@ class LabelPropagationEngine:
async def _process_single_entity(
self, entity_id: str, end_user_id: str
) -> None:
"""处理单个新实体的社区分配。"""
) -> Optional[str]:
"""
处理单个新实体的社区分配。
该函数会为新实体分配社区,可能的情况包括:
1. 孤立实体(无邻居):创建新的单成员社区
2. 邻居都没有社区:创建新社区并将实体和邻居都加入
3. 邻居有社区:通过加权投票选择最合适的社区加入
Returns:
Optional[str]: 分配到的社区ID。当前实现总是返回一个有效的社区ID
但返回类型保留为Optional以支持未来可能的扩展场景
(例如:实体无法分配到任何社区的情况)。
调用方应检查返回值的真假性truthiness
"""
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
# 查询自身 embedding从邻居查询结果中无法获取需单独查
@@ -237,8 +262,7 @@ class LabelPropagationEngine:
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
await self._generate_community_metadata([new_cid], end_user_id)
return
return new_cid
# 统计邻居社区分布
community_ids_in_neighbors = set(
@@ -260,7 +284,7 @@ class LabelPropagationEngine:
logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
)
await self._generate_community_metadata([new_cid], end_user_id)
return new_cid
else:
# 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -272,8 +296,8 @@ class LabelPropagationEngine:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
# 新实体加入后成员变化,强制重新生成元数据
await self._generate_community_metadata([target_cid], end_user_id, force=True)
# 返回目标社区ID稍后批量生成元数据
return target_cid
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
@@ -456,20 +480,19 @@ class LabelPropagationEngine:
self, community_ids: List[str], end_user_id: str, force: bool = False
) -> None:
"""
为一个或多个社区生成并写入元数据。
为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)
流程:
1. 逐个社区调 LLM 生成 name / summary串行
2. 收集所有 summary一次性批量 embed
3. 单个社区用 update_community_metadata多个用 batch_update_community_metadata
1. 批量准备所有社区的 prompt
2. 并发调用 LLM 生成所有社区的 name / summary
3. 批量 embed 所有 summary
4. 批量写入数据库
Args:
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
"""
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
async def _build_one(cid: str) -> Optional[Dict]:
async def _prepare_one(cid: str) -> Optional[Dict]:
"""准备单个社区的数据和 prompt"""
try:
if not force:
check_embedding = bool(self.embedding_model_id)
@@ -489,42 +512,32 @@ class LabelPropagationEngine:
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
# 默认值
name = "".join(core_entities[:3]) if core_entities else cid[:8]
summary = f"包含实体:{', '.join(all_names)}"
# 准备 LLM prompt如果配置了 LLM
prompt = None
if self.llm_model_id:
try:
entity_list_str = "\n".join(self._build_entity_lines(members))
relationships = await self.repo.get_community_relationships(cid, end_user_id)
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}"
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称(不超过10个字\n"
f"2. 写一句话摘要不超过80个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
entity_list_str = "\n".join(self._build_entity_lines(members))
relationships = await self.repo.get_community_relationships(cid, end_user_id)
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}"
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要(不超过80个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
return {
"community_id": cid,
@@ -532,14 +545,16 @@ class LabelPropagationEngine:
"name": name,
"summary": summary,
"core_entities": core_entities,
"prompt": prompt,
"summary_embedding": None,
}
except Exception as e:
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
return None
# --- 阶段1并发准备所有社区数据 ---
results = await asyncio.gather(
*[_build_one(cid) for cid in community_ids],
*[_prepare_one(cid) for cid in community_ids],
return_exceptions=True,
)
metadata_list = []
@@ -553,19 +568,77 @@ class LabelPropagationEngine:
logger.warning(f"[Clustering] 无有效元数据可写入community_ids={community_ids}")
return
# --- 阶段2批量生成 summary_embedding ---
if self.embedding_model_id:
try:
summaries = [m["summary"] for m in metadata_list]
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
except Exception as e:
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
# --- 阶段2批量调用 LLM 生成 name 和 summary ---
if self.llm_model_id:
llm_client = self._get_llm_client()
if not llm_client:
logger.warning(
f"[Clustering] LLM 已配置model_id={self.llm_model_id})但客户端初始化失败,"
f"将跳过社区元数据的 LLM 富化。请检查 model_id 是否正确或数据库连接是否正常。"
)
if llm_client:
prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")]
if prompts_to_process:
logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据")
async def _call_llm(idx: int, meta: Dict) -> tuple:
"""单个 LLM 调用"""
try:
response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}])
text = response.content if hasattr(response, "content") else str(response)
return (idx, text, None)
except Exception as e:
logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}")
return (idx, None, e)
# 并发调用所有 LLM 请求
llm_results = await asyncio.gather(
*[_call_llm(idx, meta) for idx, meta in prompts_to_process],
return_exceptions=True
)
# 解析 LLM 响应
for result in llm_results:
if isinstance(result, Exception):
continue
idx, text, error = result
if error or not text:
continue
meta = metadata_list[idx]
for line in text.strip().splitlines():
if line.startswith("名称:"):
meta["name"] = line[3:].strip()
elif line.startswith("摘要:"):
meta["summary"] = line[3:].strip()
logger.info(f"[Clustering] LLM 批量生成完成")
# --- 阶段3写入(单个 or 批量)---
# --- 阶段3批量生成 summary_embedding ---
if self.embedding_model_id:
embedder = self._get_embedder_client()
if not embedder:
logger.warning(
f"[Clustering] Embedding 已配置model_id={self.embedding_model_id})但客户端初始化失败,"
f"将跳过社区摘要的向量化。请检查 model_id 是否正确或数据库连接是否正常。"
)
if embedder:
try:
summaries = [m["summary"] for m in metadata_list]
logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding")
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
logger.info(f"[Clustering] Embedding 批量生成完成")
except Exception as e:
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
# --- 阶段4批量写入数据库 ---
# 移除 prompt 字段(不需要存储)
for m in metadata_list:
m.pop("prompt", None)
if len(metadata_list) == 1:
m = metadata_list[0]
result = await self.repo.update_community_metadata(
@@ -582,6 +655,28 @@ class LabelPropagationEngine:
ok = await self.repo.batch_update_community_metadata(metadata_list)
if not ok:
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
else:
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
def _get_llm_client(self):
"""获取或创建 LLM 客户端(单例模式)"""
if self._llm_client is None and self.llm_model_id:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
with get_db_context() as db:
self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}")
return self._llm_client
def _get_embedder_client(self):
"""获取或创建 Embedder 客户端(单例模式)"""
if self._embedder_client is None and self.embedding_model_id:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
with get_db_context() as db:
self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}")
return self._embedder_client
@staticmethod
def _new_community_id() -> str:

View File

@@ -2627,6 +2627,100 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
# 社区聚类补全任务(触发型)
# =============================================================================
@celery_app.task(
name="app.tasks.run_incremental_clustering",
bind=True,
ignore_result=False,
max_retries=2,
acks_late=True,
time_limit=1800, # 30分钟硬超时
soft_time_limit=1700,
)
def run_incremental_clustering(
self,
end_user_id: str,
new_entity_ids: List[str],
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> Dict[str, Any]:
"""增量聚类任务:处理新增实体的社区分配和元数据生成。
此任务在后台异步执行,不阻塞 write_message 主流程。
Args:
end_user_id: 用户 ID
new_entity_ids: 新增实体 ID 列表
llm_model_id: LLM 模型 ID可选
embedding_model_id: Embedding 模型 ID可选
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
logger = get_logger(__name__)
logger.info(
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
)
connector = Neo4jConnector()
try:
engine = LabelPropagationEngine(
connector=connector,
llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id,
)
# 执行增量聚类
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
return {
"status": "SUCCESS",
"end_user_id": end_user_id,
"entity_count": len(new_entity_ids),
}
except Exception as e:
logger.error(f"[IncrementalClustering] 增量聚类失败: {e}", exc_info=True)
raise
finally:
await connector.close()
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
logger.info(
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
f"elapsed_time={result['elapsed_time']:.2f}s"
)
return result
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(
f"[IncrementalClustering] 任务失败 - task_id={self.request.id}, "
f"elapsed_time={elapsed_time:.2f}s, error={str(e)}",
exc_info=True
)
return {
"status": "FAILURE",
"error": str(e),
"end_user_id": end_user_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id,
}
@celery_app.task(
name="app.tasks.init_community_clustering_for_users",
bind=True,