Merge pull request #573 from SuanmoSuanyangTechnology/feature/node-aggregation

Feature/node aggregation
This commit is contained in:
Ke Sun
2026-03-17 15:55:02 +08:00
committed by GitHub
11 changed files with 542 additions and 148 deletions

View File

@@ -120,7 +120,7 @@ class SearchService:
raw_results is None if return_raw_results=False
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query
cleaned_query = self.clean_query(question)
@@ -146,8 +146,8 @@ class SearchService:
if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities']
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in reranked_results:
@@ -157,13 +157,43 @@ class SearchService:
else:
# For keyword or embedding search, results are directly in answer dict
# Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities']
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in answer:
category_results = answer[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
# 对命中的 community 节点展开其成员 statements
if "communities" in include:
community_results = (
answer.get('reranked_results', {}).get('communities', [])
if search_type == "hybrid"
else answer.get('communities', [])
)
community_ids = [
r.get("id") for r in community_results if r.get("id")
]
if community_ids and end_user_id:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
expand_connector = Neo4jConnector()
try:
expand_result = await search_graph_community_expand(
connector=expand_connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
expanded_stmts = expand_result.get("expanded_statements", [])
if expanded_stmts:
answer_list.extend(expanded_stmts)
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
except Exception as e:
logger.warning(f"社区展开检索失败,跳过: {e}")
finally:
await expand_connector.close()
# Extract clean content from all results
content_list = [

View File

@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
@@ -171,6 +171,13 @@ async def write(
)
if success:
logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break
else:
logger.warning("Failed to save some data to Neo4j")

View File

@@ -238,7 +238,7 @@ def rerank_with_activation(
reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries"]:
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
@@ -281,21 +281,23 @@ def rerank_with_activation(
for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0)
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用
item["activation_score"] = act_norm
# 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score
item["base_score"] = base_score

View File

@@ -7,6 +7,7 @@
- 增量更新incremental_update新实体到达时只处理新实体及其邻居
"""
import asyncio
import logging
import uuid
from math import sqrt
@@ -19,8 +20,9 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10
# 社区摘要核心实体数量
CORE_ENTITY_LIMIT = 5
# 社区核心实体取 top-N 数量
CORE_ENTITY_LIMIT = 10
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
@@ -69,11 +71,13 @@ class LabelPropagationEngine:
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -103,58 +107,81 @@ class LabelPropagationEngine:
async def full_clustering(self, end_user_id: str) -> None:
"""
全量标签传播初始化。
全量标签传播初始化(分批处理,控制内存峰值)
1. 拉取所有实体,初始化每个实体为独立社区
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
3. 直到标签不再变化或达到 MAX_ITERATIONS
4. 将最终标签写入 Neo4j
策略:
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
- labels 字典跨批次共享(只存 id→community_id内存极小
- 每批独立跑 MAX_ITERATIONS 轮 LPA批次间通过 labels 传递社区信息
- 所有批次完成后统一 flush 和 merge
"""
entities = await self.repo.get_all_entities(end_user_id)
if not entities:
BATCH_SIZE = 888 # 每批实体数,可按需调整
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
total_count = await self.repo.get_entity_count(end_user_id)
if not total_count:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return
# 初始化:每个实体持有自己 id 作为社区标签
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in entities
}
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体,"
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}")
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
# labels 跨批次共享:只存 id→community_id内存极小
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
for iteration in range(MAX_ITERATIONS):
changed = 0
# 随机顺序Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
for entity in entities:
eid = entity["id"]
# 直接从缓存取邻居,不再发起 Neo4j 查询
neighbors = neighbors_cache.get(eid, [])
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}"
f"标签变化数: {changed}"
for batch_start in range(0, total_count, BATCH_SIZE):
batch_entities = await self.repo.get_entities_page(
end_user_id, skip=batch_start, limit=BATCH_SIZE
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束迭代")
if not batch_entities:
break
# 将最终标签写入 Neo4j
batch_ids = [e["id"] for e in batch_entities]
batch_embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in batch_entities
}
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}"
f"加载 {len(batch_entities)} 个实体的邻居图..."
)
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
batch_ids, end_user_id
)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS):
changed = 0
for entity in batch_entities:
eid = entity["id"]
neighbors = neighbors_cache.get(eid, [])
# 注入跨批次的最新标签邻居可能在其他批次labels 里有其最新值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
break
# 释放本批次的大对象
del neighbors_cache, batch_embeddings, batch_entities
# 所有批次完成,统一写入 Neo4j
await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values()))
logger.info(
@@ -162,7 +189,6 @@ class LabelPropagationEngine:
f"{len(labels)} 个实体,开始后处理合并"
)
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id)
@@ -170,17 +196,15 @@ class LabelPropagationEngine:
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体"
)
# 为所有社区生成元数据
# 注意_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活社区
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
# 查询存活社区并生成元数据
surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({
e.get("community_id") for e in surviving_communities
if e.get("community_id")
})
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
for cid in surviving_community_ids:
await self._generate_community_metadata(cid, end_user_id)
await self._generate_community_metadata(surviving_community_ids, end_user_id)
async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str
@@ -237,7 +261,7 @@ class LabelPropagationEngine:
logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
)
await self._generate_community_metadata(new_cid, end_user_id)
await self._generate_community_metadata([new_cid], end_user_id)
else:
# 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -249,7 +273,7 @@ class LabelPropagationEngine:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
await self._generate_community_metadata(target_cid, end_user_id)
await self._generate_community_metadata([target_cid], end_user_id)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
@@ -413,71 +437,122 @@ class LabelPropagationEngine:
except Exception:
return None
@staticmethod
def _build_entity_lines(members: List[Dict]) -> List[str]:
"""将实体列表格式化为 prompt 行,包含 name、aliases、description。"""
lines = []
for m in members:
m_name = m.get("name", "")
aliases = m.get("aliases") or []
description = m.get("description") or ""
aliases_str = f"(别名:{''.join(aliases)}" if aliases else ""
desc_str = f"{description}" if description else ""
lines.append(f"- {m_name}{aliases_str}{desc_str}")
return lines
async def _generate_community_metadata(
self, community_id: str, end_user_id: str
self, community_ids: List[str], end_user_id: str
) -> None:
"""
为社区生成并写入元数据:名称、摘要、核心实体
一个或多个社区生成并写入元数据。
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
流程:
1. 逐个社区调 LLM 生成 name / summary串行
2. 收集所有 summary一次性批量 embed
3. 单个社区用 update_community_metadata多个用 batch_update_community_metadata
"""
try:
members = await self.repo.get_community_members(community_id, end_user_id)
if not members:
return
if not community_ids:
return
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# --- 阶段1并发调 LLM 生成每个社区的 name / summary ---
async def _build_one(cid: str):
members = await self.repo.get_community_members(cid, end_user_id)
if not members:
return None
# 核心实体:按 activation_value 降序取 top-N
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else community_id[:8]
summary = f"包含实体:{', '.join(all_names)}"
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
if self.llm_model_id:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
entity_list_str = "".join(all_names)
prompt = (
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
entity_list_str = "\n".join(self._build_entity_lines(members))
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
except Exception as e:
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
with get_db_context() as db:
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
name, summary = "", ""
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
return {
"community_id": cid,
"end_user_id": end_user_id,
"name": name,
"summary": summary,
"core_entities": core_entities,
"summary_embedding": None,
}
results = await asyncio.gather(
*[_build_one(cid) for cid in community_ids],
return_exceptions=True,
)
metadata_list = []
for cid, res in zip(community_ids, results):
if isinstance(res, Exception):
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
elif res is not None:
metadata_list.append(res)
if not metadata_list:
return
# --- 阶段2批量生成 summary_embedding ---
summaries = [m["summary"] for m in metadata_list]
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
# --- 阶段3写入单个 or 批量)---
if len(metadata_list) == 1:
m = metadata_list[0]
result = await self.repo.update_community_metadata(
community_id=m["community_id"],
end_user_id=m["end_user_id"],
name=m["name"],
summary=m["summary"],
core_entities=m["core_entities"],
summary_embedding=m["summary_embedding"],
)
if result:
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
else:
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
else:
ok = await self.repo.batch_update_community_metadata(metadata_list)
if ok:
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
else:
logger.warning(f"[Clustering] 批量写入社区元数据失败")
@staticmethod
def _new_community_id() -> str: