feat(memory): add session-based chat history and user metadata retrieval
- Add ChatSessionCache to manage chat history per session - Add SEARCH_USER_METADATA cypher query for retrieving user entity metadata - Add "str" mode support to StructResponse for raw text extraction - Add content_str field to MemorySearchResult for pre-formatted content - Fix sandbox URL by removing hardcoded port - Add description field to entity search results - Remove history from UserInput schema, use session_id instead
This commit is contained in:
@@ -43,10 +43,13 @@ class MemoryService:
|
||||
self,
|
||||
query: str,
|
||||
search_switch: SearchStrategy,
|
||||
history: list | None = None,
|
||||
limit: int = 10,
|
||||
) -> MemorySearchResult:
|
||||
if history is None:
|
||||
history = []
|
||||
with get_db_context() as db:
|
||||
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
|
||||
return await ReadPipeLine(self.ctx, db).run(query, search_switch, history, limit)
|
||||
|
||||
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -32,10 +32,12 @@ class Memory(BaseModel):
|
||||
|
||||
class MemorySearchResult(BaseModel):
|
||||
memories: list[Memory]
|
||||
content_str: str = Field(default="")
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def content(self) -> str:
|
||||
if self.content_str:
|
||||
return self.content_str
|
||||
return "\n".join([memory.content for memory in self.memories])
|
||||
|
||||
@computed_field
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from app.core.memory.enums import SearchStrategy, StorageType
|
||||
from app.core.memory.models.service_models import MemorySearchResult
|
||||
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
|
||||
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
|
||||
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
|
||||
from app.core.memory.read_services.generate_engine.retrieval_summary import RetrievalSummaryProcessor
|
||||
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
|
||||
|
||||
|
||||
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||
@@ -10,20 +11,30 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||
self,
|
||||
query: str,
|
||||
search_switch: SearchStrategy,
|
||||
history: list,
|
||||
limit: int = 10,
|
||||
includes=None
|
||||
) -> MemorySearchResult:
|
||||
memory_l0 = None
|
||||
if self.ctx.storage_type == StorageType.NEO4J:
|
||||
memory_l0 = await self._get_search_service(includes).memory_l0()
|
||||
|
||||
query = QueryPreprocessor.process(query)
|
||||
match search_switch:
|
||||
case SearchStrategy.DEEP:
|
||||
return await self._deep_read(query, limit, includes)
|
||||
res = await self._deep_read(query, history, limit, includes)
|
||||
case SearchStrategy.NORMAL:
|
||||
return await self._normal_read(query, limit, includes)
|
||||
res = await self._normal_read(query, history, limit, includes)
|
||||
case SearchStrategy.QUICK:
|
||||
return await self._quick_read(query, limit, includes)
|
||||
res = await self._quick_read(query, limit, includes)
|
||||
case _:
|
||||
raise RuntimeError("Unsupported search strategy")
|
||||
|
||||
if memory_l0 is not None:
|
||||
res.content_str = memory_l0.content + '\n' + res.content
|
||||
res.memories.insert(0, memory_l0)
|
||||
return res
|
||||
|
||||
def _get_search_service(self, includes=None):
|
||||
if self.ctx.storage_type == StorageType.NEO4J:
|
||||
return Neo4jSearchService(
|
||||
@@ -37,10 +48,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||
self.db
|
||||
)
|
||||
|
||||
async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
||||
async def _deep_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
|
||||
search_service = self._get_search_service(includes)
|
||||
questions = await QueryPreprocessor.split(
|
||||
query,
|
||||
history,
|
||||
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||
)
|
||||
query_results = []
|
||||
@@ -49,12 +61,18 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||
query_results.append(search_results)
|
||||
results = sum(query_results, start=MemorySearchResult(memories=[]))
|
||||
results.memories.sort(key=lambda x: x.score, reverse=True)
|
||||
results.content_str = await RetrievalSummaryProcessor.summary(
|
||||
query,
|
||||
results.content,
|
||||
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||
)
|
||||
return results
|
||||
|
||||
async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
||||
async def _normal_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
|
||||
search_service = self._get_search_service(includes)
|
||||
questions = await QueryPreprocessor.split(
|
||||
query,
|
||||
history,
|
||||
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||
)
|
||||
query_results = []
|
||||
@@ -63,6 +81,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
|
||||
query_results.append(search_results)
|
||||
results = sum(query_results, start=MemorySearchResult(memories=[]))
|
||||
results.memories.sort(key=lambda x: x.score, reverse=True)
|
||||
results.content_str = await RetrievalSummaryProcessor.summary(
|
||||
query,
|
||||
results.content,
|
||||
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
|
||||
)
|
||||
return results
|
||||
|
||||
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
|
||||
|
||||
15
api/app/core/memory/prompt/retrieval_summary.jinja2
Normal file
15
api/app/core/memory/prompt/retrieval_summary.jinja2
Normal file
@@ -0,0 +1,15 @@
|
||||
You are a Content Condenser for a memory-augmented retrieval system.
|
||||
|
||||
Your task is to compress the retrieved content while preserving all information that is highly relevant to the user’s query.
|
||||
|
||||
Guidelines:
|
||||
|
||||
Focus only on content related to the query; ignore irrelevant parts.
|
||||
Remove redundancy, filler, or repeated information only for non-XML content.
|
||||
Preserve all factual details: names, dates, decisions, code snippets, technical details.
|
||||
If relevant information is inside XML tags, do not remove, merge, or compress the XML tags or their internal text; keep them fully intact.
|
||||
Structure multiple relevant points as a compact bullet list or paragraph, depending on density.
|
||||
If no content is relevant, return exactly: "No relevant information found."
|
||||
Do not add any knowledge or facts not in the retrieved content.
|
||||
# [IMPORTANT] OUTPUT ONLY THE CONDENSED CONTENT, DO NOT ATTEMPT TO ANSWER THE QUERY.
|
||||
# [IMPORTANT] DO NOT REMOVE OR PARAPHRASE HIGHLY RELEVANT INFORMATION.
|
||||
@@ -21,14 +21,14 @@ class QueryPreprocessor:
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
async def split(query: str, llm_client: RedBearLLM):
|
||||
async def split(query: str, history: list, llm_client: RedBearLLM):
|
||||
system_prompt = prompt_manager.render(
|
||||
name="problem_split",
|
||||
datetime=datetime.now().strftime("%Y-%m-%d"),
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": query},
|
||||
{"role": "user", "content": f"<history>{history}</history><query>{query}</query>"},
|
||||
]
|
||||
try:
|
||||
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')
|
||||
|
||||
@@ -1,11 +1,29 @@
|
||||
import logging
|
||||
|
||||
from app.core.models import RedBearLLM
|
||||
from app.core.memory.prompt import prompt_manager
|
||||
from app.core.memory.utils.llm.llm_utils import StructResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RetrievalSummaryProcessor:
|
||||
@staticmethod
|
||||
def summary(content: str, llm_client: RedBearLLM):
|
||||
return
|
||||
async def summary(query, content: str, llm_client: RedBearLLM):
|
||||
system_prompt = prompt_manager.render(
|
||||
name="retrieval_summary"
|
||||
)
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"<query>{query}</query><content>{content}</content>"},
|
||||
]
|
||||
try:
|
||||
summary = await llm_client.ainvoke(messages) | StructResponse(mode='str')
|
||||
return summary
|
||||
except:
|
||||
logger.error("Failed to generate reply summary, returning original content", exc_info=True)
|
||||
return content
|
||||
|
||||
@staticmethod
|
||||
def verify(content: str, llm_client: RedBearLLM):
|
||||
async def verify(query, content: str, llm_client: RedBearLLM):
|
||||
return
|
||||
|
||||
@@ -14,6 +14,8 @@ from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.read_services.search_engine.result_builder import MetadataBuilder
|
||||
from app.repositories.neo4j.graph_search import search_user_metadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -177,6 +179,22 @@ class Neo4jSearchService:
|
||||
memories.sort(key=lambda x: x.score, reverse=True)
|
||||
return MemorySearchResult(memories=memories[:limit])
|
||||
|
||||
async def memory_l0(self) -> Memory:
|
||||
async with Neo4jConnector() as connector:
|
||||
end_user_id = self.ctx.end_user_id
|
||||
user_meta = await search_user_metadata(connector, end_user_id)
|
||||
metadata = MetadataBuilder(user_meta)
|
||||
memory = Memory(
|
||||
score=1,
|
||||
source=Neo4jNodeType.EXTRACTEDENTITY,
|
||||
query='',
|
||||
id=end_user_id,
|
||||
content=metadata.content,
|
||||
data=metadata.data,
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
class RAGSearchService:
|
||||
def __init__(self, ctx: MemoryContext, db: Session):
|
||||
|
||||
@@ -42,7 +42,15 @@ class ChunkBuilder(BaseBuilder):
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return self.record.get("content")
|
||||
parts = ["<chunk>"]
|
||||
fields = [
|
||||
("content", self.record.get("content", "")),
|
||||
]
|
||||
for tag, value in fields:
|
||||
if value:
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</chunk>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class StatementBuiler(BaseBuilder):
|
||||
@@ -57,7 +65,15 @@ class StatementBuiler(BaseBuilder):
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return self.record.get("statement")
|
||||
parts = ["<statement>"]
|
||||
fields = [
|
||||
("statement", self.record.get("statement", "")),
|
||||
]
|
||||
for tag, value in fields:
|
||||
if value:
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</statement>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class EntityBuilder(BaseBuilder):
|
||||
@@ -73,10 +89,16 @@ class EntityBuilder(BaseBuilder):
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return (f"<entity>"
|
||||
f"<name>{self.record.get("name")}<name>"
|
||||
f"<description>{self.record.get("description")}</description>"
|
||||
f"</entity>")
|
||||
parts = ["<entity>"]
|
||||
fields = [
|
||||
("name", self.record.get("name", "")),
|
||||
("description", self.record.get("description", "")),
|
||||
]
|
||||
for tag, value in fields:
|
||||
if value:
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</entity>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class SummaryBuilder(BaseBuilder):
|
||||
@@ -91,7 +113,15 @@ class SummaryBuilder(BaseBuilder):
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return self.record.get("content")
|
||||
parts = ["<summary>"]
|
||||
fields = [
|
||||
("content", self.record.get("content", "")),
|
||||
]
|
||||
for tag, value in fields:
|
||||
if value:
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</summary>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class PerceptualBuilder(BaseBuilder):
|
||||
@@ -114,15 +144,21 @@ class PerceptualBuilder(BaseBuilder):
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return ("<history-file-info>"
|
||||
f"<file-name>{self.record.get('file_name')}</file-name>"
|
||||
f"<file-path>{self.record.get('file_path')}</file-path>"
|
||||
f"<summary>{self.record.get('summary')}</summary>"
|
||||
f"<topic>{self.record.get('topic')}</topic>"
|
||||
f"<domain>{self.record.get('domain')}</domain>"
|
||||
f"<keywords>{self.record.get('keywords')}</keywords>"
|
||||
f"<file-type>{self.record.get('file_type')}</file-type>"
|
||||
"</history-file-info>")
|
||||
parts = ["<history-file-info>"]
|
||||
fields = [
|
||||
("file-name", self.record.get("file_name", "")),
|
||||
("file-path", self.record.get("file_path", "")),
|
||||
("summary", self.record.get("summary", "")),
|
||||
("topic", self.record.get("topic", "")),
|
||||
("domain", self.record.get("domain", "")),
|
||||
("keywords", self.record.get("keywords", [])),
|
||||
("file-type", self.record.get("file_type", "")),
|
||||
]
|
||||
for tag, value in fields:
|
||||
if value:
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</history-file-info>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class CommunityBuilder(BaseBuilder):
|
||||
@@ -137,7 +173,54 @@ class CommunityBuilder(BaseBuilder):
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
return self.record.get("content")
|
||||
parts = ["<community>"]
|
||||
fields = [
|
||||
("content", self.record.get("content", "")),
|
||||
]
|
||||
for tag, value in fields:
|
||||
if value:
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</community>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
class MetadataBuilder(BaseBuilder):
|
||||
@property
|
||||
def data(self) -> dict:
|
||||
return {
|
||||
"id": self.record.get("id", ""),
|
||||
"aliases_name": self.record.get("aliases", []) or [],
|
||||
"description": self.record.get("description", ""),
|
||||
"anchors": self.record.get("anchors", []) or [],
|
||||
"beliefs_or_stances": self.record.get("beliefs_or_stances", []) or [],
|
||||
"core_facts": self.record.get("core_facts", []) or [],
|
||||
"events": self.record.get("events", []) or [],
|
||||
"goals": self.record.get("goals", []) or [],
|
||||
"interests": self.record.get("interests", []) or [],
|
||||
"relations": self.record.get("relations", []) or [],
|
||||
"traits": self.record.get("traits", []) or [],
|
||||
}
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
parts = ["<user-info>"]
|
||||
fields = [
|
||||
("description", self.record.get("description", "")),
|
||||
("aliases", self.record.get("aliases", [])),
|
||||
("anchors", self.record.get("anchors", [])),
|
||||
("beliefs_or_stances", self.record.get("beliefs_or_stances", [])),
|
||||
("core_facts", self.record.get("core_facts", [])),
|
||||
("events", self.record.get("events", [])),
|
||||
("goals", self.record.get("goals", [])),
|
||||
("interests", self.record.get("interests", [])),
|
||||
("relations", self.record.get("relations", [])),
|
||||
("traits", self.record.get("traits", [])),
|
||||
]
|
||||
for tag, value in fields:
|
||||
if value:
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</user-info>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def data_builder_factory(node_type, data: dict) -> T:
|
||||
|
||||
@@ -17,7 +17,7 @@ async def handle_response(response: type[BaseModel]) -> dict:
|
||||
|
||||
|
||||
class StructResponse:
|
||||
def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None):
|
||||
def __init__(self, mode: Literal["json", "pydantic", "str"], model: Type[BaseModel] = None):
|
||||
self.mode = mode
|
||||
if mode == "pydantic" and model is None:
|
||||
raise ValueError("Pydantic model is required")
|
||||
@@ -31,6 +31,8 @@ class StructResponse:
|
||||
for block in other.content_blocks:
|
||||
if block.get("type") == "text":
|
||||
text += block.get("text", "")
|
||||
if self.mode == "str":
|
||||
return text
|
||||
fixed_json = json_repair.repair_json(text, return_objects=True)
|
||||
if self.mode == "json":
|
||||
return fixed_json
|
||||
|
||||
Reference in New Issue
Block a user