Merge pull request #630 from SuanmoSuanyangTechnology/fix/celery
[changes]Community node attribute check
This commit is contained in:
@@ -198,8 +198,15 @@ async def get_workspace_end_users(
|
||||
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
|
||||
try:
|
||||
from app.tasks import init_community_clustering_for_users
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
from app.aioRedis import aio_redis_get
|
||||
|
||||
done_key = f"community_cluster:done:workspace:{workspace_id}"
|
||||
already_done = await aio_redis_get(done_key)
|
||||
if already_done:
|
||||
api_logger.info(f"工作空间 {workspace_id} 社区数据已完整,跳过本次聚类任务投递")
|
||||
else:
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
|
||||
@@ -69,11 +69,13 @@ class LabelPropagationEngine:
|
||||
connector: Neo4jConnector,
|
||||
config_id: Optional[str] = None,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.config_id = config_id
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
@@ -468,12 +470,28 @@ class LabelPropagationEngine:
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||
|
||||
# 生成 summary_embedding
|
||||
summary_embedding: Optional[List[float]] = None
|
||||
if self.embedding_model_id and summary:
|
||||
try:
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
vectors = await embedder.response([summary])
|
||||
if vectors:
|
||||
summary_embedding = vectors[0]
|
||||
except Exception as e:
|
||||
logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
|
||||
|
||||
await self.repo.update_community_metadata(
|
||||
community_id=community_id,
|
||||
end_user_id=end_user_id,
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
summary_embedding=summary_embedding,
|
||||
)
|
||||
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -19,6 +19,8 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
CHECK_USER_HAS_COMMUNITIES,
|
||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
GET_INCOMPLETE_COMMUNITIES,
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -170,6 +172,21 @@ class CommunityRepository:
|
||||
logger.error(f"refresh_member_count failed: {e}")
|
||||
return 0
|
||||
|
||||
async def get_incomplete_communities(self, end_user_id: str, check_embedding: bool = False) -> List[str]:
|
||||
"""查询该用户下属性不完整的 Community 节点 ID 列表。
|
||||
|
||||
Args:
|
||||
end_user_id: 用户 ID
|
||||
check_embedding: 为 True 时额外检查 summary_embedding 是否缺失(仅当用户有 embedding 模型配置时传 True)
|
||||
"""
|
||||
try:
|
||||
query = GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING if check_embedding else GET_INCOMPLETE_COMMUNITIES
|
||||
result = await self.connector.execute_query(query, end_user_id=end_user_id)
|
||||
return [row["community_id"] for row in result]
|
||||
except Exception as e:
|
||||
logger.error(f"get_incomplete_communities failed: {e}")
|
||||
return []
|
||||
|
||||
async def update_community_metadata(
|
||||
self,
|
||||
community_id: str,
|
||||
@@ -177,8 +194,9 @@ class CommunityRepository:
|
||||
name: str,
|
||||
summary: str,
|
||||
core_entities: List[str],
|
||||
summary_embedding: Optional[List[float]] = None,
|
||||
) -> bool:
|
||||
"""更新社区的名称、摘要和核心实体列表。"""
|
||||
"""更新社区的名称、摘要、核心实体列表及 summary_embedding。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
UPDATE_COMMUNITY_METADATA,
|
||||
@@ -187,6 +205,7 @@ class CommunityRepository:
|
||||
name=name,
|
||||
summary=summary,
|
||||
core_entities=core_entities,
|
||||
summary_embedding=summary_embedding,
|
||||
)
|
||||
return bool(result)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1153,10 +1153,11 @@ RETURN c.community_id AS community_id, cnt AS member_count
|
||||
|
||||
UPDATE_COMMUNITY_METADATA = """
|
||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||
SET c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.updated_at = datetime()
|
||||
SET c.name = $name,
|
||||
c.summary = $summary,
|
||||
c.core_entities = $core_entities,
|
||||
c.summary_embedding = $summary_embedding,
|
||||
c.updated_at = datetime()
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
@@ -1202,3 +1203,18 @@ RETURN
|
||||
properties(r) AS r_props,
|
||||
startNode(r) = e AS r_from_e
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||
OR c.name = '' OR c.summary = ''
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
|
||||
MATCH (c:Community {end_user_id: $end_user_id})
|
||||
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
|
||||
OR c.name = '' OR c.summary = ''
|
||||
OR c.summary_embedding IS NULL
|
||||
RETURN c.community_id AS community_id
|
||||
"""
|
||||
|
||||
@@ -2675,13 +2675,15 @@ def write_perceptual_memory(
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900,
|
||||
)
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
||||
任务完成且所有用户数据均完整时,写入 Redis 标记,避免下次重复投递。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户 ID 列表
|
||||
workspace_id: 工作空间 ID,用于完成标记
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
@@ -2707,6 +2709,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
||||
|
||||
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
||||
user_llm_map: Dict[str, Optional[str]] = {}
|
||||
user_embedding_map: Dict[str, Optional[str]] = {}
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
@@ -2718,21 +2721,54 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
||||
try:
|
||||
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
||||
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
||||
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
|
||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载配置失败,将使用 None: {e}")
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
else:
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
|
||||
logger.warning(f"[CommunityCluster] 批量获取配置失败,所有用户将使用 None: {e}")
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
try:
|
||||
# 已有社区节点则跳过
|
||||
# 已有社区节点时,检查是否存在属性不完整的节点
|
||||
has_communities = await repo.has_communities(end_user_id)
|
||||
if has_communities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
|
||||
llm_model_id = user_llm_map.get(end_user_id)
|
||||
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||
incomplete_ids = await repo.get_incomplete_communities(
|
||||
end_user_id, check_embedding=bool(embedding_model_id)
|
||||
)
|
||||
if not incomplete_ids:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 社区节点均完整,跳过")
|
||||
continue
|
||||
|
||||
# 对不完整的社区节点逐一补全元数据
|
||||
engine = LabelPropagationEngine(
|
||||
connector=connector,
|
||||
llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
)
|
||||
logger.info(
|
||||
f"[CommunityCluster] 用户 {end_user_id} 发现 {len(incomplete_ids)} 个属性不完整的社区,开始补全"
|
||||
)
|
||||
patch_ok = 0
|
||||
patch_fail = 0
|
||||
for cid in incomplete_ids:
|
||||
try:
|
||||
await engine._generate_community_metadata(cid, end_user_id)
|
||||
patch_ok += 1
|
||||
except Exception as patch_err:
|
||||
patch_fail += 1
|
||||
logger.error(f"[CommunityCluster] 社区 {cid} 元数据补全失败: {patch_err}")
|
||||
logger.info(
|
||||
f"[CommunityCluster] 用户 {end_user_id} 社区补全完成: 成功={patch_ok}, 失败={patch_fail}"
|
||||
)
|
||||
initialized += 1
|
||||
continue
|
||||
|
||||
# 检查是否有 ExtractedEntity 节点
|
||||
@@ -2742,11 +2778,13 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
||||
continue
|
||||
|
||||
# 每个用户使用自己的 llm_model_id
|
||||
# 每个用户使用自己的 llm_model_id / embedding_model_id
|
||||
llm_model_id = user_llm_map.get(end_user_id)
|
||||
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||
engine = LabelPropagationEngine(
|
||||
connector=connector,
|
||||
llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
)
|
||||
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||
@@ -2782,6 +2820,17 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
# 所有用户均完整(无需初始化也无失败),写入 Redis 标记,1小时内不再重复投递
|
||||
if workspace_id and result.get("initialized", 0) == 0 and result.get("failed", 0) == 0:
|
||||
try:
|
||||
_r = get_sync_redis_client()
|
||||
if _r:
|
||||
_r.set(f"community_cluster:done:workspace:{workspace_id}", "1", ex=3600)
|
||||
logger.info(f"[CommunityCluster] 工作空间 {workspace_id} 数据完整,已写入完成标记(1小时有效)")
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 写入完成标记失败: {e}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user