把group_id替换end_user_id

This commit is contained in:
lixinyue
2026-01-21 20:33:22 +08:00
parent 4a4931bee2
commit f0efed8aa1
4 changed files with 100 additions and 53 deletions

View File

@@ -164,7 +164,7 @@ async def write_server(
try:
result = await memory_agent_service.write_memory(
user_input.end_user_id,
user_input.message,
user_input.messages,
config_id,
db,
storage_type,
@@ -290,7 +290,7 @@ async def read_server(
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
@@ -596,7 +596,7 @@ async def status_type(
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
user_input.message,
user_input.messages,
user_input.config_id,
db
)

View File

@@ -21,9 +21,9 @@ async def write_node(state: WriteState) -> WriteState:
memory_config=state.get('memory_config', '')
try:
result=await write(
content=content,
end_user_id=end_user_id,
memory_config=memory_config,
messages=content, # 修复:使用正确的参数名 messages
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")

View File

@@ -77,10 +77,25 @@ async def write(
# Step 1: Load and chunk data
step_start = time.time()
# Convert messages list to content string
# messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
if isinstance(messages, list) and len(messages) > 0:
# Extract content from the last user message or concatenate all messages
if isinstance(messages[-1], dict) and 'content' in messages[-1]:
content = messages[-1]['content']
else:
# Fallback: concatenate all message contents
content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)])
elif isinstance(messages, str):
content = messages
else:
content = str(messages)
chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy,
end_user_id=end_user_id,
messages=messages,
content=content, # 修复:使用 content 参数而不是 messages
ref_id=ref_id,
config_id=config_id,
)

View File

@@ -36,6 +36,7 @@ from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import (
write_rag,
)
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field
from sqlalchemy import func
@@ -57,7 +58,6 @@ class MemoryAgentService:
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
duration = time.time() - start_time
if str(messages) == 'success':
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作
@@ -266,7 +266,7 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
@@ -319,53 +319,85 @@ class MemoryAgentService:
raise ValueError(error_msg)
try:
if storage_type == "rag":
# For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
from langchain_core.messages import AIMessage
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
raise ValueError(error_msg)
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,
"end_user_id": end_user_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name:
massages = node_data
print(massages)
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# try:
# if storage_type == "rag":
# # For RAG storage, convert messages to single string
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
# return result
# else:
# async with make_write_graph() as graph:
# config = {"configurable": {"thread_id": end_user_id}}
# # Convert structured messages to LangChain messages
# langchain_messages = []
# for msg in messages:
# if msg['role'] == 'user':
# langchain_messages.append(HumanMessage(content=msg['content']))
# elif msg['role'] == 'assistant':
# langchain_messages.append(AIMessage(content=msg['content']))
#
# # 初始状态 - 包含所有必要字段
# initial_state = {
# "messages": langchain_messages,
# "end_user_id": end_user_id,
# "memory_config": memory_config
# }
#
# # 获取节点更新信息
# async for update_event in graph.astream(
# initial_state,
# stream_mode="updates",
# config=config
# ):
# for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
# massagesstatus = massages.get('write_result')['status']
# contents = massages.get('write_result')
# # Convert messages back to string for logging
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# except Exception as e:
# # Ensure proper error handling and logging
# error_msg = f"Write operation failed: {str(e)}"
# logger.error(error_msg)
# if audit_logger:
# duration = time.time() - start_time
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
# raise ValueError(error_msg)