refactor(memory): restructure memory search architecture
- Replace storage_services/search with new read_services/memory_search structure - Implement content_search and perceptual_search strategies - Add query_preprocessor for search optimization - Create memory_service as unified interface - Update celery_app and graph_search for new architecture - Add enums for memory operations - Implement base_pipeline and memory_read pipeline patterns
This commit is contained in:
@@ -101,7 +101,6 @@ celery_app.conf.update(
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
|
||||
18
api/app/core/memory/enums.py
Normal file
18
api/app/core/memory/enums.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class StorageType(StrEnum):
|
||||
NEO4J = 'neo4j'
|
||||
RAG = 'rag'
|
||||
|
||||
|
||||
class Neo4jStorageStrategy(StrEnum):
|
||||
WINDOW = 'window'
|
||||
TIMELINE = 'timeline'
|
||||
AGGREGATE = "aggregate"
|
||||
|
||||
|
||||
class SearchStrategy(StrEnum):
|
||||
DEEP = "0"
|
||||
NORMAL = "1"
|
||||
QUICK = "2"
|
||||
57
api/app/core/memory/memory_service.py
Normal file
57
api/app/core/memory/memory_service.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from app.core.memory.enums import StorageType
|
||||
from app.schemas import MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
class MemoryContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||
|
||||
end_user_id: str
|
||||
memory_config: MemoryConfig
|
||||
storage_type: StorageType = StorageType.NEO4J
|
||||
user_rag_memory_id: str | None = None
|
||||
language: str = "zh"
|
||||
|
||||
|
||||
class MemoryService:
|
||||
def __init__(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: str,
|
||||
end_user_id: str,
|
||||
workspace_id: str | None = None,
|
||||
storage_type: str = "neo4j",
|
||||
user_rag_memory_id: str | None = None,
|
||||
language: str = "zh",
|
||||
):
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id,
|
||||
service_name="MemoryService",
|
||||
)
|
||||
self.ctx = MemoryContext(
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
storage_type=StorageType(storage_type),
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
language=language,
|
||||
)
|
||||
|
||||
async def write(self, messages: list[dict]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
async def read(self, query: str, history: list, search_switch: str) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
async def reflect(self) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
async def cluster(self, new_entity_ids: list[str] = None) -> None:
|
||||
raise NotImplementedError
|
||||
0
api/app/core/memory/pipelines/__init__.py
Normal file
0
api/app/core/memory/pipelines/__init__.py
Normal file
29
api/app/core/memory/pipelines/base_pipeline.py
Normal file
29
api/app/core/memory/pipelines/base_pipeline.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/4/3 11:44
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from demo.memory_alpha import MemoryContext
|
||||
|
||||
|
||||
class ModelClientMixin(ABC):
|
||||
def get_llm_client(self, db: Session, model_id: uuid.UUID):
|
||||
pass
|
||||
|
||||
def get_embedding_client(self, db: Session, model_id: uuid.UUID):
|
||||
pass
|
||||
|
||||
|
||||
class BasePipeline(ABC):
|
||||
def __init__(self, ctx: MemoryContext, db: Session):
|
||||
self.ctx = ctx
|
||||
self.db = db
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, *args, **kwargs) -> Any:
|
||||
pass
|
||||
26
api/app/core/memory/pipelines/memory_read.py
Normal file
26
api/app/core/memory/pipelines/memory_read.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from app.core.memory.enums import SearchStrategy
|
||||
from app.core.memory.pipelines.base_pipeline import BasePipeline
|
||||
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
|
||||
|
||||
|
||||
class ReadPipeLine(BasePipeline):
|
||||
async def run(self, query, search_switch, memory_config):
|
||||
query = QueryPreprocessor.process(query)
|
||||
match search_switch:
|
||||
case SearchStrategy.DEEP:
|
||||
return await self._deep_read()
|
||||
case SearchStrategy.NORMAL:
|
||||
return await self._normal_read(query)
|
||||
case SearchStrategy.QUICK:
|
||||
return await self._quick_read()
|
||||
case _:
|
||||
raise RuntimeError("Unsupported search strategy")
|
||||
|
||||
async def _deep_read(self):
|
||||
pass
|
||||
|
||||
async def _normal_read(self, query):
|
||||
pass
|
||||
|
||||
async def _quick_read(self):
|
||||
pass
|
||||
@@ -0,0 +1,14 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/4/9 16:48
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.core.memory.memory_service import MemoryContext
|
||||
|
||||
|
||||
class ContentSearch:
|
||||
def __init__(self, ctx: MemoryContext):
|
||||
self.ctx = ctx
|
||||
|
||||
async def search(self, query):
|
||||
pass
|
||||
@@ -0,0 +1,228 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.core.memory.memory_service import MemoryContext
|
||||
from app.core.memory.utils.data import escape_lucene_query
|
||||
from app.repositories.neo4j.graph_search import search_perceptual, search_perceptual_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerceptualResult(BaseModel):
|
||||
memories: list[dict[str, Any]] = []
|
||||
content: str = ""
|
||||
keyword_raw: int = 0
|
||||
embedding_raw: int = 0
|
||||
|
||||
|
||||
class PerceptualRetrieverService:
|
||||
DEFAULT_ALPHA = 0.6
|
||||
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 0.5
|
||||
DEFAULT_COSINE_SCORE_THRESHOLD = 0.7
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ctx: MemoryContext,
|
||||
embedder: OpenAIEmbedderClient,
|
||||
alpha: float = DEFAULT_ALPHA,
|
||||
fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD,
|
||||
cosine_score_threshold: float = DEFAULT_COSINE_SCORE_THRESHOLD
|
||||
):
|
||||
self.ctx = ctx
|
||||
self.alpha = alpha
|
||||
self.fulltext_score_threshold = fulltext_score_threshold
|
||||
self.cosine_score_threshold = cosine_score_threshold
|
||||
|
||||
self.embedder: OpenAIEmbedderClient = embedder
|
||||
self.connector = Neo4jConnector()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
keywords: list[str] | None = None,
|
||||
limit: int = 10
|
||||
) -> PerceptualResult:
|
||||
if keywords is None:
|
||||
keywords = [query] if query else []
|
||||
|
||||
try:
|
||||
kw_task = self._keyword_search(keywords, limit)
|
||||
emb_task = self._embedding_search(query, limit)
|
||||
kw_results, emb_results = await asyncio.gather(kw_task, emb_task, return_exceptions=True)
|
||||
if isinstance(kw_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword search error: {kw_results}")
|
||||
kw_results = []
|
||||
if isinstance(emb_results, Exception):
|
||||
logger.warning(f"[PerceptualSearch] embedding search error: {emb_results}")
|
||||
emb_results = []
|
||||
|
||||
reranked = self._rerank(kw_results, emb_results, limit)
|
||||
|
||||
memories = []
|
||||
content_parts = []
|
||||
for record in reranked:
|
||||
fmt = self._format_result(record)
|
||||
fmt["score"] = round(record.get("content_score", 0), 4)
|
||||
memories.append(fmt)
|
||||
content_parts.append(self._build_content_text(fmt))
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] {len(memories)} results after rerank "
|
||||
f"(keyword_raw={len(kw_results)}, embedding_raw={len(emb_results)})"
|
||||
)
|
||||
return PerceptualResult(
|
||||
memories=memories,
|
||||
content="\n\n".join(content_parts),
|
||||
keyword_raw=len(kw_results),
|
||||
embedding_raw=len(emb_results),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[PerceptualSearch] search failed: {e}", exc_info=True)
|
||||
return PerceptualResult()
|
||||
finally:
|
||||
await self.connector.close()
|
||||
|
||||
async def _keyword_search(
|
||||
self,
|
||||
keywords: list[str],
|
||||
limit: int
|
||||
) -> list[dict]:
|
||||
seen_ids: set = set()
|
||||
all_results: list[dict] = []
|
||||
|
||||
async def _one(kw: str):
|
||||
escaped = escape_lucene_query(kw)
|
||||
if not escaped.strip():
|
||||
return []
|
||||
r = await search_perceptual(
|
||||
connector=self.connector, q=escaped,
|
||||
end_user_id=self.ctx.end_user_id, limit=limit
|
||||
)
|
||||
perceptuals = r.get("perceptuals", [])
|
||||
return [perceptual for perceptual in perceptuals if perceptual["score"] > self.fulltext_score_threshold]
|
||||
|
||||
tasks = [_one(kw) for kw in keywords]
|
||||
batch = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for result in batch:
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"[PerceptualSearch] keyword sub-query error: {result}")
|
||||
continue
|
||||
for rec in result:
|
||||
rid = rec.get("id", "")
|
||||
if rid and rid not in seen_ids:
|
||||
seen_ids.add(rid)
|
||||
all_results.append(rec)
|
||||
all_results.sort(key=lambda x: float(x.get("score", 0)), reverse=True)
|
||||
return all_results[:limit]
|
||||
|
||||
async def _embedding_search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int
|
||||
) -> list[dict]:
|
||||
r = await search_perceptual_by_embedding(
|
||||
connector=self.connector,
|
||||
embedder_client=self.embedder,
|
||||
query_text=query,
|
||||
end_user_id=self.ctx.end_user_id,
|
||||
limit=limit
|
||||
)
|
||||
perceptuals = r.get("perceptuals", [])
|
||||
return [perceptual for perceptual in perceptuals if perceptual["score"] > self.cosine_score_threshold]
|
||||
|
||||
def _rerank(
|
||||
self,
|
||||
keyword_results: list[dict],
|
||||
embedding_results: list[dict],
|
||||
limit: int,
|
||||
) -> list[dict]:
|
||||
keyword_results = self._normalize_scores(keyword_results)
|
||||
embedding_results = self._normalize_scores(embedding_results)
|
||||
|
||||
kw_norm_map = {}
|
||||
for item in keyword_results:
|
||||
item_id = item["id"]
|
||||
kw_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
emb_norm_map = {}
|
||||
for item in embedding_results:
|
||||
item_id = item["id"]
|
||||
emb_norm_map[item_id] = float(item.get("normalized_score", 0))
|
||||
|
||||
combined = {}
|
||||
for item in keyword_results:
|
||||
item_id = item["id"]
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||
|
||||
for item in embedding_results:
|
||||
item_id = item["id"]
|
||||
if item_id in combined:
|
||||
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||
else:
|
||||
combined[item_id] = item.copy()
|
||||
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
|
||||
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||
|
||||
for item in combined.values():
|
||||
kw = float(item.get("kw_score", 0) or 0)
|
||||
emb = float(item.get("embedding_score", 0) or 0)
|
||||
item["content_score"] = self.alpha * emb + (1 - self.alpha) * kw
|
||||
|
||||
results = list(combined.values())
|
||||
results.sort(key=lambda x: x["content_score"], reverse=True)
|
||||
results = results[:limit]
|
||||
|
||||
logger.info(
|
||||
f"[PerceptualSearch] rerank: merged={len(combined)}, after_threshold={len(results)} "
|
||||
f"(alpha={self.alpha})"
|
||||
)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scores(items: list[dict], field: str = "score") -> list[dict]:
|
||||
"""Min-max 归一化,将分数线性映射到 [0, 1]。"""
|
||||
if not items:
|
||||
return items
|
||||
scores = [float(it.get(field, 0) or 0) for it in items]
|
||||
min_s = min(scores)
|
||||
max_s = max(scores)
|
||||
diff = max_s - min_s
|
||||
for it, s in zip(items, scores):
|
||||
it[f"normalized_{field}"] = (s - min_s) / diff if diff > 0 else 1.0
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def _format_result(record: dict) -> dict:
|
||||
return {
|
||||
"id": record.get("id", ""),
|
||||
"perceptual_type": record.get("perceptual_type", ""),
|
||||
"file_name": record.get("file_name", ""),
|
||||
"file_path": record.get("file_path", ""),
|
||||
"summary": record.get("summary", ""),
|
||||
"topic": record.get("topic", ""),
|
||||
"domain": record.get("domain", ""),
|
||||
"keywords": record.get("keywords", []),
|
||||
"created_at": str(record.get("created_at", "")),
|
||||
"file_type": record.get("file_type", ""),
|
||||
"score": record.get("score", 0),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_content_text(formatted: dict) -> str:
|
||||
content_text = (f"<history-file-info>"
|
||||
f"<file-name>{formatted["file_name"]}</file-name>"
|
||||
f"<file-path>{formatted["file_path"]}</file-path>"
|
||||
f"<file-type>{formatted["file_type"]}</file-type>"
|
||||
f"<file-topic>{formatted["topic"]}</file-topic>"
|
||||
f"<file-domain>{formatted["keywords"]}</file-domain>"
|
||||
f"<file-summary>{formatted["summary"]}</file-summary>"
|
||||
f"</history-file-info>")
|
||||
return content_text
|
||||
18
api/app/core/memory/read_services/query_preprocessor.py
Normal file
18
api/app/core/memory/read_services/query_preprocessor.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/4/8 18:11
|
||||
import re
|
||||
|
||||
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||
|
||||
|
||||
class QueryPreprocessor:
|
||||
@staticmethod
|
||||
def process(query: str) -> str:
|
||||
text = query.strip()
|
||||
if not text:
|
||||
return text
|
||||
|
||||
text = re.sub(rf"{"|".join(AgentMemoryDataset.PRONOUN)}", AgentMemoryDataset.NAME, text)
|
||||
return text
|
||||
@@ -1,143 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""搜索服务模块
|
||||
|
||||
本模块提供统一的搜索服务接口,支持关键词搜索、语义搜索和混合搜索。
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
from app.core.memory.storage_services.search.hybrid_search import HybridSearchStrategy
|
||||
from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.storage_services.search.semantic_search import (
|
||||
SemanticSearchStrategy,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SearchStrategy",
|
||||
"SearchResult",
|
||||
"KeywordSearchStrategy",
|
||||
"SemanticSearchStrategy",
|
||||
"HybridSearchStrategy",
|
||||
]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 向后兼容的函数式API
|
||||
# ============================================================================
|
||||
# 为了兼容旧代码,提供与 src/search.py 相同的函数式接口
|
||||
|
||||
|
||||
async def run_hybrid_search(
|
||||
query_text: str,
|
||||
search_type: str = "hybrid",
|
||||
end_user_id: str | None = None,
|
||||
apply_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
include: list[str] | None = None,
|
||||
alpha: float = 0.6,
|
||||
use_forgetting_curve: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
"""运行混合搜索(向后兼容的函数式API)
|
||||
|
||||
这是一个向后兼容的包装函数,将旧的函数式API转换为新的基于类的API。
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型("hybrid", "keyword", "semantic")
|
||||
end_user_id: 组ID过滤
|
||||
apply_id: 应用ID过滤
|
||||
user_id: 用户ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
alpha: BM25分数权重(0.0-1.0)
|
||||
use_forgetting_curve: 是否使用遗忘曲线
|
||||
memory_config: MemoryConfig object containing embedding_model_id
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
dict: 搜索结果字典,格式与旧API兼容
|
||||
"""
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
if not memory_config:
|
||||
raise ValueError("memory_config is required for search")
|
||||
|
||||
# 初始化客户端
|
||||
connector = Neo4jConnector()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
|
||||
try:
|
||||
# 根据搜索类型选择策略
|
||||
if search_type == "keyword":
|
||||
strategy = KeywordSearchStrategy(connector=connector)
|
||||
elif search_type == "semantic":
|
||||
strategy = SemanticSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
else: # hybrid
|
||||
strategy = HybridSearchStrategy(
|
||||
connector=connector,
|
||||
embedder_client=embedder_client,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting_curve
|
||||
)
|
||||
|
||||
# 执行搜索
|
||||
result = await strategy.search(
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
alpha=alpha,
|
||||
use_forgetting_curve=use_forgetting_curve,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 转换为旧格式
|
||||
result_dict = result.to_dict()
|
||||
|
||||
# 保存到文件(如果指定了output_path)
|
||||
output_path = kwargs.get('output_path', 'search_results.json')
|
||||
if output_path:
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
# 确保目录存在
|
||||
out_dir = os.path.dirname(output_path)
|
||||
if out_dir:
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
# 保存结果
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result_dict, f, ensure_ascii=False, indent=2, default=str)
|
||||
print(f"Search results saved to {output_path}")
|
||||
except Exception as e:
|
||||
print(f"Error saving search results: {e}")
|
||||
return result_dict
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
__all__.append("run_hybrid_search")
|
||||
@@ -1,408 +0,0 @@
|
||||
# # -*- coding: utf-8 -*-
|
||||
# """混合搜索策略
|
||||
|
||||
# 结合关键词搜索和语义搜索的混合检索方法。
|
||||
# 支持结果重排序和遗忘曲线加权。
|
||||
# """
|
||||
|
||||
# from typing import List, Dict, Any, Optional
|
||||
# import math
|
||||
# from datetime import datetime
|
||||
# from app.core.logging_config import get_memory_logger
|
||||
# from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
# from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
# from app.core.memory.storage_services.search.keyword_search import KeywordSearchStrategy
|
||||
# from app.core.memory.storage_services.search.semantic_search import SemanticSearchStrategy
|
||||
# from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
# from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
# from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
|
||||
# logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
# class HybridSearchStrategy(SearchStrategy):
|
||||
# """混合搜索策略
|
||||
|
||||
# 结合关键词搜索和语义搜索的优势:
|
||||
# - 关键词搜索:精确匹配,适合已知术语
|
||||
# - 语义搜索:语义理解,适合概念查询
|
||||
# - 混合重排序:综合两种搜索的结果
|
||||
# - 遗忘曲线:根据时间衰减调整相关性
|
||||
# """
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# connector: Optional[Neo4jConnector] = None,
|
||||
# embedder_client: Optional[OpenAIEmbedderClient] = None,
|
||||
# alpha: float = 0.6,
|
||||
# use_forgetting_curve: bool = False,
|
||||
# forgetting_config: Optional[ForgettingEngineConfig] = None
|
||||
# ):
|
||||
# """初始化混合搜索策略
|
||||
|
||||
# Args:
|
||||
# connector: Neo4j连接器
|
||||
# embedder_client: 嵌入模型客户端
|
||||
# alpha: BM25分数权重(0.0-1.0),1-alpha为嵌入分数权重
|
||||
# use_forgetting_curve: 是否使用遗忘曲线
|
||||
# forgetting_config: 遗忘引擎配置
|
||||
# """
|
||||
# self.connector = connector
|
||||
# self.embedder_client = embedder_client
|
||||
# self.alpha = alpha
|
||||
# self.use_forgetting_curve = use_forgetting_curve
|
||||
# self.forgetting_config = forgetting_config or ForgettingEngineConfig()
|
||||
# self._owns_connector = connector is None
|
||||
|
||||
# # 创建子策略
|
||||
# self.keyword_strategy = KeywordSearchStrategy(connector=connector)
|
||||
# self.semantic_strategy = SemanticSearchStrategy(
|
||||
# connector=connector,
|
||||
# embedder_client=embedder_client
|
||||
# )
|
||||
|
||||
# async def __aenter__(self):
|
||||
# """异步上下文管理器入口"""
|
||||
# if self._owns_connector:
|
||||
# self.connector = Neo4jConnector()
|
||||
# self.keyword_strategy.connector = self.connector
|
||||
# self.semantic_strategy.connector = self.connector
|
||||
# return self
|
||||
|
||||
# async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
# """异步上下文管理器出口"""
|
||||
# if self._owns_connector and self.connector:
|
||||
# await self.connector.close()
|
||||
|
||||
# async def search(
|
||||
# self,
|
||||
# query_text: str,
|
||||
# end_user_id: Optional[str] = None,
|
||||
# limit: int = 50,
|
||||
# include: Optional[List[str]] = None,
|
||||
# **kwargs
|
||||
# ) -> SearchResult:
|
||||
# """执行混合搜索
|
||||
|
||||
# Args:
|
||||
# query_text: 查询文本
|
||||
# end_user_id: 可选的组ID过滤
|
||||
# limit: 每个类别的最大结果数
|
||||
# include: 要包含的搜索类别列表
|
||||
# **kwargs: 其他搜索参数(如alpha, use_forgetting_curve)
|
||||
|
||||
# Returns:
|
||||
# SearchResult: 搜索结果对象
|
||||
# """
|
||||
# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# # 从kwargs中获取参数
|
||||
# alpha = kwargs.get("alpha", self.alpha)
|
||||
# use_forgetting = kwargs.get("use_forgetting_curve", self.use_forgetting_curve)
|
||||
|
||||
# # 获取有效的搜索类别
|
||||
# include_list = self._get_include_list(include)
|
||||
|
||||
# try:
|
||||
# # 并行执行关键词搜索和语义搜索
|
||||
# keyword_result = await self.keyword_strategy.search(
|
||||
# query_text=query_text,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
# semantic_result = await self.semantic_strategy.search(
|
||||
# query_text=query_text,
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list
|
||||
# )
|
||||
|
||||
# # 重排序结果
|
||||
# if use_forgetting:
|
||||
# reranked_results = self._rerank_with_forgetting_curve(
|
||||
# keyword_result=keyword_result,
|
||||
# semantic_result=semantic_result,
|
||||
# alpha=alpha,
|
||||
# limit=limit
|
||||
# )
|
||||
# else:
|
||||
# reranked_results = self._rerank_hybrid_results(
|
||||
# keyword_result=keyword_result,
|
||||
# semantic_result=semantic_result,
|
||||
# alpha=alpha,
|
||||
# limit=limit
|
||||
# )
|
||||
|
||||
# # 创建元数据
|
||||
# metadata = self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# include=include_list,
|
||||
# alpha=alpha,
|
||||
# use_forgetting_curve=use_forgetting
|
||||
# )
|
||||
|
||||
# # 添加结果统计
|
||||
# metadata["keyword_results"] = keyword_result.metadata.get("result_counts", {})
|
||||
# metadata["semantic_results"] = semantic_result.metadata.get("result_counts", {})
|
||||
# metadata["total_keyword_results"] = keyword_result.total_results()
|
||||
# metadata["total_semantic_results"] = semantic_result.total_results()
|
||||
# metadata["total_reranked_results"] = reranked_results.total_results()
|
||||
|
||||
# reranked_results.metadata = metadata
|
||||
|
||||
# logger.info(f"混合搜索完成: 共找到 {reranked_results.total_results()} 条结果")
|
||||
# return reranked_results
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"混合搜索失败: {e}", exc_info=True)
|
||||
# # 返回空结果但包含错误信息
|
||||
# return SearchResult(
|
||||
# metadata=self._create_metadata(
|
||||
# query_text=query_text,
|
||||
# search_type="hybrid",
|
||||
# end_user_id=end_user_id,
|
||||
# limit=limit,
|
||||
# error=str(e)
|
||||
# )
|
||||
# )
|
||||
|
||||
# def _normalize_scores(
|
||||
# self,
|
||||
# results: List[Dict[str, Any]],
|
||||
# score_field: str = "score"
|
||||
# ) -> List[Dict[str, Any]]:
|
||||
# """使用z-score标准化和sigmoid转换归一化分数
|
||||
|
||||
# Args:
|
||||
# results: 结果列表
|
||||
# score_field: 分数字段名
|
||||
|
||||
# Returns:
|
||||
# List[Dict[str, Any]]: 归一化后的结果列表
|
||||
# """
|
||||
# if not results:
|
||||
# return results
|
||||
|
||||
# # 提取分数
|
||||
# scores = []
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# score = item.get(score_field)
|
||||
# if score is not None and isinstance(score, (int, float)):
|
||||
# scores.append(float(score))
|
||||
# else:
|
||||
# scores.append(0.0)
|
||||
|
||||
# if not scores or len(scores) == 1:
|
||||
# # 单个分数或无分数,设置为1.0
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# item[f"normalized_{score_field}"] = 1.0
|
||||
# return results
|
||||
|
||||
# # 计算均值和标准差
|
||||
# mean_score = sum(scores) / len(scores)
|
||||
# variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
# std_dev = math.sqrt(variance)
|
||||
|
||||
# if std_dev == 0:
|
||||
# # 所有分数相同,设置为1.0
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# item[f"normalized_{score_field}"] = 1.0
|
||||
# else:
|
||||
# # z-score标准化 + sigmoid转换
|
||||
# for item in results:
|
||||
# if score_field in item:
|
||||
# score = item[score_field]
|
||||
# if score is None or not isinstance(score, (int, float)):
|
||||
# score = 0.0
|
||||
# z_score = (score - mean_score) / std_dev
|
||||
# normalized = 1 / (1 + math.exp(-z_score))
|
||||
# item[f"normalized_{score_field}"] = normalized
|
||||
|
||||
# return results
|
||||
|
||||
# def _rerank_hybrid_results(
|
||||
# self,
|
||||
# keyword_result: SearchResult,
|
||||
# semantic_result: SearchResult,
|
||||
# alpha: float,
|
||||
# limit: int
|
||||
# ) -> SearchResult:
|
||||
# """重排序混合搜索结果
|
||||
|
||||
# Args:
|
||||
# keyword_result: 关键词搜索结果
|
||||
# semantic_result: 语义搜索结果
|
||||
# alpha: BM25分数权重
|
||||
# limit: 结果限制
|
||||
|
||||
# Returns:
|
||||
# SearchResult: 重排序后的结果
|
||||
# """
|
||||
# reranked_data = {}
|
||||
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# keyword_items = getattr(keyword_result, category, [])
|
||||
# semantic_items = getattr(semantic_result, category, [])
|
||||
|
||||
# # 归一化分数
|
||||
# keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
# semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
|
||||
# # 合并结果
|
||||
# combined_items = {}
|
||||
|
||||
# # 添加关键词结果
|
||||
# for item in keyword_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if item_id:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
|
||||
# # 添加或更新语义结果
|
||||
# for item in semantic_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if item_id:
|
||||
# if item_id in combined_items:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# # 计算组合分数
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = item.get("bm25_score", 0)
|
||||
# embedding_score = item.get("embedding_score", 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# item["combined_score"] = combined_score
|
||||
|
||||
# # 排序并限制结果
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
|
||||
# reranked_data[category] = sorted_items
|
||||
|
||||
# return SearchResult(
|
||||
# statements=reranked_data.get("statements", []),
|
||||
# chunks=reranked_data.get("chunks", []),
|
||||
# entities=reranked_data.get("entities", []),
|
||||
# summaries=reranked_data.get("summaries", [])
|
||||
# )
|
||||
|
||||
# def _parse_datetime(self, value: Any) -> Optional[datetime]:
|
||||
# """解析日期时间字符串"""
|
||||
# if value is None:
|
||||
# return None
|
||||
# if isinstance(value, datetime):
|
||||
# return value
|
||||
# if isinstance(value, str):
|
||||
# s = value.strip()
|
||||
# if not s:
|
||||
# return None
|
||||
# try:
|
||||
# return datetime.fromisoformat(s)
|
||||
# except Exception:
|
||||
# return None
|
||||
# return None
|
||||
|
||||
# def _rerank_with_forgetting_curve(
|
||||
# self,
|
||||
# keyword_result: SearchResult,
|
||||
# semantic_result: SearchResult,
|
||||
# alpha: float,
|
||||
# limit: int
|
||||
# ) -> SearchResult:
|
||||
# """使用遗忘曲线重排序混合搜索结果
|
||||
|
||||
# Args:
|
||||
# keyword_result: 关键词搜索结果
|
||||
# semantic_result: 语义搜索结果
|
||||
# alpha: BM25分数权重
|
||||
# limit: 结果限制
|
||||
|
||||
# Returns:
|
||||
# SearchResult: 重排序后的结果
|
||||
# """
|
||||
# engine = ForgettingEngine(self.forgetting_config)
|
||||
# now_dt = datetime.now()
|
||||
|
||||
# reranked_data = {}
|
||||
|
||||
# for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
# keyword_items = getattr(keyword_result, category, [])
|
||||
# semantic_items = getattr(semantic_result, category, [])
|
||||
|
||||
# # 归一化分数
|
||||
# keyword_items = self._normalize_scores(keyword_items, "score")
|
||||
# semantic_items = self._normalize_scores(semantic_items, "score")
|
||||
|
||||
# # 合并结果
|
||||
# combined_items = {}
|
||||
|
||||
# for src_items, is_embedding in [(keyword_items, False), (semantic_items, True)]:
|
||||
# for item in src_items:
|
||||
# item_id = item.get("id") or item.get("uuid")
|
||||
# if not item_id:
|
||||
# continue
|
||||
|
||||
# if item_id not in combined_items:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
|
||||
# if is_embedding:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# # 计算分数并应用遗忘权重
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
# embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
|
||||
# # 计算时间衰减
|
||||
# dt = self._parse_datetime(item.get("created_at"))
|
||||
# if dt is None:
|
||||
# time_elapsed_days = 0.0
|
||||
# else:
|
||||
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
# memory_strength = 1.0 # 默认强度
|
||||
# forgetting_weight = engine.calculate_weight(
|
||||
# time_elapsed=time_elapsed_days,
|
||||
# memory_strength=memory_strength
|
||||
# )
|
||||
|
||||
# final_score = combined_score * forgetting_weight
|
||||
# item["combined_score"] = final_score
|
||||
# item["forgetting_weight"] = forgetting_weight
|
||||
# item["time_elapsed_days"] = time_elapsed_days
|
||||
|
||||
# # 排序并限制结果
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
|
||||
# reranked_data[category] = sorted_items
|
||||
|
||||
# return SearchResult(
|
||||
# statements=reranked_data.get("statements", []),
|
||||
# chunks=reranked_data.get("chunks", []),
|
||||
# entities=reranked_data.get("entities", []),
|
||||
# summaries=reranked_data.get("summaries", [])
|
||||
# )
|
||||
@@ -1,122 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""关键词搜索策略
|
||||
|
||||
实现基于关键词的全文搜索功能。
|
||||
使用Neo4j的全文索引进行高效的文本匹配。
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.search.search_strategy import SearchStrategy, SearchResult
|
||||
from app.repositories.neo4j.graph_search import search_graph
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class KeywordSearchStrategy(SearchStrategy):
|
||||
"""关键词搜索策略
|
||||
|
||||
使用Neo4j全文索引进行关键词匹配搜索。
|
||||
支持跨陈述句、实体、分块和摘要的搜索。
|
||||
"""
|
||||
|
||||
def __init__(self, connector: Optional[Neo4jConnector] = None):
|
||||
"""初始化关键词搜索策略
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器,如果为None则创建新连接
|
||||
"""
|
||||
self.connector = connector
|
||||
self._owns_connector = connector is None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
if self._owns_connector:
|
||||
self.connector = Neo4jConnector()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
if self._owns_connector and self.connector:
|
||||
await self.connector.close()
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> SearchResult:
|
||||
"""执行关键词搜索
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
|
||||
# 确保连接器已初始化
|
||||
if not self.connector:
|
||||
self.connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
# 调用底层的关键词搜索函数
|
||||
results_dict = await search_graph(
|
||||
connector=self.connector,
|
||||
query=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
|
||||
# 创建元数据
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
|
||||
# 添加结果统计
|
||||
metadata["result_counts"] = {
|
||||
category: len(results_dict.get(category, []))
|
||||
for category in include_list
|
||||
}
|
||||
metadata["total_results"] = sum(metadata["result_counts"].values())
|
||||
|
||||
# 构建SearchResult对象
|
||||
search_result = SearchResult(
|
||||
statements=results_dict.get("statements", []),
|
||||
chunks=results_dict.get("chunks", []),
|
||||
entities=results_dict.get("entities", []),
|
||||
summaries=results_dict.get("summaries", []),
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.info(f"关键词搜索完成: 共找到 {search_result.total_results()} 条结果")
|
||||
return search_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"关键词搜索失败: {e}", exc_info=True)
|
||||
# 返回空结果但包含错误信息
|
||||
return SearchResult(
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="keyword",
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,125 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""搜索策略基类
|
||||
|
||||
定义搜索策略的抽象接口和统一的搜索结果数据结构。
|
||||
遵循策略模式(Strategy Pattern)和开放-关闭原则(OCP)。
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
"""统一的搜索结果数据结构
|
||||
|
||||
Attributes:
|
||||
statements: 陈述句搜索结果列表
|
||||
chunks: 分块搜索结果列表
|
||||
entities: 实体搜索结果列表
|
||||
summaries: 摘要搜索结果列表
|
||||
metadata: 搜索元数据(如查询时间、结果数量等)
|
||||
"""
|
||||
statements: List[Dict[str, Any]] = Field(default_factory=list, description="陈述句搜索结果")
|
||||
chunks: List[Dict[str, Any]] = Field(default_factory=list, description="分块搜索结果")
|
||||
entities: List[Dict[str, Any]] = Field(default_factory=list, description="实体搜索结果")
|
||||
summaries: List[Dict[str, Any]] = Field(default_factory=list, description="摘要搜索结果")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="搜索元数据")
|
||||
|
||||
def total_results(self) -> int:
|
||||
"""返回所有类别的结果总数"""
|
||||
return (
|
||||
len(self.statements) +
|
||||
len(self.chunks) +
|
||||
len(self.entities) +
|
||||
len(self.summaries)
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"statements": self.statements,
|
||||
"chunks": self.chunks,
|
||||
"entities": self.entities,
|
||||
"summaries": self.summaries,
|
||||
"metadata": self.metadata
|
||||
}
|
||||
|
||||
|
||||
class SearchStrategy(ABC):
|
||||
"""搜索策略抽象基类
|
||||
|
||||
定义所有搜索策略必须实现的接口。
|
||||
遵循依赖反转原则(DIP):高层模块依赖抽象而非具体实现。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> SearchResult:
|
||||
"""执行搜索
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表(statements, chunks, entities, summaries)
|
||||
**kwargs: 其他搜索参数
|
||||
|
||||
Returns:
|
||||
SearchResult: 统一的搜索结果对象
|
||||
"""
|
||||
pass
|
||||
|
||||
def _create_metadata(
|
||||
self,
|
||||
query_text: str,
|
||||
search_type: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""创建搜索元数据
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
search_type: 搜索类型
|
||||
end_user_id: 组ID
|
||||
limit: 结果限制
|
||||
**kwargs: 其他元数据
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 元数据字典
|
||||
"""
|
||||
metadata = {
|
||||
"query": query_text,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id,
|
||||
"limit": limit,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
metadata.update(kwargs)
|
||||
return metadata
|
||||
|
||||
def _get_include_list(self, include: Optional[List[str]] = None) -> List[str]:
|
||||
"""获取要包含的搜索类别列表
|
||||
|
||||
Args:
|
||||
include: 用户指定的类别列表
|
||||
|
||||
Returns:
|
||||
List[str]: 有效的类别列表
|
||||
"""
|
||||
default_include = ["statements", "chunks", "entities", "summaries"]
|
||||
if include is None:
|
||||
return default_include
|
||||
|
||||
# 验证并过滤有效的类别
|
||||
valid_categories = set(default_include)
|
||||
return [cat for cat in include if cat in valid_categories]
|
||||
@@ -1,166 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""语义搜索策略
|
||||
|
||||
实现基于向量嵌入的语义搜索功能。
|
||||
使用余弦相似度进行语义匹配。
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.storage_services.search.search_strategy import (
|
||||
SearchResult,
|
||||
SearchStrategy,
|
||||
)
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
class SemanticSearchStrategy(SearchStrategy):
|
||||
"""语义搜索策略
|
||||
|
||||
使用向量嵌入和余弦相似度进行语义搜索。
|
||||
支持跨陈述句、分块、实体和摘要的语义匹配。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Optional[Neo4jConnector] = None,
|
||||
embedder_client: Optional[OpenAIEmbedderClient] = None
|
||||
):
|
||||
"""初始化语义搜索策略
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器,如果为None则创建新连接
|
||||
embedder_client: 嵌入模型客户端,如果为None则根据配置创建
|
||||
"""
|
||||
self.connector = connector
|
||||
self.embedder_client = embedder_client
|
||||
self._owns_connector = connector is None
|
||||
self._owns_embedder = embedder_client is None
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
if self._owns_connector:
|
||||
self.connector = Neo4jConnector()
|
||||
if self._owns_embedder:
|
||||
self.embedder_client = self._create_embedder_client()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
if self._owns_connector and self.connector:
|
||||
await self.connector.close()
|
||||
|
||||
def _create_embedder_client(self) -> OpenAIEmbedderClient:
|
||||
"""创建嵌入模型客户端
|
||||
|
||||
Returns:
|
||||
OpenAIEmbedderClient: 嵌入模型客户端实例
|
||||
"""
|
||||
try:
|
||||
# 从数据库读取嵌入器配置
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(config_defs.SELECTED_EMBEDDING_ID)
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
)
|
||||
return OpenAIEmbedderClient(model_config=rb_config)
|
||||
except Exception as e:
|
||||
logger.error(f"创建嵌入模型客户端失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> SearchResult:
|
||||
"""执行语义搜索
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
end_user_id: 可选的组ID过滤
|
||||
limit: 每个类别的最大结果数
|
||||
include: 要包含的搜索类别列表
|
||||
**kwargs: 其他搜索参数
|
||||
|
||||
Returns:
|
||||
SearchResult: 搜索结果对象
|
||||
"""
|
||||
logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}")
|
||||
|
||||
# 获取有效的搜索类别
|
||||
include_list = self._get_include_list(include)
|
||||
|
||||
# 确保连接器和嵌入器已初始化
|
||||
if not self.connector:
|
||||
self.connector = Neo4jConnector()
|
||||
if not self.embedder_client:
|
||||
self.embedder_client = self._create_embedder_client()
|
||||
|
||||
try:
|
||||
# 调用底层的语义搜索函数
|
||||
results_dict = await search_graph_by_embedding(
|
||||
connector=self.connector,
|
||||
embedder_client=self.embedder_client,
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
|
||||
# 创建元数据
|
||||
metadata = self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include_list
|
||||
)
|
||||
|
||||
# 添加结果统计
|
||||
metadata["result_counts"] = {
|
||||
category: len(results_dict.get(category, []))
|
||||
for category in include_list
|
||||
}
|
||||
metadata["total_results"] = sum(metadata["result_counts"].values())
|
||||
|
||||
# 构建SearchResult对象
|
||||
search_result = SearchResult(
|
||||
statements=results_dict.get("statements", []),
|
||||
chunks=results_dict.get("chunks", []),
|
||||
entities=results_dict.get("entities", []),
|
||||
summaries=results_dict.get("summaries", []),
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
logger.info(f"语义搜索完成: 共找到 {search_result.total_results()} 条结果")
|
||||
return search_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"语义搜索失败: {e}", exc_info=True)
|
||||
# 返回空结果但包含错误信息
|
||||
return SearchResult(
|
||||
metadata=self._create_metadata(
|
||||
query_text=query_text,
|
||||
search_type="semantic",
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
@@ -1452,6 +1452,30 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
|
||||
RETURN elementId(r) AS uuid
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_USER_ID = """
|
||||
MATCH (p:Perceptual)
|
||||
WHERE p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.summary_embedding AS summary_embedding
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_IDS = """
|
||||
MATCH (p:Perceptual)
|
||||
WHERE p.id IN $ids
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type
|
||||
"""
|
||||
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score
|
||||
WHERE p.end_user_id = $end_user_id
|
||||
@@ -1471,24 +1495,3 @@ RETURN p.id AS id,
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
PERCEPTUAL_EMBEDDING_SEARCH = """
|
||||
CALL db.index.vector.queryNodes('perceptual_summary_embedding_index', $limit * 100, $embedding)
|
||||
YIELD node AS p, score
|
||||
WHERE p.summary_embedding IS NOT NULL AND p.end_user_id = $end_user_id
|
||||
RETURN p.id AS id,
|
||||
p.end_user_id AS end_user_id,
|
||||
p.perceptual_type AS perceptual_type,
|
||||
p.file_path AS file_path,
|
||||
p.file_name AS file_name,
|
||||
p.file_ext AS file_ext,
|
||||
p.summary AS summary,
|
||||
p.keywords AS keywords,
|
||||
p.topic AS topic,
|
||||
p.domain AS domain,
|
||||
p.created_at AS created_at,
|
||||
p.file_type AS file_type,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
@@ -3,13 +3,15 @@ import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
import numpy as np
|
||||
|
||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
CHUNK_EMBEDDING_SEARCH,
|
||||
COMMUNITY_EMBEDDING_SEARCH,
|
||||
ENTITY_EMBEDDING_SEARCH,
|
||||
EXPAND_COMMUNITY_STATEMENTS,
|
||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||
SEARCH_CHUNKS_BY_CONTENT,
|
||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||
@@ -17,7 +19,6 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_ENTITIES_BY_NAME,
|
||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||
@@ -28,14 +29,41 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
STATEMENT_EMBEDDING_SEARCH,
|
||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
||||
SEARCH_PERCEPTUAL_BY_IDS,
|
||||
SEARCH_PERCEPTUAL_BY_USER_ID,
|
||||
)
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def cosine_similarity_search(
|
||||
query: list[float],
|
||||
vectors: list[list[float]],
|
||||
limit: int
|
||||
) -> dict[int, float]:
|
||||
if not vectors:
|
||||
return {}
|
||||
vectors: np.ndarray = np.array(vectors, dtype=np.float32)
|
||||
vectors_norm = vectors / np.linalg.norm(vectors, axis=1, keepdims=True)
|
||||
query: np.ndarray = np.array(query, dtype=np.float32)
|
||||
query_norm = query / np.linalg.norm(query)
|
||||
|
||||
similarities = vectors_norm @ query_norm
|
||||
similarities = (similarities + 1) / 2
|
||||
top_k = min(limit, similarities.shape[0])
|
||||
if top_k <= 0:
|
||||
return {}
|
||||
top_indices = np.argpartition(-similarities, top_k - 1)[-top_k:]
|
||||
top_indices = top_indices[np.argsort(-similarities[top_indices])]
|
||||
result = {}
|
||||
for idx in top_indices:
|
||||
result[idx] = similarities[idx]
|
||||
return result
|
||||
|
||||
|
||||
async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
@@ -352,7 +380,7 @@ async def search_graph_by_embedding(
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
include: List[str] = ["statements", "chunks", "entities", "summaries"],
|
||||
include=None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
@@ -365,6 +393,8 @@ async def search_graph_by_embedding(
|
||||
- Filters by end_user_id if provided
|
||||
- Returns up to 'limit' per included type
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
import time
|
||||
|
||||
# Get embedding for the query
|
||||
@@ -1011,7 +1041,7 @@ async def search_perceptual(
|
||||
|
||||
async def search_perceptual_by_embedding(
|
||||
connector: Neo4jConnector,
|
||||
embedder_client,
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
query_text: str,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 10,
|
||||
@@ -1040,11 +1070,22 @@ async def search_perceptual_by_embedding(
|
||||
|
||||
try:
|
||||
perceptuals = await connector.execute_query(
|
||||
PERCEPTUAL_EMBEDDING_SEARCH,
|
||||
embedding=embedding,
|
||||
SEARCH_PERCEPTUAL_BY_USER_ID,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
ids = [item['id'] for item in perceptuals]
|
||||
vectors = [item['summary_embedding'] for item in perceptuals]
|
||||
sim_res = cosine_similarity_search(embedding, vectors, limit=limit)
|
||||
perceptual_res = {
|
||||
ids[idx]: score
|
||||
for idx, score in sim_res.items()
|
||||
}
|
||||
perceptuals = await connector.execute_query(
|
||||
SEARCH_PERCEPTUAL_BY_IDS,
|
||||
ids=list(perceptual_res.keys())
|
||||
)
|
||||
for perceptual in perceptuals:
|
||||
perceptual["score"] = perceptual_res[perceptual["id"]]
|
||||
except Exception as e:
|
||||
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
|
||||
perceptuals = []
|
||||
|
||||
Reference in New Issue
Block a user