config_config替换成memory_config
This commit is contained in:
@@ -9,6 +9,7 @@ import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import redis
|
||||
@@ -266,7 +267,7 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
@@ -319,85 +320,52 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
print(massages)
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
|
||||
# try:
|
||||
# if storage_type == "rag":
|
||||
# # For RAG storage, convert messages to single string
|
||||
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
# return result
|
||||
# else:
|
||||
# async with make_write_graph() as graph:
|
||||
# config = {"configurable": {"thread_id": end_user_id}}
|
||||
# # Convert structured messages to LangChain messages
|
||||
# langchain_messages = []
|
||||
# for msg in messages:
|
||||
# if msg['role'] == 'user':
|
||||
# langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
# elif msg['role'] == 'assistant':
|
||||
# langchain_messages.append(AIMessage(content=msg['content']))
|
||||
#
|
||||
# # 初始状态 - 包含所有必要字段
|
||||
# initial_state = {
|
||||
# "messages": langchain_messages,
|
||||
# "end_user_id": end_user_id,
|
||||
# "memory_config": memory_config
|
||||
# }
|
||||
#
|
||||
# # 获取节点更新信息
|
||||
# async for update_event in graph.astream(
|
||||
# initial_state,
|
||||
# stream_mode="updates",
|
||||
# config=config
|
||||
# ):
|
||||
# for node_name, node_data in update_event.items():
|
||||
# if 'save_neo4j' == node_name:
|
||||
# massages = node_data
|
||||
# massagesstatus = massages.get('write_result')['status']
|
||||
# contents = massages.get('write_result')
|
||||
# # Convert messages back to string for logging
|
||||
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
# except Exception as e:
|
||||
# # Ensure proper error handling and logging
|
||||
# error_msg = f"Write operation failed: {str(e)}"
|
||||
# logger.error(error_msg)
|
||||
# if audit_logger:
|
||||
# duration = time.time() - start_time
|
||||
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
# raise ValueError(error_msg)
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
|
||||
@@ -408,7 +376,7 @@ class MemoryAgentService:
|
||||
message: str,
|
||||
history: List[Dict],
|
||||
search_switch: str,
|
||||
config_id: Optional[str],
|
||||
config_id: Optional[UUID],
|
||||
db: Session,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str) -> Dict:
|
||||
@@ -685,7 +653,7 @@ class MemoryAgentService:
|
||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||
return user_input.messages
|
||||
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
async def classify_message_type(self, message: str, config_id: UUID, db: Session) -> Dict:
|
||||
"""
|
||||
Determine the type of user message (read or write)
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
@@ -716,7 +684,7 @@ class MemoryAgentService:
|
||||
retrieve_info: str,
|
||||
history: List[Dict],
|
||||
query: str,
|
||||
config_id: str,
|
||||
config_id: UUID,
|
||||
db: Session
|
||||
) -> str:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user