diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index aa421237..68260f26 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -409,59 +409,24 @@ async def retrieve(state: ReadState) -> ReadState: reranked = results_dict.get('reranked_results', {}) community_hits = reranked.get('communities', []) if not community_hits: - # 兼容非 hybrid 路径 community_hits = results_dict.get('communities', []) - community_ids = [c.get('id') for c in community_hits if c.get('id')] - if community_ids: - from app.repositories.neo4j.graph_search import search_graph_community_expand - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - expand_connector = Neo4jConnector() - try: - expand_result = await search_graph_community_expand( - connector=expand_connector, - community_ids=community_ids, - end_user_id=end_user_id, - limit=10, - ) - expanded_stmts = expand_result.get('expanded_statements', []) - if expanded_stmts: - # 去重:过滤掉直接检索已经包含的 statement 文本 - existing_lines = set(clean_content.splitlines()) - expanded_texts = [ - s['statement'] for s in expanded_stmts - if s.get('statement') and s['statement'] not in existing_lines - ] - if expanded_texts: - clean_content = clean_content + '\n' + '\n'.join(expanded_texts) - # 暂存展开结果,稍后写回 raw_results['results'] - _expanded_stmts_to_write = expanded_stmts - logger.info( - f"[Retrieve] 社区展开追加 {len(expanded_stmts)} 条 statements," - f"community_ids={community_ids}" - ) - except Exception as expand_err: - logger.warning(f"[Retrieve] 社区展开检索失败,跳过: {expand_err}") - finally: - await expand_connector.close() + if community_hits: + from app.core.memory.agent.services.search_service import expand_communities_to_statements + _expanded_stmts_to_write, new_texts = await expand_communities_to_statements( + community_results=community_hits, + end_user_id=end_user_id, + existing_content=clean_content, + ) + if new_texts: + clean_content = clean_content + '\n' + '\n'.join(new_texts) except Exception as parse_err: logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}") try: raw_results = raw_results['results'] - # 写回展开结果,过一遍字段清洗去掉 DateTime 等不可序列化字段 + # 写回展开结果,接口返回中可见(已在 helper 中清洗过字段) if _expanded_stmts_to_write and isinstance(raw_results, dict): - _fields_to_remove = { - 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', - 'expired_at', 'created_at', 'chunk_id', 'apply_id', - 'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary' - } - def _clean(obj): - if isinstance(obj, dict): - return {k: _clean(v) for k, v in obj.items() if k not in _fields_to_remove} - if isinstance(obj, list): - return [_clean(i) for i in obj] - return obj - raw_results.setdefault('reranked_results', {})['expanded_statements'] = _clean(_expanded_stmts_to_write) + raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write except Exception: raw_results = [] diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 663081aa..d967a285 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -348,8 +348,8 @@ async def Input_Summary(state: ReadState) -> ReadState: if raw_results and isinstance(raw_results, dict): reranked = raw_results.get('reranked_results', {}) community_hits = reranked.get('communities', []) - logger.info(f"[Input_Summary] community 命中数: {len(community_hits)}, " - f"summary 命中数: {len(reranked.get('summaries', []))}") + logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, " + f"summary 命中数: {len(reranked.get('summaries', []))}") else: retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) except Exception as e: diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index e89951cc..90b1c088 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -13,6 +13,72 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query logger = get_agent_logger(__name__) +# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化) +_EXPAND_FIELDS_TO_REMOVE = { + 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', + 'expired_at', 'created_at', 'chunk_id', 'apply_id', + 'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary' +} + + +def _clean_expand_fields(obj): + """递归过滤展开结果中不可序列化的字段(DateTime 等)。""" + if isinstance(obj, dict): + return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE} + if isinstance(obj, list): + return [_clean_expand_fields(i) for i in obj] + return obj + + +async def expand_communities_to_statements( + community_results: List[dict], + end_user_id: str, + existing_content: str = "", + limit: int = 10, +) -> Tuple[List[dict], List[str]]: + """ + 社区展开 helper:给定命中的 community 列表,拉取关联 Statement。 + + - 对展开结果去重(过滤已在 existing_content 中出现的文本) + - 过滤不可序列化字段 + - 返回 (cleaned_expanded_stmts, new_texts) + - cleaned_expanded_stmts: 可直接写回 raw_results 的列表 + - new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content + """ + community_ids = [r.get("id") for r in community_results if r.get("id")] + if not community_ids or not end_user_id: + return [], [] + + from app.repositories.neo4j.graph_search import search_graph_community_expand + from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + connector = Neo4jConnector() + try: + result = await search_graph_community_expand( + connector=connector, + community_ids=community_ids, + end_user_id=end_user_id, + limit=limit, + ) + except Exception as e: + logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}") + return [], [] + finally: + await connector.close() + + expanded_stmts = result.get("expanded_statements", []) + if not expanded_stmts: + return [], [] + + existing_lines = set(existing_content.splitlines()) + new_texts = [ + s["statement"] for s in expanded_stmts + if s.get("statement") and s["statement"] not in existing_lines + ] + cleaned = _clean_expand_fields(expanded_stmts) + logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}") + return cleaned, new_texts + class SearchService: """Service for executing hybrid search and processing results.""" @@ -190,28 +256,11 @@ class SearchService: if search_type == "hybrid" else answer.get('communities', []) ) - community_ids = [ - r.get("id") for r in community_results if r.get("id") - ] - if community_ids and end_user_id: - from app.repositories.neo4j.graph_search import search_graph_community_expand - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - expand_connector = Neo4jConnector() - try: - expand_result = await search_graph_community_expand( - connector=expand_connector, - community_ids=community_ids, - end_user_id=end_user_id, - limit=10, - ) - expanded_stmts = expand_result.get("expanded_statements", []) - if expanded_stmts: - answer_list.extend(expanded_stmts) - logger.info(f"社区展开检索追加 {len(expanded_stmts)} 条 statements") - except Exception as e: - logger.warning(f"社区展开检索失败,跳过: {e}") - finally: - await expand_connector.close() + cleaned_stmts, new_texts = await expand_communities_to_statements( + community_results=community_results, + end_user_id=end_user_id, + ) + answer_list.extend(cleaned_stmts) # Extract clean content from all results,按类型传入 node_type 区分 community content_list = [] diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index eb45e4d1..1cab1f3f 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -727,6 +727,11 @@ async def run_hybrid_search( keyword_task = None embedding_task = None + keyword_task = None + embedding_task = None + keyword_results: Dict[str, List] = {} + embedding_results: Dict[str, List] = {} + if search_type in ["keyword", "hybrid"]: # Keyword-based search logger.info("[PERF] Starting keyword search...")