[add] The application layer introduces the clustering community-retrieval module
This commit is contained in:
@@ -155,7 +155,7 @@ async def clean_databases(data) -> str:
|
||||
# Process reranked results
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
@@ -169,11 +169,18 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# Extract text content
|
||||
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
||||
text_parts = []
|
||||
seen_community_names = set()
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
text = item.get('statement') or item.get('content', '')
|
||||
# community 节点用 name 去重
|
||||
if 'member_count' in item or 'core_entities' in item:
|
||||
community_name = item.get('name') or item.get('id', '')
|
||||
if community_name in seen_community_names:
|
||||
continue
|
||||
seen_community_names.add(community_name)
|
||||
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif isinstance(item, str):
|
||||
@@ -354,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
||||
}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
@@ -390,8 +401,67 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
raw_results = tool_results['content']
|
||||
clean_content = await clean_databases(raw_results)
|
||||
|
||||
# 社区展开:从 tool 返回结果中提取命中的 community,
|
||||
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
||||
_expanded_stmts_to_write = []
|
||||
try:
|
||||
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
||||
reranked = results_dict.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
if not community_hits:
|
||||
# 兼容非 hybrid 路径
|
||||
community_hits = results_dict.get('communities', [])
|
||||
community_ids = [c.get('id') for c in community_hits if c.get('id')]
|
||||
if community_ids:
|
||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
expand_connector = Neo4jConnector()
|
||||
try:
|
||||
expand_result = await search_graph_community_expand(
|
||||
connector=expand_connector,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
limit=10,
|
||||
)
|
||||
expanded_stmts = expand_result.get('expanded_statements', [])
|
||||
if expanded_stmts:
|
||||
# 去重:过滤掉直接检索已经包含的 statement 文本
|
||||
existing_lines = set(clean_content.splitlines())
|
||||
expanded_texts = [
|
||||
s['statement'] for s in expanded_stmts
|
||||
if s.get('statement') and s['statement'] not in existing_lines
|
||||
]
|
||||
if expanded_texts:
|
||||
clean_content = clean_content + '\n' + '\n'.join(expanded_texts)
|
||||
# 暂存展开结果,稍后写回 raw_results['results']
|
||||
_expanded_stmts_to_write = expanded_stmts
|
||||
logger.info(
|
||||
f"[Retrieve] 社区展开追加 {len(expanded_stmts)} 条 statements,"
|
||||
f"community_ids={community_ids}"
|
||||
)
|
||||
except Exception as expand_err:
|
||||
logger.warning(f"[Retrieve] 社区展开检索失败,跳过: {expand_err}")
|
||||
finally:
|
||||
await expand_connector.close()
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||
|
||||
try:
|
||||
raw_results = raw_results['results']
|
||||
# 写回展开结果,过一遍字段清洗去掉 DateTime 等不可序列化字段
|
||||
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||
_fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||
}
|
||||
def _clean(obj):
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean(v) for k, v in obj.items() if k not in _fields_to_remove}
|
||||
if isinstance(obj, list):
|
||||
return [_clean(i) for i in obj]
|
||||
return obj
|
||||
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _clean(_expanded_stmts_to_write)
|
||||
except Exception:
|
||||
raw_results = []
|
||||
|
||||
|
||||
@@ -334,13 +334,22 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
|
||||
}
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||
memory_config=memory_config)
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
||||
)
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
logger.info(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
|
||||
@@ -252,9 +252,10 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Clean dictionary
|
||||
@@ -310,7 +311,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
||||
"output_path": None, # Don't save to file
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
|
||||
@@ -21,7 +21,7 @@ class SearchService:
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
|
||||
@@ -30,9 +30,11 @@ class SearchService:
|
||||
- Entities: extract 'name' and 'fact_summary' fields
|
||||
- Summaries: extract 'content' field
|
||||
- Chunks: extract 'content' field
|
||||
- Communities: extract 'content' field (c.summary), prefixed with community name
|
||||
|
||||
Args:
|
||||
result: Search result dictionary
|
||||
node_type: Hint for node type ("community", "summary", etc.)
|
||||
|
||||
Returns:
|
||||
Clean content string without metadata
|
||||
@@ -46,8 +48,21 @@ class SearchService:
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
# Summaries/Chunks: extract content field
|
||||
if 'content' in result and result['content']:
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
content = result.get('content', '')
|
||||
if content:
|
||||
prefix = f"[主题:{name}] " if name else ""
|
||||
content_parts.append(f"{prefix}{content}")
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
@@ -99,7 +114,8 @@ class SearchService:
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None
|
||||
memory_config = None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -114,6 +130,8 @@ class SearchService:
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: Memory configuration object (required)
|
||||
expand_communities: If True, expand community hits to member statements (default: True).
|
||||
Set to False for quick-summary paths that only need community-level text.
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
@@ -165,8 +183,8 @@ class SearchService:
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# 对命中的 community 节点展开其成员 statements
|
||||
if "communities" in include:
|
||||
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||
if expand_communities and "communities" in include:
|
||||
community_results = (
|
||||
answer.get('reranked_results', {}).get('communities', [])
|
||||
if search_type == "hybrid"
|
||||
@@ -195,11 +213,12 @@ class SearchService:
|
||||
finally:
|
||||
await expand_connector.close()
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
self.extract_content_from_result(ans)
|
||||
for ans in answer_list
|
||||
]
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
# community 节点有 member_count 或 core_entities 字段
|
||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
|
||||
@@ -748,6 +748,7 @@ async def run_hybrid_search(
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
@@ -777,6 +778,12 @@ async def run_hybrid_search(
|
||||
include=include,
|
||||
)
|
||||
)
|
||||
except Exception as emb_init_err:
|
||||
logger.warning(
|
||||
f"[PERF] Embedding search skipped due to init error "
|
||||
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
|
||||
)
|
||||
embedding_task = None
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
|
||||
@@ -120,6 +120,18 @@ async def create_vector_indexes():
|
||||
""")
|
||||
print("✓ Created: summary_embedding_index")
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
FOR (c:Community)
|
||||
ON c.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: community_summary_embedding_index")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
||||
|
||||
@@ -305,6 +305,7 @@ async def search_graph(
|
||||
results = {}
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"search_graph: {key} 关键词查询异常: {result}")
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
@@ -361,7 +362,11 @@ async def search_graph_by_embedding(
|
||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||
|
||||
if not embeddings or not embeddings[0]:
|
||||
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
||||
logger.warning(
|
||||
f"search_graph_by_embedding: embedding 生成失败或为空,"
|
||||
f"query='{query_text[:50]}', end_user_id={end_user_id},向量检索跳过"
|
||||
)
|
||||
return {"statements": [], "chunks": [], "entities": [], "summaries": [], "communities": []}
|
||||
embedding = embeddings[0]
|
||||
|
||||
# Prepare tasks for parallel execution
|
||||
@@ -435,6 +440,7 @@ async def search_graph_by_embedding(
|
||||
|
||||
for key, result in zip(task_keys, task_results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
|
||||
results[key] = []
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
Reference in New Issue
Block a user