Merge pull request #573 from SuanmoSuanyangTechnology/feature/node-aggregation
Feature/node aggregation
This commit is contained in:
@@ -120,7 +120,7 @@ class SearchService:
|
|||||||
raw_results is None if return_raw_results=False
|
raw_results is None if return_raw_results=False
|
||||||
"""
|
"""
|
||||||
if include is None:
|
if include is None:
|
||||||
include = ["statements", "chunks", "entities", "summaries"]
|
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||||
|
|
||||||
# Clean query
|
# Clean query
|
||||||
cleaned_query = self.clean_query(question)
|
cleaned_query = self.clean_query(question)
|
||||||
@@ -146,8 +146,8 @@ class SearchService:
|
|||||||
if search_type == "hybrid":
|
if search_type == "hybrid":
|
||||||
reranked_results = answer.get('reranked_results', {})
|
reranked_results = answer.get('reranked_results', {})
|
||||||
|
|
||||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in reranked_results:
|
if category in include and category in reranked_results:
|
||||||
@@ -157,13 +157,43 @@ class SearchService:
|
|||||||
else:
|
else:
|
||||||
# For keyword or embedding search, results are directly in answer dict
|
# For keyword or embedding search, results are directly in answer dict
|
||||||
# Apply same priority order
|
# Apply same priority order
|
||||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in answer:
|
if category in include and category in answer:
|
||||||
category_results = answer[category]
|
category_results = answer[category]
|
||||||
if isinstance(category_results, list):
|
if isinstance(category_results, list):
|
||||||
answer_list.extend(category_results)
|
answer_list.extend(category_results)
|
||||||
|
|
||||||
|
# 对命中的 community 节点展开其成员 statements
|
||||||
|
if "communities" in include:
|
||||||
|
community_results = (
|
||||||
|
answer.get('reranked_results', {}).get('communities', [])
|
||||||
|
if search_type == "hybrid"
|
||||||
|
else answer.get('communities', [])
|
||||||
|
)
|
||||||
|
community_ids = [
|
||||||
|
r.get("id") for r in community_results if r.get("id")
|
||||||
|
]
|
||||||
|
if community_ids and end_user_id:
|
||||||
|
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
expand_connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
expand_result = await search_graph_community_expand(
|
||||||
|
connector=expand_connector,
|
||||||
|
community_ids=community_ids,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
expanded_stmts = expand_result.get("expanded_statements", [])
|
||||||
|
if expanded_stmts:
|
||||||
|
answer_list.extend(expanded_stmts)
|
||||||
|
logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"社区展开检索失败,跳过: {e}")
|
||||||
|
finally:
|
||||||
|
await expand_connector.close()
|
||||||
|
|
||||||
# Extract clean content from all results
|
# Extract clean content from all results
|
||||||
content_list = [
|
content_list = [
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
@@ -171,6 +171,13 @@ async def write(
|
|||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||||
|
schedule_clustering_after_write(
|
||||||
|
all_entity_nodes,
|
||||||
|
config_id=config_id,
|
||||||
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
|
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ def rerank_with_activation(
|
|||||||
|
|
||||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||||
keyword_items = keyword_results.get(category, [])
|
keyword_items = keyword_results.get(category, [])
|
||||||
embedding_items = embedding_results.get(category, [])
|
embedding_items = embedding_results.get(category, [])
|
||||||
|
|
||||||
@@ -281,21 +281,23 @@ def rerank_with_activation(
|
|||||||
for item in items_list:
|
for item in items_list:
|
||||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||||
if item_id and item_id in combined_items:
|
if item_id and item_id in combined_items:
|
||||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||||
|
|
||||||
# 步骤 4: 计算基础分数和最终分数
|
# 步骤 4: 计算基础分数和最终分数
|
||||||
for item_id, item in combined_items.items():
|
for item_id, item in combined_items.items():
|
||||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||||
|
raw_act_norm = item.get("normalized_activation_value")
|
||||||
|
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||||
|
|
||||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||||
base_score = content_score # 第一阶段用内容分数
|
base_score = content_score # 第一阶段用内容分数
|
||||||
|
|
||||||
# 存储激活度分数供第二阶段使用
|
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||||
item["activation_score"] = act_norm
|
item["activation_score"] = act_norm # 可能为 None
|
||||||
item["content_score"] = content_score
|
item["content_score"] = content_score
|
||||||
item["base_score"] = base_score
|
item["base_score"] = base_score
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
@@ -19,8 +20,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# 全量迭代最大轮数,防止不收敛
|
# 全量迭代最大轮数,防止不收敛
|
||||||
MAX_ITERATIONS = 10
|
MAX_ITERATIONS = 10
|
||||||
# 社区摘要核心实体数量
|
|
||||||
CORE_ENTITY_LIMIT = 5
|
# 社区核心实体取 top-N 数量
|
||||||
|
CORE_ENTITY_LIMIT = 10
|
||||||
|
|
||||||
|
|
||||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||||
@@ -69,11 +71,13 @@ class LabelPropagationEngine:
|
|||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config_id: Optional[str] = None,
|
config_id: Optional[str] = None,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
self.repo = CommunityRepository(connector)
|
self.repo = CommunityRepository(connector)
|
||||||
self.config_id = config_id
|
self.config_id = config_id
|
||||||
self.llm_model_id = llm_model_id
|
self.llm_model_id = llm_model_id
|
||||||
|
self.embedding_model_id = embedding_model_id
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 公开接口
|
# 公开接口
|
||||||
@@ -103,58 +107,81 @@ class LabelPropagationEngine:
|
|||||||
|
|
||||||
async def full_clustering(self, end_user_id: str) -> None:
|
async def full_clustering(self, end_user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
全量标签传播初始化。
|
全量标签传播初始化(分批处理,控制内存峰值)。
|
||||||
|
|
||||||
1. 拉取所有实体,初始化每个实体为独立社区
|
策略:
|
||||||
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
|
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
|
||||||
3. 直到标签不再变化或达到 MAX_ITERATIONS
|
- labels 字典跨批次共享(只存 id→community_id,内存极小)
|
||||||
4. 将最终标签写入 Neo4j
|
- 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息
|
||||||
|
- 所有批次完成后统一 flush 和 merge
|
||||||
"""
|
"""
|
||||||
entities = await self.repo.get_all_entities(end_user_id)
|
BATCH_SIZE = 888 # 每批实体数,可按需调整
|
||||||
if not entities:
|
|
||||||
|
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
|
||||||
|
total_count = await self.repo.get_entity_count(end_user_id)
|
||||||
|
if not total_count:
|
||||||
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 初始化:每个实体持有自己 id 作为社区标签
|
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
|
||||||
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
|
logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体,"
|
||||||
embeddings: Dict[str, Optional[List[float]]] = {
|
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批")
|
||||||
e["id"]: e.get("name_embedding") for e in entities
|
|
||||||
}
|
|
||||||
|
|
||||||
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
|
# labels 跨批次共享:只存 id→community_id,内存极小
|
||||||
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
|
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
|
||||||
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
|
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
|
||||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
|
||||||
|
|
||||||
for iteration in range(MAX_ITERATIONS):
|
for batch_start in range(0, total_count, BATCH_SIZE):
|
||||||
changed = 0
|
batch_entities = await self.repo.get_entities_page(
|
||||||
# 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
|
end_user_id, skip=batch_start, limit=BATCH_SIZE
|
||||||
for entity in entities:
|
|
||||||
eid = entity["id"]
|
|
||||||
# 直接从缓存取邻居,不再发起 Neo4j 查询
|
|
||||||
neighbors = neighbors_cache.get(eid, [])
|
|
||||||
|
|
||||||
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
|
|
||||||
enriched = []
|
|
||||||
for nb in neighbors:
|
|
||||||
nb_copy = dict(nb)
|
|
||||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
|
||||||
enriched.append(nb_copy)
|
|
||||||
|
|
||||||
new_label = _weighted_vote(enriched, embeddings.get(eid))
|
|
||||||
if new_label and new_label != labels[eid]:
|
|
||||||
labels[eid] = new_label
|
|
||||||
changed += 1
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS},"
|
|
||||||
f"标签变化数: {changed}"
|
|
||||||
)
|
)
|
||||||
if changed == 0:
|
if not batch_entities:
|
||||||
logger.info("[Clustering] 标签已收敛,提前结束迭代")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# 将最终标签写入 Neo4j
|
batch_ids = [e["id"] for e in batch_entities]
|
||||||
|
batch_embeddings: Dict[str, Optional[List[float]]] = {
|
||||||
|
e["id"]: e.get("name_embedding") for e in batch_entities
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:"
|
||||||
|
f"加载 {len(batch_entities)} 个实体的邻居图..."
|
||||||
|
)
|
||||||
|
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
|
||||||
|
batch_ids, end_user_id
|
||||||
|
)
|
||||||
|
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||||
|
|
||||||
|
for iteration in range(MAX_ITERATIONS):
|
||||||
|
changed = 0
|
||||||
|
for entity in batch_entities:
|
||||||
|
eid = entity["id"]
|
||||||
|
neighbors = neighbors_cache.get(eid, [])
|
||||||
|
|
||||||
|
# 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值)
|
||||||
|
enriched = []
|
||||||
|
for nb in neighbors:
|
||||||
|
nb_copy = dict(nb)
|
||||||
|
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||||
|
enriched.append(nb_copy)
|
||||||
|
|
||||||
|
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
|
||||||
|
if new_label and new_label != labels[eid]:
|
||||||
|
labels[eid] = new_label
|
||||||
|
changed += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
|
||||||
|
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
|
||||||
|
)
|
||||||
|
if changed == 0:
|
||||||
|
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 释放本批次的大对象
|
||||||
|
del neighbors_cache, batch_embeddings, batch_entities
|
||||||
|
|
||||||
|
# 所有批次完成,统一写入 Neo4j
|
||||||
await self._flush_labels(labels, end_user_id)
|
await self._flush_labels(labels, end_user_id)
|
||||||
pre_merge_count = len(set(labels.values()))
|
pre_merge_count = len(set(labels.values()))
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -162,7 +189,6 @@ class LabelPropagationEngine:
|
|||||||
f"{len(labels)} 个实体,开始后处理合并"
|
f"{len(labels)} 个实体,开始后处理合并"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
|
|
||||||
all_community_ids = list(set(labels.values()))
|
all_community_ids = list(set(labels.values()))
|
||||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||||
|
|
||||||
@@ -170,17 +196,15 @@ class LabelPropagationEngine:
|
|||||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||||
f"{len(labels)} 个实体"
|
f"{len(labels)} 个实体"
|
||||||
)
|
)
|
||||||
# 为所有社区生成元数据
|
|
||||||
# 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
|
# 查询存活社区并生成元数据
|
||||||
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
|
|
||||||
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||||
surviving_community_ids = list({
|
surviving_community_ids = list({
|
||||||
e.get("community_id") for e in surviving_communities
|
e.get("community_id") for e in surviving_communities
|
||||||
if e.get("community_id")
|
if e.get("community_id")
|
||||||
})
|
})
|
||||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||||
for cid in surviving_community_ids:
|
await self._generate_community_metadata(surviving_community_ids, end_user_id)
|
||||||
await self._generate_community_metadata(cid, end_user_id)
|
|
||||||
|
|
||||||
async def incremental_update(
|
async def incremental_update(
|
||||||
self, new_entity_ids: List[str], end_user_id: str
|
self, new_entity_ids: List[str], end_user_id: str
|
||||||
@@ -237,7 +261,7 @@ class LabelPropagationEngine:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata(new_cid, end_user_id)
|
await self._generate_community_metadata([new_cid], end_user_id)
|
||||||
else:
|
else:
|
||||||
# 加入得票最多的社区
|
# 加入得票最多的社区
|
||||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||||
@@ -249,7 +273,7 @@ class LabelPropagationEngine:
|
|||||||
await self._evaluate_merge(
|
await self._evaluate_merge(
|
||||||
list(community_ids_in_neighbors), end_user_id
|
list(community_ids_in_neighbors), end_user_id
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata(target_cid, end_user_id)
|
await self._generate_community_metadata([target_cid], end_user_id)
|
||||||
|
|
||||||
async def _evaluate_merge(
|
async def _evaluate_merge(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
@@ -413,71 +437,122 @@ class LabelPropagationEngine:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
||||||
|
"""将实体列表格式化为 prompt 行,包含 name、aliases、description。"""
|
||||||
|
lines = []
|
||||||
|
for m in members:
|
||||||
|
m_name = m.get("name", "")
|
||||||
|
aliases = m.get("aliases") or []
|
||||||
|
description = m.get("description") or ""
|
||||||
|
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
||||||
|
desc_str = f":{description}" if description else ""
|
||||||
|
lines.append(f"- {m_name}{aliases_str}{desc_str}")
|
||||||
|
return lines
|
||||||
|
|
||||||
async def _generate_community_metadata(
|
async def _generate_community_metadata(
|
||||||
self, community_id: str, end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为社区生成并写入元数据:名称、摘要、核心实体。
|
为一个或多个社区生成并写入元数据。
|
||||||
|
|
||||||
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
流程:
|
||||||
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||||
|
2. 收集所有 summary,一次性批量 embed
|
||||||
|
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||||
"""
|
"""
|
||||||
try:
|
if not community_ids:
|
||||||
members = await self.repo.get_community_members(community_id, end_user_id)
|
return
|
||||||
if not members:
|
|
||||||
return
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
|
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
||||||
|
async def _build_one(cid: str):
|
||||||
|
members = await self.repo.get_community_members(cid, end_user_id)
|
||||||
|
if not members:
|
||||||
|
return None
|
||||||
|
|
||||||
# 核心实体:按 activation_value 降序取 top-N
|
|
||||||
sorted_members = sorted(
|
sorted_members = sorted(
|
||||||
members,
|
members,
|
||||||
key=lambda m: m.get("activation_value") or 0,
|
key=lambda m: m.get("activation_value") or 0,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
all_names = [m["name"] for m in members if m.get("name")]
|
|
||||||
|
|
||||||
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||||
summary = f"包含实体:{', '.join(all_names)}"
|
prompt = (
|
||||||
|
f"以下是一组语义相关的实体:\n{entity_list_str}\n\n"
|
||||||
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
f"请为这组实体所代表的主题:\n"
|
||||||
if self.llm_model_id:
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
try:
|
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||||
from app.db import get_db_context
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
|
|
||||||
entity_list_str = "、".join(all_names)
|
|
||||||
prompt = (
|
|
||||||
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
|
||||||
f"请为这组实体所代表的主题:\n"
|
|
||||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
|
||||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
|
||||||
f"严格按以下格式输出,不要有其他内容:\n"
|
|
||||||
f"名称:<名称>\n摘要:<摘要>"
|
|
||||||
)
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
|
||||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
|
||||||
|
|
||||||
for line in text.strip().splitlines():
|
|
||||||
if line.startswith("名称:"):
|
|
||||||
name = line[3:].strip()
|
|
||||||
elif line.startswith("摘要:"):
|
|
||||||
summary = line[3:].strip()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
|
||||||
|
|
||||||
await self.repo.update_community_metadata(
|
|
||||||
community_id=community_id,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
name=name,
|
|
||||||
summary=summary,
|
|
||||||
core_entities=core_entities,
|
|
||||||
)
|
)
|
||||||
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
with get_db_context() as db:
|
||||||
except Exception as e:
|
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||||
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
|
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
|
name, summary = "", ""
|
||||||
|
for line in text.strip().splitlines():
|
||||||
|
if line.startswith("名称:"):
|
||||||
|
name = line[3:].strip()
|
||||||
|
elif line.startswith("摘要:"):
|
||||||
|
summary = line[3:].strip()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"community_id": cid,
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"name": name,
|
||||||
|
"summary": summary,
|
||||||
|
"core_entities": core_entities,
|
||||||
|
"summary_embedding": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[_build_one(cid) for cid in community_ids],
|
||||||
|
return_exceptions=True,
|
||||||
|
)
|
||||||
|
metadata_list = []
|
||||||
|
for cid, res in zip(community_ids, results):
|
||||||
|
if isinstance(res, Exception):
|
||||||
|
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
|
||||||
|
elif res is not None:
|
||||||
|
metadata_list.append(res)
|
||||||
|
|
||||||
|
if not metadata_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- 阶段2:批量生成 summary_embedding ---
|
||||||
|
summaries = [m["summary"] for m in metadata_list]
|
||||||
|
with get_db_context() as db:
|
||||||
|
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||||
|
embeddings = await embedder.response(summaries)
|
||||||
|
for i, meta in enumerate(metadata_list):
|
||||||
|
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||||
|
|
||||||
|
# --- 阶段3:写入(单个 or 批量)---
|
||||||
|
if len(metadata_list) == 1:
|
||||||
|
m = metadata_list[0]
|
||||||
|
result = await self.repo.update_community_metadata(
|
||||||
|
community_id=m["community_id"],
|
||||||
|
end_user_id=m["end_user_id"],
|
||||||
|
name=m["name"],
|
||||||
|
summary=m["summary"],
|
||||||
|
core_entities=m["core_entities"],
|
||||||
|
summary_embedding=m["summary_embedding"],
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
||||||
|
else:
|
||||||
|
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||||
|
if ok:
|
||||||
|
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _new_community_id() -> str:
|
def _new_community_id() -> str:
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from sqlalchemy.dialects.postgresql import JSONB
|
|||||||
from app.db import Base
|
from app.db import Base
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType
|
||||||
|
|
||||||
|
|
||||||
class PerceptualType(IntEnum):
|
class PerceptualType(IntEnum):
|
||||||
VISION = 1
|
VISION = 1
|
||||||
AUDIO = 2
|
AUDIO = 2
|
||||||
|
|||||||
@@ -13,12 +13,17 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
ENTITY_LEAVE_ALL_COMMUNITIES,
|
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||||
GET_ENTITY_NEIGHBORS,
|
GET_ENTITY_NEIGHBORS,
|
||||||
GET_ALL_ENTITIES_FOR_USER,
|
GET_ALL_ENTITIES_FOR_USER,
|
||||||
|
GET_ENTITY_COUNT_FOR_USER,
|
||||||
|
GET_ALL_ENTITY_IDS_FOR_USER,
|
||||||
|
GET_ENTITIES_PAGE,
|
||||||
GET_COMMUNITY_MEMBERS,
|
GET_COMMUNITY_MEMBERS,
|
||||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||||
|
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
||||||
CHECK_USER_HAS_COMMUNITIES,
|
CHECK_USER_HAS_COMMUNITIES,
|
||||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||||
UPDATE_COMMUNITY_METADATA,
|
UPDATE_COMMUNITY_METADATA,
|
||||||
|
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -110,6 +115,65 @@ class CommunityRepository:
|
|||||||
logger.error(f"get_all_entities failed: {e}")
|
logger.error(f"get_all_entities failed: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_entity_count(self, end_user_id: str) -> int:
|
||||||
|
"""仅返回用户实体总数,不加载实体数据。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
GET_ENTITY_COUNT_FOR_USER,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
return result[0]["entity_count"] if result else 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_entity_count failed: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def get_all_entity_ids(self, end_user_id: str) -> List[str]:
|
||||||
|
"""仅返回用户所有实体 ID 列表,不加载 embedding 等大字段。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
GET_ALL_ENTITY_IDS_FOR_USER,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
return [r["id"] for r in result]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_all_entity_ids failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_entities_page(
|
||||||
|
self, end_user_id: str, skip: int, limit: int
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""分页拉取实体,用于全量聚类分批处理。"""
|
||||||
|
try:
|
||||||
|
return await self.connector.execute_query(
|
||||||
|
GET_ENTITIES_PAGE,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_entities_page failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_entity_neighbors_for_ids(
|
||||||
|
self, entity_ids: List[str], end_user_id: str
|
||||||
|
) -> Dict[str, List[Dict]]:
|
||||||
|
"""批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。"""
|
||||||
|
try:
|
||||||
|
rows = await self.connector.execute_query(
|
||||||
|
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
||||||
|
entity_ids=entity_ids,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
result: Dict[str, List[Dict]] = {}
|
||||||
|
for row in rows:
|
||||||
|
eid = row["entity_id"]
|
||||||
|
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
|
||||||
|
result.setdefault(eid, []).append(neighbor)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_entity_neighbors_for_ids failed: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
async def get_community_members(
|
async def get_community_members(
|
||||||
self, community_id: str, end_user_id: str
|
self, community_id: str, end_user_id: str
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
@@ -177,8 +241,9 @@ class CommunityRepository:
|
|||||||
name: str,
|
name: str,
|
||||||
summary: str,
|
summary: str,
|
||||||
core_entities: List[str],
|
core_entities: List[str],
|
||||||
|
summary_embedding: Optional[List[float]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""更新社区的名称、摘要和核心实体列表。"""
|
"""更新社区的名称、摘要、核心实体列表和摘要向量。"""
|
||||||
try:
|
try:
|
||||||
result = await self.connector.execute_query(
|
result = await self.connector.execute_query(
|
||||||
UPDATE_COMMUNITY_METADATA,
|
UPDATE_COMMUNITY_METADATA,
|
||||||
@@ -187,8 +252,31 @@ class CommunityRepository:
|
|||||||
name=name,
|
name=name,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
core_entities=core_entities,
|
core_entities=core_entities,
|
||||||
|
summary_embedding=summary_embedding,
|
||||||
)
|
)
|
||||||
return bool(result)
|
return bool(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"update_community_metadata failed: {e}")
|
logger.error(f"update_community_metadata failed: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def batch_update_community_metadata(
|
||||||
|
self,
|
||||||
|
communities: List[Dict],
|
||||||
|
) -> bool:
|
||||||
|
"""批量更新多个社区的元数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
communities: 每项包含 community_id, end_user_id, name, summary,
|
||||||
|
core_entities, summary_embedding
|
||||||
|
"""
|
||||||
|
if not communities:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
await self.connector.execute_query(
|
||||||
|
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||||
|
communities=communities,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"batch_update_community_metadata failed: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -42,6 +42,13 @@ async def create_fulltext_indexes():
|
|||||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
""")
|
""")
|
||||||
print("✓ Created: summariesFulltext")
|
print("✓ Created: summariesFulltext")
|
||||||
|
|
||||||
|
# 创建 Community 索引
|
||||||
|
await connector.execute_query("""
|
||||||
|
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||||
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
|
""")
|
||||||
|
print("✓ Created: communitiesFulltext")
|
||||||
|
|
||||||
print("\nFull-text indexes created successfully with BM25 support.")
|
print("\nFull-text indexes created successfully with BM25 support.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -124,6 +131,18 @@ async def create_vector_indexes():
|
|||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
print("✓ Created: dialogue_embedding_index")
|
print("✓ Created: dialogue_embedding_index")
|
||||||
|
|
||||||
|
# Community summary embedding index
|
||||||
|
await connector.execute_query("""
|
||||||
|
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||||
|
FOR (c:Community)
|
||||||
|
ON c.summary_embedding
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
""")
|
||||||
|
print("✓ Created: community_summary_embedding_index")
|
||||||
|
|
||||||
print("\nVector indexes created successfully!")
|
print("\nVector indexes created successfully!")
|
||||||
print("\nExpected performance improvement:")
|
print("\nExpected performance improvement:")
|
||||||
|
|||||||
@@ -1122,21 +1122,33 @@ RETURN e.id AS id,
|
|||||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
GET_ENTITY_COUNT_FOR_USER = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
RETURN count(e) AS entity_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ALL_ENTITY_IDS_FOR_USER = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
RETURN e.id AS id
|
||||||
|
"""
|
||||||
|
|
||||||
GET_COMMUNITY_MEMBERS = """
|
GET_COMMUNITY_MEMBERS = """
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||||
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||||
e.name_embedding AS name_embedding
|
e.name_embedding AS name_embedding,
|
||||||
|
e.aliases AS aliases, e.description AS description
|
||||||
ORDER BY coalesce(e.activation_value, 0) DESC
|
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
WHERE c.community_id IN $community_ids
|
|
||||||
RETURN c.community_id AS community_id,
|
RETURN c.community_id AS community_id,
|
||||||
e.id AS id,
|
e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||||
|
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||||
e.name_embedding AS name_embedding,
|
e.name_embedding AS name_embedding,
|
||||||
e.activation_value AS activation_value
|
e.aliases AS aliases, e.description AS description
|
||||||
|
ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CHECK_USER_HAS_COMMUNITIES = """
|
CHECK_USER_HAS_COMMUNITIES = """
|
||||||
@@ -1153,13 +1165,58 @@ RETURN c.community_id AS community_id, cnt AS member_count
|
|||||||
|
|
||||||
UPDATE_COMMUNITY_METADATA = """
|
UPDATE_COMMUNITY_METADATA = """
|
||||||
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
||||||
SET c.name = $name,
|
SET c.name = $name,
|
||||||
c.summary = $summary,
|
c.summary = $summary,
|
||||||
c.core_entities = $core_entities,
|
c.core_entities = $core_entities,
|
||||||
c.updated_at = datetime()
|
c.summary_embedding = $summary_embedding,
|
||||||
|
c.updated_at = datetime()
|
||||||
RETURN c.community_id AS community_id
|
RETURN c.community_id AS community_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
BATCH_UPDATE_COMMUNITY_METADATA = """
|
||||||
|
UNWIND $communities AS row
|
||||||
|
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
||||||
|
SET c.name = row.name,
|
||||||
|
c.summary = row.summary,
|
||||||
|
c.core_entities = row.core_entities,
|
||||||
|
c.summary_embedding = row.summary_embedding,
|
||||||
|
c.updated_at = datetime()
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ENTITIES_PAGE = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.name_embedding AS name_embedding,
|
||||||
|
e.activation_value AS activation_value,
|
||||||
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
|
ORDER BY e.id
|
||||||
|
SKIP $skip LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """
|
||||||
|
// 批量拉取指定实体列表的邻居(用于分批全量聚类)
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
WHERE e.id IN $entity_ids
|
||||||
|
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
WHERE nb2.id <> e.id
|
||||||
|
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||||
|
UNWIND all_neighbors AS nb
|
||||||
|
WITH e, nb WHERE nb IS NOT NULL
|
||||||
|
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
RETURN DISTINCT
|
||||||
|
e.id AS entity_id,
|
||||||
|
nb.id AS id,
|
||||||
|
nb.name AS name,
|
||||||
|
nb.name_embedding AS name_embedding,
|
||||||
|
nb.activation_value AS activation_value,
|
||||||
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
||||||
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
@@ -1185,20 +1242,59 @@ RETURN DISTINCT
|
|||||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
GET_COMMUNITY_GRAPH_DATA = """
|
|
||||||
MATCH (c:Community {end_user_id: $end_user_id})
|
# Community keyword search: matches name or summary via fulltext index
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[b:BELONGS_TO_COMMUNITY]->(c)
|
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||||
OPTIONAL MATCH (e)-[r:EXTRACTED_RELATIONSHIP]-(e2:ExtractedEntity {end_user_id: $end_user_id})
|
CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
|
||||||
RETURN
|
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
elementId(c) AS c_id,
|
RETURN c.community_id AS id,
|
||||||
properties(c) AS c_props,
|
c.name AS name,
|
||||||
elementId(e) AS e_id,
|
c.summary AS content,
|
||||||
properties(e) AS e_props,
|
c.core_entities AS core_entities,
|
||||||
elementId(b) AS b_id,
|
c.member_count AS member_count,
|
||||||
elementId(e2) AS e2_id,
|
c.end_user_id AS end_user_id,
|
||||||
properties(e2) AS e2_props,
|
c.updated_at AS updated_at,
|
||||||
elementId(r) AS r_id,
|
score
|
||||||
type(r) AS r_type,
|
ORDER BY score DESC
|
||||||
properties(r) AS r_props,
|
LIMIT $limit
|
||||||
startNode(r) = e AS r_from_e
|
"""
|
||||||
|
|
||||||
|
# Community 向量检索 ──────────────────────────────────────────────────
|
||||||
|
# Community embedding-based search: cosine similarity on Community.summary_embedding
|
||||||
|
COMMUNITY_EMBEDDING_SEARCH = """
|
||||||
|
CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding)
|
||||||
|
YIELD node AS c, score
|
||||||
|
WHERE c.summary_embedding IS NOT NULL
|
||||||
|
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
|
RETURN c.community_id AS id,
|
||||||
|
c.name AS name,
|
||||||
|
c.summary AS content,
|
||||||
|
c.core_entities AS core_entities,
|
||||||
|
c.member_count AS member_count,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.updated_at AS updated_at,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Community 展开检索 ──────────────────────────────────────────────────
|
||||||
|
# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索)
|
||||||
|
EXPAND_COMMUNITY_STATEMENTS = """
|
||||||
|
MATCH (c:Community {community_id: $community_id})
|
||||||
|
MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c)
|
||||||
|
MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
WHERE s.end_user_id = $end_user_id
|
||||||
|
RETURN s.statement AS statement,
|
||||||
|
s.id AS id,
|
||||||
|
s.end_user_id AS end_user_id,
|
||||||
|
s.created_at AS created_at,
|
||||||
|
s.valid_at AS valid_at,
|
||||||
|
s.invalid_at AS invalid_at,
|
||||||
|
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||||
|
e.name AS source_entity,
|
||||||
|
c.name AS community_name
|
||||||
|
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||||
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
# 使用新的仓储层
|
# 使用新的仓储层
|
||||||
@@ -158,11 +157,12 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
|
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
||||||
|
schedule_clustering_after_write() 显式触发。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dialogue_nodes: List of DialogueNode objects to save
|
dialogue_nodes: List of DialogueNode objects to save
|
||||||
chunk_nodes: List of ChunkNode objects to save
|
chunk_nodes: List of ChunkNode objects to save
|
||||||
@@ -293,9 +293,6 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
logger.info("Transaction completed. Summary: %s", summary)
|
logger.info("Transaction completed. Summary: %s", summary)
|
||||||
logger.debug("Full transaction results: %r", results)
|
logger.debug("Full transaction results: %r", results)
|
||||||
|
|
||||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
|
||||||
schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -309,6 +306,7 @@ def schedule_clustering_after_write(
|
|||||||
entity_nodes: List,
|
entity_nodes: List,
|
||||||
config_id: Optional[str] = None,
|
config_id: Optional[str] = None,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
写入 Neo4j 成功后,调度后台聚类任务。
|
写入 Neo4j 成功后,调度后台聚类任务。
|
||||||
@@ -327,7 +325,7 @@ def schedule_clustering_after_write(
|
|||||||
end_user_id = entity_nodes[0].end_user_id
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
new_entity_ids = [e.id for e in entity_nodes]
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
|
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
||||||
|
|
||||||
|
|
||||||
async def _trigger_clustering(
|
async def _trigger_clustering(
|
||||||
@@ -335,6 +333,7 @@ async def _trigger_clustering(
|
|||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
config_id: Optional[str] = None,
|
config_id: Optional[str] = None,
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||||
@@ -344,7 +343,7 @@ async def _trigger_clustering(
|
|||||||
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||||
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
|
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
||||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from app.repositories.neo4j.cypher_queries import (
|
from app.repositories.neo4j.cypher_queries import (
|
||||||
CHUNK_EMBEDDING_SEARCH,
|
CHUNK_EMBEDDING_SEARCH,
|
||||||
|
COMMUNITY_EMBEDDING_SEARCH,
|
||||||
ENTITY_EMBEDDING_SEARCH,
|
ENTITY_EMBEDDING_SEARCH,
|
||||||
|
EXPAND_COMMUNITY_STATEMENTS,
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||||
SEARCH_CHUNKS_BY_CONTENT,
|
SEARCH_CHUNKS_BY_CONTENT,
|
||||||
|
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||||
SEARCH_ENTITIES_BY_NAME,
|
SEARCH_ENTITIES_BY_NAME,
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||||
@@ -285,6 +288,15 @@ async def search_graph(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
))
|
))
|
||||||
task_keys.append("summaries")
|
task_keys.append("summaries")
|
||||||
|
|
||||||
|
if "communities" in include:
|
||||||
|
tasks.append(connector.execute_query(
|
||||||
|
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||||
|
q=q,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
))
|
||||||
|
task_keys.append("communities")
|
||||||
|
|
||||||
# Execute all queries in parallel
|
# Execute all queries in parallel
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
@@ -396,6 +408,16 @@ async def search_graph_by_embedding(
|
|||||||
))
|
))
|
||||||
task_keys.append("summaries")
|
task_keys.append("summaries")
|
||||||
|
|
||||||
|
# Communities (向量语义匹配)
|
||||||
|
if "communities" in include:
|
||||||
|
tasks.append(connector.execute_query(
|
||||||
|
COMMUNITY_EMBEDDING_SEARCH,
|
||||||
|
embedding=embedding,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
))
|
||||||
|
task_keys.append("communities")
|
||||||
|
|
||||||
# Execute all queries in parallel
|
# Execute all queries in parallel
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
@@ -408,6 +430,7 @@ async def search_graph_by_embedding(
|
|||||||
"chunks": [],
|
"chunks": [],
|
||||||
"entities": [],
|
"entities": [],
|
||||||
"summaries": [],
|
"summaries": [],
|
||||||
|
"communities": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, result in zip(task_keys, task_results):
|
for key, result in zip(task_keys, task_results):
|
||||||
@@ -661,6 +684,62 @@ async def search_graph_by_chunk_id(
|
|||||||
return {"chunks": chunks}
|
return {"chunks": chunks}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_graph_community_expand(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
community_ids: List[str],
|
||||||
|
end_user_id: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
三期:社区展开检索 —— 主题 → 细节两级检索。
|
||||||
|
|
||||||
|
命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体,
|
||||||
|
再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点,
|
||||||
|
按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: Neo4j 连接器
|
||||||
|
community_ids: 已命中的社区 ID 列表
|
||||||
|
end_user_id: 用户 ID,用于数据隔离
|
||||||
|
limit: 每个社区最多返回的 Statement 数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]}
|
||||||
|
"""
|
||||||
|
if not community_ids or not end_user_id:
|
||||||
|
return {"expanded_statements": []}
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
connector.execute_query(
|
||||||
|
EXPAND_COMMUNITY_STATEMENTS,
|
||||||
|
community_id=cid,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
for cid in community_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
expanded: List[Dict[str, Any]] = []
|
||||||
|
for cid, result in zip(community_ids, task_results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(f"社区展开检索失败 community_id={cid}: {result}")
|
||||||
|
else:
|
||||||
|
expanded.extend(result)
|
||||||
|
|
||||||
|
# 按 activation_value 全局排序后去重
|
||||||
|
from app.core.memory.src.search import _deduplicate_results
|
||||||
|
expanded.sort(
|
||||||
|
key=lambda x: float(x.get("activation_value") or 0),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
expanded = _deduplicate_results(expanded)
|
||||||
|
|
||||||
|
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
|
||||||
|
return {"expanded_statements": expanded}
|
||||||
|
|
||||||
|
|
||||||
async def search_graph_by_created_at(
|
async def search_graph_by_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|||||||
Submodule redbear-mem-benchmark updated: c3bbc6931c...89053e48e9
Reference in New Issue
Block a user