feat(memory): add perceptual memory retrieval service with BM25+embedding fusion

This commit is contained in:
Eternity
2026-04-01 17:19:03 +08:00
parent 75bb96d4e7
commit 9cbe9d5edc
13 changed files with 1042 additions and 409 deletions

View File

@@ -0,0 +1,408 @@
"""
Perceptual Memory Retrieval Node & Service
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
with BM25+embedding fusion reranking.
Also provides the perceptual_retrieve_node for use as a LangGraph node.
"""
import asyncio
import math
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import (
search_perceptual,
search_perceptual_by_embedding,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = get_agent_logger(__name__)
class PerceptualSearchService:
"""
感知记忆检索服务。
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
调用方只需提供 query / keywords、end_user_id、memory_config即可获得
格式化并排序后的感知记忆列表和拼接文本。
Usage:
service = PerceptualSearchService(end_user_id=..., memory_config=...)
results = await service.search(query="...", keywords=[...], limit=10)
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
"""
DEFAULT_ALPHA = 0.6
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
end_user_id: str,
memory_config: Any,
alpha: float = DEFAULT_ALPHA,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
):
self.end_user_id = end_user_id
self.memory_config = memory_config
self.alpha = alpha
self.content_score_threshold = content_score_threshold
async def search(
self,
query: str,
keywords: Optional[List[str]] = None,
limit: int = 10,
) -> Dict[str, Any]:
"""
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
Args:
query: 原始用户查询(用于向量检索和 BM25 补查)
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
limit: 最大返回数量
Returns:
{
"memories": [格式化后的记忆 dict, ...],
"content": "拼接的纯文本摘要",
"keyword_raw": int,
"embedding_raw": int,
}
"""
if keywords is None:
keywords = [query] if query else []
connector = Neo4jConnector()
try:
kw_task = self._keyword_search(connector, keywords, limit)
emb_task = self._embedding_search(connector, query, limit)
kw_results, emb_results = await asyncio.gather(
kw_task, emb_task, return_exceptions=True
)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
# 补查 BM25找出 embedding 命中但 keyword 未命中的 id
# 用原始 query 对这些节点补查全文索引拿 BM25 score
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
if emb_only_ids and query:
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
# 把补查到的 BM25 score 注入到 embedding 结果中
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
for r in emb_results:
rid = r.get("id", "")
if rid in backfill_map:
r["bm25_backfill_score"] = backfill_map[rid]
logger.info(
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
f"{len(backfill_map)} got BM25 scores"
)
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return {
"memories": memories,
"content": "\n\n".join(content_parts),
"keyword_raw": len(kw_results),
"embedding_raw": len(emb_results),
}
finally:
await connector.close()
async def _bm25_backfill(
self,
connector: Neo4jConnector,
query: str,
target_ids: set,
limit: int,
) -> List[dict]:
"""
对指定 id 集合补查全文索引 BM25 score。
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
"""
escaped = escape_lucene_query(query)
if not escaped.strip():
return []
try:
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率
)
all_hits = r.get("perceptuals", [])
return [h for h in all_hits if h.get("id") in target_ids]
except Exception as e:
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
return []
async def _keyword_search(
self,
connector: Neo4jConnector,
keywords: List[str],
limit: int,
) -> List[dict]:
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
seen_ids: set = set()
all_results: List[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id, limit=limit,
)
return r.get("perceptuals", [])
tasks = [_one(kw) for kw in keywords[:10]]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
connector: Neo4jConnector,
query_text: str,
limit: int,
) -> List[dict]:
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
try:
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
with get_db_context() as db:
cfg = MemoryConfigService(db).get_embedder_config(
str(self.memory_config.embedding_model_id)
)
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
r = await search_perceptual_by_embedding(
connector=connector, embedder_client=client,
query_text=query_text, end_user_id=self.end_user_id,
limit=limit,
)
return r.get("perceptuals", [])
except Exception as e:
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
return []
def _rerank(
self,
keyword_results: List[dict],
embedding_results: List[dict],
limit: int,
) -> List[dict]:
"""BM25 + embedding 融合排序。
对 embedding 结果中带有 bm25_backfill_score 的条目,
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
"""
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
emb_backfill_items = []
for item in embedding_results:
backfill_score = item.get("bm25_backfill_score")
if backfill_score is not None and item.get("id"):
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
# 合并后统一归一化 BM25 scores
all_bm25_items = keyword_results + emb_backfill_items
all_bm25_items = self._normalize_scores(all_bm25_items)
# 建立 id -> normalized BM25 score 的映射
bm25_norm_map: Dict[str, float] = {}
for item in all_bm25_items:
item_id = item.get("id", "")
if item_id:
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
# 归一化 embedding scores
embedding_results = self._normalize_scores(embedding_results)
# 合并
combined: Dict[str, dict] = {}
for item in keyword_results:
item_id = item.get("id", "")
if not item_id:
continue
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = 0.0
for item in embedding_results:
item_id = item.get("id", "")
if not item_id:
continue
if item_id in combined:
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
for item in combined.values():
bm25 = float(item.get("bm25_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
results = list(combined.values())
before = len(results)
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
)
return results
@staticmethod
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
"""Z-score + sigmoid 归一化。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
if len(scores) <= 1:
for it in items:
it[f"normalized_{field}"] = 1.0
return items
mean = sum(scores) / len(scores)
var = sum((s - mean) ** 2 for s in scores) / len(scores)
std = math.sqrt(var)
if std == 0:
for it in items:
it[f"normalized_{field}"] = 1.0
else:
for it, s in zip(items, scores):
z = (s - mean) / std
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
parts = []
if formatted["summary"]:
parts.append(formatted["summary"])
if formatted["topic"]:
parts.append(f"[主题: {formatted['topic']}]")
if formatted["keywords"]:
kw_list = formatted["keywords"]
if isinstance(kw_list, list):
parts.append(f"[关键词: {', '.join(kw_list)}]")
if formatted["file_name"]:
parts.append(f"[文件: {formatted['file_name']}]")
return " ".join(parts)
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
"""Extract search keywords from problem extension results."""
keywords = []
context = problem_extension.get("context", {})
if isinstance(context, dict):
for original_q, extended_qs in context.items():
keywords.append(original_q)
if isinstance(extended_qs, list):
keywords.extend(extended_qs)
return keywords
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
"""
LangGraph node: perceptual memory retrieval.
Uses PerceptualSearchService to run keyword + embedding search with
BM25 fusion reranking, then writes results to state['perceptual_data'].
"""
end_user_id = state.get("end_user_id", "")
problem_extension = state.get("problem_extension", {})
original_query = state.get("data", "")
memory_config = state.get("memory_config", None)
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
keywords = _extract_keywords_from_problems(problem_extension)
if not keywords:
keywords = [original_query] if original_query else []
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
search_result = await service.search(
query=original_query,
keywords=keywords,
limit=10,
)
result = {
"memories": search_result["memories"],
"content": search_result["content"],
"_intermediate": {
"type": "perceptual_retrieve",
"title": "感知记忆检索",
"data": search_result["memories"],
"query": original_query,
"result_count": len(search_result["memories"]),
},
}
return {"perceptual_data": result}

View File

@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem extension result: {aggregated_dict}") logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend # Emit intermediate output for frontend
print(time.time() - start)
result = { result = {
"context": aggregated_dict, "context": aggregated_dict,
"original": data, "original": data,

View File

@@ -1,7 +1,11 @@
import asyncio
import os import os
import time import time
from app.core.logging_config import get_agent_logger, log_time from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
PerceptualSearchService,
)
from app.core.memory.agent.models.summary_models import ( from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse, RetrieveSummaryResponse,
SummaryResponse, SummaryResponse,
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
try: try:
if storage_type != "rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
async def _perceptual_search():
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
return await service.search(query=data, limit=5)
hybrid_task = SearchService().execute_hybrid_search(
**search_params, **search_params,
memory_config=memory_config, memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement expand_communities=False,
) )
perceptual_task = _perceptual_search()
gather_results = await asyncio.gather(
hybrid_task, perceptual_task, return_exceptions=True
)
hybrid_result = gather_results[0]
perceptual_results = gather_results[1]
# 处理 hybrid search 异常
if isinstance(hybrid_result, Exception):
raise hybrid_result
retrieve_info, question, raw_results = hybrid_result
# 处理感知记忆结果
if isinstance(perceptual_results, Exception):
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
perceptual_results = []
# 拼接感知记忆内容到 retrieve_info
if perceptual_results and isinstance(perceptual_results, dict):
perceptual_content = perceptual_results.get("content", "")
if perceptual_content:
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
count = len(perceptual_results.get("memories", []))
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
# 调试:打印 community 检索结果数量 # 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict): if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {}) reranked = raw_results.get('reranked_results', {})
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
"error": str(e) "error": str(e)
} }
end = time.time() end = time.time()
try: duration = end - start
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration) log_time('检索', duration)
return {"summary": summary} return {"summary": summary}
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve_info_str = list(set(retrieve_info_str)) retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str) retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str, # Merge perceptual memory content
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
aimessages = await summary_llm(
state,
history,
retrieve_info_str,
'direct_summary_prompt.jinja2',
'retrieve_summary', RetrieveSummaryResponse,
"1"
)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
retrieve_info_str += i + '\n' retrieve_info_str += i + '\n'
history = await summary_history(state) history = await summary_history(state)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
if key == 'answer_small': if key == 'answer_small':
for i in value: for i in value:
retrieve_info_str += i + '\n' retrieve_info_str += i + '\n'
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = { data = {
"query": query, "query": query,
"history": history, "history": history,

View File

@@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve_nodes, retrieve_nodes,
) )
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary, Input_Summary,
Retrieve_Summary, Retrieve_Summary,
@@ -55,6 +58,7 @@ async def make_read_graph():
workflow.add_node("Input_Summary", Input_Summary) workflow.add_node("Input_Summary", Input_Summary)
workflow.add_node("Retrieve", retrieve_nodes) workflow.add_node("Retrieve", retrieve_nodes)
# workflow.add_node("Retrieve", retrieve) # workflow.add_node("Retrieve", retrieve)
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
workflow.add_node("Verify", Verify) workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary) workflow.add_node("Summary", Summary)
@@ -65,14 +69,15 @@ async def make_read_graph():
workflow.add_conditional_edges("content_input", Split_continue) workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END) workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension") workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve") # After Problem_Extension, retrieve perceptual memory first, then main Retrieve
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END) workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue) workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END) workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END) # workflow.add_edge("Retrieve", END)
# Compile workflow # Compile workflow
@@ -80,7 +85,5 @@ async def make_read_graph():
yield graph yield graph
except Exception as e: except Exception as e:
print(f"创建工作流失败: {e}") logger.error(f"创建工作流失败: {e}")
raise raise
finally:
print("工作流创建完成")

View File

@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化) # 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
async def expand_communities_to_statements( async def expand_communities_to_statements(
community_results: List[dict], community_results: List[dict],
end_user_id: str, end_user_id: str,
existing_content: str = "", existing_content: str = "",
limit: int = 10, limit: int = 10,
) -> Tuple[List[dict], List[str]]: ) -> Tuple[List[dict], List[str]]:
""" """
社区展开 helper给定命中的 community 列表,拉取关联 Statement。 社区展开 helper给定命中的 community 列表,拉取关联 Statement。
@@ -76,7 +75,8 @@ async def expand_communities_to_statements(
if s.get("statement") and s["statement"] not in existing_lines if s.get("statement") and s["statement"] not in existing_lines
] ]
cleaned = _clean_expand_fields(expanded_stmts) cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}") logger.info(
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts return cleaned, new_texts
@@ -117,9 +117,9 @@ class SearchService:
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = ( is_community = (
node_type == "community" node_type == "community"
or 'member_count' in result or 'member_count' in result
or 'core_entities' in result or 'core_entities' in result
) )
if is_community: if is_community:
name = result.get('name', '') name = result.get('name', '')
@@ -158,7 +158,7 @@ class SearchService:
# Remove wrapping quotes # Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or ( if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"') q.startswith('"') and q.endswith('"')
): ):
q = q[1:-1] q = q[1:-1]
@@ -171,17 +171,17 @@ class SearchService:
return q return q
async def execute_hybrid_search( async def execute_hybrid_search(
self, self,
end_user_id: str, end_user_id: str,
question: str, question: str,
limit: int = 5, limit: int = 5,
search_type: str = "hybrid", search_type: str = "hybrid",
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
rerank_alpha: float = 0.4, rerank_alpha: float = 0.4,
output_path: str = "search_results.json", output_path: str = "search_results.json",
return_raw_results: bool = False, return_raw_results: bool = False,
memory_config = None, memory_config=None,
expand_communities: bool = True, expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]: ) -> Tuple[str, str, Optional[dict]]:
""" """
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
@@ -269,7 +269,6 @@ class SearchService:
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype)) content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines # Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c]) clean_content = '\n'.join([c for c in content_list if c])

View File

@@ -1,4 +1,3 @@
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Annotated, TypedDict from typing import Annotated, TypedDict
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
embedding_id: str embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象 memory_config: object # 新增字段用于传递内存配置对象
retrieve: dict retrieve: dict
perceptual_data: dict
RetrieveSummary: dict RetrieveSummary: dict
InputSummary: dict InputSummary: dict
verify: dict verify: dict

View File

@@ -43,6 +43,7 @@ load_dotenv()
logger = get_memory_logger(__name__) logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]: def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'.""" """Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None: if value is None:
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
item[f"normalized_{score_field}"] = None item[f"normalized_{score_field}"] = None
return results return results
if len(valid_scores) == 1: # Single valid score, set to 1.0 if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores): for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value": if score_field in item or score_field == "activation_value":
if score is None: if score is None:
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Remove duplicate items from search results based on content. Remove duplicate items from search results based on content.
@@ -157,11 +157,11 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# Extract content from various possible fields # Extract content from various possible fields
content = ( content = (
item.get("text") or item.get("text") or
item.get("content") or item.get("content") or
item.get("statement") or item.get("statement") or
item.get("name") or item.get("name") or
"" ""
) )
# Normalize content for comparison (strip whitespace and lowercase) # Normalize content for comparison (strip whitespace and lowercase)
@@ -189,13 +189,14 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
def rerank_with_activation( def rerank_with_activation(
keyword_results: Dict[str, List[Dict[str, Any]]], keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]], embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6, alpha: float = 0.6,
limit: int = 10, limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None, forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8, activation_boost_factor: float = 0.8,
now: datetime | None = None, now: datetime | None = None,
content_score_threshold: float = 0.5,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
两阶段排序:先按内容相关性筛选,再按激活值排序。 两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -222,6 +223,8 @@ def rerank_with_activation(
forgetting_config: 遗忘引擎配置(当前未使用) forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8) activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算) now: 当前时间(用于遗忘计算)
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score
低于此阈值的结果会被过滤。默认 0.5。
返回: 返回:
带评分元数据的重排序结果,按 final_score 排序 带评分元数据的重排序结果,按 final_score 排序
@@ -391,7 +394,19 @@ def rerank_with_activation(
# 无激活值:使用内容相关性分数 # 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0) item["final_score"] = item.get("base_score", 0)
# 最终去重确保没有重复项 if content_score_threshold > 0:
before_count = len(sorted_items)
sorted_items = [
item for item in sorted_items
if float(item.get("content_score", 0) or 0) >= content_score_threshold
]
filtered_count = before_count - len(sorted_items)
if filtered_count > 0:
logger.info(
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
f"items below content_score_threshold={content_score_threshold}"
)
sorted_items = _deduplicate_results(sorted_items) sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items reranked[category] = sorted_items
@@ -399,7 +414,8 @@ def rerank_with_activation(
return reranked return reranked
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None): def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
log_file: str = None):
"""Log search query information using the logger. """Log search query information using the logger.
Args: Args:
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
def apply_reranker_placeholder( def apply_reranker_placeholder(
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
query_text: str, query_text: str,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Placeholder for a cross-encoder reranker. Placeholder for a cross-encoder reranker.
@@ -673,17 +689,17 @@ def apply_reranker_placeholder(
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str, search_type: str,
end_user_id: str | None, end_user_id: str | None,
limit: int, limit: int,
include: List[str], include: List[str],
output_path: str | None, output_path: str | None,
memory_config: "MemoryConfig", memory_config: "MemoryConfig",
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8, activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
): ):
""" """
@@ -788,7 +804,7 @@ async def run_hybrid_search(
if keyword_task: if keyword_task:
keyword_results = await keyword_task keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start keyword_latency = time.time() - search_start_time
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4) latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s") logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword": if search_type == "keyword":
@@ -798,7 +814,7 @@ async def run_hybrid_search(
if embedding_task: if embedding_task:
embedding_results = await embedding_task embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start embedding_latency = time.time() - search_start_time
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4) latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s") logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding": if search_type == "embedding":
@@ -810,7 +826,8 @@ async def run_hybrid_search(
if search_type == "hybrid": if search_type == "hybrid":
results["combined_summary"] = { results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text, "search_query": query_text,
"search_timestamp": datetime.now().isoformat() "search_timestamp": datetime.now().isoformat()
} }
@@ -866,7 +883,8 @@ async def run_hybrid_search(
results["reranked_results"] = reranked_results results["reranked_results"] = reranked_results
results["combined_summary"] = { results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()), "total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text, "search_query": query_text,
"search_timestamp": datetime.now().isoformat(), "search_timestamp": datetime.now().isoformat(),
@@ -908,8 +926,10 @@ async def run_hybrid_search(
# Log search completion with result count # Log search completion with result count
if search_type == "hybrid": if search_type == "hybrid":
result_counts = { result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()}, "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()} keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
embedding_results.items()}
} }
else: else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()} result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
@@ -927,12 +947,12 @@ async def run_hybrid_search(
async def search_by_temporal( async def search_by_temporal(
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 1, limit: int = 1,
): ):
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -968,13 +988,13 @@ async def search_by_temporal(
async def search_by_keyword_temporal( async def search_by_keyword_temporal(
query_text: str, query_text: str,
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 1, limit: int = 1,
): ):
""" """
Temporal keyword search across Statements. Temporal keyword search across Statements.
@@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id( async def search_chunk_by_chunk_id(
chunk_id: str, chunk_id: str,
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
limit: int = 1, limit: int = 1,
): ):
""" """
Search for Chunks by chunk_id. Search for Chunks by chunk_id.
@@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id(
limit=limit limit=limit
) )
return {"chunks": chunks} return {"chunks": chunks}

View File

@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
else: else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes() await create_all_indexes()
logger.info("All neo4j indexes and constraints created successfully!")
logger.info("应用程序启动完成") logger.info("应用程序启动完成")

View File

@@ -1,11 +1,11 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes(): async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring.""" """Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# 创建 Statements 索引 # 创建 Statements 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
@@ -40,8 +40,16 @@ async def create_fulltext_indexes():
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
# 创建 Perceptual 感知记忆索引
await connector.execute_query("""
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
finally: finally:
await connector.close() await connector.close()
async def create_vector_indexes(): async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search. """Create vector indexes for fast embedding similarity search.
@@ -51,7 +59,6 @@ async def create_vector_indexes():
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# Statement embedding index # Statement embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
@@ -63,7 +70,6 @@ async def create_vector_indexes():
}} }}
""") """)
# Chunk embedding index # Chunk embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
@@ -75,7 +81,6 @@ async def create_vector_indexes():
}} }}
""") """)
# Entity name embedding index # Entity name embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
@@ -87,7 +92,6 @@ async def create_vector_indexes():
}} }}
""") """)
# Memory summary embedding index # Memory summary embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
@@ -121,8 +125,20 @@ async def create_vector_indexes():
}} }}
""") """)
# Perceptual summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
FOR (p:Perceptual)
ON p.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
finally: finally:
await connector.close() await connector.close()
async def create_unique_constraints(): async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers. """Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates. Ensures concurrent MERGE operations remain safe and prevents duplicates.
@@ -155,10 +171,10 @@ async def create_unique_constraints():
finally: finally:
await connector.close() await connector.close()
async def create_all_indexes(): async def create_all_indexes():
"""Create all indexes and constraints in one go.""" """Create all indexes and constraints in one go."""
await create_fulltext_indexes() await create_fulltext_indexes()
await create_vector_indexes() await create_vector_indexes()
await create_unique_constraints() await create_unique_constraints()
print("✓ All indexes and constraints created successfully!")

