Merge pull request #630 from SuanmoSuanyangTechnology/fix/celery

[changes]Community node attribute check
This commit is contained in:
Ke Sun
2026-03-19 19:26:52 +08:00
committed by GitHub
5 changed files with 123 additions and 14 deletions

View File

@@ -198,8 +198,15 @@ async def get_workspace_end_users(
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
try:
from app.tasks import init_community_clustering_for_users
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
from app.aioRedis import aio_redis_get
done_key = f"community_cluster:done:workspace:{workspace_id}"
already_done = await aio_redis_get(done_key)
if already_done:
api_logger.info(f"工作空间 {workspace_id} 社区数据已完整,跳过本次聚类任务投递")
else:
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")

View File

@@ -69,11 +69,13 @@ class LabelPropagationEngine:
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -468,12 +470,28 @@ class LabelPropagationEngine:
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
# 生成 summary_embedding
summary_embedding: Optional[List[float]] = None
if self.embedding_model_id and summary:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
vectors = await embedder.response([summary])
if vectors:
summary_embedding = vectors[0]
except Exception as e:
logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
summary_embedding=summary_embedding,
)
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
except Exception as e:

View File

@@ -19,6 +19,8 @@ from app.repositories.neo4j.cypher_queries import (
CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA,
GET_INCOMPLETE_COMMUNITIES,
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING,
)
logger = logging.getLogger(__name__)
@@ -170,6 +172,21 @@ class CommunityRepository:
logger.error(f"refresh_member_count failed: {e}")
return 0
async def get_incomplete_communities(self, end_user_id: str, check_embedding: bool = False) -> List[str]:
"""查询该用户下属性不完整的 Community 节点 ID 列表。
Args:
end_user_id: 用户 ID
check_embedding: 为 True 时额外检查 summary_embedding 是否缺失(仅当用户有 embedding 模型配置时传 True
"""
try:
query = GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING if check_embedding else GET_INCOMPLETE_COMMUNITIES
result = await self.connector.execute_query(query, end_user_id=end_user_id)
return [row["community_id"] for row in result]
except Exception as e:
logger.error(f"get_incomplete_communities failed: {e}")
return []
async def update_community_metadata(
self,
community_id: str,
@@ -177,8 +194,9 @@ class CommunityRepository:
name: str,
summary: str,
core_entities: List[str],
summary_embedding: Optional[List[float]] = None,
) -> bool:
"""更新社区的名称、摘要核心实体列表。"""
"""更新社区的名称、摘要核心实体列表及 summary_embedding"""
try:
result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA,
@@ -187,6 +205,7 @@ class CommunityRepository:
name=name,
summary=summary,
core_entities=core_entities,
summary_embedding=summary_embedding,
)
return bool(result)
except Exception as e:

View File

@@ -1153,10 +1153,11 @@ RETURN c.community_id AS community_id, cnt AS member_count
UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name,
c.summary = $summary,
c.core_entities = $core_entities,
c.updated_at = datetime()
SET c.name = $name,
c.summary = $summary,
c.core_entities = $core_entities,
c.summary_embedding = $summary_embedding,
c.updated_at = datetime()
RETURN c.community_id AS community_id
"""
@@ -1202,3 +1203,18 @@ RETURN
properties(r) AS r_props,
startNode(r) = e AS r_from_e
"""
GET_INCOMPLETE_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
OR c.name = '' OR c.summary = ''
RETURN c.community_id AS community_id
"""
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
OR c.name = '' OR c.summary = ''
OR c.summary_embedding IS NULL
RETURN c.community_id AS community_id
"""

View File

@@ -2675,13 +2675,15 @@ def write_perceptual_memory(
time_limit=7200, # 2小时硬超时
soft_time_limit=6900,
)
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
任务完成且所有用户数据均完整时,写入 Redis 标记,避免下次重复投递。
Args:
end_user_ids: 需要检查的用户 ID 列表
workspace_id: 工作空间 ID用于完成标记
Returns:
包含任务执行结果的字典
@@ -2707,6 +2709,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
user_llm_map: Dict[str, Optional[str]] = {}
user_embedding_map: Dict[str, Optional[str]] = {}
try:
with get_db_context() as db:
from app.services.memory_agent_service import get_end_users_connected_configs_batch
@@ -2718,21 +2721,54 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
try:
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
except Exception as e:
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
logger.warning(f"[CommunityCluster] 用户 {uid} 加载配置失败,将使用 None: {e}")
user_llm_map[uid] = None
user_embedding_map[uid] = None
else:
user_llm_map[uid] = None
user_embedding_map[uid] = None
except Exception as e:
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
logger.warning(f"[CommunityCluster] 批量获取配置失败,所有用户将使用 None: {e}")
for end_user_id in end_user_ids:
try:
# 已有社区节点则跳过
# 已有社区节点时,检查是否存在属性不完整的节点
has_communities = await repo.has_communities(end_user_id)
if has_communities:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
llm_model_id = user_llm_map.get(end_user_id)
embedding_model_id = user_embedding_map.get(end_user_id)
incomplete_ids = await repo.get_incomplete_communities(
end_user_id, check_embedding=bool(embedding_model_id)
)
if not incomplete_ids:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 社区节点均完整,跳过")
continue
# 对不完整的社区节点逐一补全元数据
engine = LabelPropagationEngine(
connector=connector,
llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id,
)
logger.info(
f"[CommunityCluster] 用户 {end_user_id} 发现 {len(incomplete_ids)} 个属性不完整的社区,开始补全"
)
patch_ok = 0
patch_fail = 0
for cid in incomplete_ids:
try:
await engine._generate_community_metadata(cid, end_user_id)
patch_ok += 1
except Exception as patch_err:
patch_fail += 1
logger.error(f"[CommunityCluster] 社区 {cid} 元数据补全失败: {patch_err}")
logger.info(
f"[CommunityCluster] 用户 {end_user_id} 社区补全完成: 成功={patch_ok}, 失败={patch_fail}"
)
initialized += 1
continue
# 检查是否有 ExtractedEntity 节点
@@ -2742,11 +2778,13 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
continue
# 每个用户使用自己的 llm_model_id
# 每个用户使用自己的 llm_model_id / embedding_model_id
llm_model_id = user_llm_map.get(end_user_id)
embedding_model_id = user_embedding_map.get(end_user_id)
engine = LabelPropagationEngine(
connector=connector,
llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id,
)
logger.info(f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}")
@@ -2782,6 +2820,17 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
# 所有用户均完整(无需初始化也无失败),写入 Redis 标记1小时内不再重复投递
if workspace_id and result.get("initialized", 0) == 0 and result.get("failed", 0) == 0:
try:
_r = get_sync_redis_client()
if _r:
_r.set(f"community_cluster:done:workspace:{workspace_id}", "1", ex=3600)
logger.info(f"[CommunityCluster] 工作空间 {workspace_id} 数据完整已写入完成标记1小时有效")
except Exception as e:
logger.warning(f"[CommunityCluster] 写入完成标记失败: {e}")
return result
except Exception as e: