From 3f9740412a7796c6046bf56c67a4bb9734eb24a9 Mon Sep 17 00:00:00 2001
From: Eternity <1533512157@qq.com>
Date: Wed, 6 May 2026 16:24:53 +0800
Subject: [PATCH] feat(memory): add session-based chat history and user
metadata retrieval
- Add ChatSessionCache to manage chat history per session
- Add SEARCH_USER_METADATA cypher query for retrieving user entity metadata
- Add "str" mode support to StructResponse for raw text extraction
- Add content_str field to MemorySearchResult for pre-formatted content
- Fix sandbox URL by removing hardcoded port
- Add description field to entity search results
- Remove history from UserInput schema, use session_id instead
---
api/app/celery_task_scheduler.py | 87 +++++++------
.../controllers/memory_agent_controller.py | 72 +++++------
api/app/core/memory/memory_service.py | 5 +-
api/app/core/memory/models/service_models.py | 4 +-
api/app/core/memory/pipelines/memory_read.py | 35 +++++-
.../memory/prompt/retrieval_summary.jinja2 | 15 +++
.../generate_engine/query_preprocessor.py | 4 +-
.../generate_engine/retrieval_summary.py | 24 +++-
.../search_engine/content_search.py | 18 +++
.../search_engine/result_builder.py | 117 +++++++++++++++---
api/app/core/memory/utils/llm/llm_utils.py | 4 +-
api/app/core/workflow/nodes/code/node.py | 2 +-
api/app/core/workflow/nodes/memory/node.py | 1 +
api/app/repositories/neo4j/cypher_queries.py | 16 +++
api/app/repositories/neo4j/graph_search.py | 17 ++-
api/app/schemas/memory_agent_schema.py | 5 +-
api/app/services/draft_run_service.py | 1 +
api/app/utils/__init__.py | 0
api/app/utils/tmp_session.py | 77 ++++++++++++
19 files changed, 387 insertions(+), 117 deletions(-)
create mode 100644 api/app/core/memory/prompt/retrieval_summary.jinja2
create mode 100644 api/app/utils/__init__.py
create mode 100644 api/app/utils/tmp_session.py
diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py
index e7f946b6..9b8756cd 100644
--- a/api/app/celery_task_scheduler.py
+++ b/api/app/celery_task_scheduler.py
@@ -158,12 +158,19 @@ class RedisTaskScheduler:
return {"status": status, "task_id": task_id, "result": result_content}
def _cleanup_finished(self):
- pending = self.redis.hgetall(PENDING_HASH)
- if not pending:
+ cursor = 0
+ all_pending = {}
+ while True:
+ cursor, batch = self.redis.hscan(PENDING_HASH, cursor=cursor, count=100)
+ all_pending.update(batch)
+ if cursor == 0:
+ break
+
+ if not all_pending:
return
now = time.time()
- task_ids = list(pending.keys())
+ task_ids = list(all_pending.keys())
pipe = self.redis.pipeline()
for task_id in task_ids:
@@ -176,7 +183,7 @@ class RedisTaskScheduler:
for task_id, raw_result in zip(task_ids, results):
try:
- meta = json.loads(pending[task_id])
+ meta = json.loads(all_pending[task_id])
lock_key = meta["lock_key"]
dispatched_at = meta.get("dispatched_at", 0)
age = now - dispatched_at
@@ -276,6 +283,22 @@ class RedisTaskScheduler:
return True
return stable_hash(user_id) % self._shard_count == self._shard_index
+ def _commit_post_dispatch(self, lock_key, task, msg_id, dispatch_lock):
+ pipe = self.redis.pipeline()
+ pipe.set(lock_key, task.id, ex=3600)
+ pipe.hset(PENDING_HASH, task.id, json.dumps({
+ "lock_key": lock_key,
+ "dispatched_at": time.time(),
+ "msg_id": msg_id,
+ }))
+ pipe.delete(dispatch_lock)
+ pipe.set(
+ f"task_tracker:{msg_id}",
+ json.dumps({"status": "DISPATCHED", "task_id": task.id}),
+ ex=86400,
+ )
+ pipe.execute()
+
def _dispatch(self, msg_id, msg_data) -> bool:
user_id = msg_data["user_id"]
task_name = msg_data["task_name"]
@@ -308,28 +331,17 @@ class RedisTaskScheduler:
task_name, user_id, msg_id, e, exc_info=True,
)
return False
-
- try:
- pipe = self.redis.pipeline()
- pipe.set(lock_key, task.id, ex=3600)
- pipe.hset(PENDING_HASH, task.id, json.dumps({
- "lock_key": lock_key,
- "dispatched_at": time.time(),
- "msg_id": msg_id,
- }))
- pipe.delete(dispatch_lock)
- pipe.set(
- f"task_tracker:{msg_id}",
- json.dumps({"status": "DISPATCHED", "task_id": task.id}),
- ex=86400,
- )
- pipe.execute()
- except Exception as e:
- logger.error(
- "Post-dispatch state update failed for %s: %s",
- task.id, e, exc_info=True,
- )
- self.errors += 1
+ for attempt in range(2):
+ try:
+ self._commit_post_dispatch(lock_key, task, msg_id, dispatch_lock)
+ break
+ except Exception as e:
+ logger.error(
+ "Post-dispatch state update failed for %s: %s",
+ task.id, e, exc_info=True,
+ )
+ time.sleep(0.1)
+ self.errors += 1
self.dispatched += 1
logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id)
@@ -367,22 +379,21 @@ class RedisTaskScheduler:
return
for uid, msg in candidates:
+ queue_key = f"{USER_QUEUE_PREFIX}{uid}"
if self._dispatch(msg["msg_id"], msg):
- self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}")
+ self.redis.lpop(queue_key)
+ if self.redis.llen(queue_key) > 0:
+ self.redis.sadd(READY_SET, uid)
def schedule_loop(self):
self._heartbeat()
self._cleanup_finished()
- pipe = self.redis.pipeline()
- pipe.smembers(READY_SET)
- pipe.delete(READY_SET)
- results = pipe.execute()
- ready_users = results[0] or set()
-
+ ready_users = self.redis.smembers(READY_SET) or set()
my_users = [uid for uid in ready_users if self._is_mine(uid)]
-
- if not my_users:
+ if my_users:
+ self.redis.srem(READY_SET, *my_users)
+ else:
time.sleep(0.5)
return
@@ -445,7 +456,7 @@ class RedisTaskScheduler:
"Scheduler started: instance=%s", self.instance_id,
)
- while True:
+ while self.running:
try:
self.schedule_loop()
@@ -480,9 +491,7 @@ class RedisTaskScheduler:
logger.error("Shutdown cleanup error: %s", e)
-scheduler: RedisTaskScheduler | None = None
-if scheduler is None:
- scheduler = RedisTaskScheduler()
+scheduler = RedisTaskScheduler()
if __name__ == "__main__":
import signal
diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py
index cba17f42..c9d41494 100644
--- a/api/app/controllers/memory_agent_controller.py
+++ b/api/app/controllers/memory_agent_controller.py
@@ -27,6 +27,7 @@ from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.memory_agent_service import get_end_user_connected_config as get_config
from app.services.model_service import ModelConfigService
+from app.utils.tmp_session import ChatSessionCache
load_dotenv()
api_logger = get_api_logger()
@@ -300,60 +301,39 @@ async def read_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
+ session_id = user_input.session_id.hex
+
api_logger.info(
- f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
+ f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}, session_id={session_id}")
try:
- # result = await memory_agent_service.read_memory(
- # user_input.end_user_id,
- # user_input.message,
- # user_input.history,
- # user_input.search_switch,
- # config_id,
- # db,
- # storage_type,
- # user_rag_memory_id
- # )
- # if str(user_input.search_switch) == "2":
- # retrieve_info = result['answer']
- # history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
- # user_input.end_user_id)
- # query = user_input.message
- #
- # # 调用 memory_agent_service 的方法生成最终答案
- # result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
- # end_user_id=user_input.end_user_id,
- # retrieve_info=retrieve_info,
- # history=history,
- # query=query,
- # config_id=config_id,
- # db=db
- # )
- # if "信息不足,无法回答" in result['answer']:
- # result['answer'] = retrieve_info
memory_config = get_config(user_input.end_user_id, db)
service = MemoryService(
db,
memory_config["memory_config_id"],
end_user_id=user_input.end_user_id
)
+ session_cache = ChatSessionCache(session_id)
search_result = await service.read(
user_input.message,
- SearchStrategy(user_input.search_switch)
+ SearchStrategy(user_input.search_switch),
+ history=await session_cache.get_history(),
)
intermediate_outputs = []
sub_queries = set()
for memory in search_result.memories:
sub_queries.add(str(memory.query))
+ idx = 0
if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]:
intermediate_outputs.append({
"type": "problem_split",
"title": "问题拆分",
"data": [
{
- "id": f"Q{idx+1}",
+ "id": f"Q{(idx := idx + 1)}",
"question": question
}
- for idx, question in enumerate(sub_queries)
+ for question in sub_queries
+ if question
]
})
perceptual_data = [
@@ -375,16 +355,24 @@ async def read_server(
"raw_result": search_result.memories,
"total": len(search_result.memories),
})
+ answer = await memory_agent_service.generate_summary_from_retrieve(
+ end_user_id=user_input.end_user_id,
+ retrieve_info=search_result.content,
+ history=[],
+ query=user_input.message,
+ config_id=config_id,
+ db=db
+ )
+ await session_cache.append_many(
+ [
+ {"role": "user", "content": user_input.message},
+ {"role": "assistant", "content": answer}
+ ]
+ )
result = {
- 'answer': await memory_agent_service.generate_summary_from_retrieve(
- end_user_id=user_input.end_user_id,
- retrieve_info=search_result.content,
- history=[],
- query=user_input.message,
- config_id=config_id,
- db=db
- ),
- "intermediate_outputs": intermediate_outputs
+ 'answer': answer,
+ "intermediate_outputs": intermediate_outputs,
+ "session_id": session_id,
}
return success(data=result, msg="回复对话消息成功")
@@ -480,9 +468,11 @@ async def read_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
+ session_id = user_input.session_id.hex
+ session_cache = ChatSessionCache(session_id)
task = celery_app.send_task(
"app.core.memory.agent.read_message",
- args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
+ args=[user_input.end_user_id, user_input.message, await session_cache.get_history(), user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
diff --git a/api/app/core/memory/memory_service.py b/api/app/core/memory/memory_service.py
index f695384b..80c7350d 100644
--- a/api/app/core/memory/memory_service.py
+++ b/api/app/core/memory/memory_service.py
@@ -43,10 +43,13 @@ class MemoryService:
self,
query: str,
search_switch: SearchStrategy,
+ history: list | None = None,
limit: int = 10,
) -> MemorySearchResult:
+ if history is None:
+ history = []
with get_db_context() as db:
- return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit)
+ return await ReadPipeLine(self.ctx, db).run(query, search_switch, history, limit)
async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict:
raise NotImplementedError
diff --git a/api/app/core/memory/models/service_models.py b/api/app/core/memory/models/service_models.py
index 6ec0693f..02ecc2fb 100644
--- a/api/app/core/memory/models/service_models.py
+++ b/api/app/core/memory/models/service_models.py
@@ -32,10 +32,12 @@ class Memory(BaseModel):
class MemorySearchResult(BaseModel):
memories: list[Memory]
+ content_str: str = Field(default="")
- @computed_field
@property
def content(self) -> str:
+ if self.content_str:
+ return self.content_str
return "\n".join([memory.content for memory in self.memories])
@computed_field
diff --git a/api/app/core/memory/pipelines/memory_read.py b/api/app/core/memory/pipelines/memory_read.py
index 0bd57b08..37f13ee1 100644
--- a/api/app/core/memory/pipelines/memory_read.py
+++ b/api/app/core/memory/pipelines/memory_read.py
@@ -1,8 +1,9 @@
from app.core.memory.enums import SearchStrategy, StorageType
from app.core.memory.models.service_models import MemorySearchResult
from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline
-from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor
+from app.core.memory.read_services.generate_engine.retrieval_summary import RetrievalSummaryProcessor
+from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService
class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
@@ -10,20 +11,30 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
self,
query: str,
search_switch: SearchStrategy,
+ history: list,
limit: int = 10,
includes=None
) -> MemorySearchResult:
+ memory_l0 = None
+ if self.ctx.storage_type == StorageType.NEO4J:
+ memory_l0 = await self._get_search_service(includes).memory_l0()
+
query = QueryPreprocessor.process(query)
match search_switch:
case SearchStrategy.DEEP:
- return await self._deep_read(query, limit, includes)
+ res = await self._deep_read(query, history, limit, includes)
case SearchStrategy.NORMAL:
- return await self._normal_read(query, limit, includes)
+ res = await self._normal_read(query, history, limit, includes)
case SearchStrategy.QUICK:
- return await self._quick_read(query, limit, includes)
+ res = await self._quick_read(query, limit, includes)
case _:
raise RuntimeError("Unsupported search strategy")
+ if memory_l0 is not None:
+ res.content_str = memory_l0.content + '\n' + res.content
+ res.memories.insert(0, memory_l0)
+ return res
+
def _get_search_service(self, includes=None):
if self.ctx.storage_type == StorageType.NEO4J:
return Neo4jSearchService(
@@ -37,10 +48,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
self.db
)
- async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
+ async def _deep_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
search_service = self._get_search_service(includes)
questions = await QueryPreprocessor.split(
query,
+ history,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
query_results = []
@@ -49,12 +61,18 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
query_results.append(search_results)
results = sum(query_results, start=MemorySearchResult(memories=[]))
results.memories.sort(key=lambda x: x.score, reverse=True)
+ results.content_str = await RetrievalSummaryProcessor.summary(
+ query,
+ results.content,
+ self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
+ )
return results
- async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
+ async def _normal_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult:
search_service = self._get_search_service(includes)
questions = await QueryPreprocessor.split(
query,
+ history,
self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
)
query_results = []
@@ -63,6 +81,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline):
query_results.append(search_results)
results = sum(query_results, start=MemorySearchResult(memories=[]))
results.memories.sort(key=lambda x: x.score, reverse=True)
+ results.content_str = await RetrievalSummaryProcessor.summary(
+ query,
+ results.content,
+ self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id)
+ )
return results
async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult:
diff --git a/api/app/core/memory/prompt/retrieval_summary.jinja2 b/api/app/core/memory/prompt/retrieval_summary.jinja2
new file mode 100644
index 00000000..bb0bddad
--- /dev/null
+++ b/api/app/core/memory/prompt/retrieval_summary.jinja2
@@ -0,0 +1,15 @@
+You are a Content Condenser for a memory-augmented retrieval system.
+
+Your task is to compress the retrieved content while preserving all information that is highly relevant to the user’s query.
+
+Guidelines:
+
+Focus only on content related to the query; ignore irrelevant parts.
+Remove redundancy, filler, or repeated information only for non-XML content.
+Preserve all factual details: names, dates, decisions, code snippets, technical details.
+If relevant information is inside XML tags, do not remove, merge, or compress the XML tags or their internal text; keep them fully intact.
+Structure multiple relevant points as a compact bullet list or paragraph, depending on density.
+If no content is relevant, return exactly: "No relevant information found."
+Do not add any knowledge or facts not in the retrieved content.
+# [IMPORTANT] OUTPUT ONLY THE CONDENSED CONTENT, DO NOT ATTEMPT TO ANSWER THE QUERY.
+# [IMPORTANT] DO NOT REMOVE OR PARAPHRASE HIGHLY RELEVANT INFORMATION.
\ No newline at end of file
diff --git a/api/app/core/memory/read_services/generate_engine/query_preprocessor.py b/api/app/core/memory/read_services/generate_engine/query_preprocessor.py
index 1e234a10..49d2856b 100644
--- a/api/app/core/memory/read_services/generate_engine/query_preprocessor.py
+++ b/api/app/core/memory/read_services/generate_engine/query_preprocessor.py
@@ -21,14 +21,14 @@ class QueryPreprocessor:
return text
@staticmethod
- async def split(query: str, llm_client: RedBearLLM):
+ async def split(query: str, history: list, llm_client: RedBearLLM):
system_prompt = prompt_manager.render(
name="problem_split",
datetime=datetime.now().strftime("%Y-%m-%d"),
)
messages = [
{"role": "system", "content": system_prompt},
- {"role": "user", "content": query},
+ {"role": "user", "content": f"{history}{query}"},
]
try:
sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json')
diff --git a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py
index c46e93f0..c189e45a 100644
--- a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py
+++ b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py
@@ -1,11 +1,29 @@
+import logging
+
from app.core.models import RedBearLLM
+from app.core.memory.prompt import prompt_manager
+from app.core.memory.utils.llm.llm_utils import StructResponse
+
+logger = logging.getLogger(__name__)
class RetrievalSummaryProcessor:
@staticmethod
- def summary(content: str, llm_client: RedBearLLM):
- return
+ async def summary(query, content: str, llm_client: RedBearLLM):
+ system_prompt = prompt_manager.render(
+ name="retrieval_summary"
+ )
+ messages = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": f"{query}{content}"},
+ ]
+ try:
+ summary = await llm_client.ainvoke(messages) | StructResponse(mode='str')
+ return summary
+ except:
+ logger.error("Failed to generate reply summary, returning original content", exc_info=True)
+ return content
@staticmethod
- def verify(content: str, llm_client: RedBearLLM):
+ async def verify(query, content: str, llm_client: RedBearLLM):
return
diff --git a/api/app/core/memory/read_services/search_engine/content_search.py b/api/app/core/memory/read_services/search_engine/content_search.py
index 4ba4dce7..b1ff1738 100644
--- a/api/app/core/memory/read_services/search_engine/content_search.py
+++ b/api/app/core/memory/read_services/search_engine/content_search.py
@@ -14,6 +14,8 @@ from app.core.rag.nlp.search import knowledge_retrieval
from app.repositories import knowledge_repository
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
+from app.core.memory.read_services.search_engine.result_builder import MetadataBuilder
+from app.repositories.neo4j.graph_search import search_user_metadata
logger = logging.getLogger(__name__)
@@ -177,6 +179,22 @@ class Neo4jSearchService:
memories.sort(key=lambda x: x.score, reverse=True)
return MemorySearchResult(memories=memories[:limit])
+ async def memory_l0(self) -> Memory:
+ async with Neo4jConnector() as connector:
+ end_user_id = self.ctx.end_user_id
+ user_meta = await search_user_metadata(connector, end_user_id)
+ metadata = MetadataBuilder(user_meta)
+ memory = Memory(
+ score=1,
+ source=Neo4jNodeType.EXTRACTEDENTITY,
+ query='',
+ id=end_user_id,
+ content=metadata.content,
+ data=metadata.data,
+ )
+
+ return memory
+
class RAGSearchService:
def __init__(self, ctx: MemoryContext, db: Session):
diff --git a/api/app/core/memory/read_services/search_engine/result_builder.py b/api/app/core/memory/read_services/search_engine/result_builder.py
index 1ef04557..8d158ea8 100644
--- a/api/app/core/memory/read_services/search_engine/result_builder.py
+++ b/api/app/core/memory/read_services/search_engine/result_builder.py
@@ -42,7 +42,15 @@ class ChunkBuilder(BaseBuilder):
@property
def content(self) -> str:
- return self.record.get("content")
+ parts = [""]
+ fields = [
+ ("content", self.record.get("content", "")),
+ ]
+ for tag, value in fields:
+ if value:
+ parts.append(f"<{tag}>{value}{tag}>")
+ parts.append("")
+ return "".join(parts)
class StatementBuiler(BaseBuilder):
@@ -57,7 +65,15 @@ class StatementBuiler(BaseBuilder):
@property
def content(self) -> str:
- return self.record.get("statement")
+ parts = [""]
+ fields = [
+ ("statement", self.record.get("statement", "")),
+ ]
+ for tag, value in fields:
+ if value:
+ parts.append(f"<{tag}>{value}{tag}>")
+ parts.append("")
+ return "".join(parts)
class EntityBuilder(BaseBuilder):
@@ -73,10 +89,16 @@ class EntityBuilder(BaseBuilder):
@property
def content(self) -> str:
- return (f""
- f"{self.record.get("name")}"
- f"{self.record.get("description")}"
- f"")
+ parts = [""]
+ fields = [
+ ("name", self.record.get("name", "")),
+ ("description", self.record.get("description", "")),
+ ]
+ for tag, value in fields:
+ if value:
+ parts.append(f"<{tag}>{value}{tag}>")
+ parts.append("")
+ return "".join(parts)
class SummaryBuilder(BaseBuilder):
@@ -91,7 +113,15 @@ class SummaryBuilder(BaseBuilder):
@property
def content(self) -> str:
- return self.record.get("content")
+ parts = [""]
+ fields = [
+ ("content", self.record.get("content", "")),
+ ]
+ for tag, value in fields:
+ if value:
+ parts.append(f"<{tag}>{value}{tag}>")
+ parts.append("")
+ return "".join(parts)
class PerceptualBuilder(BaseBuilder):
@@ -114,15 +144,21 @@ class PerceptualBuilder(BaseBuilder):
@property
def content(self) -> str:
- return (""
- f"{self.record.get('file_name')}"
- f"{self.record.get('file_path')}"
- f"{self.record.get('summary')}"
- f"{self.record.get('topic')}"
- f"{self.record.get('domain')}"
- f"{self.record.get('keywords')}"
- f"{self.record.get('file_type')}"
- "")
+ parts = [""]
+ fields = [
+ ("file-name", self.record.get("file_name", "")),
+ ("file-path", self.record.get("file_path", "")),
+ ("summary", self.record.get("summary", "")),
+ ("topic", self.record.get("topic", "")),
+ ("domain", self.record.get("domain", "")),
+ ("keywords", self.record.get("keywords", [])),
+ ("file-type", self.record.get("file_type", "")),
+ ]
+ for tag, value in fields:
+ if value:
+ parts.append(f"<{tag}>{value}{tag}>")
+ parts.append("")
+ return "".join(parts)
class CommunityBuilder(BaseBuilder):
@@ -137,7 +173,54 @@ class CommunityBuilder(BaseBuilder):
@property
def content(self) -> str:
- return self.record.get("content")
+ parts = [""]
+ fields = [
+ ("content", self.record.get("content", "")),
+ ]
+ for tag, value in fields:
+ if value:
+ parts.append(f"<{tag}>{value}{tag}>")
+ parts.append("")
+ return "".join(parts)
+
+
+class MetadataBuilder(BaseBuilder):
+ @property
+ def data(self) -> dict:
+ return {
+ "id": self.record.get("id", ""),
+ "aliases_name": self.record.get("aliases", []) or [],
+ "description": self.record.get("description", ""),
+ "anchors": self.record.get("anchors", []) or [],
+ "beliefs_or_stances": self.record.get("beliefs_or_stances", []) or [],
+ "core_facts": self.record.get("core_facts", []) or [],
+ "events": self.record.get("events", []) or [],
+ "goals": self.record.get("goals", []) or [],
+ "interests": self.record.get("interests", []) or [],
+ "relations": self.record.get("relations", []) or [],
+ "traits": self.record.get("traits", []) or [],
+ }
+
+ @property
+ def content(self) -> str:
+ parts = [""]
+ fields = [
+ ("description", self.record.get("description", "")),
+ ("aliases", self.record.get("aliases", [])),
+ ("anchors", self.record.get("anchors", [])),
+ ("beliefs_or_stances", self.record.get("beliefs_or_stances", [])),
+ ("core_facts", self.record.get("core_facts", [])),
+ ("events", self.record.get("events", [])),
+ ("goals", self.record.get("goals", [])),
+ ("interests", self.record.get("interests", [])),
+ ("relations", self.record.get("relations", [])),
+ ("traits", self.record.get("traits", [])),
+ ]
+ for tag, value in fields:
+ if value:
+ parts.append(f"<{tag}>{value}{tag}>")
+ parts.append("")
+ return "".join(parts)
def data_builder_factory(node_type, data: dict) -> T:
diff --git a/api/app/core/memory/utils/llm/llm_utils.py b/api/app/core/memory/utils/llm/llm_utils.py
index c4eee82f..1030f2cb 100644
--- a/api/app/core/memory/utils/llm/llm_utils.py
+++ b/api/app/core/memory/utils/llm/llm_utils.py
@@ -17,7 +17,7 @@ async def handle_response(response: type[BaseModel]) -> dict:
class StructResponse:
- def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None):
+ def __init__(self, mode: Literal["json", "pydantic", "str"], model: Type[BaseModel] = None):
self.mode = mode
if mode == "pydantic" and model is None:
raise ValueError("Pydantic model is required")
@@ -31,6 +31,8 @@ class StructResponse:
for block in other.content_blocks:
if block.get("type") == "text":
text += block.get("text", "")
+ if self.mode == "str":
+ return text
fixed_json = json_repair.repair_json(text, return_objects=True)
if self.mode == "json":
return fixed_json
diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py
index d715be7d..43f60211 100644
--- a/api/app/core/workflow/nodes/code/node.py
+++ b/api/app/core/workflow/nodes/code/node.py
@@ -132,7 +132,7 @@ class CodeNode(BaseNode):
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
- f"{settings.SANDBOX_URL}:8194/v1/sandbox/run",
+ f"{settings.SANDBOX_URL}/v1/sandbox/run",
headers={
"x-api-key": 'redbear-sandbox'
},
diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py
index 6d9fcdad..dd8b0915 100644
--- a/api/app/core/workflow/nodes/memory/node.py
+++ b/api/app/core/workflow/nodes/memory/node.py
@@ -40,6 +40,7 @@ class MemoryReadNode(BaseNode):
end_user_id=end_user_id,
user_rag_memory_id=state["user_rag_memory_id"],
)
+ # TODO: Historical Messages -> Used to refer to coreference resolution
search_result = await memory_service.read(
self._render_template(self.typed_config.message, variable_pool),
search_switch=SearchStrategy(self.typed_config.search_switch)
diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py
index a8c36e34..85e77988 100644
--- a/api/app/repositories/neo4j/cypher_queries.py
+++ b/api/app/repositories/neo4j/cypher_queries.py
@@ -1296,6 +1296,7 @@ RETURN e.id AS id,
e.name AS name,
e.end_user_id AS end_user_id,
e.entity_type AS entity_type,
+ e.description AS description,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
@@ -1479,6 +1480,21 @@ ORDER BY score DESC
LIMIT $limit
"""
+SEARCH_USER_METADATA = """
+MATCH (n:ExtractedEntity)
+WHERE (n.end_user_id = $end_user_id AND n.entity_type ='用户')
+RETURN n.description AS description,
+ n.aliases AS aliases,
+ n.anchors AS anchors,
+ n.beliefs_or_stances AS beliefs_or_stances,
+ n.core_facts AS core_facts,
+ n.events AS events,
+ n.goals AS goals,
+ n.interests AS interests,
+ n.relations AS relations,
+ n.traits AS traits
+"""
+
FULLTEXT_QUERY_CYPHER_MAPPING = {
Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_KEYWORD,
Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_NAME_OR_ALIAS,
diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py
index 70913267..ed97256a 100644
--- a/api/app/repositories/neo4j/graph_search.py
+++ b/api/app/repositories/neo4j/graph_search.py
@@ -27,9 +27,9 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_PERCEPTUAL_BY_USER_ID,
FULLTEXT_QUERY_CYPHER_MAPPING,
USER_ID_QUERY_CYPHER_MAPPING,
- NODE_ID_QUERY_CYPHER_MAPPING
+ NODE_ID_QUERY_CYPHER_MAPPING,
+ SEARCH_USER_METADATA
)
-
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
@@ -513,7 +513,7 @@ async def search_graph_by_embedding(
task_keys = []
for node_type in include:
- tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit*2))
+ tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit * 2))
task_keys.append(node_type.value)
task_results = await asyncio.gather(*tasks, return_exceptions=True)
@@ -557,6 +557,17 @@ async def search_graph_by_embedding(
return results
+async def search_user_metadata(
+ connector: Neo4jConnector,
+ end_user_id: str
+) -> dict:
+ user_info = await connector.execute_query(
+ SEARCH_USER_METADATA,
+ end_user_id=end_user_id
+ )
+ return user_info[0] if user_info else {}
+
+
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector,
end_user_id: str,
diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py
index 97aa5bb5..58ac5a94 100644
--- a/api/app/schemas/memory_agent_schema.py
+++ b/api/app/schemas/memory_agent_schema.py
@@ -1,14 +1,15 @@
+import uuid
from abc import ABC
from typing import Optional
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
class UserInput(BaseModel):
message: str
- history: list[dict]
search_switch: str
end_user_id: str
+ session_id: uuid.UUID = Field(default_factory=uuid.uuid4)
config_id: Optional[str] = None
diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py
index 16d856ca..8ebd21a6 100644
--- a/api/app/services/draft_run_service.py
+++ b/api/app/services/draft_run_service.py
@@ -108,6 +108,7 @@ def create_long_term_memory_tool(
try:
with get_db_context() as db:
memory_service = MemoryService(db, config_id, end_user_id)
+ # TODO: Historical Messages -> Used to refer to coreference resolution
search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK))
# memory_content = asyncio.run(
diff --git a/api/app/utils/__init__.py b/api/app/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/api/app/utils/tmp_session.py b/api/app/utils/tmp_session.py
new file mode 100644
index 00000000..14959ec0
--- /dev/null
+++ b/api/app/utils/tmp_session.py
@@ -0,0 +1,77 @@
+import json
+import logging
+
+import redis.asyncio as redis
+
+from app.aioRedis import get_redis_connection
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_TTL = 3600
+
+
+class ChatSessionCache:
+ """Cache user-AI conversation history in Redis with TTL-based expiry.
+
+ Usage::
+
+ cache = ChatSessionCache(session_id="user_123")
+ await cache.append("user", "Hello")
+ await cache.append("assistant", "Hi there!")
+ history = await cache.get_history()
+ """
+
+ def __init__(self, session_id: str, ttl: int = DEFAULT_TTL):
+ self.session_id = session_id
+ self.ttl = ttl
+ self._key = f"chat:session:{session_id}"
+
+ @staticmethod
+ async def _client() -> redis.StrictRedis:
+ return await get_redis_connection()
+
+ async def append(self, role: str, content: str) -> None:
+ r = await self._client()
+ entry = json.dumps({"role": role, "content": content}, ensure_ascii=False)
+ await r.rpush(self._key, entry)
+ await r.expire(self._key, self.ttl)
+
+ async def append_many(self, messages: list[dict[str, str]]) -> None:
+ """Batch append messages. Each dict should have ``role`` and ``content`` keys."""
+ if not messages:
+ return
+ r = await self._client()
+ entries = [
+ json.dumps(m, ensure_ascii=False)
+ for m in messages
+ if "role" in m and "content" in m
+ ]
+ if entries:
+ await r.rpush(self._key, *entries)
+ await r.expire(self._key, self.ttl)
+
+ async def get_history(self) -> list[dict[str, str]]:
+ r = await self._client()
+ raw = await r.lrange(self._key, 0, -1)
+ return [json.loads(item) for item in raw]
+
+ async def get_history_text(self, user_label: str = "User", ai_label: str = "Assistant") -> str:
+ """Return conversation as a formatted text block."""
+ history = await self.get_history()
+ lines = []
+ for msg in history:
+ role = msg.get("role", "")
+ content = msg.get("content", "")
+ label = user_label if role == "user" else ai_label if role == "assistant" else role
+ lines.append(f"{label}: {content}")
+ return "\n".join(lines)
+
+ async def reset(self) -> None:
+ """Delete the session from Redis."""
+ r = await self._client()
+ await r.delete(self._key)
+
+ async def touch(self) -> None:
+ """Refresh the TTL without modifying data."""
+ r = await self._client()
+ await r.expire(self._key, self.ttl)