[changes] Set up Celery tasks to perform clustering
This commit is contained in:
@@ -108,6 +108,9 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
|
# Clustering tasks → memory_tasks queue (使用相同的 worker,避免 macOS fork 问题)
|
||||||
|
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Document tasks → document_tasks queue (prefork worker)
|
# Document tasks → document_tasks queue (prefork worker)
|
||||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
@@ -177,28 +178,33 @@ async def write(
|
|||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
|
||||||
# 同步用户别名到 PostgreSQL
|
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
|
||||||
try:
|
if all_entity_nodes:
|
||||||
# 创建一个临时的 orchestrator 实例来调用同步方法
|
try:
|
||||||
temp_orchestrator = ExtractionOrchestrator(
|
from app.tasks import run_incremental_clustering
|
||||||
llm_client=llm_client,
|
|
||||||
embedder_client=embedder_client,
|
end_user_id = all_entity_nodes[0].end_user_id
|
||||||
connector=neo4j_connector,
|
new_entity_ids = [e.id for e in all_entity_nodes]
|
||||||
embedding_id=embedding_model_id
|
|
||||||
)
|
# 异步提交 Celery 任务
|
||||||
await temp_orchestrator._update_end_user_other_name(all_entity_nodes, chunked_dialogs)
|
task = run_incremental_clustering.apply_async(
|
||||||
logger.info("Successfully synced user aliases to PostgreSQL")
|
kwargs={
|
||||||
except Exception as sync_error:
|
"end_user_id": end_user_id,
|
||||||
logger.error(f"Failed to sync user aliases to PostgreSQL: {sync_error}", exc_info=True)
|
"new_entity_ids": new_entity_ids,
|
||||||
# 不影响主流程
|
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
|
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
|
},
|
||||||
|
# 设置任务优先级(低优先级,不影响主业务)
|
||||||
|
priority=3,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 增量聚类任务已提交到 Celery - "
|
||||||
|
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# 聚类任务提交失败不影响主流程
|
||||||
|
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
|
||||||
|
|
||||||
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突)
|
|
||||||
await _trigger_clustering_sync(
|
|
||||||
all_entity_nodes,
|
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
|
||||||
embedding_model_id=str(
|
|
||||||
memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ class LabelPropagationEngine:
|
|||||||
self.repo = CommunityRepository(connector)
|
self.repo = CommunityRepository(connector)
|
||||||
self.llm_model_id = llm_model_id
|
self.llm_model_id = llm_model_id
|
||||||
self.embedding_model_id = embedding_model_id
|
self.embedding_model_id = embedding_model_id
|
||||||
|
# 缓存客户端实例,避免重复初始化
|
||||||
|
self._llm_client = None
|
||||||
|
self._embedder_client = None
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 公开接口
|
# 公开接口
|
||||||
@@ -215,8 +218,17 @@ class LabelPropagationEngine:
|
|||||||
3. 若邻居无社区 → 创建新社区
|
3. 若邻居无社区 → 创建新社区
|
||||||
4. 若邻居分属多个社区 → 评估是否合并
|
4. 若邻居分属多个社区 → 评估是否合并
|
||||||
"""
|
"""
|
||||||
|
# 收集所有需要生成元数据的社区ID
|
||||||
|
communities_to_update = set()
|
||||||
|
|
||||||
for entity_id in new_entity_ids:
|
for entity_id in new_entity_ids:
|
||||||
await self._process_single_entity(entity_id, end_user_id)
|
cid = await self._process_single_entity(entity_id, end_user_id)
|
||||||
|
if cid:
|
||||||
|
communities_to_update.add(cid)
|
||||||
|
|
||||||
|
# 批量生成所有社区的元数据
|
||||||
|
if communities_to_update:
|
||||||
|
await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True)
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 内部方法
|
# 内部方法
|
||||||
@@ -224,8 +236,13 @@ class LabelPropagationEngine:
|
|||||||
|
|
||||||
async def _process_single_entity(
|
async def _process_single_entity(
|
||||||
self, entity_id: str, end_user_id: str
|
self, entity_id: str, end_user_id: str
|
||||||
) -> None:
|
) -> Optional[str]:
|
||||||
"""处理单个新实体的社区分配。"""
|
"""
|
||||||
|
处理单个新实体的社区分配。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 分配到的社区ID(如果有)
|
||||||
|
"""
|
||||||
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||||
|
|
||||||
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||||
@@ -237,8 +254,7 @@ class LabelPropagationEngine:
|
|||||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||||
await self._generate_community_metadata([new_cid], end_user_id)
|
return new_cid
|
||||||
return
|
|
||||||
|
|
||||||
# 统计邻居社区分布
|
# 统计邻居社区分布
|
||||||
community_ids_in_neighbors = set(
|
community_ids_in_neighbors = set(
|
||||||
@@ -260,7 +276,7 @@ class LabelPropagationEngine:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata([new_cid], end_user_id)
|
return new_cid
|
||||||
else:
|
else:
|
||||||
# 加入得票最多的社区
|
# 加入得票最多的社区
|
||||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||||
@@ -272,8 +288,8 @@ class LabelPropagationEngine:
|
|||||||
await self._evaluate_merge(
|
await self._evaluate_merge(
|
||||||
list(community_ids_in_neighbors), end_user_id
|
list(community_ids_in_neighbors), end_user_id
|
||||||
)
|
)
|
||||||
# 新实体加入后成员变化,强制重新生成元数据
|
# 返回目标社区ID,稍后批量生成元数据
|
||||||
await self._generate_community_metadata([target_cid], end_user_id, force=True)
|
return target_cid
|
||||||
|
|
||||||
async def _evaluate_merge(
|
async def _evaluate_merge(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
@@ -456,20 +472,19 @@ class LabelPropagationEngine:
|
|||||||
self, community_ids: List[str], end_user_id: str, force: bool = False
|
self, community_ids: List[str], end_user_id: str, force: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为一个或多个社区生成并写入元数据。
|
为一个或多个社区生成并写入元数据(优化版:批量 LLM 调用)。
|
||||||
|
|
||||||
流程:
|
流程:
|
||||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
1. 批量准备所有社区的 prompt
|
||||||
2. 收集所有 summary,一次性批量 embed
|
2. 并发调用 LLM 生成所有社区的 name / summary
|
||||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
3. 批量 embed 所有 summary
|
||||||
|
4. 批量写入数据库
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
|
||||||
"""
|
"""
|
||||||
from app.db import get_db_context
|
async def _prepare_one(cid: str) -> Optional[Dict]:
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
"""准备单个社区的数据和 prompt"""
|
||||||
|
|
||||||
async def _build_one(cid: str) -> Optional[Dict]:
|
|
||||||
try:
|
try:
|
||||||
if not force:
|
if not force:
|
||||||
check_embedding = bool(self.embedding_model_id)
|
check_embedding = bool(self.embedding_model_id)
|
||||||
@@ -489,42 +504,32 @@ class LabelPropagationEngine:
|
|||||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
all_names = [m["name"] for m in members if m.get("name")]
|
all_names = [m["name"] for m in members if m.get("name")]
|
||||||
|
|
||||||
|
# 默认值
|
||||||
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
name = "、".join(core_entities[:3]) if core_entities else cid[:8]
|
||||||
summary = f"包含实体:{', '.join(all_names)}"
|
summary = f"包含实体:{', '.join(all_names)}"
|
||||||
|
|
||||||
|
# 准备 LLM prompt(如果配置了 LLM)
|
||||||
|
prompt = None
|
||||||
if self.llm_model_id:
|
if self.llm_model_id:
|
||||||
try:
|
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
rel_lines = [
|
||||||
rel_lines = [
|
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
for r in relationships
|
||||||
for r in relationships
|
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
]
|
||||||
]
|
rel_section = (
|
||||||
rel_section = (
|
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
if rel_lines else ""
|
||||||
if rel_lines else ""
|
)
|
||||||
)
|
prompt = (
|
||||||
prompt = (
|
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
f"请为这组实体所代表的主题:\n"
|
||||||
f"请为这组实体所代表的主题:\n"
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
f"严格按以下格式输出,不要有其他内容:\n"
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
f"名称:<名称>\n摘要:<摘要>"
|
)
|
||||||
)
|
|
||||||
with get_db_context() as db:
|
|
||||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
|
||||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
|
||||||
|
|
||||||
for line in text.strip().splitlines():
|
|
||||||
if line.startswith("名称:"):
|
|
||||||
name = line[3:].strip()
|
|
||||||
elif line.startswith("摘要:"):
|
|
||||||
summary = line[3:].strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"community_id": cid,
|
"community_id": cid,
|
||||||
@@ -532,14 +537,16 @@ class LabelPropagationEngine:
|
|||||||
"name": name,
|
"name": name,
|
||||||
"summary": summary,
|
"summary": summary,
|
||||||
"core_entities": core_entities,
|
"core_entities": core_entities,
|
||||||
|
"prompt": prompt,
|
||||||
"summary_embedding": None,
|
"summary_embedding": None,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# --- 阶段1:并发准备所有社区数据 ---
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
*[_build_one(cid) for cid in community_ids],
|
*[_prepare_one(cid) for cid in community_ids],
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
metadata_list = []
|
metadata_list = []
|
||||||
@@ -553,19 +560,67 @@ class LabelPropagationEngine:
|
|||||||
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
logger.warning(f"[Clustering] 无有效元数据可写入,community_ids={community_ids}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# --- 阶段2:批量生成 summary_embedding ---
|
# --- 阶段2:批量调用 LLM 生成 name 和 summary ---
|
||||||
if self.embedding_model_id:
|
if self.llm_model_id:
|
||||||
try:
|
llm_client = self._get_llm_client()
|
||||||
summaries = [m["summary"] for m in metadata_list]
|
if llm_client:
|
||||||
with get_db_context() as db:
|
prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")]
|
||||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
|
||||||
embeddings = await embedder.response(summaries)
|
if prompts_to_process:
|
||||||
for i, meta in enumerate(metadata_list):
|
logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据")
|
||||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
|
||||||
except Exception as e:
|
async def _call_llm(idx: int, meta: Dict) -> tuple:
|
||||||
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
"""单个 LLM 调用"""
|
||||||
|
try:
|
||||||
|
response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}])
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
return (idx, text, None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}")
|
||||||
|
return (idx, None, e)
|
||||||
|
|
||||||
|
# 并发调用所有 LLM 请求
|
||||||
|
llm_results = await asyncio.gather(
|
||||||
|
*[_call_llm(idx, meta) for idx, meta in prompts_to_process],
|
||||||
|
return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析 LLM 响应
|
||||||
|
for result in llm_results:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
continue
|
||||||
|
idx, text, error = result
|
||||||
|
if error or not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
meta = metadata_list[idx]
|
||||||
|
for line in text.strip().splitlines():
|
||||||
|
if line.startswith("名称:"):
|
||||||
|
meta["name"] = line[3:].strip()
|
||||||
|
elif line.startswith("摘要:"):
|
||||||
|
meta["summary"] = line[3:].strip()
|
||||||
|
|
||||||
|
logger.info(f"[Clustering] LLM 批量生成完成")
|
||||||
|
|
||||||
# --- 阶段3:写入(单个 or 批量)---
|
# --- 阶段3:批量生成 summary_embedding ---
|
||||||
|
if self.embedding_model_id:
|
||||||
|
embedder = self._get_embedder_client()
|
||||||
|
if embedder:
|
||||||
|
try:
|
||||||
|
summaries = [m["summary"] for m in metadata_list]
|
||||||
|
logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding")
|
||||||
|
embeddings = await embedder.response(summaries)
|
||||||
|
for i, meta in enumerate(metadata_list):
|
||||||
|
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||||
|
logger.info(f"[Clustering] Embedding 批量生成完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# --- 阶段4:批量写入数据库 ---
|
||||||
|
# 移除 prompt 字段(不需要存储)
|
||||||
|
for m in metadata_list:
|
||||||
|
m.pop("prompt", None)
|
||||||
|
|
||||||
if len(metadata_list) == 1:
|
if len(metadata_list) == 1:
|
||||||
m = metadata_list[0]
|
m = metadata_list[0]
|
||||||
result = await self.repo.update_community_metadata(
|
result = await self.repo.update_community_metadata(
|
||||||
@@ -582,6 +637,28 @@ class LabelPropagationEngine:
|
|||||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||||
if not ok:
|
if not ok:
|
||||||
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
|
||||||
|
else:
|
||||||
|
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||||
|
|
||||||
|
def _get_llm_client(self):
|
||||||
|
"""获取或创建 LLM 客户端(单例模式)"""
|
||||||
|
if self._llm_client is None and self.llm_model_id:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
with get_db_context() as db:
|
||||||
|
self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||||
|
logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}")
|
||||||
|
return self._llm_client
|
||||||
|
|
||||||
|
def _get_embedder_client(self):
|
||||||
|
"""获取或创建 Embedder 客户端(单例模式)"""
|
||||||
|
if self._embedder_client is None and self.embedding_model_id:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
with get_db_context() as db:
|
||||||
|
self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||||
|
logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}")
|
||||||
|
return self._embedder_client
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _new_community_id() -> str:
|
def _new_community_id() -> str:
|
||||||
|
|||||||
@@ -2627,6 +2627,100 @@ def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[
|
|||||||
# 社区聚类补全任务(触发型)
|
# 社区聚类补全任务(触发型)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
@celery_app.task(
|
||||||
|
name="app.tasks.run_incremental_clustering",
|
||||||
|
bind=True,
|
||||||
|
ignore_result=False,
|
||||||
|
max_retries=2,
|
||||||
|
acks_late=True,
|
||||||
|
time_limit=1800, # 30分钟硬超时
|
||||||
|
soft_time_limit=1700,
|
||||||
|
)
|
||||||
|
def run_incremental_clustering(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
new_entity_ids: List[str],
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""增量聚类任务:处理新增实体的社区分配和元数据生成。
|
||||||
|
|
||||||
|
此任务在后台异步执行,不阻塞 write_message 主流程。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 用户 ID
|
||||||
|
new_entity_ids: 新增实体 ID 列表
|
||||||
|
llm_model_id: LLM 模型 ID(可选)
|
||||||
|
embedding_model_id: Embedding 模型 ID(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含任务执行结果的字典
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
async def _run() -> Dict[str, Any]:
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
logger.info(
|
||||||
|
f"[IncrementalClustering] 开始增量聚类任务 - end_user_id={end_user_id}, "
|
||||||
|
f"实体数={len(new_entity_ids)}, llm_model_id={llm_model_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
engine = LabelPropagationEngine(
|
||||||
|
connector=connector,
|
||||||
|
llm_model_id=llm_model_id,
|
||||||
|
embedding_model_id=embedding_model_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 执行增量聚类
|
||||||
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
|
|
||||||
|
logger.info(f"[IncrementalClustering] 增量聚类完成 - end_user_id={end_user_id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "SUCCESS",
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"entity_count": len(new_entity_ids),
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[IncrementalClustering] 增量聚类失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
loop = set_asyncio_event_loop()
|
||||||
|
result = loop.run_until_complete(_run())
|
||||||
|
result["elapsed_time"] = time.time() - start_time
|
||||||
|
result["task_id"] = self.request.id
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[IncrementalClustering] 任务完成 - task_id={self.request.id}, "
|
||||||
|
f"elapsed_time={result['elapsed_time']:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
|
logger.error(
|
||||||
|
f"[IncrementalClustering] 任务失败 - task_id={self.request.id}, "
|
||||||
|
f"elapsed_time={elapsed_time:.2f}s, error={str(e)}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"status": "FAILURE",
|
||||||
|
"error": str(e),
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"elapsed_time": elapsed_time,
|
||||||
|
"task_id": self.request.id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(
|
@celery_app.task(
|
||||||
name="app.tasks.init_community_clustering_for_users",
|
name="app.tasks.init_community_clustering_for_users",
|
||||||
bind=True,
|
bind=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user