[add] Selective merge submission

This commit is contained in:
lanceyq
2026-03-13 15:47:53 +08:00
parent 967139cea4
commit 668539e737
4 changed files with 206 additions and 17 deletions

View File

@@ -116,6 +116,7 @@ celery_app.conf.update(
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
},
)

View File

@@ -193,7 +193,19 @@ async def get_workspace_end_users(
await aio_redis_set(cache_key, json.dumps(result), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应)
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
try:
from app.tasks import init_community_clustering_for_users
init_community_clustering_for_users.apply_async(
kwargs={"end_user_ids": end_user_ids},
queue="periodic_tasks",
)
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功")

View File

@@ -165,8 +165,15 @@ class LabelPropagationEngine:
f"{len(labels)} 个实体"
)
# 为所有社区生成元数据
unique_communities = list(set(labels.values()))
for cid in unique_communities:
# 注意_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({
e.get("community_id") for e in surviving_communities
if e.get("community_id")
})
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
for cid in surviving_community_ids:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update(
@@ -249,7 +256,7 @@ class LabelPropagationEngine:
全量场景(社区数 > 20使用批量查询避免 N 次数据库往返。
"""
MERGE_THRESHOLD = 0.75
MERGE_THRESHOLD = 0.85
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
community_embeddings: Dict[str, Optional[List[float]]] = {}
@@ -305,34 +312,65 @@ class LabelPropagationEngine:
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
# 执行合并:用 union-find 思路避免重复迁移已被合并社区
# 维护一个 canonical 映射,确保链式合并正确收敛
canonical: Dict[str, str] = {cid: cid for cid in cids}
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
# A≈B、B≈C 不代表 A≈C不能因传递性把 A/B/C 全部合并)
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
def find(x: str) -> str:
while canonical[x] != x:
canonical[x] = canonical[canonical[x]]
x = canonical[x]
def get_root(x: str) -> str:
"""路径压缩,找到 x 当前所属的根社区。"""
while x in merged_into:
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
x = merged_into[x]
return x
for c1, c2 in to_merge:
root1, root2 = find(c1), find(c2)
root1, root2 = get_root(c1), get_root(c2)
if root1 == root2:
continue # 已经在同一社区,跳过
continue
# 用合并后的最新平均向量重新验证相似度
# 防止链式传递A≈B 合并后 B 的向量已更新C 必须和新 B 相似才能合并
current_sim = _cosine_similarity(
community_embeddings.get(root1),
community_embeddings.get(root2),
)
if current_sim <= MERGE_THRESHOLD:
# 合并后向量已漂移,不再满足阈值,跳过
logger.debug(
f"[Clustering] 跳过合并 {root1}{root2}"
f"当前相似度 {current_sim:.3f}{MERGE_THRESHOLD}"
)
continue
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
dissolve = root2 if keep == root1 else root1
canonical[dissolve] = keep
merged_into[dissolve] = keep
members = await self.repo.get_community_members(dissolve, end_user_id)
for m in members:
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
# 更新 sizes 以便后续合并决策准确
community_sizes[keep] = community_sizes.get(keep, 0) + len(members)
# 合并后重新计算 keep 的平均向量(加权平均)
keep_emb = community_embeddings.get(keep)
dissolve_emb = community_embeddings.get(dissolve)
keep_size = community_sizes.get(keep, 0)
dissolve_size = community_sizes.get(dissolve, 0)
total_size = keep_size + dissolve_size
if keep_emb and dissolve_emb and total_size > 0:
dim = len(keep_emb)
community_embeddings[keep] = [
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
for i in range(dim)
]
community_embeddings[dissolve] = None
community_sizes[keep] = total_size
community_sizes[dissolve] = 0
await self.repo.refresh_member_count(keep, end_user_id)
logger.info(
f"[Clustering] 社区合并: {dissolve}{keep}"
f"迁移 {len(members)} 个成员"
f"相似度={current_sim:.3f}迁移 {len(members)} 个成员"
)
async def _flush_labels(

View File

@@ -2662,3 +2662,141 @@ def write_perceptual_memory(
file_url,
file_message,
))
# =============================================================================
# 社区聚类补全任务(触发型)
# =============================================================================
@celery_app.task(
name="app.tasks.init_community_clustering_for_users",
bind=True,
ignore_result=False,
max_retries=0,
acks_late=False,
time_limit=7200, # 2小时硬超时
soft_time_limit=6900,
)
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
Args:
end_user_ids: 需要检查的用户 ID 列表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.neo4j.community_repository import CommunityRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
logger = get_logger(__name__)
logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}")
initialized = 0
skipped = 0
failed = 0
connector = Neo4jConnector()
try:
repo = CommunityRepository(connector)
# 获取 llm_model_id从第一个用户的配置中读取作为全局兜底
llm_model_id = None
try:
with get_db_context() as db:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
for uid in end_user_ids:
try:
connected = get_end_user_connected_config(uid, db)
config_id = connected.get("memory_config_id")
workspace_id = connected.get("workspace_id")
if config_id or workspace_id:
cfg = MemoryConfigService(db).load_memory_config(
config_id=config_id, workspace_id=workspace_id
)
llm_model_id = str(cfg.llm_model_id)
break
except Exception:
continue
except Exception as e:
logger.warning(f"[CommunityCluster] 获取 LLM 配置失败,将使用兜底值: {e}")
engine = LabelPropagationEngine(
connector=connector,
llm_model_id=llm_model_id,
)
for end_user_id in end_user_ids:
try:
# 已有社区节点则跳过
has_communities = await repo.has_communities(end_user_id)
if has_communities:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
continue
# 检查是否有 ExtractedEntity 节点
entities = await repo.get_all_entities(end_user_id)
if not entities:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
continue
logger.info(f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体,开始全量聚类")
await engine.full_clustering(end_user_id)
initialized += 1
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
except Exception as e:
failed += 1
logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}")
finally:
await connector.close()
logger.info(
f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}"
)
return {
"status": "SUCCESS",
"initialized": initialized,
"skipped": skipped,
"failed": failed,
}
try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
return result
except Exception as e:
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}