From 3f9740412a7796c6046bf56c67a4bb9734eb24a9 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Wed, 6 May 2026 16:24:53 +0800 Subject: [PATCH] feat(memory): add session-based chat history and user metadata retrieval - Add ChatSessionCache to manage chat history per session - Add SEARCH_USER_METADATA cypher query for retrieving user entity metadata - Add "str" mode support to StructResponse for raw text extraction - Add content_str field to MemorySearchResult for pre-formatted content - Fix sandbox URL by removing hardcoded port - Add description field to entity search results - Remove history from UserInput schema, use session_id instead --- api/app/celery_task_scheduler.py | 87 +++++++------ .../controllers/memory_agent_controller.py | 72 +++++------ api/app/core/memory/memory_service.py | 5 +- api/app/core/memory/models/service_models.py | 4 +- api/app/core/memory/pipelines/memory_read.py | 35 +++++- .../memory/prompt/retrieval_summary.jinja2 | 15 +++ .../generate_engine/query_preprocessor.py | 4 +- .../generate_engine/retrieval_summary.py | 24 +++- .../search_engine/content_search.py | 18 +++ .../search_engine/result_builder.py | 117 +++++++++++++++--- api/app/core/memory/utils/llm/llm_utils.py | 4 +- api/app/core/workflow/nodes/code/node.py | 2 +- api/app/core/workflow/nodes/memory/node.py | 1 + api/app/repositories/neo4j/cypher_queries.py | 16 +++ api/app/repositories/neo4j/graph_search.py | 17 ++- api/app/schemas/memory_agent_schema.py | 5 +- api/app/services/draft_run_service.py | 1 + api/app/utils/__init__.py | 0 api/app/utils/tmp_session.py | 77 ++++++++++++ 19 files changed, 387 insertions(+), 117 deletions(-) create mode 100644 api/app/core/memory/prompt/retrieval_summary.jinja2 create mode 100644 api/app/utils/__init__.py create mode 100644 api/app/utils/tmp_session.py diff --git a/api/app/celery_task_scheduler.py b/api/app/celery_task_scheduler.py index e7f946b6..9b8756cd 100644 --- a/api/app/celery_task_scheduler.py +++ b/api/app/celery_task_scheduler.py @@ -158,12 +158,19 @@ class RedisTaskScheduler: return {"status": status, "task_id": task_id, "result": result_content} def _cleanup_finished(self): - pending = self.redis.hgetall(PENDING_HASH) - if not pending: + cursor = 0 + all_pending = {} + while True: + cursor, batch = self.redis.hscan(PENDING_HASH, cursor=cursor, count=100) + all_pending.update(batch) + if cursor == 0: + break + + if not all_pending: return now = time.time() - task_ids = list(pending.keys()) + task_ids = list(all_pending.keys()) pipe = self.redis.pipeline() for task_id in task_ids: @@ -176,7 +183,7 @@ class RedisTaskScheduler: for task_id, raw_result in zip(task_ids, results): try: - meta = json.loads(pending[task_id]) + meta = json.loads(all_pending[task_id]) lock_key = meta["lock_key"] dispatched_at = meta.get("dispatched_at", 0) age = now - dispatched_at @@ -276,6 +283,22 @@ class RedisTaskScheduler: return True return stable_hash(user_id) % self._shard_count == self._shard_index + def _commit_post_dispatch(self, lock_key, task, msg_id, dispatch_lock): + pipe = self.redis.pipeline() + pipe.set(lock_key, task.id, ex=3600) + pipe.hset(PENDING_HASH, task.id, json.dumps({ + "lock_key": lock_key, + "dispatched_at": time.time(), + "msg_id": msg_id, + })) + pipe.delete(dispatch_lock) + pipe.set( + f"task_tracker:{msg_id}", + json.dumps({"status": "DISPATCHED", "task_id": task.id}), + ex=86400, + ) + pipe.execute() + def _dispatch(self, msg_id, msg_data) -> bool: user_id = msg_data["user_id"] task_name = msg_data["task_name"] @@ -308,28 +331,17 @@ class RedisTaskScheduler: task_name, user_id, msg_id, e, exc_info=True, ) return False - - try: - pipe = self.redis.pipeline() - pipe.set(lock_key, task.id, ex=3600) - pipe.hset(PENDING_HASH, task.id, json.dumps({ - "lock_key": lock_key, - "dispatched_at": time.time(), - "msg_id": msg_id, - })) - pipe.delete(dispatch_lock) - pipe.set( - f"task_tracker:{msg_id}", - json.dumps({"status": "DISPATCHED", "task_id": task.id}), - ex=86400, - ) - pipe.execute() - except Exception as e: - logger.error( - "Post-dispatch state update failed for %s: %s", - task.id, e, exc_info=True, - ) - self.errors += 1 + for attempt in range(2): + try: + self._commit_post_dispatch(lock_key, task, msg_id, dispatch_lock) + break + except Exception as e: + logger.error( + "Post-dispatch state update failed for %s: %s", + task.id, e, exc_info=True, + ) + time.sleep(0.1) + self.errors += 1 self.dispatched += 1 logger.info("Task dispatched: %s (msg=%s)", task.id, msg_id) @@ -367,22 +379,21 @@ class RedisTaskScheduler: return for uid, msg in candidates: + queue_key = f"{USER_QUEUE_PREFIX}{uid}" if self._dispatch(msg["msg_id"], msg): - self.redis.lpop(f"{USER_QUEUE_PREFIX}{uid}") + self.redis.lpop(queue_key) + if self.redis.llen(queue_key) > 0: + self.redis.sadd(READY_SET, uid) def schedule_loop(self): self._heartbeat() self._cleanup_finished() - pipe = self.redis.pipeline() - pipe.smembers(READY_SET) - pipe.delete(READY_SET) - results = pipe.execute() - ready_users = results[0] or set() - + ready_users = self.redis.smembers(READY_SET) or set() my_users = [uid for uid in ready_users if self._is_mine(uid)] - - if not my_users: + if my_users: + self.redis.srem(READY_SET, *my_users) + else: time.sleep(0.5) return @@ -445,7 +456,7 @@ class RedisTaskScheduler: "Scheduler started: instance=%s", self.instance_id, ) - while True: + while self.running: try: self.schedule_loop() @@ -480,9 +491,7 @@ class RedisTaskScheduler: logger.error("Shutdown cleanup error: %s", e) -scheduler: RedisTaskScheduler | None = None -if scheduler is None: - scheduler = RedisTaskScheduler() +scheduler = RedisTaskScheduler() if __name__ == "__main__": import signal diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index cba17f42..c9d41494 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -27,6 +27,7 @@ from app.services import task_service, workspace_service from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import get_end_user_connected_config as get_config from app.services.model_service import ModelConfigService +from app.utils.tmp_session import ChatSessionCache load_dotenv() api_logger = get_api_logger() @@ -300,60 +301,39 @@ async def read_server( if knowledge: user_rag_memory_id = str(knowledge.id) + session_id = user_input.session_id.hex + api_logger.info( - f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") + f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}, session_id={session_id}") try: - # result = await memory_agent_service.read_memory( - # user_input.end_user_id, - # user_input.message, - # user_input.history, - # user_input.search_switch, - # config_id, - # db, - # storage_type, - # user_rag_memory_id - # ) - # if str(user_input.search_switch) == "2": - # retrieve_info = result['answer'] - # history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, - # user_input.end_user_id) - # query = user_input.message - # - # # 调用 memory_agent_service 的方法生成最终答案 - # result['answer'] = await memory_agent_service.generate_summary_from_retrieve( - # end_user_id=user_input.end_user_id, - # retrieve_info=retrieve_info, - # history=history, - # query=query, - # config_id=config_id, - # db=db - # ) - # if "信息不足,无法回答" in result['answer']: - # result['answer'] = retrieve_info memory_config = get_config(user_input.end_user_id, db) service = MemoryService( db, memory_config["memory_config_id"], end_user_id=user_input.end_user_id ) + session_cache = ChatSessionCache(session_id) search_result = await service.read( user_input.message, - SearchStrategy(user_input.search_switch) + SearchStrategy(user_input.search_switch), + history=await session_cache.get_history(), ) intermediate_outputs = [] sub_queries = set() for memory in search_result.memories: sub_queries.add(str(memory.query)) + idx = 0 if user_input.search_switch in [SearchStrategy.DEEP, SearchStrategy.NORMAL]: intermediate_outputs.append({ "type": "problem_split", "title": "问题拆分", "data": [ { - "id": f"Q{idx+1}", + "id": f"Q{(idx := idx + 1)}", "question": question } - for idx, question in enumerate(sub_queries) + for question in sub_queries + if question ] }) perceptual_data = [ @@ -375,16 +355,24 @@ async def read_server( "raw_result": search_result.memories, "total": len(search_result.memories), }) + answer = await memory_agent_service.generate_summary_from_retrieve( + end_user_id=user_input.end_user_id, + retrieve_info=search_result.content, + history=[], + query=user_input.message, + config_id=config_id, + db=db + ) + await session_cache.append_many( + [ + {"role": "user", "content": user_input.message}, + {"role": "assistant", "content": answer} + ] + ) result = { - 'answer': await memory_agent_service.generate_summary_from_retrieve( - end_user_id=user_input.end_user_id, - retrieve_info=search_result.content, - history=[], - query=user_input.message, - config_id=config_id, - db=db - ), - "intermediate_outputs": intermediate_outputs + 'answer': answer, + "intermediate_outputs": intermediate_outputs, + "session_id": session_id, } return success(data=result, msg="回复对话消息成功") @@ -480,9 +468,11 @@ async def read_server_async( if knowledge: user_rag_memory_id = str(knowledge.id) api_logger.info(f"Async read: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") try: + session_id = user_input.session_id.hex + session_cache = ChatSessionCache(session_id) task = celery_app.send_task( "app.core.memory.agent.read_message", - args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch, + args=[user_input.end_user_id, user_input.message, await session_cache.get_history(), user_input.search_switch, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Read task queued: {task.id}") diff --git a/api/app/core/memory/memory_service.py b/api/app/core/memory/memory_service.py index f695384b..80c7350d 100644 --- a/api/app/core/memory/memory_service.py +++ b/api/app/core/memory/memory_service.py @@ -43,10 +43,13 @@ class MemoryService: self, query: str, search_switch: SearchStrategy, + history: list | None = None, limit: int = 10, ) -> MemorySearchResult: + if history is None: + history = [] with get_db_context() as db: - return await ReadPipeLine(self.ctx, db).run(query, search_switch, limit) + return await ReadPipeLine(self.ctx, db).run(query, search_switch, history, limit) async def forget(self, max_batch: int = 100, min_days: int = 30) -> dict: raise NotImplementedError diff --git a/api/app/core/memory/models/service_models.py b/api/app/core/memory/models/service_models.py index 6ec0693f..02ecc2fb 100644 --- a/api/app/core/memory/models/service_models.py +++ b/api/app/core/memory/models/service_models.py @@ -32,10 +32,12 @@ class Memory(BaseModel): class MemorySearchResult(BaseModel): memories: list[Memory] + content_str: str = Field(default="") - @computed_field @property def content(self) -> str: + if self.content_str: + return self.content_str return "\n".join([memory.content for memory in self.memories]) @computed_field diff --git a/api/app/core/memory/pipelines/memory_read.py b/api/app/core/memory/pipelines/memory_read.py index 0bd57b08..37f13ee1 100644 --- a/api/app/core/memory/pipelines/memory_read.py +++ b/api/app/core/memory/pipelines/memory_read.py @@ -1,8 +1,9 @@ from app.core.memory.enums import SearchStrategy, StorageType from app.core.memory.models.service_models import MemorySearchResult from app.core.memory.pipelines.base_pipeline import ModelClientMixin, DBRequiredPipeline -from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService from app.core.memory.read_services.generate_engine.query_preprocessor import QueryPreprocessor +from app.core.memory.read_services.generate_engine.retrieval_summary import RetrievalSummaryProcessor +from app.core.memory.read_services.search_engine.content_search import Neo4jSearchService, RAGSearchService class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): @@ -10,20 +11,30 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): self, query: str, search_switch: SearchStrategy, + history: list, limit: int = 10, includes=None ) -> MemorySearchResult: + memory_l0 = None + if self.ctx.storage_type == StorageType.NEO4J: + memory_l0 = await self._get_search_service(includes).memory_l0() + query = QueryPreprocessor.process(query) match search_switch: case SearchStrategy.DEEP: - return await self._deep_read(query, limit, includes) + res = await self._deep_read(query, history, limit, includes) case SearchStrategy.NORMAL: - return await self._normal_read(query, limit, includes) + res = await self._normal_read(query, history, limit, includes) case SearchStrategy.QUICK: - return await self._quick_read(query, limit, includes) + res = await self._quick_read(query, limit, includes) case _: raise RuntimeError("Unsupported search strategy") + if memory_l0 is not None: + res.content_str = memory_l0.content + '\n' + res.content + res.memories.insert(0, memory_l0) + return res + def _get_search_service(self, includes=None): if self.ctx.storage_type == StorageType.NEO4J: return Neo4jSearchService( @@ -37,10 +48,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): self.db ) - async def _deep_read(self, query: str, limit: int, includes=None) -> MemorySearchResult: + async def _deep_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult: search_service = self._get_search_service(includes) questions = await QueryPreprocessor.split( query, + history, self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id) ) query_results = [] @@ -49,12 +61,18 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): query_results.append(search_results) results = sum(query_results, start=MemorySearchResult(memories=[])) results.memories.sort(key=lambda x: x.score, reverse=True) + results.content_str = await RetrievalSummaryProcessor.summary( + query, + results.content, + self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id) + ) return results - async def _normal_read(self, query: str, limit: int, includes=None) -> MemorySearchResult: + async def _normal_read(self, query: str, history: list, limit: int, includes=None) -> MemorySearchResult: search_service = self._get_search_service(includes) questions = await QueryPreprocessor.split( query, + history, self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id) ) query_results = [] @@ -63,6 +81,11 @@ class ReadPipeLine(ModelClientMixin, DBRequiredPipeline): query_results.append(search_results) results = sum(query_results, start=MemorySearchResult(memories=[])) results.memories.sort(key=lambda x: x.score, reverse=True) + results.content_str = await RetrievalSummaryProcessor.summary( + query, + results.content, + self.get_llm_client(self.db, self.ctx.memory_config.llm_model_id) + ) return results async def _quick_read(self, query: str, limit: int, includes=None) -> MemorySearchResult: diff --git a/api/app/core/memory/prompt/retrieval_summary.jinja2 b/api/app/core/memory/prompt/retrieval_summary.jinja2 new file mode 100644 index 00000000..bb0bddad --- /dev/null +++ b/api/app/core/memory/prompt/retrieval_summary.jinja2 @@ -0,0 +1,15 @@ +You are a Content Condenser for a memory-augmented retrieval system. + +Your task is to compress the retrieved content while preserving all information that is highly relevant to the user’s query. + +Guidelines: + +Focus only on content related to the query; ignore irrelevant parts. +Remove redundancy, filler, or repeated information only for non-XML content. +Preserve all factual details: names, dates, decisions, code snippets, technical details. +If relevant information is inside XML tags, do not remove, merge, or compress the XML tags or their internal text; keep them fully intact. +Structure multiple relevant points as a compact bullet list or paragraph, depending on density. +If no content is relevant, return exactly: "No relevant information found." +Do not add any knowledge or facts not in the retrieved content. +# [IMPORTANT] OUTPUT ONLY THE CONDENSED CONTENT, DO NOT ATTEMPT TO ANSWER THE QUERY. +# [IMPORTANT] DO NOT REMOVE OR PARAPHRASE HIGHLY RELEVANT INFORMATION. \ No newline at end of file diff --git a/api/app/core/memory/read_services/generate_engine/query_preprocessor.py b/api/app/core/memory/read_services/generate_engine/query_preprocessor.py index 1e234a10..49d2856b 100644 --- a/api/app/core/memory/read_services/generate_engine/query_preprocessor.py +++ b/api/app/core/memory/read_services/generate_engine/query_preprocessor.py @@ -21,14 +21,14 @@ class QueryPreprocessor: return text @staticmethod - async def split(query: str, llm_client: RedBearLLM): + async def split(query: str, history: list, llm_client: RedBearLLM): system_prompt = prompt_manager.render( name="problem_split", datetime=datetime.now().strftime("%Y-%m-%d"), ) messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": query}, + {"role": "user", "content": f"{history}{query}"}, ] try: sub_queries = await llm_client.ainvoke(messages) | StructResponse(mode='json') diff --git a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py index c46e93f0..c189e45a 100644 --- a/api/app/core/memory/read_services/generate_engine/retrieval_summary.py +++ b/api/app/core/memory/read_services/generate_engine/retrieval_summary.py @@ -1,11 +1,29 @@ +import logging + from app.core.models import RedBearLLM +from app.core.memory.prompt import prompt_manager +from app.core.memory.utils.llm.llm_utils import StructResponse + +logger = logging.getLogger(__name__) class RetrievalSummaryProcessor: @staticmethod - def summary(content: str, llm_client: RedBearLLM): - return + async def summary(query, content: str, llm_client: RedBearLLM): + system_prompt = prompt_manager.render( + name="retrieval_summary" + ) + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"{query}{content}"}, + ] + try: + summary = await llm_client.ainvoke(messages) | StructResponse(mode='str') + return summary + except: + logger.error("Failed to generate reply summary, returning original content", exc_info=True) + return content @staticmethod - def verify(content: str, llm_client: RedBearLLM): + async def verify(query, content: str, llm_client: RedBearLLM): return diff --git a/api/app/core/memory/read_services/search_engine/content_search.py b/api/app/core/memory/read_services/search_engine/content_search.py index 4ba4dce7..b1ff1738 100644 --- a/api/app/core/memory/read_services/search_engine/content_search.py +++ b/api/app/core/memory/read_services/search_engine/content_search.py @@ -14,6 +14,8 @@ from app.core.rag.nlp.search import knowledge_retrieval from app.repositories import knowledge_repository from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.read_services.search_engine.result_builder import MetadataBuilder +from app.repositories.neo4j.graph_search import search_user_metadata logger = logging.getLogger(__name__) @@ -177,6 +179,22 @@ class Neo4jSearchService: memories.sort(key=lambda x: x.score, reverse=True) return MemorySearchResult(memories=memories[:limit]) + async def memory_l0(self) -> Memory: + async with Neo4jConnector() as connector: + end_user_id = self.ctx.end_user_id + user_meta = await search_user_metadata(connector, end_user_id) + metadata = MetadataBuilder(user_meta) + memory = Memory( + score=1, + source=Neo4jNodeType.EXTRACTEDENTITY, + query='', + id=end_user_id, + content=metadata.content, + data=metadata.data, + ) + + return memory + class RAGSearchService: def __init__(self, ctx: MemoryContext, db: Session): diff --git a/api/app/core/memory/read_services/search_engine/result_builder.py b/api/app/core/memory/read_services/search_engine/result_builder.py index 1ef04557..8d158ea8 100644 --- a/api/app/core/memory/read_services/search_engine/result_builder.py +++ b/api/app/core/memory/read_services/search_engine/result_builder.py @@ -42,7 +42,15 @@ class ChunkBuilder(BaseBuilder): @property def content(self) -> str: - return self.record.get("content") + parts = [""] + fields = [ + ("content", self.record.get("content", "")), + ] + for tag, value in fields: + if value: + parts.append(f"<{tag}>{value}") + parts.append("") + return "".join(parts) class StatementBuiler(BaseBuilder): @@ -57,7 +65,15 @@ class StatementBuiler(BaseBuilder): @property def content(self) -> str: - return self.record.get("statement") + parts = [""] + fields = [ + ("statement", self.record.get("statement", "")), + ] + for tag, value in fields: + if value: + parts.append(f"<{tag}>{value}") + parts.append("") + return "".join(parts) class EntityBuilder(BaseBuilder): @@ -73,10 +89,16 @@ class EntityBuilder(BaseBuilder): @property def content(self) -> str: - return (f"" - f"{self.record.get("name")}" - f"{self.record.get("description")}" - f"") + parts = [""] + fields = [ + ("name", self.record.get("name", "")), + ("description", self.record.get("description", "")), + ] + for tag, value in fields: + if value: + parts.append(f"<{tag}>{value}") + parts.append("") + return "".join(parts) class SummaryBuilder(BaseBuilder): @@ -91,7 +113,15 @@ class SummaryBuilder(BaseBuilder): @property def content(self) -> str: - return self.record.get("content") + parts = [""] + fields = [ + ("content", self.record.get("content", "")), + ] + for tag, value in fields: + if value: + parts.append(f"<{tag}>{value}") + parts.append("") + return "".join(parts) class PerceptualBuilder(BaseBuilder): @@ -114,15 +144,21 @@ class PerceptualBuilder(BaseBuilder): @property def content(self) -> str: - return ("" - f"{self.record.get('file_name')}" - f"{self.record.get('file_path')}" - f"{self.record.get('summary')}" - f"{self.record.get('topic')}" - f"{self.record.get('domain')}" - f"{self.record.get('keywords')}" - f"{self.record.get('file_type')}" - "") + parts = [""] + fields = [ + ("file-name", self.record.get("file_name", "")), + ("file-path", self.record.get("file_path", "")), + ("summary", self.record.get("summary", "")), + ("topic", self.record.get("topic", "")), + ("domain", self.record.get("domain", "")), + ("keywords", self.record.get("keywords", [])), + ("file-type", self.record.get("file_type", "")), + ] + for tag, value in fields: + if value: + parts.append(f"<{tag}>{value}") + parts.append("") + return "".join(parts) class CommunityBuilder(BaseBuilder): @@ -137,7 +173,54 @@ class CommunityBuilder(BaseBuilder): @property def content(self) -> str: - return self.record.get("content") + parts = [""] + fields = [ + ("content", self.record.get("content", "")), + ] + for tag, value in fields: + if value: + parts.append(f"<{tag}>{value}") + parts.append("") + return "".join(parts) + + +class MetadataBuilder(BaseBuilder): + @property + def data(self) -> dict: + return { + "id": self.record.get("id", ""), + "aliases_name": self.record.get("aliases", []) or [], + "description": self.record.get("description", ""), + "anchors": self.record.get("anchors", []) or [], + "beliefs_or_stances": self.record.get("beliefs_or_stances", []) or [], + "core_facts": self.record.get("core_facts", []) or [], + "events": self.record.get("events", []) or [], + "goals": self.record.get("goals", []) or [], + "interests": self.record.get("interests", []) or [], + "relations": self.record.get("relations", []) or [], + "traits": self.record.get("traits", []) or [], + } + + @property + def content(self) -> str: + parts = [""] + fields = [ + ("description", self.record.get("description", "")), + ("aliases", self.record.get("aliases", [])), + ("anchors", self.record.get("anchors", [])), + ("beliefs_or_stances", self.record.get("beliefs_or_stances", [])), + ("core_facts", self.record.get("core_facts", [])), + ("events", self.record.get("events", [])), + ("goals", self.record.get("goals", [])), + ("interests", self.record.get("interests", [])), + ("relations", self.record.get("relations", [])), + ("traits", self.record.get("traits", [])), + ] + for tag, value in fields: + if value: + parts.append(f"<{tag}>{value}") + parts.append("") + return "".join(parts) def data_builder_factory(node_type, data: dict) -> T: diff --git a/api/app/core/memory/utils/llm/llm_utils.py b/api/app/core/memory/utils/llm/llm_utils.py index c4eee82f..1030f2cb 100644 --- a/api/app/core/memory/utils/llm/llm_utils.py +++ b/api/app/core/memory/utils/llm/llm_utils.py @@ -17,7 +17,7 @@ async def handle_response(response: type[BaseModel]) -> dict: class StructResponse: - def __init__(self, mode: Literal["json", "pydantic"], model: Type[BaseModel] = None): + def __init__(self, mode: Literal["json", "pydantic", "str"], model: Type[BaseModel] = None): self.mode = mode if mode == "pydantic" and model is None: raise ValueError("Pydantic model is required") @@ -31,6 +31,8 @@ class StructResponse: for block in other.content_blocks: if block.get("type") == "text": text += block.get("text", "") + if self.mode == "str": + return text fixed_json = json_repair.repair_json(text, return_objects=True) if self.mode == "json": return fixed_json diff --git a/api/app/core/workflow/nodes/code/node.py b/api/app/core/workflow/nodes/code/node.py index d715be7d..43f60211 100644 --- a/api/app/core/workflow/nodes/code/node.py +++ b/api/app/core/workflow/nodes/code/node.py @@ -132,7 +132,7 @@ class CodeNode(BaseNode): async with httpx.AsyncClient(timeout=60) as client: response = await client.post( - f"{settings.SANDBOX_URL}:8194/v1/sandbox/run", + f"{settings.SANDBOX_URL}/v1/sandbox/run", headers={ "x-api-key": 'redbear-sandbox' }, diff --git a/api/app/core/workflow/nodes/memory/node.py b/api/app/core/workflow/nodes/memory/node.py index 6d9fcdad..dd8b0915 100644 --- a/api/app/core/workflow/nodes/memory/node.py +++ b/api/app/core/workflow/nodes/memory/node.py @@ -40,6 +40,7 @@ class MemoryReadNode(BaseNode): end_user_id=end_user_id, user_rag_memory_id=state["user_rag_memory_id"], ) + # TODO: Historical Messages -> Used to refer to coreference resolution search_result = await memory_service.read( self._render_template(self.typed_config.message, variable_pool), search_switch=SearchStrategy(self.typed_config.search_switch) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index a8c36e34..85e77988 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -1296,6 +1296,7 @@ RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type, + e.description AS description, COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, COALESCE(e.importance_score, 0.5) AS importance_score, e.last_access_time AS last_access_time, @@ -1479,6 +1480,21 @@ ORDER BY score DESC LIMIT $limit """ +SEARCH_USER_METADATA = """ +MATCH (n:ExtractedEntity) +WHERE (n.end_user_id = $end_user_id AND n.entity_type ='用户') +RETURN n.description AS description, + n.aliases AS aliases, + n.anchors AS anchors, + n.beliefs_or_stances AS beliefs_or_stances, + n.core_facts AS core_facts, + n.events AS events, + n.goals AS goals, + n.interests AS interests, + n.relations AS relations, + n.traits AS traits +""" + FULLTEXT_QUERY_CYPHER_MAPPING = { Neo4jNodeType.STATEMENT: SEARCH_STATEMENTS_BY_KEYWORD, Neo4jNodeType.EXTRACTEDENTITY: SEARCH_ENTITIES_BY_NAME_OR_ALIAS, diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 70913267..ed97256a 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -27,9 +27,9 @@ from app.repositories.neo4j.cypher_queries import ( SEARCH_PERCEPTUAL_BY_USER_ID, FULLTEXT_QUERY_CYPHER_MAPPING, USER_ID_QUERY_CYPHER_MAPPING, - NODE_ID_QUERY_CYPHER_MAPPING + NODE_ID_QUERY_CYPHER_MAPPING, + SEARCH_USER_METADATA ) - from app.repositories.neo4j.neo4j_connector import Neo4jConnector logger = logging.getLogger(__name__) @@ -513,7 +513,7 @@ async def search_graph_by_embedding( task_keys = [] for node_type in include: - tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit*2)) + tasks.append(search_by_embedding(connector, node_type, end_user_id, embedding, limit * 2)) task_keys.append(node_type.value) task_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -557,6 +557,17 @@ async def search_graph_by_embedding( return results +async def search_user_metadata( + connector: Neo4jConnector, + end_user_id: str +) -> dict: + user_info = await connector.execute_query( + SEARCH_USER_METADATA, + end_user_id=end_user_id + ) + return user_info[0] if user_info else {} + + async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 connector: Neo4jConnector, end_user_id: str, diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 97aa5bb5..58ac5a94 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,14 +1,15 @@ +import uuid from abc import ABC from typing import Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field class UserInput(BaseModel): message: str - history: list[dict] search_switch: str end_user_id: str + session_id: uuid.UUID = Field(default_factory=uuid.uuid4) config_id: Optional[str] = None diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 16d856ca..8ebd21a6 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -108,6 +108,7 @@ def create_long_term_memory_tool( try: with get_db_context() as db: memory_service = MemoryService(db, config_id, end_user_id) + # TODO: Historical Messages -> Used to refer to coreference resolution search_result = asyncio.run(memory_service.read(question, SearchStrategy.QUICK)) # memory_content = asyncio.run( diff --git a/api/app/utils/__init__.py b/api/app/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/utils/tmp_session.py b/api/app/utils/tmp_session.py new file mode 100644 index 00000000..14959ec0 --- /dev/null +++ b/api/app/utils/tmp_session.py @@ -0,0 +1,77 @@ +import json +import logging + +import redis.asyncio as redis + +from app.aioRedis import get_redis_connection + +logger = logging.getLogger(__name__) + +DEFAULT_TTL = 3600 + + +class ChatSessionCache: + """Cache user-AI conversation history in Redis with TTL-based expiry. + + Usage:: + + cache = ChatSessionCache(session_id="user_123") + await cache.append("user", "Hello") + await cache.append("assistant", "Hi there!") + history = await cache.get_history() + """ + + def __init__(self, session_id: str, ttl: int = DEFAULT_TTL): + self.session_id = session_id + self.ttl = ttl + self._key = f"chat:session:{session_id}" + + @staticmethod + async def _client() -> redis.StrictRedis: + return await get_redis_connection() + + async def append(self, role: str, content: str) -> None: + r = await self._client() + entry = json.dumps({"role": role, "content": content}, ensure_ascii=False) + await r.rpush(self._key, entry) + await r.expire(self._key, self.ttl) + + async def append_many(self, messages: list[dict[str, str]]) -> None: + """Batch append messages. Each dict should have ``role`` and ``content`` keys.""" + if not messages: + return + r = await self._client() + entries = [ + json.dumps(m, ensure_ascii=False) + for m in messages + if "role" in m and "content" in m + ] + if entries: + await r.rpush(self._key, *entries) + await r.expire(self._key, self.ttl) + + async def get_history(self) -> list[dict[str, str]]: + r = await self._client() + raw = await r.lrange(self._key, 0, -1) + return [json.loads(item) for item in raw] + + async def get_history_text(self, user_label: str = "User", ai_label: str = "Assistant") -> str: + """Return conversation as a formatted text block.""" + history = await self.get_history() + lines = [] + for msg in history: + role = msg.get("role", "") + content = msg.get("content", "") + label = user_label if role == "user" else ai_label if role == "assistant" else role + lines.append(f"{label}: {content}") + return "\n".join(lines) + + async def reset(self) -> None: + """Delete the session from Redis.""" + r = await self._client() + await r.delete(self._key) + + async def touch(self) -> None: + """Refresh the TTL without modifying data.""" + r = await self._client() + await r.expire(self._key, self.ttl)