Fix/release memory bug (#306)

* memory_BUG_fix

* memory_BUG

* memory_BUG_long_term

* memory_BUG_long_term

* memory_BUG_long_term
This commit is contained in:
lixinyue11
2026-02-04 14:34:00 +08:00
committed by GitHub
parent cdd6b80089
commit 8f0a1d9c6e
6 changed files with 202 additions and 268 deletions

View File

@@ -7,30 +7,21 @@ LangChain Agent 封装
- 支持流式输出 - 支持流式输出
- 使用 RedBearLLM 支持多提供商 - 使用 RedBearLLM 支持多提供商
""" """
import os
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
from app.db import get_db from app.db import get_db
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import ( from app.services.memory_agent_service import (
get_end_user_connected_config, get_end_user_connected_config,
) )
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from app.utils.config_utils import resolve_config_id
logger = get_business_logger() logger = get_business_logger()
@@ -148,106 +139,6 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content)) messages.append(HumanMessage(content=user_content))
return messages return messages
# TODO: 移到memory module
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
db = next(get_db())
#TODO: 魔法数字
scope=6
try:
repo = LongTermMemoryRepository(db)
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
# Handle case where no session exists in Redis (returns False)
if not result or result is False:
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
return
if type=="chunk" or type=="aggregate":
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'写入短长期:')
else:
# TODO: This branch handles type="time" strategy, currently unused.
# Will be activated when time-based long-term storage is implemented.
# TODO: 魔法数字 - extract 5 to a constant
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
# Handle case where no session exists in Redis (returns False or empty)
if not long_time_data or long_time_data is False:
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
return
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
finally:
db.close()
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id=resolve_config_id(actual_config_id, db)
if storage_type == "rag":
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def chat( async def chat(
self, self,
message: str, message: str,
@@ -321,14 +212,7 @@ class LangChainAgent:
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if memory_flag: if memory_flag:
long_term_messages=await agent_chat_messages(message_chat,content) await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
response = { response = {
"content": content, "content": content,
"model": self.model_name, "model": self.model_name,
@@ -459,15 +343,7 @@ class LangChainAgent:
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
long_term_messages = await agent_chat_messages(message_chat, full_content)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
except Exception as e: except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise raise

View File

@@ -1,8 +1,9 @@
import json
import os import os
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
@@ -10,46 +11,115 @@ from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context, get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
"""
写入记忆(支持结构化消息)
async def write_messages(end_user_id,langchain_messages,memory_config): Args:
''' storage_type: 存储类型 (neo4j/rag)
写入数据到neo4j
Args:
end_user_id: 终端用户ID end_user_id: 终端用户ID
memory_config: 内存配置对象 user_message: 用户消息内容
langchain_messages原始数据LIST ai_message: AI 回复内容
''' user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try: try:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages使用它替代 structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
with get_db_context() as db_session:
try:
repo = LongTermMemoryRepository(db_session)
await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
# yield db_session
finally:
if db_session.in_transaction():
db_session.rollback()
db_session.close()
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
# TODO删除
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
print(contents)
except Exception as e:
import traceback
traceback.print_exc()
'''根据窗口''' '''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
''' '''
@@ -61,25 +131,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
scope窗口大小 scope窗口大小
''' '''
scope=scope scope=scope
redis_messages = []
is_end_user_id = count_store.get_sessions_count(end_user_id) is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False: if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0] is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1] redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope): if is_end_user_id and int(is_end_user_id) != int(scope):
print(is_end_user_id)
is_end_user_id += 1 is_end_user_id += 1
langchain_messages += redis_messages langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope): elif int(is_end_user_id) == int(scope):
print('写入长期记忆并且设置为0') logger.info('写入长期记忆NEO4J')
print(is_end_user_id) formatted_messages = (redis_messages)
formatted_messages = await chat_data_format(redis_messages) # 获取 config_id如果 memory_config 是对象,提取 config_id否则直接使用
print(100*'-') if hasattr(memory_config, 'config_id'):
print(formatted_messages) config_id = memory_config.config_id
print(100*'-') else:
await write_messages(end_user_id, formatted_messages, memory_config) config_id = memory_config
count_store.update_sessions_count(end_user_id, 0, '')
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else: else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages) count_store.save_sessions_count(end_user_id, 1, langchain_messages)
@@ -93,12 +164,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
memory_config: 内存配置对象 memory_config: 内存配置对象
''' '''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time) long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
# Handle case where no session exists in Redis (returns False or empty) format_messages = (long_time_data)
if not long_time_data or long_time_data is False: messages=[]
return memory_config=memory_config.config_id
format_messages = await chat_data_format(long_time_data) for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]: if format_messages!=[]:
await write_messages(end_user_id, format_messages, memory_config) await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断''' '''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
""" """
@@ -109,13 +183,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象 memory_config: 内存配置对象
""" """
try: try:
# 1. 获取历史会话数据(使用新方法) # 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id) result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
# Handle case where no session exists in Redis (returns False or empty) if not result:
if not result or result is False:
history = [] history = []
else: else:
history = await format_parsing(result) history = await format_parsing(result)
@@ -154,7 +227,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
} }
if not structured.is_same_event: if not structured.is_same_event:
logger.info(result_dict) logger.info(result_dict)
await write_messages(end_user_id, output_value, memory_config) await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict return result_dict
except Exception as e: except Exception as e:

View File

@@ -26,13 +26,13 @@ async def format_parsing(messages: list,type:str='string'):
role = content['role'] role = content['role']
content = content['content'] content = content['content']
if type == "string": if type == "string":
if role == 'human': if role == 'human' or role=="user":
content = '用户:' + content content = '用户:' + content
else: else:
content = 'AI:' + content content = 'AI:' + content
result.append(content) result.append(content)
if type == "dict": if type == "dict" :
if role == 'human': if role == 'human' or role=="user":
user.append( content) user.append( content)
else: else:
ai.append(content) ai.append(content)
@@ -57,33 +57,7 @@ async def messages_parse(messages: list | dict):
for key, values in zip(user, ai): for key, values in zip(user, ai):
database.append({key, values}) database.append({key, values})
return database return database
async def chat_data_format(messages: list | dict):
"""
将消息格式化为 LangChain 消息格式
Args:
messages: 消息列表或字典
Returns:
LangChain 消息列表
"""
langchain_messages = []
if isinstance(messages, list):
for msg in messages:
if 'role' in msg.keys():
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
if "Query" in msg.keys():
langchain_messages.append(HumanMessage(content=msg['Query']))
langchain_messages.append(AIMessage(content=msg['Answer']))
if isinstance(messages, dict):
if messages['type'] == 'human':
langchain_messages.append(HumanMessage(content=messages['content']))
elif messages['type'] == 'ai':
langchain_messages.append(AIMessage(content=messages['content']))
return langchain_messages
async def agent_chat_messages(user_content,ai_content): async def agent_chat_messages(user_content,ai_content):
messages = [ messages = [

View File

@@ -1,14 +1,19 @@
import asyncio import asyncio
import json
import sys import sys
import warnings import warnings
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from langgraph.constants import END, START from langgraph.constants import END, START
from langgraph.graph import StateGraph from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
@@ -35,75 +40,67 @@ async def make_write_graph():
graph = workflow.compile() graph = workflow.compile()
yield graph yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
"""Dispatch long-term memory storage to Celery background tasks.
Args:
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
langchain_messages: List of messages to store
memory_config: Memory configuration ID (string)
end_user_id: End user identifier
scope: Window size for 'chunk' strategy (default: 6)
"""
from app.tasks import (
long_term_storage_window_task,
# TODO: Uncomment when implemented
# long_term_storage_time_task,
# long_term_storage_aggregate_task,
)
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# Convert config to string if needed
config_id = str(memory_config) if memory_config else ''
if long_term_type == 'chunk':
# Strategy 1: Window-based batching (6 rounds of dialogue)
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
long_term_storage_window_task.delay(
end_user_id=end_user_id,
langchain_messages=langchain_messages,
config_id=config_id,
scope=scope
)
# TODO: Uncomment when time-based strategy is fully implemented
# elif long_term_type == 'time':
# # Strategy 2: Time-based retrieval
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
# long_term_storage_time_task.delay(
# end_user_id=end_user_id,
# config_id=config_id,
# time_window=5
# )
# TODO: Uncomment when aggregate strategy is fully implemented
# elif long_term_type == 'aggregate':
# # Strategy 3: Aggregate judgment (deduplication)
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
# long_term_storage_aggregate_task.delay(
# end_user_id=end_user_id,
# langchain_messages=langchain_messages,
# config_id=config_id
# )
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
# 获取数据库会话
with get_db_context() as db_session:
try:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
finally:
if db_session.in_transaction():
db_session.rollback()
db_session.close()
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main(): # async def main():
# """主函数 - 运行工作流""" # """主函数 - 运行工作流"""
# langchain_messages = [ # langchain_messages = [
# { # {
# "role": "user", # "role": "user",
# "content": "今天周五好开心啊" # "content": "今天周五去爬山"
# }, # },
# { # {
# "role": "assistant", # "role": "assistant",
# "content": "你也这么觉得,我也是耶" # "content": "耶"
# } # }
# #
# ] # ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) #
# #
# #
# if __name__ == "__main__": # if __name__ == "__main__":

View File

@@ -1,3 +1,4 @@
from abc import ABC
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel): class Write_UserInput(BaseModel):
messages: list[dict] messages: list[dict]
end_user_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量"""
STORAGE_NEO4J = "neo4j"
STORAGE_RAG = "rag"
STRATEGY_AGGREGATE = "aggregate"
STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6

View File

@@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
result = task_service.get_task_memory_read_result(task.id) result = task_service.get_task_memory_read_result(task.id)
status = result.get("status") status = result.get("status")
logger.info(f"读取任务状态:{status}") logger.info(f"读取任务状态:{status}")
if memory_content:
memory_content = memory_content['answer']
finally: finally:
db.close() db.close()
@@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
"content_length": len(str(memory_content)) "content_length": len(str(memory_content))
} }
) )
return f"检索到以下历史记忆:\n\n{memory_content}" return f"检索到以下历史记忆:\n\n{memory_content}"
except Exception as e: except Exception as e:
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__}) logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})