Merge pull request #573 from SuanmoSuanyangTechnology/feature/node-aggregation

Feature/node aggregation
This commit is contained in:
Ke Sun
2026-03-17 15:55:02 +08:00
committed by GitHub
11 changed files with 542 additions and 148 deletions

View File

@@ -120,7 +120,7 @@ class SearchService:
raw_results is None if return_raw_results=False raw_results is None if return_raw_results=False
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query # Clean query
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
@@ -146,8 +146,8 @@ class SearchService:
if search_type == "hybrid": if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {}) reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities # Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in reranked_results: if category in include and category in reranked_results:
@@ -157,7 +157,7 @@ class SearchService:
else: else:
# For keyword or embedding search, results are directly in answer dict # For keyword or embedding search, results are directly in answer dict
# Apply same priority order # Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in answer: if category in include and category in answer:
@@ -165,6 +165,36 @@ class SearchService:
if isinstance(category_results, list): if isinstance(category_results, list):
answer_list.extend(category_results) answer_list.extend(category_results)
# 对命中的 community 节点展开其成员 statements
if "communities" in include:
community_results = (
answer.get('reranked_results', {}).get('communities', [])
if search_type == "hybrid"
else answer.get('communities', [])
)
community_ids = [
r.get("id") for r in community_results if r.get("id")
]
if community_ids and end_user_id:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
expand_connector = Neo4jConnector()
try:
expand_result = await search_graph_community_expand(
connector=expand_connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
expanded_stmts = expand_result.get("expanded_statements", [])
if expanded_stmts:
answer_list.extend(expanded_stmts)
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
except Exception as e:
logger.warning(f"社区展开检索失败,跳过: {e}")
finally:
await expand_connector.close()
# Extract clean content from all results # Extract clean content from all results
content_list = [ content_list = [
self.extract_content_from_result(ans) self.extract_content_from_result(ans)

View File

@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -171,6 +171,13 @@ async def write(
) )
if success: if success:
logger.info("Successfully saved all data to Neo4j") logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break break
else: else:
logger.warning("Failed to save some data to Neo4j") logger.warning("Failed to save some data to Neo4j")

View File

@@ -238,7 +238,7 @@ def rerank_with_activation(
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries"]: for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
@@ -281,21 +281,23 @@ def rerank_with_activation(
for item in items_list: for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items: if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数 # 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items(): for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0) bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0) # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding # 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数 base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用 # 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score item["content_score"] = content_score
item["base_score"] = base_score item["base_score"] = base_score

View File

@@ -7,6 +7,7 @@
- 增量更新incremental_update新实体到达时只处理新实体及其邻居 - 增量更新incremental_update新实体到达时只处理新实体及其邻居
""" """
import asyncio
import logging import logging
import uuid import uuid
from math import sqrt from math import sqrt
@@ -19,8 +20,9 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛 # 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10 MAX_ITERATIONS = 10
# 社区摘要核心实体数量
CORE_ENTITY_LIMIT = 5 # 社区核心实体取 top-N 数量
CORE_ENTITY_LIMIT = 10
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
@@ -69,11 +71,13 @@ class LabelPropagationEngine:
connector: Neo4jConnector, connector: Neo4jConnector,
config_id: Optional[str] = None, config_id: Optional[str] = None,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
): ):
self.connector = connector self.connector = connector
self.repo = CommunityRepository(connector) self.repo = CommunityRepository(connector)
self.config_id = config_id self.config_id = config_id
self.llm_model_id = llm_model_id self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# 公开接口 # 公开接口
@@ -103,58 +107,81 @@ class LabelPropagationEngine:
async def full_clustering(self, end_user_id: str) -> None: async def full_clustering(self, end_user_id: str) -> None:
""" """
全量标签传播初始化。 全量标签传播初始化(分批处理,控制内存峰值)
1. 拉取所有实体,初始化每个实体为独立社区 策略:
2. 迭代:每轮对所有实体做邻居投票,更新社区标签 - 每次只加载 BATCH_SIZE 个实体及其邻居进内存
3. 直到标签不再变化或达到 MAX_ITERATIONS - labels 字典跨批次共享(只存 id→community_id内存极小
4. 将最终标签写入 Neo4j - 每批独立跑 MAX_ITERATIONS 轮 LPA批次间通过 labels 传递社区信息
- 所有批次完成后统一 flush 和 merge
""" """
entities = await self.repo.get_all_entities(end_user_id) BATCH_SIZE = 888 # 每批实体数,可按需调整
if not entities:
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
total_count = await self.repo.get_entity_count(end_user_id)
if not total_count:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return return
# 初始化:每个实体持有自己 id 作为社区标签 all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体,"
embeddings: Dict[str, Optional[List[float]]] = { f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}")
e["id"]: e.get("name_embedding") for e in entities
}
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返 # labels 跨批次共享:只存 id→community_id内存极小
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...") labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id) del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS): for batch_start in range(0, total_count, BATCH_SIZE):
changed = 0 batch_entities = await self.repo.get_entities_page(
# 随机顺序Python dict 在 3.7+ 保持插入顺序,这里直接遍历) end_user_id, skip=batch_start, limit=BATCH_SIZE
for entity in entities:
eid = entity["id"]
# 直接从缓存取邻居,不再发起 Neo4j 查询
neighbors = neighbors_cache.get(eid, [])
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}"
f"标签变化数: {changed}"
) )
if changed == 0: if not batch_entities:
logger.info("[Clustering] 标签已收敛,提前结束迭代")
break break
# 将最终标签写入 Neo4j batch_ids = [e["id"] for e in batch_entities]
batch_embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in batch_entities
}
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}"
f"加载 {len(batch_entities)} 个实体的邻居图..."
)
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
batch_ids, end_user_id
)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS):
changed = 0
for entity in batch_entities:
eid = entity["id"]
neighbors = neighbors_cache.get(eid, [])
# 注入跨批次的最新标签邻居可能在其他批次labels 里有其最新值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
break
# 释放本批次的大对象
del neighbors_cache, batch_embeddings, batch_entities
# 所有批次完成,统一写入 Neo4j
await self._flush_labels(labels, end_user_id) await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values())) pre_merge_count = len(set(labels.values()))
logger.info( logger.info(
@@ -162,7 +189,6 @@ class LabelPropagationEngine:
f"{len(labels)} 个实体,开始后处理合并" f"{len(labels)} 个实体,开始后处理合并"
) )
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
all_community_ids = list(set(labels.values())) all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id) await self._evaluate_merge(all_community_ids, end_user_id)
@@ -170,17 +196,15 @@ class LabelPropagationEngine:
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体" f"{len(labels)} 个实体"
) )
# 为所有社区生成元数据
# 注意_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活社区 # 查询存活社区并生成元数据
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
surviving_communities = await self.repo.get_all_entities(end_user_id) surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({ surviving_community_ids = list({
e.get("community_id") for e in surviving_communities e.get("community_id") for e in surviving_communities
if e.get("community_id") if e.get("community_id")
}) })
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
for cid in surviving_community_ids: await self._generate_community_metadata(surviving_community_ids, end_user_id)
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update( async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str self, new_entity_ids: List[str], end_user_id: str
@@ -237,7 +261,7 @@ class LabelPropagationEngine:
logger.debug( logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
) )
await self._generate_community_metadata(new_cid, end_user_id) await self._generate_community_metadata([new_cid], end_user_id)
else: else:
# 加入得票最多的社区 # 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -249,7 +273,7 @@ class LabelPropagationEngine:
await self._evaluate_merge( await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id list(community_ids_in_neighbors), end_user_id
) )
await self._generate_community_metadata(target_cid, end_user_id) await self._generate_community_metadata([target_cid], end_user_id)
async def _evaluate_merge( async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str self, community_ids: List[str], end_user_id: str
@@ -413,71 +437,122 @@ class LabelPropagationEngine:
except Exception: except Exception:
return None return None
@staticmethod
def _build_entity_lines(members: List[Dict]) -> List[str]:
"""将实体列表格式化为 prompt 行,包含 name、aliases、description。"""
lines = []
for m in members:
m_name = m.get("name", "")
aliases = m.get("aliases") or []
description = m.get("description") or ""
aliases_str = f"(别名:{''.join(aliases)}" if aliases else ""
desc_str = f"{description}" if description else ""
lines.append(f"- {m_name}{aliases_str}{desc_str}")
return lines
async def _generate_community_metadata( async def _generate_community_metadata(
self, community_id: str, end_user_id: str self, community_ids: List[str], end_user_id: str
) -> None: ) -> None:
""" """
为社区生成并写入元数据:名称、摘要、核心实体 一个或多个社区生成并写入元数据。
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM 流程:
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 1. 逐个社区调 LLM 生成 name / summary串行
2. 收集所有 summary一次性批量 embed
3. 单个社区用 update_community_metadata多个用 batch_update_community_metadata
""" """
try: if not community_ids:
members = await self.repo.get_community_members(community_id, end_user_id) return
if not members:
return from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# --- 阶段1并发调 LLM 生成每个社区的 name / summary ---
async def _build_one(cid: str):
members = await self.repo.get_community_members(cid, end_user_id)
if not members:
return None
# 核心实体:按 activation_value 降序取 top-N
sorted_members = sorted( sorted_members = sorted(
members, members,
key=lambda m: m.get("activation_value") or 0, key=lambda m: m.get("activation_value") or 0,
reverse=True, reverse=True,
) )
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else community_id[:8] entity_list_str = "\n".join(self._build_entity_lines(members))
summary = f"包含实体:{', '.join(all_names)}" prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}\n\n"
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要 f"请为这组实体所代表的主题:\n"
if self.llm_model_id: f"1. 起一个简洁的中文名称不超过10个字\n"
try: f"2. 写一句话摘要不超过50个字\n\n"
from app.db import get_db_context f"严格按以下格式输出,不要有其他内容:\n"
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory f"名称:<名称>\n摘要:<摘要>"
entity_list_str = "".join(all_names)
prompt = (
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
) )
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}") with get_db_context() as db:
except Exception as e: llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}") response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
name, summary = "", ""
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
return {
"community_id": cid,
"end_user_id": end_user_id,
"name": name,
"summary": summary,
"core_entities": core_entities,
"summary_embedding": None,
}
results = await asyncio.gather(
*[_build_one(cid) for cid in community_ids],
return_exceptions=True,
)
metadata_list = []
for cid, res in zip(community_ids, results):
if isinstance(res, Exception):
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
elif res is not None:
metadata_list.append(res)
if not metadata_list:
return
# --- 阶段2批量生成 summary_embedding ---
summaries = [m["summary"] for m in metadata_list]
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
# --- 阶段3写入单个 or 批量)---
if len(metadata_list) == 1:
m = metadata_list[0]
result = await self.repo.update_community_metadata(
community_id=m["community_id"],
end_user_id=m["end_user_id"],
name=m["name"],
summary=m["summary"],
core_entities=m["core_entities"],
summary_embedding=m["summary_embedding"],
)
if result:
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
else:
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
else:
ok = await self.repo.batch_update_community_metadata(metadata_list)
if ok:
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
else:
logger.warning(f"[Clustering] 批量写入社区元数据失败")
@staticmethod @staticmethod
def _new_community_id() -> str: def _new_community_id() -> str:

