[changes] Community Clustering Retrieval Module

This commit is contained in:
lanceyq
2026-03-16 12:30:00 +08:00
parent 5b431400be
commit c244e9834f
12 changed files with 1203 additions and 61 deletions

View File

@@ -120,7 +120,7 @@ class SearchService:
raw_results is None if return_raw_results=False raw_results is None if return_raw_results=False
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query # Clean query
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
@@ -146,8 +146,8 @@ class SearchService:
if search_type == "hybrid": if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {}) reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities # Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in reranked_results: if category in include and category in reranked_results:
@@ -157,13 +157,43 @@ class SearchService:
else: else:
# For keyword or embedding search, results are directly in answer dict # For keyword or embedding search, results are directly in answer dict
# Apply same priority order # Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in answer: if category in include and category in answer:
category_results = answer[category] category_results = answer[category]
if isinstance(category_results, list): if isinstance(category_results, list):
answer_list.extend(category_results) answer_list.extend(category_results)
# 对命中的 community 节点展开其成员 statements
if "communities" in include:
community_results = (
answer.get('reranked_results', {}).get('communities', [])
if search_type == "hybrid"
else answer.get('communities', [])
)
community_ids = [
r.get("id") for r in community_results if r.get("id")
]
if community_ids and end_user_id:
try:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
expand_result = await search_graph_community_expand(
connector=connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
await connector.close()
expanded_stmts = expand_result.get("expanded_statements", [])
if expanded_stmts:
# 展开的 statements 插入 communities 之后、statements 之前
answer_list.extend(expanded_stmts)
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
except Exception as e:
logger.warning(f"社区展开检索失败,跳过: {e}")
# Extract clean content from all results # Extract clean content from all results
content_list = [ content_list = [

View File

@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -165,10 +165,17 @@ async def write(
statement_chunk_edges=all_statement_chunk_edges, statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges, statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges, entity_edges=all_entity_entity_edges,
connector=neo4j_connector connector=neo4j_connector,
) )
if success: if success:
logger.info("Successfully saved all data to Neo4j") logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break break
else: else:
logger.warning("Failed to save some data to Neo4j") logger.warning("Failed to save some data to Neo4j")

View File

@@ -238,7 +238,7 @@ def rerank_with_activation(
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries"]: for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
@@ -281,21 +281,23 @@ def rerank_with_activation(
for item in items_list: for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items: if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数 # 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items(): for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0) bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0) # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding # 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数 base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用 # 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score item["content_score"] = content_score
item["base_score"] = base_score item["base_score"] = base_score

View File

@@ -20,6 +20,9 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛 # 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10 MAX_ITERATIONS = 10
# 社区核心实体取 top-N 数量
CORE_ENTITY_LIMIT = 10
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
"""计算两个向量的余弦相似度,任一为空则返回 0。""" """计算两个向量的余弦相似度,任一为空则返回 0。"""
@@ -62,9 +65,18 @@ def _weighted_vote(
class LabelPropagationEngine: class LabelPropagationEngine:
"""标签传播聚类引擎""" """标签传播聚类引擎"""
def __init__(self, connector: Neo4jConnector): def __init__(
self,
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
):
self.connector = connector self.connector = connector
self.repo = CommunityRepository(connector) self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# 公开接口 # 公开接口
@@ -94,58 +106,110 @@ class LabelPropagationEngine:
async def full_clustering(self, end_user_id: str) -> None: async def full_clustering(self, end_user_id: str) -> None:
""" """
全量标签传播初始化。 全量标签传播初始化(分批处理,控制内存峰值)
1. 拉取所有实体,初始化每个实体为独立社区 策略:
2. 迭代:每轮对所有实体做邻居投票,更新社区标签 - 每次只加载 BATCH_SIZE 个实体及其邻居进内存
3. 直到标签不再变化或达到 MAX_ITERATIONS - labels 字典跨批次共享(只存 id→community_id内存极小
4. 将最终标签写入 Neo4j - 每批独立跑 MAX_ITERATIONS 轮 LPA批次间通过 labels 传递社区信息
- 所有批次完成后统一 flush 和 merge
""" """
entities = await self.repo.get_all_entities(end_user_id) BATCH_SIZE = 2000 # 每批实体数,可按需调整
if not entities:
# 先查总数,决定批次数
total_entities = await self.repo.get_all_entities(end_user_id)
if not total_entities:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return return
# 初始化:每个实体持有自己 id 作为社区标签 total_count = len(total_entities)
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体,"
embeddings: Dict[str, Optional[List[float]]] = { f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}")
e["id"]: e.get("name_embedding") for e in entities
}
for iteration in range(MAX_ITERATIONS): # labels 跨批次共享:先用全量数据初始化(只存 id内存极小
changed = 0 labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities}
# 随机顺序Python dict 在 3.7+ 保持插入顺序,这里直接遍历 # embeddings 也跨批次共享(每个向量 ~6KB10万实体约 600MB这是不可避免的
for entity in entities: # 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻
eid = entity["id"] # 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding
neighbors = await self.repo.get_entity_neighbors(eid, end_user_id) # 所以只需要当前批次实体的 embedding不需要全量
del total_entities # 释放全量列表,后续按批次加载
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值) for batch_start in range(0, total_count, BATCH_SIZE):
enriched = [] batch_entities = await self.repo.get_entities_page(
for nb in neighbors: end_user_id, skip=batch_start, limit=BATCH_SIZE
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}"
f"标签变化数: {changed}"
) )
if changed == 0: if not batch_entities:
logger.info("[Clustering] 标签已收敛,提前结束迭代")
break break
# 将最终标签写入 Neo4j batch_ids = [e["id"] for e in batch_entities]
batch_embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in batch_entities
}
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}"
f"加载 {len(batch_entities)} 个实体的邻居图..."
)
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
batch_ids, end_user_id
)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS):
changed = 0
for entity in batch_entities:
eid = entity["id"]
neighbors = neighbors_cache.get(eid, [])
# 注入跨批次的最新标签邻居可能在其他批次labels 里有其最新值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
break
# 释放本批次的大对象
del neighbors_cache, batch_embeddings, batch_entities
# 所有批次完成,统一写入 Neo4j
await self._flush_labels(labels, end_user_id) await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values()))
logger.info( logger.info(
f"[Clustering] 全量聚类完成,共 {len(set(labels.values()))} 个社区," f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体,开始后处理合并"
)
all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id)
logger.info(
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体" f"{len(labels)} 个实体"
) )
# 查询存活社区并生成元数据
surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({
e.get("community_id") for e in surviving_communities
if e.get("community_id")
})
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
for cid in surviving_community_ids:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update( async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str self, new_entity_ids: List[str], end_user_id: str
) -> None: ) -> None:
@@ -306,6 +370,90 @@ class LabelPropagationEngine:
except Exception: except Exception:
return None return None
async def _generate_community_metadata(
self, community_id: str, end_user_id: str
) -> None:
"""
为社区生成并写入元数据:名称、摘要、核心实体。
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
NOTE: core_entities按照激活值高低排序会造成对边缘信息检索返回消息质量不高。
"""
try:
members = await self.repo.get_community_members(community_id, end_user_id)
if not members:
return
# 核心实体:按 activation_value 降序取 top-N
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else community_id[:8]
summary = f"包含实体:{', '.join(all_names)}"
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
if self.llm_model_id:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
entity_list_str = "".join(all_names)
prompt = (
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
# 生成 summary_embedding
summary_embedding = None
if self.embedding_model_id and summary:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
results = await embedder.response([summary])
summary_embedding = results[0] if results else None
except Exception as e:
logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
result = await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
summary_embedding=summary_embedding,
)
if result:
logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...")
else:
logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False")
except Exception as e:
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True)
@staticmethod @staticmethod
def _new_community_id() -> str: def _new_community_id() -> str:
return str(uuid.uuid4()) return str(uuid.uuid4())

View File

@@ -18,6 +18,7 @@ from app.core.logging_config import LoggingConfig, get_logger
from app.core.response_utils import fail from app.core.response_utils import fail
from app.core.models.scripts.loader import load_models from app.core.models.scripts.loader import load_models
from app.db import get_db_context from app.db import get_db_context
from app.repositories.neo4j.index_manager import ensure_indexes
# Initialize logging system # Initialize logging system
LoggingConfig.setup_logging() LoggingConfig.setup_logging()
@@ -61,9 +62,18 @@ async def lifespan(app: FastAPI):
else: else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
# 确保 Neo4j 索引存在(幂等,多环境安全)
try:
report = await ensure_indexes()
if report["errors"]:
logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}")
else:
logger.info(f"Neo4j 索引检查完成 [{report['uri']}]")
except Exception as e:
logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}")
logger.info("应用程序启动完成") logger.info("应用程序启动完成")
yield yield
# 应用关闭事件
logger.info("应用程序正在关闭") logger.info("应用程序正在关闭")

View File

@@ -13,9 +13,14 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_LEAVE_ALL_COMMUNITIES, ENTITY_LEAVE_ALL_COMMUNITIES,
GET_ENTITY_NEIGHBORS, GET_ENTITY_NEIGHBORS,
GET_ALL_ENTITIES_FOR_USER, GET_ALL_ENTITIES_FOR_USER,
GET_ENTITIES_PAGE,
GET_COMMUNITY_MEMBERS, GET_COMMUNITY_MEMBERS,
GET_ALL_COMMUNITY_MEMBERS_BATCH,
GET_ALL_ENTITY_NEIGHBORS_BATCH,
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
CHECK_USER_HAS_COMMUNITIES, CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -87,6 +92,41 @@ class CommunityRepository:
logger.error(f"get_all_entities failed: {e}") logger.error(f"get_all_entities failed: {e}")
return [] return []
async def get_entities_page(
self, end_user_id: str, skip: int, limit: int
) -> List[Dict]:
"""分页拉取实体,用于全量聚类分批处理。"""
try:
return await self.connector.execute_query(
GET_ENTITIES_PAGE,
end_user_id=end_user_id,
skip=skip,
limit=limit,
)
except Exception as e:
logger.error(f"get_entities_page failed: {e}")
return []
async def get_entity_neighbors_for_ids(
self, entity_ids: List[str], end_user_id: str
) -> Dict[str, List[Dict]]:
"""批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。"""
try:
rows = await self.connector.execute_query(
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
entity_ids=entity_ids,
end_user_id=end_user_id,
)
result: Dict[str, List[Dict]] = {}
for row in rows:
eid = row["entity_id"]
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
result.setdefault(eid, []).append(neighbor)
return result
except Exception as e:
logger.error(f"get_entity_neighbors_for_ids failed: {e}")
return {}
async def get_community_members( async def get_community_members(
self, community_id: str, end_user_id: str self, community_id: str, end_user_id: str
) -> List[Dict]: ) -> List[Dict]:
@@ -127,3 +167,28 @@ class CommunityRepository:
except Exception as e: except Exception as e:
logger.error(f"refresh_member_count failed: {e}") logger.error(f"refresh_member_count failed: {e}")
return 0 return 0
async def update_community_metadata(
self,
community_id: str,
end_user_id: str,
name: str,
summary: str,
core_entities: List[str],
summary_embedding: Optional[List[float]] = None,
) -> bool:
"""更新社区的名称、摘要、核心实体列表和摘要向量。"""
try:
result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA,
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
summary_embedding=summary_embedding,
)
return bool(result)
except Exception as e:
logger.error(f"update_community_metadata failed: {e}")
return False

