[changes] Community Clustering Retrieval Module

This commit is contained in:
lanceyq
2026-03-16 12:30:00 +08:00
parent b1a7b58f97
commit f9fb480cc3
11 changed files with 637 additions and 96 deletions

View File

@@ -120,7 +120,7 @@ class SearchService:
raw_results is None if return_raw_results=False raw_results is None if return_raw_results=False
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query # Clean query
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
@@ -146,8 +146,8 @@ class SearchService:
if search_type == "hybrid": if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {}) reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities # Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in reranked_results: if category in include and category in reranked_results:
@@ -157,13 +157,43 @@ class SearchService:
else: else:
# For keyword or embedding search, results are directly in answer dict # For keyword or embedding search, results are directly in answer dict
# Apply same priority order # Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in answer: if category in include and category in answer:
category_results = answer[category] category_results = answer[category]
if isinstance(category_results, list): if isinstance(category_results, list):
answer_list.extend(category_results) answer_list.extend(category_results)
# 对命中的 community 节点展开其成员 statements
if "communities" in include:
community_results = (
answer.get('reranked_results', {}).get('communities', [])
if search_type == "hybrid"
else answer.get('communities', [])
)
community_ids = [
r.get("id") for r in community_results if r.get("id")
]
if community_ids and end_user_id:
try:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
expand_result = await search_graph_community_expand(
connector=connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
await connector.close()
expanded_stmts = expand_result.get("expanded_statements", [])
if expanded_stmts:
# 展开的 statements 插入 communities 之后、statements 之前
answer_list.extend(expanded_stmts)
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
except Exception as e:
logger.warning(f"社区展开检索失败,跳过: {e}")
# Extract clean content from all results # Extract clean content from all results
content_list = [ content_list = [

View File

@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -171,6 +171,13 @@ async def write(
) )
if success: if success:
logger.info("Successfully saved all data to Neo4j") logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)
break break
else: else:
logger.warning("Failed to save some data to Neo4j") logger.warning("Failed to save some data to Neo4j")

View File

@@ -238,7 +238,7 @@ def rerank_with_activation(
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries"]: for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
@@ -281,21 +281,23 @@ def rerank_with_activation(
for item in items_list: for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items: if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数 # 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items(): for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0) bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0) # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding # 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数 base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用 # 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score item["content_score"] = content_score
item["base_score"] = base_score item["base_score"] = base_score

View File

@@ -19,8 +19,9 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛 # 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10 MAX_ITERATIONS = 10
# 社区摘要核心实体数量
CORE_ENTITY_LIMIT = 5 # 社区核心实体取 top-N 数量
CORE_ENTITY_LIMIT = 10
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
@@ -69,11 +70,13 @@ class LabelPropagationEngine:
connector: Neo4jConnector, connector: Neo4jConnector,
config_id: Optional[str] = None, config_id: Optional[str] = None,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
): ):
self.connector = connector self.connector = connector
self.repo = CommunityRepository(connector) self.repo = CommunityRepository(connector)
self.config_id = config_id self.config_id = config_id
self.llm_model_id = llm_model_id self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# 公开接口 # 公开接口
@@ -103,58 +106,85 @@ class LabelPropagationEngine:
async def full_clustering(self, end_user_id: str) -> None: async def full_clustering(self, end_user_id: str) -> None:
""" """
全量标签传播初始化。 全量标签传播初始化(分批处理,控制内存峰值)
1. 拉取所有实体,初始化每个实体为独立社区 策略:
2. 迭代:每轮对所有实体做邻居投票,更新社区标签 - 每次只加载 BATCH_SIZE 个实体及其邻居进内存
3. 直到标签不再变化或达到 MAX_ITERATIONS - labels 字典跨批次共享(只存 id→community_id内存极小
4. 将最终标签写入 Neo4j - 每批独立跑 MAX_ITERATIONS 轮 LPA批次间通过 labels 传递社区信息
- 所有批次完成后统一 flush 和 merge
""" """
entities = await self.repo.get_all_entities(end_user_id) BATCH_SIZE = 2000 # 每批实体数,可按需调整
if not entities:
# 先查总数,决定批次数
total_entities = await self.repo.get_all_entities(end_user_id)
if not total_entities:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return return
# 初始化:每个实体持有自己 id 作为社区标签 total_count = len(total_entities)
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities} logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体,"
embeddings: Dict[str, Optional[List[float]]] = { f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}")
e["id"]: e.get("name_embedding") for e in entities
}
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返 # labels 跨批次共享:先用全量数据初始化(只存 id内存极小
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...") labels: Dict[str, str] = {e["id"]: e["id"] for e in total_entities}
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id) # embeddings 也跨批次共享(每个向量 ~6KB10万实体约 600MB这是不可避免的
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}") # 但只在当前批次的实体需要时才保留,其余批次的 embedding 不常驻
# 实际上 embeddings 只在 _weighted_vote 中用于计算 self_embedding
# 所以只需要当前批次实体的 embedding不需要全量
del total_entities # 释放全量列表,后续按批次加载
for iteration in range(MAX_ITERATIONS): for batch_start in range(0, total_count, BATCH_SIZE):
changed = 0 batch_entities = await self.repo.get_entities_page(
# 随机顺序Python dict 在 3.7+ 保持插入顺序,这里直接遍历) end_user_id, skip=batch_start, limit=BATCH_SIZE
for entity in entities:
eid = entity["id"]
# 直接从缓存取邻居,不再发起 Neo4j 查询
neighbors = neighbors_cache.get(eid, [])
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}"
f"标签变化数: {changed}"
) )
if changed == 0: if not batch_entities:
logger.info("[Clustering] 标签已收敛,提前结束迭代")
break break
# 将最终标签写入 Neo4j batch_ids = [e["id"] for e in batch_entities]
batch_embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in batch_entities
}
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}"
f"加载 {len(batch_entities)} 个实体的邻居图..."
)
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
batch_ids, end_user_id
)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS):
changed = 0
for entity in batch_entities:
eid = entity["id"]
neighbors = neighbors_cache.get(eid, [])
# 注入跨批次的最新标签邻居可能在其他批次labels 里有其最新值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
break
# 释放本批次的大对象
del neighbors_cache, batch_embeddings, batch_entities
# 所有批次完成,统一写入 Neo4j
await self._flush_labels(labels, end_user_id) await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values())) pre_merge_count = len(set(labels.values()))
logger.info( logger.info(
@@ -162,17 +192,16 @@ class LabelPropagationEngine:
f"{len(labels)} 个实体,开始后处理合并" f"{len(labels)} 个实体,开始后处理合并"
) )
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
all_community_ids = list(set(labels.values())) all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id) await self._evaluate_merge(all_community_ids, end_user_id)
logger.info( logger.info(
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体" f"{len(labels)} 个实体"
) )
# 为所有社区生成元数据
# 注意_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活社区 # 查询存活社区并生成元数据
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
surviving_communities = await self.repo.get_all_entities(end_user_id) surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({ surviving_community_ids = list({
e.get("community_id") for e in surviving_communities e.get("community_id") for e in surviving_communities
@@ -421,6 +450,7 @@ class LabelPropagationEngine:
- core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM - core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
- name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底 - name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
NOTE: core_entities按照激活值高低排序会造成对边缘信息检索返回消息质量不高。
""" """
try: try:
members = await self.repo.get_community_members(community_id, end_user_id) members = await self.repo.get_community_members(community_id, end_user_id)
@@ -468,16 +498,33 @@ class LabelPropagationEngine:
except Exception as e: except Exception as e:
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}") logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
await self.repo.update_community_metadata( # 生成 summary_embedding
summary_embedding = None
if self.embedding_model_id and summary:
try:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
results = await embedder.response([summary])
summary_embedding = results[0] if results else None
except Exception as e:
logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
result = await self.repo.update_community_metadata(
community_id=community_id, community_id=community_id,
end_user_id=end_user_id, end_user_id=end_user_id,
name=name, name=name,
summary=summary, summary=summary,
core_entities=core_entities, core_entities=core_entities,
summary_embedding=summary_embedding,
) )
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}") if result:
logger.info(f"[Clustering] 社区 {community_id} 元数据写入成功: name={name}, summary={summary[:30]}...")
else:
logger.warning(f"[Clustering] 社区 {community_id} 元数据写入返回 False")
except Exception as e: except Exception as e:
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}") logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}", exc_info=True)
@staticmethod @staticmethod
def _new_community_id() -> str: def _new_community_id() -> str:

View File

@@ -18,6 +18,7 @@ from app.core.logging_config import LoggingConfig, get_logger
from app.core.response_utils import fail from app.core.response_utils import fail
from app.core.models.scripts.loader import load_models from app.core.models.scripts.loader import load_models
from app.db import get_db_context from app.db import get_db_context
from app.repositories.neo4j.index_manager import ensure_indexes
# Initialize logging system # Initialize logging system
LoggingConfig.setup_logging() LoggingConfig.setup_logging()
@@ -61,9 +62,18 @@ async def lifespan(app: FastAPI):
else: else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
# 确保 Neo4j 索引存在(幂等,多环境安全)
try:
report = await ensure_indexes()
if report["errors"]:
logger.warning(f"Neo4j 索引部分创建失败: {report['errors']}")
else:
logger.info(f"Neo4j 索引检查完成 [{report['uri']}]")
except Exception as e:
logger.warning(f"Neo4j 索引检查跳过(连接失败): {e}")
logger.info("应用程序启动完成") logger.info("应用程序启动完成")
yield yield
# 应用关闭事件
logger.info("应用程序正在关闭") logger.info("应用程序正在关闭")

View File

@@ -13,12 +13,15 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_LEAVE_ALL_COMMUNITIES, ENTITY_LEAVE_ALL_COMMUNITIES,
GET_ENTITY_NEIGHBORS, GET_ENTITY_NEIGHBORS,
GET_ALL_ENTITIES_FOR_USER, GET_ALL_ENTITIES_FOR_USER,
GET_ENTITIES_PAGE,
GET_COMMUNITY_MEMBERS, GET_COMMUNITY_MEMBERS,
GET_ALL_COMMUNITY_MEMBERS_BATCH, GET_ALL_COMMUNITY_MEMBERS_BATCH,
GET_ALL_ENTITY_NEIGHBORS_BATCH, GET_ALL_ENTITY_NEIGHBORS_BATCH,
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
CHECK_USER_HAS_COMMUNITIES, CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT, UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA, UPDATE_COMMUNITY_METADATA,
UPDATE_COMMUNITY_METADATA,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -110,6 +113,41 @@ class CommunityRepository:
logger.error(f"get_all_entities failed: {e}") logger.error(f"get_all_entities failed: {e}")
return [] return []
async def get_entities_page(
self, end_user_id: str, skip: int, limit: int
) -> List[Dict]:
"""分页拉取实体,用于全量聚类分批处理。"""
try:
return await self.connector.execute_query(
GET_ENTITIES_PAGE,
end_user_id=end_user_id,
skip=skip,
limit=limit,
)
except Exception as e:
logger.error(f"get_entities_page failed: {e}")
return []
async def get_entity_neighbors_for_ids(
self, entity_ids: List[str], end_user_id: str
) -> Dict[str, List[Dict]]:
"""批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。"""
try:
rows = await self.connector.execute_query(
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
entity_ids=entity_ids,
end_user_id=end_user_id,
)
result: Dict[str, List[Dict]] = {}
for row in rows:
eid = row["entity_id"]
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
result.setdefault(eid, []).append(neighbor)
return result
except Exception as e:
logger.error(f"get_entity_neighbors_for_ids failed: {e}")
return {}
async def get_community_members( async def get_community_members(
self, community_id: str, end_user_id: str self, community_id: str, end_user_id: str
) -> List[Dict]: ) -> List[Dict]:
@@ -177,8 +215,9 @@ class CommunityRepository:
name: str, name: str,
summary: str, summary: str,
core_entities: List[str], core_entities: List[str],
summary_embedding: Optional[List[float]] = None,
) -> bool: ) -> bool:
"""更新社区的名称、摘要核心实体列表。""" """更新社区的名称、摘要核心实体列表和摘要向量"""
try: try:
result = await self.connector.execute_query( result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA, UPDATE_COMMUNITY_METADATA,
@@ -187,6 +226,7 @@ class CommunityRepository:
name=name, name=name,
summary=summary, summary=summary,
core_entities=core_entities, core_entities=core_entities,
summary_embedding=summary_embedding,
) )
return bool(result) return bool(result)
except Exception as e: except Exception as e:

