From dacfb360f680ac596eed544524bdd90956e2124a Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Tue, 17 Mar 2026 14:51:04 +0800 Subject: [PATCH] [add] The application layer introduces the clustering community-retrieval module --- .../langgraph_graph/nodes/retrieve_nodes.py | 78 ++++++++++++++++++- .../langgraph_graph/nodes/summary_nodes.py | 15 +++- .../agent/langgraph_graph/tools/tool.py | 5 +- .../memory/agent/services/search_service.py | 41 +++++++--- api/app/core/memory/src/search.py | 63 ++++++++------- api/app/repositories/neo4j/create_indexes.py | 12 +++ api/app/repositories/neo4j/graph_search.py | 8 +- 7 files changed, 173 insertions(+), 49 deletions(-) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index f2cd0d3d..aa421237 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -155,7 +155,7 @@ async def clean_databases(data) -> str: # Process reranked results reranked = results.get('reranked_results', {}) if reranked: - for category in ['summaries', 'statements', 'chunks', 'entities']: + for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']: items = reranked.get(category, []) if isinstance(items, list): content_list.extend(items) @@ -169,11 +169,18 @@ async def clean_databases(data) -> str: elif isinstance(time_search, list): content_list.extend(time_search) - # Extract text content + # Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复) text_parts = [] + seen_community_names = set() for item in content_list: if isinstance(item, dict): - text = item.get('statement') or item.get('content', '') + # community 节点用 name 去重 + if 'member_count' in item or 'core_entities' in item: + community_name = item.get('name') or item.get('id', '') + if community_name in seen_community_names: + continue + seen_community_names.add(community_name) + text = item.get('statement') or item.get('content') or item.get('summary', '') if text: text_parts.append(text) elif isinstance(item, str): @@ -354,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState: ) time_retrieval_tool = create_time_retrieval_tool(end_user_id) - search_params = {"end_user_id": end_user_id, "return_raw_results": True} + search_params = { + "end_user_id": end_user_id, + "return_raw_results": True, + "include": ["summaries", "statements", "chunks", "entities", "communities"], + } hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params) agent = create_agent( llm, @@ -390,8 +401,67 @@ async def retrieve(state: ReadState) -> ReadState: raw_results = tool_results['content'] clean_content = await clean_databases(raw_results) + # 社区展开:从 tool 返回结果中提取命中的 community, + # 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content + _expanded_stmts_to_write = [] + try: + results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {} + reranked = results_dict.get('reranked_results', {}) + community_hits = reranked.get('communities', []) + if not community_hits: + # 兼容非 hybrid 路径 + community_hits = results_dict.get('communities', []) + community_ids = [c.get('id') for c in community_hits if c.get('id')] + if community_ids: + from app.repositories.neo4j.graph_search import search_graph_community_expand + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + expand_connector = Neo4jConnector() + try: + expand_result = await search_graph_community_expand( + connector=expand_connector, + community_ids=community_ids, + end_user_id=end_user_id, + limit=10, + ) + expanded_stmts = expand_result.get('expanded_statements', []) + if expanded_stmts: + # 去重:过滤掉直接检索已经包含的 statement 文本 + existing_lines = set(clean_content.splitlines()) + expanded_texts = [ + s['statement'] for s in expanded_stmts + if s.get('statement') and s['statement'] not in existing_lines + ] + if expanded_texts: + clean_content = clean_content + '\n' + '\n'.join(expanded_texts) + # 暂存展开结果,稍后写回 raw_results['results'] + _expanded_stmts_to_write = expanded_stmts + logger.info( + f"[Retrieve] 社区展开追加 {len(expanded_stmts)} 条 statements," + f"community_ids={community_ids}" + ) + except Exception as expand_err: + logger.warning(f"[Retrieve] 社区展开检索失败,跳过: {expand_err}") + finally: + await expand_connector.close() + except Exception as parse_err: + logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}") + try: raw_results = raw_results['results'] + # 写回展开结果,过一遍字段清洗去掉 DateTime 等不可序列化字段 + if _expanded_stmts_to_write and isinstance(raw_results, dict): + _fields_to_remove = { + 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', + 'expired_at', 'created_at', 'chunk_id', 'apply_id', + 'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary' + } + def _clean(obj): + if isinstance(obj, dict): + return {k: _clean(v) for k, v in obj.items() if k not in _fields_to_remove} + if isinstance(obj, list): + return [_clean(i) for i in obj] + return obj + raw_results.setdefault('reranked_results', {})['expanded_statements'] = _clean(_expanded_stmts_to_write) except Exception: raw_results = [] diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 030acc9a..663081aa 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -334,13 +334,22 @@ async def Input_Summary(state: ReadState) -> ReadState: "end_user_id": end_user_id, "question": data, "return_raw_results": True, - "include": ["summaries"] # Only search summary nodes for faster performance + "include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点 } try: if storage_type != "rag": - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, - memory_config=memory_config) + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search( + **search_params, + memory_config=memory_config, + expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement + ) + # 调试:打印 community 检索结果数量 + if raw_results and isinstance(raw_results, dict): + reranked = raw_results.get('reranked_results', {}) + community_hits = reranked.get('communities', []) + logger.info(f"[Input_Summary] community 命中数: {len(community_hits)}, " + f"summary 命中数: {len(reranked.get('summaries', []))}") else: retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index 9bd2b2cf..ae2c5772 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -252,9 +252,10 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): # TODO: fact_summary functionality temporarily disabled, will be enabled after future development fields_to_remove = { 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', - 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id', + 'expired_at', 'created_at', 'chunk_id', 'apply_id', 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary" } + # 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements if isinstance(data, dict): # Clean dictionary @@ -310,7 +311,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): "search_type": search_type, "end_user_id": end_user_id or search_params.get("end_user_id"), "limit": limit or search_params.get("limit", 10), - "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), + "include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]), "output_path": None, # Don't save to file "memory_config": memory_config, "rerank_alpha": rerank_alpha, diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index c9346c16..e89951cc 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -21,7 +21,7 @@ class SearchService: """Initialize the search service.""" logger.info("SearchService initialized") - def extract_content_from_result(self, result: dict) -> str: + def extract_content_from_result(self, result: dict, node_type: str = "") -> str: """ Extract only meaningful content from search results, dropping all metadata. @@ -30,9 +30,11 @@ class SearchService: - Entities: extract 'name' and 'fact_summary' fields - Summaries: extract 'content' field - Chunks: extract 'content' field + - Communities: extract 'content' field (c.summary), prefixed with community name Args: result: Search result dictionary + node_type: Hint for node type ("community", "summary", etc.) Returns: Clean content string without metadata @@ -46,8 +48,21 @@ class SearchService: if 'statement' in result and result['statement']: content_parts.append(result['statement']) - # Summaries/Chunks: extract content field - if 'content' in result and result['content']: + # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 + # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 + is_community = ( + node_type == "community" + or 'member_count' in result + or 'core_entities' in result + ) + if is_community: + name = result.get('name', '') + content = result.get('content', '') + if content: + prefix = f"[主题:{name}] " if name else "" + content_parts.append(f"{prefix}{content}") + elif 'content' in result and result['content']: + # Summaries / Chunks content_parts.append(result['content']) # Entities: extract name and fact_summary (commented out in original) @@ -99,7 +114,8 @@ class SearchService: rerank_alpha: float = 0.4, output_path: str = "search_results.json", return_raw_results: bool = False, - memory_config = None + memory_config = None, + expand_communities: bool = True, ) -> Tuple[str, str, Optional[dict]]: """ Execute hybrid search and return clean content. @@ -114,6 +130,8 @@ class SearchService: output_path: Path to save search results (default: "search_results.json") return_raw_results: If True, also return the raw search results as third element (default: False) memory_config: Memory configuration object (required) + expand_communities: If True, expand community hits to member statements (default: True). + Set to False for quick-summary paths that only need community-level text. Returns: Tuple of (clean_content, cleaned_query, raw_results) @@ -165,8 +183,8 @@ class SearchService: if isinstance(category_results, list): answer_list.extend(category_results) - # 对命中的 community 节点展开其成员 statements - if "communities" in include: + # 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要) + if expand_communities and "communities" in include: community_results = ( answer.get('reranked_results', {}).get('communities', []) if search_type == "hybrid" @@ -195,11 +213,12 @@ class SearchService: finally: await expand_connector.close() - # Extract clean content from all results - content_list = [ - self.extract_content_from_result(ans) - for ans in answer_list - ] + # Extract clean content from all results,按类型传入 node_type 区分 community + content_list = [] + for ans in answer_list: + # community 节点有 member_count 或 core_entities 字段 + ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" + content_list.append(self.extract_content_from_result(ans, node_type=ntype)) # Filter out empty strings and join with newlines diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 3570d707..eb45e4d1 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -748,35 +748,42 @@ async def run_hybrid_search( # 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig config_load_start = time.time() - with get_db_context() as db: - config_service = MemoryConfigService(db) - embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) - rb_config = RedBearModelConfig( - model_name=embedder_config_dict["model_name"], - provider=embedder_config_dict["provider"], - api_key=embedder_config_dict["api_key"], - base_url=embedder_config_dict["base_url"], - type="llm" - ) - config_load_time = time.time() - config_load_start - logger.info(f"[PERF] Config loading took {config_load_time:.4f}s") - - # Init embedder - embedder_init_start = time.time() - embedder = OpenAIEmbedderClient(model_config=rb_config) - embedder_init_time = time.time() - embedder_init_start - logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s") - - embedding_task = asyncio.create_task( - search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=query_text, - end_user_id=end_user_id, - limit=limit, - include=include, + try: + with get_db_context() as db: + config_service = MemoryConfigService(db) + embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) + rb_config = RedBearModelConfig( + model_name=embedder_config_dict["model_name"], + provider=embedder_config_dict["provider"], + api_key=embedder_config_dict["api_key"], + base_url=embedder_config_dict["base_url"], + type="llm" ) - ) + config_load_time = time.time() - config_load_start + logger.info(f"[PERF] Config loading took {config_load_time:.4f}s") + + # Init embedder + embedder_init_start = time.time() + embedder = OpenAIEmbedderClient(model_config=rb_config) + embedder_init_time = time.time() - embedder_init_start + logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s") + + embedding_task = asyncio.create_task( + search_graph_by_embedding( + connector=connector, + embedder_client=embedder, + query_text=query_text, + end_user_id=end_user_id, + limit=limit, + include=include, + ) + ) + except Exception as emb_init_err: + logger.warning( + f"[PERF] Embedding search skipped due to init error " + f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}" + ) + embedding_task = None if keyword_task: keyword_results = await keyword_task diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index 29f60fdd..d9e94117 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -119,6 +119,18 @@ async def create_vector_indexes(): }} """) print("✓ Created: summary_embedding_index") + + # Community summary embedding index + await connector.execute_query(""" + CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS + FOR (c:Community) + ON c.summary_embedding + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """) + print("✓ Created: community_summary_embedding_index") # Dialogue embedding index (optional) await connector.execute_query(""" diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 19e40a82..d3aabd32 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -305,6 +305,7 @@ async def search_graph( results = {} for key, result in zip(task_keys, task_results): if isinstance(result, Exception): + logger.warning(f"search_graph: {key} 关键词查询异常: {result}") results[key] = [] else: results[key] = result @@ -361,7 +362,11 @@ async def search_graph_by_embedding( print(f"[PERF] Embedding generation took: {embed_time:.4f}s") if not embeddings or not embeddings[0]: - return {"statements": [], "chunks": [], "entities": [], "summaries": []} + logger.warning( + f"search_graph_by_embedding: embedding 生成失败或为空," + f"query='{query_text[:50]}', end_user_id={end_user_id},向量检索跳过" + ) + return {"statements": [], "chunks": [], "entities": [], "summaries": [], "communities": []} embedding = embeddings[0] # Prepare tasks for parallel execution @@ -435,6 +440,7 @@ async def search_graph_by_embedding( for key, result in zip(task_keys, task_results): if isinstance(result, Exception): + logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}") results[key] = [] else: results[key] = result