[changes] Set up Celery tasks to perform clustering
This commit is contained in:
@@ -8,6 +8,7 @@ import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -21,7 +22,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
@@ -177,28 +178,33 @@ async def write(
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
|
||||
# 同步用户别名到 PostgreSQL
|
||||
try:
|
||||
# 创建一个临时的 orchestrator 实例来调用同步方法
|
||||
temp_orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
embedding_id=embedding_model_id
|
||||
)
|
||||
await temp_orchestrator._update_end_user_other_name(all_entity_nodes, chunked_dialogs)
|
||||
logger.info("Successfully synced user aliases to PostgreSQL")
|
||||
except Exception as sync_error:
|
||||
logger.error(f"Failed to sync user aliases to PostgreSQL: {sync_error}", exc_info=True)
|
||||
# 不影响主流程
|
||||
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||
if all_entity_nodes:
|
||||
try:
|
||||
from app.tasks import run_incremental_clustering
|
||||
|
||||
end_user_id = all_entity_nodes[0].end_user_id
|
||||
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||
|
||||
# 异步提交 Celery 任务
|
||||
task = run_incremental_clustering.apply_async(
|
||||
kwargs={
|
||||
"end_user_id": end_user_id,
|
||||
"new_entity_ids": new_entity_ids,
|
||||
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
},
|
||||
# 设置任务优先级(低优先级,不影响主业务)
|
||||
priority=3,
|
||||
)
|
||||
logger.info(
|
||||
f"[Clustering] 增量聚类任务已提交到 Celery - "
|
||||
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||
)
|
||||
except Exception as e:
|
||||
# 聚类任务提交失败不影响主流程
|
||||
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||
|
||||
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
|
||||
await _trigger_clustering_sync(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(
|
||||
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
|
||||
@@ -76,6 +76,9 @@ class LabelPropagationEngine:
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
# 缓存客户端实例,避免重复初始化
|
||||
self._llm_client = None
|
||||
self._embedder_client = None
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
@@ -215,8 +218,17 @@ class LabelPropagationEngine:
|
||||
3. 若邻居无社区 → 创建新社区
|
||||
4. 若邻居分属多个社区 → 评估是否合并
|
||||
"""
|
||||
# 收集所有需要生成元数据的社区ID
|
||||
communities_to_update = set()
|
||||
|
||||
for entity_id in new_entity_ids:
|
||||
await self._process_single_entity(entity_id, end_user_id)
|
||||
cid = await self._process_single_entity(entity_id, end_user_id)
|
||||
if cid:
|
||||
communities_to_update.add(cid)
|
||||
|
||||
# 批量生成所有社区的元数据
|
||||
if communities_to_update:
|
||||
await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 内部方法
|
||||
@@ -224,8 +236,13 @@ class LabelPropagationEngine:
|
||||
|
||||
async def _process_single_entity(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> None:
|
||||
"""处理单个新实体的社区分配。"""
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
处理单个新实体的社区分配。
|
||||
|
||||
Returns:
|
||||
str: 分配到的社区ID(如果有)
|
||||
"""
|
||||
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||
|
||||
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||
@@ -237,8 +254,7 @@ class LabelPropagationEngine:
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
await self._generate_community_metadata([new_cid], end_user_id)
|
||||
return
|
||||
return new_cid
|
||||
|
||||
# 统计邻居社区分布
|
||||
community_ids_in_neighbors = set(
|
||||
@@ -260,7 +276,7 @@ class LabelPropagationEngine:
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
await self._generate_community_metadata([new_cid], end_user_id)
|
||||
return new_cid
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
@@ -272,8 +288,8 @@ class LabelPropagationEngine:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
# 新实体加入后成员变化,强制重新生成元数据
|
||||
await self._generate_community_metadata([target_cid], end_user_id, force=True)
|
||||
# 返回目标社区ID,稍后批量生成元数据
|
||||
return target_cid
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
@@ -456,20 +472,19 @@ class LabelPropagationEngine:
|
||||
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
为一个或多个社区生成并写入元数据。
|
||||
为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)。
|
||||
|
||||
流程:
|
||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||
2. 收集所有 summary,一次性批量 embed
|
||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||
1. 批量准备所有社区的 prompt
|
||||
2. 并发调用 LLM 生成所有社区的 name / summary
|
||||
3. 批量 embed 所有 summary
|
||||
4. 批量写入数据库
|
||||
|
||||
Args:
|
||||
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||
"""
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
async def _build_one(cid: str) -> Optional[Dict]:
|
||||
async def _prepare_one(cid: str) -> Optional[Dict]:
|
||||
"""准备单个社区的数据和 prompt"""
|
||||
try:
|
||||
if not force:
|
||||
check_embedding = bool(self.embedding_model_id)
|
||||
@@ -489,42 +504,32 @@ class LabelPropagationEngine:
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
all_names = [m["name"] for m in members if m.get("name")]
|
||||
|
||||
# 默认值
|
||||
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||
summary = f"包含实体:{', '.join(all_names)}"
|
||||
|
||||
# 准备 LLM prompt(如果配置了 LLM)
|
||||
prompt = None
|
||||
if self.llm_model_id:
|
||||
try:
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
@@ -532,14 +537,16 @@ class LabelPropagationEngine:
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"prompt": prompt,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
# --- 阶段1:并发准备所有社区数据 ---
|
||||
results = await asyncio.gather(
|
||||
*[_build_one(cid) for cid in community_ids],
|
||||
*[_prepare_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
metadata_list = []
|
||||
@@ -553,19 +560,67 @@ class LabelPropagationEngine:
|
||||
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||
return
|
||||
|
||||
# --- 阶段2:批量生成 summary_embedding ---
|
||||
if self.embedding_model_id:
|
||||
try:
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||
# --- 阶段2:批量调用 LLM 生成 name 和 summary ---
|
||||
if self.llm_model_id:
|
||||
llm_client = self._get_llm_client()
|
||||
if llm_client:
|
||||
prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")]
|
||||
|
||||
if prompts_to_process:
|
||||
logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据")
|
||||
|
||||
async def _call_llm(idx: int, meta: Dict) -> tuple:
|
||||
"""单个 LLM 调用"""
|
||||
try:
|
||||
response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
return (idx, text, None)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}")
|
||||
return (idx, None, e)
|
||||
|
||||
# 并发调用所有 LLM 请求
|
||||
llm_results = await asyncio.gather(
|
||||
*[_call_llm(idx, meta) for idx, meta in prompts_to_process],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# 解析 LLM 响应
|
||||
for result in llm_results:
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
idx, text, error = result
|
||||
if error or not text:
|
||||
continue
|
||||
|
||||
meta = metadata_list[idx]
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
meta["name"] = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
meta["summary"] = line[3:].strip()
|
||||
|
||||
logger.info(f"[Clustering] LLM 批量生成完成")
|
||||
|
||||
# --- 阶段3:写入(单个 or 批量)---
|
||||
# --- 阶段3:批量生成 summary_embedding ---
|
||||
if self.embedding_model_id:
|
||||
embedder = self._get_embedder_client()
|
||||
if embedder:
|
||||
try:
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding")
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
logger.info(f"[Clustering] Embedding 批量生成完成")
|
||||
except Exception as e:
|
||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||
|
||||
# --- 阶段4:批量写入数据库 ---
|
||||
# 移除 prompt 字段(不需要存储)
|
||||
for m in metadata_list:
|
||||
m.pop("prompt", None)
|
||||
|
||||
if len(metadata_list) == 1:
|
||||
m = metadata_list[0]
|
||||
result = await self.repo.update_community_metadata(
|
||||
@@ -582,6 +637,28 @@ class LabelPropagationEngine:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if not ok:
|
||||
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||
else:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
|
||||
def _get_llm_client(self):
|
||||
"""获取或创建 LLM 客户端(单例模式)"""
|
||||
if self._llm_client is None and self.llm_model_id:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}")
|
||||
return self._llm_client
|
||||
|
||||
def _get_embedder_client(self):
|
||||
"""获取或创建 Embedder 客户端(单例模式)"""
|
||||
if self._embedder_client is None and self.embedding_model_id:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
with get_db_context() as db:
|
||||
self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}")
|
||||
return self._embedder_client
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
|
||||
Reference in New Issue
Block a user