把group_id替换end_user_id

This commit is contained in:
lixinyue
2026-01-21 20:33:22 +08:00
parent 4a4931bee2
commit f0efed8aa1
4 changed files with 100 additions and 53 deletions

View File

@@ -164,7 +164,7 @@ async def write_server(
try: try:
result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
user_input.end_user_id, user_input.end_user_id,
user_input.message, user_input.messages,
config_id, config_id,
db, db,
storage_type, storage_type,
@@ -290,7 +290,7 @@ async def read_server(
) )
if str(user_input.search_switch) == "2": if str(user_input.search_switch) == "2":
retrieve_info = result['answer'] retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
query = user_input.message query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
@@ -596,7 +596,7 @@ async def status_type(
last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type( result = await memory_agent_service.classify_message_type(
user_input.message, user_input.messages,
user_input.config_id, user_input.config_id,
db db
) )

View File

@@ -21,9 +21,9 @@ async def write_node(state: WriteState) -> WriteState:
memory_config=state.get('memory_config', '') memory_config=state.get('memory_config', '')
try: try:
result=await write( result=await write(
content=content,
end_user_id=end_user_id, end_user_id=end_user_id,
memory_config=memory_config, memory_config=memory_config,
messages=content, # 修复:使用正确的参数名 messages
) )
logger.info(f"Write completed successfully! Config: {memory_config.config_name}") logger.info(f"Write completed successfully! Config: {memory_config.config_name}")

View File

@@ -77,10 +77,25 @@ async def write(
# Step 1: Load and chunk data # Step 1: Load and chunk data
step_start = time.time() step_start = time.time()
# Convert messages list to content string
# messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
if isinstance(messages, list) and len(messages) > 0:
# Extract content from the last user message or concatenate all messages
if isinstance(messages[-1], dict) and 'content' in messages[-1]:
content = messages[-1]['content']
else:
# Fallback: concatenate all message contents
content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)])
elif isinstance(messages, str):
content = messages
else:
content = str(messages)
chunked_dialogs = await get_chunked_dialogs( chunked_dialogs = await get_chunked_dialogs(
chunker_strategy=chunker_strategy, chunker_strategy=chunker_strategy,
end_user_id=end_user_id, end_user_id=end_user_id,
messages=messages, content=content, # 修复:使用 content 参数而不是 messages
ref_id=ref_id, ref_id=ref_id,
config_id=config_id, config_id=config_id,
) )

View File

@@ -36,6 +36,7 @@ from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import ( from app.services.memory_konwledges_server import (
write_rag, write_rag,
) )
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import func from sqlalchemy import func
@@ -57,7 +58,6 @@ class MemoryAgentService:
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context): def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
duration = time.time() - start_time duration = time.time() - start_time
if str(messages) == 'success': if str(messages) == 'success':
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作 # 记录成功的操作
@@ -266,7 +266,7 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources") logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic # LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
""" """
Process write operation with config_id Process write operation with config_id
@@ -319,53 +319,85 @@ class MemoryAgentService:
raise ValueError(error_msg) raise ValueError(error_msg)
try: async with make_write_graph() as graph:
if storage_type == "rag": config = {"configurable": {"thread_id": end_user_id}}
# For RAG storage, convert messages to single string # Convert structured messages to LangChain messages
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) langchain_messages = []
result = await write_rag(end_user_id, message_text, user_rag_memory_id) for msg in messages:
return result if msg['role'] == 'user':
else: langchain_messages.append(HumanMessage(content=msg['content']))
async with make_write_graph() as graph: elif msg['role'] == 'assistant':
config = {"configurable": {"thread_id": end_user_id}} langchain_messages.append(AIMessage(content=msg['content']))
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
from langchain_core.messages import AIMessage
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段 # 初始状态 - 包含所有必要字段
initial_state = { initial_state = {
"messages": langchain_messages, "messages": langchain_messages,
"end_user_id": end_user_id, "end_user_id": end_user_id,
"memory_config": memory_config "memory_config": memory_config
} }
# 获取节点更新信息 # 获取节点更新信息
async for update_event in graph.astream( async for update_event in graph.astream(
initial_state, initial_state,
stream_mode="updates", stream_mode="updates",
config=config config=config
): ):
for node_name, node_data in update_event.items(): for node_name, node_data in update_event.items():
if 'save_neo4j' == node_name: if 'save_neo4j' == node_name:
massages = node_data massages = node_data
massagesstatus = massages.get('write_result')['status'] print(massages)
contents = massages.get('write_result') massagesstatus = massages.get('write_result')['status']
# Convert messages back to string for logging contents = massages.get('write_result')
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) # Convert messages back to string for logging
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
except Exception as e: return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" # try:
logger.error(error_msg) # if storage_type == "rag":
if audit_logger: # # For RAG storage, convert messages to single string
duration = time.time() - start_time # message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) # result = await write_rag(end_user_id, message_text, user_rag_memory_id)
raise ValueError(error_msg) # return result
# else:
# async with make_write_graph() as graph:
# config = {"configurable": {"thread_id": end_user_id}}
# # Convert structured messages to LangChain messages
# langchain_messages = []
# for msg in messages:
# if msg['role'] == 'user':
# langchain_messages.append(HumanMessage(content=msg['content']))
# elif msg['role'] == 'assistant':
# langchain_messages.append(AIMessage(content=msg['content']))
#
# # 初始状态 - 包含所有必要字段
# initial_state = {
# "messages": langchain_messages,
# "end_user_id": end_user_id,
# "memory_config": memory_config
# }
#
# # 获取节点更新信息
# async for update_event in graph.astream(
# initial_state,
# stream_mode="updates",
# config=config
# ):
# for node_name, node_data in update_event.items():
# if 'save_neo4j' == node_name:
# massages = node_data
# massagesstatus = massages.get('write_result')['status']
# contents = massages.get('write_result')
# # Convert messages back to string for logging
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
# except Exception as e:
# # Ensure proper error handling and logging
# error_msg = f"Write operation failed: {str(e)}"
# logger.error(error_msg)
# if audit_logger:
# duration = time.time() - start_time
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
# raise ValueError(error_msg)