feat(memory): add session-based chat history and user metadata retrieval

- Add ChatSessionCache to manage chat history per session
- Add SEARCH_USER_METADATA cypher query for retrieving user entity metadata
- Add "str" mode support to StructResponse for raw text extraction
- Add content_str field to MemorySearchResult for pre-formatted content
- Fix sandbox URL by removing hardcoded port
- Add description field to entity search results
- Remove history from UserInput schema, use session_id instead
This commit is contained in:
Eternity
2026-05-06 16:24:53 +08:00
parent 6b68ee9fc8
commit 3f9740412a
19 changed files with 387 additions and 117 deletions

View File

@@ -43,10 +43,13 @@ class MemoryService:
self,
query: str,
search_switch: SearchStrategy,
history: list | None = None,
limit: int = 10,
) -> MemorySearchResult:
if history is None:
history = []
with get_db_context() as db:
return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
return await ReadPipeLine(self.ctx, db).run(query, search_switch, history, limit)
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
raise NotImplementedError

View File

@@ -32,10 +32,12 @@ class Memory(BaseModel):
class MemorySearchResult(BaseModel):
memories: list[Memory]
content_str: str = Field(default="")
@computed_field
@property
def content(self) -> str:
if self.content_str:
return self.content_str
return "\n".join([memory.content for memory in self.memories])
@computed_field

View File

