[add] The application layer introduces the clustering community-retrieval module

This commit is contained in:
lanceyq
2026-03-17 14:51:04 +08:00
parent 5df339b56d
commit dacfb360f6
7 changed files with 173 additions and 49 deletions

View File

@@ -155,7 +155,7 @@ async def clean_databases(data) -> str:
# Process reranked results
reranked = results.get('reranked_results', {})
if reranked:
for category in ['summaries', 'statements', 'chunks', 'entities']:
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
items = reranked.get(category, [])
if isinstance(items, list):
content_list.extend(items)
@@ -169,11 +169,18 @@ async def clean_databases(data) -> str:
elif isinstance(time_search, list):
content_list.extend(time_search)
# Extract text content
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
text_parts = []
seen_community_names = set()
for item in content_list:
if isinstance(item, dict):
text = item.get('statement') or item.get('content', '')
# community 节点用 name 去重
if 'member_count' in item or 'core_entities' in item:
community_name = item.get('name') or item.get('id', '')
if community_name in seen_community_names:
continue
seen_community_names.add(community_name)
text = item.get('statement') or item.get('content') or item.get('summary', '')
if text:
text_parts.append(text)
elif isinstance(item, str):
@@ -354,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState:
)
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
search_params = {
"end_user_id": end_user_id,
"return_raw_results": True,
"include": ["summaries", "statements", "chunks", "entities", "communities"],
}
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
@@ -390,8 +401,67 @@ async def retrieve(state: ReadState) -> ReadState:
raw_results = tool_results['content']
clean_content = await clean_databases(raw_results)
# 社区展开:从 tool 返回结果中提取命中的 community
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
_expanded_stmts_to_write = []
try:
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
reranked = results_dict.get('reranked_results', {})
community_hits = reranked.get('communities', [])
if not community_hits:
# 兼容非 hybrid 路径
community_hits = results_dict.get('communities', [])
community_ids = [c.get('id') for c in community_hits if c.get('id')]
if community_ids:
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
expand_connector = Neo4jConnector()
try:
expand_result = await search_graph_community_expand(
connector=expand_connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=10,
)
expanded_stmts = expand_result.get('expanded_statements', [])
if expanded_stmts:
# 去重:过滤掉直接检索已经包含的 statement 文本
existing_lines = set(clean_content.splitlines())
expanded_texts = [
s['statement'] for s in expanded_stmts
if s.get('statement') and s['statement'] not in existing_lines
]
if expanded_texts:
clean_content = clean_content + '\n' + '\n'.join(expanded_texts)
# 暂存展开结果,稍后写回 raw_results['results']
_expanded_stmts_to_write = expanded_stmts
logger.info(
f"[Retrieve] 社区展开追加 {len(expanded_stmts)} 条 statements"
f"community_ids={community_ids}"
)
except Exception as expand_err:
logger.warning(f"[Retrieve] 社区展开检索失败,跳过: {expand_err}")
finally:
await expand_connector.close()
except Exception as parse_err:
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
try:
raw_results = raw_results['results']
# 写回展开结果,过一遍字段清洗去掉 DateTime 等不可序列化字段
if _expanded_stmts_to_write and isinstance(raw_results, dict):
_fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
}
def _clean(obj):
if isinstance(obj, dict):
return {k: _clean(v) for k, v in obj.items() if k not in _fields_to_remove}
if isinstance(obj, list):
return [_clean(i) for i in obj]
return obj
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _clean(_expanded_stmts_to_write)
except Exception:
raw_results = []

View File

@@ -334,13 +334,22 @@ async def Input_Summary(state: ReadState) -> ReadState:
"end_user_id": end_user_id,
"question": data,
"return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
}
try:
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
**search_params,
memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
)
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
community_hits = reranked.get('communities', [])
logger.info(f"[Input_Summary] community 命中数: {len(community_hits)}, "
f"summary 命中数: {len(reranked.get('summaries', []))}")
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:

View File