View File

@@ -1132,11 +1132,11 @@ ORDER BY coalesce(e.activation_value, 0) DESC
GET_ALL_COMMUNITY_MEMBERS_BATCH = """ GET_ALL_COMMUNITY_MEMBERS_BATCH = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community) MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
WHERE c.community_id IN $community_ids
RETURN c.community_id AS community_id, RETURN c.community_id AS community_id,
e.id AS id, e.id AS id, e.name AS name, e.entity_type AS entity_type,
e.name_embedding AS name_embedding, e.importance_score AS importance_score, e.activation_value AS activation_value,
e.activation_value AS activation_value e.name_embedding AS name_embedding
ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC
""" """
CHECK_USER_HAS_COMMUNITIES = """ CHECK_USER_HAS_COMMUNITIES = """
@@ -1153,13 +1153,47 @@ RETURN c.community_id AS community_id, cnt AS member_count
UPDATE_COMMUNITY_METADATA = """ UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id}) MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name, SET c.name = $name,
c.summary = $summary, c.summary = $summary,
c.core_entities = $core_entities, c.core_entities = $core_entities,
c.updated_at = datetime() c.summary_embedding = $summary_embedding,
c.updated_at = datetime()
RETURN c.community_id AS community_id RETURN c.community_id AS community_id
""" """
GET_ENTITIES_PAGE = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN e.id AS id,
e.name AS name,
e.name_embedding AS name_embedding,
e.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
ORDER BY e.id
SKIP $skip LIMIT $limit
"""
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """
// 批量拉取指定实体列表的邻居(用于分批全量聚类)
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
WHERE e.id IN $entity_ids
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
WHERE nb2.id <> e.id
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
UNWIND all_neighbors AS nb
WITH e, nb WHERE nb IS NOT NULL
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN DISTINCT
e.id AS entity_id,
nb.id AS id,
nb.name AS name,
nb.name_embedding AS name_embedding,
nb.activation_value AS activation_value,
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
"""
GET_ALL_ENTITY_NEIGHBORS_BATCH = """ GET_ALL_ENTITY_NEIGHBORS_BATCH = """
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载) // 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
MATCH (e:ExtractedEntity {end_user_id: $end_user_id}) MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
@@ -1185,20 +1219,59 @@ RETURN DISTINCT
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
""" """
GET_COMMUNITY_GRAPH_DATA = """
MATCH (c:Community {end_user_id: $end_user_id}) # Community keyword search: matches name or summary via fulltext index
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c) SEARCH_COMMUNITIES_BY_KEYWORD = """
OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id}) CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
RETURN WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
elementId(c) AS c_id, RETURN c.community_id AS id,
properties(c) AS c_props, c.name AS name,
elementId(e) AS e_id, c.summary AS content,
properties(e) AS e_props, c.core_entities AS core_entities,
elementId(b) AS b_id, c.member_count AS member_count,
elementId(e2) AS e2_id, c.end_user_id AS end_user_id,
properties(e2) AS e2_props, c.updated_at AS updated_at,
elementId(r) AS r_id, score
type(r) AS r_type, ORDER BY score DESC
properties(r) AS r_props, LIMIT $limit
startNode(r) = e AS r_from_e """
# Community 向量检索 ──────────────────────────────────────────────────
# Community embedding-based search: cosine similarity on Community.summary_embedding
COMMUNITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.summary_embedding IS NOT NULL
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.community_id AS id,
c.name AS name,
c.summary AS content,
c.core_entities AS core_entities,
c.member_count AS member_count,
c.end_user_id AS end_user_id,
c.updated_at AS updated_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Community 展开检索 ──────────────────────────────────────────────────
# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索)
EXPAND_COMMUNITY_STATEMENTS = """
MATCH (c:Community {community_id: $community_id})
MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c)
MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
WHERE s.end_user_id = $end_user_id
RETURN s.statement AS statement,
s.id AS id,
s.end_user_id AS end_user_id,
s.created_at AS created_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
e.name AS source_entity,
c.name AS community_name
ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit
""" """

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import os
from typing import List, Optional from typing import List, Optional
# 使用新的仓储层 # 使用新的仓储层
@@ -158,11 +157,12 @@ async def save_dialog_and_statements_to_neo4j(
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
connector: Neo4jConnector, connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
) -> bool: ) -> bool:
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models. """Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
schedule_clustering_after_write() 显式触发。
Args: Args:
dialogue_nodes: List of DialogueNode objects to save dialogue_nodes: List of DialogueNode objects to save
chunk_nodes: List of ChunkNode objects to save chunk_nodes: List of ChunkNode objects to save
@@ -293,9 +293,6 @@ async def save_dialog_and_statements_to_neo4j(
logger.info("Transaction completed. Summary: %s", summary) logger.info("Transaction completed. Summary: %s", summary)
logger.debug("Full transaction results: %r", results) logger.debug("Full transaction results: %r", results)
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id)
return True return True
except Exception as e: except Exception as e:
@@ -309,6 +306,7 @@ def schedule_clustering_after_write(
entity_nodes: List, entity_nodes: List,
config_id: Optional[str] = None, config_id: Optional[str] = None,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
写入 Neo4j 成功后,调度后台聚类任务。 写入 Neo4j 成功后,调度后台聚类任务。
@@ -327,7 +325,7 @@ def schedule_clustering_after_write(
end_user_id = entity_nodes[0].end_user_id end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes] new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}") logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id)) asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
async def _trigger_clustering( async def _trigger_clustering(
@@ -335,6 +333,7 @@ async def _trigger_clustering(
end_user_id: str, end_user_id: str,
config_id: Optional[str] = None, config_id: Optional[str] = None,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None: ) -> None:
""" """
聚类触发函数,自动判断全量初始化还是增量更新。 聚类触发函数,自动判断全量初始化还是增量更新。
@@ -344,7 +343,7 @@ async def _trigger_clustering(
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}") logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
connector = Neo4jConnector() connector = Neo4jConnector()
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id) engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids) await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}") logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}")
except Exception as e: except Exception as e:

