diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index a1337085..ad5e3048 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -164,7 +164,7 @@ async def write_server( try: result = await memory_agent_service.write_memory( user_input.end_user_id, - user_input.message, + user_input.messages, config_id, db, storage_type, @@ -290,7 +290,7 @@ async def read_server( ) if str(user_input.search_switch) == "2": retrieve_info = result['answer'] - history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) + history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) query = user_input.message # 调用 memory_agent_service 的方法生成最终答案 @@ -596,7 +596,7 @@ async def status_type( last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) result = await memory_agent_service.classify_message_type( - user_input.message, + user_input.messages, user_input.config_id, db ) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index e2a61045..1dab1b0a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -21,9 +21,9 @@ async def write_node(state: WriteState) -> WriteState: memory_config=state.get('memory_config', '') try: result=await write( - content=content, end_user_id=end_user_id, memory_config=memory_config, + messages=content, # 修复:使用正确的参数名 messages ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index ce55286e..b8bc58eb 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -77,10 +77,25 @@ async def write( # Step 1: Load and chunk data step_start = time.time() + + # Convert messages list to content string + # messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...] + if isinstance(messages, list) and len(messages) > 0: + # Extract content from the last user message or concatenate all messages + if isinstance(messages[-1], dict) and 'content' in messages[-1]: + content = messages[-1]['content'] + else: + # Fallback: concatenate all message contents + content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)]) + elif isinstance(messages, str): + content = messages + else: + content = str(messages) + chunked_dialogs = await get_chunked_dialogs( chunker_strategy=chunker_strategy, end_user_id=end_user_id, - messages=messages, + content=content, # 修复:使用 content 参数而不是 messages ref_id=ref_id, config_id=config_id, ) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index e475bef0..8a0d5a39 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -36,6 +36,7 @@ from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from langchain_core.messages import AIMessage from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func @@ -57,7 +58,6 @@ class MemoryAgentService: def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context): duration = time.time() - start_time - if str(messages) == 'success': logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 @@ -266,7 +266,7 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: + async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id @@ -319,53 +319,85 @@ class MemoryAgentService: raise ValueError(error_msg) - try: - if storage_type == "rag": - # For RAG storage, convert messages to single string - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(end_user_id, message_text, user_rag_memory_id) - return result - else: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # Convert structured messages to LangChain messages - langchain_messages = [] - for msg in messages: - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - from langchain_core.messages import AIMessage - langchain_messages.append(AIMessage(content=msg['content'])) - - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } + async with make_write_graph() as graph: + config = {"configurable": {"thread_id": end_user_id}} + # Convert structured messages to LangChain messages + langchain_messages = [] + for msg in messages: + if msg['role'] == 'user': + langchain_messages.append(HumanMessage(content=msg['content'])) + elif msg['role'] == 'assistant': + langchain_messages.append(AIMessage(content=msg['content'])) - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - # Convert messages back to string for logging - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) - except Exception as e: - # Ensure proper error handling and logging - error_msg = f"Write operation failed: {str(e)}" - logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) - raise ValueError(error_msg) + # 初始状态 - 包含所有必要字段 + initial_state = { + "messages": langchain_messages, + "end_user_id": end_user_id, + "memory_config": memory_config + } + + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + print(massages) + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + # Convert messages back to string for logging + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) + + # try: + # if storage_type == "rag": + # # For RAG storage, convert messages to single string + # message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + # result = await write_rag(end_user_id, message_text, user_rag_memory_id) + # return result + # else: + # async with make_write_graph() as graph: + # config = {"configurable": {"thread_id": end_user_id}} + # # Convert structured messages to LangChain messages + # langchain_messages = [] + # for msg in messages: + # if msg['role'] == 'user': + # langchain_messages.append(HumanMessage(content=msg['content'])) + # elif msg['role'] == 'assistant': + # langchain_messages.append(AIMessage(content=msg['content'])) + # + # # 初始状态 - 包含所有必要字段 + # initial_state = { + # "messages": langchain_messages, + # "end_user_id": end_user_id, + # "memory_config": memory_config + # } + # + # # 获取节点更新信息 + # async for update_event in graph.astream( + # initial_state, + # stream_mode="updates", + # config=config + # ): + # for node_name, node_data in update_event.items(): + # if 'save_neo4j' == node_name: + # massages = node_data + # massagesstatus = massages.get('write_result')['status'] + # contents = massages.get('write_result') + # # Convert messages back to string for logging + # message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + # return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) + # except Exception as e: + # # Ensure proper error handling and logging + # error_msg = f"Write operation failed: {str(e)}" + # logger.error(error_msg) + # if audit_logger: + # duration = time.time() - start_time + # audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + # raise ValueError(error_msg)