把group_id替换end_user_id
This commit is contained in:
@@ -164,7 +164,7 @@ async def write_server(
|
||||
try:
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.end_user_id,
|
||||
user_input.message,
|
||||
user_input.messages,
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
@@ -290,7 +290,7 @@ async def read_server(
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
||||
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
@@ -596,7 +596,7 @@ async def status_type(
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
user_input.message,
|
||||
user_input.messages,
|
||||
user_input.config_id,
|
||||
db
|
||||
)
|
||||
|
||||
@@ -21,9 +21,9 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
memory_config=state.get('memory_config', '')
|
||||
try:
|
||||
result=await write(
|
||||
content=content,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
messages=content, # 修复:使用正确的参数名 messages
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
|
||||
@@ -77,10 +77,25 @@ async def write(
|
||||
|
||||
# Step 1: Load and chunk data
|
||||
step_start = time.time()
|
||||
|
||||
# Convert messages list to content string
|
||||
# messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
|
||||
if isinstance(messages, list) and len(messages) > 0:
|
||||
# Extract content from the last user message or concatenate all messages
|
||||
if isinstance(messages[-1], dict) and 'content' in messages[-1]:
|
||||
content = messages[-1]['content']
|
||||
else:
|
||||
# Fallback: concatenate all message contents
|
||||
content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)])
|
||||
elif isinstance(messages, str):
|
||||
content = messages
|
||||
else:
|
||||
content = str(messages)
|
||||
|
||||
chunked_dialogs = await get_chunked_dialogs(
|
||||
chunker_strategy=chunker_strategy,
|
||||
end_user_id=end_user_id,
|
||||
messages=messages,
|
||||
content=content, # 修复:使用 content 参数而不是 messages
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,7 @@ from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import (
|
||||
write_rag,
|
||||
)
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import func
|
||||
@@ -57,7 +58,6 @@ class MemoryAgentService:
|
||||
|
||||
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
|
||||
duration = time.time() - start_time
|
||||
|
||||
if str(messages) == 'success':
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
@@ -266,7 +266,7 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
@@ -319,53 +319,85 @@ class MemoryAgentService:
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
from langchain_core.messages import AIMessage
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"end_user_id": end_user_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j' == node_name:
|
||||
massages = node_data
|
||||
print(massages)
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
|
||||
# try:
|
||||
# if storage_type == "rag":
|
||||
# # For RAG storage, convert messages to single string
|
||||
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||
# return result
|
||||
# else:
|
||||
# async with make_write_graph() as graph:
|
||||
# config = {"configurable": {"thread_id": end_user_id}}
|
||||
# # Convert structured messages to LangChain messages
|
||||
# langchain_messages = []
|
||||
# for msg in messages:
|
||||
# if msg['role'] == 'user':
|
||||
# langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
# elif msg['role'] == 'assistant':
|
||||
# langchain_messages.append(AIMessage(content=msg['content']))
|
||||
#
|
||||
# # 初始状态 - 包含所有必要字段
|
||||
# initial_state = {
|
||||
# "messages": langchain_messages,
|
||||
# "end_user_id": end_user_id,
|
||||
# "memory_config": memory_config
|
||||
# }
|
||||
#
|
||||
# # 获取节点更新信息
|
||||
# async for update_event in graph.astream(
|
||||
# initial_state,
|
||||
# stream_mode="updates",
|
||||
# config=config
|
||||
# ):
|
||||
# for node_name, node_data in update_event.items():
|
||||
# if 'save_neo4j' == node_name:
|
||||
# massages = node_data
|
||||
# massagesstatus = massages.get('write_result')['status']
|
||||
# contents = massages.get('write_result')
|
||||
# # Convert messages back to string for logging
|
||||
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||
# except Exception as e:
|
||||
# # Ensure proper error handling and logging
|
||||
# error_msg = f"Write operation failed: {str(e)}"
|
||||
# logger.error(error_msg)
|
||||
# if audit_logger:
|
||||
# duration = time.time() - start_time
|
||||
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||
# raise ValueError(error_msg)
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user