把group_id替换end_user_id
This commit is contained in:
@@ -164,7 +164,7 @@ async def write_server(
|
|||||||
try:
|
try:
|
||||||
result = await memory_agent_service.write_memory(
|
result = await memory_agent_service.write_memory(
|
||||||
user_input.end_user_id,
|
user_input.end_user_id,
|
||||||
user_input.message,
|
user_input.messages,
|
||||||
config_id,
|
config_id,
|
||||||
db,
|
db,
|
||||||
storage_type,
|
storage_type,
|
||||||
@@ -290,7 +290,7 @@ async def read_server(
|
|||||||
)
|
)
|
||||||
if str(user_input.search_switch) == "2":
|
if str(user_input.search_switch) == "2":
|
||||||
retrieve_info = result['answer']
|
retrieve_info = result['answer']
|
||||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id)
|
||||||
query = user_input.message
|
query = user_input.message
|
||||||
|
|
||||||
# 调用 memory_agent_service 的方法生成最终答案
|
# 调用 memory_agent_service 的方法生成最终答案
|
||||||
@@ -596,7 +596,7 @@ async def status_type(
|
|||||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||||
|
|
||||||
result = await memory_agent_service.classify_message_type(
|
result = await memory_agent_service.classify_message_type(
|
||||||
user_input.message,
|
user_input.messages,
|
||||||
user_input.config_id,
|
user_input.config_id,
|
||||||
db
|
db
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ async def write_node(state: WriteState) -> WriteState:
|
|||||||
memory_config=state.get('memory_config', '')
|
memory_config=state.get('memory_config', '')
|
||||||
try:
|
try:
|
||||||
result=await write(
|
result=await write(
|
||||||
content=content,
|
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
memory_config=memory_config,
|
memory_config=memory_config,
|
||||||
|
messages=content, # 修复:使用正确的参数名 messages
|
||||||
)
|
)
|
||||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||||
|
|
||||||
|
|||||||
@@ -77,10 +77,25 @@ async def write(
|
|||||||
|
|
||||||
# Step 1: Load and chunk data
|
# Step 1: Load and chunk data
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
|
|
||||||
|
# Convert messages list to content string
|
||||||
|
# messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...]
|
||||||
|
if isinstance(messages, list) and len(messages) > 0:
|
||||||
|
# Extract content from the last user message or concatenate all messages
|
||||||
|
if isinstance(messages[-1], dict) and 'content' in messages[-1]:
|
||||||
|
content = messages[-1]['content']
|
||||||
|
else:
|
||||||
|
# Fallback: concatenate all message contents
|
||||||
|
content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)])
|
||||||
|
elif isinstance(messages, str):
|
||||||
|
content = messages
|
||||||
|
else:
|
||||||
|
content = str(messages)
|
||||||
|
|
||||||
chunked_dialogs = await get_chunked_dialogs(
|
chunked_dialogs = await get_chunked_dialogs(
|
||||||
chunker_strategy=chunker_strategy,
|
chunker_strategy=chunker_strategy,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
messages=messages,
|
content=content, # 修复:使用 content 参数而不是 messages
|
||||||
ref_id=ref_id,
|
ref_id=ref_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from app.services.memory_config_service import MemoryConfigService
|
|||||||
from app.services.memory_konwledges_server import (
|
from app.services.memory_konwledges_server import (
|
||||||
write_rag,
|
write_rag,
|
||||||
)
|
)
|
||||||
|
from langchain_core.messages import AIMessage
|
||||||
from langchain_core.messages import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
@@ -57,7 +58,6 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
|
def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context):
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
|
|
||||||
if str(messages) == 'success':
|
if str(messages) == 'success':
|
||||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||||
# 记录成功的操作
|
# 记录成功的操作
|
||||||
@@ -266,7 +266,7 @@ class MemoryAgentService:
|
|||||||
logger.info("Log streaming completed, cleaning up resources")
|
logger.info("Log streaming completed, cleaning up resources")
|
||||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||||
|
|
||||||
async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
Process write operation with config_id
|
Process write operation with config_id
|
||||||
|
|
||||||
@@ -319,53 +319,85 @@ class MemoryAgentService:
|
|||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
try:
|
async with make_write_graph() as graph:
|
||||||
if storage_type == "rag":
|
config = {"configurable": {"thread_id": end_user_id}}
|
||||||
# For RAG storage, convert messages to single string
|
# Convert structured messages to LangChain messages
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
langchain_messages = []
|
||||||
result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
for msg in messages:
|
||||||
return result
|
if msg['role'] == 'user':
|
||||||
else:
|
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
async with make_write_graph() as graph:
|
elif msg['role'] == 'assistant':
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
# Convert structured messages to LangChain messages
|
|
||||||
langchain_messages = []
|
|
||||||
for msg in messages:
|
|
||||||
if msg['role'] == 'user':
|
|
||||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
|
||||||
elif msg['role'] == 'assistant':
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
langchain_messages.append(AIMessage(content=msg['content']))
|
|
||||||
|
|
||||||
# 初始状态 - 包含所有必要字段
|
|
||||||
initial_state = {
|
|
||||||
"messages": langchain_messages,
|
|
||||||
"end_user_id": end_user_id,
|
|
||||||
"memory_config": memory_config
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取节点更新信息
|
# 初始状态 - 包含所有必要字段
|
||||||
async for update_event in graph.astream(
|
initial_state = {
|
||||||
initial_state,
|
"messages": langchain_messages,
|
||||||
stream_mode="updates",
|
"end_user_id": end_user_id,
|
||||||
config=config
|
"memory_config": memory_config
|
||||||
):
|
}
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
if 'save_neo4j' == node_name:
|
# 获取节点更新信息
|
||||||
massages = node_data
|
async for update_event in graph.astream(
|
||||||
massagesstatus = massages.get('write_result')['status']
|
initial_state,
|
||||||
contents = massages.get('write_result')
|
stream_mode="updates",
|
||||||
# Convert messages back to string for logging
|
config=config
|
||||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
):
|
||||||
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
for node_name, node_data in update_event.items():
|
||||||
except Exception as e:
|
if 'save_neo4j' == node_name:
|
||||||
# Ensure proper error handling and logging
|
massages = node_data
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
print(massages)
|
||||||
logger.error(error_msg)
|
massagesstatus = massages.get('write_result')['status']
|
||||||
if audit_logger:
|
contents = massages.get('write_result')
|
||||||
duration = time.time() - start_time
|
# Convert messages back to string for logging
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
raise ValueError(error_msg)
|
return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# if storage_type == "rag":
|
||||||
|
# # For RAG storage, convert messages to single string
|
||||||
|
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
|
# result = await write_rag(end_user_id, message_text, user_rag_memory_id)
|
||||||
|
# return result
|
||||||
|
# else:
|
||||||
|
# async with make_write_graph() as graph:
|
||||||
|
# config = {"configurable": {"thread_id": end_user_id}}
|
||||||
|
# # Convert structured messages to LangChain messages
|
||||||
|
# langchain_messages = []
|
||||||
|
# for msg in messages:
|
||||||
|
# if msg['role'] == 'user':
|
||||||
|
# langchain_messages.append(HumanMessage(content=msg['content']))
|
||||||
|
# elif msg['role'] == 'assistant':
|
||||||
|
# langchain_messages.append(AIMessage(content=msg['content']))
|
||||||
|
#
|
||||||
|
# # 初始状态 - 包含所有必要字段
|
||||||
|
# initial_state = {
|
||||||
|
# "messages": langchain_messages,
|
||||||
|
# "end_user_id": end_user_id,
|
||||||
|
# "memory_config": memory_config
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# # 获取节点更新信息
|
||||||
|
# async for update_event in graph.astream(
|
||||||
|
# initial_state,
|
||||||
|
# stream_mode="updates",
|
||||||
|
# config=config
|
||||||
|
# ):
|
||||||
|
# for node_name, node_data in update_event.items():
|
||||||
|
# if 'save_neo4j' == node_name:
|
||||||
|
# massages = node_data
|
||||||
|
# massagesstatus = massages.get('write_result')['status']
|
||||||
|
# contents = massages.get('write_result')
|
||||||
|
# # Convert messages back to string for logging
|
||||||
|
# message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||||
|
# return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents)
|
||||||
|
# except Exception as e:
|
||||||
|
# # Ensure proper error handling and logging
|
||||||
|
# error_msg = f"Write operation failed: {str(e)}"
|
||||||
|
# logger.error(error_msg)
|
||||||
|
# if audit_logger:
|
||||||
|
# duration = time.time() - start_time
|
||||||
|
# audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
|
||||||
|
# raise ValueError(error_msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user