@@ -252,9 +252,10 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
}
# 注意:'id' 字段保留community 展开时需要用 community id 查询成员 statements
if isinstance(data, dict):
# Clean dictionary
@@ -310,7 +311,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"search_type": search_type,
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
"output_path": None, # Don't save to file
"memory_config": memory_config,
"rerank_alpha": rerank_alpha,

View File

@@ -21,7 +21,7 @@ class SearchService:
"""Initialize the search service."""
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
"""
Extract only meaningful content from search results, dropping all metadata.
@@ -30,9 +30,11 @@ class SearchService:
- Entities: extract 'name' and 'fact_summary' fields
- Summaries: extract 'content' field
- Chunks: extract 'content' field
- Communities: extract 'content' field (c.summary), prefixed with community name
Args:
result: Search result dictionary
node_type: Hint for node type ("community", "summary", etc.)
Returns:
Clean content string without metadata
@@ -46,8 +48,21 @@ class SearchService:
if 'statement' in result and result['statement']:
content_parts.append(result['statement'])
# Summaries/Chunks: extract content field
if 'content' in result and result['content']:
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = (
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
)
if is_community:
name = result.get('name', '')
content = result.get('content', '')
if content:
prefix = f"[主题:{name}] " if name else ""
content_parts.append(f"{prefix}{content}")
elif 'content' in result and result['content']:
# Summaries / Chunks
content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original)
@@ -99,7 +114,8 @@ class SearchService:
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config = None
memory_config = None,
expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
@@ -114,6 +130,8 @@ class SearchService:
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: Memory configuration object (required)
expand_communities: If True, expand community hits to member statements (default: True).
Set to False for quick-summary paths that only need community-level text.
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
@@ -165,8 +183,8 @@ class SearchService:
if isinstance(category_results, list):
answer_list.extend(category_results)
# 对命中的 community 节点展开其成员 statements
if "communities" in include:
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
if expand_communities and "communities" in include:
community_results = (
answer.get('reranked_results', {}).get('communities', [])
if search_type == "hybrid"
@@ -195,11 +213,12 @@ class SearchService:
finally:
await expand_connector.close()
# Extract clean content from all results
content_list = [
self.extract_content_from_result(ans)
for ans in answer_list
]
# Extract clean content from all results,按类型传入 node_type 区分 community
content_list = []
for ans in answer_list:
# community 节点有 member_count 或 core_entities 字段
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines

View File

@@ -748,6 +748,7 @@ async def run_hybrid_search(
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
try:
with get_db_context() as db:
config_service = MemoryConfigService(db)
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
@@ -777,6 +778,12 @@ async def run_hybrid_search(
include=include,
)
)
except Exception as emb_init_err:
logger.warning(
f"[PERF] Embedding search skipped due to init error "
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
)
embedding_task = None
if keyword_task:
keyword_results = await keyword_task

View File

@@ -120,6 +120,18 @@ async def create_vector_indexes():
""")
print("✓ Created: summary_embedding_index")
# Community summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
FOR (c:Community)
ON c.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: community_summary_embedding_index")
# Dialogue embedding index (optional)
await connector.execute_query("""
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS

View File

@@ -305,6 +305,7 @@ async def search_graph(
results = {}
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
logger.warning(f"search_graph: {key} 关键词查询异常: {result}")
results[key] = []
else:
results[key] = result
@@ -361,7 +362,11 @@ async def search_graph_by_embedding(
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
logger.warning(
f"search_graph_by_embedding: embedding 生成失败或为空,"
f"query='{query_text[:50]}', end_user_id={end_user_id},向量检索跳过"
)
return {"statements": [], "chunks": [], "entities": [], "summaries": [], "communities": []}
embedding = embeddings[0]
# Prepare tasks for parallel execution
@@ -435,6 +440,7 @@ async def search_graph_by_embedding(
for key, result in zip(task_keys, task_results):
if isinstance(result, Exception):
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
results[key] = []
else:
results[key] = result