config_config替换成memory_config

This commit is contained in:
lixinyue
2026-01-22 18:43:22 +08:00
parent f3f9211c9c
commit 8db4f914d8
21 changed files with 158 additions and 201 deletions

View File

@@ -9,6 +9,7 @@ import os
import re
import time
import uuid
from uuid import UUID
from typing import Any, AsyncGenerator, Dict, List, Optional
import redis
@@ -266,7 +267,7 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
@@ -319,85 +320,52 @@ class MemoryAgentService:
raise ValueError(error_msg)
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
try:
if storage_type == "rag":
# For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
print(massages)
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# try:
# if storage_type == "rag":
# # For RAG storage, convert messages to single string
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
# return result
# else:
# async with make_write_graph() as graph:
# config = {"configurable": {"thread_id": end_user_id}}
# # Convert structured messages to LangChain messages
# langchain_messages = []
# for msg in messages:
# if msg['role'] == 'user':
# langchain_messages.append(HumanMessage(content=msg['content']))
# elif msg['role'] == 'assistant':
# langchain_messages.append(AIMessage(content=msg['content']))
#
# # 初始状态 - 包含所有必要字段
# initial_state = {
# "messages": langchain_messages,
# "end_user_id": end_user_id,
# "memory_config": memory_config
# }
#
# # 获取节点更新信息
# async for update_event in graph.astream(
# initial_state,
# stream_mode="updates",
# config=config
# ):
# for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
# massagesstatus = massages.get('write_result')['status']
# contents = massages.get('write_result')
# # Convert messages back to string for logging
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# except Exception as e:
# # Ensure proper error handling and logging
# error_msg = f"Write operation failed: {str(e)}"
# logger.error(error_msg)
# if audit_logger:
# duration = time.time() - start_time
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
# raise ValueError(error_msg)
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
@@ -408,7 +376,7 @@ class MemoryAgentService:
message: str,
history: List[Dict],
search_switch: str,
config_id: Optional[str],
config_id: Optional[UUID],
db: Session,
storage_type: str,
user_rag_memory_id: str) -> Dict:
@@ -685,7 +653,7 @@ class MemoryAgentService:
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
"""
Determine the type of user message (read or write)
Updated to eliminate global variables in favor of explicit parameters.
@@ -716,7 +684,7 @@ class MemoryAgentService:
retrieve_info: str,
history: List[Dict],
query: str,
config_id: str,
config_id: UUID,
db: Session
) -> str:
"""