View File

@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
from app.repositories.neo4j.cypher_queries import ( from app.repositories.neo4j.cypher_queries import (
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
COMMUNITY_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
@@ -285,6 +288,15 @@ async def search_graph(
limit=limit, limit=limit,
)) ))
task_keys.append("summaries") task_keys.append("summaries")
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -396,6 +408,16 @@ async def search_graph_by_embedding(
)) ))
task_keys.append("summaries") task_keys.append("summaries")
# Communities (向量语义匹配)
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -408,6 +430,7 @@ async def search_graph_by_embedding(
"chunks": [], "chunks": [],
"entities": [], "entities": [],
"summaries": [], "summaries": [],
"communities": [],
} }
for key, result in zip(task_keys, task_results): for key, result in zip(task_keys, task_results):
@@ -661,6 +684,62 @@ async def search_graph_by_chunk_id(
return {"chunks": chunks} return {"chunks": chunks}
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体,
再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点,
按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。
Args:
connector: Neo4j 连接器
community_ids: 已命中的社区 ID 列表
end_user_id: 用户 ID用于数据隔离
limit: 每个社区最多返回的 Statement 数量
Returns:
{"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]}
"""
if not community_ids or not end_user_id:
return {"expanded_statements": []}
tasks = [
connector.execute_query(
EXPAND_COMMUNITY_STATEMENTS,
community_id=cid,
end_user_id=end_user_id,
limit=limit,
)
for cid in community_ids
]
task_results = await asyncio.gather(*tasks, return_exceptions=True)
expanded: List[Dict[str, Any]] = []
for cid, result in zip(community_ids, task_results):
if isinstance(result, Exception):
logger.warning(f"社区展开检索失败 community_id={cid}: {result}")
else:
expanded.extend(result)
# 按 activation_value 全局排序后去重
from app.core.memory.src.search import _deduplicate_results
expanded.sort(
key=lambda x: float(x.get("activation_value") or 0),
reverse=True,
)
expanded = _deduplicate_results(expanded)
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
return {"expanded_statements": expanded}
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,

View File

@@ -0,0 +1,254 @@
# -*- coding: utf-8 -*-
"""Neo4j 索引管理模块
负责检查和创建 Neo4j 全文索引与向量索引。
支持多环境(通过 .env 中的 NEO4J_URI/USERNAME/PASSWORD 区分)。
用法:
# 作为模块调用(应用启动时)
from app.repositories.neo4j.index_manager import ensure_indexes
await ensure_indexes()
# 作为独立脚本执行(手动建索引)
python -m app.repositories.neo4j.index_manager
"""
import asyncio
import logging
from dataclasses import dataclass
from typing import List
from app.core.config import settings
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────────────────────
# 索引定义表
# ─────────────────────────────────────────────────────────────
@dataclass
class FulltextIndexDef:
name: str
label: str
properties: List[str]
@dataclass
class VectorIndexDef:
name: str
label: str
property: str
dimensions: int
similarity: str = "cosine"
# 全文索引清单(现有 + 新增 communities
FULLTEXT_INDEXES: List[FulltextIndexDef] = [
FulltextIndexDef("statementsFulltext", "Statement", ["statement"]),
FulltextIndexDef("entitiesFulltext", "ExtractedEntity", ["name"]),
FulltextIndexDef("chunksFulltext", "Chunk", ["content"]),
FulltextIndexDef("summariesFulltext", "MemorySummary", ["content"]),
FulltextIndexDef("communitiesFulltext", "Community", ["name", "summary"]), # 第五检索源
]
# 向量索引清单(预留 community 二期)
VECTOR_INDEXES: List[VectorIndexDef] = [
VectorIndexDef("statement_embedding_index", "Statement", "statement_embedding", 1536),
VectorIndexDef("chunk_embedding_index", "Chunk", "chunk_embedding", 1536),
VectorIndexDef("entity_embedding_index", "ExtractedEntity","name_embedding", 1536),
VectorIndexDef("summary_embedding_index", "MemorySummary", "summary_embedding", 1536),
# 二期:社区向量索引
VectorIndexDef("community_summary_embedding_index", "Community", "summary_embedding", 1536),
]
# ─────────────────────────────────────────────────────────────
# 核心检查 / 创建逻辑
# ─────────────────────────────────────────────────────────────
async def _get_existing_indexes(connector: Neo4jConnector) -> set:
"""查询 Neo4j 中已存在的索引名称集合"""
rows = await connector.execute_query("SHOW INDEXES YIELD name RETURN name")
return {row["name"] for row in rows}
async def _ensure_fulltext_index(
connector: Neo4jConnector,
idx: FulltextIndexDef,
existing: set,
) -> str:
"""检查并按需创建全文索引,返回操作状态描述"""
if idx.name in existing:
return f"[SKIP] 全文索引已存在: {idx.name}"
props = ", ".join(f"n.{p}" for p in idx.properties)
cypher = (
f'CREATE FULLTEXT INDEX {idx.name} IF NOT EXISTS '
f'FOR (n:{idx.label}) ON EACH [{props}]'
)
await connector.execute_query(cypher)
return f"[CREATE] 全文索引已创建: {idx.name} ({idx.label}{idx.properties})"
async def _ensure_vector_index(
connector: Neo4jConnector,
idx: VectorIndexDef,
existing: set,
) -> str:
"""检查并按需创建向量索引,返回操作状态描述"""
if idx.name in existing:
return f"[SKIP] 向量索引已存在: {idx.name}"
cypher = (
f"CREATE VECTOR INDEX {idx.name} IF NOT EXISTS "
f"FOR (n:{idx.label}) ON n.{idx.property} "
f"OPTIONS {{indexConfig: {{"
f"`vector.dimensions`: {idx.dimensions}, "
f"`vector.similarity_function`: '{idx.similarity}'"
f"}}}}"
)
await connector.execute_query(cypher)
return (
f"[CREATE] 向量索引已创建: {idx.name} "
f"({idx.label}.{idx.property}, dim={idx.dimensions})"
)
async def ensure_indexes(connector: Neo4jConnector | None = None) -> dict:
"""
检查并创建所有必要的 Neo4j 索引(幂等,可重复调用)。
Args:
connector: 可选,传入已有连接器;为 None 时自动创建。
Returns:
dict: {
"uri": 当前连接的 Neo4j URI,
"fulltext": [操作日志列表],
"vector": [操作日志列表],
"errors": [错误信息列表],
}
"""
own_connector = connector is None
if own_connector:
connector = Neo4jConnector()
report = {
"uri": settings.NEO4J_URI,
"fulltext": [],
"vector": [],
"errors": [],
}
try:
# 一次性拉取所有已有索引名
existing = await _get_existing_indexes(connector)
logger.info(f"[IndexManager] 当前环境: {settings.NEO4J_URI}")
logger.info(f"[IndexManager] 已有索引数量: {len(existing)}")
# 处理全文索引
for idx in FULLTEXT_INDEXES:
try:
msg = await _ensure_fulltext_index(connector, idx, existing)
report["fulltext"].append(msg)
logger.info(f"[IndexManager] {msg}")
except Exception as e:
err = f"[ERROR] 全文索引 {idx.name} 创建失败: {e}"
report["errors"].append(err)
logger.error(f"[IndexManager] {err}")
# 处理向量索引
for idx in VECTOR_INDEXES:
try:
msg = await _ensure_vector_index(connector, idx, existing)
report["vector"].append(msg)
logger.info(f"[IndexManager] {msg}")
except Exception as e:
err = f"[ERROR] 向量索引 {idx.name} 创建失败: {e}"
report["errors"].append(err)
logger.error(f"[IndexManager] {err}")
finally:
if own_connector:
await connector.close()
return report
async def check_indexes(connector: Neo4jConnector | None = None) -> dict:
"""
仅检查索引状态,不创建任何索引。
Returns:
dict: {
"uri": ...,
"present": [已存在的索引名],
"missing_fulltext": [缺失的全文索引名],
"missing_vector": [缺失的向量索引名],
}
"""
own_connector = connector is None
if own_connector:
connector = Neo4jConnector()
try:
existing = await _get_existing_indexes(connector)
missing_ft = [i.name for i in FULLTEXT_INDEXES if i.name not in existing]
missing_vec = [i.name for i in VECTOR_INDEXES if i.name not in existing]
return {
"uri": settings.NEO4J_URI,
"present": sorted(existing),
"missing_fulltext": missing_ft,
"missing_vector": missing_vec,
}
finally:
if own_connector:
await connector.close()
# ─────────────────────────────────────────────────────────────
# 独立脚本入口
# ─────────────────────────────────────────────────────────────
async def _main():
import sys
print(f"\n{'='*60}")
print(f"Neo4j 索引管理工具")
print(f"环境: {settings.NEO4J_URI}")
print(f"{'='*60}\n")
# 先检查
print(">>> 检查当前索引状态...\n")
status = await check_indexes()
print(f" 已存在索引数: {len(status['present'])}")
if status["missing_fulltext"]:
print(f" 缺失全文索引: {status['missing_fulltext']}")
if status["missing_vector"]:
print(f" 缺失向量索引: {status['missing_vector']}")
if not status["missing_fulltext"] and not status["missing_vector"]:
print("\n 所有索引均已存在,无需操作。")
return
# 再创建
print("\n>>> 开始创建缺失索引...\n")
report = await ensure_indexes()
for msg in report["fulltext"] + report["vector"]:
print(f" {msg}")
if report["errors"]:
print("\n[!] 以下索引创建失败:")
for err in report["errors"]:
print(f" {err}")
sys.exit(1)
else:
print("\n 全部索引处理完成。")
if __name__ == "__main__":
asyncio.run(_main())