@@ -1,8 +1,9 @@
from app.core.memory.enums import SearchStrategy, StorageType
from app.core.memory.models.service_models import MemorySearchResult
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
from app.core.memory.read_services.generate_engine.retrieval_summary import RetrievalSummaryProcessor
from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
@@ -10,20 +11,30 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
self,
query: str,
search_switch: SearchStrategy,
history: list,
limit: int = 10,
includes=None
) -> MemorySearchResult:
memory_l0 = None
if self.ctx.storage_type == StorageType.NEO4J:
memory_l0 = await self._get_search_service(includes).memory_l0()
query = QueryPreprocessor.process(query)
match search_switch:
case SearchStrategy.DEEP:
return await self._deep_read(query, limit, includes)
res = await self._deep_read(query, history, limit, includes)
case SearchStrategy.NORMAL:
return await self._normal_read(query, limit, includes)
res = await self._normal_read(query, history, limit, includes)
case SearchStrategy.QUICK:
return await self._quick_read(query, limit, includes)
res = await self._quick_read(query, limit, includes)
case _:
raise RuntimeError("Unsupported search strategy")
if memory_l0 is not None:
res.content_str = memory_l0.content + '\n' + res.content
res.memories.insert(0, memory_l0)
return res
def _get_search_service(self, includes=None):
if self.ctx.storage_type == StorageType.NEO4J:
return Neo4jSearchService(
@@ -37,10 +48,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
self.db
)
async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
async def _deep_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
search_service = self._get_search_service(includes)
questions = await QueryPreprocessor.split(
query,
history,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
query_results = []
@@ -49,12 +61,18 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
query_results.append(search_results)
results = sum(query_results, start=MemorySearchResult(memories=[]))
results.memories.sort(key=lambda x: x.score, reverse=True)
results.content_str = await RetrievalSummaryProcessor.summary(
query,
results.content,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
return results
async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
async def _normal_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
search_service = self._get_search_service(includes)
questions = await QueryPreprocessor.split(
query,
history,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
query_results = []
@@ -63,6 +81,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
query_results.append(search_results)
results = sum(query_results, start=MemorySearchResult(memories=[]))
results.memories.sort(key=lambda x: x.score, reverse=True)
results.content_str = await RetrievalSummaryProcessor.summary(
query,
results.content,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
return results
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:

View File

@@ -0,0 +1,15 @@
You are a Content Condenser for a memory-augmented retrieval system.
Your task is to compress the retrieved content while preserving all information that is highly relevant to the users query.
Guidelines:
Focus only on content related to the query; ignore irrelevant parts.
Remove redundancy, filler, or repeated information only for non-XML content.
Preserve all factual details: names, dates, decisions, code snippets, technical details.
If relevant information is inside XML tags, do not remove, merge, or compress the XML tags or their internal text; keep them fully intact.
Structure multiple relevant points as a compact bullet list or paragraph, depending on density.
If no content is relevant, return exactly: "No relevant information found."
Do not add any knowledge or facts not in the retrieved content.
# [IMPORTANT] OUTPUT ONLY THE CONDENSED CONTENT, DO NOT ATTEMPT TO ANSWER THE QUERY.
# [IMPORTANT] DO NOT REMOVE OR PARAPHRASE HIGHLY RELEVANT INFORMATION.

View File

@@ -21,14 +21,14 @@ class QueryPreprocessor:
return text
@staticmethod
async def split(query: str, llm_client: RedBearLLM):
async def split(query: str, history: list, llm_client: RedBearLLM):
system_prompt = prompt_manager.render(
name="problem_split",
datetime=datetime.now().strftime("%Y-%m-%d"),
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": query},
{"role": "user", "content": f"<history>{history}</history><query>{query}</query>"},
]
try:
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')

View File

@@ -1,11 +1,29 @@
import logging
from app.core.models import RedBearLLM
from app.core.memory.prompt import prompt_manager
from app.core.memory.utils.llm.llm_utils import StructResponse
logger = logging.getLogger(__name__)
class RetrievalSummaryProcessor:
@staticmethod
def summary(content: str, llm_client: RedBearLLM):
return
async def summary(query, content: str, llm_client: RedBearLLM):
system_prompt = prompt_manager.render(
name="retrieval_summary"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"<query>{query}</query><content>{content}</content>"},
]
try:
summary = await llm_client.ainvoke(messages) | StructResponse(mode='str')
return summary
except:
logger.error("Failed to generate reply summary, returning original content", exc_info=True)
return content
@staticmethod
def verify(content: str, llm_client: RedBearLLM):
async def verify(query, content: str, llm_client: RedBearLLM):
return

View File

@@ -14,6 +14,8 @@ from app.core.rag.nlp.search import knowledge_retrieval
from app.repositories import knowledge_repository
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.read_services.search_engine.result_builder import MetadataBuilder
from app.repositories.neo4j.graph_search import search_user_metadata
logger = logging.getLogger(__name__)
@@ -177,6 +179,22 @@ class Neo4jSearchService:
memories.sort(key=lambda x: x.score, reverse=True)
return MemorySearchResult(memories=memories[:limit])
async def memory_l0(self) -> Memory:
async with Neo4jConnector() as connector:
end_user_id = self.ctx.end_user_id
user_meta = await search_user_metadata(connector, end_user_id)
metadata = MetadataBuilder(user_meta)
memory = Memory(
score=1,
source=Neo4jNodeType.EXTRACTEDENTITY,
query='',
id=end_user_id,
content=metadata.content,
data=metadata.data,
)
return memory
class RAGSearchService:
def __init__(self, ctx: MemoryContext, db: Session):

View File

@@ -42,7 +42,15 @@ class ChunkBuilder(BaseBuilder):
@property
def content(self) -> str:
return self.record.get("content")
parts = ["<chunk>"]
fields = [
("content", self.record.get("content", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</chunk>")
return "".join(parts)
class StatementBuiler(BaseBuilder):
@@ -57,7 +65,15 @@ class StatementBuiler(BaseBuilder):
@property
def content(self) -> str:
return self.record.get("statement")
parts = ["<statement>"]
fields = [
("statement", self.record.get("statement", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</statement>")
return "".join(parts)
class EntityBuilder(BaseBuilder):
@@ -73,10 +89,16 @@ class EntityBuilder(BaseBuilder):
@property
def content(self) -> str:
return (f"<entity>"
f"<name>{self.record.get("name")}<name>"
f"<description>{self.record.get("description")}</description>"
f"</entity>")
parts = ["<entity>"]
fields = [
("name", self.record.get("name", "")),
("description", self.record.get("description", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</entity>")
return "".join(parts)
class SummaryBuilder(BaseBuilder):
@@ -91,7 +113,15 @@ class SummaryBuilder(BaseBuilder):
@property
def content(self) -> str:
return self.record.get("content")
parts = ["<summary>"]
fields = [
("content", self.record.get("content", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</summary>")
return "".join(parts)
class PerceptualBuilder(BaseBuilder):
@@ -114,15 +144,21 @@ class PerceptualBuilder(BaseBuilder):
@property
def content(self) -> str:
return ("<history-file-info>"
f"<file-name>{self.record.get('file_name')}</file-name>"
f"<file-path>{self.record.get('file_path')}</file-path>"
f"<summary>{self.record.get('summary')}</summary>"
f"<topic>{self.record.get('topic')}</topic>"
f"<domain>{self.record.get('domain')}</domain>"
f"<keywords>{self.record.get('keywords')}</keywords>"
f"<file-type>{self.record.get('file_type')}</file-type>"
"</history-file-info>")
parts = ["<history-file-info>"]
fields = [
("file-name", self.record.get("file_name", "")),
("file-path", self.record.get("file_path", "")),
("summary", self.record.get("summary", "")),
("topic", self.record.get("topic", "")),
("domain", self.record.get("domain", "")),
("keywords", self.record.get("keywords", [])),
("file-type", self.record.get("file_type", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</history-file-info>")
return "".join(parts)
class CommunityBuilder(BaseBuilder):
@@ -137,7 +173,54 @@ class CommunityBuilder(BaseBuilder):
@property
def content(self) -> str:
return self.record.get("content")
parts = ["<community>"]
fields = [
("content", self.record.get("content", "")),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</community>")
return "".join(parts)
class MetadataBuilder(BaseBuilder):
@property
def data(self) -> dict:
return {
"id": self.record.get("id", ""),
"aliases_name": self.record.get("aliases", []) or [],
"description": self.record.get("description", ""),
"anchors": self.record.get("anchors", []) or [],
"beliefs_or_stances": self.record.get("beliefs_or_stances", []) or [],
"core_facts": self.record.get("core_facts", []) or [],
"events": self.record.get("events", []) or [],
"goals": self.record.get("goals", []) or [],
"interests": self.record.get("interests", []) or [],
"relations": self.record.get("relations", []) or [],
"traits": self.record.get("traits", []) or [],
}
@property
def content(self) -> str:
parts = ["<user-info>"]
fields = [
("description", self.record.get("description", "")),
("aliases", self.record.get("aliases", [])),
("anchors", self.record.get("anchors", [])),
("beliefs_or_stances", self.record.get("beliefs_or_stances", [])),
("core_facts", self.record.get("core_facts", [])),
("events", self.record.get("events", [])),
("goals", self.record.get("goals", [])),
("interests", self.record.get("interests", [])),
("relations", self.record.get("relations", [])),
("traits", self.record.get("traits", [])),
]
for tag, value in fields:
if value:
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</user-info>")
return "".join(parts)
def data_builder_factory(node_type, data: dict) -> T:

View File

@@ -17,7 +17,7 @@ async def handle_response(response: type[BaseModel]) -> dict:
class StructResponse:
def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None):
def __init__(self, mode: Literal["json", "pydantic", "str"], model: Type[BaseModel] = None):
self.mode = mode
if mode == "pydantic" and model is None:
raise ValueError("Pydantic model is required")
@@ -31,6 +31,8 @@ class StructResponse:
for block in other.content_blocks:
if block.get("type") == "text":
text += block.get("text", "")
if self.mode == "str":
return text
fixed_json = json_repair.repair_json(text, return_objects=True)
if self.mode == "json":
return fixed_json

View File

@@ -132,7 +132,7 @@ class CodeNode(BaseNode):
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
f"{settings.SANDBOX_URL}:8194/v1/sandbox/run",
f"{settings.SANDBOX_URL}/v1/sandbox/run",
headers={
"x-api-key": 'redbear-sandbox'
},

View File

@@ -40,6 +40,7 @@ class MemoryReadNode(BaseNode):
end_user_id=end_user_id,
user_rag_memory_id=state["user_rag_memory_id"],
)
# TODO: Historical Messages -> Used to refer to coreference resolution
search_result = await memory_service.read(
self._render_template(self.typed_config.message, variable_pool),
search_switch=SearchStrategy(self.typed_config.search_switch)