feat(memory): add perceptual memory retrieval service with BM25+embedding fusion

This commit is contained in:
Eternity
2026-04-01 17:19:03 +08:00
parent 75bb96d4e7
commit 9cbe9d5edc
13 changed files with 1042 additions and 409 deletions

View File

@@ -0,0 +1,408 @@
"""
Perceptual Memory Retrieval Node & Service
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
with BM25+embedding fusion reranking.
Also provides the perceptual_retrieve_node for use as a LangGraph node.
"""
import asyncio
import math
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import (
search_perceptual,
search_perceptual_by_embedding,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = get_agent_logger(__name__)
class PerceptualSearchService:
"""
感知记忆检索服务。
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
调用方只需提供 query / keywords、end_user_id、memory_config即可获得
格式化并排序后的感知记忆列表和拼接文本。
Usage:
service = PerceptualSearchService(end_user_id=..., memory_config=...)
results = await service.search(query="...", keywords=[...], limit=10)
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
"""
DEFAULT_ALPHA = 0.6
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
end_user_id: str,
memory_config: Any,
alpha: float = DEFAULT_ALPHA,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
):
self.end_user_id = end_user_id
self.memory_config = memory_config
self.alpha = alpha
self.content_score_threshold = content_score_threshold
async def search(
self,
query: str,
keywords: Optional[List[str]] = None,
limit: int = 10,
) -> Dict[str, Any]:
"""
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
Args:
query: 原始用户查询(用于向量检索和 BM25 补查)
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
limit: 最大返回数量
Returns:
{
"memories": [格式化后的记忆 dict, ...],
"content": "拼接的纯文本摘要",
"keyword_raw": int,
"embedding_raw": int,
}
"""
if keywords is None:
keywords = [query] if query else []
connector = Neo4jConnector()
try:
kw_task = self._keyword_search(connector, keywords, limit)
emb_task = self._embedding_search(connector, query, limit)
kw_results, emb_results = await asyncio.gather(
kw_task, emb_task, return_exceptions=True
)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
# 补查 BM25找出 embedding 命中但 keyword 未命中的 id
# 用原始 query 对这些节点补查全文索引拿 BM25 score
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
if emb_only_ids and query:
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
# 把补查到的 BM25 score 注入到 embedding 结果中
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
for r in emb_results:
rid = r.get("id", "")
if rid in backfill_map:
r["bm25_backfill_score"] = backfill_map[rid]
logger.info(
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
f"{len(backfill_map)} got BM25 scores"
)
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return {
"memories": memories,
"content": "\n\n".join(content_parts),
"keyword_raw": len(kw_results),
"embedding_raw": len(emb_results),
}
finally:
await connector.close()
async def _bm25_backfill(
self,
connector: Neo4jConnector,
query: str,
target_ids: set,
limit: int,
) -> List[dict]:
"""
对指定 id 集合补查全文索引 BM25 score。
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
"""
escaped = escape_lucene_query(query)
if not escaped.strip():
return []
try:
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率
)
all_hits = r.get("perceptuals", [])
return [h for h in all_hits if h.get("id") in target_ids]
except Exception as e:
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
return []
async def _keyword_search(
self,
connector: Neo4jConnector,
keywords: List[str],
limit: int,
) -> List[dict]:
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
seen_ids: set = set()
all_results: List[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id, limit=limit,
)
return r.get("perceptuals", [])
tasks = [_one(kw) for kw in keywords[:10]]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
connector: Neo4jConnector,
query_text: str,
limit: int,
) -> List[dict]:
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
try:
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
with get_db_context() as db:
cfg = MemoryConfigService(db).get_embedder_config(
str(self.memory_config.embedding_model_id)
)
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
r = await search_perceptual_by_embedding(
connector=connector, embedder_client=client,
query_text=query_text, end_user_id=self.end_user_id,
limit=limit,
)
return r.get("perceptuals", [])
except Exception as e:
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
return []
def _rerank(
self,
keyword_results: List[dict],
embedding_results: List[dict],
limit: int,
) -> List[dict]:
"""BM25 + embedding 融合排序。
对 embedding 结果中带有 bm25_backfill_score 的条目,
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
"""
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
emb_backfill_items = []
for item in embedding_results:
backfill_score = item.get("bm25_backfill_score")
if backfill_score is not None and item.get("id"):
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
# 合并后统一归一化 BM25 scores
all_bm25_items = keyword_results + emb_backfill_items
all_bm25_items = self._normalize_scores(all_bm25_items)
# 建立 id -> normalized BM25 score 的映射
bm25_norm_map: Dict[str, float] = {}
for item in all_bm25_items:
item_id = item.get("id", "")
if item_id:
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
# 归一化 embedding scores
embedding_results = self._normalize_scores(embedding_results)
# 合并
combined: Dict[str, dict] = {}
for item in keyword_results:
item_id = item.get("id", "")
if not item_id:
continue
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = 0.0
for item in embedding_results:
item_id = item.get("id", "")
if not item_id:
continue
if item_id in combined:
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
for item in combined.values():
bm25 = float(item.get("bm25_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
results = list(combined.values())
before = len(results)
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
)
return results
@staticmethod
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
"""Z-score + sigmoid 归一化。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
if len(scores) <= 1:
for it in items:
it[f"normalized_{field}"] = 1.0
return items
mean = sum(scores) / len(scores)
var = sum((s - mean) ** 2 for s in scores) / len(scores)
std = math.sqrt(var)
if std == 0:
for it in items:
it[f"normalized_{field}"] = 1.0
else:
for it, s in zip(items, scores):
z = (s - mean) / std
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
parts = []
if formatted["summary"]:
parts.append(formatted["summary"])
if formatted["topic"]:
parts.append(f"[主题: {formatted['topic']}]")
if formatted["keywords"]:
kw_list = formatted["keywords"]
if isinstance(kw_list, list):
parts.append(f"[关键词: {', '.join(kw_list)}]")
if formatted["file_name"]:
parts.append(f"[文件: {formatted['file_name']}]")
return " ".join(parts)
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
"""Extract search keywords from problem extension results."""
keywords = []
context = problem_extension.get("context", {})
if isinstance(context, dict):
for original_q, extended_qs in context.items():
keywords.append(original_q)
if isinstance(extended_qs, list):
keywords.extend(extended_qs)
return keywords
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
"""
LangGraph node: perceptual memory retrieval.
Uses PerceptualSearchService to run keyword + embedding search with
BM25 fusion reranking, then writes results to state['perceptual_data'].
"""
end_user_id = state.get("end_user_id", "")
problem_extension = state.get("problem_extension", {})
original_query = state.get("data", "")
memory_config = state.get("memory_config", None)
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
keywords = _extract_keywords_from_problems(problem_extension)
if not keywords:
keywords = [original_query] if original_query else []
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
search_result = await service.search(
query=original_query,
keywords=keywords,
limit=10,
)
result = {
"memories": search_result["memories"],
"content": search_result["content"],
"_intermediate": {
"type": "perceptual_retrieve",
"title": "感知记忆检索",
"data": search_result["memories"],
"query": original_query,
"result_count": len(search_result["memories"]),
},
}
return {"perceptual_data": result}

View File

@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem extension result: {aggregated_dict}") logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend # Emit intermediate output for frontend
print(time.time() - start)
result = { result = {
"context": aggregated_dict, "context": aggregated_dict,
"original": data, "original": data,

View File

@@ -1,7 +1,11 @@
import asyncio
import os import os
import time import time
from app.core.logging_config import get_agent_logger, log_time from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
PerceptualSearchService,
)
from app.core.memory.agent.models.summary_models import ( from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse, RetrieveSummaryResponse,
SummaryResponse, SummaryResponse,
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
try: try:
if storage_type != "rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
async def _perceptual_search():
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
return await service.search(query=data, limit=5)
hybrid_task = SearchService().execute_hybrid_search(
**search_params, **search_params,
memory_config=memory_config, memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement expand_communities=False,
) )
perceptual_task = _perceptual_search()
gather_results = await asyncio.gather(
hybrid_task, perceptual_task, return_exceptions=True
)
hybrid_result = gather_results[0]
perceptual_results = gather_results[1]
# 处理 hybrid search 异常
if isinstance(hybrid_result, Exception):
raise hybrid_result
retrieve_info, question, raw_results = hybrid_result
# 处理感知记忆结果
if isinstance(perceptual_results, Exception):
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
perceptual_results = []
# 拼接感知记忆内容到 retrieve_info
if perceptual_results and isinstance(perceptual_results, dict):
perceptual_content = perceptual_results.get("content", "")
if perceptual_content:
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
count = len(perceptual_results.get("memories", []))
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
# 调试:打印 community 检索结果数量 # 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict): if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {}) reranked = raw_results.get('reranked_results', {})
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
"error": str(e) "error": str(e)
} }
end = time.time() end = time.time()
try: duration = end - start
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration) log_time('检索', duration)
return {"summary": summary} return {"summary": summary}
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve_info_str = list(set(retrieve_info_str)) retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str) retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str, # Merge perceptual memory content
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
aimessages = await summary_llm(
state,
history,
retrieve_info_str,
'direct_summary_prompt.jinja2',
'retrieve_summary', RetrieveSummaryResponse,
"1"
)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages) await summary_redis_save(state, aimessages)
if aimessages == '': if aimessages == '':
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
retrieve_info_str += i + '\n' retrieve_info_str += i + '\n'
history = await summary_history(state) history = await summary_history(state)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = { data = {
"query": query, "query": query,
"history": history, "history": history,
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
if key == 'answer_small': if key == 'answer_small':
for i in value: for i in value:
retrieve_info_str += i + '\n' retrieve_info_str += i + '\n'
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = { data = {
"query": query, "query": query,
"history": history, "history": history,

View File

@@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve_nodes, retrieve_nodes,
) )
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary, Input_Summary,
Retrieve_Summary, Retrieve_Summary,
@@ -48,13 +51,14 @@ async def make_read_graph():
""" """
try: try:
# Build workflow graph # Build workflow graph
workflow = StateGraph(ReadState) workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node) workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem) workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension) workflow.add_node("Problem_Extension", Problem_Extension)
workflow.add_node("Input_Summary", Input_Summary) workflow.add_node("Input_Summary", Input_Summary)
workflow.add_node("Retrieve", retrieve_nodes) workflow.add_node("Retrieve", retrieve_nodes)
# workflow.add_node("Retrieve", retrieve) # workflow.add_node("Retrieve", retrieve)
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
workflow.add_node("Verify", Verify) workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary) workflow.add_node("Summary", Summary)
@@ -65,14 +69,15 @@ async def make_read_graph():
workflow.add_conditional_edges("content_input", Split_continue) workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END) workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension") workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve") # After Problem_Extension, retrieve perceptual memory first, then main Retrieve
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END) workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue) workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END) workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END) # workflow.add_edge("Retrieve", END)
# Compile workflow # Compile workflow
@@ -80,7 +85,5 @@ async def make_read_graph():
yield graph yield graph
except Exception as e: except Exception as e:
print(f"创建工作流失败: {e}") logger.error(f"创建工作流失败: {e}")
raise raise
finally:
print("工作流创建完成")

View File

@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化) # 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
async def expand_communities_to_statements( async def expand_communities_to_statements(
community_results: List[dict], community_results: List[dict],
end_user_id: str, end_user_id: str,
existing_content: str = "", existing_content: str = "",
limit: int = 10, limit: int = 10,
) -> Tuple[List[dict], List[str]]: ) -> Tuple[List[dict], List[str]]:
""" """
社区展开 helper给定命中的 community 列表,拉取关联 Statement。 社区展开 helper给定命中的 community 列表,拉取关联 Statement。
@@ -76,17 +75,18 @@ async def expand_communities_to_statements(
if s.get("statement") and s["statement"] not in existing_lines if s.get("statement") and s["statement"] not in existing_lines
] ]
cleaned = _clean_expand_fields(expanded_stmts) cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}") logger.info(
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts return cleaned, new_texts
class SearchService: class SearchService:
"""Service for executing hybrid search and processing results.""" """Service for executing hybrid search and processing results."""
def __init__(self): def __init__(self):
"""Initialize the search service.""" """Initialize the search service."""
logger.info("SearchService initialized") logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict, node_type: str = "") -> str: def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
""" """
Extract only meaningful content from search results, dropping all metadata. Extract only meaningful content from search results, dropping all metadata.
@@ -107,19 +107,19 @@ class SearchService:
""" """
if not isinstance(result, dict): if not isinstance(result, dict):
return str(result) return str(result)
content_parts = [] content_parts = []
# Statements: extract statement field # Statements: extract statement field
if 'statement' in result and result['statement']: if 'statement' in result and result['statement']:
content_parts.append(result['statement']) content_parts.append(result['statement'])
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = ( is_community = (
node_type == "community" node_type == "community"
or 'member_count' in result or 'member_count' in result
or 'core_entities' in result or 'core_entities' in result
) )
if is_community: if is_community:
name = result.get('name', '') name = result.get('name', '')
@@ -130,16 +130,16 @@ class SearchService:
elif 'content' in result and result['content']: elif 'content' in result and result['content']:
# Summaries / Chunks # Summaries / Chunks
content_parts.append(result['content']) content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original) # Entities: extract name and fact_summary (commented out in original)
# if 'name' in result and result['name']: # if 'name' in result and result['name']:
# content_parts.append(result['name']) # content_parts.append(result['name'])
# if result.get('fact_summary'): # if result.get('fact_summary'):
# content_parts.append(result['fact_summary']) # content_parts.append(result['fact_summary'])
# Return concatenated content or empty string # Return concatenated content or empty string
return '\n'.join(content_parts) if content_parts else "" return '\n'.join(content_parts) if content_parts else ""
def clean_query(self, query: str) -> str: def clean_query(self, query: str) -> str:
""" """
Clean and escape query text for Lucene. Clean and escape query text for Lucene.
@@ -155,33 +155,33 @@ class SearchService:
Cleaned and escaped query string Cleaned and escaped query string
""" """
q = str(query).strip() q = str(query).strip()
# Remove wrapping quotes # Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or ( if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"') q.startswith('"') and q.endswith('"')
): ):
q = q[1:-1] q = q[1:-1]
# Remove newlines and carriage returns # Remove newlines and carriage returns
q = q.replace('\r', ' ').replace('\n', ' ').strip() q = q.replace('\r', ' ').replace('\n', ' ').strip()
# Apply Lucene escaping # Apply Lucene escaping
q = escape_lucene_query(q) q = escape_lucene_query(q)
return q return q
async def execute_hybrid_search( async def execute_hybrid_search(
self, self,
end_user_id: str, end_user_id: str,
question: str, question: str,
limit: int = 5, limit: int = 5,
search_type: str = "hybrid", search_type: str = "hybrid",
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
rerank_alpha: float = 0.4, rerank_alpha: float = 0.4,
output_path: str = "search_results.json", output_path: str = "search_results.json",
return_raw_results: bool = False, return_raw_results: bool = False,
memory_config = None, memory_config=None,
expand_communities: bool = True, expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]: ) -> Tuple[str, str, Optional[dict]]:
""" """
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
@@ -205,10 +205,10 @@ class SearchService:
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries", "communities"] include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query # Clean query
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
try: try:
# Execute search # Execute search
answer = await run_hybrid_search( answer = await run_hybrid_search(
@@ -221,18 +221,18 @@ class SearchService:
memory_config=memory_config, memory_config=memory_config,
rerank_alpha=rerank_alpha rerank_alpha=rerank_alpha
) )
# Extract results based on search type and include parameter # Extract results based on search type and include parameter
# Prioritize summaries as they contain synthesized contextual information # Prioritize summaries as they contain synthesized contextual information
answer_list = [] answer_list = []
# For hybrid search, use reranked_results # For hybrid search, use reranked_results
if search_type == "hybrid": if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {}) reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities # Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in reranked_results: if category in include and category in reranked_results:
category_results = reranked_results[category] category_results = reranked_results[category]
@@ -242,7 +242,7 @@ class SearchService:
# For keyword or embedding search, results are directly in answer dict # For keyword or embedding search, results are directly in answer dict
# Apply same priority order # Apply same priority order
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in answer: if category in include and category in answer:
category_results = answer[category] category_results = answer[category]
@@ -261,7 +261,7 @@ class SearchService:
end_user_id=end_user_id, end_user_id=end_user_id,
) )
answer_list.extend(cleaned_stmts) answer_list.extend(cleaned_stmts)
# Extract clean content from all results按类型传入 node_type 区分 community # Extract clean content from all results按类型传入 node_type 区分 community
content_list = [] content_list = []
for ans in answer_list: for ans in answer_list:
@@ -269,19 +269,18 @@ class SearchService:
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype)) content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines # Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c]) clean_content = '\n'.join([c for c in content_list if c])
# Log first 200 chars # Log first 200 chars
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...") logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
# Return raw results if requested # Return raw results if requested
if return_raw_results: if return_raw_results:
return clean_content, cleaned_query, answer return clean_content, cleaned_query, answer
else: else:
return clean_content, cleaned_query, None return clean_content, cleaned_query, None
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Search failed for query '{question}' in group '{end_user_id}': {e}", f"Search failed for query '{question}' in group '{end_user_id}': {e}",

View File

@@ -1,4 +1,3 @@
import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Annotated, TypedDict from typing import Annotated, TypedDict
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
embedding_id: str embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象 memory_config: object # 新增字段用于传递内存配置对象
retrieve: dict retrieve: dict
perceptual_data: dict
RetrieveSummary: dict RetrieveSummary: dict
InputSummary: dict InputSummary: dict
verify: dict verify: dict

View File

@@ -43,6 +43,7 @@ load_dotenv()
logger = get_memory_logger(__name__) logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]: def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'.""" """Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None: if value is None:
@@ -75,7 +76,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if score_field == "activation_value" and score is None: if score_field == "activation_value" and score is None:
scores.append(None) # 保持 None稍后特殊处理 scores.append(None) # 保持 None稍后特殊处理
continue continue
if score is not None and isinstance(score, (int, float)): if score is not None and isinstance(score, (int, float)):
scores.append(float(score)) scores.append(float(score))
else: else:
@@ -83,10 +84,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if not scores: if not scores:
return results return results
# 过滤掉 None 值,只对有效分数进行归一化 # 过滤掉 None 值,只对有效分数进行归一化
valid_scores = [s for s in scores if s is not None] valid_scores = [s for s in scores if s is not None]
if not valid_scores: if not valid_scores:
# 所有分数都是 None不进行归一化 # 所有分数都是 None不进行归一化
for item in results: for item in results:
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
item[f"normalized_{score_field}"] = None item[f"normalized_{score_field}"] = None
return results return results
if len(valid_scores) == 1: # Single valid score, set to 1.0 if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores): for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value": if score_field in item or score_field == "activation_value":
if score is None: if score is None:
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Remove duplicate items from search results based on content. Remove duplicate items from search results based on content.
@@ -150,52 +150,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
seen_ids = set() seen_ids = set()
seen_content = set() seen_content = set()
deduplicated = [] deduplicated = []
for item in items: for item in items:
# Try multiple ID fields to identify unique items # Try multiple ID fields to identify unique items
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# Extract content from various possible fields # Extract content from various possible fields
content = ( content = (
item.get("text") or item.get("text") or
item.get("content") or item.get("content") or
item.get("statement") or item.get("statement") or
item.get("name") or item.get("name") or
"" ""
) )
# Normalize content for comparison (strip whitespace and lowercase) # Normalize content for comparison (strip whitespace and lowercase)
normalized_content = str(content).strip().lower() if content else "" normalized_content = str(content).strip().lower() if content else ""
# Check if we've seen this ID or content before # Check if we've seen this ID or content before
is_duplicate = False is_duplicate = False
if item_id and item_id in seen_ids: if item_id and item_id in seen_ids:
is_duplicate = True is_duplicate = True
elif normalized_content and normalized_content in seen_content: elif normalized_content and normalized_content in seen_content:
# Only check content duplication if content is not empty # Only check content duplication if content is not empty
is_duplicate = True is_duplicate = True
if not is_duplicate: if not is_duplicate:
# Mark as seen # Mark as seen
if item_id: if item_id:
seen_ids.add(item_id) seen_ids.add(item_id)
if normalized_content: # Only track non-empty content if normalized_content: # Only track non-empty content
seen_content.add(normalized_content) seen_content.add(normalized_content)
deduplicated.append(item) deduplicated.append(item)
return deduplicated return deduplicated
def rerank_with_activation( def rerank_with_activation(
keyword_results: Dict[str, List[Dict[str, Any]]], keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]], embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6, alpha: float = 0.6,
limit: int = 10, limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None, forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8, activation_boost_factor: float = 0.8,
now: datetime | None = None, now: datetime | None = None,
content_score_threshold: float = 0.5,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
两阶段排序:先按内容相关性筛选,再按激活值排序。 两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -222,6 +223,8 @@ def rerank_with_activation(
forgetting_config: 遗忘引擎配置(当前未使用) forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8) activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算) now: 当前时间(用于遗忘计算)
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score
低于此阈值的结果会被过滤。默认 0.5。
返回: 返回:
带评分元数据的重排序结果,按 final_score 排序 带评分元数据的重排序结果,按 final_score 排序
@@ -229,26 +232,26 @@ def rerank_with_activation(
# 验证权重范围 # 验证权重范围
if not (0 <= alpha <= 1): if not (0 <= alpha <= 1):
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}") raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
# 初始化遗忘引擎(如果需要) # 初始化遗忘引擎(如果需要)
engine = None engine = None
if forgetting_config: if forgetting_config:
engine = ForgettingEngine(forgetting_config) engine = ForgettingEngine(forgetting_config)
now_dt = now or datetime.now() now_dt = now or datetime.now()
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries", "communities"]: for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
# 步骤 1: 归一化分数 # 步骤 1: 归一化分数
keyword_items = normalize_scores(keyword_items, "score") keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score") embedding_items = normalize_scores(embedding_items, "score")
# 步骤 2: 按 ID 合并结果(去重) # 步骤 2: 按 ID 合并结果(去重)
combined_items: Dict[str, Dict[str, Any]] = {} combined_items: Dict[str, Dict[str, Any]] = {}
# 添加关键词结果 # 添加关键词结果
for item in keyword_items: for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
@@ -257,7 +260,7 @@ def rerank_with_activation(
combined_items[item_id] = item.copy() combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # 默认值 combined_items[item_id]["embedding_score"] = 0 # 默认值
# 添加或更新向量嵌入结果 # 添加或更新向量嵌入结果
for item in embedding_items: for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
@@ -271,18 +274,18 @@ def rerank_with_activation(
combined_items[item_id] = item.copy() combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # 默认值 combined_items[item_id]["bm25_score"] = 0 # 默认值
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# 步骤 3: 归一化激活度分数 # 步骤 3: 归一化激活度分数
# 为所有项准备激活度值列表 # 为所有项准备激活度值列表
items_list = list(combined_items.values()) items_list = list(combined_items.values())
items_list = normalize_scores(items_list, "activation_value") items_list = normalize_scores(items_list, "activation_value")
# 更新 combined_items 中的归一化激活度分数 # 更新 combined_items 中的归一化激活度分数
for item in items_list: for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items: if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value") combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数 # 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items(): for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0) bm25_norm = float(item.get("bm25_score", 0) or 0)
@@ -290,45 +293,45 @@ def rerank_with_activation(
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义 # normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value") raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding # 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数 base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序) # 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm # 可能为 None item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score item["content_score"] = content_score
item["base_score"] = base_score item["base_score"] = base_score
# 步骤 5: 应用遗忘曲线(可选) # 步骤 5: 应用遗忘曲线(可选)
if engine: if engine:
# 计算受激活度影响的记忆强度 # 计算受激活度影响的记忆强度
importance = float(item.get("importance_score", 0.5) or 0.5) importance = float(item.get("importance_score", 0.5) or 0.5)
# 获取 activation_value # 获取 activation_value
activation_val = item.get("activation_value") activation_val = item.get("activation_value")
# 只对有激活值的节点应用遗忘曲线 # 只对有激活值的节点应用遗忘曲线
if activation_val is not None and isinstance(activation_val, (int, float)): if activation_val is not None and isinstance(activation_val, (int, float)):
activation_val = float(activation_val) activation_val = float(activation_val)
# 计算记忆强度importance_score × (1 + activation_value × boost_factor) # 计算记忆强度importance_score × (1 + activation_value × boost_factor)
memory_strength = importance * (1 + activation_val * activation_boost_factor) memory_strength = importance * (1 + activation_val * activation_boost_factor)
# 计算经过的时间(天数) # 计算经过的时间(天数)
dt = _parse_datetime(item.get("created_at")) dt = _parse_datetime(item.get("created_at"))
if dt is None: if dt is None:
time_elapsed_days = 0.0 time_elapsed_days = 0.0
else: else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# 获取遗忘权重 # 获取遗忘权重
forgetting_weight = engine.calculate_weight( forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days, time_elapsed=time_elapsed_days,
memory_strength=memory_strength memory_strength=memory_strength
) )
# 应用到基础分数 # 应用到基础分数
item["forgetting_weight"] = forgetting_weight item["forgetting_weight"] = forgetting_weight
item["final_score"] = base_score * forgetting_weight item["final_score"] = base_score * forgetting_weight
@@ -338,7 +341,7 @@ def rerank_with_activation(
else: else:
# 不使用遗忘曲线 # 不使用遗忘曲线
item["final_score"] = base_score item["final_score"] = base_score
# 步骤 6: 两阶段排序和限制 # 步骤 6: 两阶段排序和限制
# 第一阶段按内容相关性base_score排序取 Top-K # 第一阶段按内容相关性base_score排序取 Top-K
first_stage_limit = limit * 3 # 可配置取3倍候选 first_stage_limit = limit * 3 # 可配置取3倍候选
@@ -347,11 +350,11 @@ def rerank_with_activation(
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序 key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
reverse=True reverse=True
)[:first_stage_limit] )[:first_stage_limit]
# 第二阶段:分离有激活值和无激活值的节点 # 第二阶段:分离有激活值和无激活值的节点
items_with_activation = [] items_with_activation = []
items_without_activation = [] items_without_activation = []
for item in first_stage_sorted: for item in first_stage_sorted:
activation_score = item.get("activation_score") activation_score = item.get("activation_score")
# 检查是否有有效的激活值(不是 None # 检查是否有有效的激活值(不是 None
@@ -359,14 +362,14 @@ def rerank_with_activation(
items_with_activation.append(item) items_with_activation.append(item)
else: else:
items_without_activation.append(item) items_without_activation.append(item)
# 优先按激活值排序有激活值的节点 # 优先按激活值排序有激活值的节点
sorted_with_activation = sorted( sorted_with_activation = sorted(
items_with_activation, items_with_activation,
key=lambda x: float(x.get("activation_score", 0) or 0), key=lambda x: float(x.get("activation_score", 0) or 0),
reverse=True reverse=True
) )
# 如果有激活值的节点不足 limit用无激活值的节点补充 # 如果有激活值的节点不足 limit用无激活值的节点补充
if len(sorted_with_activation) < limit: if len(sorted_with_activation) < limit:
needed = limit - len(sorted_with_activation) needed = limit - len(sorted_with_activation)
@@ -374,7 +377,7 @@ def rerank_with_activation(
sorted_items = sorted_with_activation + items_without_activation[:needed] sorted_items = sorted_with_activation + items_without_activation[:needed]
else: else:
sorted_items = sorted_with_activation[:limit] sorted_items = sorted_with_activation[:limit]
# 两阶段排序完成,更新 final_score 以反映实际排序依据 # 两阶段排序完成,更新 final_score 以反映实际排序依据
# Stage 1: 按 content_score 筛选候选(已完成) # Stage 1: 按 content_score 筛选候选(已完成)
# Stage 2: 按 activation_score 排序(已完成) # Stage 2: 按 activation_score 排序(已完成)
@@ -390,16 +393,29 @@ def rerank_with_activation(
else: else:
# 无激活值:使用内容相关性分数 # 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0) item["final_score"] = item.get("base_score", 0)
# 最终去重确保没有重复项 if content_score_threshold > 0:
before_count = len(sorted_items)
sorted_items = [
item for item in sorted_items
if float(item.get("content_score", 0) or 0) >= content_score_threshold
]
filtered_count = before_count - len(sorted_items)
if filtered_count > 0:
logger.info(
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
f"items below content_score_threshold={content_score_threshold}"
)
sorted_items = _deduplicate_results(sorted_items) sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items reranked[category] = sorted_items
return reranked return reranked
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None): def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
log_file: str = None):
"""Log search query information using the logger. """Log search query information using the logger.
Args: Args:
@@ -412,7 +428,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None,
""" """
# Ensure the query text is plain and clean before logging # Ensure the query text is plain and clean before logging
cleaned_query = extract_plain_query(query_text) cleaned_query = extract_plain_query(query_text)
# Log using the standard logger # Log using the standard logger
logger.info( logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, " f"Search query: query='{cleaned_query}', type={search_type}, "
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
def apply_reranker_placeholder( def apply_reranker_placeholder(
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
query_text: str, query_text: str,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Placeholder for a cross-encoder reranker. Placeholder for a cross-encoder reranker.
@@ -483,7 +499,7 @@ def apply_reranker_placeholder(
# ) -> Dict[str, List[Dict[str, Any]]]: # ) -> Dict[str, List[Dict[str, Any]]]:
# """ # """
# Apply LLM-based reranking to search results. # Apply LLM-based reranking to search results.
# Args: # Args:
# results: Search results organized by category # results: Search results organized by category
# query_text: Original search query # query_text: Original search query
@@ -491,7 +507,7 @@ def apply_reranker_placeholder(
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM) # llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
# top_k: Maximum number of items to rerank per category # top_k: Maximum number of items to rerank per category
# batch_size: Number of items to process concurrently # batch_size: Number of items to process concurrently
# Returns: # Returns:
# Reranked results with final_score and reranker_model fields # Reranked results with final_score and reranker_model fields
# """ # """
@@ -501,18 +517,18 @@ def apply_reranker_placeholder(
# # except Exception as e: # # except Exception as e:
# # logger.debug(f"Failed to load reranker config: {e}") # # logger.debug(f"Failed to load reranker config: {e}")
# # rc = {} # # rc = {}
# # Check if reranking is enabled # # Check if reranking is enabled
# enabled = rc.get("enabled", False) # enabled = rc.get("enabled", False)
# if not enabled: # if not enabled:
# logger.debug("LLM reranking is disabled in configuration") # logger.debug("LLM reranking is disabled in configuration")
# return results # return results
# # Load configuration parameters with defaults # # Load configuration parameters with defaults
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5) # llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
# top_k = top_k if top_k is not None else rc.get("top_k", 20) # top_k = top_k if top_k is not None else rc.get("top_k", 20)
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5) # batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# # Initialize reranker client if not provided # # Initialize reranker client if not provided
# if reranker_client is None: # if reranker_client is None:
# try: # try:
@@ -520,10 +536,10 @@ def apply_reranker_placeholder(
# except Exception as e: # except Exception as e:
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking") # logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
# return results # return results
# # Get model name for metadata # # Get model name for metadata
# model_name = getattr(reranker_client, 'model_name', 'unknown') # model_name = getattr(reranker_client, 'model_name', 'unknown')
# # Process each category # # Process each category
# reranked_results = {} # reranked_results = {}
# for category in ["statements", "chunks", "entities", "summaries"]: # for category in ["statements", "chunks", "entities", "summaries"]:
@@ -531,38 +547,38 @@ def apply_reranker_placeholder(
# if not items: # if not items:
# reranked_results[category] = [] # reranked_results[category] = []
# continue # continue
# # Select top K items by combined_score for reranking # # Select top K items by combined_score for reranking
# sorted_items = sorted( # sorted_items = sorted(
# items, # items,
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0), # key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
# reverse=True # reverse=True
# ) # )
# top_items = sorted_items[:top_k] # top_items = sorted_items[:top_k]
# remaining_items = sorted_items[top_k:] # remaining_items = sorted_items[top_k:]
# # Extract text content from each item # # Extract text content from each item
# def extract_text(item: Dict[str, Any]) -> str: # def extract_text(item: Dict[str, Any]) -> str:
# """Extract text content from a result item.""" # """Extract text content from a result item."""
# # Try different text fields based on category # # Try different text fields based on category
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or "" # text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
# return str(text).strip() # return str(text).strip()
# # Batch items for concurrent processing # # Batch items for concurrent processing
# batches = [] # batches = []
# for i in range(0, len(top_items), batch_size): # for i in range(0, len(top_items), batch_size):
# batch = top_items[i:i + batch_size] # batch = top_items[i:i + batch_size]
# batches.append(batch) # batches.append(batch)
# # Process batches concurrently # # Process batches concurrently
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]: # async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# """Process a batch of items with LLM relevance scoring.""" # """Process a batch of items with LLM relevance scoring."""
# scored_batch = [] # scored_batch = []
# for item in batch: # for item in batch:
# item_text = extract_text(item) # item_text = extract_text(item)
# # Skip items with no text # # Skip items with no text
# if not item_text: # if not item_text:
# item_copy = item.copy() # item_copy = item.copy()
@@ -572,7 +588,7 @@ def apply_reranker_placeholder(
# item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy) # scored_batch.append(item_copy)
# continue # continue
# # Create relevance scoring prompt # # Create relevance scoring prompt
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0. # prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
@@ -585,15 +601,15 @@ def apply_reranker_placeholder(
# - 1.0 means perfectly relevant # - 1.0 means perfectly relevant
# Relevance score:""" # Relevance score:"""
# # Send request to LLM # # Send request to LLM
# try: # try:
# messages = [{"role": "user", "content": prompt}] # messages = [{"role": "user", "content": prompt}]
# response = await reranker_client.chat(messages) # response = await reranker_client.chat(messages)
# # Parse LLM response to extract relevance score # # Parse LLM response to extract relevance score
# response_text = str(response.content if hasattr(response, 'content') else response).strip() # response_text = str(response.content if hasattr(response, 'content') else response).strip()
# # Try to extract a float from the response # # Try to extract a float from the response
# try: # try:
# # Remove any non-numeric characters except decimal point # # Remove any non-numeric characters except decimal point
@@ -608,11 +624,11 @@ def apply_reranker_placeholder(
# except (ValueError, AttributeError) as e: # except (ValueError, AttributeError) as e:
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}") # logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
# llm_score = None # llm_score = None
# # Calculate final score # # Calculate final score
# item_copy = item.copy() # item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0) # combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# if llm_score is not None: # if llm_score is not None:
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score # final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
# item_copy["llm_relevance_score"] = llm_score # item_copy["llm_relevance_score"] = llm_score
@@ -620,7 +636,7 @@ def apply_reranker_placeholder(
# # Use combined_score as fallback # # Use combined_score as fallback
# final_score = combined_score # final_score = combined_score
# item_copy["llm_relevance_score"] = combined_score # item_copy["llm_relevance_score"] = combined_score
# item_copy["final_score"] = final_score # item_copy["final_score"] = final_score
# item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy) # scored_batch.append(item_copy)
@@ -632,14 +648,14 @@ def apply_reranker_placeholder(
# item_copy["llm_relevance_score"] = combined_score # item_copy["llm_relevance_score"] = combined_score
# item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy) # scored_batch.append(item_copy)
# return scored_batch # return scored_batch
# # Process all batches concurrently # # Process all batches concurrently
# try: # try:
# batch_tasks = [process_batch(batch) for batch in batches] # batch_tasks = [process_batch(batch) for batch in batches]
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) # batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# # Merge batch results # # Merge batch results
# scored_items = [] # scored_items = []
# for result in batch_results: # for result in batch_results:
@@ -647,7 +663,7 @@ def apply_reranker_placeholder(
# logger.warning(f"Batch processing failed: {result}") # logger.warning(f"Batch processing failed: {result}")
# continue # continue
# scored_items.extend(result) # scored_items.extend(result)
# # Add remaining items (not in top K) with their combined_score as final_score # # Add remaining items (not in top K) with their combined_score as final_score
# for item in remaining_items: # for item in remaining_items:
# item_copy = item.copy() # item_copy = item.copy()
@@ -655,11 +671,11 @@ def apply_reranker_placeholder(
# item_copy["final_score"] = combined_score # item_copy["final_score"] = combined_score
# item_copy["reranker_model"] = model_name # item_copy["reranker_model"] = model_name
# scored_items.append(item_copy) # scored_items.append(item_copy)
# # Sort all items by final_score in descending order # # Sort all items by final_score in descending order
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True) # scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
# reranked_results[category] = scored_items # reranked_results[category] = scored_items
# except Exception as e: # except Exception as e:
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results") # logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# # Return original items with combined_score as final_score # # Return original items with combined_score as final_score
@@ -668,22 +684,22 @@ def apply_reranker_placeholder(
# item["final_score"] = combined_score # item["final_score"] = combined_score
# item["reranker_model"] = model_name # item["reranker_model"] = model_name
# reranked_results[category] = items # reranked_results[category] = items
# return reranked_results # return reranked_results
async def run_hybrid_search( async def run_hybrid_search(
query_text: str, query_text: str,
search_type: str, search_type: str,
end_user_id: str | None, end_user_id: str | None,
limit: int, limit: int,
include: List[str], include: List[str],
output_path: str | None, output_path: str | None,
memory_config: "MemoryConfig", memory_config: "MemoryConfig",
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8, activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
): ):
""" """
@@ -699,7 +715,7 @@ async def run_hybrid_search(
# Clean and normalize the incoming query before use/logging # Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text) query_text = extract_plain_query(query_text)
# Validate query is not empty after cleaning # Validate query is not empty after cleaning
if not query_text or not query_text.strip(): if not query_text or not query_text.strip():
logger.warning("Empty query after cleaning, returning empty results") logger.warning("Empty query after cleaning, returning empty results")
@@ -716,7 +732,7 @@ async def run_hybrid_search(
"error": "Empty query" "error": "Empty query"
} }
} }
# Log the search query # Log the search query
log_search_query(query_text, search_type, end_user_id, limit, include) log_search_query(query_text, search_type, end_user_id, limit, include)
@@ -747,7 +763,7 @@ async def run_hybrid_search(
# Embedding-based search # Embedding-based search
logger.info("[PERF] Starting embedding search...") logger.info("[PERF] Starting embedding search...")
embedding_start = time.time() embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig # 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time() config_load_start = time.time()
try: try:
@@ -768,7 +784,7 @@ async def run_hybrid_search(
embedder = OpenAIEmbedderClient(model_config=rb_config) embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start embedder_init_time = time.time() - embedder_init_start
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s") logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task( embedding_task = asyncio.create_task(
search_graph_by_embedding( search_graph_by_embedding(
connector=connector, connector=connector,
@@ -788,7 +804,7 @@ async def run_hybrid_search(
if keyword_task: if keyword_task:
keyword_results = await keyword_task keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start keyword_latency = time.time() - search_start_time
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4) latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s") logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword": if search_type == "keyword":
@@ -798,7 +814,7 @@ async def run_hybrid_search(
if embedding_task: if embedding_task:
embedding_results = await embedding_task embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start embedding_latency = time.time() - search_start_time
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4) latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s") logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding": if search_type == "embedding":
@@ -810,7 +826,8 @@ async def run_hybrid_search(
if search_type == "hybrid": if search_type == "hybrid":
results["combined_summary"] = { results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text, "search_query": query_text,
"search_timestamp": datetime.now().isoformat() "search_timestamp": datetime.now().isoformat()
} }
@@ -818,7 +835,7 @@ async def run_hybrid_search(
# Apply two-stage reranking with ACTR activation calculation # Apply two-stage reranking with ACTR activation calculation
rerank_start = time.time() rerank_start = time.time()
logger.info("[PERF] Using two-stage reranking with ACTR activation") logger.info("[PERF] Using two-stage reranking with ACTR activation")
# 加载遗忘引擎配置 # 加载遗忘引擎配置
config_start = time.time() config_start = time.time()
try: try:
@@ -829,7 +846,7 @@ async def run_hybrid_search(
forgetting_cfg = ForgettingEngineConfig() forgetting_cfg = ForgettingEngineConfig()
config_time = time.time() - config_start config_time = time.time() - config_start
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s") logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
# 统一使用激活度重排序(两阶段:检索 + ACTR计算 # 统一使用激活度重排序(两阶段:检索 + ACTR计算
rerank_compute_start = time.time() rerank_compute_start = time.time()
reranked_results = rerank_with_activation( reranked_results = rerank_with_activation(
@@ -842,14 +859,14 @@ async def run_hybrid_search(
) )
rerank_compute_time = time.time() - rerank_compute_start rerank_compute_time = time.time() - rerank_compute_start
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s") logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
rerank_latency = time.time() - rerank_start rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4) latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s") logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
# Optional: apply reranker placeholder if enabled via config # Optional: apply reranker placeholder if enabled via config
reranked_results = apply_reranker_placeholder(reranked_results, query_text) reranked_results = apply_reranker_placeholder(reranked_results, query_text)
# Apply LLM reranking if enabled # Apply LLM reranking if enabled
llm_rerank_applied = False llm_rerank_applied = False
# if use_llm_rerank: # if use_llm_rerank:
@@ -862,11 +879,12 @@ async def run_hybrid_search(
# logger.info("LLM reranking applied successfully") # logger.info("LLM reranking applied successfully")
# except Exception as e: # except Exception as e:
# logger.warning(f"LLM reranking failed: {e}, using previous scores") # logger.warning(f"LLM reranking failed: {e}, using previous scores")
results["reranked_results"] = reranked_results results["reranked_results"] = reranked_results
results["combined_summary"] = { results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()), "total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()), "total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()), "total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text, "search_query": query_text,
"search_timestamp": datetime.now().isoformat(), "search_timestamp": datetime.now().isoformat(),
@@ -879,13 +897,13 @@ async def run_hybrid_search(
# Calculate total latency # Calculate total latency
total_latency = time.time() - search_start_time total_latency = time.time() - search_start_time
latency_metrics["total_latency"] = round(total_latency, 4) latency_metrics["total_latency"] = round(total_latency, 4)
# Add latency metrics to results # Add latency metrics to results
if "combined_summary" in results: if "combined_summary" in results:
results["combined_summary"]["latency_metrics"] = latency_metrics results["combined_summary"]["latency_metrics"] = latency_metrics
else: else:
results["latency_metrics"] = latency_metrics results["latency_metrics"] = latency_metrics
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====") logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s") logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}") logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
@@ -908,8 +926,10 @@ async def run_hybrid_search(
# Log search completion with result count # Log search completion with result count
if search_type == "hybrid": if search_type == "hybrid":
result_counts = { result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()}, "keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()} keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
embedding_results.items()}
} }
else: else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()} result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
@@ -927,12 +947,12 @@ async def run_hybrid_search(
async def search_by_temporal( async def search_by_temporal(
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 1, limit: int = 1,
): ):
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -968,13 +988,13 @@ async def search_by_temporal(
async def search_by_keyword_temporal( async def search_by_keyword_temporal(
query_text: str, query_text: str,
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 1, limit: int = 1,
): ):
""" """
Temporal keyword search across Statements. Temporal keyword search across Statements.
@@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id( async def search_chunk_by_chunk_id(
chunk_id: str, chunk_id: str,
end_user_id: Optional[str] = "test", end_user_id: Optional[str] = "test",
limit: int = 1, limit: int = 1,
): ):
""" """
Search for Chunks by chunk_id. Search for Chunks by chunk_id.
@@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id(
limit=limit limit=limit
) )
return {"chunks": chunks} return {"chunks": chunks}

View File

@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
else: else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes() await create_all_indexes()
logger.info("All neo4j indexes and constraints created successfully!")
logger.info("应用程序启动完成") logger.info("应用程序启动完成")

View File

@@ -1,17 +1,17 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes(): async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring.""" """Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# 创建 Statements 索引 # 创建 Statements 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
# # 创建 Dialogues 索引 # # 创建 Dialogues 索引
# await connector.execute_query(""" # await connector.execute_query("""
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
@@ -21,27 +21,35 @@ async def create_fulltext_indexes():
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
# 创建 Chunks 索引 # 创建 Chunks 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
# 创建 MemorySummary 索引 # 创建 MemorySummary 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
# 创建 Community 索引 # 创建 Community 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
# 创建 Perceptual 感知记忆索引
await connector.execute_query("""
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""")
finally: finally:
await connector.close() await connector.close()
async def create_vector_indexes(): async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search. """Create vector indexes for fast embedding similarity search.
@@ -50,8 +58,7 @@ async def create_vector_indexes():
""" """
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# Statement embedding index # Statement embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
@@ -62,8 +69,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
# Chunk embedding index # Chunk embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
@@ -75,7 +81,6 @@ async def create_vector_indexes():
}} }}
""") """)
# Entity name embedding index # Entity name embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
@@ -86,8 +91,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
# Memory summary embedding index # Memory summary embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
@@ -98,7 +102,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
# Community summary embedding index # Community summary embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
@@ -108,8 +112,8 @@ async def create_vector_indexes():
`vector.dimensions`: 1024, `vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
# Dialogue embedding index (optional) # Dialogue embedding index (optional)
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
@@ -120,15 +124,27 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
# Perceptual summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
FOR (p:Perceptual)
ON p.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
finally: finally:
await connector.close() await connector.close()
async def create_unique_constraints(): async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers. """Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates. Ensures concurrent MERGE operations remain safe and prevents duplicates.
""" """
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
# Dialogue.id unique # Dialogue.id unique
await connector.execute_query( await connector.execute_query(
""" """
@@ -136,7 +152,7 @@ async def create_unique_constraints():
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
""" """
) )
# Statement.id unique # Statement.id unique
await connector.execute_query( await connector.execute_query(
""" """
@@ -144,7 +160,7 @@ async def create_unique_constraints():
FOR (s:Statement) REQUIRE s.id IS UNIQUE FOR (s:Statement) REQUIRE s.id IS UNIQUE
""" """
) )
# Chunk.id unique # Chunk.id unique
await connector.execute_query( await connector.execute_query(
""" """
@@ -152,13 +168,13 @@ async def create_unique_constraints():
FOR (c:Chunk) REQUIRE c.id IS UNIQUE FOR (c:Chunk) REQUIRE c.id IS UNIQUE
""" """
) )
finally: finally:
await connector.close() await connector.close()
async def create_all_indexes(): async def create_all_indexes():
"""Create all indexes and constraints in one go.""" """Create all indexes and constraints in one go."""
await create_fulltext_indexes() await create_fulltext_indexes()
await create_vector_indexes() await create_vector_indexes()
await create_unique_constraints() await create_unique_constraints()
print("✓ All indexes and constraints created successfully!")

View File

@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
r.created_at = edge.created_at r.created_at = edge.created_at
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
SEARCH_PERCEPTUAL_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
WHERE p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""
PERCEPTUAL_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
YIELD node AS p, score
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
RETURN p.id AS id,
p.end_user_id AS end_user_id,
p.perceptual_type AS perceptual_type,
p.file_path AS file_path,
p.file_name AS file_name,
p.file_ext AS file_ext,
p.summary AS summary,
p.keywords AS keywords,
p.topic AS topic,
p.domain AS domain,
p.created_at AS created_at,
p.file_type AS file_type,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS, EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
PERCEPTUAL_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_COMMUNITIES_BY_KEYWORD,
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
async def _update_activation_values_batch( async def _update_activation_values_batch(
connector: Neo4jConnector, connector: Neo4jConnector,
nodes: List[Dict[str, Any]], nodes: List[Dict[str, Any]],
node_label: str, node_label: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
max_retries: int = 3 max_retries: int = 3
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
批量更新节点的激活值 批量更新节点的激活值
@@ -58,7 +60,7 @@ async def _update_activation_values_batch(
""" """
if not nodes: if not nodes:
return [] return []
# 延迟导入以避免循环依赖 # 延迟导入以避免循环依赖
from app.core.memory.storage_services.forgetting_engine.access_history_manager import ( from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager, AccessHistoryManager,
@@ -66,7 +68,7 @@ async def _update_activation_values_batch(
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator, ACTRCalculator,
) )
# 创建计算器和管理器实例 # 创建计算器和管理器实例
actr_calculator = ACTRCalculator() actr_calculator = ACTRCalculator()
access_manager = AccessHistoryManager( access_manager = AccessHistoryManager(
@@ -74,7 +76,7 @@ async def _update_activation_values_batch(
actr_calculator=actr_calculator, actr_calculator=actr_calculator,
max_retries=max_retries max_retries=max_retries
) )
# 提取节点ID列表并去重保持原始顺序 # 提取节点ID列表并去重保持原始顺序
seen_ids = set() seen_ids = set()
unique_node_ids = [] unique_node_ids = []
@@ -83,7 +85,7 @@ async def _update_activation_values_batch(
if node_id and node_id not in seen_ids: if node_id and node_id not in seen_ids:
seen_ids.add(node_id) seen_ids.add(node_id)
unique_node_ids.append(node_id) unique_node_ids.append(node_id)
if not unique_node_ids: if not unique_node_ids:
logger.warning(f"批量更新激活值没有有效的节点ID") logger.warning(f"批量更新激活值没有有效的节点ID")
return nodes return nodes
@@ -95,7 +97,7 @@ async def _update_activation_values_batch(
f"批量更新激活值检测到重复节点具有有效ID的节点数量={id_nodes_count}, " f"批量更新激活值检测到重复节点具有有效ID的节点数量={id_nodes_count}, "
f"去重后唯一ID数量={len(unique_node_ids)}" f"去重后唯一ID数量={len(unique_node_ids)}"
) )
# 批量记录访问 # 批量记录访问
try: try:
updated_nodes = await access_manager.record_batch_access( updated_nodes = await access_manager.record_batch_access(
@@ -103,14 +105,14 @@ async def _update_activation_values_batch(
node_label=node_label, node_label=node_label,
end_user_id=end_user_id end_user_id=end_user_id
) )
logger.info( logger.info(
f"批量更新激活值成功: {node_label}, " f"批量更新激活值成功: {node_label}, "
f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}" f"更新数量={len(updated_nodes)}/{len(unique_node_ids)}"
) )
return updated_nodes return updated_nodes
except Exception as e: except Exception as e:
logger.error( logger.error(
f"批量更新激活值失败: {node_label}, 错误: {str(e)}" f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
async def _update_search_results_activation( async def _update_search_results_activation(
connector: Neo4jConnector, connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]], results: Dict[str, List[Dict[str, Any]]],
end_user_id: Optional[str] = None end_user_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
更新搜索结果中所有知识节点的激活值 更新搜索结果中所有知识节点的激活值
@@ -144,11 +146,11 @@ async def _update_search_results_activation(
'entities': 'ExtractedEntity', 'entities': 'ExtractedEntity',
'summaries': 'MemorySummary' 'summaries': 'MemorySummary'
} }
# 并行更新所有类型的节点 # 并行更新所有类型的节点
update_tasks = [] update_tasks = []
update_keys = [] update_keys = []
for key, label in knowledge_node_types.items(): for key, label in knowledge_node_types.items():
if key in results and results[key]: if key in results and results[key]:
update_tasks.append( update_tasks.append(
@@ -160,13 +162,13 @@ async def _update_search_results_activation(
) )
) )
update_keys.append(key) update_keys.append(key)
if not update_tasks: if not update_tasks:
return results return results
# 并行执行所有更新 # 并行执行所有更新
update_results = await asyncio.gather(*update_tasks, return_exceptions=True) update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
# 更新结果字典,保留原始搜索分数 # 更新结果字典,保留原始搜索分数
updated_results = results.copy() updated_results = results.copy()
for key, update_result in zip(update_keys, update_results): for key, update_result in zip(update_keys, update_results):
@@ -175,10 +177,10 @@ async def _update_search_results_activation(
# 保留原始的 score 字段BM25/Embedding 分数) # 保留原始的 score 字段BM25/Embedding 分数)
original_nodes = results[key] original_nodes = results[key]
updated_nodes = update_result updated_nodes = update_result
# 创建 ID 到更新节点的映射(用于快速查找激活值数据) # 创建 ID 到更新节点的映射(用于快速查找激活值数据)
updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')} updated_map = {node.get('id'): node for node in updated_nodes if node.get('id')}
# 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充 # 合并数据:保留所有原始节点(包括重复的),用更新后的激活值数据填充
merged_nodes = [] merged_nodes = []
for original_node in original_nodes: for original_node in original_nodes:
@@ -186,7 +188,7 @@ async def _update_search_results_activation(
if node_id and node_id in updated_map: if node_id and node_id in updated_map:
# 从原始节点开始,用更新后的激活值数据覆盖 # 从原始节点开始,用更新后的激活值数据覆盖
merged_node = original_node.copy() merged_node = original_node.copy()
# 更新激活值相关字段 # 更新激活值相关字段
activation_fields = { activation_fields = {
'activation_value', 'activation_value',
@@ -196,35 +198,35 @@ async def _update_search_results_activation(
'importance_score', 'importance_score',
'version', 'version',
'statement', # Statement 节点的内容字段 'statement', # Statement 节点的内容字段
'content' # MemorySummary 节点的内容字段 'content' # MemorySummary 节点的内容字段
} }
# 只更新激活值相关字段,保留原始节点的其他字段 # 只更新激活值相关字段,保留原始节点的其他字段
for field in activation_fields: for field in activation_fields:
if field in updated_map[node_id]: if field in updated_map[node_id]:
merged_node[field] = updated_map[node_id][field] merged_node[field] = updated_map[node_id][field]
merged_nodes.append(merged_node) merged_nodes.append(merged_node)
else: else:
# 如果没有更新数据,保留原始节点 # 如果没有更新数据,保留原始节点
merged_nodes.append(original_node) merged_nodes.append(original_node)
updated_results[key] = merged_nodes updated_results[key] = merged_nodes
else: else:
# 更新失败,记录错误但保留原始结果 # 更新失败,记录错误但保留原始结果
logger.warning( logger.warning(
f"更新 {key} 激活值失败: {str(update_result)}" f"更新 {key} 激活值失败: {str(update_result)}"
) )
return updated_results return updated_results
async def search_graph( async def search_graph(
connector: Neo4jConnector, connector: Neo4jConnector,
q: str, q: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = None, include: List[str] = None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Search across Statements, Entities, Chunks, and Summaries using a free-text query. Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -249,41 +251,45 @@ async def search_graph(
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = ["statements", "chunks", "entities", "summaries"]
# Prepare tasks for parallel execution # Prepare tasks for parallel execution
tasks = [] tasks = []
task_keys = [] task_keys = []
if "statements" in include: if "statements" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD, SEARCH_STATEMENTS_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("statements") task_keys.append("statements")
if "entities" in include: if "entities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS, SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("entities") task_keys.append("entities")
if "chunks" in include: if "chunks" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT, SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("chunks") task_keys.append("chunks")
if "summaries" in include: if "summaries" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -293,15 +299,16 @@ async def search_graph(
if "communities" in include: if "communities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD, SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
q=q, q=q,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
)) ))
task_keys.append("communities") task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Build results dictionary # Build results dictionary
results = {} results = {}
for key, result in zip(task_keys, task_results): for key, result in zip(task_keys, task_results):
@@ -310,14 +317,14 @@ async def search_graph(
results[key] = [] results[key] = []
else: else:
results[key] = result results[key] = result
# Deduplicate results before updating activation values # Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline # This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results from app.core.memory.src.search import _deduplicate_results
for key in results: for key in results:
if isinstance(results[key], list): if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key]) results[key] = _deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization) # Skip activation updates if only searching summaries (optimization)
needs_activation_update = any( needs_activation_update = any(
@@ -331,17 +338,17 @@ async def search_graph(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_graph_by_embedding( async def search_graph_by_embedding(
connector: Neo4jConnector, connector: Neo4jConnector,
embedder_client, embedder_client,
query_text: str, query_text: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = ["statements", "chunks", "entities","summaries"], include: List[str] = ["statements", "chunks", "entities", "summaries"],
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Embedding-based semantic search across Statements, Chunks, and Entities. Embedding-based semantic search across Statements, Chunks, and Entities.
@@ -355,13 +362,13 @@ async def search_graph_by_embedding(
- Returns up to 'limit' per included type - Returns up to 'limit' per included type
""" """
import time import time
# Get embedding for the query # Get embedding for the query
embed_start = time.time() embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start embed_time = time.time() - embed_start
print(f"[PERF] Embedding generation took: {embed_time:.4f}s") logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
logger.warning( logger.warning(
f"search_graph_by_embedding: embedding 生成失败或为空," f"search_graph_by_embedding: embedding 生成失败或为空,"
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
if "statements" in include: if "statements" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
STATEMENT_EMBEDDING_SEARCH, STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
if "chunks" in include: if "chunks" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH, CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
if "entities" in include: if "entities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH, ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
if "summaries" in include: if "summaries" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH, MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
if "communities" in include: if "communities" in include:
tasks.append(connector.execute_query( tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH, COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding, embedding=embedding,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
@@ -428,8 +440,8 @@ async def search_graph_by_embedding(
query_start = time.time() query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start query_time = time.time() - query_start
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {
"statements": [], "statements": [],
@@ -438,7 +450,7 @@ async def search_graph_by_embedding(
"summaries": [], "summaries": [],
"communities": [], "communities": [],
} }
for key, result in zip(task_keys, task_results): for key, result in zip(task_keys, task_results):
if isinstance(result, Exception): if isinstance(result, Exception):
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}") logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
logger.info(f"[PERF] Skipping activation updates (only summaries)") logger.info(f"[PERF] Skipping activation updates (only summaries)")
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: str, end_user_id: str,
entities: List[Dict[str, Any]], entities: List[Dict[str, Any]],
use_contains_fallback: bool = True, use_contains_fallback: bool = True,
batch_size: int = 500, batch_size: int = 500,
max_concurrency: int = 5, max_concurrency: int = 5,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
async def search_graph_by_keyword_temporal( async def search_graph_by_keyword_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
query_text: str, query_text: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 50, limit: int = 50,
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
""" """
Temporal keyword search across Statements. Temporal keyword search across Statements.
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
- Returns up to 'limit' statements - Returns up to 'limit' statements
""" """
if not query_text: if not query_text:
print(f"query_text不能为空") logger.warning(f"query_text不能为空")
return {"statements": []} return {"statements": []}
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
invalid_date=invalid_date, invalid_date=invalid_date,
limit=limit, limit=limit,
) )
print(f"查询结果为:\n{statements}") logger.debug(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
async def search_graph_by_temporal( async def search_graph_by_temporal(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
start_date: Optional[str] = None, start_date: Optional[str] = None,
end_date: Optional[str] = None, end_date: Optional[str] = None,
valid_date: Optional[str] = None, valid_date: Optional[str] = None,
invalid_date: Optional[str] = None, invalid_date: Optional[str] = None,
limit: int = 10, limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -643,15 +653,15 @@ async def search_graph_by_temporal(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_graph_by_dialog_id( async def search_graph_by_dialog_id(
connector: Neo4jConnector, connector: Neo4jConnector,
dialog_id: str, dialog_id: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Dialogues. Temporal search across Dialogues.
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
- Returns up to 'limit' dialogues - Returns up to 'limit' dialogues
""" """
if not dialog_id: if not dialog_id:
print(f"dialog_id不能为空") logger.warning(f"dialog_id不能为空")
return {"dialogues": []} return {"dialogues": []}
dialogues = await connector.execute_query( dialogues = await connector.execute_query(
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
async def search_graph_by_chunk_id( async def search_graph_by_chunk_id(
connector: Neo4jConnector, connector: Neo4jConnector,
chunk_id : str, chunk_id: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 1, limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
if not chunk_id: if not chunk_id:
print(f"chunk_id不能为空") logger.warning(f"chunk_id不能为空")
return {"chunks": []} return {"chunks": []}
chunks = await connector.execute_query( chunks = await connector.execute_query(
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
async def search_graph_community_expand( async def search_graph_community_expand(
connector: Neo4jConnector, connector: Neo4jConnector,
community_ids: List[str], community_ids: List[str],
end_user_id: str, end_user_id: str,
limit: int = 10, limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
三期:社区展开检索 —— 主题 → 细节两级检索。 三期:社区展开检索 —— 主题 → 细节两级检索。
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
async def search_graph_by_created_at( async def search_graph_by_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -767,16 +776,11 @@ async def search_graph_by_created_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -784,16 +788,16 @@ async def search_graph_by_created_at(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_graph_by_valid_at( async def search_graph_by_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -807,16 +811,11 @@ async def search_graph_by_valid_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_BY_VALID_AT, SEARCH_STATEMENTS_BY_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id} valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -824,16 +823,16 @@ async def search_graph_by_valid_at(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_graph_g_created_at( async def search_graph_g_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -847,16 +846,11 @@ async def search_graph_g_created_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_CREATED_AT, SEARCH_STATEMENTS_G_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -864,16 +858,16 @@ async def search_graph_g_created_at(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_graph_g_valid_at( async def search_graph_g_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_G_VALID_AT, SEARCH_STATEMENTS_G_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -904,16 +892,16 @@ async def search_graph_g_valid_at(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_graph_l_created_at( async def search_graph_l_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
created_at: Optional[str] = None,
created_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -927,16 +915,11 @@ async def search_graph_l_created_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_L_CREATED_AT, SEARCH_STATEMENTS_L_CREATED_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
created_at=created_at, created_at=created_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -944,16 +927,16 @@ async def search_graph_l_created_at(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_graph_l_valid_at( async def search_graph_l_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
valid_at: Optional[str] = None,
valid_at: Optional[str] = None, limit: int = 1,
limit: int = 1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
@@ -967,16 +950,11 @@ async def search_graph_l_valid_at(
statements = await connector.execute_query( statements = await connector.execute_query(
SEARCH_STATEMENTS_L_VALID_AT, SEARCH_STATEMENTS_L_VALID_AT,
end_user_id=end_user_id, end_user_id=end_user_id,
valid_at=valid_at, valid_at=valid_at,
limit=limit, limit=limit,
) )
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}")
# 更新 Statement 节点的激活值 # 更新 Statement 节点的激活值
results = {"statements": statements} results = {"statements": statements}
results = await _update_search_results_activation( results = await _update_search_results_activation(
@@ -984,5 +962,89 @@ async def search_graph_l_valid_at(
results=results, results=results,
end_user_id=end_user_id end_user_id=end_user_id
) )
return results return results
async def search_perceptual(
connector: Neo4jConnector,
q: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
q: Query text
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
q=q,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
PERCEPTUAL_EMBEDDING_SEARCH,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}

View File

@@ -11,10 +11,28 @@ Classes:
from typing import Any, List, Dict from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth from neo4j import AsyncGraphDatabase, basic_auth
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
from app.core.config import settings from app.core.config import settings
def _convert_neo4j_types(value: Any) -> Any:
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
if isinstance(value, Neo4jDateTime):
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
if isinstance(value, Neo4jDate):
return value.iso_format()
if isinstance(value, Neo4jTime):
return value.iso_format()
if isinstance(value, Neo4jDuration):
return str(value)
if isinstance(value, dict):
return {k: _convert_neo4j_types(v) for k, v in value.items()}
if isinstance(value, list):
return [_convert_neo4j_types(item) for item in value]
return value
class Neo4jConnector: class Neo4jConnector:
"""Neo4j数据库连接器 """Neo4j数据库连接器
@@ -59,11 +77,12 @@ class Neo4jConnector:
""" """
await self.driver.close() await self.driver.close()
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]: async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询 """执行Cypher查询
Args: Args:
query: Cypher查询语句 query: Cypher查询语句
json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询 **kwargs: 查询参数将作为参数传递给Cypher查询
Returns: Returns:
@@ -78,7 +97,10 @@ class Neo4jConnector:
**kwargs **kwargs
) )
records, summary, keys = result records, summary, keys = result
return [record.data() for record in records] if json_format:
return [_convert_neo4j_types(record.data()) for record in records]
else:
return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any: async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作 """在写事务中执行操作

View File

@@ -462,11 +462,6 @@ class MemoryAgentService:
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器
config_load_start = time.time() config_load_start = time.time()
try: try:
# Use a separate database session to avoid transaction failures # Use a separate database session to avoid transaction failures
@@ -507,10 +502,13 @@ class MemoryAgentService:
async with make_read_graph() as graph: async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}} config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, initial_state = {
"end_user_id": end_user_id "messages": [HumanMessage(content=message)],
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "search_switch": search_switch,
"memory_config": memory_config} "end_user_id": end_user_id
, "storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息 # 获取节点更新信息
_intermediate_outputs = [] _intermediate_outputs = []
summary = '' summary = ''
@@ -522,7 +520,7 @@ class MemoryAgentService:
for node_name, node_data in update_event.items(): for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name: # if 'save_neo4j' == node_name:
# massages = node_data # massages = node_data
print(f"处理节点: {node_name}") logger.info(f"处理节点: {node_name}")
# 处理不同Summary节点的返回结构 # 处理不同Summary节点的返回结构
if 'Summary' in node_name: if 'Summary' in node_name:
@@ -549,6 +547,11 @@ class MemoryAgentService:
if retrieve_node and retrieve_node != [] and retrieve_node != {}: if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node) _intermediate_outputs.extend(retrieve_node)
# Perceptual_Retrieve 节点
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
_intermediate_outputs.append(perceptual_node)
# Verify 节点 # Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None) verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}: if verify_n and verify_n != [] and verify_n != {}: