[changes] Community Clustering Retrieval Module
This commit is contained in:
388
api/app/tasks.py
388
api/app/tasks.py
@@ -2416,3 +2416,391 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_implicit_emotions_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
# 触发型任务标识,区别于 periodic_tasks 队列中的定时任务
|
||||
triggered=True,
|
||||
)
|
||||
def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。
|
||||
存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.repositories.implicit_emotions_storage_repository import (
|
||||
ImplicitEmotionsStorageRepository,
|
||||
)
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from app.services.implicit_memory_service import ImplicitMemoryService
|
||||
|
||||
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
|
||||
with get_db_context() as db:
|
||||
repo = ImplicitEmotionsStorageRepository(db)
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
existing = repo.get_by_end_user_id(end_user_id)
|
||||
if existing is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无记录,开始初始化")
|
||||
implicit_ok = False
|
||||
emotion_ok = False
|
||||
try:
|
||||
try:
|
||||
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
|
||||
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
|
||||
await implicit_service.save_profile_cache(
|
||||
end_user_id=end_user_id, profile_data=profile_data, db=db
|
||||
)
|
||||
implicit_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}")
|
||||
|
||||
try:
|
||||
emotion_service = EmotionAnalyticsService()
|
||||
suggestions_data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=end_user_id, db=db, language="zh"
|
||||
)
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
|
||||
)
|
||||
emotion_ok = True
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}")
|
||||
|
||||
if implicit_ok or emotion_ok:
|
||||
initialized += 1
|
||||
else:
|
||||
failed += 1
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 初始化异常: {e}")
|
||||
|
||||
logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_interest_distribution_for_users",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
|
||||
默认生成中文(zh)兴趣分布数据。
|
||||
|
||||
Args:
|
||||
self: task object
|
||||
end_user_ids: 需要检查的用户ID列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
|
||||
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
failed = 0
|
||||
skipped = 0
|
||||
language = "zh"
|
||||
|
||||
service = MemoryAgentService()
|
||||
|
||||
with get_db_context() as db:
|
||||
for end_user_id in end_user_ids:
|
||||
# 存在性检查:缓存有数据则跳过
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
if cached is not None:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
|
||||
try:
|
||||
result = await service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=5,
|
||||
language=language,
|
||||
)
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
expire=INTEREST_CACHE_EXPIRE,
|
||||
)
|
||||
initialized += 1
|
||||
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
|
||||
|
||||
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.write_perceptual_memory",
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=3600,
|
||||
soft_time_limit=3300,
|
||||
)
|
||||
def write_perceptual_memory(
|
||||
self,
|
||||
end_user_id: str,
|
||||
model_api_config: dict,
|
||||
file_type: str,
|
||||
file_url: str,
|
||||
file_message: dict
|
||||
):
|
||||
"""
|
||||
Write perceptual memory for a user into PostgreSQL and Neo4j.
|
||||
|
||||
This task generates or updates the user's perceptual memory
|
||||
in the backend databases. It is intended to be executed asynchronously
|
||||
via Celery.
|
||||
|
||||
Args:
|
||||
end_user_id (uuid.UUID): The unique identifier of the end user.
|
||||
model_api_config (ModelInfo): API configuration for the model
|
||||
used to generate perceptual memory.
|
||||
file_type (str): The file type
|
||||
file_url (url): The url of file
|
||||
file_message (dict): The file message containing details about the file
|
||||
to be processed.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
|
||||
set_asyncio_event_loop()
|
||||
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
|
||||
model_info = ModelInfo(**model_api_config)
|
||||
with get_db_context() as db:
|
||||
memory_perceptual_service = MemoryPerceptualService(db)
|
||||
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
|
||||
end_user_id,
|
||||
model_info,
|
||||
file_type,
|
||||
file_url,
|
||||
file_message,
|
||||
))
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 社区聚类补全任务(触发型)
|
||||
# =============================================================================
|
||||
|
||||
@celery_app.task(
|
||||
name="app.tasks.init_community_clustering_for_users",
|
||||
bind=True,
|
||||
ignore_result=False,
|
||||
max_retries=0,
|
||||
acks_late=False,
|
||||
time_limit=7200, # 2小时硬超时
|
||||
soft_time_limit=6900,
|
||||
)
|
||||
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
|
||||
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
|
||||
|
||||
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
|
||||
|
||||
Args:
|
||||
end_user_ids: 需要检查的用户 ID 列表
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
|
||||
initialized = 0
|
||||
skipped = 0
|
||||
failed = 0
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
repo = CommunityRepository(connector)
|
||||
|
||||
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
|
||||
user_llm_map: Dict[str, Optional[str]] = {}
|
||||
user_embedding_map: Dict[str, Optional[str]] = {}
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
batch_configs = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
for uid, cfg_info in batch_configs.items():
|
||||
config_id = cfg_info.get("memory_config_id")
|
||||
if config_id:
|
||||
try:
|
||||
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
|
||||
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
|
||||
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
else:
|
||||
user_llm_map[uid] = None
|
||||
user_embedding_map[uid] = None
|
||||
except Exception as e:
|
||||
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
|
||||
|
||||
for end_user_id in end_user_ids:
|
||||
try:
|
||||
# 已有社区节点则跳过
|
||||
has_communities = await repo.has_communities(end_user_id)
|
||||
if has_communities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
|
||||
continue
|
||||
|
||||
# 检查是否有 ExtractedEntity 节点
|
||||
entities = await repo.get_all_entities(end_user_id)
|
||||
if not entities:
|
||||
skipped += 1
|
||||
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
|
||||
continue
|
||||
|
||||
# 每个用户使用自己的 llm_model_id
|
||||
llm_model_id = user_llm_map.get(end_user_id)
|
||||
embedding_model_id = user_embedding_map.get(end_user_id)
|
||||
engine = LabelPropagationEngine(
|
||||
connector=connector,
|
||||
llm_model_id=llm_model_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
)
|
||||
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 有 {len(entities)} 个实体,开始全量聚类,llm_model_id={llm_model_id}")
|
||||
await engine.full_clustering(end_user_id)
|
||||
initialized += 1
|
||||
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
|
||||
|
||||
except Exception as e:
|
||||
failed += 1
|
||||
logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
logger.info(
|
||||
f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}"
|
||||
)
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"initialized": initialized,
|
||||
"skipped": skipped,
|
||||
"failed": failed,
|
||||
}
|
||||
|
||||
try:
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
loop = set_asyncio_event_loop()
|
||||
result = loop.run_until_complete(_run())
|
||||
result["elapsed_time"] = time.time() - start_time
|
||||
result["task_id"] = self.request.id
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user