feat(memory): add perceptual memory retrieval service with BM25+embedding fusion

This commit is contained in:
Eternity
2026-04-01 17:19:03 +08:00
parent 75bb96d4e7
commit 9cbe9d5edc
13 changed files with 1042 additions and 409 deletions

View File

@@ -0,0 +1,408 @@
"""
Perceptual Memory Retrieval Node & Service
Provides PerceptualSearchService for searching perceptual memories (vision, audio,
text, conversation) from Neo4j using keyword fulltext + embedding semantic search
with BM25+embedding fusion reranking.
Also provides the perceptual_retrieve_node for use as a LangGraph node.
"""
import asyncio
import math
from typing import List, Dict, Any, Optional
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import (
search_perceptual,
search_perceptual_by_embedding,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = get_agent_logger(__name__)
class PerceptualSearchService:
"""
感知记忆检索服务。
封装关键词全文检索 + 向量语义检索 + BM25/embedding 融合排序的完整流程。
调用方只需提供 query / keywords、end_user_id、memory_config即可获得
格式化并排序后的感知记忆列表和拼接文本。
Usage:
service = PerceptualSearchService(end_user_id=..., memory_config=...)
results = await service.search(query="...", keywords=[...], limit=10)
# results = {"memories": [...], "content": "...", "keyword_raw": N, "embedding_raw": M}
"""
DEFAULT_ALPHA = 0.6
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
end_user_id: str,
memory_config: Any,
alpha: float = DEFAULT_ALPHA,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD,
):
self.end_user_id = end_user_id
self.memory_config = memory_config
self.alpha = alpha
self.content_score_threshold = content_score_threshold
async def search(
self,
query: str,
keywords: Optional[List[str]] = None,
limit: int = 10,
) -> Dict[str, Any]:
"""
执行感知记忆检索(关键词 + 向量并行),融合排序后返回结果。
对 embedding 命中但 keyword 未命中的结果,补查全文索引获取 BM25 分数,
确保所有结果都同时具备 BM25 和 embedding 两个维度的评分。
Args:
query: 原始用户查询(用于向量检索和 BM25 补查)
keywords: 关键词列表(用于全文检索),为 None 时使用 [query]
limit: 最大返回数量
Returns:
{
"memories": [格式化后的记忆 dict, ...],
"content": "拼接的纯文本摘要",
"keyword_raw": int,
"embedding_raw": int,
}
"""
if keywords is None:
keywords = [query] if query else []
connector = Neo4jConnector()
try:
kw_task = self._keyword_search(connector, keywords, limit)
emb_task = self._embedding_search(connector, query, limit)
kw_results, emb_results = await asyncio.gather(
kw_task, emb_task, return_exceptions=True
)
if isinstance(kw_results, Exception):
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
kw_results = []
if isinstance(emb_results, Exception):
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
emb_results = []
# 补查 BM25找出 embedding 命中但 keyword 未命中的 id
# 用原始 query 对这些节点补查全文索引拿 BM25 score
kw_ids = {r.get("id") for r in kw_results if r.get("id")}
emb_only_ids = {r.get("id") for r in emb_results if r.get("id") and r.get("id") not in kw_ids}
if emb_only_ids and query:
backfill = await self._bm25_backfill(connector, query, emb_only_ids, limit)
# 把补查到的 BM25 score 注入到 embedding 结果中
backfill_map = {r["id"]: r.get("score", 0) for r in backfill}
for r in emb_results:
rid = r.get("id", "")
if rid in backfill_map:
r["bm25_backfill_score"] = backfill_map[rid]
logger.info(
f"[PerceptualSearch] BM25 backfill: {len(emb_only_ids)} embedding-only ids, "
f"{len(backfill_map)} got BM25 scores"
)
reranked = self._rerank(kw_results, emb_results, limit)
memories = []
content_parts = []
for record in reranked:
fmt = self._format_result(record)
fmt["score"] = round(record.get("content_score", 0), 4)
memories.append(fmt)
content_parts.append(self._build_content_text(fmt))
logger.info(
f"[PerceptualSearch] {len(memories)} results after rerank "
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
)
return {
"memories": memories,
"content": "\n\n".join(content_parts),
"keyword_raw": len(kw_results),
"embedding_raw": len(emb_results),
}
finally:
await connector.close()
async def _bm25_backfill(
self,
connector: Neo4jConnector,
query: str,
target_ids: set,
limit: int,
) -> List[dict]:
"""
对指定 id 集合补查全文索引 BM25 score。
用原始 query 查全文索引,只保留 id 在 target_ids 中的结果。
"""
escaped = escape_lucene_query(query)
if not escaped.strip():
return []
try:
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率
)
all_hits = r.get("perceptuals", [])
return [h for h in all_hits if h.get("id") in target_ids]
except Exception as e:
logger.warning(f"[PerceptualSearch] BM25 backfill failed: {e}")
return []
async def _keyword_search(
self,
connector: Neo4jConnector,
keywords: List[str],
limit: int,
) -> List[dict]:
"""并发对每个关键词做全文检索,去重后按 score 降序返回 top N 原始结果。"""
seen_ids: set = set()
all_results: List[dict] = []
async def _one(kw: str):
escaped = escape_lucene_query(kw)
if not escaped.strip():
return []
r = await search_perceptual(
connector=connector, q=escaped,
end_user_id=self.end_user_id, limit=limit,
)
return r.get("perceptuals", [])
tasks = [_one(kw) for kw in keywords[:10]]
batch = await asyncio.gather(*tasks, return_exceptions=True)
for result in batch:
if isinstance(result, Exception):
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
continue
for rec in result:
rid = rec.get("id", "")
if rid and rid not in seen_ids:
seen_ids.add(rid)
all_results.append(rec)
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
return all_results[:limit]
async def _embedding_search(
self,
connector: Neo4jConnector,
query_text: str,
limit: int,
) -> List[dict]:
"""向量语义检索,返回原始结果(不做阈值过滤)。"""
try:
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
with get_db_context() as db:
cfg = MemoryConfigService(db).get_embedder_config(
str(self.memory_config.embedding_model_id)
)
client = OpenAIEmbedderClient(RedBearModelConfig(**cfg))
r = await search_perceptual_by_embedding(
connector=connector, embedder_client=client,
query_text=query_text, end_user_id=self.end_user_id,
limit=limit,
)
return r.get("perceptuals", [])
except Exception as e:
logger.warning(f"[PerceptualSearch] embedding search failed: {e}")
return []
def _rerank(
self,
keyword_results: List[dict],
embedding_results: List[dict],
limit: int,
) -> List[dict]:
"""BM25 + embedding 融合排序。
对 embedding 结果中带有 bm25_backfill_score 的条目,
将其与 keyword 结果合并后统一归一化,确保 BM25 分数在同一尺度上。
"""
# 把补查的 BM25 score 合并到 keyword_results 中统一归一化
emb_backfill_items = []
for item in embedding_results:
backfill_score = item.get("bm25_backfill_score")
if backfill_score is not None and item.get("id"):
emb_backfill_items.append({"id": item["id"], "score": backfill_score})
# 合并后统一归一化 BM25 scores
all_bm25_items = keyword_results + emb_backfill_items
all_bm25_items = self._normalize_scores(all_bm25_items)
# 建立 id -> normalized BM25 score 的映射
bm25_norm_map: Dict[str, float] = {}
for item in all_bm25_items:
item_id = item.get("id", "")
if item_id:
bm25_norm_map[item_id] = float(item.get("normalized_score", 0))
# 归一化 embedding scores
embedding_results = self._normalize_scores(embedding_results)
# 合并
combined: Dict[str, dict] = {}
for item in keyword_results:
item_id = item.get("id", "")
if not item_id:
continue
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = 0.0
for item in embedding_results:
item_id = item.get("id", "")
if not item_id:
continue
if item_id in combined:
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined[item_id] = item.copy()
combined[item_id]["bm25_score"] = bm25_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = item.get("normalized_score", 0)
for item in combined.values():
bm25 = float(item.get("bm25_score", 0) or 0)
emb = float(item.get("embedding_score", 0) or 0)
item["content_score"] = self.alpha * bm25 + (1 - self.alpha) * emb
results = list(combined.values())
before = len(results)
results = [r for r in results if r["content_score"] >= self.content_score_threshold]
results.sort(key=lambda x: x["content_score"], reverse=True)
results = results[:limit]
logger.info(
f"[PerceptualSearch] rerank: merged={before}, after_threshold={len(results)} "
f"(alpha={self.alpha}, threshold={self.content_score_threshold})"
)
return results
@staticmethod
def _normalize_scores(items: List[dict], field: str = "score") -> List[dict]:
"""Z-score + sigmoid 归一化。"""
if not items:
return items
scores = [float(it.get(field, 0) or 0) for it in items]
if len(scores) <= 1:
for it in items:
it[f"normalized_{field}"] = 1.0
return items
mean = sum(scores) / len(scores)
var = sum((s - mean) ** 2 for s in scores) / len(scores)
std = math.sqrt(var)
if std == 0:
for it in items:
it[f"normalized_{field}"] = 1.0
else:
for it, s in zip(items, scores):
z = (s - mean) / std
it[f"normalized_{field}"] = 1 / (1 + math.exp(-z))
return items
@staticmethod
def _format_result(record: dict) -> dict:
return {
"id": record.get("id", ""),
"perceptual_type": record.get("perceptual_type", ""),
"file_name": record.get("file_name", ""),
"file_path": record.get("file_path", ""),
"summary": record.get("summary", ""),
"topic": record.get("topic", ""),
"domain": record.get("domain", ""),
"keywords": record.get("keywords", []),
"created_at": str(record.get("created_at", "")),
"file_type": record.get("file_type", ""),
"score": record.get("score", 0),
}
@staticmethod
def _build_content_text(formatted: dict) -> str:
parts = []
if formatted["summary"]:
parts.append(formatted["summary"])
if formatted["topic"]:
parts.append(f"[主题: {formatted['topic']}]")
if formatted["keywords"]:
kw_list = formatted["keywords"]
if isinstance(kw_list, list):
parts.append(f"[关键词: {', '.join(kw_list)}]")
if formatted["file_name"]:
parts.append(f"[文件: {formatted['file_name']}]")
return " ".join(parts)
def _extract_keywords_from_problems(problem_extension: dict) -> List[str]:
"""Extract search keywords from problem extension results."""
keywords = []
context = problem_extension.get("context", {})
if isinstance(context, dict):
for original_q, extended_qs in context.items():
keywords.append(original_q)
if isinstance(extended_qs, list):
keywords.extend(extended_qs)
return keywords
async def perceptual_retrieve_node(state: ReadState) -> ReadState:
"""
LangGraph node: perceptual memory retrieval.
Uses PerceptualSearchService to run keyword + embedding search with
BM25 fusion reranking, then writes results to state['perceptual_data'].
"""
end_user_id = state.get("end_user_id", "")
problem_extension = state.get("problem_extension", {})
original_query = state.get("data", "")
memory_config = state.get("memory_config", None)
logger.info(f"Perceptual_Retrieve: start, end_user_id={end_user_id}")
keywords = _extract_keywords_from_problems(problem_extension)
if not keywords:
keywords = [original_query] if original_query else []
logger.info(f"Perceptual_Retrieve: {len(keywords)} keywords extracted")
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
search_result = await service.search(
query=original_query,
keywords=keywords,
limit=10,
)
result = {
"memories": search_result["memories"],
"content": search_result["content"],
"_intermediate": {
"type": "perceptual_retrieve",
"title": "感知记忆检索",
"data": search_result["memories"],
"query": original_query,
"result_count": len(search_result["memories"]),
},
}
return {"perceptual_data": result}

View File

@@ -263,7 +263,6 @@ async def Problem_Extension(state: ReadState) -> ReadState:
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
print(time.time() - start)
result = {
"context": aggregated_dict,
"original": data,

View File

@@ -1,7 +1,11 @@
import asyncio
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
PerceptualSearchService,
)
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
@@ -339,11 +343,45 @@ async def Input_Summary(state: ReadState) -> ReadState:
try:
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
async def _perceptual_search():
service = PerceptualSearchService(
end_user_id=end_user_id,
memory_config=memory_config,
)
return await service.search(query=data, limit=5)
hybrid_task = SearchService().execute_hybrid_search(
**search_params,
memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
expand_communities=False,
)
perceptual_task = _perceptual_search()
gather_results = await asyncio.gather(
hybrid_task, perceptual_task, return_exceptions=True
)
hybrid_result = gather_results[0]
perceptual_results = gather_results[1]
# 处理 hybrid search 异常
if isinstance(hybrid_result, Exception):
raise hybrid_result
retrieve_info, question, raw_results = hybrid_result
# 处理感知记忆结果
if isinstance(perceptual_results, Exception):
logger.warning(f"[Input_Summary] perceptual search failed: {perceptual_results}")
perceptual_results = []
# 拼接感知记忆内容到 retrieve_info
if perceptual_results and isinstance(perceptual_results, dict):
perceptual_content = perceptual_results.get("content", "")
if perceptual_content:
retrieve_info = f"{retrieve_info}\n\n<history-files>\n{perceptual_content}"
count = len(perceptual_results.get("memories", []))
logger.info(f"[Input_Summary] appended {count} perceptual memories (reranked)")
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
@@ -371,10 +409,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
"error": str(e)
}
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
duration = end - start
log_time('检索', duration)
return {"summary": summary}
@@ -412,8 +447,20 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
aimessages = await summary_llm(
state,
history,
retrieve_info_str,
'direct_summary_prompt.jinja2',
'retrieve_summary', RetrieveSummaryResponse,
"1"
)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
@@ -458,6 +505,12 @@ async def Summary(state: ReadState) -> ReadState:
retrieve_info_str += i + '\n'
history = await summary_history(state)
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,
@@ -508,6 +561,13 @@ async def Summary_fails(state: ReadState) -> ReadState:
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
# Merge perceptual memory content
perceptual_data = state.get("perceptual_data", {})
perceptual_content = perceptual_data.get("content", "") if isinstance(perceptual_data, dict) else ""
if perceptual_content:
retrieve_info_str = f"{retrieve_info_str}\n\n<history-file-input>\n{perceptual_content}</history-file-input>"
data = {
"query": query,
"history": history,

View File

@@ -17,6 +17,9 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve_nodes,
)
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary,
Retrieve_Summary,
@@ -48,13 +51,14 @@ async def make_read_graph():
"""
try:
# Build workflow graph
workflow = StateGraph(ReadState)
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension)
workflow.add_node("Input_Summary", Input_Summary)
workflow.add_node("Retrieve", retrieve_nodes)
# workflow.add_node("Retrieve", retrieve)
workflow.add_node("Perceptual_Retrieve", perceptual_retrieve_node)
workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
@@ -65,14 +69,15 @@ async def make_read_graph():
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
# After Problem_Extension, retrieve perceptual memory first, then main Retrieve
workflow.add_edge("Problem_Extension", "Perceptual_Retrieve")
workflow.add_edge("Perceptual_Retrieve", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# Compile workflow
@@ -80,7 +85,5 @@ async def make_read_graph():
yield graph
except Exception as e:
print(f"创建工作流失败: {e}")
logger.error(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")

View File

@@ -10,7 +10,6 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
@@ -31,10 +30,10 @@ def _clean_expand_fields(obj):
async def expand_communities_to_statements(
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
) -> Tuple[List[dict], List[str]]:
"""
社区展开 helper给定命中的 community 列表,拉取关联 Statement。
@@ -76,17 +75,18 @@ async def expand_communities_to_statements(
if s.get("statement") and s["statement"] not in existing_lines
]
cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
logger.info(
f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
"""
Extract only meaningful content from search results, dropping all metadata.
@@ -107,19 +107,19 @@ class SearchService:
"""
if not isinstance(result, dict):
return str(result)
content_parts = []
# Statements: extract statement field
if 'statement' in result and result['statement']:
content_parts.append(result['statement'])
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = (
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
)
if is_community:
name = result.get('name', '')
@@ -130,16 +130,16 @@ class SearchService:
elif 'content' in result and result['content']:
# Summaries / Chunks
content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original)
# if 'name' in result and result['name']:
# content_parts.append(result['name'])
# if result.get('fact_summary'):
# content_parts.append(result['fact_summary'])
# Return concatenated content or empty string
return '\n'.join(content_parts) if content_parts else ""
def clean_query(self, query: str) -> str:
"""
Clean and escape query text for Lucene.
@@ -155,33 +155,33 @@ class SearchService:
Cleaned and escaped query string
"""
q = str(query).strip()
# Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"')
q.startswith('"') and q.endswith('"')
):
q = q[1:-1]
# Remove newlines and carriage returns
q = q.replace('\r', ' ').replace('\n', ' ').strip()
# Apply Lucene escaping
q = escape_lucene_query(q)
return q
async def execute_hybrid_search(
self,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config = None,
expand_communities: bool = True,
self,
end_user_id: str,
question: str,
limit: int = 5,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config=None,
expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
@@ -205,10 +205,10 @@ class SearchService:
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries", "communities"]
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
answer = await run_hybrid_search(
@@ -221,18 +221,18 @@ class SearchService:
memory_config=memory_config,
rerank_alpha=rerank_alpha
)
# Extract results based on search type and include parameter
# Prioritize summaries as they contain synthesized contextual information
answer_list = []
# For hybrid search, use reranked_results
if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in reranked_results:
category_results = reranked_results[category]
@@ -242,7 +242,7 @@ class SearchService:
# For keyword or embedding search, results are directly in answer dict
# Apply same priority order
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in answer:
category_results = answer[category]
@@ -261,7 +261,7 @@ class SearchService:
end_user_id=end_user_id,
)
answer_list.extend(cleaned_stmts)
# Extract clean content from all results按类型传入 node_type 区分 community
content_list = []
for ans in answer_list:
@@ -269,19 +269,18 @@ class SearchService:
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])
# Log first 200 chars
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
# Return raw results if requested
if return_raw_results:
return clean_content, cleaned_query, answer
else:
return clean_content, cleaned_query, None
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{end_user_id}': {e}",

View File

@@ -1,4 +1,3 @@
import os
from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict
@@ -52,6 +51,7 @@ class ReadState(TypedDict):
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve: dict
perceptual_data: dict
RetrieveSummary: dict
InputSummary: dict
verify: dict

View File

@@ -43,6 +43,7 @@ load_dotenv()
logger = get_memory_logger(__name__)
def _parse_datetime(value: Any) -> Optional[datetime]:
"""Parse ISO `created_at` strings of the form 'YYYY-MM-DDTHH:MM:SS.ssssss'."""
if value is None:
@@ -75,7 +76,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if score_field == "activation_value" and score is None:
scores.append(None) # 保持 None稍后特殊处理
continue
if score is not None and isinstance(score, (int, float)):
scores.append(float(score))
else:
@@ -83,10 +84,10 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if not scores:
return results
# 过滤掉 None 值,只对有效分数进行归一化
valid_scores = [s for s in scores if s is not None]
if not valid_scores:
# 所有分数都是 None不进行归一化
for item in results:
@@ -94,7 +95,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
item[f"normalized_{score_field}"] = None
return results
if len(valid_scores) == 1: # Single valid score, set to 1.0
if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
@@ -132,7 +133,6 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Remove duplicate items from search results based on content.
@@ -150,52 +150,53 @@ def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
seen_ids = set()
seen_content = set()
deduplicated = []
for item in items:
# Try multiple ID fields to identify unique items
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# Extract content from various possible fields
content = (
item.get("text") or
item.get("content") or
item.get("statement") or
item.get("name") or
""
item.get("text") or
item.get("content") or
item.get("statement") or
item.get("name") or
""
)
# Normalize content for comparison (strip whitespace and lowercase)
normalized_content = str(content).strip().lower() if content else ""
# Check if we've seen this ID or content before
is_duplicate = False
if item_id and item_id in seen_ids:
is_duplicate = True
elif normalized_content and normalized_content in seen_content:
# Only check content duplication if content is not empty
is_duplicate = True
if not is_duplicate:
# Mark as seen
if item_id:
seen_ids.add(item_id)
if normalized_content: # Only track non-empty content
seen_content.add(normalized_content)
deduplicated.append(item)
return deduplicated
def rerank_with_activation(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
content_score_threshold: float = 0.5,
) -> Dict[str, List[Dict[str, Any]]]:
"""
两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -222,6 +223,8 @@ def rerank_with_activation(
forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算)
content_score_threshold: 内容相关性最低阈值(基于归一化后的 content_score
低于此阈值的结果会被过滤。默认 0.5。
返回:
带评分元数据的重排序结果,按 final_score 排序
@@ -229,26 +232,26 @@ def rerank_with_activation(
# 验证权重范围
if not (0 <= alpha <= 1):
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
# 初始化遗忘引擎(如果需要)
engine = None
if forgetting_config:
engine = ForgettingEngine(forgetting_config)
now_dt = now or datetime.now()
reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# 步骤 1: 归一化分数
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# 步骤 2: 按 ID 合并结果(去重)
combined_items: Dict[str, Dict[str, Any]] = {}
# 添加关键词结果
for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
@@ -257,7 +260,7 @@ def rerank_with_activation(
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # 默认值
# 添加或更新向量嵌入结果
for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
@@ -271,18 +274,18 @@ def rerank_with_activation(
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # 默认值
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# 步骤 3: 归一化激活度分数
# 为所有项准备激活度值列表
items_list = list(combined_items.values())
items_list = normalize_scores(items_list, "activation_value")
# 更新 combined_items 中的归一化激活度分数
for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
# 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0)
@@ -290,45 +293,45 @@ def rerank_with_activation(
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序)
item["activation_score"] = act_norm # 可能为 None
item["content_score"] = content_score
item["base_score"] = base_score
# 步骤 5: 应用遗忘曲线(可选)
if engine:
# 计算受激活度影响的记忆强度
importance = float(item.get("importance_score", 0.5) or 0.5)
# 获取 activation_value
activation_val = item.get("activation_value")
# 只对有激活值的节点应用遗忘曲线
if activation_val is not None and isinstance(activation_val, (int, float)):
activation_val = float(activation_val)
# 计算记忆强度importance_score × (1 + activation_value × boost_factor)
memory_strength = importance * (1 + activation_val * activation_boost_factor)
# 计算经过的时间(天数)
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# 获取遗忘权重
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days,
memory_strength=memory_strength
)
# 应用到基础分数
item["forgetting_weight"] = forgetting_weight
item["final_score"] = base_score * forgetting_weight
@@ -338,7 +341,7 @@ def rerank_with_activation(
else:
# 不使用遗忘曲线
item["final_score"] = base_score
# 步骤 6: 两阶段排序和限制
# 第一阶段按内容相关性base_score排序取 Top-K
first_stage_limit = limit * 3 # 可配置取3倍候选
@@ -347,11 +350,11 @@ def rerank_with_activation(
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
reverse=True
)[:first_stage_limit]
# 第二阶段:分离有激活值和无激活值的节点
items_with_activation = []
items_without_activation = []
for item in first_stage_sorted:
activation_score = item.get("activation_score")
# 检查是否有有效的激活值(不是 None
@@ -359,14 +362,14 @@ def rerank_with_activation(
items_with_activation.append(item)
else:
items_without_activation.append(item)
# 优先按激活值排序有激活值的节点
sorted_with_activation = sorted(
items_with_activation,
key=lambda x: float(x.get("activation_score", 0) or 0),
reverse=True
)
# 如果有激活值的节点不足 limit用无激活值的节点补充
if len(sorted_with_activation) < limit:
needed = limit - len(sorted_with_activation)
@@ -374,7 +377,7 @@ def rerank_with_activation(
sorted_items = sorted_with_activation + items_without_activation[:needed]
else:
sorted_items = sorted_with_activation[:limit]
# 两阶段排序完成,更新 final_score 以反映实际排序依据
# Stage 1: 按 content_score 筛选候选(已完成)
# Stage 2: 按 activation_score 排序(已完成)
@@ -390,16 +393,29 @@ def rerank_with_activation(
else:
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
# 最终去重确保没有重复项
if content_score_threshold > 0:
before_count = len(sorted_items)
sorted_items = [
item for item in sorted_items
if float(item.get("content_score", 0) or 0) >= content_score_threshold
]
filtered_count = before_count - len(sorted_items)
if filtered_count > 0:
logger.info(
f"[rerank] {category}: filtered {filtered_count}/{before_count} "
f"items below content_score_threshold={content_score_threshold}"
)
sorted_items = _deduplicate_results(sorted_items)
reranked[category] = sorted_items
return reranked
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None):
def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str],
log_file: str = None):
"""Log search query information using the logger.
Args:
@@ -412,7 +428,7 @@ def log_search_query(query_text: str, search_type: str, end_user_id: str | None,
"""
# Ensure the query text is plain and clean before logging
cleaned_query = extract_plain_query(query_text)
# Log using the standard logger
logger.info(
f"Search query: query='{cleaned_query}', type={search_type}, "
@@ -439,8 +455,8 @@ def _remove_keys_recursive(obj: Any, keys_to_remove: List[str]) -> Any:
def apply_reranker_placeholder(
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
results: Dict[str, List[Dict[str, Any]]],
query_text: str,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Placeholder for a cross-encoder reranker.
@@ -483,7 +499,7 @@ def apply_reranker_placeholder(
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Apply LLM-based reranking to search results.
# Args:
# results: Search results organized by category
# query_text: Original search query
@@ -491,7 +507,7 @@ def apply_reranker_placeholder(
# llm_weight: Weight for LLM score (0.0-1.0, higher favors LLM)
# top_k: Maximum number of items to rerank per category
# batch_size: Number of items to process concurrently
# Returns:
# Reranked results with final_score and reranker_model fields
# """
@@ -501,18 +517,18 @@ def apply_reranker_placeholder(
# # except Exception as e:
# # logger.debug(f"Failed to load reranker config: {e}")
# # rc = {}
# # Check if reranking is enabled
# enabled = rc.get("enabled", False)
# if not enabled:
# logger.debug("LLM reranking is disabled in configuration")
# return results
# # Load configuration parameters with defaults
# llm_weight = llm_weight if llm_weight is not None else rc.get("llm_weight", 0.5)
# top_k = top_k if top_k is not None else rc.get("top_k", 20)
# batch_size = batch_size if batch_size is not None else rc.get("batch_size", 5)
# # Initialize reranker client if not provided
# if reranker_client is None:
# try:
@@ -520,10 +536,10 @@ def apply_reranker_placeholder(
# except Exception as e:
# logger.warning(f"Failed to initialize reranker client: {e}, skipping LLM reranking")
# return results
# # Get model name for metadata
# model_name = getattr(reranker_client, 'model_name', 'unknown')
# # Process each category
# reranked_results = {}
# for category in ["statements", "chunks", "entities", "summaries"]:
@@ -531,38 +547,38 @@ def apply_reranker_placeholder(
# if not items:
# reranked_results[category] = []
# continue
# # Select top K items by combined_score for reranking
# sorted_items = sorted(
# items,
# key=lambda x: float(x.get("combined_score", x.get("score", 0.0)) or 0.0),
# reverse=True
# )
# top_items = sorted_items[:top_k]
# remaining_items = sorted_items[top_k:]
# # Extract text content from each item
# def extract_text(item: Dict[str, Any]) -> str:
# """Extract text content from a result item."""
# # Try different text fields based on category
# text = item.get("text") or item.get("content") or item.get("statement") or item.get("name") or ""
# return str(text).strip()
# # Batch items for concurrent processing
# batches = []
# for i in range(0, len(top_items), batch_size):
# batch = top_items[i:i + batch_size]
# batches.append(batch)
# # Process batches concurrently
# async def process_batch(batch: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# """Process a batch of items with LLM relevance scoring."""
# scored_batch = []
# for item in batch:
# item_text = extract_text(item)
# # Skip items with no text
# if not item_text:
# item_copy = item.copy()
@@ -572,7 +588,7 @@ def apply_reranker_placeholder(
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# continue
# # Create relevance scoring prompt
# prompt = f"""Given the search query and a result item, rate the relevance of the item to the query on a scale from 0.0 to 1.0.
@@ -585,15 +601,15 @@ def apply_reranker_placeholder(
# - 1.0 means perfectly relevant
# Relevance score:"""
# # Send request to LLM
# try:
# messages = [{"role": "user", "content": prompt}]
# response = await reranker_client.chat(messages)
# # Parse LLM response to extract relevance score
# response_text = str(response.content if hasattr(response, 'content') else response).strip()
# # Try to extract a float from the response
# try:
# # Remove any non-numeric characters except decimal point
@@ -608,11 +624,11 @@ def apply_reranker_placeholder(
# except (ValueError, AttributeError) as e:
# logger.warning(f"Invalid LLM score format: {response_text}, using combined_score. Error: {e}")
# llm_score = None
# # Calculate final score
# item_copy = item.copy()
# combined_score = float(item.get("combined_score", item.get("score", 0.0)) or 0.0)
# if llm_score is not None:
# final_score = (1 - llm_weight) * combined_score + llm_weight * llm_score
# item_copy["llm_relevance_score"] = llm_score
@@ -620,7 +636,7 @@ def apply_reranker_placeholder(
# # Use combined_score as fallback
# final_score = combined_score
# item_copy["llm_relevance_score"] = combined_score
# item_copy["final_score"] = final_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
@@ -632,14 +648,14 @@ def apply_reranker_placeholder(
# item_copy["llm_relevance_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_batch.append(item_copy)
# return scored_batch
# # Process all batches concurrently
# try:
# batch_tasks = [process_batch(batch) for batch in batches]
# batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True)
# # Merge batch results
# scored_items = []
# for result in batch_results:
@@ -647,7 +663,7 @@ def apply_reranker_placeholder(
# logger.warning(f"Batch processing failed: {result}")
# continue
# scored_items.extend(result)
# # Add remaining items (not in top K) with their combined_score as final_score
# for item in remaining_items:
# item_copy = item.copy()
@@ -655,11 +671,11 @@ def apply_reranker_placeholder(
# item_copy["final_score"] = combined_score
# item_copy["reranker_model"] = model_name
# scored_items.append(item_copy)
# # Sort all items by final_score in descending order
# scored_items.sort(key=lambda x: float(x.get("final_score", 0.0) or 0.0), reverse=True)
# reranked_results[category] = scored_items
# except Exception as e:
# logger.error(f"Error in LLM reranking for category {category}: {e}, returning original results")
# # Return original items with combined_score as final_score
@@ -668,22 +684,22 @@ def apply_reranker_placeholder(
# item["final_score"] = combined_score
# item["reranker_model"] = model_name
# reranked_results[category] = items
# return reranked_results
async def run_hybrid_search(
query_text: str,
search_type: str,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
query_text: str,
search_type: str,
end_user_id: str | None,
limit: int,
include: List[str],
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
):
"""
@@ -699,7 +715,7 @@ async def run_hybrid_search(
# Clean and normalize the incoming query before use/logging
query_text = extract_plain_query(query_text)
# Validate query is not empty after cleaning
if not query_text or not query_text.strip():
logger.warning("Empty query after cleaning, returning empty results")
@@ -716,7 +732,7 @@ async def run_hybrid_search(
"error": "Empty query"
}
}
# Log the search query
log_search_query(query_text, search_type, end_user_id, limit, include)
@@ -747,7 +763,7 @@ async def run_hybrid_search(
# Embedding-based search
logger.info("[PERF] Starting embedding search...")
embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time()
try:
@@ -768,7 +784,7 @@ async def run_hybrid_search(
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
connector=connector,
@@ -788,7 +804,7 @@ async def run_hybrid_search(
if keyword_task:
keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start
keyword_latency = time.time() - search_start_time
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
@@ -798,7 +814,7 @@ async def run_hybrid_search(
if embedding_task:
embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start
embedding_latency = time.time() - search_start_time
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
@@ -810,7 +826,8 @@ async def run_hybrid_search(
if search_type == "hybrid":
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat()
}
@@ -818,7 +835,7 @@ async def run_hybrid_search(
# Apply two-stage reranking with ACTR activation calculation
rerank_start = time.time()
logger.info("[PERF] Using two-stage reranking with ACTR activation")
# 加载遗忘引擎配置
config_start = time.time()
try:
@@ -829,7 +846,7 @@ async def run_hybrid_search(
forgetting_cfg = ForgettingEngineConfig()
config_time = time.time() - config_start
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
# 统一使用激活度重排序(两阶段:检索 + ACTR计算
rerank_compute_start = time.time()
reranked_results = rerank_with_activation(
@@ -842,14 +859,14 @@ async def run_hybrid_search(
)
rerank_compute_time = time.time() - rerank_compute_start
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
# Optional: apply reranker placeholder if enabled via config
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
# Apply LLM reranking if enabled
llm_rerank_applied = False
# if use_llm_rerank:
@@ -862,11 +879,12 @@ async def run_hybrid_search(
# logger.info("LLM reranking applied successfully")
# except Exception as e:
# logger.warning(f"LLM reranking failed: {e}, using previous scores")
results["reranked_results"] = reranked_results
results["combined_summary"] = {
"total_keyword_results": sum(len(v) if isinstance(v, list) else 0 for v in keyword_results.values()),
"total_embedding_results": sum(len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_embedding_results": sum(
len(v) if isinstance(v, list) else 0 for v in embedding_results.values()),
"total_reranked_results": sum(len(v) if isinstance(v, list) else 0 for v in reranked_results.values()),
"search_query": query_text,
"search_timestamp": datetime.now().isoformat(),
@@ -879,13 +897,13 @@ async def run_hybrid_search(
# Calculate total latency
total_latency = time.time() - search_start_time
latency_metrics["total_latency"] = round(total_latency, 4)
# Add latency metrics to results
if "combined_summary" in results:
results["combined_summary"]["latency_metrics"] = latency_metrics
else:
results["latency_metrics"] = latency_metrics
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
@@ -908,8 +926,10 @@ async def run_hybrid_search(
# Log search completion with result count
if search_type == "hybrid":
result_counts = {
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in embedding_results.items()}
"keyword": {key: len(value) if isinstance(value, list) else 0 for key, value in
keyword_results.items()},
"embedding": {key: len(value) if isinstance(value, list) else 0 for key, value in
embedding_results.items()}
}
else:
result_counts = {key: len(value) if isinstance(value, list) else 0 for key, value in results.items()}
@@ -927,12 +947,12 @@ async def run_hybrid_search(
async def search_by_temporal(
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal search across Statements.
@@ -968,13 +988,13 @@ async def search_by_temporal(
async def search_by_keyword_temporal(
query_text: str,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
query_text: str,
end_user_id: Optional[str] = "test",
start_date: Optional[str] = None,
end_date: Optional[str] = None,
valid_date: Optional[str] = None,
invalid_date: Optional[str] = None,
limit: int = 1,
):
"""
Temporal keyword search across Statements.
@@ -1011,9 +1031,9 @@ async def search_by_keyword_temporal(
async def search_chunk_by_chunk_id(
chunk_id: str,
end_user_id: Optional[str] = "test",
limit: int = 1,
chunk_id: str,
end_user_id: Optional[str] = "test",
limit: int = 1,
):
"""
Search for Chunks by chunk_id.
@@ -1026,4 +1046,3 @@ async def search_chunk_by_chunk_id(
limit=limit
)
return {"chunks": chunks}