feat(memory): add long-term storage task routing and batching

This commit is contained in:
Ke Sun
2026-02-03 15:52:45 +08:00
parent 9f2b6390b0
commit f27de7df35
5 changed files with 353 additions and 27 deletions

View File

@@ -148,6 +148,7 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
# TODO: 移到memory module
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
scope=6
@@ -307,9 +308,12 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
'''长期'''
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = {
"content": content,
@@ -441,9 +445,13 @@ class LangChainAgent:
yield total_tokens
break
if memory_flag:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e:

View File

@@ -43,6 +43,7 @@ async def write_messages(end_user_id,langchain_messages,memory_config):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)

View File

@@ -1,18 +1,14 @@
import asyncio
import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
@@ -40,27 +36,55 @@ async def make_write_graph():
yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
"""Dispatch long-term memory storage to Celery background tasks.
Args:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
# async def main():