From 9cbe9d5edc12c622d2d5a576f00d4873bcd1f7ce Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 1 Apr 2026 17:19:03 +0800 Subject: [PATCH] feat(memory): add perceptual memory retrieval service with BM25+embedding fusion --- .../nodes/perceptual_retrieve_node.py | 408 +++++++++++++++++ .../langgraph_graph/nodes/problem_nodes.py | 1 - .../langgraph_graph/nodes/summary_nodes.py | 76 +++- .../agent/langgraph_graph/read_graph.py | 15 +- .../memory/agent/services/search_service.py | 91 ++-- api/app/core/memory/agent/utils/llm_tools.py | 2 +- api/app/core/memory/src/search.py | 283 ++++++------ api/app/main.py | 1 + api/app/repositories/neo4j/create_indexes.py | 70 +-- api/app/repositories/neo4j/cypher_queries.py | 41 ++ api/app/repositories/neo4j/graph_search.py | 414 ++++++++++-------- api/app/repositories/neo4j/neo4j_connector.py | 26 +- api/app/services/memory_agent_service.py | 23 +- 13 files changed, 1042 insertions(+), 409 deletions(-) create mode 100644 api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py b/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py new file mode 100644 index 00000000..f248afa5 --- /dev/null +++ b/api/app/core/memory/agent/langgraph_graph/nodes/perceptual_retrieve_node.py @@ -0,0 +1,408 @@ +""" +Perceptual Memory Retrieval Node & Service + +Provides PerceptualSearchService for searching perceptual memories (vision, audio, +text, conversation) from Neo4j using keyword fulltext + embedding semantic search +with BM25+embedding fusion reranking. + +Also provides the perceptual_retrieve_node for use as a LangGraph node. +""" +import asyncio +import math +from typing import List, Dict, Any, Optional + +from app.core.logging_config import get_agent_logger +from app.core.memory.agent.utils.llm_tools import ReadState +from app.core.memory.utils.data.text_utils import escape_lucene_query +from app.repositories.neo4j.graph_search import ( + search_perceptual, + search_perceptual_by_embedding, +) +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + +logger = get_agent_logger(__name__) + + +class PerceptualSearchService: + """ + 感知记忆检索服务。 + + 封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。 + 调用方只需提供 query / keywords、end_user_id、memory_config,即可获得 + 格式化并排序后的感知记忆列表和拼接文本。 + + Usage: + service = PerceptualSearchService(end_user_id=..., memory_config=...) + results = await service.search(query="...", keywords=[...], limit=10) + # results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M} + """ + + DEFAULT_ALPHA = 0.6 + DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5 + + def __init__( + self, + end_user_id: str, + memory_config: Any, + alpha: float = DEFAULT_ALPHA, + content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD, + ): + self.end_user_id = end_user_id + self.memory_config = memory_config + self.alpha = alpha + self.content_score_threshold = content_score_threshold + + async def search( + self, + query: str, + keywords: Optional[List[str]] = None, + limit: int = 10, + ) -> Dict[str, Any]: + """ + 执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。 + + 对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数, + 确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。 + + Args: + query: 原始用户查询(用于向量检索和 BM25 补查) + keywords: 关键词列表(用于全文检索),为 None 时使用 [query] + limit: 最大返回数量 + + Returns: + { + "memories": [格式化后的记忆 dict, ...], + "content": "拼接的纯文本摘要", + "keyword_raw": int, + "embedding_raw": int, + } + """ + if keywords is None: + keywords = [query] if query else [] + + connector = Neo4jConnector() + try: + kw_task = self._keyword_search(connector, keywords, limit) + emb_task = self._embedding_search(connector, query, limit) + + kw_results, emb_results = await asyncio.gather( + kw_task, emb_task, return_exceptions=True + ) + if isinstance(kw_results, Exception): + logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}") + kw_results = [] + if isinstance(emb_results, Exception): + logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}") + emb_results = [] + + # 补查 BM25:找出 embedding 命中但 keyword 未命中的 id, + # 用原始 query 对这些节点补查全文索引拿 BM25 score + kw_ids = {r.get("id") for r in kw_results if r.get("id")} + emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids} + + if emb_only_ids and query: + backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit) + # 把补查到的 BM25 score 注入到 embedding 结果中 + backfill_map = {r["id"]: r.get("score", 0) for r in backfill} + for r in emb_results: + rid = r.get("id", "") + if rid in backfill_map: + r["bm25_backfill_score"] = backfill_map[rid] + logger.info( + f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, " + f"{len(backfill_map)} got BM25 scores" + ) + + reranked = self._rerank(kw_results, emb_results, limit) + + memories = [] + content_parts = [] + for record in reranked: + fmt = self._format_result(record) + fmt["score"] = round(record.get("content_score", 0), 4) + memories.append(fmt) + content_parts.append(self._build_content_text(fmt)) + + logger.info( + f"[PerceptualSearch] {len(memories)} results after rerank " + f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})" + ) + return { + "memories": memories, + "content": "\n\n".join(content_parts), + "keyword_raw": len(kw_results), + "embedding_raw": len(emb_results), + } + finally: + await connector.close() + + async def _bm25_backfill( + self, + connector: Neo4jConnector, + query: str, + target_ids: set, + limit: int, + ) -> List[dict]: + """ + 对指定 id 集合补查全文索引 BM25 score。 + + 用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。 + """ + escaped = escape_lucene_query(query) + if not escaped.strip(): + return [] + try: + r = await search_perceptual( + connector=connector, q=escaped, + end_user_id=self.end_user_id, + limit=limit * 5, # 多查一些以提高命中率 + ) + all_hits = r.get("perceptuals", []) + return [h for h in all_hits if h.get("id") in target_ids] + except Exception as e: + logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}") + return [] + + async def _keyword_search( + self, + connector: Neo4jConnector, + keywords: List[str], + limit: int, + ) -> List[dict]: + """并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。""" + seen_ids: set = set() + all_results: List[dict] = [] + + async def _one(kw: str): + escaped = escape_lucene_query(kw) + if not escaped.strip(): + return [] + r = await search_perceptual( + connector=connector, q=escaped, + end_user_id=self.end_user_id, limit=limit, + ) + return r.get("perceptuals", []) + + tasks = [_one(kw) for kw in keywords[:10]] + batch = await asyncio.gather(*tasks, return_exceptions=True) + + for result in batch: + if isinstance(result, Exception): + logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}") + continue + for rec in result: + rid = rec.get("id", "") + if rid and rid not in seen_ids: + seen_ids.add(rid) + all_results.append(rec) + + all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True) + return all_results[:limit] + + async def _embedding_search( + self, + connector: Neo4jConnector, + query_text: str, + limit: int, + ) -> List[dict]: + """向量语义检索,返回原始结果(不做阈值过滤)。""" + try: + from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient + from app.core.models.base import RedBearModelConfig + from app.db import get_db_context + from app.services.memory_config_service import MemoryConfigService + + with get_db_context() as db: + cfg = MemoryConfigService(db).get_embedder_config( + str(self.memory_config.embedding_model_id) + ) + client = OpenAIEmbedderClient(RedBearModelConfig(**cfg)) + + r = await search_perceptual_by_embedding( + connector=connector, embedder_client=client, + query_text=query_text, end_user_id=self.end_user_id, + limit=limit, + ) + return r.get("perceptuals", []) + except Exception as e: + logger.warning(f"[PerceptualSearch] embedding search failed: {e}") + return [] + + def _rerank( + self, + keyword_results: List[dict], + embedding_results: List[dict], + limit: int, + ) -> List[dict]: + """BM25 + embedding 融合排序。 + + 对 embedding 结果中带有 bm25_backfill_score 的条目, + 将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。 + """ + # 把补查的 BM25 score 合并到 keyword_results 中统一归一化 + emb_backfill_items = [] + for item in embedding_results: + backfill_score = item.get("bm25_backfill_score") + if backfill_score is not None and item.get("id"): + emb_backfill_items.append({"id": item["id"], "score": backfill_score}) + + # 合并后统一归一化 BM25 scores + all_bm25_items = keyword_results + emb_backfill_items + all_bm25_items = self._normalize_scores(all_bm25_items) + + # 建立 id -> normalized BM25 score 的映射 + bm25_norm_map: Dict[str, float] = {} + for item in all_bm25_items: + item_id = item.get("id", "") + if item_id: + bm25_norm_map[item_id] = float(item.get("normalized_score", 0)) + + # 归一化 embedding scores + embedding_results = self._normalize_scores(embedding_results) + + # 合并 + combined: Dict[str, dict] = {} + for item in keyword_results: + item_id = item.get("id", "") + if not item_id: + continue + combined[item_id] = item.copy() + combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0) + combined[item_id]["embedding_score"] = 0.0 + + for item in embedding_results: + item_id = item.get("id", "") + if not item_id: + continue + if item_id in combined: + combined[item_id]["embedding_score"] = item.get("normalized_score", 0) + else: + combined[item_id] = item.copy() + combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0) + combined[item_id]["embedding_score"] = item.get("normalized_score", 0) + + for item in combined.values(): + bm25 = float(item.get("bm25_score", 0) or 0) + emb = float(item.get("embedding_score", 0) or 0) + item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb + + results = list(combined.values()) + before = len(results) + results = [r for r in results if r["content_score"] >= self.content_score_threshold] + results.sort(key=lambda x: x["content_score"], reverse=True) + results = results[:limit] + + logger.info( + f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} " + f"(alpha={self.alpha}, threshold={self.content_score_threshold})" + ) + return results + + @staticmethod + def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]: + """Z-score + sigmoid 归一化。""" + if not items: + return items + scores = [float(it.get(field, 0) or 0) for it in items] + if len(scores) <= 1: + for it in items: + it[f"normalized_{field}"] = 1.0 + return items + mean = sum(scores) / len(scores) + var = sum((s - mean) ** 2 for s in scores) / len(scores) + std = math.sqrt(var) + if std == 0: + for it in items: + it[f"normalized_{field}"] = 1.0 + else: + for it, s in zip(items, scores): + z = (s - mean) / std + it[f"normalized_{field}"] = 1 / (1 + math.exp(-z)) + return items + + @staticmethod + def _format_result(record: dict) -> dict: + return { + "id": record.get("id", ""), + "perceptual_type": record.get("perceptual_type", ""), + "file_name": record.get("file_name", ""), + "file_path": record.get("file_path", ""), + "summary": record.get("summary", ""), + "topic": record.get("topic", ""), + "domain": record.get("domain", ""), + "keywords": record.get("keywords", []), + "created_at": str(record.get("created_at", "")), + "file_type": record.get("file_type", ""), + "score": record.get("score", 0), + } + + @staticmethod + def _build_content_text(formatted: dict) -> str: + parts = [] + if formatted["summary"]: + parts.append(formatted["summary"]) + if formatted["topic"]: + parts.append(f"[主题: {formatted['topic']}]") + if formatted["keywords"]: + kw_list = formatted["keywords"] + if isinstance(kw_list, list): + parts.append(f"[关键词: {', '.join(kw_list)}]") + if formatted["file_name"]: + parts.append(f"[文件: {formatted['file_name']}]") + return " ".join(parts) + + +def _extract_keywords_from_problems(problem_extension: dict) -> List[str]: + """Extract search keywords from problem extension results.""" + keywords = [] + context = problem_extension.get("context", {}) + if isinstance(context, dict): + for original_q, extended_qs in context.items(): + keywords.append(original_q) + if isinstance(extended_qs, list): + keywords.extend(extended_qs) + return keywords + + +async def perceptual_retrieve_node(state: ReadState) -> ReadState: + """ + LangGraph node: perceptual memory retrieval. + + Uses PerceptualSearchService to run keyword + embedding search with + BM25 fusion reranking, then writes results to state['perceptual_data']. + """ + end_user_id = state.get("end_user_id", "") + problem_extension = state.get("problem_extension", {}) + original_query = state.get("data", "") + memory_config = state.get("memory_config", None) + + logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}") + + keywords = _extract_keywords_from_problems(problem_extension) + if not keywords: + keywords = [original_query] if original_query else [] + + logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted") + + service = PerceptualSearchService( + end_user_id=end_user_id, + memory_config=memory_config, + ) + search_result = await service.search( + query=original_query, + keywords=keywords, + limit=10, + ) + + result = { + "memories": search_result["memories"], + "content": search_result["content"], + "_intermediate": { + "type": "perceptual_retrieve", + "title": "感知记忆检索", + "data": search_result["memories"], + "query": original_query, + "result_count": len(search_result["memories"]), + }, + } + return {"perceptual_data": result} diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 3030669c..2d6eaa81 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState: logger.info(f"Problem extension result: {aggregated_dict}") # Emit intermediate output for frontend - print(time.time() - start) result = { "context": aggregated_dict, "original": data, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index d967a285..1bf68966 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -1,7 +1,11 @@ +import asyncio import os import time from app.core.logging_config import get_agent_logger, log_time +from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import ( + PerceptualSearchService, +) from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, SummaryResponse, @@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState: try: if storage_type != "rag": - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search( + + async def _perceptual_search(): + service = PerceptualSearchService( + end_user_id=end_user_id, + memory_config=memory_config, + ) + return await service.search(query=data, limit=5) + + hybrid_task = SearchService().execute_hybrid_search( **search_params, memory_config=memory_config, - expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement + expand_communities=False, ) + perceptual_task = _perceptual_search() + + gather_results = await asyncio.gather( + hybrid_task, perceptual_task, return_exceptions=True + ) + hybrid_result = gather_results[0] + perceptual_results = gather_results[1] + + # 处理 hybrid search 异常 + if isinstance(hybrid_result, Exception): + raise hybrid_result + retrieve_info, question, raw_results = hybrid_result + + # 处理感知记忆结果 + if isinstance(perceptual_results, Exception): + logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}") + perceptual_results = [] + + # 拼接感知记忆内容到 retrieve_info + if perceptual_results and isinstance(perceptual_results, dict): + perceptual_content = perceptual_results.get("content", "") + if perceptual_content: + retrieve_info = f"{retrieve_info}\n\n\n{perceptual_content}" + count = len(perceptual_results.get("memories", [])) + logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)") + # 调试:打印 community 检索结果数量 if raw_results and isinstance(raw_results, dict): reranked = raw_results.get('reranked_results', {}) @@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState: "error": str(e) } end = time.time() - try: - duration = end - start - except Exception: - duration = 0.0 + duration = end - start log_time('检索', duration) return {"summary": summary} @@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState: retrieve_info_str = list(set(retrieve_info_str)) retrieve_info_str = '\n'.join(retrieve_info_str) - aimessages = await summary_llm(state, history, retrieve_info_str, - 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") + # Merge perceptual memory content + perceptual_data = state.get("perceptual_data", {}) + perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else "" + if perceptual_content: + retrieve_info_str = f"{retrieve_info_str}\n\n\n{perceptual_content}" + + aimessages = await summary_llm( + state, + history, + retrieve_info_str, + 'direct_summary_prompt.jinja2', + 'retrieve_summary', RetrieveSummaryResponse, + "1" + ) if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) if aimessages == '': @@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState: retrieve_info_str += i + '\n' history = await summary_history(state) + # Merge perceptual memory content + perceptual_data = state.get("perceptual_data", {}) + perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else "" + if perceptual_content: + retrieve_info_str = f"{retrieve_info_str}\n\n\n{perceptual_content}" + data = { "query": query, "history": history, @@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState: if key == 'answer_small': for i in value: retrieve_info_str += i + '\n' + + # Merge perceptual memory content + perceptual_data = state.get("perceptual_data", {}) + perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else "" + if perceptual_content: + retrieve_info_str = f"{retrieve_info_str}\n\n\n{perceptual_content}" + data = { "query": query, "history": history, diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index e698e6ad..d3ca4ea7 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( retrieve_nodes, ) +from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import ( + perceptual_retrieve_node, +) from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( Input_Summary, Retrieve_Summary, @@ -48,13 +51,14 @@ async def make_read_graph(): """ try: # Build workflow graph - workflow = StateGraph(ReadState) + workflow = StateGraph(ReadState) workflow.add_node("content_input", content_input_node) workflow.add_node("Split_The_Problem", Split_The_Problem) workflow.add_node("Problem_Extension", Problem_Extension) workflow.add_node("Input_Summary", Input_Summary) workflow.add_node("Retrieve", retrieve_nodes) # workflow.add_node("Retrieve", retrieve) + workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node) workflow.add_node("Verify", Verify) workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Summary", Summary) @@ -65,14 +69,15 @@ async def make_read_graph(): workflow.add_conditional_edges("content_input", Split_continue) workflow.add_edge("Input_Summary", END) workflow.add_edge("Split_The_Problem", "Problem_Extension") - workflow.add_edge("Problem_Extension", "Retrieve") + # After Problem_Extension, retrieve perceptual memory first, then main Retrieve + workflow.add_edge("Problem_Extension", "Perceptual_Retrieve") + workflow.add_edge("Perceptual_Retrieve", "Retrieve") workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_edge("Retrieve_Summary", END) workflow.add_conditional_edges("Verify", Verify_continue) workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) - '''-----''' # workflow.add_edge("Retrieve", END) # Compile workflow @@ -80,7 +85,5 @@ async def make_read_graph(): yield graph except Exception as e: - print(f"创建工作流失败: {e}") + logger.error(f"创建工作流失败: {e}") raise - finally: - print("工作流创建完成") diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 90b1c088..eaa5f0ab 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger from app.core.memory.src.search import run_hybrid_search from app.core.memory.utils.data.text_utils import escape_lucene_query - logger = get_agent_logger(__name__) # 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化) @@ -31,10 +30,10 @@ def _clean_expand_fields(obj): async def expand_communities_to_statements( - community_results: List[dict], - end_user_id: str, - existing_content: str = "", - limit: int = 10, + community_results: List[dict], + end_user_id: str, + existing_content: str = "", + limit: int = 10, ) -> Tuple[List[dict], List[str]]: """ 社区展开 helper:给定命中的 community 列表,拉取关联 Statement。 @@ -76,17 +75,18 @@ async def expand_communities_to_statements( if s.get("statement") and s["statement"] not in existing_lines ] cleaned = _clean_expand_fields(expanded_stmts) - logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}") + logger.info( + f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}") return cleaned, new_texts class SearchService: """Service for executing hybrid search and processing results.""" - + def __init__(self): """Initialize the search service.""" logger.info("SearchService initialized") - + def extract_content_from_result(self, result: dict, node_type: str = "") -> str: """ Extract only meaningful content from search results, dropping all metadata. @@ -107,19 +107,19 @@ class SearchService: """ if not isinstance(result, dict): return str(result) - + content_parts = [] - + # Statements: extract statement field if 'statement' in result and result['statement']: content_parts.append(result['statement']) - + # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 is_community = ( - node_type == "community" - or 'member_count' in result - or 'core_entities' in result + node_type == "community" + or 'member_count' in result + or 'core_entities' in result ) if is_community: name = result.get('name', '') @@ -130,16 +130,16 @@ class SearchService: elif 'content' in result and result['content']: # Summaries / Chunks content_parts.append(result['content']) - + # Entities: extract name and fact_summary (commented out in original) # if 'name' in result and result['name']: # content_parts.append(result['name']) # if result.get('fact_summary'): # content_parts.append(result['fact_summary']) - + # Return concatenated content or empty string return '\n'.join(content_parts) if content_parts else "" - + def clean_query(self, query: str) -> str: """ Clean and escape query text for Lucene. @@ -155,33 +155,33 @@ class SearchService: Cleaned and escaped query string """ q = str(query).strip() - + # Remove wrapping quotes if (q.startswith("'") and q.endswith("'")) or ( - q.startswith('"') and q.endswith('"') + q.startswith('"') and q.endswith('"') ): q = q[1:-1] - + # Remove newlines and carriage returns q = q.replace('\r', ' ').replace('\n', ' ').strip() - + # Apply Lucene escaping q = escape_lucene_query(q) - + return q - + async def execute_hybrid_search( - self, - end_user_id: str, - question: str, - limit: int = 5, - search_type: str = "hybrid", - include: Optional[List[str]] = None, - rerank_alpha: float = 0.4, - output_path: str = "search_results.json", - return_raw_results: bool = False, - memory_config = None, - expand_communities: bool = True, + self, + end_user_id: str, + question: str, + limit: int = 5, + search_type: str = "hybrid", + include: Optional[List[str]] = None, + rerank_alpha: float = 0.4, + output_path: str = "search_results.json", + return_raw_results: bool = False, + memory_config=None, + expand_communities: bool = True, ) -> Tuple[str, str, Optional[dict]]: """ Execute hybrid search and return clean content. @@ -205,10 +205,10 @@ class SearchService: """ if include is None: include = ["statements", "chunks", "entities", "summaries", "communities"] - + # Clean query cleaned_query = self.clean_query(question) - + try: # Execute search answer = await run_hybrid_search( @@ -221,18 +221,18 @@ class SearchService: memory_config=memory_config, rerank_alpha=rerank_alpha ) - + # Extract results based on search type and include parameter # Prioritize summaries as they contain synthesized contextual information answer_list = [] - + # For hybrid search, use reranked_results if search_type == "hybrid": reranked_results = answer.get('reranked_results', {}) - + # Priority order: summaries first (most contextual), then communities, statements, chunks, entities priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] - + for category in priority_order: if category in include and category in reranked_results: category_results = reranked_results[category] @@ -242,7 +242,7 @@ class SearchService: # For keyword or embedding search, results are directly in answer dict # Apply same priority order priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] - + for category in priority_order: if category in include and category in answer: category_results = answer[category] @@ -261,7 +261,7 @@ class SearchService: end_user_id=end_user_id, ) answer_list.extend(cleaned_stmts) - + # Extract clean content from all results,按类型传入 node_type 区分 community content_list = [] for ans in answer_list: @@ -269,19 +269,18 @@ class SearchService: ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" content_list.append(self.extract_content_from_result(ans, node_type=ntype)) - # Filter out empty strings and join with newlines clean_content = '\n'.join([c for c in content_list if c]) - + # Log first 200 chars logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...") - + # Return raw results if requested if return_raw_results: return clean_content, cleaned_query, answer else: return clean_content, cleaned_query, None - + except Exception as e: logger.error( f"Search failed for query '{question}' in group '{end_user_id}': {e}", diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index ea8add48..21bc1777 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -1,4 +1,3 @@ -import os from collections import defaultdict from pathlib import Path from typing import Annotated, TypedDict @@ -52,6 +51,7 @@ class ReadState(TypedDict): embedding_id: str memory_config: object # 新增字段用于传递内存配置对象 retrieve: dict + perceptual_data: dict RetrieveSummary: dict InputSummary: dict verify: dict diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index a3c40dcd..ef39a12e 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -43,6 +43,7 @@ load_dotenv() logger = get_memory_logger(__name__) + def _parse_datetime(value: Any) -> Optional[datetime]: """Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'.""" if value is None: @@ -75,7 +76,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") if score_field == "activation_value" and score is None: scores.append(None) # 保持 None,稍后特殊处理 continue - + if score is not None and isinstance(score, (int, float)): scores.append(float(score)) else: @@ -83,10 +84,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") if not scores: return results - + # 过滤掉 None 值,只对有效分数进行归一化 valid_scores = [s for s in scores if s is not None] - + if not valid_scores: # 所有分数都是 None,不进行归一化 for item in results: @@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") item[f"normalized_{score_field}"] = None return results - if len(valid_scores) == 1: # Single valid score, set to 1.0 + if len(valid_scores) == 1: # Single valid score, set to 1.0 for item, score in zip(results, scores): if score_field in item or score_field == "activation_value": if score is None: @@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") return results - def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ Remove duplicate items from search results based on content. @@ -150,52 +150,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: seen_ids = set() seen_content = set() deduplicated = [] - + for item in items: # Try multiple ID fields to identify unique items item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") - + # Extract content from various possible fields content = ( - item.get("text") or - item.get("content") or - item.get("statement") or - item.get("name") or - "" + item.get("text") or + item.get("content") or + item.get("statement") or + item.get("name") or + "" ) - + # Normalize content for comparison (strip whitespace and lowercase) normalized_content = str(content).strip().lower() if content else "" - + # Check if we've seen this ID or content before is_duplicate = False - + if item_id and item_id in seen_ids: is_duplicate = True elif normalized_content and normalized_content in seen_content: # Only check content duplication if content is not empty is_duplicate = True - + if not is_duplicate: # Mark as seen if item_id: seen_ids.add(item_id) if normalized_content: # Only track non-empty content seen_content.add(normalized_content) - + deduplicated.append(item) - + return deduplicated def rerank_with_activation( - keyword_results: Dict[str, List[Dict[str, Any]]], - embedding_results: Dict[str, List[Dict[str, Any]]], - alpha: float = 0.6, - limit: int = 10, - forgetting_config: ForgettingEngineConfig | None = None, - activation_boost_factor: float = 0.8, - now: datetime | None = None, + keyword_results: Dict[str, List[Dict[str, Any]]], + embedding_results: Dict[str, List[Dict[str, Any]]], + alpha: float = 0.6, + limit: int = 10, + forgetting_config: ForgettingEngineConfig | None = None, + activation_boost_factor: float = 0.8, + now: datetime | None = None, + content_score_threshold: float = 0.5, ) -> Dict[str, List[Dict[str, Any]]]: """ 两阶段排序:先按内容相关性筛选,再按激活值排序。 @@ -222,6 +223,8 @@ def rerank_with_activation( forgetting_config: 遗忘引擎配置(当前未使用) activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8) now: 当前时间(用于遗忘计算) + content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score), + 低于此阈值的结果会被过滤。默认 0.5。 返回: 带评分元数据的重排序结果,按 final_score 排序 @@ -229,26 +232,26 @@ def rerank_with_activation( # 验证权重范围 if not (0 <= alpha <= 1): raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}") - + # 初始化遗忘引擎(如果需要) engine = None if forgetting_config: engine = ForgettingEngine(forgetting_config) now_dt = now or datetime.now() - + reranked: Dict[str, List[Dict[str, Any]]] = {} - + for category in ["statements", "chunks", "entities", "summaries", "communities"]: keyword_items = keyword_results.get(category, []) embedding_items = embedding_results.get(category, []) - + # 步骤 1: 归一化分数 keyword_items = normalize_scores(keyword_items, "score") embedding_items = normalize_scores(embedding_items, "score") - + # 步骤 2: 按 ID 合并结果(去重) combined_items: Dict[str, Dict[str, Any]] = {} - + # 添加关键词结果 for item in keyword_items: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") @@ -257,7 +260,7 @@ def rerank_with_activation( combined_items[item_id] = item.copy() combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) combined_items[item_id]["embedding_score"] = 0 # 默认值 - + # 添加或更新向量嵌入结果 for item in embedding_items: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") @@ -271,18 +274,18 @@ def rerank_with_activation( combined_items[item_id] = item.copy() combined_items[item_id]["bm25_score"] = 0 # 默认值 combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - + # 步骤 3: 归一化激活度分数 # 为所有项准备激活度值列表 items_list = list(combined_items.values()) items_list = normalize_scores(items_list, "activation_value") - + # 更新 combined_items 中的归一化激活度分数 for item in items_list: item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") if item_id and item_id in combined_items: combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value") - + # 步骤 4: 计算基础分数和最终分数 for item_id, item in combined_items.items(): bm25_norm = float(item.get("bm25_score", 0) or 0) @@ -290,45 +293,45 @@ def rerank_with_activation( # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义 raw_act_norm = item.get("normalized_activation_value") act_norm = float(raw_act_norm) if raw_act_norm is not None else None - + # 第一阶段:只考虑内容相关性(BM25 + Embedding) # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 content_score = alpha * bm25_norm + (1 - alpha) * emb_norm base_score = content_score # 第一阶段用内容分数 - + # 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序) item["activation_score"] = act_norm # 可能为 None item["content_score"] = content_score item["base_score"] = base_score - + # 步骤 5: 应用遗忘曲线(可选) if engine: # 计算受激活度影响的记忆强度 importance = float(item.get("importance_score", 0.5) or 0.5) - + # 获取 activation_value activation_val = item.get("activation_value") - + # 只对有激活值的节点应用遗忘曲线 if activation_val is not None and isinstance(activation_val, (int, float)): activation_val = float(activation_val) - + # 计算记忆强度:importance_score × (1 + activation_value × boost_factor) memory_strength = importance * (1 + activation_val * activation_boost_factor) - + # 计算经过的时间(天数) dt = _parse_datetime(item.get("created_at")) if dt is None: time_elapsed_days = 0.0 else: time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) - + # 获取遗忘权重 forgetting_weight = engine.calculate_weight( time_elapsed=time_elapsed_days, memory_strength=memory_strength ) - + # 应用到基础分数 item["forgetting_weight"] = forgetting_weight item["final_score"] = base_score * forgetting_weight @@ -338,7 +341,7 @@ def rerank_with_activation( else: # 不使用遗忘曲线 item["final_score"] = base_score - + # 步骤 6: 两阶段排序和限制 # 第一阶段:按内容相关性(base_score)排序,取 Top-K first_stage_limit = limit * 3 # 可配置,取3倍候选 @@ -347,11 +350,11 @@ def rerank_with_activation( key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序 reverse=True )[:first_stage_limit] - + # 第二阶段:分离有激活值和无激活值的节点 items_with_activation = [] items_without_activation = [] - + for item in first_stage_sorted: activation_score = item.get("activation_score") # 检查是否有有效的激活值(不是 None) @@ -359,14 +362,14 @@ def rerank_with_activation( items_with_activation.append(item) else: items_without_activation.append(item) - + # 优先按激活值排序有激活值的节点 sorted_with_activation = sorted( items_with_activation, key=lambda x: float(x.get("activation_score", 0) or 0), reverse=True ) - + # 如果有激活值的节点不足 limit,用无激活值的节点补充 if len(sorted_with_activation) < limit: needed = limit - len(sorted_with_activation) @@ -374,7 +377,7 @@ def rerank_with_activation( sorted_items = sorted_with_activation + items_without_activation[:needed] else: sorted_items = sorted_with_activation[:limit] - + # 两阶段排序完成,更新 final_score 以反映实际排序依据 # Stage 1: 按 content_score 筛选候选(已完成) # Stage 2: 按 activation_score 排序(已完成) @@ -390,16 +393,29 @@ def rerank_with_activation( else: # 无激活值:使用内容相关性分数 item["final_score"] = item.get("base_score", 0) - - # 最终去重确保没有重复项 + + if content_score_threshold > 0: + before_count = len(sorted_items) + sorted_items = [ + item for item in sorted_items + if float(item.get("content_score", 0) or 0) >= content_score_threshold + ] + filtered_count = before_count - len(sorted_items) + if filtered_count > 0: + logger.info( + f"[rerank] {category}: filtered {filtered_count}/{before_count} " + f"items below content_score_threshold={content_score_threshold}" + ) + sorted_items = _deduplicate_results(sorted_items) - + reranked[category] = sorted_items - + return reranked -def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None): +def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], + log_file: str = None): """Log search query information using the logger. Args: @@ -412,7 +428,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None, """ # Ensure the query text is plain and clean before logging cleaned_query = extract_plain_query(query_text) - + # Log using the standard logger logger.info( f"Search query: query='{cleaned_query}', type={search_type}, " @@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any: def apply_reranker_placeholder( - results: Dict[str, List[Dict[str, Any]]], - query_text: str, + results: Dict[str, List[Dict[str, Any]]], + query_text: str, ) -> Dict[str, List[Dict[str, Any]]]: """ Placeholder for a cross-encoder reranker. @@ -483,7 +499,7 @@ def apply_reranker_placeholder( # ) -> Dict[str, List[Dict[str, Any]]]: # """ # Apply LLM-based reranking to search results. - + # Args: # results: Search results organized by category # query_text: Original search query @@ -491,7 +507,7 @@ def apply_reranker_placeholder( # llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM) # top_k: Maximum number of items to rerank per category # batch_size: Number of items to process concurrently - + # Returns: # Reranked results with final_score and reranker_model fields # """ @@ -501,18 +517,18 @@ def apply_reranker_placeholder( # # except Exception as e: # # logger.debug(f"Failed to load reranker config: {e}") # # rc = {} - + # # Check if reranking is enabled # enabled = rc.get("enabled", False) # if not enabled: # logger.debug("LLM reranking is disabled in configuration") # return results - + # # Load configuration parameters with defaults # llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5) # top_k = top_k if top_k is not None else rc.get("top_k", 20) # batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5) - + # # Initialize reranker client if not provided # if reranker_client is None: # try: @@ -520,10 +536,10 @@ def apply_reranker_placeholder( # except Exception as e: # logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking") # return results - + # # Get model name for metadata # model_name = getattr(reranker_client, 'model_name', 'unknown') - + # # Process each category # reranked_results = {} # for category in ["statements", "chunks", "entities", "summaries"]: @@ -531,38 +547,38 @@ def apply_reranker_placeholder( # if not items: # reranked_results[category] = [] # continue - + # # Select top K items by combined_score for reranking # sorted_items = sorted( # items, # key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0), # reverse=True # ) - + # top_items = sorted_items[:top_k] # remaining_items = sorted_items[top_k:] - + # # Extract text content from each item # def extract_text(item: Dict[str, Any]) -> str: # """Extract text content from a result item.""" # # Try different text fields based on category # text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or "" # return str(text).strip() - + # # Batch items for concurrent processing # batches = [] # for i in range(0, len(top_items), batch_size): # batch = top_items[i:i + batch_size] # batches.append(batch) - + # # Process batches concurrently # async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # """Process a batch of items with LLM relevance scoring.""" # scored_batch = [] - + # for item in batch: # item_text = extract_text(item) - + # # Skip items with no text # if not item_text: # item_copy = item.copy() @@ -572,7 +588,7 @@ def apply_reranker_placeholder( # item_copy["reranker_model"] = model_name # scored_batch.append(item_copy) # continue - + # # Create relevance scoring prompt # prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0. @@ -585,15 +601,15 @@ def apply_reranker_placeholder( # - 1.0 means perfectly relevant # Relevance score:""" - + # # Send request to LLM # try: # messages = [{"role": "user", "content": prompt}] # response = await reranker_client.chat(messages) - + # # Parse LLM response to extract relevance score # response_text = str(response.content if hasattr(response, 'content') else response).strip() - + # # Try to extract a float from the response # try: # # Remove any non-numeric characters except decimal point @@ -608,11 +624,11 @@ def apply_reranker_placeholder( # except (ValueError, AttributeError) as e: # logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}") # llm_score = None - + # # Calculate final score # item_copy = item.copy() # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) - + # if llm_score is not None: # final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score # item_copy["llm_relevance_score"] = llm_score @@ -620,7 +636,7 @@ def apply_reranker_placeholder( # # Use combined_score as fallback # final_score = combined_score # item_copy["llm_relevance_score"] = combined_score - + # item_copy["final_score"] = final_score # item_copy["reranker_model"] = model_name # scored_batch.append(item_copy) @@ -632,14 +648,14 @@ def apply_reranker_placeholder( # item_copy["llm_relevance_score"] = combined_score # item_copy["reranker_model"] = model_name # scored_batch.append(item_copy) - + # return scored_batch - + # # Process all batches concurrently # try: # batch_tasks = [process_batch(batch) for batch in batches] # batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) - + # # Merge batch results # scored_items = [] # for result in batch_results: @@ -647,7 +663,7 @@ def apply_reranker_placeholder( # logger.warning(f"Batch processing failed: {result}") # continue # scored_items.extend(result) - + # # Add remaining items (not in top K) with their combined_score as final_score # for item in remaining_items: # item_copy = item.copy() @@ -655,11 +671,11 @@ def apply_reranker_placeholder( # item_copy["final_score"] = combined_score # item_copy["reranker_model"] = model_name # scored_items.append(item_copy) - + # # Sort all items by final_score in descending order # scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True) # reranked_results[category] = scored_items - + # except Exception as e: # logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results") # # Return original items with combined_score as final_score @@ -668,22 +684,22 @@ def apply_reranker_placeholder( # item["final_score"] = combined_score # item["reranker_model"] = model_name # reranked_results[category] = items - + # return reranked_results async def run_hybrid_search( - query_text: str, - search_type: str, - end_user_id: str | None, - limit: int, - include: List[str], - output_path: str | None, - memory_config: "MemoryConfig", - rerank_alpha: float = 0.6, - activation_boost_factor: float = 0.8, - use_forgetting_rerank: bool = False, - use_llm_rerank: bool = False, + query_text: str, + search_type: str, + end_user_id: str | None, + limit: int, + include: List[str], + output_path: str | None, + memory_config: "MemoryConfig", + rerank_alpha: float = 0.6, + activation_boost_factor: float = 0.8, + use_forgetting_rerank: bool = False, + use_llm_rerank: bool = False, ): """ @@ -699,7 +715,7 @@ async def run_hybrid_search( # Clean and normalize the incoming query before use/logging query_text = extract_plain_query(query_text) - + # Validate query is not empty after cleaning if not query_text or not query_text.strip(): logger.warning("Empty query after cleaning, returning empty results") @@ -716,7 +732,7 @@ async def run_hybrid_search( "error": "Empty query" } } - + # Log the search query log_search_query(query_text, search_type, end_user_id, limit, include) @@ -747,7 +763,7 @@ async def run_hybrid_search( # Embedding-based search logger.info("[PERF] Starting embedding search...") embedding_start = time.time() - + # 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig config_load_start = time.time() try: @@ -768,7 +784,7 @@ async def run_hybrid_search( embedder = OpenAIEmbedderClient(model_config=rb_config) embedder_init_time = time.time() - embedder_init_start logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s") - + embedding_task = asyncio.create_task( search_graph_by_embedding( connector=connector, @@ -788,7 +804,7 @@ async def run_hybrid_search( if keyword_task: keyword_results = await keyword_task - keyword_latency = time.time() - keyword_start + keyword_latency = time.time() - search_start_time latency_metrics["keyword_search_latency"] = round(keyword_latency, 4) logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s") if search_type == "keyword": @@ -798,7 +814,7 @@ async def run_hybrid_search( if embedding_task: embedding_results = await embedding_task - embedding_latency = time.time() - embedding_start + embedding_latency = time.time() - search_start_time latency_metrics["embedding_search_latency"] = round(embedding_latency, 4) logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s") if search_type == "embedding": @@ -810,7 +826,8 @@ async def run_hybrid_search( if search_type == "hybrid": results["combined_summary"] = { "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), - "total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), + "total_embedding_results": sum( + len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "search_query": query_text, "search_timestamp": datetime.now().isoformat() } @@ -818,7 +835,7 @@ async def run_hybrid_search( # Apply two-stage reranking with ACTR activation calculation rerank_start = time.time() logger.info("[PERF] Using two-stage reranking with ACTR activation") - + # 加载遗忘引擎配置 config_start = time.time() try: @@ -829,7 +846,7 @@ async def run_hybrid_search( forgetting_cfg = ForgettingEngineConfig() config_time = time.time() - config_start logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s") - + # 统一使用激活度重排序(两阶段:检索 + ACTR计算) rerank_compute_start = time.time() reranked_results = rerank_with_activation( @@ -842,14 +859,14 @@ async def run_hybrid_search( ) rerank_compute_time = time.time() - rerank_compute_start logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s") - + rerank_latency = time.time() - rerank_start latency_metrics["reranking_latency"] = round(rerank_latency, 4) logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s") - + # Optional: apply reranker placeholder if enabled via config reranked_results = apply_reranker_placeholder(reranked_results, query_text) - + # Apply LLM reranking if enabled llm_rerank_applied = False # if use_llm_rerank: @@ -862,11 +879,12 @@ async def run_hybrid_search( # logger.info("LLM reranking applied successfully") # except Exception as e: # logger.warning(f"LLM reranking failed: {e}, using previous scores") - + results["reranked_results"] = reranked_results results["combined_summary"] = { "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), - "total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), + "total_embedding_results": sum( + len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()), "search_query": query_text, "search_timestamp": datetime.now().isoformat(), @@ -879,13 +897,13 @@ async def run_hybrid_search( # Calculate total latency total_latency = time.time() - search_start_time latency_metrics["total_latency"] = round(total_latency, 4) - + # Add latency metrics to results if "combined_summary" in results: results["combined_summary"]["latency_metrics"] = latency_metrics else: results["latency_metrics"] = latency_metrics - + logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====") logger.info(f"[PERF] Total search completed in {total_latency:.4f}s") logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}") @@ -908,8 +926,10 @@ async def run_hybrid_search( # Log search completion with result count if search_type == "hybrid": result_counts = { - "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()}, - "embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()} + "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in + keyword_results.items()}, + "embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in + embedding_results.items()} } else: result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()} @@ -927,12 +947,12 @@ async def run_hybrid_search( async def search_by_temporal( - end_user_id: Optional[str] = "test", - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 1, + end_user_id: Optional[str] = "test", + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 1, ): """ Temporal search across Statements. @@ -968,13 +988,13 @@ async def search_by_temporal( async def search_by_keyword_temporal( - query_text: str, - end_user_id: Optional[str] = "test", - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 1, + query_text: str, + end_user_id: Optional[str] = "test", + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 1, ): """ Temporal keyword search across Statements. @@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal( async def search_chunk_by_chunk_id( - chunk_id: str, - end_user_id: Optional[str] = "test", - limit: int = 1, + chunk_id: str, + end_user_id: Optional[str] = "test", + limit: int = 1, ): """ Search for Chunks by chunk_id. @@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id( limit=limit ) return {"chunks": chunks} - diff --git a/api/app/main.py b/api/app/main.py index 9e501f11..a8223a49 100644 --- a/api/app/main.py +++ b/api/app/main.py @@ -62,6 +62,7 @@ async def lifespan(app: FastAPI): else: logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") await create_all_indexes() + logger.info("All neo4j indexes and constraints created successfully!") logger.info("应用程序启动完成") diff --git a/api/app/repositories/neo4j/create_indexes.py b/api/app/repositories/neo4j/create_indexes.py index 5132aa09..7caeea8a 100644 --- a/api/app/repositories/neo4j/create_indexes.py +++ b/api/app/repositories/neo4j/create_indexes.py @@ -1,17 +1,17 @@ -import asyncio from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + async def create_fulltext_indexes(): """Create full-text indexes for keyword search with BM25 scoring.""" connector = Neo4jConnector() try: - # 创建 Statements 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - + """) + # # 创建 Dialogues 索引 # await connector.execute_query(""" # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] @@ -21,27 +21,35 @@ async def create_fulltext_indexes(): await connector.execute_query(""" CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - + """) + # 创建 Chunks 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) - + """) + # 创建 MemorySummary 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } - """) + """) # 创建 Community 索引 await connector.execute_query(""" CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } """) - + + # 创建 Perceptual 感知记忆索引 + await connector.execute_query(""" + CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain] + OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } + """) + finally: await connector.close() + + async def create_vector_indexes(): """Create vector indexes for fast embedding similarity search. @@ -50,8 +58,7 @@ async def create_vector_indexes(): """ connector = Neo4jConnector() try: - - + # Statement embedding index await connector.execute_query(""" CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS @@ -62,8 +69,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - - + # Chunk embedding index await connector.execute_query(""" CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS @@ -75,7 +81,6 @@ async def create_vector_indexes(): }} """) - # Entity name embedding index await connector.execute_query(""" CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS @@ -86,8 +91,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - - + # Memory summary embedding index await connector.execute_query(""" CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS @@ -98,7 +102,7 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - + # Community summary embedding index await connector.execute_query(""" CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS @@ -108,8 +112,8 @@ async def create_vector_indexes(): `vector.dimensions`: 1024, `vector.similarity_function`: 'cosine' }} - """) - + """) + # Dialogue embedding index (optional) await connector.execute_query(""" CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS @@ -120,15 +124,27 @@ async def create_vector_indexes(): `vector.similarity_function`: 'cosine' }} """) - + + # Perceptual summary embedding index + await connector.execute_query(""" + CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS + FOR (p:Perceptual) + ON p.summary_embedding + OPTIONS {indexConfig: { + `vector.dimensions`: 1024, + `vector.similarity_function`: 'cosine' + }} + """) finally: await connector.close() + + async def create_unique_constraints(): """Create uniqueness constraints for core node identifiers. Ensures concurrent MERGE operations remain safe and prevents duplicates. """ connector = Neo4jConnector() - try: + try: # Dialogue.id unique await connector.execute_query( """ @@ -136,7 +152,7 @@ async def create_unique_constraints(): FOR (d:Dialogue) REQUIRE d.id IS UNIQUE """ ) - + # Statement.id unique await connector.execute_query( """ @@ -144,7 +160,7 @@ async def create_unique_constraints(): FOR (s:Statement) REQUIRE s.id IS UNIQUE """ ) - + # Chunk.id unique await connector.execute_query( """ @@ -152,13 +168,13 @@ async def create_unique_constraints(): FOR (c:Chunk) REQUIRE c.id IS UNIQUE """ ) - + finally: await connector.close() + + async def create_all_indexes(): """Create all indexes and constraints in one go.""" await create_fulltext_indexes() await create_vector_indexes() await create_unique_constraints() - print("✓ All indexes and constraints created successfully!") - diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 26ffe350..aa246829 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id, r.created_at = edge.created_at RETURN elementId(r) AS uuid """ + +SEARCH_PERCEPTUAL_BY_KEYWORD = """ +CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score +WHERE p.end_user_id = $end_user_id +RETURN p.id AS id, + p.end_user_id AS end_user_id, + p.perceptual_type AS perceptual_type, + p.file_path AS file_path, + p.file_name AS file_name, + p.file_ext AS file_ext, + p.summary AS summary, + p.keywords AS keywords, + p.topic AS topic, + p.domain AS domain, + p.created_at AS created_at, + p.file_type AS file_type, + score +ORDER BY score DESC +LIMIT $limit +""" + +PERCEPTUAL_EMBEDDING_SEARCH = """ +CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding) +YIELD node AS p, score +WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id +RETURN p.id AS id, + p.end_user_id AS end_user_id, + p.perceptual_type AS perceptual_type, + p.file_path AS file_path, + p.file_name AS file_name, + p.file_ext AS file_ext, + p.summary AS summary, + p.keywords AS keywords, + p.topic AS topic, + p.domain AS domain, + p.created_at AS created_at, + p.file_type AS file_type, + score +ORDER BY score DESC +LIMIT $limit +""" diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index c5d3bcca..32ec4474 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import ( ENTITY_EMBEDDING_SEARCH, EXPAND_COMMUNITY_STATEMENTS, MEMORY_SUMMARY_EMBEDDING_SEARCH, + PERCEPTUAL_EMBEDDING_SEARCH, SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNKS_BY_CONTENT, SEARCH_COMMUNITIES_BY_KEYWORD, @@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import ( SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, + SEARCH_PERCEPTUAL_BY_KEYWORD, SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, @@ -34,11 +36,11 @@ logger = logging.getLogger(__name__) async def _update_activation_values_batch( - connector: Neo4jConnector, - nodes: List[Dict[str, Any]], - node_label: str, - end_user_id: Optional[str] = None, - max_retries: int = 3 + connector: Neo4jConnector, + nodes: List[Dict[str, Any]], + node_label: str, + end_user_id: Optional[str] = None, + max_retries: int = 3 ) -> List[Dict[str, Any]]: """ 批量更新节点的激活值 @@ -58,7 +60,7 @@ async def _update_activation_values_batch( """ if not nodes: return [] - + # 延迟导入以避免循环依赖 from app.core.memory.storage_services.forgetting_engine.access_history_manager import ( AccessHistoryManager, @@ -66,7 +68,7 @@ async def _update_activation_values_batch( from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( ACTRCalculator, ) - + # 创建计算器和管理器实例 actr_calculator = ACTRCalculator() access_manager = AccessHistoryManager( @@ -74,7 +76,7 @@ async def _update_activation_values_batch( actr_calculator=actr_calculator, max_retries=max_retries ) - + # 提取节点ID列表并去重(保持原始顺序) seen_ids = set() unique_node_ids = [] @@ -83,7 +85,7 @@ async def _update_activation_values_batch( if node_id and node_id not in seen_ids: seen_ids.add(node_id) unique_node_ids.append(node_id) - + if not unique_node_ids: logger.warning(f"批量更新激活值:没有有效的节点ID") return nodes @@ -95,7 +97,7 @@ async def _update_activation_values_batch( f"批量更新激活值:检测到重复节点,具有有效ID的节点数量={id_nodes_count}, " f"去重后唯一ID数量={len(unique_node_ids)}" ) - + # 批量记录访问 try: updated_nodes = await access_manager.record_batch_access( @@ -103,14 +105,14 @@ async def _update_activation_values_batch( node_label=node_label, end_user_id=end_user_id ) - + logger.info( f"批量更新激活值成功: {node_label}, " f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}" ) - + return updated_nodes - + except Exception as e: logger.error( f"批量更新激活值失败: {node_label}, 错误: {str(e)}" @@ -120,9 +122,9 @@ async def _update_activation_values_batch( async def _update_search_results_activation( - connector: Neo4jConnector, - results: Dict[str, List[Dict[str, Any]]], - end_user_id: Optional[str] = None + connector: Neo4jConnector, + results: Dict[str, List[Dict[str, Any]]], + end_user_id: Optional[str] = None ) -> Dict[str, List[Dict[str, Any]]]: """ 更新搜索结果中所有知识节点的激活值 @@ -144,11 +146,11 @@ async def _update_search_results_activation( 'entities': 'ExtractedEntity', 'summaries': 'MemorySummary' } - + # 并行更新所有类型的节点 update_tasks = [] update_keys = [] - + for key, label in knowledge_node_types.items(): if key in results and results[key]: update_tasks.append( @@ -160,13 +162,13 @@ async def _update_search_results_activation( ) ) update_keys.append(key) - + if not update_tasks: return results - + # 并行执行所有更新 update_results = await asyncio.gather(*update_tasks, return_exceptions=True) - + # 更新结果字典,保留原始搜索分数 updated_results = results.copy() for key, update_result in zip(update_keys, update_results): @@ -175,10 +177,10 @@ async def _update_search_results_activation( # 保留原始的 score 字段(BM25/Embedding 分数) original_nodes = results[key] updated_nodes = update_result - + # 创建 ID 到更新节点的映射(用于快速查找激活值数据) updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')} - + # 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充 merged_nodes = [] for original_node in original_nodes: @@ -186,7 +188,7 @@ async def _update_search_results_activation( if node_id and node_id in updated_map: # 从原始节点开始,用更新后的激活值数据覆盖 merged_node = original_node.copy() - + # 更新激活值相关字段 activation_fields = { 'activation_value', @@ -196,35 +198,35 @@ async def _update_search_results_activation( 'importance_score', 'version', 'statement', # Statement 节点的内容字段 - 'content' # MemorySummary 节点的内容字段 + 'content' # MemorySummary 节点的内容字段 } - + # 只更新激活值相关字段,保留原始节点的其他字段 for field in activation_fields: if field in updated_map[node_id]: merged_node[field] = updated_map[node_id][field] - + merged_nodes.append(merged_node) else: # 如果没有更新数据,保留原始节点 merged_nodes.append(original_node) - + updated_results[key] = merged_nodes else: # 更新失败,记录错误但保留原始结果 logger.warning( f"更新 {key} 激活值失败: {str(update_result)}" ) - + return updated_results async def search_graph( - connector: Neo4jConnector, - q: str, - end_user_id: Optional[str] = None, - limit: int = 50, - include: List[str] = None, + connector: Neo4jConnector, + q: str, + end_user_id: Optional[str] = None, + limit: int = 50, + include: List[str] = None, ) -> Dict[str, List[Dict[str, Any]]]: """ Search across Statements, Entities, Chunks, and Summaries using a free-text query. @@ -249,41 +251,45 @@ async def search_graph( """ if include is None: include = ["statements", "chunks", "entities", "summaries"] - + # Prepare tasks for parallel execution tasks = [] task_keys = [] - + if "statements" in include: tasks.append(connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD, + json_format=True, q=q, end_user_id=end_user_id, limit=limit, )) task_keys.append("statements") - + if "entities" in include: tasks.append(connector.execute_query( SEARCH_ENTITIES_BY_NAME_OR_ALIAS, + json_format=True, q=q, end_user_id=end_user_id, limit=limit, )) task_keys.append("entities") - + if "chunks" in include: tasks.append(connector.execute_query( SEARCH_CHUNKS_BY_CONTENT, + json_format=True, q=q, end_user_id=end_user_id, limit=limit, )) task_keys.append("chunks") - + if "summaries" in include: tasks.append(connector.execute_query( SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, + json_format=True, q=q, end_user_id=end_user_id, limit=limit, @@ -293,15 +299,16 @@ async def search_graph( if "communities" in include: tasks.append(connector.execute_query( SEARCH_COMMUNITIES_BY_KEYWORD, + json_format=True, q=q, end_user_id=end_user_id, limit=limit, )) task_keys.append("communities") - + # Execute all queries in parallel task_results = await asyncio.gather(*tasks, return_exceptions=True) - + # Build results dictionary results = {} for key, result in zip(task_keys, task_results): @@ -310,14 +317,14 @@ async def search_graph( results[key] = [] else: results[key] = result - + # Deduplicate results before updating activation values # This prevents duplicates from propagating through the pipeline from app.core.memory.src.search import _deduplicate_results for key in results: if isinstance(results[key], list): results[key] = _deduplicate_results(results[key]) - + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) # Skip activation updates if only searching summaries (optimization) needs_activation_update = any( @@ -331,17 +338,17 @@ async def search_graph( results=results, end_user_id=end_user_id ) - + return results async def search_graph_by_embedding( - connector: Neo4jConnector, - embedder_client, - query_text: str, - end_user_id: Optional[str] = None, - limit: int = 50, - include: List[str] = ["statements", "chunks", "entities","summaries"], + connector: Neo4jConnector, + embedder_client, + query_text: str, + end_user_id: Optional[str] = None, + limit: int = 50, + include: List[str] = ["statements", "chunks", "entities", "summaries"], ) -> Dict[str, List[Dict[str, Any]]]: """ Embedding-based semantic search across Statements, Chunks, and Entities. @@ -355,13 +362,13 @@ async def search_graph_by_embedding( - Returns up to 'limit' per included type """ import time - + # Get embedding for the query embed_start = time.time() embeddings = await embedder_client.response([query_text]) embed_time = time.time() - embed_start - print(f"[PERF] Embedding generation took: {embed_time:.4f}s") - + logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s") + if not embeddings or not embeddings[0]: logger.warning( f"search_graph_by_embedding: embedding 生成失败或为空," @@ -378,6 +385,7 @@ async def search_graph_by_embedding( if "statements" in include: tasks.append(connector.execute_query( STATEMENT_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -388,6 +396,7 @@ async def search_graph_by_embedding( if "chunks" in include: tasks.append(connector.execute_query( CHUNK_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -398,6 +407,7 @@ async def search_graph_by_embedding( if "entities" in include: tasks.append(connector.execute_query( ENTITY_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -408,6 +418,7 @@ async def search_graph_by_embedding( if "summaries" in include: tasks.append(connector.execute_query( MEMORY_SUMMARY_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -418,6 +429,7 @@ async def search_graph_by_embedding( if "communities" in include: tasks.append(connector.execute_query( COMMUNITY_EMBEDDING_SEARCH, + json_format=True, embedding=embedding, end_user_id=end_user_id, limit=limit, @@ -428,8 +440,8 @@ async def search_graph_by_embedding( query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) query_time = time.time() - query_start - print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") - + logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") + # Build results dictionary results: Dict[str, List[Dict[str, Any]]] = { "statements": [], @@ -438,7 +450,7 @@ async def search_graph_by_embedding( "summaries": [], "communities": [], } - + for key, result in zip(task_keys, task_results): if isinstance(result, Exception): logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}") @@ -473,13 +485,15 @@ async def search_graph_by_embedding( logger.info(f"[PERF] Skipping activation updates (only summaries)") return results + + async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 - connector: Neo4jConnector, - end_user_id: str, - entities: List[Dict[str, Any]], - use_contains_fallback: bool = True, - batch_size: int = 500, - max_concurrency: int = 5, + connector: Neo4jConnector, + end_user_id: str, + entities: List[Dict[str, Any]], + use_contains_fallback: bool = True, + batch_size: int = 500, + max_concurrency: int = 5, ) -> Dict[str, List[Dict[str, Any]]]: """ 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries): @@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 async def search_graph_by_keyword_temporal( - connector: Neo4jConnector, - query_text: str, - end_user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 50, + connector: Neo4jConnector, + query_text: str, + end_user_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 50, ) -> Dict[str, List[Any]]: """ Temporal keyword search across Statements. @@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal( - Returns up to 'limit' statements """ if not query_text: - print(f"query_text不能为空") + logger.warning(f"query_text不能为空") return {"statements": []} statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, @@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal( invalid_date=invalid_date, limit=limit, ) - print(f"查询结果为:\n{statements}") + logger.debug(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal( async def search_graph_by_temporal( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - start_date: Optional[str] = None, - end_date: Optional[str] = None, - valid_date: Optional[str] = None, - invalid_date: Optional[str] = None, - limit: int = 10, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + start_date: Optional[str] = None, + end_date: Optional[str] = None, + valid_date: Optional[str] = None, + invalid_date: Optional[str] = None, + limit: int = 10, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -632,10 +646,6 @@ async def search_graph_by_temporal( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -643,15 +653,15 @@ async def search_graph_by_temporal( results=results, end_user_id=end_user_id ) - + return results async def search_graph_by_dialog_id( - connector: Neo4jConnector, - dialog_id: str, - end_user_id: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + dialog_id: str, + end_user_id: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Dialogues. @@ -661,7 +671,7 @@ async def search_graph_by_dialog_id( - Returns up to 'limit' dialogues """ if not dialog_id: - print(f"dialog_id不能为空") + logger.warning(f"dialog_id不能为空") return {"dialogues": []} dialogues = await connector.execute_query( @@ -674,13 +684,13 @@ async def search_graph_by_dialog_id( async def search_graph_by_chunk_id( - connector: Neo4jConnector, - chunk_id : str, - end_user_id: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + chunk_id: str, + end_user_id: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - print(f"chunk_id不能为空") + logger.warning(f"chunk_id不能为空") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, @@ -692,10 +702,10 @@ async def search_graph_by_chunk_id( async def search_graph_community_expand( - connector: Neo4jConnector, - community_ids: List[str], - end_user_id: str, - limit: int = 10, + connector: Neo4jConnector, + community_ids: List[str], + end_user_id: str, + limit: int = 10, ) -> Dict[str, List[Dict[str, Any]]]: """ 三期:社区展开检索 —— 主题 → 细节两级检索。 @@ -748,12 +758,11 @@ async def search_graph_community_expand( async def search_graph_by_created_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - created_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -767,16 +776,11 @@ async def search_graph_by_created_at( statements = await connector.execute_query( SEARCH_STATEMENTS_BY_CREATED_AT, end_user_id=end_user_id, - - + created_at=created_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -784,16 +788,16 @@ async def search_graph_by_created_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_by_valid_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - valid_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -807,16 +811,11 @@ async def search_graph_by_valid_at( statements = await connector.execute_query( SEARCH_STATEMENTS_BY_VALID_AT, end_user_id=end_user_id, - - + valid_at=valid_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -824,16 +823,16 @@ async def search_graph_by_valid_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_g_created_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - created_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -847,16 +846,11 @@ async def search_graph_g_created_at( statements = await connector.execute_query( SEARCH_STATEMENTS_G_CREATED_AT, end_user_id=end_user_id, - - + created_at=created_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -864,16 +858,16 @@ async def search_graph_g_created_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_g_valid_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - valid_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -887,16 +881,10 @@ async def search_graph_g_valid_at( statements = await connector.execute_query( SEARCH_STATEMENTS_G_VALID_AT, end_user_id=end_user_id, - - valid_at=valid_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -904,16 +892,16 @@ async def search_graph_g_valid_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_l_created_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - created_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -927,16 +915,11 @@ async def search_graph_l_created_at( statements = await connector.execute_query( SEARCH_STATEMENTS_L_CREATED_AT, end_user_id=end_user_id, - - + created_at=created_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -944,16 +927,16 @@ async def search_graph_l_created_at( results=results, end_user_id=end_user_id ) - + return results + async def search_graph_l_valid_at( - connector: Neo4jConnector, - end_user_id: Optional[str] = None, - - - valid_at: Optional[str] = None, - limit: int = 1, + connector: Neo4jConnector, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, + limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. @@ -967,16 +950,11 @@ async def search_graph_l_valid_at( statements = await connector.execute_query( SEARCH_STATEMENTS_L_VALID_AT, end_user_id=end_user_id, - - + valid_at=valid_at, limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") - print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") - # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( @@ -984,5 +962,89 @@ async def search_graph_l_valid_at( results=results, end_user_id=end_user_id ) - + return results + + +async def search_perceptual( + connector: Neo4jConnector, + q: str, + end_user_id: Optional[str] = None, + limit: int = 10, +) -> Dict[str, List[Dict[str, Any]]]: + """ + Search Perceptual memory nodes using fulltext keyword search. + + Matches against summary, topic, and domain fields via the perceptualFulltext index. + + Args: + connector: Neo4j connector + q: Query text + end_user_id: Optional user filter + limit: Max results + + Returns: + Dictionary with 'perceptuals' key containing matched perceptual memory nodes + """ + try: + perceptuals = await connector.execute_query( + SEARCH_PERCEPTUAL_BY_KEYWORD, + q=q, + end_user_id=end_user_id, + limit=limit, + ) + except Exception as e: + logger.warning(f"search_perceptual: keyword search failed: {e}") + perceptuals = [] + + # Deduplicate + from app.core.memory.src.search import _deduplicate_results + perceptuals = _deduplicate_results(perceptuals) + + return {"perceptuals": perceptuals} + + +async def search_perceptual_by_embedding( + connector: Neo4jConnector, + embedder_client, + query_text: str, + end_user_id: Optional[str] = None, + limit: int = 10, +) -> Dict[str, List[Dict[str, Any]]]: + """ + Search Perceptual memory nodes using embedding-based semantic search. + + Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index. + + Args: + connector: Neo4j connector + embedder_client: Embedding client with async response() method + query_text: Query text to embed + end_user_id: Optional user filter + limit: Max results + + Returns: + Dictionary with 'perceptuals' key containing matched perceptual memory nodes + """ + embeddings = await embedder_client.response([query_text]) + if not embeddings or not embeddings[0]: + logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'") + return {"perceptuals": []} + + embedding = embeddings[0] + + try: + perceptuals = await connector.execute_query( + PERCEPTUAL_EMBEDDING_SEARCH, + embedding=embedding, + end_user_id=end_user_id, + limit=limit, + ) + except Exception as e: + logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}") + perceptuals = [] + + from app.core.memory.src.search import _deduplicate_results + perceptuals = _deduplicate_results(perceptuals) + + return {"perceptuals": perceptuals} diff --git a/api/app/repositories/neo4j/neo4j_connector.py b/api/app/repositories/neo4j/neo4j_connector.py index d96e4431..ea8fa917 100644 --- a/api/app/repositories/neo4j/neo4j_connector.py +++ b/api/app/repositories/neo4j/neo4j_connector.py @@ -11,10 +11,28 @@ Classes: from typing import Any, List, Dict from neo4j import AsyncGraphDatabase, basic_auth +from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration from app.core.config import settings +def _convert_neo4j_types(value: Any) -> Any: + """递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。""" + if isinstance(value, Neo4jDateTime): + return value.to_native().isoformat() if value.tzinfo else value.iso_format() + if isinstance(value, Neo4jDate): + return value.iso_format() + if isinstance(value, Neo4jTime): + return value.iso_format() + if isinstance(value, Neo4jDuration): + return str(value) + if isinstance(value, dict): + return {k: _convert_neo4j_types(v) for k, v in value.items()} + if isinstance(value, list): + return [_convert_neo4j_types(item) for item in value] + return value + + class Neo4jConnector: """Neo4j数据库连接器 @@ -59,11 +77,12 @@ class Neo4jConnector: """ await self.driver.close() - async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: + async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]: """执行Cypher查询 Args: query: Cypher查询语句 + json_format: json格式化 **kwargs: 查询参数,将作为参数传递给Cypher查询 Returns: @@ -78,7 +97,10 @@ class Neo4jConnector: **kwargs ) records, summary, keys = result - return [record.data() for record in records] + if json_format: + return [_convert_neo4j_types(record.data()) for record in records] + else: + return [record.data() for record in records] async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any: """在写事务中执行操作 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index c27a75be..b12bb48a 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -462,11 +462,6 @@ class MemoryAgentService: logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") - # 导入审计日志记录器 - - - - config_load_start = time.time() try: # Use a separate database session to avoid transaction failures @@ -507,10 +502,13 @@ class MemoryAgentService: async with make_read_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, - "end_user_id": end_user_id - , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, - "memory_config": memory_config} + initial_state = { + "messages": [HumanMessage(content=message)], + "search_switch": search_switch, + "end_user_id": end_user_id + , "storage_type": storage_type, + "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} # 获取节点更新信息 _intermediate_outputs = [] summary = '' @@ -522,7 +520,7 @@ class MemoryAgentService: for node_name, node_data in update_event.items(): # if 'save_neo4j' == node_name: # massages = node_data - print(f"处理节点: {node_name}") + logger.info(f"处理节点: {node_name}") # 处理不同Summary节点的返回结构 if 'Summary' in node_name: @@ -549,6 +547,11 @@ class MemoryAgentService: if retrieve_node and retrieve_node != [] and retrieve_node != {}: _intermediate_outputs.extend(retrieve_node) + # Perceptual_Retrieve 节点 + perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None) + if perceptual_node and perceptual_node != [] and perceptual_node != {}: + _intermediate_outputs.append(perceptual_node) + # Verify 节点 verify_n = node_data.get('verify', {}).get('_intermediate', None) if verify_n and verify_n != [] and verify_n != {}: