memory_BUG
This commit is contained in:
@@ -11,14 +11,17 @@ import os
|
|||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
|
from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage
|
||||||
|
from app.core.memory.agent.utils.write_tools import write
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.agent.utils.redis_tool import store
|
from app.core.memory.agent.utils.redis_tool import store
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
from app.models.models_model import ModelType
|
from app.models.models_model import ModelType
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_agent_service import (
|
from app.services.memory_agent_service import (
|
||||||
get_end_user_connected_config,
|
get_end_user_connected_config,
|
||||||
)
|
)
|
||||||
@@ -148,106 +151,6 @@ class LangChainAgent:
|
|||||||
messages.append(HumanMessage(content=user_content))
|
messages.append(HumanMessage(content=user_content))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
# TODO: 移到memory module
|
|
||||||
async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type):
|
|
||||||
db = next(get_db())
|
|
||||||
#TODO: 魔法数字
|
|
||||||
scope=6
|
|
||||||
|
|
||||||
try:
|
|
||||||
repo = LongTermMemoryRepository(db)
|
|
||||||
await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages,
|
|
||||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
|
||||||
|
|
||||||
# Handle case where no session exists in Redis (returns False)
|
|
||||||
if not result or result is False:
|
|
||||||
logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update")
|
|
||||||
return
|
|
||||||
|
|
||||||
if type=="chunk" or type=="aggregate":
|
|
||||||
data = await format_parsing(result, "dict")
|
|
||||||
chunk_data = data[:scope]
|
|
||||||
if len(chunk_data)==scope:
|
|
||||||
repo.upsert(end_user_id, chunk_data)
|
|
||||||
logger.info(f'写入短长期:')
|
|
||||||
else:
|
|
||||||
# TODO: This branch handles type="time" strategy, currently unused.
|
|
||||||
# Will be activated when time-based long-term storage is implemented.
|
|
||||||
# TODO: 魔法数字 - extract 5 to a constant
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
|
||||||
if not long_time_data or long_time_data is False:
|
|
||||||
logger.debug(f"No recent sessions in Redis for user {end_user_id}")
|
|
||||||
return
|
|
||||||
long_messages = await messages_parse(long_time_data)
|
|
||||||
repo.upsert(end_user_id, long_messages)
|
|
||||||
logger.info(f'写入短长期:')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
|
||||||
"""
|
|
||||||
写入记忆(支持结构化消息)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
user_message: 用户消息内容
|
|
||||||
ai_message: AI 回复内容
|
|
||||||
user_rag_memory_id: RAG 记忆ID
|
|
||||||
actual_end_user_id: 实际用户ID
|
|
||||||
actual_config_id: 配置ID
|
|
||||||
|
|
||||||
逻辑说明:
|
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
|
||||||
- Neo4j 模式:使用结构化消息列表
|
|
||||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
|
||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
|
||||||
"""
|
|
||||||
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
actual_config_id=resolve_config_id(actual_config_id, db)
|
|
||||||
|
|
||||||
if storage_type == "rag":
|
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
else:
|
|
||||||
# Neo4j 模式:使用结构化消息列表
|
|
||||||
structured_messages = []
|
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
|
||||||
if user_message:
|
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
|
||||||
if ai_message:
|
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
|
||||||
if not structured_messages:
|
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
|
||||||
write_id = write_message_task.delay(
|
|
||||||
actual_end_user_id, # end_user_id: 用户ID
|
|
||||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
|
||||||
actual_config_id, # config_id: 配置ID
|
|
||||||
storage_type, # storage_type: "neo4j"
|
|
||||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
|
||||||
)
|
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
@@ -321,14 +224,14 @@ class LangChainAgent:
|
|||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
long_term_messages=await agent_chat_messages(message_chat,content)
|
if storage_type == "rag":
|
||||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
await write_rag(end_user_id, message_chat, content, user_rag_memory_id)
|
||||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
else:
|
||||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
long_term_messages=await agent_chat_messages(message_chat,content)
|
||||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
await long_term_storage(long_term_type="chunk",langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=2)
|
||||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
'''长期'''
|
||||||
await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
await term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk")
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -459,14 +362,15 @@ class LangChainAgent:
|
|||||||
yield total_tokens
|
yield total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
if memory_flag:
|
||||||
# TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable.
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
# This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j
|
await write_rag(end_user_id, message_chat, full_content, user_rag_memory_id)
|
||||||
# when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph.
|
else:
|
||||||
# Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time.
|
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||||
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
CHUNK=AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
SCOPE=AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
# Batched long-term memory storage (Redis buffer + Neo4j when window full)
|
long_term_messages = await agent_chat_messages(message_chat, full_content)
|
||||||
await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk")
|
await long_term_storage(long_term_type=CHUNK,langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=SCOPE)
|
||||||
|
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK,scope=SCOPE)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||||
|
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
@@ -10,46 +11,111 @@ from app.core.memory.agent.utils.redis_tool import write_store
|
|||||||
from app.core.memory.agent.utils.redis_tool import count_store
|
from app.core.memory.agent.utils.redis_tool import count_store
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context, get_db
|
||||||
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
from app.services.task_service import get_task_memory_write_result
|
||||||
|
from app.tasks import write_message_task
|
||||||
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
async def write_rag(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||||
|
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||||
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
|
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||||
|
actual_config_id, long_term_messages=[]):
|
||||||
|
"""
|
||||||
|
写入记忆(支持结构化消息)
|
||||||
|
|
||||||
async def write_messages(end_user_id,langchain_messages,memory_config):
|
Args:
|
||||||
'''
|
storage_type: 存储类型 (neo4j/rag)
|
||||||
写入数据到neo4j:
|
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
end_user_id: 终端用户ID
|
||||||
memory_config: 内存配置对象
|
user_message: 用户消息内容
|
||||||
langchain_messages:原始数据LIST
|
ai_message: AI 回复内容
|
||||||
'''
|
user_rag_memory_id: RAG 记忆ID
|
||||||
|
actual_end_user_id: 实际用户ID
|
||||||
|
actual_config_id: 配置ID
|
||||||
|
|
||||||
|
逻辑说明:
|
||||||
|
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||||
|
- Neo4j 模式:使用结构化消息列表
|
||||||
|
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||||
|
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||||
|
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
db = next(get_db())
|
||||||
try:
|
try:
|
||||||
|
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||||
|
# Neo4j 模式:使用结构化消息列表
|
||||||
|
structured_messages = []
|
||||||
|
|
||||||
async with make_write_graph() as graph:
|
# 始终添加用户消息(如果不为空)
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
if isinstance(user_message, str) and user_message.strip() != "":
|
||||||
# 初始状态 - 包含所有必要字段
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
initial_state = {
|
|
||||||
"messages": langchain_messages,
|
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||||
"end_user_id": end_user_id,
|
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||||
"memory_config": memory_config
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
}
|
|
||||||
|
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||||
|
if long_term_messages and isinstance(long_term_messages, list):
|
||||||
|
structured_messages = long_term_messages
|
||||||
|
elif long_term_messages and isinstance(long_term_messages, str):
|
||||||
|
# 如果是 JSON 字符串,先解析
|
||||||
|
try:
|
||||||
|
structured_messages = json.loads(long_term_messages)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||||
|
|
||||||
|
# 如果没有消息,直接返回
|
||||||
|
if not structured_messages:
|
||||||
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
|
write_id = write_message_task.delay(
|
||||||
|
actual_end_user_id, # end_user_id: 用户ID
|
||||||
|
structured_messages, # message: JSON 字符串格式的消息列表
|
||||||
|
str(actual_config_id), # config_id: 配置ID字符串
|
||||||
|
storage_type, # storage_type: "neo4j"
|
||||||
|
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||||
|
)
|
||||||
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||||
|
db = next(get_db())
|
||||||
|
try:
|
||||||
|
repo = LongTermMemoryRepository(db)
|
||||||
|
await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages,
|
||||||
|
memory_config=actual_config_id, end_user_id=end_user_id, scope=scope)
|
||||||
|
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
|
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
|
data = await format_parsing(result, "dict")
|
||||||
|
chunk_data = data[:scope]
|
||||||
|
if len(chunk_data)==scope:
|
||||||
|
repo.upsert(end_user_id, chunk_data)
|
||||||
|
logger.info(f'---------写入短长期-----------')
|
||||||
|
else:
|
||||||
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||||
|
long_messages = await messages_parse(long_time_data)
|
||||||
|
repo.upsert(end_user_id, long_messages)
|
||||||
|
logger.info(f'写入短长期:')
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
# 获取节点更新信息
|
|
||||||
async for update_event in graph.astream(
|
|
||||||
initial_state,
|
|
||||||
stream_mode="updates",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
if 'save_neo4j' == node_name:
|
|
||||||
massages = node_data
|
|
||||||
# TODO:删除
|
|
||||||
massagesstatus = massages.get('write_result')['status']
|
|
||||||
contents = massages.get('write_result')
|
|
||||||
print(contents)
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
'''根据窗口'''
|
'''根据窗口'''
|
||||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||||
'''
|
'''
|
||||||
@@ -61,25 +127,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
|||||||
scope:窗口大小
|
scope:窗口大小
|
||||||
'''
|
'''
|
||||||
scope=scope
|
scope=scope
|
||||||
redis_messages = []
|
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||||
if is_end_user_id is not False:
|
if is_end_user_id is not False:
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||||
print(is_end_user_id)
|
|
||||||
is_end_user_id += 1
|
is_end_user_id += 1
|
||||||
langchain_messages += redis_messages
|
langchain_messages += redis_messages
|
||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
elif int(is_end_user_id) == int(scope):
|
||||||
print('写入长期记忆,并且设置为0')
|
logger.info('写入长期记忆NEO4J')
|
||||||
print(is_end_user_id)
|
formatted_messages = (redis_messages)
|
||||||
formatted_messages = await chat_data_format(redis_messages)
|
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||||
print(100*'-')
|
if hasattr(memory_config, 'config_id'):
|
||||||
print(formatted_messages)
|
config_id = memory_config.config_id
|
||||||
print(100*'-')
|
else:
|
||||||
await write_messages(end_user_id, formatted_messages, memory_config)
|
config_id = memory_config
|
||||||
count_store.update_sessions_count(end_user_id, 0, '')
|
|
||||||
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
config_id, formatted_messages)
|
||||||
|
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
else:
|
else:
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
|
|
||||||
@@ -93,12 +160,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time):
|
|||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
'''
|
'''
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
format_messages = (long_time_data)
|
||||||
if not long_time_data or long_time_data is False:
|
messages=[]
|
||||||
return
|
memory_config=memory_config.config_id
|
||||||
format_messages = await chat_data_format(long_time_data)
|
for i in format_messages:
|
||||||
|
message=json.loads(i['Query'])
|
||||||
|
messages+= message
|
||||||
if format_messages!=[]:
|
if format_messages!=[]:
|
||||||
await write_messages(end_user_id, format_messages, memory_config)
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config, messages)
|
||||||
'''聚合判断'''
|
'''聚合判断'''
|
||||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -109,13 +179,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
memory_config: 内存配置对象
|
memory_config: 内存配置对象
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 获取历史会话数据(使用新方法)
|
# 1. 获取历史会话数据(使用新方法)
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
|
history = await format_parsing(result)
|
||||||
# Handle case where no session exists in Redis (returns False or empty)
|
if not result:
|
||||||
if not result or result is False:
|
|
||||||
history = []
|
history = []
|
||||||
else:
|
else:
|
||||||
history = await format_parsing(result)
|
history = await format_parsing(result)
|
||||||
@@ -154,7 +223,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
}
|
}
|
||||||
if not structured.is_same_event:
|
if not structured.is_same_event:
|
||||||
logger.info(result_dict)
|
logger.info(result_dict)
|
||||||
await write_messages(end_user_id, output_value, memory_config)
|
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config.config_id, output_value)
|
||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -26,13 +26,13 @@ async def format_parsing(messages: list,type:str='string'):
|
|||||||
role = content['role']
|
role = content['role']
|
||||||
content = content['content']
|
content = content['content']
|
||||||
if type == "string":
|
if type == "string":
|
||||||
if role == 'human':
|
if role == 'human' or role=="user":
|
||||||
content = '用户:' + content
|
content = '用户:' + content
|
||||||
else:
|
else:
|
||||||
content = 'AI:' + content
|
content = 'AI:' + content
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if type == "dict":
|
if type == "dict" :
|
||||||
if role == 'human':
|
if role == 'human' or role=="user":
|
||||||
user.append( content)
|
user.append( content)
|
||||||
else:
|
else:
|
||||||
ai.append(content)
|
ai.append(content)
|
||||||
@@ -57,33 +57,7 @@ async def messages_parse(messages: list | dict):
|
|||||||
for key, values in zip(user, ai):
|
for key, values in zip(user, ai):
|
||||||
database.append({key, values})
|
database.append({key, values})
|
||||||
return database
|
return database
|
||||||
async def chat_data_format(messages: list | dict):
|
|
||||||
"""
|
|
||||||
将消息格式化为 LangChain 消息格式
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: 消息列表或字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
LangChain 消息列表
|
|
||||||
"""
|
|
||||||
langchain_messages = []
|
|
||||||
if isinstance(messages, list):
|
|
||||||
for msg in messages:
|
|
||||||
if 'role' in msg.keys():
|
|
||||||
if msg['role'] == 'user':
|
|
||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
|
||||||
elif msg['role'] == 'assistant':
|
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
|
||||||
if "Query" in msg.keys():
|
|
||||||
langchain_messages.append(HumanMessage(content=msg['Query']))
|
|
||||||
langchain_messages.append(AIMessage(content=msg['Answer']))
|
|
||||||
if isinstance(messages, dict):
|
|
||||||
if messages['type'] == 'human':
|
|
||||||
langchain_messages.append(HumanMessage(content=messages['content']))
|
|
||||||
elif messages['type'] == 'ai':
|
|
||||||
langchain_messages.append(AIMessage(content=messages['content']))
|
|
||||||
return langchain_messages
|
|
||||||
|
|
||||||
async def agent_chat_messages(user_content,ai_content):
|
async def agent_chat_messages(user_content,ai_content):
|
||||||
messages = [
|
messages = [
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from contextlib import asynccontextmanager
|
|||||||
from langgraph.constants import END, START
|
from langgraph.constants import END, START
|
||||||
from langgraph.graph import StateGraph
|
from langgraph.graph import StateGraph
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||||
@@ -42,9 +42,8 @@ async def make_write_graph():
|
|||||||
|
|
||||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages))
|
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
db_session = next(get_db())
|
db_session = next(get_db())
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
@@ -62,31 +61,24 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
|||||||
"""方案三:聚合判断"""
|
"""方案三:聚合判断"""
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
#
|
|
||||||
# async def main():
|
# async def main():
|
||||||
# """主函数 - 运行工作流"""
|
# """主函数 - 运行工作流"""
|
||||||
# langchain_messages = [
|
# langchain_messages = [
|
||||||
# {
|
# {
|
||||||
# "role": "user",
|
# "role": "user",
|
||||||
# "content": "今天周五好开心啊"
|
# "content": "今天周五去爬山"
|
||||||
# },
|
# },
|
||||||
# {
|
# {
|
||||||
# "role": "assistant",
|
# "role": "assistant",
|
||||||
# "content": "你也这么觉得,我也是耶"
|
# "content": "好耶"
|
||||||
# }
|
# }
|
||||||
#
|
#
|
||||||
# ]
|
# ]
|
||||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||||
# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||||
# from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
# result=write_store.get_session_by_userid(end_user_id)
|
|
||||||
# data=await format_parsing(result,"dict")
|
|
||||||
# chunk_data=data[:6]
|
|
||||||
#
|
#
|
||||||
# long_time_data = write_store.find_user_recent_sessions(end_user_id, 240)
|
|
||||||
# long_=await messages_parse(long_time_data)
|
|
||||||
# print(long_)
|
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
# if __name__ == "__main__":
|
# if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
from abc import ABC
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -14,4 +15,15 @@ class UserInput(BaseModel):
|
|||||||
class Write_UserInput(BaseModel):
|
class Write_UserInput(BaseModel):
|
||||||
messages: list[dict]
|
messages: list[dict]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
class AgentMemory_Long_Term(ABC):
|
||||||
|
"""长期记忆配置常量"""
|
||||||
|
STORAGE_NEO4J = "neo4j"
|
||||||
|
STORAGE_RAG = "rag"
|
||||||
|
STRATEGY_AGGREGATE = "aggregate"
|
||||||
|
STRATEGY_CHUNK = "chunk"
|
||||||
|
STRATEGY_TIME = "time"
|
||||||
|
DEFAULT_SCOPE = 6
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user