View File

@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at r.created_at = edge.created_at
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
SEARCH_PERCEPTUAL_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
WHERE p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""
PERCEPTUAL_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS p, score
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS, EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
PERCEPTUAL_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_COMMUNITIES_BY_KEYWORD,
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
async def _update_activation_values_batch( async def _update_activation_values_batch(
connector: Neo4jConnector, connector: Neo4jConnector,
nodes: List[Dict[str, Any]], nodes: List[Dict[str, Any]],
node_label: str, node_label: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_retries: int = 3 max_retries: int = 3
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
批量更新节点的激活值 批量更新节点的激活值
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
async def _update_search_results_activation( async def _update_search_results_activation(
connector: Neo4jConnector, connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
更新搜索结果中所有知识节点的激活值 更新搜索结果中所有知识节点的激活值
@@ -196,7 +198,7 @@ async def _update_search_results_activation(
'importance_score', 'importance_score',
'version', 'version',
'statement', # Statement 节点的内容字段 'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段 'content' # MemorySummary 节点的内容字段
} }
# 只更新激活值相关字段,保留原始节点的其他字段 # 只更新激活值相关字段,保留原始节点的其他字段
@@ -220,11 +222,11 @@ async def _update_search_results_activation(
async def search_graph( async def search_graph(
connector: Neo4jConnector, connector: Neo4jConnector,
q: str, q: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = None, include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Search across Statements, Entities, Chunks, and Summaries using a free-text query. Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -257,6 +259,7 @@ async def search_graph(
if "statements" in include: if "statements" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -266,6 +269,7 @@ async def search_graph(
if "entities" in include: if "entities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -275,6 +279,7 @@ async def search_graph(
if "chunks" in include: if "chunks" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -284,6 +289,7 @@ async def search_graph(
if "summaries" in include: if "summaries" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -293,6 +299,7 @@ async def search_graph(
if "communities" in include: if "communities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -336,12 +343,12 @@ async def search_graph(
async def search_graph_by_embedding( async def search_graph_by_embedding(
connector: Neo4jConnector, connector: Neo4jConnector,
embedder_client, embedder_client,
query_text: str, query_text: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"], include: List[str] = ["statements", "chunks", "entities", "summaries"],
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Embedding-based semantic search across Statements, Chunks, and Entities. Embedding-based semantic search across Statements, Chunks, and Entities.
@@ -360,7 +367,7 @@ async def search_graph_by_embedding(
embed_start = time.time() embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s") logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
logger.warning( logger.warning(
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
if "statements" in include: if "statements" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH, STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
if "chunks" in include: if "chunks" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
if "entities" in include: if "entities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
if "summaries" in include: if "summaries" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
if "communities" in include: if "communities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH, COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -428,7 +440,7 @@ async def search_graph_by_embedding(
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
logger.info(f"[PERF] Skipping activation updates (only summaries)") logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: str, end_user_id: str,
entities: List[Dict[str, Any]], entities: List[Dict[str, Any]],
use_contains_fallback: bool = True, use_contains_fallback: bool = True,
batch_size: int = 500, batch_size: int = 500,
max_concurrency: int = 5, max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal( async def search_graph_by_keyword_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
query_text: str, query_text: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 50, limit: int = 50,
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
""" """
Temporal keyword search across Statements. Temporal keyword search across Statements.
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
if not query_text: if not query_text:
print(f"query_text不能为空") logger.warning(f"query_text不能为空")
return {"statements": []} return {"statements": []}
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date, invalid_date=invalid_date,
limit=limit, limit=limit,
) )
print(f"查询结果为:\n{statements}") logger.debug(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal( async def search_graph_by_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 10, limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -648,10 +658,10 @@ async def search_graph_by_temporal(
async def search_graph_by_dialog_id( async def search_graph_by_dialog_id(
connector: Neo4jConnector, connector: Neo4jConnector,
dialog_id: str, dialog_id: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Dialogues. Temporal search across Dialogues.
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues - Returns up to 'limit' dialogues
""" """
if not dialog_id: if not dialog_id:
print(f"dialog_id不能为空") logger.warning(f"dialog_id不能为空")
return {"dialogues": []} return {"dialogues": []}
dialogues = await connector.execute_query( dialogues = await connector.execute_query(
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id( async def search_graph_by_chunk_id(
connector: Neo4jConnector, connector: Neo4jConnector,
chunk_id : str, chunk_id: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id: if not chunk_id:
print(f"chunk_id不能为空") logger.warning(f"chunk_id不能为空")
return {"chunks": []} return {"chunks": []}
chunks = await connector.execute_query( chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
async def search_graph_community_expand( async def search_graph_community_expand(
connector: Neo4jConnector, connector: Neo4jConnector,
community_ids: List[str], community_ids: List[str],
end_user_id: str, end_user_id: str,
limit: int = 10, limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
三期:社区展开检索 —— 主题 → 细节两级检索。 三期:社区展开检索 —— 主题 → 细节两级检索。
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -768,15 +777,10 @@ async def search_graph_by_created_at(
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -787,13 +791,13 @@ async def search_graph_by_created_at(
return results return results
async def search_graph_by_valid_at( async def search_graph_by_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -808,15 +812,10 @@ async def search_graph_by_valid_at(
SEARCH_STATEMENTS_BY_VALID_AT, SEARCH_STATEMENTS_BY_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -827,13 +826,13 @@ async def search_graph_by_valid_at(
return results return results
async def search_graph_g_created_at( async def search_graph_g_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -848,15 +847,10 @@ async def search_graph_g_created_at(
SEARCH_STATEMENTS_G_CREATED_AT, SEARCH_STATEMENTS_G_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -867,13 +861,13 @@ async def search_graph_g_created_at(
return results return results
async def search_graph_g_valid_at( async def search_graph_g_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT, SEARCH_STATEMENTS_G_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -907,13 +895,13 @@ async def search_graph_g_valid_at(
return results return results
async def search_graph_l_created_at( async def search_graph_l_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -928,15 +916,10 @@ async def search_graph_l_created_at(
SEARCH_STATEMENTS_L_CREATED_AT, SEARCH_STATEMENTS_L_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -947,13 +930,13 @@ async def search_graph_l_created_at(
return results return results
async def search_graph_l_valid_at( async def search_graph_l_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -968,15 +951,10 @@ async def search_graph_l_valid_at(
SEARCH_STATEMENTS_L_VALID_AT, SEARCH_STATEMENTS_L_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -986,3 +964,87 @@ async def search_graph_l_valid_at(
) )
return results return results
async def search_perceptual(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
q: Query text
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
PERCEPTUAL_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}

View File

@@ -11,10 +11,28 @@ Classes:
from typing import Any, List, Dict from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth from neo4j import AsyncGraphDatabase, basic_auth
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
from app.core.config import settings from app.core.config import settings
def _convert_neo4j_types(value: Any) -> Any:
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
if isinstance(value, Neo4jDateTime):
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
if isinstance(value, Neo4jDate):
return value.iso_format()
if isinstance(value, Neo4jTime):
return value.iso_format()
if isinstance(value, Neo4jDuration):
return str(value)
if isinstance(value, dict):
return {k: _convert_neo4j_types(v) for k, v in value.items()}
if isinstance(value, list):
return [_convert_neo4j_types(item) for item in value]
return value
class Neo4jConnector: class Neo4jConnector:
"""Neo4j数据库连接器 """Neo4j数据库连接器
@@ -59,11 +77,12 @@ class Neo4jConnector:
""" """
await self.driver.close() await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询 """执行Cypher查询
Args: Args:
query: Cypher查询语句 query: Cypher查询语句
json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询 **kwargs: 查询参数将作为参数传递给Cypher查询
Returns: Returns:
@@ -78,7 +97,10 @@ class Neo4jConnector:
**kwargs **kwargs
) )
records, summary, keys = result records, summary, keys = result
return [record.data() for record in records] if json_format:
return [_convert_neo4j_types(record.data()) for record in records]
else:
return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any: async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作 """在写事务中执行操作

View File

@@ -462,11 +462,6 @@ class MemoryAgentService:
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
config_load_start = time.time() config_load_start = time.time()
try: try:
# Use a separate database session to avoid transaction failures # Use a separate database session to avoid transaction failures
@@ -507,10 +502,13 @@ class MemoryAgentService:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, initial_state = {
"end_user_id": end_user_id "messages": [HumanMessage(content=message)],
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "search_switch": search_switch,
"memory_config": memory_config} "end_user_id": end_user_id
, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []
summary = '' summary = ''
@@ -522,7 +520,7 @@ class MemoryAgentService:
for node_name, node_data in update_event.items(): for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name: # if 'save_neo4j' == node_name:
# massages = node_data # massages = node_data
print(f"处理节点: {node_name}") logger.info(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构 # 处理不同Summary节点的返回结构
if 'Summary' in node_name: if 'Summary' in node_name:
@@ -549,6 +547,11 @@ class MemoryAgentService:
if retrieve_node and retrieve_node != [] and retrieve_node != {}: if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node) _intermediate_outputs.extend(retrieve_node)
# Perceptual_Retrieve 节点
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
_intermediate_outputs.append(perceptual_node)
# Verify 节点 # Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None) verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}: if verify_n and verify_n != [] and verify_n != {}: