Fix/release memory bug (#306)
* memory_BUG_fix * memory_BUG * memory_BUG_long_term * memory_BUG_long_term * memory_BUG_long_term
This commit is contained in:
@@ -7,30 +7,21 @@ LangChain Agent 封装
|
|||||||
- 支持流式输出
|
- 支持流式输出
|
||||||
- 使用 RedBearLLM 支持多提供商
|
- 使用 RedBearLLM 支持多提供商
|
||||||
"""
|
"""
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
|
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
|
||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -148,106 +139,6 @@ class LangChainAgent:
|
|||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
# TODO: 移到memory module
|
|
||||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
|
||||||
db = next(get_db())
|
|
||||||
#TODO: 魔法数字
|
|
||||||
scope=6
|
|
||||||
|
|
||||||
try:
|
|
||||||
repo = LongTermMemoryRepository(db)
|
|
||||||
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
|
|
||||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
|
||||||
|
|
||||||
# Handle case where no session exists in Redis (returns False)
|
|
||||||
if not result or result is False:
|
|
||||||
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
|
|
||||||
return
|
|
||||||
|
|
||||||
if type=="chunk" or type=="aggregate":
|
|
||||||
data = await format_parsing(result, "dict")
|
|
||||||
chunk_data = data[:scope]
|
|
||||||
if len(chunk_data)==scope:
|
|
||||||
repo.upsert(end_user_id, chunk_data)
|
|
||||||
logger.info(f'写入短长期:')
|
|
||||||
else:
|
|
||||||
# TODO: This branch handles type="time" strategy, currently unused.
|
|
||||||
# Will be activated when time-based long-term storage is implemented.
|
|
||||||
# TODO: 魔法数字 - extract 5 to a constant
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
|
||||||
if not long_time_data or long_time_data is False:
|
|
||||||
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
|
|
||||||
return
|
|
||||||
long_messages = await messages_parse(long_time_data)
|
|
||||||
repo.upsert(end_user_id, long_messages)
|
|
||||||
logger.info(f'写入短长期:')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
|
||||||
"""
|
|
||||||
写入记忆(支持结构化消息)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
user_message: 用户消息内容
|
|
||||||
ai_message: AI 回复内容
|
|
||||||
user_rag_memory_id: RAG 记忆ID
|
|
||||||
actual_end_user_id: 实际用户ID
|
|
||||||
actual_config_id: 配置ID
|
|
||||||
|
|
||||||
逻辑说明:
|
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
|
||||||
- Neo4j 模式:使用结构化消息列表
|
|
||||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
|
||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
|
||||||
"""
|
|
||||||
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
actual_config_id=resolve_config_id(actual_config_id, db)
|
|
||||||
|
|
||||||
if storage_type == "rag":
|
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
else:
|
|
||||||
# Neo4j 模式:使用结构化消息列表
|
|
||||||
structured_messages = []
|
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
|
||||||
if user_message:
|
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
|
||||||
if ai_message:
|
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
|
||||||
if not structured_messages:
|
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
|
||||||
write_id = write_message_task.delay(
|
|
||||||
actual_end_user_id, # end_user_id: 用户ID
|
|
||||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
|
||||||
actual_config_id, # config_id: 配置ID
|
|
||||||
storage_type, # storage_type: "neo4j"
|
|
||||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
|
||||||
)
|
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@@ -321,14 +212,7 @@ class LangChainAgent:
|
|||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
long_term_messages=await agent_chat_messages(message_chat,content)
|
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
|
||||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
|
||||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
|
||||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
|
||||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
|
||||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
|
||||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -459,15 +343,7 @@ class LangChainAgent:
|
|||||||
yield total_tokens
|
yield total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
|
||||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
|
||||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
|
||||||
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
|
||||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
|
||||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
|
||||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||||
|
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
@@ -10,46 +11,115 @@ from app.core.memory.agent.utils.redis_tool import write_store
|
|||||||
from app.core.memory.agent.utils.redis_tool import count_store
|
from app.core.memory.agent.utils.redis_tool import count_store
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context, get_db
|
||||||
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||||
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||||
|
actual_config_id, long_term_messages=[]):
|
||||||
|
"""
|
||||||
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
async def write_messages(end_user_id,langchain_messages,memory_config):
|
Args:
|
||||||
'''
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
写入数据到neo4j:
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
end_user_id: 终端用户ID
|
||||||
memory_config: 内存配置对象
|
user_message: 用户消息内容
|
||||||
langchain_messages:原始数据LIST
|
ai_message: AI 回复内容
|
||||||
'''
|
user_rag_memory_id: RAG 记忆ID
|
||||||
|
actual_end_user_id: 实际用户ID
|
||||||
|
actual_config_id: 配置ID
|
||||||
|
|
||||||
|
逻辑说明:
|
||||||
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
|
- Neo4j 模式:使用结构化消息列表
|
||||||
|
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||||
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
|
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||||
|
# Neo4j 模式:使用结构化消息列表
|
||||||
|
structured_messages = []
|
||||||
|
|
||||||
|
# 始终添加用户消息(如果不为空)
|
||||||
|
if isinstance(user_message, str) and user_message.strip() != "":
|
||||||
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
|
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||||
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
|
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||||
|
if long_term_messages and isinstance(long_term_messages, list):
|
||||||
|
structured_messages = long_term_messages
|
||||||
|
elif long_term_messages and isinstance(long_term_messages, str):
|
||||||
|
# 如果是 JSON 字符串,先解析
|
||||||
|
try:
|
||||||
|
structured_messages = json.loads(long_term_messages)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||||
|
|
||||||
|
# 如果没有消息,直接返回
|
||||||
|
if not structured_messages:
|
||||||
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
|
write_id = write_message_task.delay(
|
||||||
|
actual_end_user_id, # end_user_id: 用户ID
|
||||||
|
structured_messages, # message: JSON 字符串格式的消息列表
|
||||||
|
str(actual_config_id), # config_id: 配置ID字符串
|
||||||
|
storage_type, # storage_type: "neo4j"
|
||||||
|
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||||
|
with get_db_context() as db_session:
|
||||||
|
try:
|
||||||
|
repo = LongTermMemoryRepository(db_session)
|
||||||
|
await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages,
|
||||||
|
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
|
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
|
data = await format_parsing(result, "dict")
|
||||||
|
chunk_data = data[:scope]
|
||||||
|
if len(chunk_data)==scope:
|
||||||
|
repo.upsert(end_user_id, chunk_data)
|
||||||
|
logger.info(f'---------写入短长期-----------')
|
||||||
|
else:
|
||||||
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||||
|
long_messages = await messages_parse(long_time_data)
|
||||||
|
repo.upsert(end_user_id, long_messages)
|
||||||
|
logger.info(f'写入短长期:')
|
||||||
|
# yield db_session
|
||||||
|
finally:
|
||||||
|
if db_session.in_transaction():
|
||||||
|
db_session.rollback()
|
||||||
|
db_session.close()
|
||||||
|
|
||||||
async with make_write_graph() as graph:
|
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
|
||||||
# 初始状态 - 包含所有必要字段
|
|
||||||
initial_state = {
|
|
||||||
"messages": langchain_messages,
|
|
||||||
"end_user_id": end_user_id,
|
|
||||||
"memory_config": memory_config
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取节点更新信息
|
|
||||||
async for update_event in graph.astream(
|
|
||||||
initial_state,
|
|
||||||
stream_mode="updates",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
if 'save_neo4j' == node_name:
|
|
||||||
massages = node_data
|
|
||||||
# TODO:删除
|
|
||||||
massagesstatus = massages.get('write_result')['status']
|
|
||||||
contents = massages.get('write_result')
|
|
||||||
print(contents)
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
'''根据窗口'''
|
'''根据窗口'''
|
||||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||||
'''
|
'''
|
||||||
@@ -61,25 +131,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
|||||||
scope:窗口大小
|
scope:窗口大小
|
||||||
'''
|
'''
|
||||||
scope=scope
|
scope=scope
|
||||||
redis_messages = []
|
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||||
if is_end_user_id is not False:
|
if is_end_user_id is not False:
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||||
print(is_end_user_id)
|
|
||||||
is_end_user_id += 1
|
is_end_user_id += 1
|
||||||
langchain_messages += redis_messages
|
langchain_messages += redis_messages
|
||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
elif int(is_end_user_id) == int(scope):
|
||||||
print('写入长期记忆,并且设置为0')
|
logger.info('写入长期记忆NEO4J')
|
||||||
print(is_end_user_id)
|
formatted_messages = (redis_messages)
|
||||||
formatted_messages = await chat_data_format(redis_messages)
|
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||||
print(100*'-')
|
if hasattr(memory_config, 'config_id'):
|
||||||
print(formatted_messages)
|
config_id = memory_config.config_id
|
||||||
print(100*'-')
|
else:
|
||||||
await write_messages(end_user_id, formatted_messages, memory_config)
|
config_id = memory_config
|
||||||
count_store.update_sessions_count(end_user_id, 0, '')
|
|
||||||
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
config_id, formatted_messages)
|
||||||
|
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
else:
|
else:
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
|
|
||||||
@@ -93,12 +164,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
|
|||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
'''
|
'''
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
format_messages = (long_time_data)
|
||||||
if not long_time_data or long_time_data is False:
|
messages=[]
|
||||||
return
|
memory_config=memory_config.config_id
|
||||||
format_messages = await chat_data_format(long_time_data)
|
for i in format_messages:
|
||||||
|
message=json.loads(i['Query'])
|
||||||
|
messages+= message
|
||||||
if format_messages!=[]:
|
if format_messages!=[]:
|
||||||
await write_messages(end_user_id, format_messages, memory_config)
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config, messages)
|
||||||
'''聚合判断'''
|
'''聚合判断'''
|
||||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -109,13 +183,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 获取历史会话数据(使用新方法)
|
# 1. 获取历史会话数据(使用新方法)
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
|
history = await format_parsing(result)
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
if not result:
|
||||||
if not result or result is False:
|
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
history = await format_parsing(result)
|
history = await format_parsing(result)
|
||||||
@@ -154,7 +227,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
}
|
}
|
||||||
if not structured.is_same_event:
|
if not structured.is_same_event:
|
||||||
logger.info(result_dict)
|
logger.info(result_dict)
|
||||||
await write_messages(end_user_id, output_value, memory_config)
|
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config.config_id, output_value)
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -26,13 +26,13 @@ async def format_parsing(messages: list,type:str='string'):
|
|||||||
role = content['role']
|
role = content['role']
|
||||||
content = content['content']
|
content = content['content']
|
||||||
if type == "string":
|
if type == "string":
|
||||||
if role == 'human':
|
if role == 'human' or role=="user":
|
||||||
content = '用户:' + content
|
content = '用户:' + content
|
||||||
else:
|
else:
|
||||||
content = 'AI:' + content
|
content = 'AI:' + content
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if type == "dict":
|
if type == "dict" :
|
||||||
if role == 'human':
|
if role == 'human' or role=="user":
|
||||||
user.append( content)
|
user.append( content)
|
||||||
else:
|
else:
|
||||||
ai.append(content)
|
ai.append(content)
|
||||||
@@ -57,33 +57,7 @@ async def messages_parse(messages: list | dict):
|
|||||||
for key, values in zip(user, ai):
|
for key, values in zip(user, ai):
|
||||||
database.append({key, values})
|
database.append({key, values})
|
||||||
return database
|
return database
|
||||||
async def chat_data_format(messages: list | dict):
|
|
||||||
"""
|
|
||||||
将消息格式化为 LangChain 消息格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表或字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LangChain 消息列表
|
|
||||||
"""
|
|
||||||
langchain_messages = []
|
|
||||||
if isinstance(messages, list):
|
|
||||||
for msg in messages:
|
|
||||||
if 'role' in msg.keys():
|
|
||||||
if msg['role'] == 'user':
|
|
||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
|
||||||
elif msg['role'] == 'assistant':
|
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
|
||||||
if "Query" in msg.keys():
|
|
||||||
langchain_messages.append(HumanMessage(content=msg['Query']))
|
|
||||||
langchain_messages.append(AIMessage(content=msg['Answer']))
|
|
||||||
if isinstance(messages, dict):
|
|
||||||
if messages['type'] == 'human':
|
|
||||||
langchain_messages.append(HumanMessage(content=messages['content']))
|
|
||||||
elif messages['type'] == 'ai':
|
|
||||||
langchain_messages.append(AIMessage(content=messages['content']))
|
|
||||||
return langchain_messages
|
|
||||||
|
|
||||||
async def agent_chat_messages(user_content,ai_content):
|
async def agent_chat_messages(user_content,ai_content):
|
||||||
messages = [
|
messages = [
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
|
from app.db import get_db, get_db_context
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
@@ -35,75 +40,67 @@ async def make_write_graph():
|
|||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
|
|
||||||
yield graph
|
yield graph
|
||||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
|
||||||
"""Dispatch long-term memory storage to Celery background tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
long_term_type: Storage strategy - 'chunk' (window), 'time', or 'aggregate'
|
|
||||||
langchain_messages: List of messages to store
|
|
||||||
memory_config: Memory configuration ID (string)
|
|
||||||
end_user_id: End user identifier
|
|
||||||
scope: Window size for 'chunk' strategy (default: 6)
|
|
||||||
"""
|
|
||||||
from app.tasks import (
|
|
||||||
long_term_storage_window_task,
|
|
||||||
# TODO: Uncomment when implemented
|
|
||||||
# long_term_storage_time_task,
|
|
||||||
# long_term_storage_aggregate_task,
|
|
||||||
)
|
|
||||||
from app.core.logging_config import get_logger
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
# Convert config to string if needed
|
|
||||||
config_id = str(memory_config) if memory_config else ''
|
|
||||||
|
|
||||||
if long_term_type == 'chunk':
|
|
||||||
# Strategy 1: Window-based batching (6 rounds of dialogue)
|
|
||||||
logger.info(f"[LONG_TERM] Dispatching window task - end_user_id={end_user_id}, scope={scope}")
|
|
||||||
long_term_storage_window_task.delay(
|
|
||||||
end_user_id=end_user_id,
|
|
||||||
langchain_messages=langchain_messages,
|
|
||||||
config_id=config_id,
|
|
||||||
scope=scope
|
|
||||||
)
|
|
||||||
# TODO: Uncomment when time-based strategy is fully implemented
|
|
||||||
# elif long_term_type == 'time':
|
|
||||||
# # Strategy 2: Time-based retrieval
|
|
||||||
# logger.info(f"[LONG_TERM] Dispatching time task - end_user_id={end_user_id}")
|
|
||||||
# long_term_storage_time_task.delay(
|
|
||||||
# end_user_id=end_user_id,
|
|
||||||
# config_id=config_id,
|
|
||||||
# time_window=5
|
|
||||||
# )
|
|
||||||
# TODO: Uncomment when aggregate strategy is fully implemented
|
|
||||||
# elif long_term_type == 'aggregate':
|
|
||||||
# # Strategy 3: Aggregate judgment (deduplication)
|
|
||||||
# logger.info(f"[LONG_TERM] Dispatching aggregate task - end_user_id={end_user_id}")
|
|
||||||
# long_term_storage_aggregate_task.delay(
|
|
||||||
# end_user_id=end_user_id,
|
|
||||||
# langchain_messages=langchain_messages,
|
|
||||||
# config_id=config_id
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||||
|
# 获取数据库会话
|
||||||
|
with get_db_context() as db_session:
|
||||||
|
try:
|
||||||
|
config_service = MemoryConfigService(db_session)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=memory_config, # 改为整数
|
||||||
|
service_name="MemoryAgentService"
|
||||||
|
)
|
||||||
|
if long_term_type=='chunk':
|
||||||
|
'''方案一:对话窗口6轮对话'''
|
||||||
|
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||||
|
if long_term_type=='time':
|
||||||
|
"""时间"""
|
||||||
|
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||||
|
if long_term_type=='aggregate':
|
||||||
|
"""方案三:聚合判断"""
|
||||||
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
finally:
|
||||||
|
if db_session.in_transaction():
|
||||||
|
db_session.rollback()
|
||||||
|
db_session.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||||
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
|
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||||
|
else:
|
||||||
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
|
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||||
|
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||||
|
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||||
|
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||||
|
|
||||||
# async def main():
|
# async def main():
|
||||||
# """主函数 - 运行工作流"""
|
# """主函数 - 运行工作流"""
|
||||||
# langchain_messages = [
|
# langchain_messages = [
|
||||||
# {
|
# {
|
||||||
# "role": "user",
|
# "role": "user",
|
||||||
# "content": "今天周五好开心啊"
|
# "content": "今天周五去爬山"
|
||||||
# },
|
# },
|
||||||
# {
|
# {
|
||||||
# "role": "assistant",
|
# "role": "assistant",
|
||||||
# "content": "你也这么觉得,我也是耶"
|
# "content": "好耶"
|
||||||
# }
|
# }
|
||||||
#
|
#
|
||||||
# ]
|
# ]
|
||||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||||
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||||
# result=await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
#
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
|
|||||||
class Write_UserInput(BaseModel):
|
class Write_UserInput(BaseModel):
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
class AgentMemory_Long_Term(ABC):
|
||||||
|
"""长期记忆配置常量"""
|
||||||
|
STORAGE_NEO4J = "neo4j"
|
||||||
|
STORAGE_RAG = "rag"
|
||||||
|
STRATEGY_AGGREGATE = "aggregate"
|
||||||
|
STRATEGY_CHUNK = "chunk"
|
||||||
|
STRATEGY_TIME = "time"
|
||||||
|
DEFAULT_SCOPE = 6
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -110,6 +110,8 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
result = task_service.get_task_memory_read_result(task.id)
|
result = task_service.get_task_memory_read_result(task.id)
|
||||||
status = result.get("status")
|
status = result.get("status")
|
||||||
logger.info(f"读取任务状态:{status}")
|
logger.info(f"读取任务状态:{status}")
|
||||||
|
if memory_content:
|
||||||
|
memory_content = memory_content['answer']
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
@@ -123,7 +125,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str
|
|||||||
"content_length": len(str(memory_content))
|
"content_length": len(str(memory_content))
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return f"检索到以下历史记忆:\n\n{memory_content}"
|
return f"检索到以下历史记忆:\n\n{memory_content}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
logger.error("长期记忆检索失败", extra={"error": str(e), "error_type": type(e).__name__})
|
||||||
|
|||||||
Reference in New Issue
Block a user