feat(memory): implement quick search pipeline with Neo4j integration

This commit is contained in:
Eternity
2026-04-15 12:18:23 +08:00
parent dca3173ed9
commit 2716a55c7f
19 changed files with 899 additions and 574 deletions

View File

@@ -15,7 +15,7 @@ from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import ReadState from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.utils.data.text_utils import escape_lucene_query from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.graph_search import ( from app.repositories.neo4j.graph_search import (
search_perceptual, search_perceptual_by_fulltext,
search_perceptual_by_embedding, search_perceptual_by_embedding,
) )
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -152,7 +152,7 @@ class PerceptualSearchService:
if not escaped.strip(): if not escaped.strip():
return [] return []
try: try:
r = await search_perceptual( r = await search_perceptual_by_fulltext(
connector=connector, query=escaped, connector=connector, query=escaped,
end_user_id=self.end_user_id, end_user_id=self.end_user_id,
limit=limit * 5, # 多查一些以提高命中率 limit=limit * 5, # 多查一些以提高命中率
@@ -177,7 +177,7 @@ class PerceptualSearchService:
escaped = escape_lucene_query(kw) escaped = escape_lucene_query(kw)
if not escaped.strip(): if not escaped.strip():
return [] return []
r = await search_perceptual( r = await search_perceptual_by_fulltext(
connector=connector, query=escaped, connector=connector, query=escaped,
end_user_id=self.end_user_id, limit=limit, end_user_id=self.end_user_id, limit=limit,
) )

View File

@@ -19,6 +19,7 @@ from app.core.memory.agent.utils.llm_tools import (
from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.enums import Neo4jNodeType
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context from app.db import get_db_context
@@ -338,7 +339,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
"end_user_id": end_user_id, "end_user_id": end_user_id,
"question": data, "question": data,
"return_raw_results": True, "return_raw_results": True,
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点 "include": [Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY] # MemorySummary 和 Community 同为高维度概括节点
} }
try: try:

View File

@@ -1,15 +1,14 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from langchain_core.messages import HumanMessage
from langgraph.constants import START, END from langgraph.constants import START, END
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
Split_The_Problem, Split_The_Problem,
Problem_Extension, Problem_Extension,
@@ -17,9 +16,6 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve_nodes, retrieve_nodes,
) )
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
perceptual_retrieve_node,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary, Input_Summary,
Retrieve_Summary, Retrieve_Summary,
@@ -32,6 +28,9 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
Retrieve_continue, Retrieve_continue,
Verify_continue, Verify_continue,
) )
from app.core.memory.agent.utils.llm_tools import ReadState
logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
@@ -51,7 +50,7 @@ async def make_read_graph():
""" """
try: try:
# Build workflow graph # Build workflow graph
workflow = StateGraph(ReadState) workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node) workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem) workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension) workflow.add_node("Problem_Extension", Problem_Extension)

View File

