feat(memory): add long-term storage task routing and batching

This commit is contained in:
Ke Sun
2026-02-03 15:52:45 +08:00
parent 9f2b6390b0
commit f27de7df35
5 changed files with 353 additions and 27 deletions

View File

@@ -64,6 +64,11 @@ celery_app.conf.update(
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker) # Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},

View File

@@ -148,6 +148,7 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content)) messages.append(HumanMessage(content=user_content))
return messages return messages
# TODO: 移到memory module
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db()) db = next(get_db())
scope=6 scope=6
@@ -307,9 +308,12 @@ class LangChainAgent:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if memory_flag: if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content) long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
'''长期''' # Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = { response = {
"content": content, "content": content,
@@ -441,9 +445,13 @@ class LangChainAgent:
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content) long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e: except Exception as e:

View File

@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
for node_name, node_data in update_event.items(): for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name: if 'save_neo4j' == node_name:
massages = node_data massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status'] massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result') contents = massages.get('write_result')
print(contents) print(contents)

View File

@@ -1,18 +1,14 @@
import asyncio import asyncio
import json
import sys import sys
import warnings import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from langgraph.constants import END, START from langgraph.constants import END, START
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
from app.db import get_db
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -40,27 +36,55 @@ async def make_write_graph():
yield graph yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment """Dispatch long-term memory storage to Celery background tasks.
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断""" Args:
await aggregate_judgment(end_user_id, langchain_messages, memory_config) long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
# async def main(): # async def main():

View File

@@ -1066,6 +1066,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务") f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
except Exception as e: except Exception as e:
db.rollback() # Rollback failed transaction to allow next query
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}") api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
all_reflection_results.append({ all_reflection_results.append({
"workspace_id": str(workspace_id), "workspace_id": str(workspace_id),
@@ -1204,3 +1205,290 @@ def run_forgetting_cycle_task(self, config_id: Optional[uuid.UUID] = None) -> Di
return result return result
finally: finally:
loop.close() loop.close()
# =============================================================================
# Long-term Memory Storage Tasks (Batched Write Strategies)
# =============================================================================
@celery_app.task(name="app.core.memory.agent.long_term_storage.window", bind=True)
def long_term_storage_window_task(
self,
end_user_id: str,
langchain_messages: List[Dict[str, Any]],
config_id: str,
scope: int = 6
) -> Dict[str, Any]:
"""Celery task for window-based long-term memory storage.
Accumulates messages in Redis buffer until window size (scope) is reached,
then writes batched messages to Neo4j.
Args:
end_user_id: End user identifier
langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
config_id: Memory configuration ID
scope: Window size (number of messages before triggering write)
Returns:
Dict containing task status and metadata
"""
from app.core.logging_config import get_logger
logger = get_logger(__name__)
logger.info(f"[LONG_TERM_WINDOW] Starting task - end_user_id={end_user_id}, scope={scope}")
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.memory.agent.langgraph_graph.routing.write_router import window_dialogue
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
from app.services.memory_config_service import MemoryConfigService
db = next(get_db())
try:
# Save to Redis buffer first
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# Load memory config
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="LongTermStorageTask"
)
# Execute window-based dialogue storage
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
return {"status": "SUCCESS", "strategy": "window", "scope": scope}
finally:
db.close()
try:
import nest_asyncio
nest_asyncio.apply()
except ImportError:
pass
try:
loop = asyncio.get_event_loop()
if loop.is_closed():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
elapsed_time = time.time() - start_time
logger.info(f"[LONG_TERM_WINDOW] Task completed - elapsed_time={elapsed_time:.2f}s")
return {
**result,
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
except Exception as e:
elapsed_time = time.time() - start_time
logger.error(f"[LONG_TERM_WINDOW] Task failed - error={str(e)}", exc_info=True)
return {
"status": "FAILURE",
"strategy": "window",
"error": str(e),
"end_user_id": end_user_id,
"config_id": config_id,
"elapsed_time": elapsed_time,
"task_id": self.request.id
}
# @celery_app.task(name="app.core.memory.agent.long_term_storage.time", bind=True)
# def long_term_storage_time_task(
# self,
# end_user_id: str,
# config_id: str,
# time_window: int = 5
# ) -> Dict[str, Any]:
# """Celery task for time-based long-term memory storage.
# Retrieves recent sessions from Redis within time window and writes to Neo4j.
# Args:
# end_user_id: End user identifier
# config_id: Memory configuration ID
# time_window: Time window in minutes for retrieving recent sessions
# Returns:
# Dict containing task status and metadata
# """
# from app.core.logging_config import get_logger
# logger = get_logger(__name__)
# logger.info(f"[LONG_TERM_TIME] Starting task - end_user_id={end_user_id}, time_window={time_window}")
# start_time = time.time()
# async def _run() -> Dict[str, Any]:
# from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage
# from app.services.memory_config_service import MemoryConfigService
# db = next(get_db())
# try:
# # Load memory config
# config_service = MemoryConfigService(db)
# memory_config = config_service.load_memory_config(
# config_id=config_id,
# service_name="LongTermStorageTask"
# )
# # Execute time-based storage
# await memory_long_term_storage(end_user_id, memory_config, time_window)
# return {"status": "SUCCESS", "strategy": "time", "time_window": time_window}
# finally:
# db.close()
# try:
# import nest_asyncio
# nest_asyncio.apply()
# except ImportError:
# pass
# try:
# loop = asyncio.get_event_loop()
# if loop.is_closed():
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# except RuntimeError:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(_run())
# elapsed_time = time.time() - start_time
# logger.info(f"[LONG_TERM_TIME] Task completed - elapsed_time={elapsed_time:.2f}s")
# return {
# **result,
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# except Exception as e:
# elapsed_time = time.time() - start_time
# logger.error(f"[LONG_TERM_TIME] Task failed - error={str(e)}", exc_info=True)
# return {
# "status": "FAILURE",
# "strategy": "time",
# "error": str(e),
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# @celery_app.task(name="app.core.memory.agent.long_term_storage.aggregate", bind=True)
# def long_term_storage_aggregate_task(
# self,
# end_user_id: str,
# langchain_messages: List[Dict[str, Any]],
# config_id: str
# ) -> Dict[str, Any]:
# """Celery task for aggregate-based long-term memory storage.
# Uses LLM to determine if new messages describe the same event as history.
# Only writes to Neo4j if messages represent new information (not duplicates).
# Args:
# end_user_id: End user identifier
# langchain_messages: List of messages [{"role": "user/assistant", "content": "..."}]
# config_id: Memory configuration ID
# Returns:
# Dict containing task status, is_same_event flag, and metadata
# """
# from app.core.logging_config import get_logger
# logger = get_logger(__name__)
# logger.info(f"[LONG_TERM_AGGREGATE] Starting task - end_user_id={end_user_id}")
# start_time = time.time()
# async def _run() -> Dict[str, Any]:
# from app.core.memory.agent.langgraph_graph.routing.write_router import aggregate_judgment
# from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
# from app.core.memory.agent.utils.redis_tool import write_store
# from app.services.memory_config_service import MemoryConfigService
# db = next(get_db())
# try:
# # Save to Redis buffer first
# write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# # Load memory config
# config_service = MemoryConfigService(db)
# memory_config = config_service.load_memory_config(
# config_id=config_id,
# service_name="LongTermStorageTask"
# )
# # Execute aggregate judgment
# result = await aggregate_judgment(end_user_id, langchain_messages, memory_config)
# return {
# "status": "SUCCESS",
# "strategy": "aggregate",
# "is_same_event": result.get("is_same_event", False),
# "wrote_to_neo4j": not result.get("is_same_event", False)
# }
# finally:
# db.close()
# try:
# import nest_asyncio
# nest_asyncio.apply()
# except ImportError:
# pass
# try:
# loop = asyncio.get_event_loop()
# if loop.is_closed():
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# except RuntimeError:
# loop = asyncio.new_event_loop()
# asyncio.set_event_loop(loop)
# try:
# result = loop.run_until_complete(_run())
# elapsed_time = time.time() - start_time
# logger.info(f"[LONG_TERM_AGGREGATE] Task completed - is_same_event={result.get('is_same_event')}, elapsed_time={elapsed_time:.2f}s")
# return {
# **result,
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }
# except Exception as e:
# elapsed_time = time.time() - start_time
# logger.error(f"[LONG_TERM_AGGREGATE] Task failed - error={str(e)}", exc_info=True)
# return {
# "status": "FAILURE",
# "strategy": "aggregate",
# "error": str(e),
# "end_user_id": end_user_id,
# "config_id": config_id,
# "elapsed_time": elapsed_time,
# "task_id": self.request.id
# }