diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 7e0015ae..e4204e83 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,14 +11,17 @@ import os import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence +from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages, format_parsing, messages_parse from app.core.memory.agent.langgraph_graph.write_graph import long_term_storage +from app.core.memory.agent.utils.write_tools import write from app.db import get_db from app.core.logging_config import get_business_logger from app.core.memory.agent.utils.redis_tool import store from app.core.models import RedBearLLM, RedBearModelConfig from app.models.models_model import ModelType from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_agent_service import ( get_end_user_connected_config, ) @@ -148,106 +151,6 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - # TODO: 移到memory module - async def term_memory_save(self,long_term_messages,actual_config_id,end_user_id,type): - db = next(get_db()) - #TODO: 魔法数字 - scope=6 - - try: - repo = LongTermMemoryRepository(db) - await long_term_storage(long_term_type="chunk", langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) - - from app.core.memory.agent.utils.redis_tool import write_store - result = write_store.get_session_by_userid(end_user_id) - - # Handle case where no session exists in Redis (returns False) - if not result or result is False: - logger.debug(f"No existing session in Redis for user {end_user_id}, skipping short-term memory update") - return - - if type=="chunk" or type=="aggregate": - data = await format_parsing(result, "dict") - chunk_data = data[:scope] - if len(chunk_data)==scope: - repo.upsert(end_user_id, chunk_data) - logger.info(f'写入短长期:') - else: - # TODO: This branch handles type="time" strategy, currently unused. - # Will be activated when time-based long-term storage is implemented. - # TODO: 魔法数字 - extract 5 to a constant - long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - logger.debug(f"No recent sessions in Redis for user {end_user_id}") - return - long_messages = await messages_parse(long_time_data) - repo.upsert(end_user_id, long_messages) - logger.info(f'写入短长期:') - finally: - db.close() - - async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): - """ - 写入记忆(支持结构化消息) - - Args: - storage_type: 存储类型 (neo4j/rag) - end_user_id: 终端用户ID - user_message: 用户消息内容 - ai_message: AI 回复内容 - user_rag_memory_id: RAG 记忆ID - actual_end_user_id: 实际用户ID - actual_config_id: 配置ID - - 逻辑说明: - - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - - Neo4j 模式:使用结构化消息列表 - 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] - 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) - 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 - """ - - db = next(get_db()) - try: - actual_config_id=resolve_config_id(actual_config_id, db) - - if storage_type == "rag": - # RAG 模式:组合消息为字符串格式(保持原有逻辑) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - else: - # Neo4j 模式:使用结构化消息列表 - structured_messages = [] - - # 始终添加用户消息(如果不为空) - if user_message: - structured_messages.append({"role": "user", "content": user_message}) - - # 只有当 AI 回复不为空时才添加 assistant 消息 - if ai_message: - structured_messages.append({"role": "assistant", "content": ai_message}) - - # 如果没有消息,直接返回 - if not structured_messages: - logger.warning(f"No messages to write for user {actual_end_user_id}") - return - - logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") - write_id = write_message_task.delay( - actual_end_user_id, # end_user_id: 用户ID - structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] - actual_config_id, # config_id: 配置ID - storage_type, # storage_type: "neo4j" - user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) - ) - logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") - write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') - finally: - db.close() async def chat( self, message: str, @@ -321,14 +224,14 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - long_term_messages=await agent_chat_messages(message_chat,content) - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") + if storage_type == "rag": + await write_rag(end_user_id, message_chat, content, user_rag_memory_id) + else: + long_term_messages=await agent_chat_messages(message_chat,content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await long_term_storage(long_term_type="chunk",langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=2) + '''长期''' + await term_memory_save(long_term_messages,actual_config_id,end_user_id,"chunk") response = { "content": content, "model": self.model_name, @@ -459,14 +362,15 @@ class LangChainAgent: yield total_tokens break if memory_flag: - # TODO: DUPLICATE WRITE - Remove this immediate write once batched write (term_memory_save) is verified stable. - # This writes to Neo4j immediately via Celery task, but term_memory_save also writes to Neo4j - # when the window buffer reaches scope (6 messages). This causes duplicate entities in the graph. - # Recommended: Keep only term_memory_save for batched efficiency, or only self.write for real-time. - long_term_messages = await agent_chat_messages(message_chat, full_content) - await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) - # Batched long-term memory storage (Redis buffer + Neo4j when window full) - await self.term_memory_save(long_term_messages, actual_config_id, end_user_id, "chunk") + if storage_type == AgentMemory_Long_Term.STORAGE_RAG: + await write_rag(end_user_id, message_chat, full_content, user_rag_memory_id) + else: + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + CHUNK=AgentMemory_Long_Term.STRATEGY_CHUNK + SCOPE=AgentMemory_Long_Term.DEFAULT_SCOPE + long_term_messages = await agent_chat_messages(message_chat, full_content) + await long_term_storage(long_term_type=CHUNK,langchain_messages=long_term_messages,memory_config=actual_config_id,end_user_id=end_user_id,scope=SCOPE) + await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK,scope=SCOPE) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index e9de02b6..ab65caa7 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -1,8 +1,9 @@ +import json import os from app.core.logging_config import get_agent_logger -from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format, format_parsing -from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse +from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_ @@ -10,46 +11,111 @@ from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import count_store from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context +from app.db import get_db_context, get_db +from app.repositories.memory_short_repository import LongTermMemoryRepository +from app.schemas.memory_agent_schema import AgentMemory_Long_Term +from app.services.memory_konwledges_server import write_rag +from app.services.task_service import get_task_memory_write_result +from app.tasks import write_message_task +from app.utils.config_utils import resolve_config_id + logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') +async def write_rag(end_user_id, user_message, ai_message, user_rag_memory_id): + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) + logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') +async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, + actual_config_id, long_term_messages=[]): + """ + 写入记忆(支持结构化消息) -async def write_messages(end_user_id,langchain_messages,memory_config): - ''' - 写入数据到neo4j: - Args: + Args: + storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID - memory_config: 内存配置对象 - langchain_messages:原始数据LIST - ''' + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ + + db = next(get_db()) try: + actual_config_id = resolve_config_id(actual_config_id, db) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } + # 始终添加用户消息(如果不为空) + if isinstance(user_message, str) and user_message.strip() != "": + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if isinstance(ai_message, str) and ai_message.strip() != "": + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果提供了 long_term_messages,使用它替代 structured_messages + if long_term_messages and isinstance(long_term_messages, list): + structured_messages = long_term_messages + elif long_term_messages and isinstance(long_term_messages, str): + # 如果是 JSON 字符串,先解析 + try: + structured_messages = json.loads(long_term_messages) + except json.JSONDecodeError: + logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}") + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + logger.info( + f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # end_user_id: 用户ID + structured_messages, # message: JSON 字符串格式的消息列表 + str(actual_config_id), # config_id: 配置ID字符串 + storage_type, # storage_type: "neo4j" + user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") + write_status = get_task_memory_write_result(str(write_id)) + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') + finally: + db.close() + +async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope): + db = next(get_db()) + try: + repo = LongTermMemoryRepository(db) + await long_term_storage(long_term_type=AgentMemory_Long_Term.STRATEGY_CHUNK, langchain_messages=long_term_messages, + memory_config=actual_config_id, end_user_id=end_user_id, scope=scope) + + from app.core.memory.agent.utils.redis_tool import write_store + result = write_store.get_session_by_userid(end_user_id) + if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + data = await format_parsing(result, "dict") + chunk_data = data[:scope] + if len(chunk_data)==scope: + repo.upsert(end_user_id, chunk_data) + logger.info(f'---------写入短长期-----------') + else: + long_time_data = write_store.find_user_recent_sessions(end_user_id, 5) + long_messages = await messages_parse(long_time_data) + repo.upsert(end_user_id, long_messages) + logger.info(f'写入短长期:') + finally: + db.close() - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - # TODO:删除 - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - print(contents) - except Exception as e: - import traceback - traceback.print_exc() '''根据窗口''' async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): ''' @@ -61,25 +127,26 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope): scope:窗口大小 ''' scope=scope - redis_messages = [] is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_id is not False: is_end_user_id = count_store.get_sessions_count(end_user_id)[0] redis_messages = count_store.get_sessions_count(end_user_id)[1] if is_end_user_id and int(is_end_user_id) != int(scope): - print(is_end_user_id) is_end_user_id += 1 langchain_messages += redis_messages count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) elif int(is_end_user_id) == int(scope): - print('写入长期记忆,并且设置为0') - print(is_end_user_id) - formatted_messages = await chat_data_format(redis_messages) - print(100*'-') - print(formatted_messages) - print(100*'-') - await write_messages(end_user_id, formatted_messages, memory_config) - count_store.update_sessions_count(end_user_id, 0, '') + logger.info('写入长期记忆NEO4J') + formatted_messages = (redis_messages) + # 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用) + if hasattr(memory_config, 'config_id'): + config_id = memory_config.config_id + else: + config_id = memory_config + + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + config_id, formatted_messages) + count_store.update_sessions_count(end_user_id, 1, langchain_messages) else: count_store.save_sessions_count(end_user_id, 1, langchain_messages) @@ -93,12 +160,15 @@ async def memory_long_term_storage(end_user_id,memory_config,time): memory_config: 内存配置对象 ''' long_time_data = write_store.find_user_recent_sessions(end_user_id, time) - # Handle case where no session exists in Redis (returns False or empty) - if not long_time_data or long_time_data is False: - return - format_messages = await chat_data_format(long_time_data) + format_messages = (long_time_data) + messages=[] + memory_config=memory_config.config_id + for i in format_messages: + message=json.loads(i['Query']) + messages+= message if format_messages!=[]: - await write_messages(end_user_id, format_messages, memory_config) + await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id, + memory_config, messages) '''聚合判断''' async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict: """ @@ -109,13 +179,12 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] memory_config: 内存配置对象 """ - + try: # 1. 获取历史会话数据(使用新方法) result = write_store.get_all_sessions_by_end_user_id(end_user_id) - - # Handle case where no session exists in Redis (returns False or empty) - if not result or result is False: + history = await format_parsing(result) + if not result: history = [] else: history = await format_parsing(result) @@ -154,7 +223,8 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config } if not structured.is_same_event: logger.info(result_dict) - await write_messages(end_user_id, output_value, memory_config) + await write("neo4j", end_user_id, "", "", None, end_user_id, + memory_config.config_id, output_value) return result_dict except Exception as e: diff --git a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py index a1fb8226..d0be8e5c 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/write_tool.py @@ -26,13 +26,13 @@ async def format_parsing(messages: list,type:str='string'): role = content['role'] content = content['content'] if type == "string": - if role == 'human': + if role == 'human' or role=="user": content = '用户:' + content else: content = 'AI:' + content result.append(content) - if type == "dict": - if role == 'human': + if type == "dict" : + if role == 'human' or role=="user": user.append( content) else: ai.append(content) @@ -57,33 +57,7 @@ async def messages_parse(messages: list | dict): for key, values in zip(user, ai): database.append({key, values}) return database -async def chat_data_format(messages: list | dict): - """ - 将消息格式化为 LangChain 消息格式 - - Args: - messages: 消息列表或字典 - - Returns: - LangChain 消息列表 - """ - langchain_messages = [] - if isinstance(messages, list): - for msg in messages: - if 'role' in msg.keys(): - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - langchain_messages.append(AIMessage(content=msg['content'])) - if "Query" in msg.keys(): - langchain_messages.append(HumanMessage(content=msg['Query'])) - langchain_messages.append(AIMessage(content=msg['Answer'])) - if isinstance(messages, dict): - if messages['type'] == 'human': - langchain_messages.append(HumanMessage(content=messages['content'])) - elif messages['type'] == 'ai': - langchain_messages.append(AIMessage(content=messages['content'])) - return langchain_messages + async def agent_chat_messages(user_content,ai_content): messages = [ diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 9547c866..64a1296c 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -7,7 +7,7 @@ from contextlib import asynccontextmanager from langgraph.constants import END, START from langgraph.graph import StateGraph -from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, chat_data_format, messages_parse +from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState @@ -42,9 +42,8 @@ async def make_write_graph(): async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6): from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment - from app.core.memory.agent.langgraph_graph.tools.write_tool import chat_data_format from app.core.memory.agent.utils.redis_tool import write_store - write_store.save_session_write(end_user_id, await chat_data_format(langchain_messages)) + write_store.save_session_write(end_user_id, (langchain_messages)) # 获取数据库会话 db_session = next(get_db()) config_service = MemoryConfigService(db_session) @@ -62,31 +61,24 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[ """方案三:聚合判断""" await aggregate_judgment(end_user_id, langchain_messages, memory_config) -# + # async def main(): # """主函数 - 运行工作流""" # langchain_messages = [ # { # "role": "user", -# "content": "今天周五好开心啊" +# "content": "今天周五去爬山" # }, # { # "role": "assistant", -# "content": "你也这么觉得,我也是耶" +# "content": "好耶" # } # # ] # end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID # memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# # await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# from app.core.memory.agent.utils.redis_tool import write_store -# result=write_store.get_session_by_userid(end_user_id) -# data=await format_parsing(result,"dict") -# chunk_data=data[:6] +# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) # -# long_time_data = write_store.find_user_recent_sessions(end_user_id, 240) -# long_=await messages_parse(long_time_data) -# print(long_) # # # if __name__ == "__main__": diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b6f50dd7..1a5017eb 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import Optional from pydantic import BaseModel @@ -14,4 +15,15 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): messages: list[dict] end_user_id: str - config_id: Optional[str] = None \ No newline at end of file + config_id: Optional[str] = None + +class AgentMemory_Long_Term(ABC): + """长期记忆配置常量""" + STORAGE_NEO4J = "neo4j" + STORAGE_RAG = "rag" + STRATEGY_AGGREGATE = "aggregate" + STRATEGY_CHUNK = "chunk" + STRATEGY_TIME = "time" + DEFAULT_SCOPE = 6 + +