diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 3e7db8cb..727cd041 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -64,6 +64,11 @@ celery_app.conf.update( 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, + # Long-term storage tasks → memory_tasks queue (batched write strategies) + 'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, + # Document tasks → document_tasks queue (prefork worker) 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 441609ac..58cf9443 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -148,6 +148,7 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages + # TODO: 移到memory module async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): db = next(get_db()) scope=6 @@ -307,9 +308,12 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: long_term_messages=await agent_chat_messages(message_chat,content) - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. + # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j + # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. + # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - '''长期''' + # Batched long-term memory storage (Redis buffer + Neo4j when window full) await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") response = { "content": content, @@ -441,9 +445,13 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. + # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j + # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. + # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. long_term_messages = await agent_chat_messages(message_chat, full_content) await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) + # Batched long-term memory storage (Redis buffer + Neo4j when window full) await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index d6fbbb38..6650b9b0 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config): for node_name, node_data in update_event.items(): if 'save_neo4j' == node_name: massages = node_data + # TODO:删除 massagesstatus = massages.get('write_result')['status'] contents = massages.get('write_result') print(contents) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index d0e8a45d..84ea9381 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,18 +1,14 @@ import asyncio -import json import sys import warnings from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse -from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) @@ -40,27 +36,55 @@ async def make_write_graph(): yield graph async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment - from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format - from app.core.memory.agent.utils.redis_tool import write_store - write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) - # 获取数据库会话 - db_session = next(get_db()) - config_service = MemoryConfigService(db_session) - memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 - service_name="MemoryAgentService" + """Dispatch long-term memory storage to Celery background tasks. + + Args: + long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate' + langchain_messages: List of messages to store + memory_config: Memory configuration ID (string) + end_user_id: End user identifier + scope: Window size for 'chunk' strategy (default: 6) + """ + from app.tasks import ( + long_term_storage_window_task, + # TODO: Uncomment when implemented + # long_term_storage_time_task, + # long_term_storage_aggregate_task, ) - if long_term_type=='chunk': - '''方案一:对话窗口6轮对话''' - await window_dialogue(end_user_id,langchain_messages,memory_config,scope) - if long_term_type=='time': - """时间""" - await memory_long_term_storage(end_user_id, memory_config,5) - if long_term_type=='aggregate': - - """方案三:聚合判断""" - await aggregate_judgment(end_user_id, langchain_messages, memory_config) + from app.core.logging_config import get_logger + + logger = get_logger(__name__) + + # Convert config to string if needed + config_id = str(memory_config) if memory_config else '' + + if long_term_type == 'chunk': + # Strategy 1: Window-based batching (6 rounds of dialogue) + logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}") + long_term_storage_window_task.delay( + end_user_id=end_user_id, + langchain_messages=langchain_messages, + config_id=config_id, + scope=scope + ) + # TODO: Uncomment when time-based strategy is fully implemented + # elif long_term_type == 'time': + # # Strategy 2: Time-based retrieval + # logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}") + # long_term_storage_time_task.delay( + # end_user_id=end_user_id, + # config_id=config_id, + # time_window=5 + # ) + # TODO: Uncomment when aggregate strategy is fully implemented + # elif long_term_type == 'aggregate': + # # Strategy 3: Aggregate judgment (deduplication) + # logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}") + # long_term_storage_aggregate_task.delay( + # end_user_id=end_user_id, + # langchain_messages=langchain_messages, + # config_id=config_id + # ) # async def main(): diff --git a/api/app/tasks.py b/api/app/tasks.py index 48b41e4f..db332816 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1066,6 +1066,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") except Exception as e: + db.rollback() # Rollback failed transaction to allow next query api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") all_reflection_results.append({ "workspace_id": str(workspace_id), @@ -1204,3 +1205,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di return result finally: loop.close() + + +# ============================================================================= +# Long-term Memory Storage Tasks (Batched Write Strategies) +# ============================================================================= + +@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True) +def long_term_storage_window_task( + self, + end_user_id: str, + langchain_messages: List[Dict[str, Any]], + config_id: str, + scope: int = 6 +) -> Dict[str, Any]: + """Celery task for window-based long-term memory storage. + + Accumulates messages in Redis buffer until window size (scope) is reached, + then writes batched messages to Neo4j. + + Args: + end_user_id: End user identifier + langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] + config_id: Memory configuration ID + scope: Window size (number of messages before triggering write) + + Returns: + Dict containing task status and metadata + """ + from app.core.logging_config import get_logger + logger = get_logger(__name__) + + logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}") + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue + from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format + from app.core.memory.agent.utils.redis_tool import write_store + from app.services.memory_config_service import MemoryConfigService + + db = next(get_db()) + try: + # Save to Redis buffer first + write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + + # Load memory config + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="LongTermStorageTask" + ) + + # Execute window-based dialogue storage + await window_dialogue(end_user_id, langchain_messages, memory_config, scope) + + return {"status": "SUCCESS", "strategy": "window", "scope": scope} + finally: + db.close() + + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + try: + result = loop.run_until_complete(_run()) + elapsed_time = time.time() - start_time + + logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s") + + return { + **result, + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + except Exception as e: + elapsed_time = time.time() - start_time + logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True) + + return { + "status": "FAILURE", + "strategy": "window", + "error": str(e), + "end_user_id": end_user_id, + "config_id": config_id, + "elapsed_time": elapsed_time, + "task_id": self.request.id + } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True) +# def long_term_storage_time_task( +# self, +# end_user_id: str, +# config_id: str, +# time_window: int = 5 +# ) -> Dict[str, Any]: +# """Celery task for time-based long-term memory storage. + +# Retrieves recent sessions from Redis within time window and writes to Neo4j. + +# Args: +# end_user_id: End user identifier +# config_id: Memory configuration ID +# time_window: Time window in minutes for retrieving recent sessions + +# Returns: +# Dict containing task status and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute time-based storage +# await memory_long_term_storage(end_user_id, memory_config, time_window) + +# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window} +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "time", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } + + +# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True) +# def long_term_storage_aggregate_task( +# self, +# end_user_id: str, +# langchain_messages: List[Dict[str, Any]], +# config_id: str +# ) -> Dict[str, Any]: +# """Celery task for aggregate-based long-term memory storage. + +# Uses LLM to determine if new messages describe the same event as history. +# Only writes to Neo4j if messages represent new information (not duplicates). + +# Args: +# end_user_id: End user identifier +# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}] +# config_id: Memory configuration ID + +# Returns: +# Dict containing task status, is_same_event flag, and metadata +# """ +# from app.core.logging_config import get_logger +# logger = get_logger(__name__) + +# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}") +# start_time = time.time() + +# async def _run() -> Dict[str, Any]: +# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment +# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format +# from app.core.memory.agent.utils.redis_tool import write_store +# from app.services.memory_config_service import MemoryConfigService + +# db = next(get_db()) +# try: +# # Save to Redis buffer first +# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + +# # Load memory config +# config_service = MemoryConfigService(db) +# memory_config = config_service.load_memory_config( +# config_id=config_id, +# service_name="LongTermStorageTask" +# ) + +# # Execute aggregate judgment +# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config) + +# return { +# "status": "SUCCESS", +# "strategy": "aggregate", +# "is_same_event": result.get("is_same_event", False), +# "wrote_to_neo4j": not result.get("is_same_event", False) +# } +# finally: +# db.close() + +# try: +# import nest_asyncio +# nest_asyncio.apply() +# except ImportError: +# pass + +# try: +# loop = asyncio.get_event_loop() +# if loop.is_closed(): +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) +# except RuntimeError: +# loop = asyncio.new_event_loop() +# asyncio.set_event_loop(loop) + +# try: +# result = loop.run_until_complete(_run()) +# elapsed_time = time.time() - start_time + +# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s") + +# return { +# **result, +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# } +# except Exception as e: +# elapsed_time = time.time() - start_time +# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True) + +# return { +# "status": "FAILURE", +# "strategy": "aggregate", +# "error": str(e), +# "end_user_id": end_user_id, +# "config_id": config_id, +# "elapsed_time": elapsed_time, +# "task_id": self.request.id +# }