[changes] Community Clustering Retrieval Module

This commit is contained in:
lanceyq
2026-03-16 12:30:00 +08:00
parent b1a7b58f97
commit f9fb480cc3
11 changed files with 637 additions and 96 deletions

View File

@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
from app.repositories.neo4j.cypher_queries import (
CHUNK_EMBEDDING_SEARCH,
COMMUNITY_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_ENTITIES_BY_NAME,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
@@ -285,6 +288,15 @@ async def search_graph(
limit=limit,
))
task_keys.append("summaries")
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -396,6 +408,16 @@ async def search_graph_by_embedding(
))
task_keys.append("summaries")
# Communities (向量语义匹配)
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -408,6 +430,7 @@ async def search_graph_by_embedding(
"chunks": [],
"entities": [],
"summaries": [],
"communities": [],
}
for key, result in zip(task_keys, task_results):
@@ -661,6 +684,62 @@ async def search_graph_by_chunk_id(
return {"chunks": chunks}
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体,
再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点,
按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。
Args:
connector: Neo4j 连接器
community_ids: 已命中的社区 ID 列表
end_user_id: 用户 ID用于数据隔离
limit: 每个社区最多返回的 Statement 数量
Returns:
{"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]}
"""
if not community_ids or not end_user_id:
return {"expanded_statements": []}
tasks = [
connector.execute_query(
EXPAND_COMMUNITY_STATEMENTS,
community_id=cid,
end_user_id=end_user_id,
limit=limit,
)
for cid in community_ids
]
task_results = await asyncio.gather(*tasks, return_exceptions=True)
expanded: List[Dict[str, Any]] = []
for cid, result in zip(community_ids, task_results):
if isinstance(result, Exception):
logger.warning(f"社区展开检索失败 community_id={cid}: {result}")
else:
expanded.extend(result)
# 按 activation_value 全局排序后去重
from app.core.memory.src.search import _deduplicate_results
expanded.sort(
key=lambda x: float(x.get("activation_value") or 0),
reverse=True,
)
expanded = _deduplicate_results(expanded)
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
return {"expanded_statements": expanded}
async def search_graph_by_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,