View File

@@ -1139,6 +1139,15 @@ RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
ORDER BY coalesce(e.activation_value, 0) DESC ORDER BY coalesce(e.activation_value, 0) DESC
""" """
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN c.community_id AS community_id,
e.id AS id, e.name AS name, e.entity_type AS entity_type,
e.importance_score AS importance_score, e.activation_value AS activation_value,
e.name_embedding AS name_embedding
ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC
"""
CHECK_USER_HAS_COMMUNITIES = """ CHECK_USER_HAS_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id}) MATCH (c:Community {end_user_id: $end_user_id})
RETURN count(c) AS community_count RETURN count(c) AS community_count
@@ -1150,3 +1159,128 @@ WITH c, count(e) AS cnt
SET c.member_count = cnt SET c.member_count = cnt
RETURN c.community_id AS community_id, cnt AS member_count RETURN c.community_id AS community_id, cnt AS member_count
""" """
UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name,
c.summary = $summary,
c.core_entities = $core_entities,
c.summary_embedding = $summary_embedding,
c.updated_at = datetime()
RETURN c.community_id AS community_id
"""
GET_ENTITIES_PAGE = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN e.id AS id,
e.name AS name,
e.name_embedding AS name_embedding,
e.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
ORDER BY e.id
SKIP $skip LIMIT $limit
"""
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """
// 批量拉取指定实体列表的邻居(用于分批全量聚类)
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
WHERE e.id IN $entity_ids
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
WHERE nb2.id <> e.id
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
UNWIND all_neighbors AS nb
WITH e, nb WHERE nb IS NOT NULL
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN DISTINCT
e.id AS entity_id,
nb.id AS id,
nb.name AS name,
nb.name_embedding AS name_embedding,
nb.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
"""
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
// 来源一:直接关系邻居
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
// 来源二:同 Statement 共现邻居
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
WHERE nb2.id <> e.id
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
UNWIND all_neighbors AS nb
WITH e, nb WHERE nb IS NOT NULL
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN DISTINCT
e.id AS entity_id,
nb.id AS id,
nb.name AS name,
nb.name_embedding AS name_embedding,
nb.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
"""
# Community keyword search: matches name or summary via fulltext index
SEARCH_COMMUNITIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.community_id AS id,
c.name AS name,
c.summary AS content,
c.core_entities AS core_entities,
c.member_count AS member_count,
c.end_user_id AS end_user_id,
c.updated_at AS updated_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Community 向量检索 ──────────────────────────────────────────────────
# Community embedding-based search: cosine similarity on Community.summary_embedding
COMMUNITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.summary_embedding IS NOT NULL
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.community_id AS id,
c.name AS name,
c.summary AS content,
c.core_entities AS core_entities,
c.member_count AS member_count,
c.end_user_id AS end_user_id,
c.updated_at AS updated_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Community 展开检索 ──────────────────────────────────────────────────
# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索)
EXPAND_COMMUNITY_STATEMENTS = """
MATCH (c:Community {community_id: $community_id})
MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c)
MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
WHERE s.end_user_id = $end_user_id
RETURN s.statement AS statement,
s.id AS id,
s.end_user_id AS end_user_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
e.name AS source_entity,
c.name AS community_name
ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit
"""

View File

@@ -1,5 +1,5 @@
import asyncio import asyncio
from typing import List from typing import List, Optional
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -156,10 +156,13 @@ async def save_dialog_and_statements_to_neo4j(
entity_edges: List[EntityEntityEdge], entity_edges: List[EntityEntityEdge],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector connector: Neo4jConnector,
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
schedule_clustering_after_write() 显式触发。
Args: Args:
dialogue_nodes: List of DialogueNode objects to save dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save chunk_nodes: List of ChunkNode objects to save
@@ -290,13 +293,6 @@ async def save_dialog_and_statements_to_neo4j(
logger.info("Transaction completed. Summary: %s", summary) logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results) logger.debug("Full transaction results: %r", results)
# 写入成功后,触发聚类
if entity_nodes:
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
await _trigger_clustering(new_entity_ids, end_user_id)
return True return True
except Exception as e: except Exception as e:
@@ -306,9 +302,38 @@ async def save_dialog_and_statements_to_neo4j(
return False return False
def schedule_clustering_after_write(
entity_nodes: List,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
"""
写入 Neo4j 成功后,调度后台聚类任务。
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
使用 asyncio.create_task 异步触发,不阻塞写入响应。
"""
if not entity_nodes:
return
clustering_enabled = os.getenv("CLUSTERING_ENABLED", "true").lower() != "false"
if not clustering_enabled:
logger.info("[Clustering] 聚类已禁用CLUSTERING_ENABLED=false跳过聚类触发")
return
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
async def _trigger_clustering( async def _trigger_clustering(
new_entity_ids: List[str], new_entity_ids: List[str],
end_user_id: str, end_user_id: str,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
聚类触发函数,自动判断全量初始化还是增量更新。 聚类触发函数,自动判断全量初始化还是增量更新。
@@ -318,7 +343,7 @@ async def _trigger_clustering(
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
connector = Neo4jConnector() connector = Neo4jConnector()
engine = LabelPropagationEngine(connector) engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}") logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}")
except Exception as e: except Exception as e:

View File

@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
from app.repositories.neo4j.cypher_queries import ( from app.repositories.neo4j.cypher_queries import (
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
COMMUNITY_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
@@ -285,6 +288,15 @@ async def search_graph(
limit=limit, limit=limit,
)) ))
task_keys.append("summaries") task_keys.append("summaries")
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -396,6 +408,16 @@ async def search_graph_by_embedding(
)) ))
task_keys.append("summaries") task_keys.append("summaries")
# Communities (向量语义匹配)
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -408,6 +430,7 @@ async def search_graph_by_embedding(
"chunks": [], "chunks": [],
"entities": [], "entities": [],
"summaries": [], "summaries": [],
"communities": [],
} }
for key, result in zip(task_keys, task_results): for key, result in zip(task_keys, task_results):
@@ -661,6 +684,62 @@ async def search_graph_by_chunk_id(
return {"chunks": chunks} return {"chunks": chunks}
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体,
再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点,
按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。
Args:
connector: Neo4j 连接器
community_ids: 已命中的社区 ID 列表
end_user_id: 用户 ID用于数据隔离
limit: 每个社区最多返回的 Statement 数量
Returns:
{"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]}
"""
if not community_ids or not end_user_id:
return {"expanded_statements": []}
tasks = [
connector.execute_query(
EXPAND_COMMUNITY_STATEMENTS,
community_id=cid,
end_user_id=end_user_id,
limit=limit,
)
for cid in community_ids
]
task_results = await asyncio.gather(*tasks, return_exceptions=True)
expanded: List[Dict[str, Any]] = []
for cid, result in zip(community_ids, task_results):
if isinstance(result, Exception):
logger.warning(f"社区展开检索失败 community_id={cid}: {result}")
else:
expanded.extend(result)
# 按 activation_value 全局排序后去重
from app.core.memory.src.search import _deduplicate_results
expanded.sort(
key=lambda x: float(x.get("activation_value") or 0),
reverse=True,
)
expanded = _deduplicate_results(expanded)
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
return {"expanded_statements": expanded}
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,

