feat(memory): add perceptual memory retrieval service with BM25+embedding fusion
This commit is contained in:
@@ -0,0 +1,408 @@
|
|||||||
|
"""
|
||||||
|
Perceptual Memory Retrieval Node & Service
|
||||||
|
|
||||||
|
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
|
||||||
|
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
|
||||||
|
with BM25+embedding fusion reranking.
|
||||||
|
|
||||||
|
Also provides the perceptual_retrieve_node for use as a LangGraph node.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import math
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||||
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
|
from app.repositories.neo4j.graph_search import (
|
||||||
|
search_perceptual,
|
||||||
|
search_perceptual_by_embedding,
|
||||||
|
)
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualSearchService:
|
||||||
|
"""
|
||||||
|
感知记忆检索服务。
|
||||||
|
|
||||||
|
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
|
||||||
|
调用方只需提供 query / keywords、end_user_id、memory_config,即可获得
|
||||||
|
格式化并排序后的感知记忆列表和拼接文本。
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
service = PerceptualSearchService(end_user_id=..., memory_config=...)
|
||||||
|
results = await service.search(query="...", keywords=[...], limit=10)
|
||||||
|
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_ALPHA = 0.6
|
||||||
|
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
memory_config: Any,
|
||||||
|
alpha: float = DEFAULT_ALPHA,
|
||||||
|
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
|
||||||
|
):
|
||||||
|
self.end_user_id = end_user_id
|
||||||
|
self.memory_config = memory_config
|
||||||
|
self.alpha = alpha
|
||||||
|
self.content_score_threshold = content_score_threshold
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
keywords: Optional[List[str]] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
|
||||||
|
|
||||||
|
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
|
||||||
|
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 原始用户查询(用于向量检索和 BM25 补查)
|
||||||
|
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
|
||||||
|
limit: 最大返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"memories": [格式化后的记忆 dict, ...],
|
||||||
|
"content": "拼接的纯文本摘要",
|
||||||
|
"keyword_raw": int,
|
||||||
|
"embedding_raw": int,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
if keywords is None:
|
||||||
|
keywords = [query] if query else []
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
kw_task = self._keyword_search(connector, keywords, limit)
|
||||||
|
emb_task = self._embedding_search(connector, query, limit)
|
||||||
|
|
||||||
|
kw_results, emb_results = await asyncio.gather(
|
||||||
|
kw_task, emb_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
if isinstance(kw_results, Exception):
|
||||||
|
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||||
|
kw_results = []
|
||||||
|
if isinstance(emb_results, Exception):
|
||||||
|
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||||
|
emb_results = []
|
||||||
|
|
||||||
|
# 补查 BM25:找出 embedding 命中但 keyword 未命中的 id,
|
||||||
|
# 用原始 query 对这些节点补查全文索引拿 BM25 score
|
||||||
|
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
|
||||||
|
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
|
||||||
|
|
||||||
|
if emb_only_ids and query:
|
||||||
|
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
|
||||||
|
# 把补查到的 BM25 score 注入到 embedding 结果中
|
||||||
|
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
|
||||||
|
for r in emb_results:
|
||||||
|
rid = r.get("id", "")
|
||||||
|
if rid in backfill_map:
|
||||||
|
r["bm25_backfill_score"] = backfill_map[rid]
|
||||||
|
logger.info(
|
||||||
|
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
|
||||||
|
f"{len(backfill_map)} got BM25 scores"
|
||||||
|
)
|
||||||
|
|
||||||
|
reranked = self._rerank(kw_results, emb_results, limit)
|
||||||
|
|
||||||
|
memories = []
|
||||||
|
content_parts = []
|
||||||
|
for record in reranked:
|
||||||
|
fmt = self._format_result(record)
|
||||||
|
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||||
|
memories.append(fmt)
|
||||||
|
content_parts.append(self._build_content_text(fmt))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||||
|
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"memories": memories,
|
||||||
|
"content": "\n\n".join(content_parts),
|
||||||
|
"keyword_raw": len(kw_results),
|
||||||
|
"embedding_raw": len(emb_results),
|
||||||
|
}
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
async def _bm25_backfill(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
query: str,
|
||||||
|
target_ids: set,
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
对指定 id 集合补查全文索引 BM25 score。
|
||||||
|
|
||||||
|
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
|
||||||
|
"""
|
||||||
|
escaped = escape_lucene_query(query)
|
||||||
|
if not escaped.strip():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
r = await search_perceptual(
|
||||||
|
connector=connector, q=escaped,
|
||||||
|
end_user_id=self.end_user_id,
|
||||||
|
limit=limit * 5, # 多查一些以提高命中率
|
||||||
|
)
|
||||||
|
all_hits = r.get("perceptuals", [])
|
||||||
|
return [h for h in all_hits if h.get("id") in target_ids]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _keyword_search(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
keywords: List[str],
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
|
||||||
|
seen_ids: set = set()
|
||||||
|
all_results: List[dict] = []
|
||||||
|
|
||||||
|
async def _one(kw: str):
|
||||||
|
escaped = escape_lucene_query(kw)
|
||||||
|
if not escaped.strip():
|
||||||
|
return []
|
||||||
|
r = await search_perceptual(
|
||||||
|
connector=connector, q=escaped,
|
||||||
|
end_user_id=self.end_user_id, limit=limit,
|
||||||
|
)
|
||||||
|
return r.get("perceptuals", [])
|
||||||
|
|
||||||
|
tasks = [_one(kw) for kw in keywords[:10]]
|
||||||
|
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
for result in batch:
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||||
|
continue
|
||||||
|
for rec in result:
|
||||||
|
rid = rec.get("id", "")
|
||||||
|
if rid and rid not in seen_ids:
|
||||||
|
seen_ids.add(rid)
|
||||||
|
all_results.append(rec)
|
||||||
|
|
||||||
|
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||||
|
return all_results[:limit]
|
||||||
|
|
||||||
|
async def _embedding_search(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
query_text: str,
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
|
||||||
|
try:
|
||||||
|
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||||
|
from app.core.models.base import RedBearModelConfig
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
cfg = MemoryConfigService(db).get_embedder_config(
|
||||||
|
str(self.memory_config.embedding_model_id)
|
||||||
|
)
|
||||||
|
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
|
||||||
|
|
||||||
|
r = await search_perceptual_by_embedding(
|
||||||
|
connector=connector, embedder_client=client,
|
||||||
|
query_text=query_text, end_user_id=self.end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
return r.get("perceptuals", [])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _rerank(
|
||||||
|
self,
|
||||||
|
keyword_results: List[dict],
|
||||||
|
embedding_results: List[dict],
|
||||||
|
limit: int,
|
||||||
|
) -> List[dict]:
|
||||||
|
"""BM25 + embedding 融合排序。
|
||||||
|
|
||||||
|
对 embedding 结果中带有 bm25_backfill_score 的条目,
|
||||||
|
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
|
||||||
|
"""
|
||||||
|
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
|
||||||
|
emb_backfill_items = []
|
||||||
|
for item in embedding_results:
|
||||||
|
backfill_score = item.get("bm25_backfill_score")
|
||||||
|
if backfill_score is not None and item.get("id"):
|
||||||
|
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
|
||||||
|
|
||||||
|
# 合并后统一归一化 BM25 scores
|
||||||
|
all_bm25_items = keyword_results + emb_backfill_items
|
||||||
|
all_bm25_items = self._normalize_scores(all_bm25_items)
|
||||||
|
|
||||||
|
# 建立 id -> normalized BM25 score 的映射
|
||||||
|
bm25_norm_map: Dict[str, float] = {}
|
||||||
|
for item in all_bm25_items:
|
||||||
|
item_id = item.get("id", "")
|
||||||
|
if item_id:
|
||||||
|
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||||
|
|
||||||
|
# 归一化 embedding scores
|
||||||
|
embedding_results = self._normalize_scores(embedding_results)
|
||||||
|
|
||||||
|
# 合并
|
||||||
|
combined: Dict[str, dict] = {}
|
||||||
|
for item in keyword_results:
|
||||||
|
item_id = item.get("id", "")
|
||||||
|
if not item_id:
|
||||||
|
continue
|
||||||
|
combined[item_id] = item.copy()
|
||||||
|
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||||
|
combined[item_id]["embedding_score"] = 0.0
|
||||||
|
|
||||||
|
for item in embedding_results:
|
||||||
|
item_id = item.get("id", "")
|
||||||
|
if not item_id:
|
||||||
|
continue
|
||||||
|
if item_id in combined:
|
||||||
|
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||||
|
else:
|
||||||
|
combined[item_id] = item.copy()
|
||||||
|
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
|
||||||
|
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||||
|
|
||||||
|
for item in combined.values():
|
||||||
|
bm25 = float(item.get("bm25_score", 0) or 0)
|
||||||
|
emb = float(item.get("embedding_score", 0) or 0)
|
||||||
|
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
|
||||||
|
|
||||||
|
results = list(combined.values())
|
||||||
|
before = len(results)
|
||||||
|
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
|
||||||
|
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||||
|
results = results[:limit]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
|
||||||
|
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
|
||||||
|
"""Z-score + sigmoid 归一化。"""
|
||||||
|
if not items:
|
||||||
|
return items
|
||||||
|
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||||
|
if len(scores) <= 1:
|
||||||
|
for it in items:
|
||||||
|
it[f"normalized_{field}"] = 1.0
|
||||||
|
return items
|
||||||
|
mean = sum(scores) / len(scores)
|
||||||
|
var = sum((s - mean) ** 2 for s in scores) / len(scores)
|
||||||
|
std = math.sqrt(var)
|
||||||
|
if std == 0:
|
||||||
|
for it in items:
|
||||||
|
it[f"normalized_{field}"] = 1.0
|
||||||
|
else:
|
||||||
|
for it, s in zip(items, scores):
|
||||||
|
z = (s - mean) / std
|
||||||
|
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
|
||||||
|
return items
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_result(record: dict) -> dict:
|
||||||
|
return {
|
||||||
|
"id": record.get("id", ""),
|
||||||
|
"perceptual_type": record.get("perceptual_type", ""),
|
||||||
|
"file_name": record.get("file_name", ""),
|
||||||
|
"file_path": record.get("file_path", ""),
|
||||||
|
"summary": record.get("summary", ""),
|
||||||
|
"topic": record.get("topic", ""),
|
||||||
|
"domain": record.get("domain", ""),
|
||||||
|
"keywords": record.get("keywords", []),
|
||||||
|
"created_at": str(record.get("created_at", "")),
|
||||||
|
"file_type": record.get("file_type", ""),
|
||||||
|
"score": record.get("score", 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_content_text(formatted: dict) -> str:
|
||||||
|
parts = []
|
||||||
|
if formatted["summary"]:
|
||||||
|
parts.append(formatted["summary"])
|
||||||
|
if formatted["topic"]:
|
||||||
|
parts.append(f"[主题: {formatted['topic']}]")
|
||||||
|
if formatted["keywords"]:
|
||||||
|
kw_list = formatted["keywords"]
|
||||||
|
if isinstance(kw_list, list):
|
||||||
|
parts.append(f"[关键词: {', '.join(kw_list)}]")
|
||||||
|
if formatted["file_name"]:
|
||||||
|
parts.append(f"[文件: {formatted['file_name']}]")
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
|
||||||
|
"""Extract search keywords from problem extension results."""
|
||||||
|
keywords = []
|
||||||
|
context = problem_extension.get("context", {})
|
||||||
|
if isinstance(context, dict):
|
||||||
|
for original_q, extended_qs in context.items():
|
||||||
|
keywords.append(original_q)
|
||||||
|
if isinstance(extended_qs, list):
|
||||||
|
keywords.extend(extended_qs)
|
||||||
|
return keywords
|
||||||
|
|
||||||
|
|
||||||
|
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
LangGraph node: perceptual memory retrieval.
|
||||||
|
|
||||||
|
Uses PerceptualSearchService to run keyword + embedding search with
|
||||||
|
BM25 fusion reranking, then writes results to state['perceptual_data'].
|
||||||
|
"""
|
||||||
|
end_user_id = state.get("end_user_id", "")
|
||||||
|
problem_extension = state.get("problem_extension", {})
|
||||||
|
original_query = state.get("data", "")
|
||||||
|
memory_config = state.get("memory_config", None)
|
||||||
|
|
||||||
|
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
|
||||||
|
|
||||||
|
keywords = _extract_keywords_from_problems(problem_extension)
|
||||||
|
if not keywords:
|
||||||
|
keywords = [original_query] if original_query else []
|
||||||
|
|
||||||
|
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
|
||||||
|
|
||||||
|
service = PerceptualSearchService(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
)
|
||||||
|
search_result = await service.search(
|
||||||
|
query=original_query,
|
||||||
|
keywords=keywords,
|
||||||
|
limit=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"memories": search_result["memories"],
|
||||||
|
"content": search_result["content"],
|
||||||
|
"_intermediate": {
|
||||||
|
"type": "perceptual_retrieve",
|
||||||
|
"title": "感知记忆检索",
|
||||||
|
"data": search_result["memories"],
|
||||||
|
"query": original_query,
|
||||||
|
"result_count": len(search_result["memories"]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return {"perceptual_data": result}
|
||||||
@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||||
|
|
||||||
# Emit intermediate output for frontend
|
# Emit intermediate output for frontend
|
||||||
print(time.time() - start)
|
|
||||||
result = {
|
result = {
|
||||||
"context": aggregated_dict,
|
"context": aggregated_dict,
|
||||||
"original": data,
|
"original": data,
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger, log_time
|
from app.core.logging_config import get_agent_logger, log_time
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||||
|
PerceptualSearchService,
|
||||||
|
)
|
||||||
from app.core.memory.agent.models.summary_models import (
|
from app.core.memory.agent.models.summary_models import (
|
||||||
RetrieveSummaryResponse,
|
RetrieveSummaryResponse,
|
||||||
SummaryResponse,
|
SummaryResponse,
|
||||||
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if storage_type != "rag":
|
if storage_type != "rag":
|
||||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
|
||||||
|
async def _perceptual_search():
|
||||||
|
service = PerceptualSearchService(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
memory_config=memory_config,
|
||||||
|
)
|
||||||
|
return await service.search(query=data, limit=5)
|
||||||
|
|
||||||
|
hybrid_task = SearchService().execute_hybrid_search(
|
||||||
**search_params,
|
**search_params,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
expand_communities=False,
|
||||||
)
|
)
|
||||||
|
perceptual_task = _perceptual_search()
|
||||||
|
|
||||||
|
gather_results = await asyncio.gather(
|
||||||
|
hybrid_task, perceptual_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
hybrid_result = gather_results[0]
|
||||||
|
perceptual_results = gather_results[1]
|
||||||
|
|
||||||
|
# 处理 hybrid search 异常
|
||||||
|
if isinstance(hybrid_result, Exception):
|
||||||
|
raise hybrid_result
|
||||||
|
retrieve_info, question, raw_results = hybrid_result
|
||||||
|
|
||||||
|
# 处理感知记忆结果
|
||||||
|
if isinstance(perceptual_results, Exception):
|
||||||
|
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
|
||||||
|
perceptual_results = []
|
||||||
|
|
||||||
|
# 拼接感知记忆内容到 retrieve_info
|
||||||
|
if perceptual_results and isinstance(perceptual_results, dict):
|
||||||
|
perceptual_content = perceptual_results.get("content", "")
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
|
||||||
|
count = len(perceptual_results.get("memories", []))
|
||||||
|
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
|
||||||
|
|
||||||
# 调试:打印 community 检索结果数量
|
# 调试:打印 community 检索结果数量
|
||||||
if raw_results and isinstance(raw_results, dict):
|
if raw_results and isinstance(raw_results, dict):
|
||||||
reranked = raw_results.get('reranked_results', {})
|
reranked = raw_results.get('reranked_results', {})
|
||||||
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
"error": str(e)
|
"error": str(e)
|
||||||
}
|
}
|
||||||
end = time.time()
|
end = time.time()
|
||||||
try:
|
duration = end - start
|
||||||
duration = end - start
|
|
||||||
except Exception:
|
|
||||||
duration = 0.0
|
|
||||||
log_time('检索', duration)
|
log_time('检索', duration)
|
||||||
return {"summary": summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
|||||||
retrieve_info_str = list(set(retrieve_info_str))
|
retrieve_info_str = list(set(retrieve_info_str))
|
||||||
retrieve_info_str = '\n'.join(retrieve_info_str)
|
retrieve_info_str = '\n'.join(retrieve_info_str)
|
||||||
|
|
||||||
aimessages = await summary_llm(state, history, retrieve_info_str,
|
# Merge perceptual memory content
|
||||||
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
|
perceptual_data = state.get("perceptual_data", {})
|
||||||
|
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||||
|
|
||||||
|
aimessages = await summary_llm(
|
||||||
|
state,
|
||||||
|
history,
|
||||||
|
retrieve_info_str,
|
||||||
|
'direct_summary_prompt.jinja2',
|
||||||
|
'retrieve_summary', RetrieveSummaryResponse,
|
||||||
|
"1"
|
||||||
|
)
|
||||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||||
await summary_redis_save(state, aimessages)
|
await summary_redis_save(state, aimessages)
|
||||||
if aimessages == '':
|
if aimessages == '':
|
||||||
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
|
|||||||
retrieve_info_str += i + '\n'
|
retrieve_info_str += i + '\n'
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
|
|
||||||
|
# Merge perceptual memory content
|
||||||
|
perceptual_data = state.get("perceptual_data", {})
|
||||||
|
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
|
|||||||
if key == 'answer_small':
|
if key == 'answer_small':
|
||||||
for i in value:
|
for i in value:
|
||||||
retrieve_info_str += i + '\n'
|
retrieve_info_str += i + '\n'
|
||||||
|
|
||||||
|
# Merge perceptual memory content
|
||||||
|
perceptual_data = state.get("perceptual_data", {})
|
||||||
|
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
|
||||||
|
if perceptual_content:
|
||||||
|
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"history": history,
|
"history": history,
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
|||||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||||
retrieve_nodes,
|
retrieve_nodes,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||||
|
perceptual_retrieve_node,
|
||||||
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||||
Input_Summary,
|
Input_Summary,
|
||||||
Retrieve_Summary,
|
Retrieve_Summary,
|
||||||
@@ -55,6 +58,7 @@ async def make_read_graph():
|
|||||||
workflow.add_node("Input_Summary", Input_Summary)
|
workflow.add_node("Input_Summary", Input_Summary)
|
||||||
workflow.add_node("Retrieve", retrieve_nodes)
|
workflow.add_node("Retrieve", retrieve_nodes)
|
||||||
# workflow.add_node("Retrieve", retrieve)
|
# workflow.add_node("Retrieve", retrieve)
|
||||||
|
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
|
||||||
workflow.add_node("Verify", Verify)
|
workflow.add_node("Verify", Verify)
|
||||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||||
workflow.add_node("Summary", Summary)
|
workflow.add_node("Summary", Summary)
|
||||||
@@ -65,14 +69,15 @@ async def make_read_graph():
|
|||||||
workflow.add_conditional_edges("content_input", Split_continue)
|
workflow.add_conditional_edges("content_input", Split_continue)
|
||||||
workflow.add_edge("Input_Summary", END)
|
workflow.add_edge("Input_Summary", END)
|
||||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
|
||||||
|
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
|
||||||
|
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
|
||||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||||
workflow.add_edge("Retrieve_Summary", END)
|
workflow.add_edge("Retrieve_Summary", END)
|
||||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||||
workflow.add_edge("Summary_fails", END)
|
workflow.add_edge("Summary_fails", END)
|
||||||
workflow.add_edge("Summary", END)
|
workflow.add_edge("Summary", END)
|
||||||
|
|
||||||
'''-----'''
|
|
||||||
# workflow.add_edge("Retrieve", END)
|
# workflow.add_edge("Retrieve", END)
|
||||||
|
|
||||||
# Compile workflow
|
# Compile workflow
|
||||||
@@ -80,7 +85,5 @@ async def make_read_graph():
|
|||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"创建工作流失败: {e}")
|
logger.error(f"创建工作流失败: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
print("工作流创建完成")
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
|
|||||||
from app.core.memory.src.search import run_hybrid_search
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||||
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
|
|||||||
|
|
||||||
|
|
||||||
async def expand_communities_to_statements(
|
async def expand_communities_to_statements(
|
||||||
community_results: List[dict],
|
community_results: List[dict],
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
existing_content: str = "",
|
existing_content: str = "",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> Tuple[List[dict], List[str]]:
|
) -> Tuple[List[dict], List[str]]:
|
||||||
"""
|
"""
|
||||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||||
@@ -76,7 +75,8 @@ async def expand_communities_to_statements(
|
|||||||
if s.get("statement") and s["statement"] not in existing_lines
|
if s.get("statement") and s["statement"] not in existing_lines
|
||||||
]
|
]
|
||||||
cleaned = _clean_expand_fields(expanded_stmts)
|
cleaned = _clean_expand_fields(expanded_stmts)
|
||||||
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
logger.info(
|
||||||
|
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||||
return cleaned, new_texts
|
return cleaned, new_texts
|
||||||
|
|
||||||
|
|
||||||
@@ -117,9 +117,9 @@ class SearchService:
|
|||||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||||
is_community = (
|
is_community = (
|
||||||
node_type == "community"
|
node_type == "community"
|
||||||
or 'member_count' in result
|
or 'member_count' in result
|
||||||
or 'core_entities' in result
|
or 'core_entities' in result
|
||||||
)
|
)
|
||||||
if is_community:
|
if is_community:
|
||||||
name = result.get('name', '')
|
name = result.get('name', '')
|
||||||
@@ -158,7 +158,7 @@ class SearchService:
|
|||||||
|
|
||||||
# Remove wrapping quotes
|
# Remove wrapping quotes
|
||||||
if (q.startswith("'") and q.endswith("'")) or (
|
if (q.startswith("'") and q.endswith("'")) or (
|
||||||
q.startswith('"') and q.endswith('"')
|
q.startswith('"') and q.endswith('"')
|
||||||
):
|
):
|
||||||
q = q[1:-1]
|
q = q[1:-1]
|
||||||
|
|
||||||
@@ -171,17 +171,17 @@ class SearchService:
|
|||||||
return q
|
return q
|
||||||
|
|
||||||
async def execute_hybrid_search(
|
async def execute_hybrid_search(
|
||||||
self,
|
self,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
question: str,
|
question: str,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
rerank_alpha: float = 0.4,
|
rerank_alpha: float = 0.4,
|
||||||
output_path: str = "search_results.json",
|
output_path: str = "search_results.json",
|
||||||
return_raw_results: bool = False,
|
return_raw_results: bool = False,
|
||||||
memory_config = None,
|
memory_config=None,
|
||||||
expand_communities: bool = True,
|
expand_communities: bool = True,
|
||||||
) -> Tuple[str, str, Optional[dict]]:
|
) -> Tuple[str, str, Optional[dict]]:
|
||||||
"""
|
"""
|
||||||
Execute hybrid search and return clean content.
|
Execute hybrid search and return clean content.
|
||||||
@@ -269,7 +269,6 @@ class SearchService:
|
|||||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||||
|
|
||||||
|
|
||||||
# Filter out empty strings and join with newlines
|
# Filter out empty strings and join with newlines
|
||||||
clean_content = '\n'.join([c for c in content_list if c])
|
clean_content = '\n'.join([c for c in content_list if c])
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import os
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, TypedDict
|
from typing import Annotated, TypedDict
|
||||||
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
|
|||||||
embedding_id: str
|
embedding_id: str
|
||||||
memory_config: object # 新增字段用于传递内存配置对象
|
memory_config: object # 新增字段用于传递内存配置对象
|
||||||
retrieve: dict
|
retrieve: dict
|
||||||
|
perceptual_data: dict
|
||||||
RetrieveSummary: dict
|
RetrieveSummary: dict
|
||||||
InputSummary: dict
|
InputSummary: dict
|
||||||
verify: dict
|
verify: dict
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ load_dotenv()
|
|||||||
|
|
||||||
logger = get_memory_logger(__name__)
|
logger = get_memory_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _parse_datetime(value: Any) -> Optional[datetime]:
|
def _parse_datetime(value: Any) -> Optional[datetime]:
|
||||||
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
|
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
|
||||||
if value is None:
|
if value is None:
|
||||||
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
|||||||
item[f"normalized_{score_field}"] = None
|
item[f"normalized_{score_field}"] = None
|
||||||
return results
|
return results
|
||||||
|
|
||||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||||
for item, score in zip(results, scores):
|
for item, score in zip(results, scores):
|
||||||
if score_field in item or score_field == "activation_value":
|
if score_field in item or score_field == "activation_value":
|
||||||
if score is None:
|
if score is None:
|
||||||
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Remove duplicate items from search results based on content.
|
Remove duplicate items from search results based on content.
|
||||||
@@ -157,11 +157,11 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
# Extract content from various possible fields
|
# Extract content from various possible fields
|
||||||
content = (
|
content = (
|
||||||
item.get("text") or
|
item.get("text") or
|
||||||
item.get("content") or
|
item.get("content") or
|
||||||
item.get("statement") or
|
item.get("statement") or
|
||||||
item.get("name") or
|
item.get("name") or
|
||||||
""
|
""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normalize content for comparison (strip whitespace and lowercase)
|
# Normalize content for comparison (strip whitespace and lowercase)
|
||||||
@@ -189,13 +189,14 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|||||||
|
|
||||||
|
|
||||||
def rerank_with_activation(
|
def rerank_with_activation(
|
||||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||||
alpha: float = 0.6,
|
alpha: float = 0.6,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
forgetting_config: ForgettingEngineConfig | None = None,
|
forgetting_config: ForgettingEngineConfig | None = None,
|
||||||
activation_boost_factor: float = 0.8,
|
activation_boost_factor: float = 0.8,
|
||||||
now: datetime | None = None,
|
now: datetime | None = None,
|
||||||
|
content_score_threshold: float = 0.5,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||||
@@ -222,6 +223,8 @@ def rerank_with_activation(
|
|||||||
forgetting_config: 遗忘引擎配置(当前未使用)
|
forgetting_config: 遗忘引擎配置(当前未使用)
|
||||||
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
||||||
now: 当前时间(用于遗忘计算)
|
now: 当前时间(用于遗忘计算)
|
||||||
|
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score),
|
||||||
|
低于此阈值的结果会被过滤。默认 0.5。
|
||||||
|
|
||||||
返回:
|
返回:
|
||||||
带评分元数据的重排序结果,按 final_score 排序
|
带评分元数据的重排序结果,按 final_score 排序
|
||||||
@@ -391,7 +394,19 @@ def rerank_with_activation(
|
|||||||
# 无激活值:使用内容相关性分数
|
# 无激活值:使用内容相关性分数
|
||||||
item["final_score"] = item.get("base_score", 0)
|
item["final_score"] = item.get("base_score", 0)
|
||||||
|
|
||||||
# 最终去重确保没有重复项
|
if content_score_threshold > 0:
|
||||||
|
before_count = len(sorted_items)
|
||||||
|
sorted_items = [
|
||||||
|
item for item in sorted_items
|
||||||
|
if float(item.get("content_score", 0) or 0) >= content_score_threshold
|
||||||
|
]
|
||||||
|
filtered_count = before_count - len(sorted_items)
|
||||||
|
if filtered_count > 0:
|
||||||
|
logger.info(
|
||||||
|
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
|
||||||
|
f"items below content_score_threshold={content_score_threshold}"
|
||||||
|
)
|
||||||
|
|
||||||
sorted_items = _deduplicate_results(sorted_items)
|
sorted_items = _deduplicate_results(sorted_items)
|
||||||
|
|
||||||
reranked[category] = sorted_items
|
reranked[category] = sorted_items
|
||||||
@@ -399,7 +414,8 @@ def rerank_with_activation(
|
|||||||
return reranked
|
return reranked
|
||||||
|
|
||||||
|
|
||||||
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
|
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
|
||||||
|
log_file: str = None):
|
||||||
"""Log search query information using the logger.
|
"""Log search query information using the logger.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
|
|||||||
|
|
||||||
|
|
||||||
def apply_reranker_placeholder(
|
def apply_reranker_placeholder(
|
||||||
results: Dict[str, List[Dict[str, Any]]],
|
results: Dict[str, List[Dict[str, Any]]],
|
||||||
query_text: str,
|
query_text: str,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Placeholder for a cross-encoder reranker.
|
Placeholder for a cross-encoder reranker.
|
||||||
@@ -673,17 +689,17 @@ def apply_reranker_placeholder(
|
|||||||
|
|
||||||
|
|
||||||
async def run_hybrid_search(
|
async def run_hybrid_search(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
search_type: str,
|
search_type: str,
|
||||||
end_user_id: str | None,
|
end_user_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
include: List[str],
|
include: List[str],
|
||||||
output_path: str | None,
|
output_path: str | None,
|
||||||
memory_config: "MemoryConfig",
|
memory_config: "MemoryConfig",
|
||||||
rerank_alpha: float = 0.6,
|
rerank_alpha: float = 0.6,
|
||||||
activation_boost_factor: float = 0.8,
|
activation_boost_factor: float = 0.8,
|
||||||
use_forgetting_rerank: bool = False,
|
use_forgetting_rerank: bool = False,
|
||||||
use_llm_rerank: bool = False,
|
use_llm_rerank: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -788,7 +804,7 @@ async def run_hybrid_search(
|
|||||||
|
|
||||||
if keyword_task:
|
if keyword_task:
|
||||||
keyword_results = await keyword_task
|
keyword_results = await keyword_task
|
||||||
keyword_latency = time.time() - keyword_start
|
keyword_latency = time.time() - search_start_time
|
||||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||||
if search_type == "keyword":
|
if search_type == "keyword":
|
||||||
@@ -798,7 +814,7 @@ async def run_hybrid_search(
|
|||||||
|
|
||||||
if embedding_task:
|
if embedding_task:
|
||||||
embedding_results = await embedding_task
|
embedding_results = await embedding_task
|
||||||
embedding_latency = time.time() - embedding_start
|
embedding_latency = time.time() - search_start_time
|
||||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||||
if search_type == "embedding":
|
if search_type == "embedding":
|
||||||
@@ -810,7 +826,8 @@ async def run_hybrid_search(
|
|||||||
if search_type == "hybrid":
|
if search_type == "hybrid":
|
||||||
results["combined_summary"] = {
|
results["combined_summary"] = {
|
||||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
"total_embedding_results": sum(
|
||||||
|
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||||
"search_query": query_text,
|
"search_query": query_text,
|
||||||
"search_timestamp": datetime.now().isoformat()
|
"search_timestamp": datetime.now().isoformat()
|
||||||
}
|
}
|
||||||
@@ -866,7 +883,8 @@ async def run_hybrid_search(
|
|||||||
results["reranked_results"] = reranked_results
|
results["reranked_results"] = reranked_results
|
||||||
results["combined_summary"] = {
|
results["combined_summary"] = {
|
||||||
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
|
||||||
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
"total_embedding_results": sum(
|
||||||
|
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
|
||||||
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
|
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
|
||||||
"search_query": query_text,
|
"search_query": query_text,
|
||||||
"search_timestamp": datetime.now().isoformat(),
|
"search_timestamp": datetime.now().isoformat(),
|
||||||
@@ -908,8 +926,10 @@ async def run_hybrid_search(
|
|||||||
# Log search completion with result count
|
# Log search completion with result count
|
||||||
if search_type == "hybrid":
|
if search_type == "hybrid":
|
||||||
result_counts = {
|
result_counts = {
|
||||||
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
|
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||||
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
|
keyword_results.items()},
|
||||||
|
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
|
||||||
|
embedding_results.items()}
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
|
||||||
@@ -927,12 +947,12 @@ async def run_hybrid_search(
|
|||||||
|
|
||||||
|
|
||||||
async def search_by_temporal(
|
async def search_by_temporal(
|
||||||
end_user_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
invalid_date: Optional[str] = None,
|
invalid_date: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -968,13 +988,13 @@ async def search_by_temporal(
|
|||||||
|
|
||||||
|
|
||||||
async def search_by_keyword_temporal(
|
async def search_by_keyword_temporal(
|
||||||
query_text: str,
|
query_text: str,
|
||||||
end_user_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
invalid_date: Optional[str] = None,
|
invalid_date: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Temporal keyword search across Statements.
|
Temporal keyword search across Statements.
|
||||||
@@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal(
|
|||||||
|
|
||||||
|
|
||||||
async def search_chunk_by_chunk_id(
|
async def search_chunk_by_chunk_id(
|
||||||
chunk_id: str,
|
chunk_id: str,
|
||||||
end_user_id: Optional[str] = "test",
|
end_user_id: Optional[str] = "test",
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Search for Chunks by chunk_id.
|
Search for Chunks by chunk_id.
|
||||||
@@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id(
|
|||||||
limit=limit
|
limit=limit
|
||||||
)
|
)
|
||||||
return {"chunks": chunks}
|
return {"chunks": chunks}
|
||||||
|
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ async def lifespan(app: FastAPI):
|
|||||||
else:
|
else:
|
||||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||||
await create_all_indexes()
|
await create_all_indexes()
|
||||||
|
logger.info("All neo4j indexes and constraints created successfully!")
|
||||||
logger.info("应用程序启动完成")
|
logger.info("应用程序启动完成")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import asyncio
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
|
||||||
async def create_fulltext_indexes():
|
async def create_fulltext_indexes():
|
||||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
# 创建 Statements 索引
|
# 创建 Statements 索引
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||||
@@ -40,8 +40,16 @@ async def create_fulltext_indexes():
|
|||||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# 创建 Perceptual 感知记忆索引
|
||||||
|
await connector.execute_query("""
|
||||||
|
CREATE FULLTEXT INDEX perceptualFulltext IF NOT EXISTS FOR (p:Perceptual) ON EACH [p.summary, p.topic, p.domain]
|
||||||
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
|
""")
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
|
||||||
async def create_vector_indexes():
|
async def create_vector_indexes():
|
||||||
"""Create vector indexes for fast embedding similarity search.
|
"""Create vector indexes for fast embedding similarity search.
|
||||||
|
|
||||||
@@ -51,7 +59,6 @@ async def create_vector_indexes():
|
|||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
||||||
# Statement embedding index
|
# Statement embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
|
CREATE VECTOR INDEX statement_embedding_index IF NOT EXISTS
|
||||||
@@ -63,7 +70,6 @@ async def create_vector_indexes():
|
|||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
# Chunk embedding index
|
# Chunk embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
|
CREATE VECTOR INDEX chunk_embedding_index IF NOT EXISTS
|
||||||
@@ -75,7 +81,6 @@ async def create_vector_indexes():
|
|||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
# Entity name embedding index
|
# Entity name embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
|
CREATE VECTOR INDEX entity_embedding_index IF NOT EXISTS
|
||||||
@@ -87,7 +92,6 @@ async def create_vector_indexes():
|
|||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
# Memory summary embedding index
|
# Memory summary embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
|
CREATE VECTOR INDEX summary_embedding_index IF NOT EXISTS
|
||||||
@@ -121,8 +125,20 @@ async def create_vector_indexes():
|
|||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# Perceptual summary embedding index
|
||||||
|
await connector.execute_query("""
|
||||||
|
CREATE VECTOR INDEX perceptual_summary_embedding_index IF NOT EXISTS
|
||||||
|
FOR (p:Perceptual)
|
||||||
|
ON p.summary_embedding
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
""")
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
|
||||||
async def create_unique_constraints():
|
async def create_unique_constraints():
|
||||||
"""Create uniqueness constraints for core node identifiers.
|
"""Create uniqueness constraints for core node identifiers.
|
||||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||||
@@ -155,10 +171,10 @@ async def create_unique_constraints():
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
|
||||||
async def create_all_indexes():
|
async def create_all_indexes():
|
||||||
"""Create all indexes and constraints in one go."""
|
"""Create all indexes and constraints in one go."""
|
||||||
await create_fulltext_indexes()
|
await create_fulltext_indexes()
|
||||||
await create_vector_indexes()
|
await create_vector_indexes()
|
||||||
await create_unique_constraints()
|
await create_unique_constraints()
|
||||||
print("✓ All indexes and constraints created successfully!")
|
|
||||||
|
|
||||||
|
|||||||
@@ -1449,3 +1449,44 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
|
|||||||
r.created_at = edge.created_at
|
r.created_at = edge.created_at
|
||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
||||||
|
CALL db.index.fulltext.queryNodes("perceptualFulltext", $q) YIELD node AS p, score
|
||||||
|
WHERE p.end_user_id = $end_user_id
|
||||||
|
RETURN p.id AS id,
|
||||||
|
p.end_user_id AS end_user_id,
|
||||||
|
p.perceptual_type AS perceptual_type,
|
||||||
|
p.file_path AS file_path,
|
||||||
|
p.file_name AS file_name,
|
||||||
|
p.file_ext AS file_ext,
|
||||||
|
p.summary AS summary,
|
||||||
|
p.keywords AS keywords,
|
||||||
|
p.topic AS topic,
|
||||||
|
p.domain AS domain,
|
||||||
|
p.created_at AS created_at,
|
||||||
|
p.file_type AS file_type,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
PERCEPTUAL_EMBEDDING_SEARCH = """
|
||||||
|
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
|
||||||
|
YIELD node AS p, score
|
||||||
|
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
|
||||||
|
RETURN p.id AS id,
|
||||||
|
p.end_user_id AS end_user_id,
|
||||||
|
p.perceptual_type AS perceptual_type,
|
||||||
|
p.file_path AS file_path,
|
||||||
|
p.file_name AS file_name,
|
||||||
|
p.file_ext AS file_ext,
|
||||||
|
p.summary AS summary,
|
||||||
|
p.keywords AS keywords,
|
||||||
|
p.topic AS topic,
|
||||||
|
p.domain AS domain,
|
||||||
|
p.created_at AS created_at,
|
||||||
|
p.file_type AS file_type,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
ENTITY_EMBEDDING_SEARCH,
|
ENTITY_EMBEDDING_SEARCH,
|
||||||
EXPAND_COMMUNITY_STATEMENTS,
|
EXPAND_COMMUNITY_STATEMENTS,
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||||
|
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||||
SEARCH_CHUNKS_BY_CONTENT,
|
SEARCH_CHUNKS_BY_CONTENT,
|
||||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||||
@@ -15,6 +16,7 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
SEARCH_ENTITIES_BY_NAME,
|
SEARCH_ENTITIES_BY_NAME,
|
||||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||||
|
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||||
@@ -34,11 +36,11 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def _update_activation_values_batch(
|
async def _update_activation_values_batch(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
nodes: List[Dict[str, Any]],
|
nodes: List[Dict[str, Any]],
|
||||||
node_label: str,
|
node_label: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
max_retries: int = 3
|
max_retries: int = 3
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
批量更新节点的激活值
|
批量更新节点的激活值
|
||||||
@@ -120,9 +122,9 @@ async def _update_activation_values_batch(
|
|||||||
|
|
||||||
|
|
||||||
async def _update_search_results_activation(
|
async def _update_search_results_activation(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
results: Dict[str, List[Dict[str, Any]]],
|
results: Dict[str, List[Dict[str, Any]]],
|
||||||
end_user_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
更新搜索结果中所有知识节点的激活值
|
更新搜索结果中所有知识节点的激活值
|
||||||
@@ -196,7 +198,7 @@ async def _update_search_results_activation(
|
|||||||
'importance_score',
|
'importance_score',
|
||||||
'version',
|
'version',
|
||||||
'statement', # Statement 节点的内容字段
|
'statement', # Statement 节点的内容字段
|
||||||
'content' # MemorySummary 节点的内容字段
|
'content' # MemorySummary 节点的内容字段
|
||||||
}
|
}
|
||||||
|
|
||||||
# 只更新激活值相关字段,保留原始节点的其他字段
|
# 只更新激活值相关字段,保留原始节点的其他字段
|
||||||
@@ -220,11 +222,11 @@ async def _update_search_results_activation(
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph(
|
async def search_graph(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
q: str,
|
q: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: List[str] = None,
|
include: List[str] = None,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||||
@@ -257,6 +259,7 @@ async def search_graph(
|
|||||||
if "statements" in include:
|
if "statements" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||||
|
json_format=True,
|
||||||
q=q,
|
q=q,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -266,6 +269,7 @@ async def search_graph(
|
|||||||
if "entities" in include:
|
if "entities" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||||
|
json_format=True,
|
||||||
q=q,
|
q=q,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -275,6 +279,7 @@ async def search_graph(
|
|||||||
if "chunks" in include:
|
if "chunks" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_CHUNKS_BY_CONTENT,
|
SEARCH_CHUNKS_BY_CONTENT,
|
||||||
|
json_format=True,
|
||||||
q=q,
|
q=q,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -284,6 +289,7 @@ async def search_graph(
|
|||||||
if "summaries" in include:
|
if "summaries" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||||
|
json_format=True,
|
||||||
q=q,
|
q=q,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -293,6 +299,7 @@ async def search_graph(
|
|||||||
if "communities" in include:
|
if "communities" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||||
|
json_format=True,
|
||||||
q=q,
|
q=q,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -336,12 +343,12 @@ async def search_graph(
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph_by_embedding(
|
async def search_graph_by_embedding(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
embedder_client,
|
embedder_client,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: List[str] = ["statements", "chunks", "entities","summaries"],
|
include: List[str] = ["statements", "chunks", "entities", "summaries"],
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||||
@@ -360,7 +367,7 @@ async def search_graph_by_embedding(
|
|||||||
embed_start = time.time()
|
embed_start = time.time()
|
||||||
embeddings = await embedder_client.response([query_text])
|
embeddings = await embedder_client.response([query_text])
|
||||||
embed_time = time.time() - embed_start
|
embed_time = time.time() - embed_start
|
||||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||||
|
|
||||||
if not embeddings or not embeddings[0]:
|
if not embeddings or not embeddings[0]:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -378,6 +385,7 @@ async def search_graph_by_embedding(
|
|||||||
if "statements" in include:
|
if "statements" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
STATEMENT_EMBEDDING_SEARCH,
|
STATEMENT_EMBEDDING_SEARCH,
|
||||||
|
json_format=True,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -388,6 +396,7 @@ async def search_graph_by_embedding(
|
|||||||
if "chunks" in include:
|
if "chunks" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
CHUNK_EMBEDDING_SEARCH,
|
CHUNK_EMBEDDING_SEARCH,
|
||||||
|
json_format=True,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -398,6 +407,7 @@ async def search_graph_by_embedding(
|
|||||||
if "entities" in include:
|
if "entities" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
ENTITY_EMBEDDING_SEARCH,
|
ENTITY_EMBEDDING_SEARCH,
|
||||||
|
json_format=True,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -408,6 +418,7 @@ async def search_graph_by_embedding(
|
|||||||
if "summaries" in include:
|
if "summaries" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||||
|
json_format=True,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -418,6 +429,7 @@ async def search_graph_by_embedding(
|
|||||||
if "communities" in include:
|
if "communities" in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(connector.execute_query(
|
||||||
COMMUNITY_EMBEDDING_SEARCH,
|
COMMUNITY_EMBEDDING_SEARCH,
|
||||||
|
json_format=True,
|
||||||
embedding=embedding,
|
embedding=embedding,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -428,7 +440,7 @@ async def search_graph_by_embedding(
|
|||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
query_time = time.time() - query_start
|
query_time = time.time() - query_start
|
||||||
print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
||||||
|
|
||||||
# Build results dictionary
|
# Build results dictionary
|
||||||
results: Dict[str, List[Dict[str, Any]]] = {
|
results: Dict[str, List[Dict[str, Any]]] = {
|
||||||
@@ -473,13 +485,15 @@ async def search_graph_by_embedding(
|
|||||||
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
logger.info(f"[PERF] Skipping activation updates (only summaries)")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
entities: List[Dict[str, Any]],
|
entities: List[Dict[str, Any]],
|
||||||
use_contains_fallback: bool = True,
|
use_contains_fallback: bool = True,
|
||||||
batch_size: int = 500,
|
batch_size: int = 500,
|
||||||
max_concurrency: int = 5,
|
max_concurrency: int = 5,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
为第二层去重消歧批量检索候选实体(适配新版 cypher_queries):
|
||||||
@@ -560,14 +574,14 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph_by_keyword_temporal(
|
async def search_graph_by_keyword_temporal(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
invalid_date: Optional[str] = None,
|
invalid_date: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
) -> Dict[str, List[Any]]:
|
) -> Dict[str, List[Any]]:
|
||||||
"""
|
"""
|
||||||
Temporal keyword search across Statements.
|
Temporal keyword search across Statements.
|
||||||
@@ -579,7 +593,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
- Returns up to 'limit' statements
|
- Returns up to 'limit' statements
|
||||||
"""
|
"""
|
||||||
if not query_text:
|
if not query_text:
|
||||||
print(f"query_text不能为空")
|
logger.warning(f"query_text不能为空")
|
||||||
return {"statements": []}
|
return {"statements": []}
|
||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||||
@@ -591,7 +605,7 @@ async def search_graph_by_keyword_temporal(
|
|||||||
invalid_date=invalid_date,
|
invalid_date=invalid_date,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
print(f"查询结果为:\n{statements}")
|
logger.debug(f"查询结果为:\n{statements}")
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
@@ -605,13 +619,13 @@ async def search_graph_by_keyword_temporal(
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph_by_temporal(
|
async def search_graph_by_temporal(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
start_date: Optional[str] = None,
|
start_date: Optional[str] = None,
|
||||||
end_date: Optional[str] = None,
|
end_date: Optional[str] = None,
|
||||||
valid_date: Optional[str] = None,
|
valid_date: Optional[str] = None,
|
||||||
invalid_date: Optional[str] = None,
|
invalid_date: Optional[str] = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -632,10 +646,6 @@ async def search_graph_by_temporal(
|
|||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
|
||||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
|
||||||
print(f"查询结果为:\n{statements}")
|
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
@@ -648,10 +658,10 @@ async def search_graph_by_temporal(
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph_by_dialog_id(
|
async def search_graph_by_dialog_id(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
dialog_id: str,
|
dialog_id: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Dialogues.
|
Temporal search across Dialogues.
|
||||||
@@ -661,7 +671,7 @@ async def search_graph_by_dialog_id(
|
|||||||
- Returns up to 'limit' dialogues
|
- Returns up to 'limit' dialogues
|
||||||
"""
|
"""
|
||||||
if not dialog_id:
|
if not dialog_id:
|
||||||
print(f"dialog_id不能为空")
|
logger.warning(f"dialog_id不能为空")
|
||||||
return {"dialogues": []}
|
return {"dialogues": []}
|
||||||
|
|
||||||
dialogues = await connector.execute_query(
|
dialogues = await connector.execute_query(
|
||||||
@@ -674,13 +684,13 @@ async def search_graph_by_dialog_id(
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph_by_chunk_id(
|
async def search_graph_by_chunk_id(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
chunk_id : str,
|
chunk_id: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 1,
|
limit: int = 1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
if not chunk_id:
|
if not chunk_id:
|
||||||
print(f"chunk_id不能为空")
|
logger.warning(f"chunk_id不能为空")
|
||||||
return {"chunks": []}
|
return {"chunks": []}
|
||||||
chunks = await connector.execute_query(
|
chunks = await connector.execute_query(
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||||
@@ -692,10 +702,10 @@ async def search_graph_by_chunk_id(
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph_community_expand(
|
async def search_graph_community_expand(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
community_ids: List[str],
|
community_ids: List[str],
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
三期:社区展开检索 —— 主题 → 细节两级检索。
|
三期:社区展开检索 —— 主题 → 细节两级检索。
|
||||||
@@ -748,12 +758,11 @@ async def search_graph_community_expand(
|
|||||||
|
|
||||||
|
|
||||||
async def search_graph_by_created_at(
|
async def search_graph_by_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|
||||||
|
created_at: Optional[str] = None,
|
||||||
created_at: Optional[str] = None,
|
limit: int = 1,
|
||||||
limit: int = 1,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -768,15 +777,10 @@ async def search_graph_by_created_at(
|
|||||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
|
|
||||||
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
|
||||||
print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}")
|
|
||||||
print(f"查询结果为:\n{statements}")
|
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
@@ -787,13 +791,13 @@ async def search_graph_by_created_at(
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def search_graph_by_valid_at(
|
async def search_graph_by_valid_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|
||||||
|
valid_at: Optional[str] = None,
|
||||||
valid_at: Optional[str] = None,
|
limit: int = 1,
|
||||||
limit: int = 1,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -808,15 +812,10 @@ async def search_graph_by_valid_at(
|
|||||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
|
|
||||||
|
|
||||||
valid_at=valid_at,
|
valid_at=valid_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
|
||||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
|
||||||
print(f"查询结果为:\n{statements}")
|
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
@@ -827,13 +826,13 @@ async def search_graph_by_valid_at(
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def search_graph_g_created_at(
|
async def search_graph_g_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|
||||||
|
created_at: Optional[str] = None,
|
||||||
created_at: Optional[str] = None,
|
limit: int = 1,
|
||||||
limit: int = 1,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -848,15 +847,10 @@ async def search_graph_g_created_at(
|
|||||||
SEARCH_STATEMENTS_G_CREATED_AT,
|
SEARCH_STATEMENTS_G_CREATED_AT,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
|
|
||||||
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
|
||||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
|
||||||
print(f"查询结果为:\n{statements}")
|
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
@@ -867,13 +861,13 @@ async def search_graph_g_created_at(
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def search_graph_g_valid_at(
|
async def search_graph_g_valid_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|
||||||
|
valid_at: Optional[str] = None,
|
||||||
valid_at: Optional[str] = None,
|
limit: int = 1,
|
||||||
limit: int = 1,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -887,16 +881,10 @@ async def search_graph_g_valid_at(
|
|||||||
statements = await connector.execute_query(
|
statements = await connector.execute_query(
|
||||||
SEARCH_STATEMENTS_G_VALID_AT,
|
SEARCH_STATEMENTS_G_VALID_AT,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
|
|
||||||
|
|
||||||
valid_at=valid_at,
|
valid_at=valid_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
|
||||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
|
||||||
print(f"查询结果为:\n{statements}")
|
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
@@ -907,13 +895,13 @@ async def search_graph_g_valid_at(
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def search_graph_l_created_at(
|
async def search_graph_l_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|
||||||
|
created_at: Optional[str] = None,
|
||||||
created_at: Optional[str] = None,
|
limit: int = 1,
|
||||||
limit: int = 1,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -928,15 +916,10 @@ async def search_graph_l_created_at(
|
|||||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
|
|
||||||
|
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
|
||||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}")
|
|
||||||
print(f"查询结果为:\n{statements}")
|
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
@@ -947,13 +930,13 @@ async def search_graph_l_created_at(
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def search_graph_l_valid_at(
|
async def search_graph_l_valid_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|
||||||
|
valid_at: Optional[str] = None,
|
||||||
valid_at: Optional[str] = None,
|
limit: int = 1,
|
||||||
limit: int = 1,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Temporal search across Statements.
|
Temporal search across Statements.
|
||||||
@@ -968,15 +951,10 @@ async def search_graph_l_valid_at(
|
|||||||
SEARCH_STATEMENTS_L_VALID_AT,
|
SEARCH_STATEMENTS_L_VALID_AT,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
|
|
||||||
|
|
||||||
valid_at=valid_at,
|
valid_at=valid_at,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
|
||||||
print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
|
||||||
print(f"查询结果为:\n{statements}")
|
|
||||||
|
|
||||||
# 更新 Statement 节点的激活值
|
# 更新 Statement 节点的激活值
|
||||||
results = {"statements": statements}
|
results = {"statements": statements}
|
||||||
results = await _update_search_results_activation(
|
results = await _update_search_results_activation(
|
||||||
@@ -986,3 +964,87 @@ async def search_graph_l_valid_at(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def search_perceptual(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
q: str,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Search Perceptual memory nodes using fulltext keyword search.
|
||||||
|
|
||||||
|
Matches against summary, topic, and domain fields via the perceptualFulltext index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: Neo4j connector
|
||||||
|
q: Query text
|
||||||
|
end_user_id: Optional user filter
|
||||||
|
limit: Max results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
perceptuals = await connector.execute_query(
|
||||||
|
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||||
|
q=q,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"search_perceptual: keyword search failed: {e}")
|
||||||
|
perceptuals = []
|
||||||
|
|
||||||
|
# Deduplicate
|
||||||
|
from app.core.memory.src.search import _deduplicate_results
|
||||||
|
perceptuals = _deduplicate_results(perceptuals)
|
||||||
|
|
||||||
|
return {"perceptuals": perceptuals}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_perceptual_by_embedding(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
embedder_client,
|
||||||
|
query_text: str,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Search Perceptual memory nodes using embedding-based semantic search.
|
||||||
|
|
||||||
|
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: Neo4j connector
|
||||||
|
embedder_client: Embedding client with async response() method
|
||||||
|
query_text: Query text to embed
|
||||||
|
end_user_id: Optional user filter
|
||||||
|
limit: Max results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||||
|
"""
|
||||||
|
embeddings = await embedder_client.response([query_text])
|
||||||
|
if not embeddings or not embeddings[0]:
|
||||||
|
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
||||||
|
return {"perceptuals": []}
|
||||||
|
|
||||||
|
embedding = embeddings[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
perceptuals = await connector.execute_query(
|
||||||
|
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||||
|
embedding=embedding,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
|
||||||
|
perceptuals = []
|
||||||
|
|
||||||
|
from app.core.memory.src.search import _deduplicate_results
|
||||||
|
perceptuals = _deduplicate_results(perceptuals)
|
||||||
|
|
||||||
|
return {"perceptuals": perceptuals}
|
||||||
|
|||||||
@@ -11,10 +11,28 @@ Classes:
|
|||||||
from typing import Any, List, Dict
|
from typing import Any, List, Dict
|
||||||
|
|
||||||
from neo4j import AsyncGraphDatabase, basic_auth
|
from neo4j import AsyncGraphDatabase, basic_auth
|
||||||
|
from neo4j.time import DateTime as Neo4jDateTime, Date as Neo4jDate, Time as Neo4jTime, Duration as Neo4jDuration
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_neo4j_types(value: Any) -> Any:
|
||||||
|
"""递归将 neo4j 原生时间类型转为 Python 原生类型 / ISO 字符串,确保可被 json.dumps 序列化。"""
|
||||||
|
if isinstance(value, Neo4jDateTime):
|
||||||
|
return value.to_native().isoformat() if value.tzinfo else value.iso_format()
|
||||||
|
if isinstance(value, Neo4jDate):
|
||||||
|
return value.iso_format()
|
||||||
|
if isinstance(value, Neo4jTime):
|
||||||
|
return value.iso_format()
|
||||||
|
if isinstance(value, Neo4jDuration):
|
||||||
|
return str(value)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {k: _convert_neo4j_types(v) for k, v in value.items()}
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [_convert_neo4j_types(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class Neo4jConnector:
|
class Neo4jConnector:
|
||||||
"""Neo4j数据库连接器
|
"""Neo4j数据库连接器
|
||||||
|
|
||||||
@@ -59,11 +77,12 @@ class Neo4jConnector:
|
|||||||
"""
|
"""
|
||||||
await self.driver.close()
|
await self.driver.close()
|
||||||
|
|
||||||
async def execute_query(self, query: str, **kwargs: Any) -> List[Dict[str, Any]]:
|
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||||
"""执行Cypher查询
|
"""执行Cypher查询
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Cypher查询语句
|
query: Cypher查询语句
|
||||||
|
json_format: json格式化
|
||||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -78,7 +97,10 @@ class Neo4jConnector:
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
records, summary, keys = result
|
records, summary, keys = result
|
||||||
return [record.data() for record in records]
|
if json_format:
|
||||||
|
return [_convert_neo4j_types(record.data()) for record in records]
|
||||||
|
else:
|
||||||
|
return [record.data() for record in records]
|
||||||
|
|
||||||
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||||
"""在写事务中执行操作
|
"""在写事务中执行操作
|
||||||
|
|||||||
@@ -462,11 +462,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||||
|
|
||||||
# 导入审计日志记录器
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
config_load_start = time.time()
|
config_load_start = time.time()
|
||||||
try:
|
try:
|
||||||
# Use a separate database session to avoid transaction failures
|
# Use a separate database session to avoid transaction failures
|
||||||
@@ -507,10 +502,13 @@ class MemoryAgentService:
|
|||||||
async with make_read_graph() as graph:
|
async with make_read_graph() as graph:
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# 初始状态 - 包含所有必要字段
|
# 初始状态 - 包含所有必要字段
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
initial_state = {
|
||||||
"end_user_id": end_user_id
|
"messages": [HumanMessage(content=message)],
|
||||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
"search_switch": search_switch,
|
||||||
"memory_config": memory_config}
|
"end_user_id": end_user_id
|
||||||
|
, "storage_type": storage_type,
|
||||||
|
"user_rag_memory_id": user_rag_memory_id,
|
||||||
|
"memory_config": memory_config}
|
||||||
# 获取节点更新信息
|
# 获取节点更新信息
|
||||||
_intermediate_outputs = []
|
_intermediate_outputs = []
|
||||||
summary = ''
|
summary = ''
|
||||||
@@ -522,7 +520,7 @@ class MemoryAgentService:
|
|||||||
for node_name, node_data in update_event.items():
|
for node_name, node_data in update_event.items():
|
||||||
# if 'save_neo4j' == node_name:
|
# if 'save_neo4j' == node_name:
|
||||||
# massages = node_data
|
# massages = node_data
|
||||||
print(f"处理节点: {node_name}")
|
logger.info(f"处理节点: {node_name}")
|
||||||
|
|
||||||
# 处理不同Summary节点的返回结构
|
# 处理不同Summary节点的返回结构
|
||||||
if 'Summary' in node_name:
|
if 'Summary' in node_name:
|
||||||
@@ -549,6 +547,11 @@ class MemoryAgentService:
|
|||||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||||
_intermediate_outputs.extend(retrieve_node)
|
_intermediate_outputs.extend(retrieve_node)
|
||||||
|
|
||||||
|
# Perceptual_Retrieve 节点
|
||||||
|
perceptual_node = node_data.get('perceptual_data', {}).get('_intermediate', None)
|
||||||
|
if perceptual_node and perceptual_node != [] and perceptual_node != {}:
|
||||||
|
_intermediate_outputs.append(perceptual_node)
|
||||||
|
|
||||||
# Verify 节点
|
# Verify 节点
|
||||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||||
if verify_n and verify_n != [] and verify_n != {}:
|
if verify_n and verify_n != [] and verify_n != {}:
|
||||||
|
|||||||
Reference in New Issue
Block a user