@@ -7,6 +7,7 @@ and deduplication.
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.enums import Neo4jNodeType
from app.core.memory.src.search import run_hybrid_search from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query from app.core.memory.utils.data.text_utils import escape_lucene_query
@@ -111,13 +112,13 @@ class SearchService:
content_parts = [] content_parts = []
# Statements: extract statement field # Statements: extract statement field
if 'statement' in result and result['statement']: if Neo4jNodeType.STATEMENT in result and result[Neo4jNodeType.STATEMENT]:
content_parts.append(result['statement']) content_parts.append(result[Neo4jNodeType.STATEMENT])
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 # Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 # 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
is_community = ( is_community = (
node_type == "community" node_type == Neo4jNodeType.COMMUNITY
or 'member_count' in result or 'member_count' in result
or 'core_entities' in result or 'core_entities' in result
) )
@@ -204,7 +205,7 @@ class SearchService:
raw_results is None if return_raw_results=False raw_results is None if return_raw_results=False
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries", "communities"] include = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
# Clean query # Clean query
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
@@ -231,7 +232,7 @@ class SearchService:
reranked_results = answer.get('reranked_results', {}) reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities # Priority order: summaries first (most contextual), then communities, statements, chunks, entities
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
for category in priority_order: for category in priority_order:
if category in include and category in reranked_results: if category in include and category in reranked_results:
@@ -241,7 +242,7 @@ class SearchService:
else: else:
# For keyword or embedding search, results are directly in answer dict # For keyword or embedding search, results are directly in answer dict
# Apply same priority order # Apply same priority order
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
for category in priority_order: for category in priority_order:
if category in include and category in answer: if category in include and category in answer:
@@ -250,11 +251,11 @@ class SearchService:
answer_list.extend(category_results) answer_list.extend(category_results)
# 对命中的 community 节点展开其成员 statements路径 "0"/"1" 需要,路径 "2" 不需要) # 对命中的 community 节点展开其成员 statements路径 "0"/"1" 需要,路径 "2" 不需要)
if expand_communities and "communities" in include: if expand_communities and Neo4jNodeType.COMMUNITY in include:
community_results = ( community_results = (
answer.get('reranked_results', {}).get('communities', []) answer.get('reranked_results', {}).get(Neo4jNodeType.COMMUNITY.value, [])
if search_type == "hybrid" if search_type == "hybrid"
else answer.get('communities', []) else answer.get(Neo4jNodeType.COMMUNITY.value, [])
) )
cleaned_stmts, new_texts = await expand_communities_to_statements( cleaned_stmts, new_texts = await expand_communities_to_statements(
community_results=community_results, community_results=community_results,
@@ -266,7 +267,7 @@ class SearchService:
content_list = [] content_list = []
for ans in answer_list: for ans in answer_list:
# community 节点有 member_count 或 core_entities 字段 # community 节点有 member_count 或 core_entities 字段
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" ntype = Neo4jNodeType.COMMUNITY if ('member_count' in ans or 'core_entities' in ans) else ""
content_list.append(self.extract_content_from_result(ans, node_type=ntype)) content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines # Filter out empty strings and join with newlines

View File

@@ -16,3 +16,14 @@ class SearchStrategy(StrEnum):
DEEP = "0" DEEP = "0"
NORMAL = "1" NORMAL = "1"
QUICK = "2" QUICK = "2"
class Neo4jNodeType(StrEnum):
CHUNK = "Chunk"
COMMUNITY = "Community"
DIALOGUE = "Dialogue"
EXTRACTEDENTITY = "ExtractedEntity"
MEMORYSUMMARY = "MemorySummary"
PERCEPTUAL = "Perceptual"
STATEMENT = "Statement"

View File

@@ -21,6 +21,7 @@ from chonkie import (
from app.core.memory.models.config_models import ChunkerConfig from app.core.memory.models.config_models import ChunkerConfig
from app.core.memory.models.message_models import DialogData, Chunk from app.core.memory.models.message_models import DialogData, Chunk
try: try:
from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.memory.llm_tools.openai_client import OpenAIClient
except Exception: except Exception:
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
class LLMChunker: class LLMChunker:
"""LLM-based intelligent chunking strategy""" """LLM-based intelligent chunking strategy"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000): def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client self.llm_client = llm_client
self.chunk_size = chunk_size self.chunk_size = chunk_size
@@ -46,7 +48,8 @@ class LLMChunker:
""" """
messages = [ messages = [
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."}, {"role": "system",
"content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
{"role": "user", "content": prompt} {"role": "user", "content": prompt}
] ]
@@ -311,7 +314,7 @@ class ChunkerClient:
f.write("=" * 60 + "\n\n") f.write("=" * 60 + "\n\n")
for i, chunk in enumerate(dialogue.chunks): for i, chunk in enumerate(dialogue.chunks):
f.write(f"Chunk {i+1}:\n") f.write(f"Chunk {i + 1}:\n")
f.write(f"Size: {len(chunk.content)} characters\n") f.write(f"Size: {len(chunk.content)} characters\n")
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata: if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n") f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")

View File

@@ -1,21 +1,12 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from pydantic import BaseModel, ConfigDict
from app.core.memory.enums import StorageType from app.core.memory.enums import StorageType, SearchStrategy
from app.schemas import MemoryConfig from app.core.memory.models.service_models import Memory, MemoryContext
from app.core.memory.pipelines.memory_read import ReadPipeLine
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
class MemoryContext(BaseModel):
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
end_user_id: str
memory_config: MemoryConfig
storage_type: StorageType = StorageType.NEO4J
user_rag_memory_id: str | None = None
language: str = "zh"
class MemoryService: class MemoryService:
def __init__( def __init__(
self, self,
@@ -44,8 +35,9 @@ class MemoryService:
async def write(self, messages: list[dict]) -> str: async def write(self, messages: list[dict]) -> str:
raise NotImplementedError raise NotImplementedError
async def read(self, query: str, history: list, search_switch: str) -> dict: async def read(self, query: str, history: list, search_switch: SearchStrategy) -> list[Memory]:
raise NotImplementedError with get_db_context() as db:
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit=10)
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict: async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
raise NotImplementedError raise NotImplementedError

View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel, Field, field_serializer, ConfigDict
from app.core.memory.enums import Neo4jNodeType, StorageType
from app.schemas.memory_config_schema import MemoryConfig
class MemoryContext(BaseModel):
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
end_user_id: str
memory_config: MemoryConfig
storage_type: StorageType = StorageType.NEO4J
user_rag_memory_id: str | None = None
language: str = "zh"
class Memory(BaseModel):
source: Neo4jNodeType = Field(...)
score: float = Field(default=0.0)
content: str = Field(default="")
data: dict = Field(default_factory=dict)
query: str = Field(...)
@field_serializer("source")
def serialize_source(self, v) -> str:
return v.value

View File

@@ -1,22 +1,32 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/4/3 11:44
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from demo.memory_alpha import MemoryContext from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.models.service_models import MemoryContext
from app.core.models import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
class ModelClientMixin(ABC): class ModelClientMixin(ABC):
def get_llm_client(self, db: Session, model_id: uuid.UUID): @staticmethod
def get_llm_client(db: Session, model_id: uuid.UUID):
pass pass
def get_embedding_client(self, db: Session, model_id: uuid.UUID): @staticmethod
pass def get_embedding_client(db: Session, model_id: uuid.UUID) -> OpenAIEmbedderClient:
config_service = MemoryConfigService(db)
embedder_client_config = config_service.get_embedder_config(str(model_id))
return OpenAIEmbedderClient(
RedBearModelConfig(
model_name=embedder_client_config["model_name"],
provider=embedder_client_config["provider"],
api_key=embedder_client_config["api_key"],
base_url=embedder_client_config["base_url"],
)
)
class BasePipeline(ABC): class BasePipeline(ABC):

View File

@@ -1,10 +1,11 @@
from app.core.memory.enums import SearchStrategy from app.core.memory.enums import SearchStrategy
from app.core.memory.pipelines.base_pipeline import BasePipeline from app.core.memory.pipelines.base_pipeline import BasePipeline, ModelClientMixin
from app.core.memory.read_services.content_search import Neo4jSearchService
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
class ReadPipeLine(BasePipeline): class ReadPipeLine(ModelClientMixin, BasePipeline):
async def run(self, query, search_switch, memory_config): async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10):
query = QueryPreprocessor.process(query) query = QueryPreprocessor.process(query)
match search_switch: match search_switch:
case SearchStrategy.DEEP: case SearchStrategy.DEEP:
@@ -12,7 +13,7 @@ class ReadPipeLine(BasePipeline):
case SearchStrategy.NORMAL: case SearchStrategy.NORMAL:
return await self._normal_read(query) return await self._normal_read(query)
case SearchStrategy.QUICK: case SearchStrategy.QUICK:
return await self._quick_read() return await self._quick_read(query, limit)
case _: case _:
raise RuntimeError("Unsupported search strategy") raise RuntimeError("Unsupported search strategy")
@@ -22,5 +23,9 @@ class ReadPipeLine(BasePipeline):
async def _normal_read(self, query): async def _normal_read(self, query):
pass pass
async def _quick_read(self): async def _quick_read(self, query, limit):
pass search_service = Neo4jSearchService(
self.ctx,
self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id)
)
return await search_service.search(query, limit)

View File

@@ -0,0 +1,178 @@
import asyncio
import logging
import math
import time
from pydantic import BaseModel, Field
from app.core.memory.enums import Neo4jNodeType
from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.memory_service import MemoryContext
from app.core.memory.models.service_models import Memory
from app.core.memory.read_services.result_builder import data_builder_factory
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
class MemorySearchResult(BaseModel):
memories: dict[str, list[dict]] = Field(default_factory=dict)
content: str = Field(default="")
count: int = Field(default=0)
class Neo4jSearchService:
DEFAULT_ALPHA = 0.6
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1
DEFAULT_COSINE_SCORE_THRESHOLD = 0.5
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
def __init__(
self,
ctx: MemoryContext,
embedder: OpenAIEmbedderClient,
includes: list[Neo4jNodeType] | None = None,
alpha: float = DEFAULT_ALPHA,
fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD,
cosine_score_threshold: float = DEFAULT_COSINE_SCORE_THRESHOLD,
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD
):
self.ctx = ctx
self.alpha = alpha
self.fulltext_score_threshold = fulltext_score_threshold
self.cosine_score_threshold = cosine_score_threshold
self.content_score_threshold = content_score_threshold
self.embedder: OpenAIEmbedderClient = embedder
self.connector: Neo4jConnector | None = None
self.includes = includes
if includes is None:
self.includes = [
Neo4jNodeType.STATEMENT,
Neo4jNodeType.CHUNK,
Neo4jNodeType.EXTRACTEDENTITY,
Neo4jNodeType.MEMORYSUMMARY,
Neo4jNodeType.PERCEPTUAL,
Neo4jNodeType.COMMUNITY
]
async def _keyword_search(
self,
query: str,
limit: int
):
return await search_graph(
connector=self.connector,
query=query,
end_user_id=self.ctx.end_user_id,
limit=limit,
include=self.includes
)
async def _embedding_search(self, query, limit):
return await search_graph_by_embedding(
connector=self.connector,
embedder_client=self.embedder,
query_text=query,
end_user_id=self.ctx.end_user_id,
limit=limit,
include=self.includes
)
def _rerank(
self,
keyword_results: list[dict],
embedding_results: list[dict],
limit: int,
) -> list[dict]:
keyword_results = self._normalize_kw_scores(keyword_results)
embedding_results = embedding_results
kw_norm_map = {}
for item in keyword_results:
item_id = item["id"]
kw_norm_map[item_id] = float(item.get("normalized_kw_score", 0))
emb_norm_map = {}
for item in embedding_results:
item_id = item["id"]
emb_norm_map[item_id] = float(item.get("score", 0))
combined = {}
for item in keyword_results:
item_id = item["id"]
combined[item_id] = item.copy()
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
for item in embedding_results:
item_id = item["id"]
if item_id in combined:
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
else:
combined[item_id] = item.copy()
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
for item in combined.values():
item_id = item["id"]
kw = float(combined[item_id].get("kw_score", 0) or 0)
emb = float(combined[item_id].get("embedding_score", 0) or 0)
base = self.alpha * emb + (1 - self.alpha) * kw
combined[item_id]["content_score"] = base + min(1 - base, 0.1 * kw * emb)
results = sorted(combined.values(), key=lambda x: x["content_score"], reverse=True)
# results = [res for res in results if res["content_score"] > self.content_score_threshold]
results = results[:limit]
logger.info(
f"[MemorySearch] rerank: merged={len(combined)}, after_threshold={len(results)} "
f"(alpha={self.alpha})"
)
return results
def _normalize_kw_scores(self, items: list[dict]) -> list[dict]:
if not items:
return items
scores = [float(it.get("score", 0) or 0) for it in items]
for it, s in zip(items, scores):
it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2))
return items
async def search(
self,
query: str,
limit: int = 10,
) -> list[Memory]:
async with Neo4jConnector() as connector:
self.connector = connector
kw_task = self._keyword_search(query, limit)
emb_task = self._embedding_search(query, limit)
kw_results, emb_results = await asyncio.gather(kw_task, emb_task, return_exceptions=True)
if isinstance(kw_results, Exception):
logger.warning(f"[MemorySearch] keyword search error: {kw_results}")
kw_results = {}
if isinstance(emb_results, Exception):
logger.warning(f"[MemorySearch] embedding search error: {emb_results}")
emb_results = {}
memories = []
for node_type in self.includes:
reranked = self._rerank(
kw_results.get(node_type, []),
emb_results.get(node_type, []),
limit
)
for record in reranked:
memory = data_builder_factory(node_type, record)
memories.append(Memory(
score=memory.score,
content=memory.content,
data=memory.data,
source=node_type,
query=query
))
memories.sort(key=lambda x: x.score, reverse=True)
return memories[:limit]

View File

@@ -0,0 +1,150 @@
from abc import ABC, abstractmethod
from typing import TypeVar
from app.core.memory.enums import Neo4jNodeType
class BaseBuilder(ABC):
def __init__(self, records: dict):
self.record = records
@property
@abstractmethod
def data(self) -> dict:
pass
@property
@abstractmethod
def content(self) -> str:
pass
@property
def score(self) -> float:
return self.record.get("content_score", 0.0) or 0.0
T = TypeVar("T", bound=BaseBuilder)
class ChunkBuilder(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id"),
"content": self.record.get("content"),
"kw_score": self.record.get("kw_score", 0.0),
"emb_score": self.record.get("embedding_score", 0.0)
}
@property
def content(self) -> str:
return self.record.get("content")
class StatementBuiler(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id"),
"content": self.record.get("statement"),
"kw_score": self.record.get("kw_score", 0.0),
"emb_score": self.record.get("embedding_score", 0.0)
}
@property
def content(self) -> str:
return self.record.get("statement")
class EntityBuilder(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id"),
"content": self.record.get("name"),
"kw_score": self.record.get("kw_score", 0.0),
"emb_score": self.record.get("embedding_score", 0.0)
}
@property
def content(self) -> str:
return self.record.get("name")
class SummaryBuilder(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id"),
"content": self.record.get("content"),
"kw_score": self.record.get("kw_score", 0.0),
"emb_score": self.record.get("embedding_score", 0.0)
}
@property
def content(self) -> str:
return self.record.get("content")
class PerceptualBuilder(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id", ""),
"perceptual_type": self.record.get("perceptual_type", ""),
"file_name": self.record.get("file_name", ""),
"file_path": self.record.get("file_path", ""),
"summary": self.record.get("summary", ""),
"topic": self.record.get("topic", ""),
"domain": self.record.get("domain", ""),
"keywords": self.record.get("keywords", []),
"created_at": str(self.record.get("created_at", "")),
"file_type": self.record.get("file_type", ""),
"kw_score": self.record.get("kw_score", 0.0),
"emb_score": self.record.get("embedding_score", 0.0)
}
@property
def content(self) -> str:
return ("<history-file-info>"
f"<file-name>{self.record.get('file_name')}</file-name>"
f"<file-path>{self.record.get('file_path')}</file-path>"
f"<summary>{self.record.get('summary')}</summary>"
f"<topic>{self.record.get('topic')}</topic>"
f"<domain>{self.record.get('domain')}</domain>"
f"<keywords>{self.record.get('keywords')}</keywords>"
f"<file-type>{self.record.get('file_type')}</file-type>"
"</<history-file-info>")
class CommunityBuilder(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id"),
"content": self.record.get("content"),
"kw_score": self.record.get("kw_score", 0.0),
"emb_score": self.record.get("embedding_score", 0.0)
}
@property
def content(self) -> str:
return self.record.get("content")
def data_builder_factory(node_type, data: dict) -> T:
match node_type:
case Neo4jNodeType.STATEMENT:
return StatementBuiler(data)
case Neo4jNodeType.CHUNK:
return ChunkBuilder(data)
case Neo4jNodeType.EXTRACTEDENTITY:
return EntityBuilder(data)
case Neo4jNodeType.MEMORYSUMMARY:
return SummaryBuilder(data)
case Neo4jNodeType.PERCEPTUAL:
return PerceptualBuilder(data)
case Neo4jNodeType.COMMUNITY:
return CommunityBuilder(data)
case _:
raise KeyError(f"Unknown node_type: {node_type}")

View File

@@ -6,6 +6,8 @@ import time
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from app.core.memory.enums import Neo4jNodeType
if TYPE_CHECKING: if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
@@ -131,7 +133,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
return results return results
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]: def deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
""" """
Remove duplicate items from search results based on content. Remove duplicate items from search results based on content.
@@ -194,7 +196,7 @@ def rerank_with_activation(
forgetting_config: ForgettingEngineConfig | None = None, forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8, activation_boost_factor: float = 0.8,
now: datetime | None = None, now: datetime | None = None,
content_score_threshold: float = 0.5, content_score_threshold: float = 0.1,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
两阶段排序:先按内容相关性筛选,再按激活值排序。 两阶段排序:先按内容相关性筛选,再按激活值排序。
@@ -239,7 +241,7 @@ def rerank_with_activation(
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries", "communities"]: for category in [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
@@ -405,7 +407,7 @@ def rerank_with_activation(
f"items below content_score_threshold={content_score_threshold}" f"items below content_score_threshold={content_score_threshold}"
) )
sorted_items = _deduplicate_results(sorted_items) sorted_items = deduplicate_results(sorted_items)
reranked[category] = sorted_items reranked[category] = sorted_items
@@ -691,7 +693,7 @@ async def run_hybrid_search(
search_type: str, search_type: str,
end_user_id: str | None, end_user_id: str | None,
limit: int, limit: int,
include: List[str], include: List[Neo4jNodeType],
output_path: str | None, output_path: str | None,
memory_config: "MemoryConfig", memory_config: "MemoryConfig",
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,

View File

@@ -7,7 +7,8 @@ from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from app.db import Base from app.db import Base
from app.schemas import FileType from app.schemas.app_schema import FileType
class PerceptualType(IntEnum): class PerceptualType(IntEnum):
VISION = 1 VISION = 1

View File

@@ -1,3 +1,4 @@
from app.core.memory.enums import Neo4jNodeType
DIALOGUE_NODE_SAVE = """ DIALOGUE_NODE_SAVE = """
UNWIND $dialogues AS dialogue UNWIND $dialogues AS dialogue
@@ -147,57 +148,6 @@ SET r.predicate = rel.predicate,
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
# 保存弱关系实体,设置 e.is_weak = true不维护 e.relations 聚合字段
WEAK_ENTITY_NODE_SAVE = """
UNWIND $weak_entities AS entity
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
SET e += {
name: entity.name,
end_user_id: entity.end_user_id,
run_id: entity.run_id,
description: entity.description,
chunk_id: entity.chunk_id,
dialog_id: entity.dialog_id
}
// Independent weak flag仅标记弱关系不再维护 relations 聚合字段
SET e.is_weak = true
RETURN e.id AS id
"""
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true不维护 e.relations 字段
SAVE_STRONG_TRIPLE_ENTITIES = """
UNWIND $items AS item
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET s.is_strong = true
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
// Independent strong flag
SET o.is_strong = true
"""
DIALOGUE_STATEMENT_EDGE_SAVE = """
UNWIND $dialogue_statement_edges AS edge
// 支持按 uuid 或 ref_id 连接到 Dialogue避免因来源 ID 不一致而断链
MATCH (dialogue:Dialogue)
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
MATCH (statement:Statement {id: edge.target})
// 仅按端点去重,关系属性可更新
MERGE (dialogue)-[e:MENTIONS]->(statement)
SET e.uuid = edge.id,
e.end_user_id = edge.end_user_id,
e.created_at = edge.created_at,
e.expired_at = edge.expired_at
RETURN e.uuid AS uuid
"""
# 在 Neo4j 5及后续版本中id() 函数已被标记为弃用用elementId() 函数替代
CHUNK_STATEMENT_EDGE_SAVE = """ CHUNK_STATEMENT_EDGE_SAVE = """
UNWIND $chunk_statement_edges AS edge UNWIND $chunk_statement_edges AS edge
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id}) MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
@@ -226,87 +176,6 @@ SET r.end_user_id = rel.end_user_id,
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
ENTITY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
YIELD node AS e, score
WHERE e.name_embedding IS NOT NULL
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Statement.statement_embedding
STATEMENT_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
YIELD node AS s, score
WHERE s.statement_embedding IS NOT NULL
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
RETURN s.id AS id,
s.statement AS statement,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
CHUNK_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
YIELD node AS c, score
WHERE c.chunk_embedding IS NOT NULL
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS chunk_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# 查询实体名称包含指定字符串的实体 # 查询实体名称包含指定字符串的实体
SEARCH_ENTITIES_BY_NAME = """ SEARCH_ENTITIES_BY_NAME = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
@@ -338,73 +207,6 @@ ORDER BY score DESC
LIMIT $limit LIMIT $limit
""" """
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
WITH e, score
With collect({entity: e, score: score}) AS fulltextResults
OPTIONAL MATCH (ae:ExtractedEntity)
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
AND ae.aliases IS NOT NULL
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
WITH fulltextResults, collect(ae) AS aliasEntities
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
CASE
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
ELSE 0.8
END
}]) AS row
WITH row.entity AS e, row.score AS score
WITH DISTINCT e, MAX(score) AS score
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx,
e.statement_id AS statement_id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS chunk_id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT e.id) AS entity_ids,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用 # 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索 # # 同组group_id下按“精确名字或别名+可选类型一致”来检索
@@ -677,49 +479,6 @@ MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
SET n.invalid_at = $new_invalid_at SET n.invalid_at = $new_invalid_at
""" """
# MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id,
m.name AS name,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
COALESCE(m.importance_score, 0.5) AS importance_score,
m.last_access_time AS last_access_time,
COALESCE(m.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# Embedding-based search: cosine similarity on MemorySummary.summary_embedding
MEMORY_SUMMARY_EMBEDDING_SEARCH = """
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
YIELD node AS m, score
WHERE m.summary_embedding IS NOT NULL
AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
RETURN m.id AS id,
m.name AS name,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
COALESCE(m.importance_score, 0.5) AS importance_score,
m.last_access_time AS last_access_time,
COALESCE(m.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
MEMORY_SUMMARY_NODE_SAVE = """ MEMORY_SUMMARY_NODE_SAVE = """
UNWIND $summaries AS summary UNWIND $summaries AS summary
MERGE (m:MemorySummary {id: summary.id}) MERGE (m:MemorySummary {id: summary.id})
@@ -1030,8 +789,6 @@ RETURN DISTINCT
e.statement AS statement; e.statement AS statement;
""" """
'''获取实体'''
Memory_Space_User = """ Memory_Space_User = """
MATCH (n)-[r]->(m) MATCH (n)-[r]->(m)
WHERE n.end_user_id = $end_user_id AND m.name="用户" WHERE n.end_user_id = $end_user_id AND m.name="用户"
@@ -1363,22 +1120,6 @@ WHERE c.name IS NULL OR c.name = ''
RETURN c.community_id AS community_id RETURN c.community_id AS community_id
""" """
# Community keyword search: matches name or summary via fulltext index
SEARCH_COMMUNITIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.community_id AS id,
c.name AS name,
c.summary AS content,
c.core_entities AS core_entities,
c.member_count AS member_count,
c.end_user_id AS end_user_id,
c.updated_at AS updated_at,
score
ORDER BY score DESC
LIMIT $limit
"""
# Community 向量检索 ────────────────────────────────────────────────── # Community 向量检索 ──────────────────────────────────────────────────
# Community embedding-based search: cosine similarity on Community.summary_embedding # Community embedding-based search: cosine similarity on Community.summary_embedding
COMMUNITY_EMBEDDING_SEARCH = """ COMMUNITY_EMBEDDING_SEARCH = """
@@ -1452,13 +1193,54 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
RETURN elementId(r) AS uuid RETURN elementId(r) AS uuid
""" """
# -------------------
# search by user id
# -------------------
SEARCH_PERCEPTUAL_BY_USER_ID = """ SEARCH_PERCEPTUAL_BY_USER_ID = """
MATCH (p:Perceptual) MATCH (p:Perceptual)
WHERE p.end_user_id = $end_user_id WHERE p.end_user_id = $end_user_id
RETURN p.id AS id, RETURN p.id AS id,
p.summary_embedding AS summary_embedding p.summary_embedding AS embedding
""" """
SEARCH_STATEMENTS_BY_USER_ID = """
MATCH (s:Statement)
WHERE s.end_user_id = $end_user_id
RETURN s.id AS id,
s.statement_embedding AS embedding
"""
SEARCH_ENTITIES_BY_USER_ID = """
MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id
RETURN e.id AS id,
e.name_embedding AS embedding
"""
SEARCH_CHUNKS_BY_USER_ID = """
MATCH (c:Chunk)
WHERE c.end_user_id = $end_user_id
RETURN c.id AS id,
c.chunk_embedding AS embedding
"""
SEARCH_MEMORY_SUMMARIES_BY_USER_ID = """
MATCH (s:MemorySummary)
WHERE s.end_user_id = $end_user_id
RETURN s.id AS id,
s.summary_embedding AS embedding
"""
SEARCH_COMMUNITIES_BY_USER_ID = """
MATCH (c:Community)
WHERE c.end_user_id = $end_user_id
RETURN c.id AS id,
c.summary_embedding AS embedding
"""
# -------------------
# search by id
# -------------------
SEARCH_PERCEPTUAL_BY_IDS = """ SEARCH_PERCEPTUAL_BY_IDS = """
MATCH (p:Perceptual) MATCH (p:Perceptual)
WHERE p.id IN $ids WHERE p.id IN $ids
@@ -1476,7 +1258,79 @@ RETURN p.id AS id,
p.file_type AS file_type p.file_type AS file_type
""" """
SEARCH_PERCEPTUAL_BY_KEYWORD = """ SEARCH_STATEMENTS_BY_IDS = """
MATCH (s:Statement)
WHERE s.id IN $ids
RETURN s.id AS id,
s.statement AS statement,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
properties(s)['invalid_at'] AS invalid_at,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count
"""
SEARCH_CHUNKS_BY_IDS = """
MATCH (c:Chunk)
WHERE c.id IN $ids
RETURN c.id AS id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count
"""
SEARCH_ENTITIES_BY_IDS = """
MATCH (e:ExtractedEntity)
WHERE e.id IN $ids
RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count
"""
SEARCH_MEMORY_SUMMARIES_BY_IDS = """
MATCH (m:MemorySummary)
WHERE m.id IN $ids
RETURN m.id AS id,
m.name AS name,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
COALESCE(m.importance_score, 0.5) AS importance_score,
m.last_access_time AS last_access_time,
COALESCE(m.access_count, 0) AS access_count
"""
SEARCH_COMMUNITIES_BY_IDS = """
MATCH (c:Community)
WHERE c.id IN $ids
RETURN c.id AS id,
c.name AS name,
c.summary AS content,
c.core_entities AS core_entities,
c.member_count AS member_count,
c.end_user_id AS end_user_id,
c.updated_at AS updated_at
"""
# -------------------
# search by fulltext
# -------------------
SEARCH_PERCEPTUALS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score
WHERE p.end_user_id = $end_user_id WHERE p.end_user_id = $end_user_id
RETURN p.id AS id, RETURN p.id AS id,
@@ -1495,3 +1349,155 @@ RETURN p.id AS id,
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
""" """
SEARCH_STATEMENTS_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN s.id AS id,
s.statement AS statement,
s.end_user_id AS end_user_id,
s.chunk_id AS chunk_id,
s.created_at AS created_at,
s.expired_at AS expired_at,
s.valid_at AS valid_at,
properties(s)['invalid_at'] AS invalid_at,
c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
WITH e, score
With collect({entity: e, score: score}) AS fulltextResults
OPTIONAL MATCH (ae:ExtractedEntity)
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
AND ae.aliases IS NOT NULL
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
WITH fulltextResults, collect(ae) AS aliasEntities
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
CASE
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
ELSE 0.8
END
}]) AS row
WITH row.entity AS e, row.score AS score
WITH DISTINCT e, MAX(score) AS score
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
e.created_at AS created_at,
e.expired_at AS expired_at,
e.entity_idx AS entity_idx,
e.statement_id AS statement_id,
e.description AS description,
e.aliases AS aliases,
e.name_embedding AS name_embedding,
e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
SEARCH_CHUNKS_BY_CONTENT = """
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
RETURN c.id AS id,
c.end_user_id AS end_user_id,
c.content AS content,
c.dialog_id AS dialog_id,
c.sequence_number AS sequence_number,
collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT e.id) AS entity_ids,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# MemorySummary keyword search using fulltext index
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
RETURN m.id AS id,
m.name AS name,
m.end_user_id AS end_user_id,
m.dialog_id AS dialog_id,
m.chunk_ids AS chunk_ids,
m.content AS content,
m.created_at AS created_at,
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
COALESCE(m.importance_score, 0.5) AS importance_score,
m.last_access_time AS last_access_time,
COALESCE(m.access_count, 0) AS access_count,
score
ORDER BY score DESC
LIMIT $limit
"""
# Community keyword search: matches name or summary via fulltext index
SEARCH_COMMUNITIES_BY_KEYWORD = """
CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
RETURN c.id AS id,
c.name AS name,
c.summary AS content,
c.core_entities AS core_entities,
c.member_count AS member_count,
c.end_user_id AS end_user_id,
c.updated_at AS updated_at,
score
ORDER BY score DESC
LIMIT $limit
"""
FULLTEXT_QUERY_CYPHER_MAPPING = {
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_KEYWORD,
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
Neo4jNodeType.CHUNK: SEARCH_CHUNKS_BY_CONTENT,
Neo4jNodeType.MEMORYSUMMARY: SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
Neo4jNodeType.COMMUNITY: SEARCH_COMMUNITIES_BY_KEYWORD,
Neo4jNodeType.PERCEPTUAL: SEARCH_PERCEPTUALS_BY_KEYWORD
}
USER_ID_QUERY_CYPHER_MAPPING = {
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_USER_ID,
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_USER_ID,
Neo4jNodeType.CHUNK: SEARCH_CHUNKS_BY_USER_ID,
Neo4jNodeType.MEMORYSUMMARY: SEARCH_MEMORY_SUMMARIES_BY_USER_ID,
Neo4jNodeType.COMMUNITY: SEARCH_COMMUNITIES_BY_USER_ID,
Neo4jNodeType.PERCEPTUAL: SEARCH_PERCEPTUAL_BY_USER_ID
}
NODE_ID_QUERY_CYPHER_MAPPING = {
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_IDS,
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_IDS,
Neo4jNodeType.CHUNK: SEARCH_CHUNKS_BY_IDS,
Neo4jNodeType.MEMORYSUMMARY: SEARCH_MEMORY_SUMMARIES_BY_IDS,
Neo4jNodeType.COMMUNITY: SEARCH_COMMUNITIES_BY_IDS,
Neo4jNodeType.PERCEPTUAL: SEARCH_PERCEPTUAL_BY_IDS
}

View File

@@ -1,26 +1,19 @@
import asyncio import asyncio
import logging import logging
from typing import Any, Dict, List, Optional import time
from typing import Any, Dict, List, Optional, Coroutine
from app.core.memory.utils.data.text_utils import escape_lucene_query
import numpy as np import numpy as np
from app.core.memory.enums import Neo4jNodeType
from app.core.memory.llm_tools import OpenAIEmbedderClient from app.core.memory.llm_tools import OpenAIEmbedderClient
from app.core.memory.utils.data.text_utils import escape_lucene_query
from app.repositories.neo4j.cypher_queries import ( from app.repositories.neo4j.cypher_queries import (
CHUNK_EMBEDDING_SEARCH,
COMMUNITY_EMBEDDING_SEARCH,
ENTITY_EMBEDDING_SEARCH,
EXPAND_COMMUNITY_STATEMENTS, EXPAND_COMMUNITY_STATEMENTS,
MEMORY_SUMMARY_EMBEDDING_SEARCH,
SEARCH_CHUNK_BY_CHUNK_ID, SEARCH_CHUNK_BY_CHUNK_ID,
SEARCH_CHUNKS_BY_CONTENT,
SEARCH_COMMUNITIES_BY_KEYWORD,
SEARCH_DIALOGUE_BY_DIALOG_ID, SEARCH_DIALOGUE_BY_DIALOG_ID,
SEARCH_ENTITIES_BY_NAME, SEARCH_ENTITIES_BY_NAME,
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
SEARCH_STATEMENTS_BY_CREATED_AT, SEARCH_STATEMENTS_BY_CREATED_AT,
SEARCH_STATEMENTS_BY_KEYWORD,
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
SEARCH_STATEMENTS_BY_TEMPORAL, SEARCH_STATEMENTS_BY_TEMPORAL,
SEARCH_STATEMENTS_BY_VALID_AT, SEARCH_STATEMENTS_BY_VALID_AT,
@@ -28,12 +21,14 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_STATEMENTS_G_VALID_AT, SEARCH_STATEMENTS_G_VALID_AT,
SEARCH_STATEMENTS_L_CREATED_AT, SEARCH_STATEMENTS_L_CREATED_AT,
SEARCH_STATEMENTS_L_VALID_AT, SEARCH_STATEMENTS_L_VALID_AT,
STATEMENT_EMBEDDING_SEARCH, SEARCH_PERCEPTUALS_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_KEYWORD,
SEARCH_PERCEPTUAL_BY_IDS, SEARCH_PERCEPTUAL_BY_IDS,
SEARCH_PERCEPTUAL_BY_USER_ID, SEARCH_PERCEPTUAL_BY_USER_ID,
FULLTEXT_QUERY_CYPHER_MAPPING,
USER_ID_QUERY_CYPHER_MAPPING,
NODE_ID_QUERY_CYPHER_MAPPING
) )
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -52,7 +47,7 @@ def cosine_similarity_search(
query_norm = query / np.linalg.norm(query) query_norm = query / np.linalg.norm(query)
similarities = vectors_norm @ query_norm similarities = vectors_norm @ query_norm
similarities = (similarities + 1) / 2 similarities = np.clip(similarities, 0, 1)
top_k = min(limit, similarities.shape[0]) top_k = min(limit, similarities.shape[0])
if top_k <= 0: if top_k <= 0:
return {} return {}
@@ -60,7 +55,7 @@ def cosine_similarity_search(
top_indices = top_indices[np.argsort(-similarities[top_indices])] top_indices = top_indices[np.argsort(-similarities[top_indices])]
result = {} result = {}
for idx in top_indices: for idx in top_indices:
result[idx] = similarities[idx] result[idx] = float(similarities[idx])
return result return result
@@ -173,7 +168,10 @@ async def _update_search_results_activation(
knowledge_node_types = { knowledge_node_types = {
'statements': 'Statement', 'statements': 'Statement',
'entities': 'ExtractedEntity', 'entities': 'ExtractedEntity',
'summaries': 'MemorySummary' 'summaries': 'MemorySummary',
Neo4jNodeType.STATEMENT: Neo4jNodeType.STATEMENT.value,
Neo4jNodeType.EXTRACTEDENTITY: Neo4jNodeType.EXTRACTEDENTITY.value,
Neo4jNodeType.MEMORYSUMMARY: Neo4jNodeType.MEMORYSUMMARY.value,
} }
# 并行更新所有类型的节点 # 并行更新所有类型的节点
@@ -250,12 +248,147 @@ async def _update_search_results_activation(
return updated_results return updated_results
async def search_perceptual_by_fulltext(
connector: Neo4jConnector,
query: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUALS_BY_KEYWORD,
query=escape_lucene_query(query),
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import deduplicate_results
perceptuals = deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client: OpenAIEmbedderClient,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_USER_ID,
end_user_id=end_user_id,
)
ids = [item['id'] for item in perceptuals]
vectors = [item['summary_embedding'] for item in perceptuals]
sim_res = cosine_similarity_search(embedding, vectors, limit=limit)
perceptual_res = {
ids[idx]: score
for idx, score in sim_res.items()
}
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_IDS,
ids=list(perceptual_res.keys())
)
for perceptual in perceptuals:
perceptual["score"] = perceptual_res[perceptual["id"]]
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import deduplicate_results
perceptuals = deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
def search_by_fulltext(
connector: Neo4jConnector,
node_type: Neo4jNodeType,
end_user_id: str,
query: str,
limit: int = 10,
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
cypher = FULLTEXT_QUERY_CYPHER_MAPPING[node_type]
return connector.execute_query(
cypher,
json_format=True,
end_user_id=end_user_id,
query=query,
limit=limit,
)
async def search_by_embedding(
connector: Neo4jConnector,
node_type: Neo4jNodeType,
end_user_id: str,
query_embedding: list[float],
limit: int = 10,
) -> list[dict[str, Any]]:
try:
records = await connector.execute_query(
USER_ID_QUERY_CYPHER_MAPPING[node_type],
end_user_id=end_user_id,
)
records = [record for record in records if record if record["embedding"] is not None]
ids = [item['id'] for item in records]
vectors = [item['embedding'] for item in records]
sim_res = cosine_similarity_search(query_embedding, vectors, limit=limit)
records_score_map = {
ids[idx]: score
for idx, score in sim_res.items()
}
records = await connector.execute_query(
NODE_ID_QUERY_CYPHER_MAPPING[node_type],
ids=list(records_score_map.keys()),
json_format=True
)
for record in records:
record["score"] = records_score_map[record["id"]]
except Exception as e:
logger.warning(f"search_graph_by_embedding: vector search failed: {e}, node_type:{node_type.value}",
exc_info=True)
records = []
from app.core.memory.src.search import deduplicate_results
records = deduplicate_results(records)
return records
async def search_graph( async def search_graph(
connector: Neo4jConnector, connector: Neo4jConnector,
query: str, query: str,
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
limit: int = 50, limit: int = 50,
include: List[str] = None, include: List[Neo4jNodeType] = None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Search across Statements, Entities, Chunks, and Summaries using a free-text query. Search across Statements, Entities, Chunks, and Summaries using a free-text query.
@@ -279,7 +412,13 @@ async def search_graph(
Dictionary with search results per category (with updated activation values) Dictionary with search results per category (with updated activation values)
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = [
Neo4jNodeType.STATEMENT,
Neo4jNodeType.CHUNK,
Neo4jNodeType.EXTRACTEDENTITY,
Neo4jNodeType.MEMORYSUMMARY,
Neo4jNodeType.PERCEPTUAL
]
# Escape Lucene special characters to prevent query parse errors # Escape Lucene special characters to prevent query parse errors
escaped_query = escape_lucene_query(query) escaped_query = escape_lucene_query(query)
@@ -288,55 +427,9 @@ async def search_graph(
tasks = [] tasks = []
task_keys = [] task_keys = []
if "statements" in include: for node_type in include:
tasks.append(connector.execute_query( tasks.append(search_by_fulltext(connector, node_type, end_user_id, escaped_query, limit))
SEARCH_STATEMENTS_BY_KEYWORD, task_keys.append(node_type.value)
json_format=True,
query=escaped_query,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
if "entities" in include:
tasks.append(connector.execute_query(
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
json_format=True,
query=escaped_query,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
if "chunks" in include:
tasks.append(connector.execute_query(
SEARCH_CHUNKS_BY_CONTENT,
json_format=True,
query=escaped_query,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
if "summaries" in include:
tasks.append(connector.execute_query(
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
json_format=True,
query=escaped_query,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
if "communities" in include:
tasks.append(connector.execute_query(
SEARCH_COMMUNITIES_BY_KEYWORD,
json_format=True,
query=escaped_query,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel # Execute all queries in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -352,16 +445,16 @@ async def search_graph(
# Deduplicate results before updating activation values # Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline # This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results from app.core.memory.src.search import deduplicate_results
for key in results: for key in results:
if isinstance(results[key], list): if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key]) results[key] = deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization) # Skip activation updates if only searching summaries (optimization)
needs_activation_update = any( needs_activation_update = any(
key in include and key in results and results[key] key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks'] for key in [Neo4jNodeType.STATEMENT, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY]
) )
if needs_activation_update: if needs_activation_update:
@@ -378,7 +471,7 @@ async def search_graph_by_embedding(
connector: Neo4jConnector, connector: Neo4jConnector,
embedder_client, embedder_client,
query_text: str, query_text: str,
end_user_id: Optional[str] = None, end_user_id: str,
limit: int = 50, limit: int = 50,
include=None, include=None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
@@ -394,96 +487,32 @@ async def search_graph_by_embedding(
- Returns up to 'limit' per included type - Returns up to 'limit' per included type
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = [
import time Neo4jNodeType.STATEMENT,
Neo4jNodeType.CHUNK,
Neo4jNodeType.EXTRACTEDENTITY,
Neo4jNodeType.MEMORYSUMMARY,
Neo4jNodeType.PERCEPTUAL
]
# Get embedding for the query
embed_start = time.time()
embeddings = await embedder_client.response([query_text]) embeddings = await embedder_client.response([query_text])
embed_time = time.time() - embed_start
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
if not embeddings or not embeddings[0]: if not embeddings or not embeddings[0]:
logger.warning( logger.warning(f"search_graph_by_embedding: embedding generation failed for '{query_text[:50]}'")
f"search_graph_by_embedding: embedding 生成失败或为空," return {search_key: [] for search_key in include}
f"query='{query_text[:50]}', end_user_id={end_user_id},向量检索跳过"
)
return {"statements": [], "chunks": [], "entities": [], "summaries": [], "communities": []}
embedding = embeddings[0] embedding = embeddings[0]
# Prepare tasks for parallel execution # Prepare tasks for parallel execution
tasks = [] tasks = []
task_keys = [] task_keys = []
# Statements (embedding) for node_type in include:
if "statements" in include: tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit))
tasks.append(connector.execute_query( task_keys.append(node_type.value)
STATEMENT_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("statements")
# Chunks (embedding)
if "chunks" in include:
tasks.append(connector.execute_query(
CHUNK_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("chunks")
# Entities
if "entities" in include:
tasks.append(connector.execute_query(
ENTITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("entities")
# Memory summaries
if "summaries" in include:
tasks.append(connector.execute_query(
MEMORY_SUMMARY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("summaries")
# Communities (向量语义匹配)
if "communities" in include:
tasks.append(connector.execute_query(
COMMUNITY_EMBEDDING_SEARCH,
json_format=True,
embedding=embedding,
end_user_id=end_user_id,
limit=limit,
))
task_keys.append("communities")
# Execute all queries in parallel
query_start = time.time()
task_results = await asyncio.gather(*tasks, return_exceptions=True) task_results = await asyncio.gather(*tasks, return_exceptions=True)
query_time = time.time() - query_start
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
# Build results dictionary # Build results dictionary
results: Dict[str, List[Dict[str, Any]]] = { results: Dict[str, List[Dict[str, Any]]] = {}
"statements": [],
"chunks": [],
"entities": [],
"summaries": [],
"communities": [],
}
for key, result in zip(task_keys, task_results): for key, result in zip(task_keys, task_results):
if isinstance(result, Exception): if isinstance(result, Exception):
@@ -494,16 +523,16 @@ async def search_graph_by_embedding(
# Deduplicate results before updating activation values # Deduplicate results before updating activation values
# This prevents duplicates from propagating through the pipeline # This prevents duplicates from propagating through the pipeline
from app.core.memory.src.search import _deduplicate_results from app.core.memory.src.search import deduplicate_results
for key in results: for key in results:
if isinstance(results[key], list): if isinstance(results[key], list):
results[key] = _deduplicate_results(results[key]) results[key] = deduplicate_results(results[key])
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary # 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
# Skip activation updates if only searching summaries (optimization) # Skip activation updates if only searching summaries (optimization)
needs_activation_update = any( needs_activation_update = any(
key in include and key in results and results[key] key in include and key in results and results[key]
for key in ['statements', 'entities', 'chunks'] for key in [Neo4jNodeType.STATEMENT, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY]
) )
if needs_activation_update: if needs_activation_update:
@@ -781,12 +810,12 @@ async def search_graph_community_expand(
expanded.extend(result) expanded.extend(result)
# 按 activation_value 全局排序后去重 # 按 activation_value 全局排序后去重
from app.core.memory.src.search import _deduplicate_results from app.core.memory.src.search import deduplicate_results
expanded.sort( expanded.sort(
key=lambda x: float(x.get("activation_value") or 0), key=lambda x: float(x.get("activation_value") or 0),
reverse=True, reverse=True,
) )
expanded = _deduplicate_results(expanded) expanded = deduplicate_results(expanded)
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}") logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
return {"expanded_statements": expanded} return {"expanded_statements": expanded}
@@ -999,98 +1028,3 @@ async def search_graph_l_valid_at(
) )
return results return results
async def search_perceptual(
connector: Neo4jConnector,
query: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using fulltext keyword search.
Matches against summary, topic, and domain fields via the perceptualFulltext index.
Args:
connector: Neo4j connector
query: Query text for full-text search
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_KEYWORD,
query=escape_lucene_query(query),
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"search_perceptual: keyword search failed: {e}")
perceptuals = []
# Deduplicate
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}
async def search_perceptual_by_embedding(
connector: Neo4jConnector,
embedder_client: OpenAIEmbedderClient,
query_text: str,
end_user_id: Optional[str] = None,
limit: int = 10,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Search Perceptual memory nodes using embedding-based semantic search.
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
Args:
connector: Neo4j connector
embedder_client: Embedding client with async response() method
query_text: Query text to embed
end_user_id: Optional user filter
limit: Max results
Returns:
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
"""
embeddings = await embedder_client.response([query_text])
if not embeddings or not embeddings[0]:
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
return {"perceptuals": []}
embedding = embeddings[0]
try:
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_USER_ID,
end_user_id=end_user_id,
)
ids = [item['id'] for item in perceptuals]
vectors = [item['summary_embedding'] for item in perceptuals]
sim_res = cosine_similarity_search(embedding, vectors, limit=limit)
perceptual_res = {
ids[idx]: score
for idx, score in sim_res.items()
}
perceptuals = await connector.execute_query(
SEARCH_PERCEPTUAL_BY_IDS,
ids=list(perceptual_res.keys())
)
for perceptual in perceptuals:
perceptual["score"] = perceptual_res[perceptual["id"]]
except Exception as e:
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
perceptuals = []
from app.core.memory.src.search import _deduplicate_results
perceptuals = _deduplicate_results(perceptuals)
return {"perceptuals": perceptuals}

View File

@@ -70,6 +70,12 @@ class Neo4jConnector:
auth=basic_auth(username, password) auth=basic_auth(username, password)
) )
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def close(self): async def close(self):
"""关闭数据库连接 """关闭数据库连接
@@ -77,11 +83,11 @@ class Neo4jConnector:
""" """
await self.driver.close() await self.driver.close()
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]: async def execute_query(self, cypher: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
"""执行Cypher查询 """执行Cypher查询
Args: Args:
query: Cypher查询语句 cypher: Cypher查询语句
json_format: json格式化 json_format: json格式化
**kwargs: 查询参数将作为参数传递给Cypher查询 **kwargs: 查询参数将作为参数传递给Cypher查询
@@ -92,7 +98,7 @@ class Neo4jConnector:
""" """
result = await self.driver.execute_query( result = await self.driver.execute_query(
query, cypher,
database="neo4j", database="neo4j",
**kwargs **kwargs
) )