feat(memory): implement quick search pipeline with Neo4j integration
This commit is contained in:
@@ -15,7 +15,7 @@ from app.core.logging_config import get_agent_logger
|
|||||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
from app.repositories.neo4j.graph_search import (
|
from app.repositories.neo4j.graph_search import (
|
||||||
search_perceptual,
|
search_perceptual_by_fulltext,
|
||||||
search_perceptual_by_embedding,
|
search_perceptual_by_embedding,
|
||||||
)
|
)
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
@@ -152,7 +152,7 @@ class PerceptualSearchService:
|
|||||||
if not escaped.strip():
|
if not escaped.strip():
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
r = await search_perceptual(
|
r = await search_perceptual_by_fulltext(
|
||||||
connector=connector, query=escaped,
|
connector=connector, query=escaped,
|
||||||
end_user_id=self.end_user_id,
|
end_user_id=self.end_user_id,
|
||||||
limit=limit * 5, # 多查一些以提高命中率
|
limit=limit * 5, # 多查一些以提高命中率
|
||||||
@@ -177,7 +177,7 @@ class PerceptualSearchService:
|
|||||||
escaped = escape_lucene_query(kw)
|
escaped = escape_lucene_query(kw)
|
||||||
if not escaped.strip():
|
if not escaped.strip():
|
||||||
return []
|
return []
|
||||||
r = await search_perceptual(
|
r = await search_perceptual_by_fulltext(
|
||||||
connector=connector, query=escaped,
|
connector=connector, query=escaped,
|
||||||
end_user_id=self.end_user_id, limit=limit,
|
end_user_id=self.end_user_id, limit=limit,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from app.core.memory.agent.utils.llm_tools import (
|
|||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.memory.agent.utils.session_tools import SessionService
|
from app.core.memory.agent.utils.session_tools import SessionService
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
|
|
||||||
@@ -338,7 +339,7 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
"return_raw_results": True,
|
"return_raw_results": True,
|
||||||
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
|
"include": [Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY] # MemorySummary 和 Community 同为高维度概括节点
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -1,15 +1,14 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
import logging
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage
|
|
||||||
from langgraph.constants import START, END
|
from langgraph.constants import START, END
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
from app.db import get_db
|
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
||||||
|
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
||||||
|
perceptual_retrieve_node,
|
||||||
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||||
Split_The_Problem,
|
Split_The_Problem,
|
||||||
Problem_Extension,
|
Problem_Extension,
|
||||||
@@ -17,9 +16,6 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
|||||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||||
retrieve_nodes,
|
retrieve_nodes,
|
||||||
)
|
)
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.perceptual_retrieve_node import (
|
|
||||||
perceptual_retrieve_node,
|
|
||||||
)
|
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||||
Input_Summary,
|
Input_Summary,
|
||||||
Retrieve_Summary,
|
Retrieve_Summary,
|
||||||
@@ -32,6 +28,9 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
|||||||
Retrieve_continue,
|
Retrieve_continue,
|
||||||
Verify_continue,
|
Verify_continue,
|
||||||
)
|
)
|
||||||
|
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -51,7 +50,7 @@ async def make_read_graph():
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# Build workflow graph
|
# Build workflow graph
|
||||||
workflow = StateGraph(ReadState)
|
workflow = StateGraph(ReadState)
|
||||||
workflow.add_node("content_input", content_input_node)
|
workflow.add_node("content_input", content_input_node)
|
||||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ and deduplication.
|
|||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
from app.core.memory.src.search import run_hybrid_search
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
|
|
||||||
@@ -111,13 +112,13 @@ class SearchService:
|
|||||||
content_parts = []
|
content_parts = []
|
||||||
|
|
||||||
# Statements: extract statement field
|
# Statements: extract statement field
|
||||||
if 'statement' in result and result['statement']:
|
if Neo4jNodeType.STATEMENT in result and result[Neo4jNodeType.STATEMENT]:
|
||||||
content_parts.append(result['statement'])
|
content_parts.append(result[Neo4jNodeType.STATEMENT])
|
||||||
|
|
||||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||||
is_community = (
|
is_community = (
|
||||||
node_type == "community"
|
node_type == Neo4jNodeType.COMMUNITY
|
||||||
or 'member_count' in result
|
or 'member_count' in result
|
||||||
or 'core_entities' in result
|
or 'core_entities' in result
|
||||||
)
|
)
|
||||||
@@ -204,7 +205,7 @@ class SearchService:
|
|||||||
raw_results is None if return_raw_results=False
|
raw_results is None if return_raw_results=False
|
||||||
"""
|
"""
|
||||||
if include is None:
|
if include is None:
|
||||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
include = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||||
|
|
||||||
# Clean query
|
# Clean query
|
||||||
cleaned_query = self.clean_query(question)
|
cleaned_query = self.clean_query(question)
|
||||||
@@ -231,7 +232,7 @@ class SearchService:
|
|||||||
reranked_results = answer.get('reranked_results', {})
|
reranked_results = answer.get('reranked_results', {})
|
||||||
|
|
||||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in reranked_results:
|
if category in include and category in reranked_results:
|
||||||
@@ -241,7 +242,7 @@ class SearchService:
|
|||||||
else:
|
else:
|
||||||
# For keyword or embedding search, results are directly in answer dict
|
# For keyword or embedding search, results are directly in answer dict
|
||||||
# Apply same priority order
|
# Apply same priority order
|
||||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
priority_order = [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in answer:
|
if category in include and category in answer:
|
||||||
@@ -250,11 +251,11 @@ class SearchService:
|
|||||||
answer_list.extend(category_results)
|
answer_list.extend(category_results)
|
||||||
|
|
||||||
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||||
if expand_communities and "communities" in include:
|
if expand_communities and Neo4jNodeType.COMMUNITY in include:
|
||||||
community_results = (
|
community_results = (
|
||||||
answer.get('reranked_results', {}).get('communities', [])
|
answer.get('reranked_results', {}).get(Neo4jNodeType.COMMUNITY.value, [])
|
||||||
if search_type == "hybrid"
|
if search_type == "hybrid"
|
||||||
else answer.get('communities', [])
|
else answer.get(Neo4jNodeType.COMMUNITY.value, [])
|
||||||
)
|
)
|
||||||
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||||
community_results=community_results,
|
community_results=community_results,
|
||||||
@@ -266,7 +267,7 @@ class SearchService:
|
|||||||
content_list = []
|
content_list = []
|
||||||
for ans in answer_list:
|
for ans in answer_list:
|
||||||
# community 节点有 member_count 或 core_entities 字段
|
# community 节点有 member_count 或 core_entities 字段
|
||||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
ntype = Neo4jNodeType.COMMUNITY if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||||
|
|
||||||
# Filter out empty strings and join with newlines
|
# Filter out empty strings and join with newlines
|
||||||
|
|||||||
@@ -16,3 +16,14 @@ class SearchStrategy(StrEnum):
|
|||||||
DEEP = "0"
|
DEEP = "0"
|
||||||
NORMAL = "1"
|
NORMAL = "1"
|
||||||
QUICK = "2"
|
QUICK = "2"
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jNodeType(StrEnum):
|
||||||
|
CHUNK = "Chunk"
|
||||||
|
COMMUNITY = "Community"
|
||||||
|
DIALOGUE = "Dialogue"
|
||||||
|
EXTRACTEDENTITY = "ExtractedEntity"
|
||||||
|
MEMORYSUMMARY = "MemorySummary"
|
||||||
|
PERCEPTUAL = "Perceptual"
|
||||||
|
STATEMENT = "Statement"
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from chonkie import (
|
|||||||
|
|
||||||
from app.core.memory.models.config_models import ChunkerConfig
|
from app.core.memory.models.config_models import ChunkerConfig
|
||||||
from app.core.memory.models.message_models import DialogData, Chunk
|
from app.core.memory.models.message_models import DialogData, Chunk
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -32,6 +33,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class LLMChunker:
|
class LLMChunker:
|
||||||
"""LLM-based intelligent chunking strategy"""
|
"""LLM-based intelligent chunking strategy"""
|
||||||
|
|
||||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||||
self.llm_client = llm_client
|
self.llm_client = llm_client
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
@@ -46,7 +48,8 @@ class LLMChunker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
|
{"role": "system",
|
||||||
|
"content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
|
||||||
{"role": "user", "content": prompt}
|
{"role": "user", "content": prompt}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -311,7 +314,7 @@ class ChunkerClient:
|
|||||||
f.write("=" * 60 + "\n\n")
|
f.write("=" * 60 + "\n\n")
|
||||||
|
|
||||||
for i, chunk in enumerate(dialogue.chunks):
|
for i, chunk in enumerate(dialogue.chunks):
|
||||||
f.write(f"Chunk {i+1}:\n")
|
f.write(f"Chunk {i + 1}:\n")
|
||||||
f.write(f"Size: {len(chunk.content)} characters\n")
|
f.write(f"Size: {len(chunk.content)} characters\n")
|
||||||
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
if hasattr(chunk, 'metadata') and 'start_index' in chunk.metadata:
|
||||||
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
f.write(f"Position: {chunk.metadata.get('start_index')}-{chunk.metadata.get('end_index')}\n")
|
||||||
|
|||||||
@@ -1,21 +1,12 @@
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from pydantic import BaseModel, ConfigDict
|
|
||||||
|
|
||||||
from app.core.memory.enums import StorageType
|
from app.core.memory.enums import StorageType, SearchStrategy
|
||||||
from app.schemas import MemoryConfig
|
from app.core.memory.models.service_models import Memory, MemoryContext
|
||||||
|
from app.core.memory.pipelines.memory_read import ReadPipeLine
|
||||||
|
from app.db import get_db_context
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
|
||||||
class MemoryContext(BaseModel):
|
|
||||||
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
|
||||||
|
|
||||||
end_user_id: str
|
|
||||||
memory_config: MemoryConfig
|
|
||||||
storage_type: StorageType = StorageType.NEO4J
|
|
||||||
user_rag_memory_id: str | None = None
|
|
||||||
language: str = "zh"
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryService:
|
class MemoryService:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -44,8 +35,9 @@ class MemoryService:
|
|||||||
async def write(self, messages: list[dict]) -> str:
|
async def write(self, messages: list[dict]) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def read(self, query: str, history: list, search_switch: str) -> dict:
|
async def read(self, query: str, history: list, search_switch: SearchStrategy) -> list[Memory]:
|
||||||
raise NotImplementedError
|
with get_db_context() as db:
|
||||||
|
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit=10)
|
||||||
|
|
||||||
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
|
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
26
api/app/core/memory/models/service_models.py
Normal file
26
api/app/core/memory/models/service_models.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pydantic import BaseModel, Field, field_serializer, ConfigDict
|
||||||
|
|
||||||
|
from app.core.memory.enums import Neo4jNodeType, StorageType
|
||||||
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryContext(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
end_user_id: str
|
||||||
|
memory_config: MemoryConfig
|
||||||
|
storage_type: StorageType = StorageType.NEO4J
|
||||||
|
user_rag_memory_id: str | None = None
|
||||||
|
language: str = "zh"
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(BaseModel):
|
||||||
|
source: Neo4jNodeType = Field(...)
|
||||||
|
score: float = Field(default=0.0)
|
||||||
|
content: str = Field(default="")
|
||||||
|
data: dict = Field(default_factory=dict)
|
||||||
|
query: str = Field(...)
|
||||||
|
|
||||||
|
@field_serializer("source")
|
||||||
|
def serialize_source(self, v) -> str:
|
||||||
|
return v.value
|
||||||
@@ -1,22 +1,32 @@
|
|||||||
# -*- coding: UTF-8 -*-
|
|
||||||
# Author: Eternity
|
|
||||||
# @Email: 1533512157@qq.com
|
|
||||||
# @Time : 2026/4/3 11:44
|
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from demo.memory_alpha import MemoryContext
|
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||||
|
from app.core.memory.models.service_models import MemoryContext
|
||||||
|
from app.core.models import RedBearModelConfig
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
|
||||||
class ModelClientMixin(ABC):
|
class ModelClientMixin(ABC):
|
||||||
def get_llm_client(self, db: Session, model_id: uuid.UUID):
|
@staticmethod
|
||||||
|
def get_llm_client(db: Session, model_id: uuid.UUID):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_embedding_client(self, db: Session, model_id: uuid.UUID):
|
@staticmethod
|
||||||
pass
|
def get_embedding_client(db: Session, model_id: uuid.UUID) -> OpenAIEmbedderClient:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
embedder_client_config = config_service.get_embedder_config(str(model_id))
|
||||||
|
return OpenAIEmbedderClient(
|
||||||
|
RedBearModelConfig(
|
||||||
|
model_name=embedder_client_config["model_name"],
|
||||||
|
provider=embedder_client_config["provider"],
|
||||||
|
api_key=embedder_client_config["api_key"],
|
||||||
|
base_url=embedder_client_config["base_url"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BasePipeline(ABC):
|
class BasePipeline(ABC):
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from app.core.memory.enums import SearchStrategy
|
from app.core.memory.enums import SearchStrategy
|
||||||
from app.core.memory.pipelines.base_pipeline import BasePipeline
|
from app.core.memory.pipelines.base_pipeline import BasePipeline, ModelClientMixin
|
||||||
|
from app.core.memory.read_services.content_search import Neo4jSearchService
|
||||||
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
|
from app.core.memory.read_services.query_preprocessor import QueryPreprocessor
|
||||||
|
|
||||||
|
|
||||||
class ReadPipeLine(BasePipeline):
|
class ReadPipeLine(ModelClientMixin, BasePipeline):
|
||||||
async def run(self, query, search_switch, memory_config):
|
async def run(self, query: str, search_switch: SearchStrategy, limit: int = 10):
|
||||||
query = QueryPreprocessor.process(query)
|
query = QueryPreprocessor.process(query)
|
||||||
match search_switch:
|
match search_switch:
|
||||||
case SearchStrategy.DEEP:
|
case SearchStrategy.DEEP:
|
||||||
@@ -12,7 +13,7 @@ class ReadPipeLine(BasePipeline):
|
|||||||
case SearchStrategy.NORMAL:
|
case SearchStrategy.NORMAL:
|
||||||
return await self._normal_read(query)
|
return await self._normal_read(query)
|
||||||
case SearchStrategy.QUICK:
|
case SearchStrategy.QUICK:
|
||||||
return await self._quick_read()
|
return await self._quick_read(query, limit)
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unsupported search strategy")
|
raise RuntimeError("Unsupported search strategy")
|
||||||
|
|
||||||
@@ -22,5 +23,9 @@ class ReadPipeLine(BasePipeline):
|
|||||||
async def _normal_read(self, query):
|
async def _normal_read(self, query):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _quick_read(self):
|
async def _quick_read(self, query, limit):
|
||||||
pass
|
search_service = Neo4jSearchService(
|
||||||
|
self.ctx,
|
||||||
|
self.get_embedding_client(self.db, self.ctx.memory_config.embedding_model_id)
|
||||||
|
)
|
||||||
|
return await search_service.search(query, limit)
|
||||||
|
|||||||
0
api/app/core/memory/read_services/__init__.py
Normal file
0
api/app/core/memory/read_services/__init__.py
Normal file
178
api/app/core/memory/read_services/content_search.py
Normal file
178
api/app/core/memory/read_services/content_search.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
|
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||||
|
from app.core.memory.memory_service import MemoryContext
|
||||||
|
from app.core.memory.models.service_models import Memory
|
||||||
|
from app.core.memory.read_services.result_builder import data_builder_factory
|
||||||
|
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemorySearchResult(BaseModel):
|
||||||
|
memories: dict[str, list[dict]] = Field(default_factory=dict)
|
||||||
|
content: str = Field(default="")
|
||||||
|
count: int = Field(default=0)
|
||||||
|
|
||||||
|
|
||||||
|
class Neo4jSearchService:
|
||||||
|
DEFAULT_ALPHA = 0.6
|
||||||
|
DEFAULT_FULLTEXT_SCORE_THRESHOLD = 1
|
||||||
|
DEFAULT_COSINE_SCORE_THRESHOLD = 0.5
|
||||||
|
DEFAULT_CONTENT_SCORE_THRESHOLD = 0.5
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ctx: MemoryContext,
|
||||||
|
embedder: OpenAIEmbedderClient,
|
||||||
|
includes: list[Neo4jNodeType] | None = None,
|
||||||
|
alpha: float = DEFAULT_ALPHA,
|
||||||
|
fulltext_score_threshold: float = DEFAULT_FULLTEXT_SCORE_THRESHOLD,
|
||||||
|
cosine_score_threshold: float = DEFAULT_COSINE_SCORE_THRESHOLD,
|
||||||
|
content_score_threshold: float = DEFAULT_CONTENT_SCORE_THRESHOLD
|
||||||
|
):
|
||||||
|
self.ctx = ctx
|
||||||
|
self.alpha = alpha
|
||||||
|
self.fulltext_score_threshold = fulltext_score_threshold
|
||||||
|
self.cosine_score_threshold = cosine_score_threshold
|
||||||
|
self.content_score_threshold = content_score_threshold
|
||||||
|
|
||||||
|
self.embedder: OpenAIEmbedderClient = embedder
|
||||||
|
self.connector: Neo4jConnector | None = None
|
||||||
|
|
||||||
|
self.includes = includes
|
||||||
|
if includes is None:
|
||||||
|
self.includes = [
|
||||||
|
Neo4jNodeType.STATEMENT,
|
||||||
|
Neo4jNodeType.CHUNK,
|
||||||
|
Neo4jNodeType.EXTRACTEDENTITY,
|
||||||
|
Neo4jNodeType.MEMORYSUMMARY,
|
||||||
|
Neo4jNodeType.PERCEPTUAL,
|
||||||
|
Neo4jNodeType.COMMUNITY
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _keyword_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int
|
||||||
|
):
|
||||||
|
return await search_graph(
|
||||||
|
connector=self.connector,
|
||||||
|
query=query,
|
||||||
|
end_user_id=self.ctx.end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
include=self.includes
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _embedding_search(self, query, limit):
|
||||||
|
return await search_graph_by_embedding(
|
||||||
|
connector=self.connector,
|
||||||
|
embedder_client=self.embedder,
|
||||||
|
query_text=query,
|
||||||
|
end_user_id=self.ctx.end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
include=self.includes
|
||||||
|
)
|
||||||
|
|
||||||
|
def _rerank(
|
||||||
|
self,
|
||||||
|
keyword_results: list[dict],
|
||||||
|
embedding_results: list[dict],
|
||||||
|
limit: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
keyword_results = self._normalize_kw_scores(keyword_results)
|
||||||
|
embedding_results = embedding_results
|
||||||
|
|
||||||
|
kw_norm_map = {}
|
||||||
|
for item in keyword_results:
|
||||||
|
item_id = item["id"]
|
||||||
|
kw_norm_map[item_id] = float(item.get("normalized_kw_score", 0))
|
||||||
|
|
||||||
|
emb_norm_map = {}
|
||||||
|
for item in embedding_results:
|
||||||
|
item_id = item["id"]
|
||||||
|
emb_norm_map[item_id] = float(item.get("score", 0))
|
||||||
|
|
||||||
|
combined = {}
|
||||||
|
for item in keyword_results:
|
||||||
|
item_id = item["id"]
|
||||||
|
combined[item_id] = item.copy()
|
||||||
|
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
|
||||||
|
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||||
|
|
||||||
|
for item in embedding_results:
|
||||||
|
item_id = item["id"]
|
||||||
|
if item_id in combined:
|
||||||
|
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||||
|
else:
|
||||||
|
combined[item_id] = item.copy()
|
||||||
|
combined[item_id]["kw_score"] = kw_norm_map.get(item_id, 0)
|
||||||
|
combined[item_id]["embedding_score"] = emb_norm_map.get(item_id, 0)
|
||||||
|
|
||||||
|
for item in combined.values():
|
||||||
|
item_id = item["id"]
|
||||||
|
kw = float(combined[item_id].get("kw_score", 0) or 0)
|
||||||
|
emb = float(combined[item_id].get("embedding_score", 0) or 0)
|
||||||
|
base = self.alpha * emb + (1 - self.alpha) * kw
|
||||||
|
combined[item_id]["content_score"] = base + min(1 - base, 0.1 * kw * emb)
|
||||||
|
results = sorted(combined.values(), key=lambda x: x["content_score"], reverse=True)
|
||||||
|
# results = [res for res in results if res["content_score"] > self.content_score_threshold]
|
||||||
|
results = results[:limit]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[MemorySearch] rerank: merged={len(combined)}, after_threshold={len(results)} "
|
||||||
|
f"(alpha={self.alpha})"
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _normalize_kw_scores(self, items: list[dict]) -> list[dict]:
|
||||||
|
if not items:
|
||||||
|
return items
|
||||||
|
scores = [float(it.get("score", 0) or 0) for it in items]
|
||||||
|
for it, s in zip(items, scores):
|
||||||
|
it[f"normalized_kw_score"] = 1 / (1 + math.exp(-(s - self.fulltext_score_threshold) / 2))
|
||||||
|
return items
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[Memory]:
|
||||||
|
async with Neo4jConnector() as connector:
|
||||||
|
self.connector = connector
|
||||||
|
kw_task = self._keyword_search(query, limit)
|
||||||
|
emb_task = self._embedding_search(query, limit)
|
||||||
|
kw_results, emb_results = await asyncio.gather(kw_task, emb_task, return_exceptions=True)
|
||||||
|
|
||||||
|
if isinstance(kw_results, Exception):
|
||||||
|
logger.warning(f"[MemorySearch] keyword search error: {kw_results}")
|
||||||
|
kw_results = {}
|
||||||
|
if isinstance(emb_results, Exception):
|
||||||
|
logger.warning(f"[MemorySearch] embedding search error: {emb_results}")
|
||||||
|
emb_results = {}
|
||||||
|
|
||||||
|
memories = []
|
||||||
|
for node_type in self.includes:
|
||||||
|
reranked = self._rerank(
|
||||||
|
kw_results.get(node_type, []),
|
||||||
|
emb_results.get(node_type, []),
|
||||||
|
limit
|
||||||
|
)
|
||||||
|
for record in reranked:
|
||||||
|
memory = data_builder_factory(node_type, record)
|
||||||
|
memories.append(Memory(
|
||||||
|
score=memory.score,
|
||||||
|
content=memory.content,
|
||||||
|
data=memory.data,
|
||||||
|
source=node_type,
|
||||||
|
query=query
|
||||||
|
))
|
||||||
|
memories.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return memories[:limit]
|
||||||
150
api/app/core/memory/read_services/result_builder.py
Normal file
150
api/app/core/memory/read_services/result_builder.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
|
|
||||||
|
|
||||||
|
class BaseBuilder(ABC):
|
||||||
|
def __init__(self, records: dict):
|
||||||
|
self.record = records
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def data(self) -> dict:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def content(self) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def score(self) -> float:
|
||||||
|
return self.record.get("content_score", 0.0) or 0.0
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseBuilder)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkBuilder(BaseBuilder):
|
||||||
|
@property
|
||||||
|
def data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.record.get("id"),
|
||||||
|
"content": self.record.get("content"),
|
||||||
|
"kw_score": self.record.get("kw_score", 0.0),
|
||||||
|
"emb_score": self.record.get("embedding_score", 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return self.record.get("content")
|
||||||
|
|
||||||
|
|
||||||
|
class StatementBuiler(BaseBuilder):
|
||||||
|
@property
|
||||||
|
def data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.record.get("id"),
|
||||||
|
"content": self.record.get("statement"),
|
||||||
|
"kw_score": self.record.get("kw_score", 0.0),
|
||||||
|
"emb_score": self.record.get("embedding_score", 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return self.record.get("statement")
|
||||||
|
|
||||||
|
|
||||||
|
class EntityBuilder(BaseBuilder):
|
||||||
|
@property
|
||||||
|
def data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.record.get("id"),
|
||||||
|
"content": self.record.get("name"),
|
||||||
|
"kw_score": self.record.get("kw_score", 0.0),
|
||||||
|
"emb_score": self.record.get("embedding_score", 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return self.record.get("name")
|
||||||
|
|
||||||
|
|
||||||
|
class SummaryBuilder(BaseBuilder):
|
||||||
|
@property
|
||||||
|
def data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.record.get("id"),
|
||||||
|
"content": self.record.get("content"),
|
||||||
|
"kw_score": self.record.get("kw_score", 0.0),
|
||||||
|
"emb_score": self.record.get("embedding_score", 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return self.record.get("content")
|
||||||
|
|
||||||
|
|
||||||
|
class PerceptualBuilder(BaseBuilder):
|
||||||
|
@property
|
||||||
|
def data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.record.get("id", ""),
|
||||||
|
"perceptual_type": self.record.get("perceptual_type", ""),
|
||||||
|
"file_name": self.record.get("file_name", ""),
|
||||||
|
"file_path": self.record.get("file_path", ""),
|
||||||
|
"summary": self.record.get("summary", ""),
|
||||||
|
"topic": self.record.get("topic", ""),
|
||||||
|
"domain": self.record.get("domain", ""),
|
||||||
|
"keywords": self.record.get("keywords", []),
|
||||||
|
"created_at": str(self.record.get("created_at", "")),
|
||||||
|
"file_type": self.record.get("file_type", ""),
|
||||||
|
"kw_score": self.record.get("kw_score", 0.0),
|
||||||
|
"emb_score": self.record.get("embedding_score", 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return ("<history-file-info>"
|
||||||
|
f"<file-name>{self.record.get('file_name')}</file-name>"
|
||||||
|
f"<file-path>{self.record.get('file_path')}</file-path>"
|
||||||
|
f"<summary>{self.record.get('summary')}</summary>"
|
||||||
|
f"<topic>{self.record.get('topic')}</topic>"
|
||||||
|
f"<domain>{self.record.get('domain')}</domain>"
|
||||||
|
f"<keywords>{self.record.get('keywords')}</keywords>"
|
||||||
|
f"<file-type>{self.record.get('file_type')}</file-type>"
|
||||||
|
"</<history-file-info>")
|
||||||
|
|
||||||
|
|
||||||
|
class CommunityBuilder(BaseBuilder):
|
||||||
|
@property
|
||||||
|
def data(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.record.get("id"),
|
||||||
|
"content": self.record.get("content"),
|
||||||
|
"kw_score": self.record.get("kw_score", 0.0),
|
||||||
|
"emb_score": self.record.get("embedding_score", 0.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> str:
|
||||||
|
return self.record.get("content")
|
||||||
|
|
||||||
|
|
||||||
|
def data_builder_factory(node_type, data: dict) -> T:
|
||||||
|
match node_type:
|
||||||
|
case Neo4jNodeType.STATEMENT:
|
||||||
|
return StatementBuiler(data)
|
||||||
|
case Neo4jNodeType.CHUNK:
|
||||||
|
return ChunkBuilder(data)
|
||||||
|
case Neo4jNodeType.EXTRACTEDENTITY:
|
||||||
|
return EntityBuilder(data)
|
||||||
|
case Neo4jNodeType.MEMORYSUMMARY:
|
||||||
|
return SummaryBuilder(data)
|
||||||
|
case Neo4jNodeType.PERCEPTUAL:
|
||||||
|
return PerceptualBuilder(data)
|
||||||
|
case Neo4jNodeType.COMMUNITY:
|
||||||
|
return CommunityBuilder(data)
|
||||||
|
case _:
|
||||||
|
raise KeyError(f"Unknown node_type: {node_type}")
|
||||||
@@ -6,6 +6,8 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
@@ -131,7 +133,7 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Remove duplicate items from search results based on content.
|
Remove duplicate items from search results based on content.
|
||||||
|
|
||||||
@@ -194,7 +196,7 @@ def rerank_with_activation(
|
|||||||
forgetting_config: ForgettingEngineConfig | None = None,
|
forgetting_config: ForgettingEngineConfig | None = None,
|
||||||
activation_boost_factor: float = 0.8,
|
activation_boost_factor: float = 0.8,
|
||||||
now: datetime | None = None,
|
now: datetime | None = None,
|
||||||
content_score_threshold: float = 0.5,
|
content_score_threshold: float = 0.1,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||||
@@ -239,7 +241,7 @@ def rerank_with_activation(
|
|||||||
|
|
||||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
for category in [Neo4jNodeType.STATEMENT, Neo4jNodeType.CHUNK, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY, Neo4jNodeType.COMMUNITY]:
|
||||||
keyword_items = keyword_results.get(category, [])
|
keyword_items = keyword_results.get(category, [])
|
||||||
embedding_items = embedding_results.get(category, [])
|
embedding_items = embedding_results.get(category, [])
|
||||||
|
|
||||||
@@ -405,7 +407,7 @@ def rerank_with_activation(
|
|||||||
f"items below content_score_threshold={content_score_threshold}"
|
f"items below content_score_threshold={content_score_threshold}"
|
||||||
)
|
)
|
||||||
|
|
||||||
sorted_items = _deduplicate_results(sorted_items)
|
sorted_items = deduplicate_results(sorted_items)
|
||||||
|
|
||||||
reranked[category] = sorted_items
|
reranked[category] = sorted_items
|
||||||
|
|
||||||
@@ -691,7 +693,7 @@ async def run_hybrid_search(
|
|||||||
search_type: str,
|
search_type: str,
|
||||||
end_user_id: str | None,
|
end_user_id: str | None,
|
||||||
limit: int,
|
limit: int,
|
||||||
include: List[str],
|
include: List[Neo4jNodeType],
|
||||||
output_path: str | None,
|
output_path: str | None,
|
||||||
memory_config: "MemoryConfig",
|
memory_config: "MemoryConfig",
|
||||||
rerank_alpha: float = 0.6,
|
rerank_alpha: float = 0.6,
|
||||||
|
|||||||
@@ -7,7 +7,8 @@ from sqlalchemy.dialects.postgresql import UUID
|
|||||||
from sqlalchemy.dialects.postgresql import JSONB
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
|
||||||
from app.db import Base
|
from app.db import Base
|
||||||
from app.schemas import FileType
|
from app.schemas.app_schema import FileType
|
||||||
|
|
||||||
|
|
||||||
class PerceptualType(IntEnum):
|
class PerceptualType(IntEnum):
|
||||||
VISION = 1
|
VISION = 1
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
|
|
||||||
DIALOGUE_NODE_SAVE = """
|
DIALOGUE_NODE_SAVE = """
|
||||||
UNWIND $dialogues AS dialogue
|
UNWIND $dialogues AS dialogue
|
||||||
@@ -147,57 +148,6 @@ SET r.predicate = rel.predicate,
|
|||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
|
||||||
|
|
||||||
# 保存弱关系实体,设置 e.is_weak = true;不维护 e.relations 聚合字段
|
|
||||||
WEAK_ENTITY_NODE_SAVE = """
|
|
||||||
UNWIND $weak_entities AS entity
|
|
||||||
MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id})
|
|
||||||
SET e += {
|
|
||||||
name: entity.name,
|
|
||||||
end_user_id: entity.end_user_id,
|
|
||||||
run_id: entity.run_id,
|
|
||||||
description: entity.description,
|
|
||||||
chunk_id: entity.chunk_id,
|
|
||||||
dialog_id: entity.dialog_id
|
|
||||||
}
|
|
||||||
// Independent weak flag,仅标记弱关系,不再维护 relations 聚合字段
|
|
||||||
SET e.is_weak = true
|
|
||||||
RETURN e.id AS id
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 为强关系三元组中的主语和宾语创建/更新实体节点,仅设置 e.is_strong = true,不维护 e.relations 字段
|
|
||||||
SAVE_STRONG_TRIPLE_ENTITIES = """
|
|
||||||
UNWIND $items AS item
|
|
||||||
MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id})
|
|
||||||
SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id}
|
|
||||||
// Independent strong flag
|
|
||||||
SET s.is_strong = true
|
|
||||||
MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id})
|
|
||||||
SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id}
|
|
||||||
// Independent strong flag
|
|
||||||
SET o.is_strong = true
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DIALOGUE_STATEMENT_EDGE_SAVE = """
|
|
||||||
UNWIND $dialogue_statement_edges AS edge
|
|
||||||
// 支持按 uuid 或 ref_id 连接到 Dialogue,避免因来源 ID 不一致而断链
|
|
||||||
MATCH (dialogue:Dialogue)
|
|
||||||
WHERE dialogue.uuid = edge.source OR dialogue.ref_id = edge.source
|
|
||||||
MATCH (statement:Statement {id: edge.target})
|
|
||||||
// 仅按端点去重,关系属性可更新
|
|
||||||
MERGE (dialogue)-[e:MENTIONS]->(statement)
|
|
||||||
SET e.uuid = edge.id,
|
|
||||||
e.end_user_id = edge.end_user_id,
|
|
||||||
e.created_at = edge.created_at,
|
|
||||||
e.expired_at = edge.expired_at
|
|
||||||
RETURN e.uuid AS uuid
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 在 Neo4j 5及后续版本中,id() 函数已被标记为弃用,用elementId() 函数替代
|
|
||||||
|
|
||||||
|
|
||||||
CHUNK_STATEMENT_EDGE_SAVE = """
|
CHUNK_STATEMENT_EDGE_SAVE = """
|
||||||
UNWIND $chunk_statement_edges AS edge
|
UNWIND $chunk_statement_edges AS edge
|
||||||
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
MATCH (statement:Statement {id: edge.source, run_id: edge.run_id})
|
||||||
@@ -226,87 +176,6 @@ SET r.end_user_id = rel.end_user_id,
|
|||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ENTITY_EMBEDDING_SEARCH = """
|
|
||||||
CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding)
|
|
||||||
YIELD node AS e, score
|
|
||||||
WHERE e.name_embedding IS NOT NULL
|
|
||||||
AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
|
||||||
RETURN e.id AS id,
|
|
||||||
e.name AS name,
|
|
||||||
e.end_user_id AS end_user_id,
|
|
||||||
e.entity_type AS entity_type,
|
|
||||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
|
||||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
|
||||||
e.last_access_time AS last_access_time,
|
|
||||||
COALESCE(e.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
# Embedding-based search: cosine similarity on Statement.statement_embedding
|
|
||||||
STATEMENT_EMBEDDING_SEARCH = """
|
|
||||||
CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding)
|
|
||||||
YIELD node AS s, score
|
|
||||||
WHERE s.statement_embedding IS NOT NULL
|
|
||||||
AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
|
||||||
RETURN s.id AS id,
|
|
||||||
s.statement AS statement,
|
|
||||||
s.end_user_id AS end_user_id,
|
|
||||||
s.chunk_id AS chunk_id,
|
|
||||||
s.created_at AS created_at,
|
|
||||||
s.expired_at AS expired_at,
|
|
||||||
s.valid_at AS valid_at,
|
|
||||||
s.invalid_at AS invalid_at,
|
|
||||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
|
||||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
|
||||||
s.last_access_time AS last_access_time,
|
|
||||||
COALESCE(s.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Embedding-based search: cosine similarity on Chunk.chunk_embedding
|
|
||||||
CHUNK_EMBEDDING_SEARCH = """
|
|
||||||
CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding)
|
|
||||||
YIELD node AS c, score
|
|
||||||
WHERE c.chunk_embedding IS NOT NULL
|
|
||||||
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
|
||||||
RETURN c.id AS chunk_id,
|
|
||||||
c.end_user_id AS end_user_id,
|
|
||||||
c.content AS content,
|
|
||||||
c.dialog_id AS dialog_id,
|
|
||||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
|
||||||
c.last_access_time AS last_access_time,
|
|
||||||
COALESCE(c.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
|
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD = """
|
|
||||||
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
|
|
||||||
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
|
||||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
|
||||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
|
||||||
RETURN s.id AS id,
|
|
||||||
s.statement AS statement,
|
|
||||||
s.end_user_id AS end_user_id,
|
|
||||||
s.chunk_id AS chunk_id,
|
|
||||||
s.created_at AS created_at,
|
|
||||||
s.expired_at AS expired_at,
|
|
||||||
s.valid_at AS valid_at,
|
|
||||||
s.invalid_at AS invalid_at,
|
|
||||||
c.id AS chunk_id_from_rel,
|
|
||||||
collect(DISTINCT e.id) AS entity_ids,
|
|
||||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
|
||||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
|
||||||
s.last_access_time AS last_access_time,
|
|
||||||
COALESCE(s.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
# 查询实体名称包含指定字符串的实体
|
# 查询实体名称包含指定字符串的实体
|
||||||
SEARCH_ENTITIES_BY_NAME = """
|
SEARCH_ENTITIES_BY_NAME = """
|
||||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||||
@@ -338,73 +207,6 @@ ORDER BY score DESC
|
|||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
|
||||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
|
||||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
|
||||||
WITH e, score
|
|
||||||
With collect({entity: e, score: score}) AS fulltextResults
|
|
||||||
|
|
||||||
OPTIONAL MATCH (ae:ExtractedEntity)
|
|
||||||
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
|
||||||
AND ae.aliases IS NOT NULL
|
|
||||||
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
|
|
||||||
WITH fulltextResults, collect(ae) AS aliasEntities
|
|
||||||
|
|
||||||
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
|
||||||
CASE
|
|
||||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
|
|
||||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
|
|
||||||
ELSE 0.8
|
|
||||||
END
|
|
||||||
}]) AS row
|
|
||||||
WITH row.entity AS e, row.score AS score
|
|
||||||
WITH DISTINCT e, MAX(score) AS score
|
|
||||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
|
||||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
|
||||||
RETURN e.id AS id,
|
|
||||||
e.name AS name,
|
|
||||||
e.end_user_id AS end_user_id,
|
|
||||||
e.entity_type AS entity_type,
|
|
||||||
e.created_at AS created_at,
|
|
||||||
e.expired_at AS expired_at,
|
|
||||||
e.entity_idx AS entity_idx,
|
|
||||||
e.statement_id AS statement_id,
|
|
||||||
e.description AS description,
|
|
||||||
e.aliases AS aliases,
|
|
||||||
e.name_embedding AS name_embedding,
|
|
||||||
e.connect_strength AS connect_strength,
|
|
||||||
collect(DISTINCT s.id) AS statement_ids,
|
|
||||||
collect(DISTINCT c.id) AS chunk_ids,
|
|
||||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
|
||||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
|
||||||
e.last_access_time AS last_access_time,
|
|
||||||
COALESCE(e.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
SEARCH_CHUNKS_BY_CONTENT = """
|
|
||||||
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
|
|
||||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
|
||||||
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
|
||||||
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
|
||||||
RETURN c.id AS chunk_id,
|
|
||||||
c.end_user_id AS end_user_id,
|
|
||||||
c.content AS content,
|
|
||||||
c.dialog_id AS dialog_id,
|
|
||||||
c.sequence_number AS sequence_number,
|
|
||||||
collect(DISTINCT s.id) AS statement_ids,
|
|
||||||
collect(DISTINCT e.id) AS entity_ids,
|
|
||||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
|
||||||
c.last_access_time AS last_access_time,
|
|
||||||
COALESCE(c.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
# 以下是关于第二层去重消歧与数据库进行检索的语句,在最近的规划中不再使用
|
||||||
|
|
||||||
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
# # 同组group_id下按“精确名字或别名+可选类型一致”来检索
|
||||||
@@ -677,49 +479,6 @@ MATCH (n:Statement {end_user_id: $end_user_id, id: $id})
|
|||||||
SET n.invalid_at = $new_invalid_at
|
SET n.invalid_at = $new_invalid_at
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# MemorySummary keyword search using fulltext index
|
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
|
||||||
CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score
|
|
||||||
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
|
||||||
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
|
||||||
RETURN m.id AS id,
|
|
||||||
m.name AS name,
|
|
||||||
m.end_user_id AS end_user_id,
|
|
||||||
m.dialog_id AS dialog_id,
|
|
||||||
m.chunk_ids AS chunk_ids,
|
|
||||||
m.content AS content,
|
|
||||||
m.created_at AS created_at,
|
|
||||||
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
|
||||||
COALESCE(m.importance_score, 0.5) AS importance_score,
|
|
||||||
m.last_access_time AS last_access_time,
|
|
||||||
COALESCE(m.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Embedding-based search: cosine similarity on MemorySummary.summary_embedding
|
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH = """
|
|
||||||
CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding)
|
|
||||||
YIELD node AS m, score
|
|
||||||
WHERE m.summary_embedding IS NOT NULL
|
|
||||||
AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
|
||||||
RETURN m.id AS id,
|
|
||||||
m.name AS name,
|
|
||||||
m.end_user_id AS end_user_id,
|
|
||||||
m.dialog_id AS dialog_id,
|
|
||||||
m.chunk_ids AS chunk_ids,
|
|
||||||
m.content AS content,
|
|
||||||
m.created_at AS created_at,
|
|
||||||
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
|
||||||
COALESCE(m.importance_score, 0.5) AS importance_score,
|
|
||||||
m.last_access_time AS last_access_time,
|
|
||||||
COALESCE(m.access_count, 0) AS access_count,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
|
|
||||||
MEMORY_SUMMARY_NODE_SAVE = """
|
MEMORY_SUMMARY_NODE_SAVE = """
|
||||||
UNWIND $summaries AS summary
|
UNWIND $summaries AS summary
|
||||||
MERGE (m:MemorySummary {id: summary.id})
|
MERGE (m:MemorySummary {id: summary.id})
|
||||||
@@ -1030,8 +789,6 @@ RETURN DISTINCT
|
|||||||
e.statement AS statement;
|
e.statement AS statement;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
'''获取实体'''
|
|
||||||
|
|
||||||
Memory_Space_User = """
|
Memory_Space_User = """
|
||||||
MATCH (n)-[r]->(m)
|
MATCH (n)-[r]->(m)
|
||||||
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
WHERE n.end_user_id = $end_user_id AND m.name="用户"
|
||||||
@@ -1363,22 +1120,6 @@ WHERE c.name IS NULL OR c.name = ''
|
|||||||
RETURN c.community_id AS community_id
|
RETURN c.community_id AS community_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Community keyword search: matches name or summary via fulltext index
|
|
||||||
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
|
||||||
CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score
|
|
||||||
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
|
||||||
RETURN c.community_id AS id,
|
|
||||||
c.name AS name,
|
|
||||||
c.summary AS content,
|
|
||||||
c.core_entities AS core_entities,
|
|
||||||
c.member_count AS member_count,
|
|
||||||
c.end_user_id AS end_user_id,
|
|
||||||
c.updated_at AS updated_at,
|
|
||||||
score
|
|
||||||
ORDER BY score DESC
|
|
||||||
LIMIT $limit
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Community 向量检索 ──────────────────────────────────────────────────
|
# Community 向量检索 ──────────────────────────────────────────────────
|
||||||
# Community embedding-based search: cosine similarity on Community.summary_embedding
|
# Community embedding-based search: cosine similarity on Community.summary_embedding
|
||||||
COMMUNITY_EMBEDDING_SEARCH = """
|
COMMUNITY_EMBEDDING_SEARCH = """
|
||||||
@@ -1452,13 +1193,54 @@ ON CREATE SET r.end_user_id = edge.end_user_id,
|
|||||||
RETURN elementId(r) AS uuid
|
RETURN elementId(r) AS uuid
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# search by user id
|
||||||
|
# -------------------
|
||||||
SEARCH_PERCEPTUAL_BY_USER_ID = """
|
SEARCH_PERCEPTUAL_BY_USER_ID = """
|
||||||
MATCH (p:Perceptual)
|
MATCH (p:Perceptual)
|
||||||
WHERE p.end_user_id = $end_user_id
|
WHERE p.end_user_id = $end_user_id
|
||||||
RETURN p.id AS id,
|
RETURN p.id AS id,
|
||||||
p.summary_embedding AS summary_embedding
|
p.summary_embedding AS embedding
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SEARCH_STATEMENTS_BY_USER_ID = """
|
||||||
|
MATCH (s:Statement)
|
||||||
|
WHERE s.end_user_id = $end_user_id
|
||||||
|
RETURN s.id AS id,
|
||||||
|
s.statement_embedding AS embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_ENTITIES_BY_USER_ID = """
|
||||||
|
MATCH (e:ExtractedEntity)
|
||||||
|
WHERE e.end_user_id = $end_user_id
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name_embedding AS embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_CHUNKS_BY_USER_ID = """
|
||||||
|
MATCH (c:Chunk)
|
||||||
|
WHERE c.end_user_id = $end_user_id
|
||||||
|
RETURN c.id AS id,
|
||||||
|
c.chunk_embedding AS embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_MEMORY_SUMMARIES_BY_USER_ID = """
|
||||||
|
MATCH (s:MemorySummary)
|
||||||
|
WHERE s.end_user_id = $end_user_id
|
||||||
|
RETURN s.id AS id,
|
||||||
|
s.summary_embedding AS embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_COMMUNITIES_BY_USER_ID = """
|
||||||
|
MATCH (c:Community)
|
||||||
|
WHERE c.end_user_id = $end_user_id
|
||||||
|
RETURN c.id AS id,
|
||||||
|
c.summary_embedding AS embedding
|
||||||
|
"""
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# search by id
|
||||||
|
# -------------------
|
||||||
SEARCH_PERCEPTUAL_BY_IDS = """
|
SEARCH_PERCEPTUAL_BY_IDS = """
|
||||||
MATCH (p:Perceptual)
|
MATCH (p:Perceptual)
|
||||||
WHERE p.id IN $ids
|
WHERE p.id IN $ids
|
||||||
@@ -1476,7 +1258,79 @@ RETURN p.id AS id,
|
|||||||
p.file_type AS file_type
|
p.file_type AS file_type
|
||||||
"""
|
"""
|
||||||
|
|
||||||
SEARCH_PERCEPTUAL_BY_KEYWORD = """
|
SEARCH_STATEMENTS_BY_IDS = """
|
||||||
|
MATCH (s:Statement)
|
||||||
|
WHERE s.id IN $ids
|
||||||
|
RETURN s.id AS id,
|
||||||
|
s.statement AS statement,
|
||||||
|
s.end_user_id AS end_user_id,
|
||||||
|
s.chunk_id AS chunk_id,
|
||||||
|
s.created_at AS created_at,
|
||||||
|
s.expired_at AS expired_at,
|
||||||
|
s.valid_at AS valid_at,
|
||||||
|
properties(s)['invalid_at'] AS invalid_at,
|
||||||
|
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||||
|
s.last_access_time AS last_access_time,
|
||||||
|
COALESCE(s.access_count, 0) AS access_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_CHUNKS_BY_IDS = """
|
||||||
|
MATCH (c:Chunk)
|
||||||
|
WHERE c.id IN $ids
|
||||||
|
RETURN c.id AS id,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.content AS content,
|
||||||
|
c.dialog_id AS dialog_id,
|
||||||
|
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||||
|
c.last_access_time AS last_access_time,
|
||||||
|
COALESCE(c.access_count, 0) AS access_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_ENTITIES_BY_IDS = """
|
||||||
|
MATCH (e:ExtractedEntity)
|
||||||
|
WHERE e.id IN $ids
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.end_user_id AS end_user_id,
|
||||||
|
e.entity_type AS entity_type,
|
||||||
|
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||||
|
e.last_access_time AS last_access_time,
|
||||||
|
COALESCE(e.access_count, 0) AS access_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_MEMORY_SUMMARIES_BY_IDS = """
|
||||||
|
MATCH (m:MemorySummary)
|
||||||
|
WHERE m.id IN $ids
|
||||||
|
RETURN m.id AS id,
|
||||||
|
m.name AS name,
|
||||||
|
m.end_user_id AS end_user_id,
|
||||||
|
m.dialog_id AS dialog_id,
|
||||||
|
m.chunk_ids AS chunk_ids,
|
||||||
|
m.content AS content,
|
||||||
|
m.created_at AS created_at,
|
||||||
|
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(m.importance_score, 0.5) AS importance_score,
|
||||||
|
m.last_access_time AS last_access_time,
|
||||||
|
COALESCE(m.access_count, 0) AS access_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_COMMUNITIES_BY_IDS = """
|
||||||
|
MATCH (c:Community)
|
||||||
|
WHERE c.id IN $ids
|
||||||
|
RETURN c.id AS id,
|
||||||
|
c.name AS name,
|
||||||
|
c.summary AS content,
|
||||||
|
c.core_entities AS core_entities,
|
||||||
|
c.member_count AS member_count,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.updated_at AS updated_at
|
||||||
|
"""
|
||||||
|
# -------------------
|
||||||
|
# search by fulltext
|
||||||
|
# -------------------
|
||||||
|
SEARCH_PERCEPTUALS_BY_KEYWORD = """
|
||||||
CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score
|
CALL db.index.fulltext.queryNodes("perceptualFulltext", $query) YIELD node AS p, score
|
||||||
WHERE p.end_user_id = $end_user_id
|
WHERE p.end_user_id = $end_user_id
|
||||||
RETURN p.id AS id,
|
RETURN p.id AS id,
|
||||||
@@ -1495,3 +1349,155 @@ RETURN p.id AS id,
|
|||||||
ORDER BY score DESC
|
ORDER BY score DESC
|
||||||
LIMIT $limit
|
LIMIT $limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SEARCH_STATEMENTS_BY_KEYWORD = """
|
||||||
|
CALL db.index.fulltext.queryNodes("statementsFulltext", $query) YIELD node AS s, score
|
||||||
|
WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id)
|
||||||
|
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||||
|
RETURN s.id AS id,
|
||||||
|
s.statement AS statement,
|
||||||
|
s.end_user_id AS end_user_id,
|
||||||
|
s.chunk_id AS chunk_id,
|
||||||
|
s.created_at AS created_at,
|
||||||
|
s.expired_at AS expired_at,
|
||||||
|
s.valid_at AS valid_at,
|
||||||
|
properties(s)['invalid_at'] AS invalid_at,
|
||||||
|
c.id AS chunk_id_from_rel,
|
||||||
|
collect(DISTINCT e.id) AS entity_ids,
|
||||||
|
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||||
|
s.last_access_time AS last_access_time,
|
||||||
|
COALESCE(s.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
||||||
|
CALL db.index.fulltext.queryNodes("entitiesFulltext", $query) YIELD node AS e, score
|
||||||
|
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
|
WITH e, score
|
||||||
|
With collect({entity: e, score: score}) AS fulltextResults
|
||||||
|
|
||||||
|
OPTIONAL MATCH (ae:ExtractedEntity)
|
||||||
|
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
||||||
|
AND ae.aliases IS NOT NULL
|
||||||
|
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($query))
|
||||||
|
WITH fulltextResults, collect(ae) AS aliasEntities
|
||||||
|
|
||||||
|
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
||||||
|
CASE
|
||||||
|
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($query)) THEN 1.0
|
||||||
|
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($query)) THEN 0.9
|
||||||
|
ELSE 0.8
|
||||||
|
END
|
||||||
|
}]) AS row
|
||||||
|
WITH row.entity AS e, row.score AS score
|
||||||
|
WITH DISTINCT e, MAX(score) AS score
|
||||||
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.end_user_id AS end_user_id,
|
||||||
|
e.entity_type AS entity_type,
|
||||||
|
e.created_at AS created_at,
|
||||||
|
e.expired_at AS expired_at,
|
||||||
|
e.entity_idx AS entity_idx,
|
||||||
|
e.statement_id AS statement_id,
|
||||||
|
e.description AS description,
|
||||||
|
e.aliases AS aliases,
|
||||||
|
e.name_embedding AS name_embedding,
|
||||||
|
e.connect_strength AS connect_strength,
|
||||||
|
collect(DISTINCT s.id) AS statement_ids,
|
||||||
|
collect(DISTINCT c.id) AS chunk_ids,
|
||||||
|
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||||
|
e.last_access_time AS last_access_time,
|
||||||
|
COALESCE(e.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
SEARCH_CHUNKS_BY_CONTENT = """
|
||||||
|
CALL db.index.fulltext.queryNodes("chunksFulltext", $query) YIELD node AS c, score
|
||||||
|
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
|
OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity)
|
||||||
|
RETURN c.id AS id,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.content AS content,
|
||||||
|
c.dialog_id AS dialog_id,
|
||||||
|
c.sequence_number AS sequence_number,
|
||||||
|
collect(DISTINCT s.id) AS statement_ids,
|
||||||
|
collect(DISTINCT e.id) AS entity_ids,
|
||||||
|
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||||
|
c.last_access_time AS last_access_time,
|
||||||
|
COALESCE(c.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# MemorySummary keyword search using fulltext index
|
||||||
|
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """
|
||||||
|
CALL db.index.fulltext.queryNodes("summariesFulltext", $query) YIELD node AS m, score
|
||||||
|
WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id)
|
||||||
|
OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement)
|
||||||
|
RETURN m.id AS id,
|
||||||
|
m.name AS name,
|
||||||
|
m.end_user_id AS end_user_id,
|
||||||
|
m.dialog_id AS dialog_id,
|
||||||
|
m.chunk_ids AS chunk_ids,
|
||||||
|
m.content AS content,
|
||||||
|
m.created_at AS created_at,
|
||||||
|
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(m.importance_score, 0.5) AS importance_score,
|
||||||
|
m.last_access_time AS last_access_time,
|
||||||
|
COALESCE(m.access_count, 0) AS access_count,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Community keyword search: matches name or summary via fulltext index
|
||||||
|
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||||
|
CALL db.index.fulltext.queryNodes("communitiesFulltext", $query) YIELD node AS c, score
|
||||||
|
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
|
RETURN c.id AS id,
|
||||||
|
c.name AS name,
|
||||||
|
c.summary AS content,
|
||||||
|
c.core_entities AS core_entities,
|
||||||
|
c.member_count AS member_count,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.updated_at AS updated_at,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
FULLTEXT_QUERY_CYPHER_MAPPING = {
|
||||||
|
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_KEYWORD,
|
||||||
|
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
||||||
|
Neo4jNodeType.CHUNK: SEARCH_CHUNKS_BY_CONTENT,
|
||||||
|
Neo4jNodeType.MEMORYSUMMARY: SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||||
|
Neo4jNodeType.COMMUNITY: SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||||
|
Neo4jNodeType.PERCEPTUAL: SEARCH_PERCEPTUALS_BY_KEYWORD
|
||||||
|
}
|
||||||
|
USER_ID_QUERY_CYPHER_MAPPING = {
|
||||||
|
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_USER_ID,
|
||||||
|
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_USER_ID,
|
||||||
|
Neo4jNodeType.CHUNK: SEARCH_CHUNKS_BY_USER_ID,
|
||||||
|
Neo4jNodeType.MEMORYSUMMARY: SEARCH_MEMORY_SUMMARIES_BY_USER_ID,
|
||||||
|
Neo4jNodeType.COMMUNITY: SEARCH_COMMUNITIES_BY_USER_ID,
|
||||||
|
Neo4jNodeType.PERCEPTUAL: SEARCH_PERCEPTUAL_BY_USER_ID
|
||||||
|
}
|
||||||
|
NODE_ID_QUERY_CYPHER_MAPPING = {
|
||||||
|
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_IDS,
|
||||||
|
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_IDS,
|
||||||
|
Neo4jNodeType.CHUNK: SEARCH_CHUNKS_BY_IDS,
|
||||||
|
Neo4jNodeType.MEMORYSUMMARY: SEARCH_MEMORY_SUMMARIES_BY_IDS,
|
||||||
|
Neo4jNodeType.COMMUNITY: SEARCH_COMMUNITIES_BY_IDS,
|
||||||
|
Neo4jNodeType.PERCEPTUAL: SEARCH_PERCEPTUAL_BY_IDS
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,26 +1,19 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
import time
|
||||||
|
from typing import Any, Dict, List, Optional, Coroutine
|
||||||
|
|
||||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from app.core.memory.enums import Neo4jNodeType
|
||||||
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
from app.core.memory.llm_tools import OpenAIEmbedderClient
|
||||||
|
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||||
from app.repositories.neo4j.cypher_queries import (
|
from app.repositories.neo4j.cypher_queries import (
|
||||||
CHUNK_EMBEDDING_SEARCH,
|
|
||||||
COMMUNITY_EMBEDDING_SEARCH,
|
|
||||||
ENTITY_EMBEDDING_SEARCH,
|
|
||||||
EXPAND_COMMUNITY_STATEMENTS,
|
EXPAND_COMMUNITY_STATEMENTS,
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||||
SEARCH_CHUNKS_BY_CONTENT,
|
|
||||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
|
||||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||||
SEARCH_ENTITIES_BY_NAME,
|
SEARCH_ENTITIES_BY_NAME,
|
||||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
|
||||||
SEARCH_STATEMENTS_BY_CREATED_AT,
|
SEARCH_STATEMENTS_BY_CREATED_AT,
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL,
|
||||||
SEARCH_STATEMENTS_BY_TEMPORAL,
|
SEARCH_STATEMENTS_BY_TEMPORAL,
|
||||||
SEARCH_STATEMENTS_BY_VALID_AT,
|
SEARCH_STATEMENTS_BY_VALID_AT,
|
||||||
@@ -28,12 +21,14 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
SEARCH_STATEMENTS_G_VALID_AT,
|
SEARCH_STATEMENTS_G_VALID_AT,
|
||||||
SEARCH_STATEMENTS_L_CREATED_AT,
|
SEARCH_STATEMENTS_L_CREATED_AT,
|
||||||
SEARCH_STATEMENTS_L_VALID_AT,
|
SEARCH_STATEMENTS_L_VALID_AT,
|
||||||
STATEMENT_EMBEDDING_SEARCH,
|
SEARCH_PERCEPTUALS_BY_KEYWORD,
|
||||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
|
||||||
SEARCH_PERCEPTUAL_BY_IDS,
|
SEARCH_PERCEPTUAL_BY_IDS,
|
||||||
SEARCH_PERCEPTUAL_BY_USER_ID,
|
SEARCH_PERCEPTUAL_BY_USER_ID,
|
||||||
|
FULLTEXT_QUERY_CYPHER_MAPPING,
|
||||||
|
USER_ID_QUERY_CYPHER_MAPPING,
|
||||||
|
NODE_ID_QUERY_CYPHER_MAPPING
|
||||||
)
|
)
|
||||||
# 使用新的仓储层
|
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -52,7 +47,7 @@ def cosine_similarity_search(
|
|||||||
query_norm = query / np.linalg.norm(query)
|
query_norm = query / np.linalg.norm(query)
|
||||||
|
|
||||||
similarities = vectors_norm @ query_norm
|
similarities = vectors_norm @ query_norm
|
||||||
similarities = (similarities + 1) / 2
|
similarities = np.clip(similarities, 0, 1)
|
||||||
top_k = min(limit, similarities.shape[0])
|
top_k = min(limit, similarities.shape[0])
|
||||||
if top_k <= 0:
|
if top_k <= 0:
|
||||||
return {}
|
return {}
|
||||||
@@ -60,7 +55,7 @@ def cosine_similarity_search(
|
|||||||
top_indices = top_indices[np.argsort(-similarities[top_indices])]
|
top_indices = top_indices[np.argsort(-similarities[top_indices])]
|
||||||
result = {}
|
result = {}
|
||||||
for idx in top_indices:
|
for idx in top_indices:
|
||||||
result[idx] = similarities[idx]
|
result[idx] = float(similarities[idx])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@@ -173,7 +168,10 @@ async def _update_search_results_activation(
|
|||||||
knowledge_node_types = {
|
knowledge_node_types = {
|
||||||
'statements': 'Statement',
|
'statements': 'Statement',
|
||||||
'entities': 'ExtractedEntity',
|
'entities': 'ExtractedEntity',
|
||||||
'summaries': 'MemorySummary'
|
'summaries': 'MemorySummary',
|
||||||
|
Neo4jNodeType.STATEMENT: Neo4jNodeType.STATEMENT.value,
|
||||||
|
Neo4jNodeType.EXTRACTEDENTITY: Neo4jNodeType.EXTRACTEDENTITY.value,
|
||||||
|
Neo4jNodeType.MEMORYSUMMARY: Neo4jNodeType.MEMORYSUMMARY.value,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 并行更新所有类型的节点
|
# 并行更新所有类型的节点
|
||||||
@@ -250,12 +248,147 @@ async def _update_search_results_activation(
|
|||||||
return updated_results
|
return updated_results
|
||||||
|
|
||||||
|
|
||||||
|
async def search_perceptual_by_fulltext(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
query: str,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
try:
|
||||||
|
perceptuals = await connector.execute_query(
|
||||||
|
SEARCH_PERCEPTUALS_BY_KEYWORD,
|
||||||
|
query=escape_lucene_query(query),
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"search_perceptual: keyword search failed: {e}")
|
||||||
|
perceptuals = []
|
||||||
|
|
||||||
|
# Deduplicate
|
||||||
|
from app.core.memory.src.search import deduplicate_results
|
||||||
|
perceptuals = deduplicate_results(perceptuals)
|
||||||
|
|
||||||
|
return {"perceptuals": perceptuals}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_perceptual_by_embedding(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
embedder_client: OpenAIEmbedderClient,
|
||||||
|
query_text: str,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Search Perceptual memory nodes using embedding-based semantic search.
|
||||||
|
|
||||||
|
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: Neo4j connector
|
||||||
|
embedder_client: Embedding client with async response() method
|
||||||
|
query_text: Query text to embed
|
||||||
|
end_user_id: Optional user filter
|
||||||
|
limit: Max results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
||||||
|
"""
|
||||||
|
embeddings = await embedder_client.response([query_text])
|
||||||
|
if not embeddings or not embeddings[0]:
|
||||||
|
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
||||||
|
return {"perceptuals": []}
|
||||||
|
|
||||||
|
embedding = embeddings[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
perceptuals = await connector.execute_query(
|
||||||
|
SEARCH_PERCEPTUAL_BY_USER_ID,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
ids = [item['id'] for item in perceptuals]
|
||||||
|
vectors = [item['summary_embedding'] for item in perceptuals]
|
||||||
|
sim_res = cosine_similarity_search(embedding, vectors, limit=limit)
|
||||||
|
perceptual_res = {
|
||||||
|
ids[idx]: score
|
||||||
|
for idx, score in sim_res.items()
|
||||||
|
}
|
||||||
|
perceptuals = await connector.execute_query(
|
||||||
|
SEARCH_PERCEPTUAL_BY_IDS,
|
||||||
|
ids=list(perceptual_res.keys())
|
||||||
|
)
|
||||||
|
for perceptual in perceptuals:
|
||||||
|
perceptual["score"] = perceptual_res[perceptual["id"]]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
|
||||||
|
perceptuals = []
|
||||||
|
|
||||||
|
from app.core.memory.src.search import deduplicate_results
|
||||||
|
perceptuals = deduplicate_results(perceptuals)
|
||||||
|
|
||||||
|
return {"perceptuals": perceptuals}
|
||||||
|
|
||||||
|
|
||||||
|
def search_by_fulltext(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
node_type: Neo4jNodeType,
|
||||||
|
end_user_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
|
||||||
|
cypher = FULLTEXT_QUERY_CYPHER_MAPPING[node_type]
|
||||||
|
return connector.execute_query(
|
||||||
|
cypher,
|
||||||
|
json_format=True,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
query=query,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def search_by_embedding(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
node_type: Neo4jNodeType,
|
||||||
|
end_user_id: str,
|
||||||
|
query_embedding: list[float],
|
||||||
|
limit: int = 10,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
records = await connector.execute_query(
|
||||||
|
USER_ID_QUERY_CYPHER_MAPPING[node_type],
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
records = [record for record in records if record if record["embedding"] is not None]
|
||||||
|
ids = [item['id'] for item in records]
|
||||||
|
vectors = [item['embedding'] for item in records]
|
||||||
|
sim_res = cosine_similarity_search(query_embedding, vectors, limit=limit)
|
||||||
|
records_score_map = {
|
||||||
|
ids[idx]: score
|
||||||
|
for idx, score in sim_res.items()
|
||||||
|
}
|
||||||
|
records = await connector.execute_query(
|
||||||
|
NODE_ID_QUERY_CYPHER_MAPPING[node_type],
|
||||||
|
ids=list(records_score_map.keys()),
|
||||||
|
json_format=True
|
||||||
|
)
|
||||||
|
for record in records:
|
||||||
|
record["score"] = records_score_map[record["id"]]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"search_graph_by_embedding: vector search failed: {e}, node_type:{node_type.value}",
|
||||||
|
exc_info=True)
|
||||||
|
records = []
|
||||||
|
|
||||||
|
from app.core.memory.src.search import deduplicate_results
|
||||||
|
records = deduplicate_results(records)
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
async def search_graph(
|
async def search_graph(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
query: str,
|
query: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include: List[str] = None,
|
include: List[Neo4jNodeType] = None,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||||
@@ -279,7 +412,13 @@ async def search_graph(
|
|||||||
Dictionary with search results per category (with updated activation values)
|
Dictionary with search results per category (with updated activation values)
|
||||||
"""
|
"""
|
||||||
if include is None:
|
if include is None:
|
||||||
include = ["statements", "chunks", "entities", "summaries"]
|
include = [
|
||||||
|
Neo4jNodeType.STATEMENT,
|
||||||
|
Neo4jNodeType.CHUNK,
|
||||||
|
Neo4jNodeType.EXTRACTEDENTITY,
|
||||||
|
Neo4jNodeType.MEMORYSUMMARY,
|
||||||
|
Neo4jNodeType.PERCEPTUAL
|
||||||
|
]
|
||||||
|
|
||||||
# Escape Lucene special characters to prevent query parse errors
|
# Escape Lucene special characters to prevent query parse errors
|
||||||
escaped_query = escape_lucene_query(query)
|
escaped_query = escape_lucene_query(query)
|
||||||
@@ -288,55 +427,9 @@ async def search_graph(
|
|||||||
tasks = []
|
tasks = []
|
||||||
task_keys = []
|
task_keys = []
|
||||||
|
|
||||||
if "statements" in include:
|
for node_type in include:
|
||||||
tasks.append(connector.execute_query(
|
tasks.append(search_by_fulltext(connector, node_type, end_user_id, escaped_query, limit))
|
||||||
SEARCH_STATEMENTS_BY_KEYWORD,
|
task_keys.append(node_type.value)
|
||||||
json_format=True,
|
|
||||||
query=escaped_query,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("statements")
|
|
||||||
|
|
||||||
if "entities" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
|
|
||||||
json_format=True,
|
|
||||||
query=escaped_query,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("entities")
|
|
||||||
|
|
||||||
if "chunks" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
SEARCH_CHUNKS_BY_CONTENT,
|
|
||||||
json_format=True,
|
|
||||||
query=escaped_query,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("chunks")
|
|
||||||
|
|
||||||
if "summaries" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
|
||||||
json_format=True,
|
|
||||||
query=escaped_query,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("summaries")
|
|
||||||
|
|
||||||
if "communities" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
SEARCH_COMMUNITIES_BY_KEYWORD,
|
|
||||||
json_format=True,
|
|
||||||
query=escaped_query,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("communities")
|
|
||||||
|
|
||||||
# Execute all queries in parallel
|
# Execute all queries in parallel
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
@@ -352,16 +445,16 @@ async def search_graph(
|
|||||||
|
|
||||||
# Deduplicate results before updating activation values
|
# Deduplicate results before updating activation values
|
||||||
# This prevents duplicates from propagating through the pipeline
|
# This prevents duplicates from propagating through the pipeline
|
||||||
from app.core.memory.src.search import _deduplicate_results
|
from app.core.memory.src.search import deduplicate_results
|
||||||
for key in results:
|
for key in results:
|
||||||
if isinstance(results[key], list):
|
if isinstance(results[key], list):
|
||||||
results[key] = _deduplicate_results(results[key])
|
results[key] = deduplicate_results(results[key])
|
||||||
|
|
||||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||||
# Skip activation updates if only searching summaries (optimization)
|
# Skip activation updates if only searching summaries (optimization)
|
||||||
needs_activation_update = any(
|
needs_activation_update = any(
|
||||||
key in include and key in results and results[key]
|
key in include and key in results and results[key]
|
||||||
for key in ['statements', 'entities', 'chunks']
|
for key in [Neo4jNodeType.STATEMENT, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY]
|
||||||
)
|
)
|
||||||
|
|
||||||
if needs_activation_update:
|
if needs_activation_update:
|
||||||
@@ -378,7 +471,7 @@ async def search_graph_by_embedding(
|
|||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
embedder_client,
|
embedder_client,
|
||||||
query_text: str,
|
query_text: str,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: str,
|
||||||
limit: int = 50,
|
limit: int = 50,
|
||||||
include=None,
|
include=None,
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
@@ -394,96 +487,32 @@ async def search_graph_by_embedding(
|
|||||||
- Returns up to 'limit' per included type
|
- Returns up to 'limit' per included type
|
||||||
"""
|
"""
|
||||||
if include is None:
|
if include is None:
|
||||||
include = ["statements", "chunks", "entities", "summaries"]
|
include = [
|
||||||
import time
|
Neo4jNodeType.STATEMENT,
|
||||||
|
Neo4jNodeType.CHUNK,
|
||||||
|
Neo4jNodeType.EXTRACTEDENTITY,
|
||||||
|
Neo4jNodeType.MEMORYSUMMARY,
|
||||||
|
Neo4jNodeType.PERCEPTUAL
|
||||||
|
]
|
||||||
|
|
||||||
# Get embedding for the query
|
|
||||||
embed_start = time.time()
|
|
||||||
embeddings = await embedder_client.response([query_text])
|
embeddings = await embedder_client.response([query_text])
|
||||||
embed_time = time.time() - embed_start
|
|
||||||
logger.debug(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
|
||||||
|
|
||||||
if not embeddings or not embeddings[0]:
|
if not embeddings or not embeddings[0]:
|
||||||
logger.warning(
|
logger.warning(f"search_graph_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
||||||
f"search_graph_by_embedding: embedding 生成失败或为空,"
|
return {search_key: [] for search_key in include}
|
||||||
f"query='{query_text[:50]}', end_user_id={end_user_id},向量检索跳过"
|
|
||||||
)
|
|
||||||
return {"statements": [], "chunks": [], "entities": [], "summaries": [], "communities": []}
|
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
|
|
||||||
# Prepare tasks for parallel execution
|
# Prepare tasks for parallel execution
|
||||||
tasks = []
|
tasks = []
|
||||||
task_keys = []
|
task_keys = []
|
||||||
|
|
||||||
# Statements (embedding)
|
for node_type in include:
|
||||||
if "statements" in include:
|
tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit))
|
||||||
tasks.append(connector.execute_query(
|
task_keys.append(node_type.value)
|
||||||
STATEMENT_EMBEDDING_SEARCH,
|
|
||||||
json_format=True,
|
|
||||||
embedding=embedding,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("statements")
|
|
||||||
|
|
||||||
# Chunks (embedding)
|
|
||||||
if "chunks" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
CHUNK_EMBEDDING_SEARCH,
|
|
||||||
json_format=True,
|
|
||||||
embedding=embedding,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("chunks")
|
|
||||||
|
|
||||||
# Entities
|
|
||||||
if "entities" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
ENTITY_EMBEDDING_SEARCH,
|
|
||||||
json_format=True,
|
|
||||||
embedding=embedding,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("entities")
|
|
||||||
|
|
||||||
# Memory summaries
|
|
||||||
if "summaries" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
|
||||||
json_format=True,
|
|
||||||
embedding=embedding,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("summaries")
|
|
||||||
|
|
||||||
# Communities (向量语义匹配)
|
|
||||||
if "communities" in include:
|
|
||||||
tasks.append(connector.execute_query(
|
|
||||||
COMMUNITY_EMBEDDING_SEARCH,
|
|
||||||
json_format=True,
|
|
||||||
embedding=embedding,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
))
|
|
||||||
task_keys.append("communities")
|
|
||||||
|
|
||||||
# Execute all queries in parallel
|
|
||||||
query_start = time.time()
|
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
query_time = time.time() - query_start
|
|
||||||
logger.debug(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s")
|
|
||||||
|
|
||||||
# Build results dictionary
|
# Build results dictionary
|
||||||
results: Dict[str, List[Dict[str, Any]]] = {
|
results: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
"statements": [],
|
|
||||||
"chunks": [],
|
|
||||||
"entities": [],
|
|
||||||
"summaries": [],
|
|
||||||
"communities": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
for key, result in zip(task_keys, task_results):
|
for key, result in zip(task_keys, task_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
@@ -494,16 +523,16 @@ async def search_graph_by_embedding(
|
|||||||
|
|
||||||
# Deduplicate results before updating activation values
|
# Deduplicate results before updating activation values
|
||||||
# This prevents duplicates from propagating through the pipeline
|
# This prevents duplicates from propagating through the pipeline
|
||||||
from app.core.memory.src.search import _deduplicate_results
|
from app.core.memory.src.search import deduplicate_results
|
||||||
for key in results:
|
for key in results:
|
||||||
if isinstance(results[key], list):
|
if isinstance(results[key], list):
|
||||||
results[key] = _deduplicate_results(results[key])
|
results[key] = deduplicate_results(results[key])
|
||||||
|
|
||||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||||
# Skip activation updates if only searching summaries (optimization)
|
# Skip activation updates if only searching summaries (optimization)
|
||||||
needs_activation_update = any(
|
needs_activation_update = any(
|
||||||
key in include and key in results and results[key]
|
key in include and key in results and results[key]
|
||||||
for key in ['statements', 'entities', 'chunks']
|
for key in [Neo4jNodeType.STATEMENT, Neo4jNodeType.EXTRACTEDENTITY, Neo4jNodeType.MEMORYSUMMARY]
|
||||||
)
|
)
|
||||||
|
|
||||||
if needs_activation_update:
|
if needs_activation_update:
|
||||||
@@ -781,12 +810,12 @@ async def search_graph_community_expand(
|
|||||||
expanded.extend(result)
|
expanded.extend(result)
|
||||||
|
|
||||||
# 按 activation_value 全局排序后去重
|
# 按 activation_value 全局排序后去重
|
||||||
from app.core.memory.src.search import _deduplicate_results
|
from app.core.memory.src.search import deduplicate_results
|
||||||
expanded.sort(
|
expanded.sort(
|
||||||
key=lambda x: float(x.get("activation_value") or 0),
|
key=lambda x: float(x.get("activation_value") or 0),
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
expanded = _deduplicate_results(expanded)
|
expanded = deduplicate_results(expanded)
|
||||||
|
|
||||||
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
|
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
|
||||||
return {"expanded_statements": expanded}
|
return {"expanded_statements": expanded}
|
||||||
@@ -999,98 +1028,3 @@ async def search_graph_l_valid_at(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
async def search_perceptual(
|
|
||||||
connector: Neo4jConnector,
|
|
||||||
query: str,
|
|
||||||
end_user_id: Optional[str] = None,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
|
||||||
"""
|
|
||||||
Search Perceptual memory nodes using fulltext keyword search.
|
|
||||||
|
|
||||||
Matches against summary, topic, and domain fields via the perceptualFulltext index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connector: Neo4j connector
|
|
||||||
query: Query text for full-text search
|
|
||||||
end_user_id: Optional user filter
|
|
||||||
limit: Max results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
perceptuals = await connector.execute_query(
|
|
||||||
SEARCH_PERCEPTUAL_BY_KEYWORD,
|
|
||||||
query=escape_lucene_query(query),
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"search_perceptual: keyword search failed: {e}")
|
|
||||||
perceptuals = []
|
|
||||||
|
|
||||||
# Deduplicate
|
|
||||||
from app.core.memory.src.search import _deduplicate_results
|
|
||||||
perceptuals = _deduplicate_results(perceptuals)
|
|
||||||
|
|
||||||
return {"perceptuals": perceptuals}
|
|
||||||
|
|
||||||
|
|
||||||
async def search_perceptual_by_embedding(
|
|
||||||
connector: Neo4jConnector,
|
|
||||||
embedder_client: OpenAIEmbedderClient,
|
|
||||||
query_text: str,
|
|
||||||
end_user_id: Optional[str] = None,
|
|
||||||
limit: int = 10,
|
|
||||||
) -> Dict[str, List[Dict[str, Any]]]:
|
|
||||||
"""
|
|
||||||
Search Perceptual memory nodes using embedding-based semantic search.
|
|
||||||
|
|
||||||
Uses cosine similarity on summary_embedding via the perceptual_summary_embedding_index.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
connector: Neo4j connector
|
|
||||||
embedder_client: Embedding client with async response() method
|
|
||||||
query_text: Query text to embed
|
|
||||||
end_user_id: Optional user filter
|
|
||||||
limit: Max results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with 'perceptuals' key containing matched perceptual memory nodes
|
|
||||||
"""
|
|
||||||
embeddings = await embedder_client.response([query_text])
|
|
||||||
if not embeddings or not embeddings[0]:
|
|
||||||
logger.warning(f"search_perceptual_by_embedding: embedding generation failed for '{query_text[:50]}'")
|
|
||||||
return {"perceptuals": []}
|
|
||||||
|
|
||||||
embedding = embeddings[0]
|
|
||||||
|
|
||||||
try:
|
|
||||||
perceptuals = await connector.execute_query(
|
|
||||||
SEARCH_PERCEPTUAL_BY_USER_ID,
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
)
|
|
||||||
ids = [item['id'] for item in perceptuals]
|
|
||||||
vectors = [item['summary_embedding'] for item in perceptuals]
|
|
||||||
sim_res = cosine_similarity_search(embedding, vectors, limit=limit)
|
|
||||||
perceptual_res = {
|
|
||||||
ids[idx]: score
|
|
||||||
for idx, score in sim_res.items()
|
|
||||||
}
|
|
||||||
perceptuals = await connector.execute_query(
|
|
||||||
SEARCH_PERCEPTUAL_BY_IDS,
|
|
||||||
ids=list(perceptual_res.keys())
|
|
||||||
)
|
|
||||||
for perceptual in perceptuals:
|
|
||||||
perceptual["score"] = perceptual_res[perceptual["id"]]
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"search_perceptual_by_embedding: vector search failed: {e}")
|
|
||||||
perceptuals = []
|
|
||||||
|
|
||||||
from app.core.memory.src.search import _deduplicate_results
|
|
||||||
perceptuals = _deduplicate_results(perceptuals)
|
|
||||||
|
|
||||||
return {"perceptuals": perceptuals}
|
|
||||||
|
|||||||
@@ -70,6 +70,12 @@ class Neo4jConnector:
|
|||||||
auth=basic_auth(username, password)
|
auth=basic_auth(username, password)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.close()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""关闭数据库连接
|
"""关闭数据库连接
|
||||||
|
|
||||||
@@ -77,11 +83,11 @@ class Neo4jConnector:
|
|||||||
"""
|
"""
|
||||||
await self.driver.close()
|
await self.driver.close()
|
||||||
|
|
||||||
async def execute_query(self, query: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
async def execute_query(self, cypher: str, json_format=False, **kwargs: Any) -> List[Dict[str, Any]]:
|
||||||
"""执行Cypher查询
|
"""执行Cypher查询
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: Cypher查询语句
|
cypher: Cypher查询语句
|
||||||
json_format: json格式化
|
json_format: json格式化
|
||||||
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
**kwargs: 查询参数,将作为参数传递给Cypher查询
|
||||||
|
|
||||||
@@ -92,7 +98,7 @@ class Neo4jConnector:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
result = await self.driver.execute_query(
|
result = await self.driver.execute_query(
|
||||||
query,
|
cypher,
|
||||||
database="neo4j",
|
database="neo4j",
|
||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user