feat(memory): add perceptual memory retrieval service with BM25+embedding fusion

This commit is contained in:
Eternity
2026-04-01 17:19:03 +08:00
parent 75bb96d4e7
commit 9cbe9d5edc
13 changed files with 1042 additions and 409 deletions

View File

@@ -0,0 +1,408 @@
"""
Perceptual Memory Retrieval Node & Service
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
with BM25+embedding fusion reranking.
Also provides the perceptual_retrieve_node for use as a LangGraph node.
"""
import asyncio
import math
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import (
search_perceptual,
search_perceptual_by_embedding,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = get_agent_logger(__name__)
class PerceptualSearchService:
"""
感知记忆检索服务。
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
调用方只需提供 query / keywords、end_user_id、memory_config即可获得
格式化并排序后的感知记忆列表和拼接文本。
Usage:
service = PerceptualSearchService(end_user_id=..., memory_config=...)
results = await service.search(query="...", keywords=[...], limit=10)
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
"""
DEFAULT_ALPHA = 0.6
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
end_user_id: str,
memory_config: Any,
alpha: float = DEFAULT_ALPHA,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
):
self.end_user_id = end_user_id
self.memory_config = memory_config
self.alpha = alpha
self.content_score_threshold = content_score_threshold
async def search(
self,
query: str,
keywords: Optional[List[str]] = None,
limit: int = 10,
) -> Dict[str, Any]:
"""
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
Args:
query: 原始用户查询(用于向量检索和 BM25 补查)
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
limit: 最大返回数量
Returns:
{
"memories": [格式化后的记忆 dict, ...],
"content": "拼接的纯文本摘要",
"keyword_raw": int,
"embedding_raw": int,
}
"""
if keywords is None:
keywords = [query] if query else []
connector = Neo4jConnector()
try:
kw_task = self._keyword_search(connector, keywords, limit)
emb_task = self._embedding_search(connector, query, limit)
kw_results, emb_results = await asyncio.gather(
kw_task, emb_task, return_exceptions=True
)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
# 补查 BM25找出 embedding 命中但 keyword 未命中的 id
# 用原始 query 对这些节点补查全文索引拿 BM25 score
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
if emb_only_ids and query:
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
# 把补查到的 BM25 score 注入到 embedding 结果中
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
for r in emb_results:
rid = r.get("id", "")
if rid in backfill_map:
r["bm25_backfill_score"] = backfill_map[rid]
logger.info(
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
f"{len(backfill_map)} got BM25 scores"
)
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return {
"memories": memories,
"content": "\n\n".join(content_parts),
"keyword_raw": len(kw_results),
"embedding_raw": len(emb_results),
}
finally:
await connector.close()
async def _bm25_backfill(
self,
connector: Neo4jConnector,
query: str,
target_ids: set,
limit: int,
) -> List[dict]:
"""
对指定 id 集合补查全文索引 BM25 score。
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
"""
escaped = escape_lucene_query(query)
if not escaped.strip():
return []
try:
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率
)
all_hits = r.get("perceptuals", [])
return [h for h in all_hits if h.get("id") in target_ids]
except Exception as e:
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
return []
async def _keyword_search(
self,
connector: Neo4jConnector,
keywords: List[str],
limit: int,
) -> List[dict]:
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
seen_ids: set = set()
all_results: List[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id, limit=limit,
)
return r.get("perceptuals", [])
tasks = [_one(kw) for kw in keywords[:10]]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
connector: Neo4jConnector,
query_text: str,
limit: int,
) -> List[dict]:
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
try:
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
with get_db_context() as db:
cfg = MemoryConfigService(db).get_embedder_config(
str(self.memory_config.embedding_model_id)
)
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
r = await search_perceptual_by_embedding(
connector=connector, embedder_client=client,
query_text=query_text, end_user_id=self.end_user_id,
limit=limit,
)
return r.get("perceptuals", [])
except Exception as e:
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
return []
def _rerank(
self,
keyword_results: List[dict],
embedding_results: List[dict],
limit: int,
) -> List[dict]:
"""BM25 + embedding 融合排序。
对 embedding 结果中带有 bm25_backfill_score 的条目,
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
"""
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
emb_backfill_items = []
for item in embedding_results:
backfill_score = item.get("bm25_backfill_score")
if backfill_score is not None and item.get("id"):
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
# 合并后统一归一化 BM25 scores
all_bm25_items = keyword_results + emb_backfill_items
all_bm25_items = self._normalize_scores(all_bm25_items)
# 建立 id -> normalized BM25 score 的映射
bm25_norm_map: Dict[str, float] = {}
for item in all_bm25_items:
item_id = item.get("id", "")
if item_id:
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
# 归一化 embedding scores
embedding_results = self._normalize_scores(embedding_results)
# 合并
combined: Dict[str, dict] = {}
for item in keyword_results:
item_id = item.get("id", "")
if not item_id:
continue
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = 0.0
for item in embedding_results:
item_id = item.get("id", "")
if not item_id:
continue
if item_id in combined:
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
for item in combined.values():
bm25 = float(item.get("bm25_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
results = list(combined.values())
before = len(results)
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
)
return results
@staticmethod
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
"""Z-score + sigmoid 归一化。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
if len(scores) <= 1:
for it in items:
it[f"normalized_{field}"] = 1.0
return items
mean = sum(scores) / len(scores)
var = sum((s - mean) ** 2 for s in scores) / len(scores)
std = math.sqrt(var)
if std == 0:
for it in items:
it[f"normalized_{field}"] = 1.0
else:
for it, s in zip(items, scores):
z = (s - mean) / std
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
parts = []
if formatted["summary"]:
parts.append(formatted["summary"])
if formatted["topic"]:
parts.append(f"[主题: {formatted['topic']}]")
if formatted["keywords"]:
kw_list = formatted["keywords"]
if isinstance(kw_list, list):
parts.append(f"[关键词: {', '.join(kw_list)}]")
if formatted["file_name"]:
parts.append(f"[文件: {formatted['file_name']}]")
return " ".join(parts)
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
"""Extract search keywords from problem extension results."""
keywords = []
context = problem_extension.get("context", {})
if isinstance(context, dict):
for original_q, extended_qs in context.items():
keywords.append(original_q)
if isinstance(extended_qs, list):
keywords.extend(extended_qs)
return keywords
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
"""
LangGraph node: perceptual memory retrieval.
Uses PerceptualSearchService to run keyword + embedding search with
BM25 fusion reranking, then writes results to state['perceptual_data'].
"""
end_user_id = state.get("end_user_id", "")
problem_extension = state.get("problem_extension", {})
original_query = state.get("data", "")
memory_config = state.get("memory_config", None)
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
keywords = _extract_keywords_from_problems(problem_extension)
if not keywords:
keywords = [original_query] if original_query else []
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
search_result = await service.search(
query=original_query,
keywords=keywords,
limit=10,
)
result = {
"memories": search_result["memories"],
"content": search_result["content"],
"_intermediate": {
"type": "perceptual_retrieve",
"title": "感知记忆检索",
"data": search_result["memories"],
"query": original_query,
"result_count": len(search_result["memories"]),
},
}
return {"perceptual_data": result}

View File

@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
print(time.time() - start)
result = {
"context": aggregated_dict,
"original": data,

View File

@@ -1,7 +1,11 @@
import asyncio
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
PerceptualSearchService,
)
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
try:
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
async def _perceptual_search():
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
return await service.search(query=data, limit=5)
hybrid_task = SearchService().execute_hybrid_search(
**search_params,
memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
expand_communities=False,
)
perceptual_task = _perceptual_search()
gather_results = await asyncio.gather(
hybrid_task, perceptual_task, return_exceptions=True
)
hybrid_result = gather_results[0]
perceptual_results = gather_results[1]
# 处理 hybrid search 异常
if isinstance(hybrid_result, Exception):
raise hybrid_result
retrieve_info, question, raw_results = hybrid_result
# 处理感知记忆结果
if isinstance(perceptual_results, Exception):
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
perceptual_results = []
# 拼接感知记忆内容到 retrieve_info
if perceptual_results and isinstance(perceptual_results, dict):
perceptual_content = perceptual_results.get("content", "")
if perceptual_content:
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
count = len(perceptual_results.get("memories", []))
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
"error": str(e)
}
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
duration = end - start
log_time('检索', duration)
return {"summary": summary}
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
aimessages = await summary_llm(
state,
history,
retrieve_info_str,
'direct_summary_prompt.jinja2',
'retrieve_summary', RetrieveSummaryResponse,
"1"
)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
retrieve_info_str += i + '\n'
history = await summary_history(state)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,

View File

@@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve_nodes,
)
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary,
Retrieve_Summary,
@@ -55,6 +58,7 @@ async def make_read_graph():
workflow.add_node("Input_Summary", Input_Summary)
workflow.add_node("Retrieve", retrieve_nodes)
# workflow.add_node("Retrieve", retrieve)
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
@@ -65,14 +69,15 @@ async def make_read_graph():
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# Compile workflow
@@ -80,7 +85,5 @@ async def make_read_graph():
yield graph
except Exception as e:
print(f"创建工作流失败: {e}")
logger.error(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")

View File

@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
async def expand_communities_to_statements(
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
) -> Tuple[List[dict], List[str]]:
"""
社区展开 helper给定命中的 community 列表,拉取关联 Statement。
@@ -76,7 +75,8 @@ async def expand_communities_to_statements(
if s.get("statement") and s["statement"] not in existing_lines
]
cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
logger.info(
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts
@@ -117,9 +117,9 @@ class SearchService:
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = (
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
)
if is_community:
name = result.get('name', '')
@@ -158,7 +158,7 @@ class SearchService:
# Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"')
q.startswith('"') and q.endswith('"')
):
q = q[1:-1]
@@ -171,17 +171,17 @@ class SearchService:
return q
async def execute_hybrid_search(
self,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config = None,
expand_communities: bool = True,
self,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config=None,
expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
@@ -269,7 +269,6 @@ class SearchService:
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])