View File

@@ -9,7 +9,6 @@ from sqlalchemy.dialects.postgresql import JSONB
from app.db import Base from app.db import Base
from app.schemas import FileType from app.schemas import FileType
class PerceptualType(IntEnum): class PerceptualType(IntEnum):
VISION = 1 VISION = 1
AUDIO = 2 AUDIO = 2

View File

@@ -13,12 +13,17 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_LEAVE_ALL_COMMUNITIES, ENTITY_LEAVE_ALL_COMMUNITIES,
GET_ENTITY_NEIGHBORS, GET_ENTITY_NEIGHBORS,
GET_ALL_ENTITIES_FOR_USER, GET_ALL_ENTITIES_FOR_USER,
GET_ENTITY_COUNT_FOR_USER,
GET_ALL_ENTITY_IDS_FOR_USER,
GET_ENTITIES_PAGE,
GET_COMMUNITY_MEMBERS, GET_COMMUNITY_MEMBERS,
GET_ALL_COMMUNITY_MEMBERS_BATCH, GET_ALL_COMMUNITY_MEMBERS_BATCH,
GET_ALL_ENTITY_NEIGHBORS_BATCH, GET_ALL_ENTITY_NEIGHBORS_BATCH,
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
CHECK_USER_HAS_COMMUNITIES, CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA, UPDATE_COMMUNITY_METADATA,
BATCH_UPDATE_COMMUNITY_METADATA,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -110,6 +115,65 @@ class CommunityRepository:
logger.error(f"get_all_entities failed: {e}") logger.error(f"get_all_entities failed: {e}")
return [] return []
async def get_entity_count(self, end_user_id: str) -> int:
"""仅返回用户实体总数,不加载实体数据。"""
try:
result = await self.connector.execute_query(
GET_ENTITY_COUNT_FOR_USER,
end_user_id=end_user_id,
)
return result[0]["entity_count"] if result else 0
except Exception as e:
logger.error(f"get_entity_count failed: {e}")
return 0
async def get_all_entity_ids(self, end_user_id: str) -> List[str]:
"""仅返回用户所有实体 ID 列表,不加载 embedding 等大字段。"""
try:
result = await self.connector.execute_query(
GET_ALL_ENTITY_IDS_FOR_USER,
end_user_id=end_user_id,
)
return [r["id"] for r in result]
except Exception as e:
logger.error(f"get_all_entity_ids failed: {e}")
return []
async def get_entities_page(
self, end_user_id: str, skip: int, limit: int
) -> List[Dict]:
"""分页拉取实体,用于全量聚类分批处理。"""
try:
return await self.connector.execute_query(
GET_ENTITIES_PAGE,
end_user_id=end_user_id,
skip=skip,
limit=limit,
)
except Exception as e:
logger.error(f"get_entities_page failed: {e}")
return []
async def get_entity_neighbors_for_ids(
self, entity_ids: List[str], end_user_id: str
) -> Dict[str, List[Dict]]:
"""批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。"""
try:
rows = await self.connector.execute_query(
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
entity_ids=entity_ids,
end_user_id=end_user_id,
)
result: Dict[str, List[Dict]] = {}
for row in rows:
eid = row["entity_id"]
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
result.setdefault(eid, []).append(neighbor)
return result
except Exception as e:
logger.error(f"get_entity_neighbors_for_ids failed: {e}")
return {}
async def get_community_members( async def get_community_members(
self, community_id: str, end_user_id: str self, community_id: str, end_user_id: str
) -> List[Dict]: ) -> List[Dict]:
@@ -177,8 +241,9 @@ class CommunityRepository:
name: str, name: str,
summary: str, summary: str,
core_entities: List[str], core_entities: List[str],
summary_embedding: Optional[List[float]] = None,
) -> bool: ) -> bool:
"""更新社区的名称、摘要核心实体列表。""" """更新社区的名称、摘要核心实体列表和摘要向量"""
try: try:
result = await self.connector.execute_query( result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA, UPDATE_COMMUNITY_METADATA,
@@ -187,8 +252,31 @@ class CommunityRepository:
name=name, name=name,
summary=summary, summary=summary,
core_entities=core_entities, core_entities=core_entities,
summary_embedding=summary_embedding,
) )
return bool(result) return bool(result)
except Exception as e: except Exception as e:
logger.error(f"update_community_metadata failed: {e}") logger.error(f"update_community_metadata failed: {e}")
return False return False
async def batch_update_community_metadata(
self,
communities: List[Dict],
) -> bool:
"""批量更新多个社区的元数据。
Args:
communities: 每项包含 community_id, end_user_id, name, summary,
core_entities, summary_embedding
"""
if not communities:
return True
try:
await self.connector.execute_query(
BATCH_UPDATE_COMMUNITY_METADATA,
communities=communities,
)
return True
except Exception as e:
logger.error(f"batch_update_community_metadata failed: {e}")
return False

View File

@@ -43,6 +43,13 @@ async def create_fulltext_indexes():
""") """)
print("✓ Created: summariesFulltext") print("✓ Created: summariesFulltext")
# 创建 Community 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
print("✓ Created: communitiesFulltext")
print("\nFull-text indexes created successfully with BM25 support.") print("\nFull-text indexes created successfully with BM25 support.")
except Exception as e: except Exception as e:
print(f"✗ Error creating full-text indexes: {e}") print(f"✗ Error creating full-text indexes: {e}")
@@ -125,6 +132,18 @@ async def create_vector_indexes():
""") """)
print("✓ Created: dialogue_embedding_index") print("✓ Created: dialogue_embedding_index")
# Community summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
FOR (c:Community)
ON c.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: community_summary_embedding_index")
print("\nVector indexes created successfully!") print("\nVector indexes created successfully!")
print("\nExpected performance improvement:") print("\nExpected performance improvement:")
print(" Before: ~1.4s for embedding search") print(" Before: ~1.4s for embedding search")

View File

@@ -1122,21 +1122,33 @@ RETURN e.id AS id,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
""" """
GET_ENTITY_COUNT_FOR_USER = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
RETURN count(e) AS entity_count
"""
GET_ALL_ENTITY_IDS_FOR_USER = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
RETURN e.id AS id
"""
GET_COMMUNITY_MEMBERS = """ GET_COMMUNITY_MEMBERS = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id}) MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type, RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
e.importance_score AS importance_score, e.activation_value AS activation_value, e.importance_score AS importance_score, e.activation_value AS activation_value,
e.name_embedding AS name_embedding e.name_embedding AS name_embedding,
e.aliases AS aliases, e.description AS description
ORDER BY coalesce(e.activation_value, 0) DESC ORDER BY coalesce(e.activation_value, 0) DESC
""" """
GET_ALL_COMMUNITY_MEMBERS_BATCH = """ GET_ALL_COMMUNITY_MEMBERS_BATCH = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
WHERE c.community_id IN $community_ids
RETURN c.community_id AS community_id, RETURN c.community_id AS community_id,
e.id AS id, e.id AS id, e.name AS name, e.entity_type AS entity_type,
e.importance_score AS importance_score, e.activation_value AS activation_value,
e.name_embedding AS name_embedding, e.name_embedding AS name_embedding,
e.activation_value AS activation_value e.aliases AS aliases, e.description AS description
ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC
""" """
CHECK_USER_HAS_COMMUNITIES = """ CHECK_USER_HAS_COMMUNITIES = """
@@ -1153,13 +1165,58 @@ RETURN c.community_id AS community_id, cnt AS member_count
UPDATE_COMMUNITY_METADATA = """ UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name, SET c.name = $name,
c.summary = $summary, c.summary = $summary,
c.core_entities = $core_entities, c.core_entities = $core_entities,
c.updated_at = datetime() c.summary_embedding = $summary_embedding,
c.updated_at = datetime()
RETURN c.community_id AS community_id RETURN c.community_id AS community_id
""" """
BATCH_UPDATE_COMMUNITY_METADATA = """
UNWIND $communities AS row
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
SET c.name = row.name,
c.summary = row.summary,
c.core_entities = row.core_entities,
c.summary_embedding = row.summary_embedding,
c.updated_at = datetime()
RETURN c.community_id AS community_id
"""
GET_ENTITIES_PAGE = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN e.id AS id,
e.name AS name,
e.name_embedding AS name_embedding,
e.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
ORDER BY e.id
SKIP $skip LIMIT $limit
"""
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """
// 批量拉取指定实体列表的邻居(用于分批全量聚类)
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
WHERE e.id IN $entity_ids
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
WHERE nb2.id <> e.id
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
UNWIND all_neighbors AS nb
WITH e, nb WHERE nb IS NOT NULL
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN DISTINCT
e.id AS entity_id,
nb.id AS id,
nb.name AS name,
nb.name_embedding AS name_embedding,
nb.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
"""
GET_ALL_ENTITY_NEIGHBORS_BATCH = """ GET_ALL_ENTITY_NEIGHBORS_BATCH = """
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载) // 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
@@ -1185,20 +1242,59 @@ RETURN DISTINCT
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
""" """
GET_COMMUNITY_GRAPH_DATA = """
MATCH (c:Community {end_user_id: $end_user_id}) # Community keyword search: matches name or summary via fulltext index
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c) SEARCH_COMMUNITIES_BY_KEYWORD = """
OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id}) CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
RETURN WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
elementId(c) AS c_id, RETURN c.community_id AS id,
properties(c) AS c_props, c.name AS name,
elementId(e) AS e_id, c.summary AS content,
properties(e) AS e_props, c.core_entities AS core_entities,
elementId(b) AS b_id, c.member_count AS member_count,
elementId(e2) AS e2_id, c.end_user_id AS end_user_id,
properties(e2) AS e2_props, c.updated_at AS updated_at,
elementId(r) AS r_id, score
type(r) AS r_type, ORDER BY score DESC
properties(r) AS r_props, LIMIT $limit
startNode(r) = e AS r_from_e """
# Community 向量检索 ──────────────────────────────────────────────────
# Community embedding-based search: cosine similarity on Community.summary_embedding
COMMUNITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.summary_embedding IS NOT NULL
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.community_id AS id,
c.name AS name,
c.summary AS content,
c.core_entities AS core_entities,
c.member_count AS member_count,
c.end_user_id AS end_user_id,
c.updated_at AS updated_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Community 展开检索 ──────────────────────────────────────────────────
# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索)
EXPAND_COMMUNITY_STATEMENTS = """
MATCH (c:Community {community_id: $community_id})
MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c)
MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
WHERE s.end_user_id = $end_user_id
RETURN s.statement AS statement,
s.id AS id,
s.end_user_id AS end_user_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
e.name AS source_entity,
c.name AS community_name
ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit
""" """

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import os
from typing import List, Optional from typing import List, Optional
# 使用新的仓储层 # 使用新的仓储层
@@ -158,11 +157,12 @@ async def save_dialog_and_statements_to_neo4j(
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector, connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
schedule_clustering_after_write() 显式触发。
Args: Args:
dialogue_nodes: List of DialogueNode objects to save dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save chunk_nodes: List of ChunkNode objects to save
@@ -293,9 +293,6 @@ async def save_dialog_and_statements_to_neo4j(
logger.info("Transaction completed. Summary: %s", summary) logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results) logger.debug("Full transaction results: %r", results)
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id)
return True return True
except Exception as e: except Exception as e:
@@ -309,6 +306,7 @@ def schedule_clustering_after_write(
entity_nodes: List, entity_nodes: List,
config_id: Optional[str] = None, config_id: Optional[str] = None,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
写入 Neo4j 成功后,调度后台聚类任务。 写入 Neo4j 成功后,调度后台聚类任务。
@@ -327,7 +325,7 @@ def schedule_clustering_after_write(
end_user_id = entity_nodes[0].end_user_id end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes] new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
async def _trigger_clustering( async def _trigger_clustering(
@@ -335,6 +333,7 @@ async def _trigger_clustering(
end_user_id: str, end_user_id: str,
config_id: Optional[str] = None, config_id: Optional[str] = None,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
聚类触发函数,自动判断全量初始化还是增量更新。 聚类触发函数,自动判断全量初始化还是增量更新。
@@ -344,7 +343,7 @@ async def _trigger_clustering(
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
connector = Neo4jConnector() connector = Neo4jConnector()
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id) engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}") logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}")
except Exception as e: except Exception as e:

View File

@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
from app.repositories.neo4j.cypher_queries import ( from app.repositories.neo4j.cypher_queries import (
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
COMMUNITY_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
@@ -286,6 +289,15 @@ async def search_graph(
)) ))
task_keys.append("summaries") task_keys.append("summaries")
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -396,6 +408,16 @@ async def search_graph_by_embedding(
)) ))
task_keys.append("summaries") task_keys.append("summaries")
# Communities (向量语义匹配)
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -408,6 +430,7 @@ async def search_graph_by_embedding(
"chunks": [], "chunks": [],
"entities": [], "entities": [],
"summaries": [], "summaries": [],
"communities": [],
} }
for key, result in zip(task_keys, task_results): for key, result in zip(task_keys, task_results):
@@ -661,6 +684,62 @@ async def search_graph_by_chunk_id(
return {"chunks": chunks} return {"chunks": chunks}
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体,
再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点,
按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。
Args:
connector: Neo4j 连接器
community_ids: 已命中的社区 ID 列表
end_user_id: 用户 ID用于数据隔离
limit: 每个社区最多返回的 Statement 数量
Returns:
{"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]}
"""
if not community_ids or not end_user_id:
return {"expanded_statements": []}
tasks = [
connector.execute_query(
EXPAND_COMMUNITY_STATEMENTS,
community_id=cid,
end_user_id=end_user_id,
limit=limit,
)
for cid in community_ids
]
task_results = await asyncio.gather(*tasks, return_exceptions=True)
expanded: List[Dict[str, Any]] = []
for cid, result in zip(community_ids, task_results):
if isinstance(result, Exception):
logger.warning(f"社区展开检索失败 community_id={cid}: {result}")
else:
expanded.extend(result)
# 按 activation_value 全局排序后去重
from app.core.memory.src.search import _deduplicate_results
expanded.sort(
key=lambda x: float(x.get("activation_value") or 0),
reverse=True,
)
expanded = _deduplicate_results(expanded)
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
return {"expanded_statements": expanded}
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,