View File

@@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
"""Neo4j 索引管理模块
负责检查和创建 Neo4j 全文索引与向量索引。
支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。
用法:
# 作为模块调用(应用启动时)
from app.repositories.neo4j.index_manager import ensure_indexes
await ensure_indexes()
# 作为独立脚本执行(手动建索引)
python -m app.repositories.neo4j.index_manager
"""
import asyncio
import logging
from dataclasses import dataclass
from typing import List
from app.core.config import settings
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────────────────────
# 索引定义表
# ─────────────────────────────────────────────────────────────
@dataclass
class FulltextIndexDef:
name: str
label: str
properties: List[str]
@dataclass
class VectorIndexDef:
name: str
label: str
property: str
dimensions: int
similarity: str = "cosine"
# 全文索引清单(现有 + 新增 communities
FULLTEXT_INDEXES: List[FulltextIndexDef] = [
FulltextIndexDef("statementsFulltext", "Statement", ["statement"]),
FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]),
FulltextIndexDef("chunksFulltext", "Chunk", ["content"]),
FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]),
FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源
]
# 向量索引清单(预留 community 二期)
VECTOR_INDEXES: List[VectorIndexDef] = [
VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536),
VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536),
VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536),
VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536),
# 二期:社区向量索引
VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536),
]
# ─────────────────────────────────────────────────────────────
# 核心检查 / 创建逻辑
# ─────────────────────────────────────────────────────────────
async def _get_existing_indexes(connector: Neo4jConnector) -> set:
"""查询 Neo4j 中已存在的索引名称集合"""
rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name")
return {row["name"] for row in rows}
async def _ensure_fulltext_index(
connector: Neo4jConnector,
idx: FulltextIndexDef,
existing: set,
) -> str:
"""检查并按需创建全文索引,返回操作状态描述"""
if idx.name in existing:
return f"[SKIP] 全文索引已存在: {idx.name}"
props = ", ".join(f"n.{p}" for p in idx.properties)
cypher = (
f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS '
f'FOR (n:{idx.label}) ON EACH [{props}]'
)
await connector.execute_query(cypher)
return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label}{idx.properties})"
async def _ensure_vector_index(
connector: Neo4jConnector,
idx: VectorIndexDef,
existing: set,
) -> str:
"""检查并按需创建向量索引,返回操作状态描述"""
if idx.name in existing:
return f"[SKIP] 向量索引已存在: {idx.name}"
cypher = (
f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS "
f"FOR (n:{idx.label}) ON n.{idx.property} "
f"OPTIONS {{indexConfig: {{"
f"`vector.dimensions`: {idx.dimensions}, "
f"`vector.similarity_function`: '{idx.similarity}'"
f"}}}}"
)
await connector.execute_query(cypher)
return (
f"[CREATE] 向量索引已创建: {idx.name} "
f"({idx.label}.{idx.property}, dim={idx.dimensions})"
)
async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict:
"""
检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。
Args:
connector: 可选,传入已有连接器;为 None 时自动创建。
Returns:
dict: {
"uri": 当前连接的 Neo4j URI,
"fulltext": [操作日志列表],
"vector": [操作日志列表],
"errors": [错误信息列表],
}
"""
own_connector = connector is None
if own_connector:
connector = Neo4jConnector()
report = {
"uri": settings.NEO4J_URI,
"fulltext": [],
"vector": [],
"errors": [],
}
try:
# 一次性拉取所有已有索引名
existing = await _get_existing_indexes(connector)
logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}")
logger.info(f"[IndexManager] 已有索引数量: {len(existing)}")
# 处理全文索引
for idx in FULLTEXT_INDEXES:
try:
msg = await _ensure_fulltext_index(connector, idx, existing)
report["fulltext"].append(msg)
logger.info(f"[IndexManager] {msg}")
except Exception as e:
err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}"
report["errors"].append(err)
logger.error(f"[IndexManager] {err}")
# 处理向量索引
for idx in VECTOR_INDEXES:
try:
msg = await _ensure_vector_index(connector, idx, existing)
report["vector"].append(msg)
logger.info(f"[IndexManager] {msg}")
except Exception as e:
err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}"
report["errors"].append(err)
logger.error(f"[IndexManager] {err}")
finally:
if own_connector:
await connector.close()
return report
async def check_indexes(connector: Neo4jConnector | None = None) -> dict:
"""
仅检查索引状态,不创建任何索引。
Returns:
dict: {
"uri": ...,
"present": [已存在的索引名],
"missing_fulltext": [缺失的全文索引名],
"missing_vector": [缺失的向量索引名],
}
"""
own_connector = connector is None
if own_connector:
connector = Neo4jConnector()
try:
existing = await _get_existing_indexes(connector)
missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing]
missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing]
return {
"uri": settings.NEO4J_URI,
"present": sorted(existing),
"missing_fulltext": missing_ft,
"missing_vector": missing_vec,
}
finally:
if own_connector:
await connector.close()
# ─────────────────────────────────────────────────────────────
# 独立脚本入口
# ─────────────────────────────────────────────────────────────
async def _main():
import sys
print(f"\n{'='*60}")
print(f"Neo4j 索引管理工具")
print(f"环境: {settings.NEO4J_URI}")
print(f"{'='*60}\n")
# 先检查
print(">>> 检查当前索引状态...\n")
status = await check_indexes()
print(f" 已存在索引数: {len(status['present'])}")
if status["missing_fulltext"]:
print(f" 缺失全文索引: {status['missing_fulltext']}")
if status["missing_vector"]:
print(f" 缺失向量索引: {status['missing_vector']}")
if not status["missing_fulltext"] and not status["missing_vector"]:
print("\n 所有索引均已存在,无需操作。")
return
# 再创建
print("\n>>> 开始创建缺失索引...\n")
report = await ensure_indexes()
for msg in report["fulltext"] + report["vector"]:
print(f" {msg}")
if report["errors"]:
print("\n[!] 以下索引创建失败:")
for err in report["errors"]:
print(f" {err}")
sys.exit(1)
else:
print("\n 全部索引处理完成。")
if __name__ == "__main__":
asyncio.run(_main())

View File

@@ -2416,3 +2416,391 @@ def update_implicit_emotions_storage(self) -> Dict[str, Any]:
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
} }
# =============================================================================
@celery_app.task(
name="app.tasks.init_implicit_emotions_for_users",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
# 触发型任务标识,区别于 periodic_tasks 队列中的定时任务
triggered=True,
)
def init_implicit_emotions_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
"""事件触发任务:对指定用户列表做存在性检查,无记录则执行首次初始化。
由 /dashboard/end_users 接口触发,已有数据的用户直接跳过。
存量用户的数据刷新由定时任务 update_implicit_emotions_storage 负责。
Args:
end_user_ids: 需要检查的用户ID列表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.repositories.implicit_emotions_storage_repository import (
ImplicitEmotionsStorageRepository,
)
from app.services.emotion_analytics_service import EmotionAnalyticsService
from app.services.implicit_memory_service import ImplicitMemoryService
logger.info(f"开始按需初始化隐性记忆/情绪数据,候选用户数: {len(end_user_ids)}")
initialized = 0
failed = 0
skipped = 0
with get_db_context() as db:
repo = ImplicitEmotionsStorageRepository(db)
for end_user_id in end_user_ids:
existing = repo.get_by_end_user_id(end_user_id)
if existing is not None:
skipped += 1
continue
logger.info(f"用户 {end_user_id} 无记录,开始初始化")
implicit_ok = False
emotion_ok = False
try:
try:
implicit_service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
profile_data = await implicit_service.generate_complete_profile(user_id=end_user_id)
await implicit_service.save_profile_cache(
end_user_id=end_user_id, profile_data=profile_data, db=db
)
implicit_ok = True
except Exception as e:
logger.error(f"用户 {end_user_id} 隐性记忆初始化失败: {e}")
try:
emotion_service = EmotionAnalyticsService()
suggestions_data = await emotion_service.generate_emotion_suggestions(
end_user_id=end_user_id, db=db, language="zh"
)
await emotion_service.save_suggestions_cache(
end_user_id=end_user_id, suggestions_data=suggestions_data, db=db
)
emotion_ok = True
except Exception as e:
logger.error(f"用户 {end_user_id} 情绪建议初始化失败: {e}")
if implicit_ok or emotion_ok:
initialized += 1
else:
failed += 1
except Exception as e:
failed += 1
logger.error(f"用户 {end_user_id} 初始化异常: {e}")
logger.info(f"按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
return {
"status": "SUCCESS",
"initialized": initialized,
"skipped": skipped,
"failed": failed,
}
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
return result
except Exception as e:
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
# =============================================================================
@celery_app.task(
name="app.tasks.init_interest_distribution_for_users",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def init_interest_distribution_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
"""事件触发任务:检查指定用户列表的兴趣分布缓存,无缓存则生成并写入 Redis。
由 /dashboard/end_users 接口触发,已有缓存的用户直接跳过。
默认生成中文zh兴趣分布数据。
Args:
self: task object
end_user_ids: 需要检查的用户ID列表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.cache.memory.interest_memory import InterestMemoryCache, INTEREST_CACHE_EXPIRE
from app.services.memory_agent_service import MemoryAgentService
logger.info(f"开始按需初始化兴趣分布缓存,候选用户数: {len(end_user_ids)}")
initialized = 0
failed = 0
skipped = 0
language = "zh"
service = MemoryAgentService()
with get_db_context() as db:
for end_user_id in end_user_ids:
# 存在性检查:缓存有数据则跳过
cached = await InterestMemoryCache.get_interest_distribution(
end_user_id=end_user_id,
language=language,
)
if cached is not None:
skipped += 1
continue
logger.info(f"用户 {end_user_id} 无兴趣分布缓存,开始生成")
try:
result = await service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=5,
language=language,
)
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
expire=INTEREST_CACHE_EXPIRE,
)
initialized += 1
logger.info(f"用户 {end_user_id} 兴趣分布缓存生成成功")
except Exception as e:
failed += 1
logger.error(f"用户 {end_user_id} 兴趣分布缓存生成失败: {e}")
logger.info(f"兴趣分布按需初始化完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}")
return {
"status": "SUCCESS",
"initialized": initialized,
"skipped": skipped,
"failed": failed,
}
try:
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
return result
except Exception as e:
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
@celery_app.task(
name="app.tasks.write_perceptual_memory",
bind=True,
ignore_result=True,
max_retries=0,
acks_late=False,
time_limit=3600,
soft_time_limit=3300,
)
def write_perceptual_memory(
self,
end_user_id: str,
model_api_config: dict,
file_type: str,
file_url: str,
file_message: dict
):
"""
Write perceptual memory for a user into PostgreSQL and Neo4j.
This task generates or updates the user's perceptual memory
in the backend databases. It is intended to be executed asynchronously
via Celery.
Args:
end_user_id (uuid.UUID): The unique identifier of the end user.
model_api_config (ModelInfo): API configuration for the model
used to generate perceptual memory.
file_type (str): The file type
file_url (url): The url of file
file_message (dict): The file message containing details about the file
to be processed.
Returns:
None
"""
file_url_md5 = hashlib.md5(file_url.encode("utf-8")).hexdigest()
set_asyncio_event_loop()
with RedisLock(f"perceptual:{file_url_md5}", redis_client=get_sync_redis_client()):
model_info = ModelInfo(**model_api_config)
with get_db_context() as db:
memory_perceptual_service = MemoryPerceptualService(db)
return asyncio.run(memory_perceptual_service.generate_perceptual_memory(
end_user_id,
model_info,
file_type,
file_url,
file_message,
))
# =============================================================================
# 社区聚类补全任务(触发型)
# =============================================================================
@celery_app.task(
name="app.tasks.init_community_clustering_for_users",
bind=True,
ignore_result=False,
max_retries=0,
acks_late=False,
time_limit=7200, # 2小时硬超时
soft_time_limit=6900,
)
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
Args:
end_user_ids: 需要检查的用户 ID 列表
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_logger
from app.repositories.neo4j.community_repository import CommunityRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
logger = get_logger(__name__)
logger.info(f"[CommunityCluster] 开始社区聚类补全任务,候选用户数: {len(end_user_ids)}")
initialized = 0
skipped = 0
failed = 0
connector = Neo4jConnector()
try:
repo = CommunityRepository(connector)
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
user_llm_map: Dict[str, Optional[str]] = {}
user_embedding_map: Dict[str, Optional[str]] = {}
try:
with get_db_context() as db:
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.services.memory_config_service import MemoryConfigService
batch_configs = get_end_users_connected_configs_batch(end_user_ids, db)
for uid, cfg_info in batch_configs.items():
config_id = cfg_info.get("memory_config_id")
if config_id:
try:
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
except Exception as e:
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
user_llm_map[uid] = None
user_embedding_map[uid] = None
else:
user_llm_map[uid] = None
user_embedding_map[uid] = None
except Exception as e:
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
for end_user_id in end_user_ids:
try:
# 已有社区节点则跳过
has_communities = await repo.has_communities(end_user_id)
if has_communities:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
continue
# 检查是否有 ExtractedEntity 节点
entities = await repo.get_all_entities(end_user_id)
if not entities:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
continue
# 每个用户使用自己的 llm_model_id
llm_model_id = user_llm_map.get(end_user_id)
embedding_model_id = user_embedding_map.get(end_user_id)
engine = LabelPropagationEngine(
connector=connector,
llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id,
)
logger.info(f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}")
await engine.full_clustering(end_user_id)
initialized += 1
logger.info(f"[CommunityCluster] 用户 {end_user_id} 聚类完成")
except Exception as e:
failed += 1
logger.error(f"[CommunityCluster] 用户 {end_user_id} 聚类失败: {e}")
finally:
await connector.close()
logger.info(
f"[CommunityCluster] 任务完成: 初始化={initialized}, 跳过={skipped}, 失败={failed}"
)
return {
"status": "SUCCESS",
"initialized": initialized,
"skipped": skipped,
"failed": failed,
}
try:
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run())
result["elapsed_time"] = time.time() - start_time
result["task_id"] = self.request.id
return result
except Exception as e:
return {
"status": "FAILURE",
"error": str(e),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}