View File

@@ -1,4 +1,3 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve: dict
perceptual_data: dict
RetrieveSummary: dict
InputSummary: dict
verify: dict

View File

@@ -43,6 +43,7 @@ load_dotenv()
logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None:
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
item[f"normalized_{score_field}"] = None
return results
if len(valid_scores) == 1: # Single valid score, set to 1.0
if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Remove duplicate items from search results based on content.
@@ -157,11 +157,11 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Extract content from various possible fields
content = (
item.get("text") or
item.get("content") or
item.get("statement") or
item.get("name") or
""
item.get("text") or
item.get("content") or
item.get("statement") or
item.get("name") or
""
)
# Normalize content for comparison (strip whitespace and lowercase)
@@ -189,13 +189,14 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def rerank_with_activation(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
content_score_threshold: float = 0.5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -222,6 +223,8 @@ def rerank_with_activation(
forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算)
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score
低于此阈值的结果会被过滤。默认 0.5。
返回:
带评分元数据的重排序结果,按 final_score 排序
@@ -391,7 +394,19 @@ def rerank_with_activation(
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
# 最终去重确保没有重复项
if content_score_threshold > 0:
before_count = len(sorted_items)
sorted_items = [
item for item in sorted_items
if float(item.get("content_score", 0) or 0) >= content_score_threshold
]
filtered_count = before_count - len(sorted_items)
if filtered_count > 0:
logger.info(
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
f"items below content_score_threshold={content_score_threshold}"
)
sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items
@@ -399,7 +414,8 @@ def rerank_with_activation(
return reranked
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
log_file: str = None):
"""Log search query information using the logger.
Args:
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
def apply_reranker_placeholder(
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Placeholder for a cross-encoder reranker.
@@ -673,17 +689,17 @@ def apply_reranker_placeholder(
async def run_hybrid_search(
query_text: str,
search_type: str,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
query_text: str,
search_type: str,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
):
"""
@@ -788,7 +804,7 @@ async def run_hybrid_search(
if keyword_task:
keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start
keyword_latency = time.time() - search_start_time
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
@@ -798,7 +814,7 @@ async def run_hybrid_search(
if embedding_task:
embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start
embedding_latency = time.time() - search_start_time
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
@@ -810,7 +826,8 @@ async def run_hybrid_search(
if search_type == "hybrid":
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat()
}
@@ -866,7 +883,8 @@ async def run_hybrid_search(
results["reranked_results"] = reranked_results
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat(),
@@ -908,8 +926,10 @@ async def run_hybrid_search(
# Log search completion with result count
if search_type == "hybrid":
result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
embedding_results.items()}
}
else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
@@ -927,12 +947,12 @@ async def run_hybrid_search(
async def search_by_temporal(
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal search across Statements.
@@ -968,13 +988,13 @@ async def search_by_temporal(
async def search_by_keyword_temporal(
query_text: str,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
query_text: str,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal keyword search across Statements.
@@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id(
chunk_id: str,
end_user_id: Optional[str] = "test",
limit: int = 1,
chunk_id: str,
end_user_id: Optional[str] = "test",
limit: int = 1,
):
"""
Search for Chunks by chunk_id.
@@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id(
limit=limit
)
return {"chunks": chunks}

View File

@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes()
logger.info("All neo4j indexes and constraints created successfully!")
logger.info("应用程序启动完成")

View File

@@ -1,11 +1,11 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector()
try:
# 创建 Statements 索引
await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
@@ -40,8 +40,16 @@ async def create_fulltext_indexes():
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
# 创建 Perceptual 感知记忆索引
await connector.execute_query("""
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
finally:
await connector.close()
async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search.
@@ -51,7 +59,6 @@ async def create_vector_indexes():
connector = Neo4jConnector()
try:
# Statement embedding index
await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
@@ -63,7 +70,6 @@ async def create_vector_indexes():
}}
""")
# Chunk embedding index
await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
@@ -75,7 +81,6 @@ async def create_vector_indexes():
}}
""")
# Entity name embedding index
await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
@@ -87,7 +92,6 @@ async def create_vector_indexes():
}}
""")
# Memory summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
@@ -121,8 +125,20 @@ async def create_vector_indexes():
}}
""")
# Perceptual summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
FOR (p:Perceptual)
ON p.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
finally:
await connector.close()
async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates.
@@ -155,10 +171,10 @@ async def create_unique_constraints():
finally:
await connector.close()
async def create_all_indexes():
"""Create all indexes and constraints in one go."""
await create_fulltext_indexes()
await create_vector_indexes()
await create_unique_constraints()
print("✓ All indexes and constraints created successfully!")

View File

@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at
RETURN elementId(r) AS uuid
"""
SEARCH_PERCEPTUAL_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
WHERE p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""
PERCEPTUAL_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS p, score
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
PERCEPTUAL_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
end_user_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
批量更新节点的激活值
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
@@ -196,7 +198,7 @@ async def _update_search_results_activation(
'importance_score',
'version',
'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段
'content' # MemorySummary 节点的内容字段
}
# 只更新激活值相关字段,保留原始节点的其他字段
@@ -220,11 +222,11 @@ async def _update_search_results_activation(
async def search_graph(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -257,6 +259,7 @@ async def search_graph(
if "statements" in include:
tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -266,6 +269,7 @@ async def search_graph(
if "entities" in include:
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -275,6 +279,7 @@ async def search_graph(
if "chunks" in include:
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -284,6 +289,7 @@ async def search_graph(
if "summaries" in include:
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -293,6 +299,7 @@ async def search_graph(
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
q=q,
end_user_id=end_user_id,
limit=limit,
@@ -336,12 +343,12 @@ async def search_graph(
async def search_graph_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"],
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 50,
include: List[str] = ["statements", "chunks", "entities", "summaries"],
) -> Dict[str, List[Dict[str, Any]]]:
"""
Embedding-based semantic search across Statements, Chunks, and Entities.
@@ -360,7 +367,7 @@ async def search_graph_by_embedding(
embed_start = time.time()
embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]:
logger.warning(
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
if "statements" in include:
tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
if "chunks" in include:
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
if "entities" in include:
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
if "summaries" in include:
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
@@ -428,7 +440,7 @@ async def search_graph_by_embedding(
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = {
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
connector: Neo4jConnector,
end_user_id: str,
entities: List[Dict[str, Any]],
use_contains_fallback: bool = True,
batch_size: int = 500,
max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal(
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
connector: Neo4jConnector,
query_text: str,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 50,
) -> Dict[str, List[Any]]:
"""
Temporal keyword search across Statements.
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements
"""
if not query_text:
print(f"query_text不能为空")
logger.warning(f"query_text不能为空")
return {"statements": []}
statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date,
limit=limit,
)
print(f"查询结果为:\n{statements}")
logger.debug(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -648,10 +658,10 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id(
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
dialog_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Dialogues.
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues
"""
if not dialog_id:
print(f"dialog_id不能为空")
logger.warning(f"dialog_id不能为空")
return {"dialogues": []}
dialogues = await connector.execute_query(
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id(
connector: Neo4jConnector,
chunk_id : str,
end_user_id: Optional[str] = None,
limit: int = 1,
connector: Neo4jConnector,
chunk_id: str,
end_user_id: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id:
print(f"chunk_id不能为空")
logger.warning(f"chunk_id不能为空")
return {"chunks": []}
chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID,
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
async def search_graph_community_expand(
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
connector: Neo4jConnector,
community_ids: List[str],
end_user_id: str,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
三期:社区展开检索 —— 主题 → 细节两级检索。
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
async def search_graph_by_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -768,15 +777,10 @@ async def search_graph_by_created_at(
SEARCH_STATEMENTS_BY_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -787,13 +791,13 @@ async def search_graph_by_created_at(
return results
async def search_graph_by_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -808,15 +812,10 @@ async def search_graph_by_valid_at(
SEARCH_STATEMENTS_BY_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -827,13 +826,13 @@ async def search_graph_by_valid_at(
return results
async def search_graph_g_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -848,15 +847,10 @@ async def search_graph_g_created_at(
SEARCH_STATEMENTS_G_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -867,13 +861,13 @@ async def search_graph_g_created_at(
return results
async def search_graph_g_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -907,13 +895,13 @@ async def search_graph_g_valid_at(
return results
async def search_graph_l_created_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
limit: int = 1,
created_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -928,15 +916,10 @@ async def search_graph_l_created_at(
SEARCH_STATEMENTS_L_CREATED_AT,
end_user_id=end_user_id,
created_at=created_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -947,13 +930,13 @@ async def search_graph_l_created_at(
return results
async def search_graph_l_valid_at(
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
connector: Neo4jConnector,
end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
limit: int = 1,
valid_at: Optional[str] = None,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Temporal search across Statements.
@@ -968,15 +951,10 @@ async def search_graph_l_valid_at(
SEARCH_STATEMENTS_L_VALID_AT,
end_user_id=end_user_id,
valid_at=valid_at,
limit=limit,
)
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
@@ -986,3 +964,87 @@ async def search_graph_l_valid_at(
)
return results
async def search_perceptual(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
q: Query text
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
PERCEPTUAL_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}

View File

@@ -11,10 +11,28 @@ Classes:
from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
from app.core.config import settings
def _convert_neo4j_types(value: Any) -> Any:
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
if isinstance(value, Neo4jDateTime):
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
if isinstance(value, Neo4jDate):
return value.iso_format()
if isinstance(value, Neo4jTime):
return value.iso_format()
if isinstance(value, Neo4jDuration):
return str(value)
if isinstance(value, dict):
return {k: _convert_neo4j_types(v) for k, v in value.items()}
if isinstance(value, list):
return [_convert_neo4j_types(item) for item in value]
return value
class Neo4jConnector:
"""Neo4j数据库连接器
@@ -59,11 +77,12 @@ class Neo4jConnector:
"""
await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询
Args:
query: Cypher查询语句
json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询
Returns:
@@ -78,7 +97,10 @@ class Neo4jConnector:
**kwargs
)
records, summary, keys = result
return [record.data() for record in records]
if json_format:
return [_convert_neo4j_types(record.data()) for record in records]
else:
return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作

View File

@@ -462,11 +462,6 @@ class MemoryAgentService:
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
config_load_start = time.time()
try:
# Use a separate database session to avoid transaction failures
@@ -507,10 +502,13 @@ class MemoryAgentService:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
initial_state = {
"messages": [HumanMessage(content=message)],
"search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
@@ -522,7 +520,7 @@ class MemoryAgentService:
for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
print(f"处理节点: {node_name}")
logger.info(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
@@ -549,6 +547,11 @@ class MemoryAgentService:
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
# Perceptual_Retrieve 节点
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
_intermediate_outputs.append(perceptual_node)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}: