Merge remote-tracking branch 'origin/develop' into develop

This commit is contained in:
lixinyue
2026-01-21 16:04:56 +08:00
17 changed files with 541 additions and 330 deletions

View File

@@ -160,9 +160,12 @@ async def write_server(
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
user_input.group_id,
user_input.message,
messages_list, # 传递结构化消息列表
config_id,
db,
storage_type,
@@ -219,9 +222,12 @@ async def write_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
@@ -564,8 +570,23 @@ async def status_type(
"""
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
for msg in reversed(messages_list):
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
user_input.message,
last_user_message,
user_input.config_id,
db
)

View File

@@ -145,44 +145,98 @@ class LangChainAgent:
messages.append(HumanMessage(content=user_content))
return messages
async def term_memory_save(self,messages,end_user_end,aimessages):
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
end_user_end=f"Term_{end_user_end}"
print(messages)
print(aimessages)
session_id = store.save_session(
userid=end_user_end,
messages=messages,
apply_id=end_user_end,
group_id=end_user_end,
aimessages=aimessages
)
store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id
async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[]
retrieved_content=[]
for messages in history:
query = messages.get("Query")
aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_save(self,messages,end_user_end,aimessages):
# '''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
# end_user_end=f"Term_{end_user_end}"
# print(messages)
# print(aimessages)
# session_id = store.save_session(
# userid=end_user_end,
# messages=messages,
# apply_id=end_user_end,
# group_id=end_user_end,
# aimessages=aimessages
# )
# store.delete_duplicate_sessions()
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
# return session_id
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# async def term_memory_redis_read(self,end_user_end):
# end_user_end = f"Term_{end_user_end}"
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
# messagss_list=[]
# retrieved_content=[]
# for messages in history:
# query = messages.get("Query")
# aimessages = messages.get("Answer")
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
# retrieved_content.append({query: aimessages})
# return messagss_list,retrieved_content
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
user_rag_memory_id)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if user_message:
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if ai_message:
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
# 调用 Celery 任务,传递结构化消息列表
# 数据流:
# 1. structured_messages 传递给 write_message_task
# 2. write_message_task 调用 memory_agent_service.write_memory
# 3. write_memory 调用 write_tools.write传递 messages 参数
# 4. write_tools.write 调用 get_chunked_dialogs传递 messages 参数
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk设置 speaker 字段
# 6. 每个 Chunk 保存到 Neo4j包含 speaker 字段
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # group_id: 用户ID
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
actual_config_id, # config_id: 配置ID
storage_type, # storage_type: "neo4j"
user_rag_memory_id # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def chat(
self,
@@ -227,29 +281,30 @@ class LangChainAgent:
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# db_for_memory = next(get_db())
# if memory_flag:
# if len(history_term_memory)>=4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# print(retrieved_content)
# # 为长期记忆操作获取新的数据库连接
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# except Exception as e:
# logger.error(f"Failed to write to LongTermMemory: {e}")
# raise
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
# # 长期记忆写入(
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -277,8 +332,10 @@ class LangChainAgent:
elapsed_time = time.time() - start_time
if memory_flag:
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
await self.term_memory_save(message_chat,end_user_id,content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, content)
response = {
"content": content,
"model": self.model_name,
@@ -346,27 +403,27 @@ class LangChainAgent:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# # TODO 乐力齐
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
# history_term_memory = history_term_memory_result[0]
# if memory_flag:
# if len(history_term_memory) >= 4 and storage_type != "rag":
# history_term_memory = ';'.join(history_term_memory)
# retrieved_content = history_term_memory_result[1]
# db_for_memory = next(get_db())
# try:
# repo = LongTermMemoryRepository(db_for_memory)
# repo.upsert(end_user_id, retrieved_content)
# logger.info(
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
# # 长期记忆写入
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
# except Exception as e:
# logger.error(f"Failed to write to long term memory: {e}")
# finally:
# db_for_memory.close()
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try:
# 准备消息列表
messages = self._prepare_messages(message, history, context)
@@ -418,8 +475,10 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
if memory_flag:
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
await self.term_memory_save(message_chat, end_user_id, full_content)
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
# await self.term_memory_save(message_chat, end_user_id, full_content)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)

View File

@@ -9,22 +9,29 @@ async def write_node(state: WriteState) -> WriteState:
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
state: WriteState containing messages, group_id, and memory_config
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
dict: Contains 'write_result' with status and data fields
"""
content=state.get('data','')
group_id=state.get('group_id','')
memory_config=state.get('memory_config', '')
messages = state.get('messages', [])
group_id = state.get('group_id', '')
memory_config = state.get('memory_config', '')
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# Map LangChain message types to role names
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
structured_messages.append({
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try:
result=await write(
content=content,
result = await write(
messages=structured_messages,
user_id=group_id,
apply_id=group_id,
group_id=group_id,
@@ -32,18 +39,17 @@ async def write_node(state: WriteState) -> WriteState:
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
write_result= {
write_result = {
"status": "success",
"data": content,
"data": structured_messages,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
return {"write_result":write_result}
return {"write_result": write_result}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
write_result= {
write_result = {
"status": "error",
"message": str(e),
}

View File

@@ -14,7 +14,6 @@ from app.db import get_db
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
from app.services.memory_config_service import MemoryConfigService
warnings.filterwarnings("ignore", category=RuntimeWarning)
@@ -27,18 +26,12 @@ async def make_write_graph():
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
The workflow directly processes messages from the initial state
and saves them to Neo4j storage.
"""
workflow = StateGraph(WriteState)
workflow.add_node("content_input", content_input_write)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()

View File

@@ -12,32 +12,49 @@ async def get_chunked_dialogs(
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
content: str = "这是用户的输入",
messages: list = None,
ref_id: str = "wyl_20251027",
config_id: str = None
) -> List[DialogData]:
"""Generate chunks from all test data entries using the specified chunker strategy.
"""Generate chunks from structured messages using the specified chunker strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
content: Dialog content
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference identifier
config_id: Configuration ID for processing
Returns:
List of DialogData objects with generated chunks for each test entry
List of DialogData objects with generated chunks
"""
dialog_data_list = []
messages = []
messages.append(ConversationMessage(role="用户", msg=content))
# Create DialogData
conversation_context = ConversationContext(msgs=messages)
# Create DialogData with group_id based on the entry's id for uniqueness
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
if not messages or not isinstance(messages, list) or len(messages) == 0:
raise ValueError("messages parameter must be a non-empty list")
conversation_messages = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
role = msg['role']
content = msg['content']
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
conversation_context = ConversationContext(msgs=conversation_messages)
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
@@ -46,25 +63,11 @@ async def get_chunked_dialogs(
apply_id=apply_id,
config_id=config_id
)
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
dialog_data_list.append(dialog_data)
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list
return [dialog_data]

View File

@@ -29,25 +29,22 @@ logger = get_agent_logger(__name__)
async def write(
content: str,
user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
messages: list,
ref_id: str = "wyl20251027",
) -> None:
"""
Execute the complete knowledge extraction pipeline.
Only MemoryConfig is needed - LLM and embedding clients are constructed
internally from the config.
Args:
content: Dialogue content to process
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "wyl20251027"
"""
# Extract config values
@@ -89,7 +86,7 @@ async def write(
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
content=content,
messages=messages,
ref_id=ref_id,
config_id=config_id,
)

View File

@@ -4,6 +4,7 @@ import os
import asyncio
import json
import numpy as np
import logging
# Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk
try:
from app.core.memory.llm_tools.openai_client import OpenAIClient
except Exception:
# 在测试或无可用依赖(如 langfuse环境下允许惰性导入
OpenAIClient = Any
# Initialize logger
logger = logging.getLogger(__name__)
class LLMChunker:
"""基于LLM的智能分块策略"""
"""LLM-based intelligent chunking strategy"""
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
self.llm_client = llm_client
self.chunk_size = chunk_size
async def __call__(self, text: str) -> List[Any]:
# 使用LLM分析文本结构并进行智能分块
prompt = f"""
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
请以JSON格式返回结果包含chunks数组每个chunk有text字段。
Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long.
Return results in JSON format with a chunks array, each chunk having a text field.
文本内容:
Text content:
{text[:5000]}
"""
messages = [
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
{"role": "user", "content": prompt}
]
@@ -171,8 +173,6 @@ class ChunkerClient:
base_chunk_size=self.chunk_size,
)
elif chunker_config.chunker_strategy == "SentenceChunker":
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
# 为了兼容不同版本,这里仅传递广泛支持的参数
self.chunker = SentenceChunker(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
@@ -186,100 +186,93 @@ class ChunkerClient:
async def generate_chunks(self, dialogue: DialogData):
"""
生成分块,支持异步操作
Generate chunks following 1 Message = 1 Chunk strategy.
Each message creates one chunk, directly inheriting role information.
If a message is too long, it will be split into multiple sub-chunks,
each maintaining the same speaker.
Raises:
ValueError: If dialogue has no messages or chunking fails
"""
try:
# 预处理文本:确保对话标记格式统一
content = dialogue.content
content = content.replace('AI', 'AI:').replace('用户:', '用户:') # 统一冒号
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
# 同步分块器
chunks = self.chunker(content)
# Validate dialogue has messages
if not dialogue.context or not dialogue.context.msgs:
raise ValueError(
f"Dialogue {dialogue.ref_id} has no messages. "
f"Cannot generate chunks from empty dialogue."
)
dialogue.chunks = []
# 按消息分块:每个消息创建一个或多个 chunk直接继承角色
for msg_idx, msg in enumerate(dialogue.context.msgs):
# Validate message has required attributes
if not hasattr(msg, 'role') or not hasattr(msg, 'msg'):
raise ValueError(
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
f"missing 'role' or 'msg' attribute"
)
msg_content = msg.msg.strip()
# Skip empty messages
if not msg_content:
continue
# 如果消息太长,可以进一步分块
if len(msg_content) > self.chunk_size:
# 对单个消息的内容进行分块
try:
sub_chunks = self.chunker(msg_content)
except Exception as e:
raise ValueError(
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
)
for idx, sub_chunk in enumerate(sub_chunks):
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
sub_chunk_text = sub_chunk_text.strip()
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
continue
chunk = Chunk(
content=f"{msg.role}: {sub_chunk_text}",
speaker=msg.role, # 直接继承角色
metadata={
"message_index": msg_idx,
"message_role": msg.role,
"sub_chunk_index": idx,
"total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
dialogue.chunks.append(chunk)
else:
# 异步分块器如LLMChunker
chunks = await self.chunker(content)
# 过滤空块和过小的块
valid_chunks = []
for c in chunks:
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
valid_chunks.append(c)
dialogue.chunks = [
Chunk(
content=c.text if hasattr(c, 'text') else str(c),
# 消息不长,直接作为一个 chunk
chunk = Chunk(
content=f"{msg.role}: {msg_content}",
speaker=msg.role, # 直接继承角色
metadata={
"start_index": getattr(c, "start_index", None),
"end_index": getattr(c, "end_index", None),
"message_index": msg_idx,
"message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy,
},
)
for c in valid_chunks
]
return dialogue
except Exception as e:
print(f"分块失败: {e}")
# 改进的后备方案:尝试按对话回合分割
try:
# 简单的按对话分割
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
class SimpleChunk:
def __init__(self, text, start_index, end_index):
self.text = text
self.start_index = start_index
self.end_index = end_index
chunks = []
current_chunk = ""
current_start = 0
for match in matches:
speaker, ct = match[0], match[1].strip()
turn_text = f"{speaker} {ct}"
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
current_chunk = turn_text
current_start = dialogue.content.find(turn_text, current_start)
else:
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
if current_chunk:
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
dialogue.chunks = [
Chunk(
content=c.text,
metadata={
"start_index": c.start_index,
"end_index": c.end_index,
"chunker_strategy": "DialogueTurnFallback",
},
)
for c in chunks
]
except Exception:
# 最后的手段:单一大块
dialogue.chunks = [Chunk(
content=dialogue.content,
metadata={"chunker_strategy": "SingleChunkFallback"},
)]
return dialogue
dialogue.chunks.append(chunk)
# Validate we generated at least one chunk
if not dialogue.chunks:
raise ValueError(
f"No valid chunks generated for dialogue {dialogue.ref_id}. "
f"All messages were either empty or too short. "
f"Messages count: {len(dialogue.context.msgs)}"
)
return dialogue
def evaluate_chunking(self, dialogue: DialogData) -> dict:
"""
评估分块质量
"""
"""Evaluate chunking quality."""
if not getattr(dialogue, 'chunks', None):
return {}
@@ -304,11 +297,8 @@ class ChunkerClient:
return metrics
def save_chunking_results(self, dialogue: DialogData, output_path: str):
"""
保存分块结果到文件,文件名包含策略名称
"""
"""Save chunking results to file with strategy name in filename."""
strategy_name = self.chunker_config.chunker_strategy
# 在文件名中添加策略名称
base_name, ext = os.path.splitext(output_path)
strategy_output_path = f"{base_name}_{strategy_name}{ext}"

View File

@@ -92,8 +92,6 @@ class OpenAIClient(LLMClient):
config["callbacks"] = [self.langfuse_handler]
response = await chain.ainvoke({"messages": messages}, config=config)
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
return response
except Exception as e:
@@ -149,13 +147,10 @@ class OpenAIClient(LLMClient):
config=config
)
logger.debug(f"使用 PydanticOutputParser 解析成功")
return parsed
except Exception as e:
logger.warning(
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
)
logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}")
# 方法 2: 使用 LangChain 的 with_structured_output
template = """{question}"""
@@ -173,13 +168,17 @@ class OpenAIClient(LLMClient):
# 验证并返回结果
try:
return response_model.model_validate(parsed)
result = response_model.model_validate(parsed)
return result
except Exception:
# 如果已经是 Pydantic 实例,直接返回
if hasattr(parsed, "model_dump"):
return parsed
# 尝试从 JSON 解析
return response_model.model_validate_json(json.dumps(parsed))
result = response_model.model_validate_json(json.dumps(parsed))
return result
else:
logger.warning("with_structured_output 方法不可用")
except Exception as e:
logger.error(f"结构化输出失败: {e}")

View File

@@ -224,6 +224,7 @@ class StatementNode(Node):
chunk_id: ID of the parent chunk this statement belongs to
stmt_type: Type of the statement (from ontology)
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses)
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
emotion_target: Optional emotion target (person or object name)
emotion_subject: Optional emotion subject (self/other/object)
@@ -249,6 +250,12 @@ class StatementNode(Node):
stmt_type: str = Field(..., description="Type of the statement")
statement: str = Field(..., description="The statement text content")
# Speaker identification
speaker: Optional[str] = Field(
None,
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
)
# Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field(
None,

View File

@@ -25,10 +25,10 @@ class ConversationMessage(BaseModel):
"""Represents a single message in a conversation.
Attributes:
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant)
msg: Text content of the message
"""
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.")
@@ -57,6 +57,7 @@ class Statement(BaseModel):
chunk_id: ID of the parent chunk this statement belongs to
group_id: Optional group ID for multi-tenancy
statement: The actual statement text content
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
statement_embedding: Optional embedding vector for the statement
stmt_type: Type of the statement (from ontology)
temporal_info: Temporal information extracted from the statement
@@ -74,6 +75,7 @@ class Statement(BaseModel):
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
statement: str = Field(..., description="The text content of the statement.")
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
stmt_type: StatementType = Field(..., description="The type of the statement.")
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
@@ -118,36 +120,36 @@ class Chunk(BaseModel):
Attributes:
id: Unique identifier for the chunk
text: List of messages in the chunk
content: The content of the chunk as a formatted string
speaker: The speaker/role for this chunk (user/assistant)
statements: List of statements extracted from this chunk
chunk_embedding: Optional embedding vector for the chunk
metadata: Additional metadata as key-value pairs
"""
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a list of messages.
def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None):
"""Create a chunk from a single message (1 Message = 1 Chunk).
Args:
messages: List of conversation messages
message: Single conversation message
metadata: Optional metadata dictionary
Returns:
Chunk instance with formatted content
Chunk instance with speaker directly from message.role
"""
if metadata is None:
metadata = {}
# Generate content from messages
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
return cls(text=messages, content=content, metadata=metadata)
return cls(
content=f"{message.role}: {message.msg}",
speaker=message.role,
metadata=metadata or {}
)
class DialogData(BaseModel):
"""Represents the complete data structure for a dialog record.

View File

@@ -550,7 +550,7 @@ class ExtractionOrchestrator:
self, dialog_data_list: List[DialogData]
) -> List[Dict[str, Any]]:
"""
从对话中提取情绪信息(优化版:全局陈述句级并行)
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
Args:
dialog_data_list: 对话数据列表
@@ -558,7 +558,7 @@ class ExtractionOrchestrator:
Returns:
情绪信息映射列表,每个对话对应一个字典
"""
logger.info("开始情绪信息提取(全局陈述句级并行")
logger.info("开始情绪信息提取(仅处理用户消息")
# 收集所有陈述句及其配置
all_statements = []
@@ -597,15 +597,22 @@ class ExtractionOrchestrator:
if not data_config or not data_config.emotion_enabled:
logger.info("情绪提取未启用,跳过")
return [{} for _ in dialog_data_list]
# 收集所有陈述句(只收集 speaker 为 "user" 的)
total_statements = 0
filtered_statements = 0
# 收集所有陈述句
for d_idx, dialog in enumerate(dialog_data_list):
for chunk in dialog.chunks:
for statement in chunk.statements:
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
total_statements += 1
# 只处理用户的陈述句 (role 为 "user")
if hasattr(statement, 'speaker') and statement.speaker == "user":
all_statements.append((statement, data_config))
statement_metadata.append((d_idx, statement.id))
filtered_statements += 1
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
# 初始化情绪提取服务
from app.services.emotion_extraction_service import EmotionExtractionService
@@ -1033,6 +1040,7 @@ class ExtractionOrchestrator:
apply_id=dialog_data.apply_id,
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
statement=statement.statement,
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
statement_embedding=statement.statement_embedding,
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,

View File

@@ -22,12 +22,12 @@ class DialogueChunker:
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
"""
self.chunker_strategy = chunker_strategy
chunker_config_dict = get_chunker_config(chunker_strategy)
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
# 对于 LLMChunker需要传入 llm_client
if self.chunker_config.chunker_strategy == "LLMChunker":
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
else:
@@ -41,29 +41,19 @@ class DialogueChunker:
Returns:
A list of Chunk objects
Raises:
ValueError: If chunking fails or returns empty chunks
"""
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
# Defensive fallback: ensure at least one chunk is returned for non-empty content
try:
chunks = result_dialogue.chunks
except Exception:
chunks = []
chunks = result_dialogue.chunks
if not chunks or len(chunks) == 0:
# If the dialogue has content, return a single fallback chunk built from messages
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
if content_str and len(content_str.strip()) > 0:
fallback_chunk = Chunk.from_messages(
dialogue.context.msgs,
metadata={
"fallback": "single_chunk",
"chunker_strategy": self.chunker_config.chunker_strategy,
"source": "DialogueChunkerFallback",
},
)
return [fallback_chunk]
# No content: return empty list
return []
raise ValueError(
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
f"Strategy: {self.chunker_config.chunker_strategy}"
)
return chunks
@@ -72,22 +62,25 @@ class DialogueChunker:
Args:
dialogue: The processed DialogData object with chunks
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
output_path: Optional path to save the output
Returns:
The path where the output was saved
"""
if not output_path:
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt")
output_path = os.path.join(
os.path.dirname(__file__), "..", "..",
f"chunker_output_{self.chunker_strategy.lower()}.txt"
)
output_lines = []
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
output_lines.append(f"Total characters: {len(dialogue.content)}")
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
output_lines = [
f"=== Chunking Results ({self.chunker_strategy}) ===",
f"Dialogue ID: {dialogue.ref_id}",
f"Original conversation has {len(dialogue.context.msgs)} messages",
f"Total characters: {len(dialogue.content)}",
f"Generated {len(dialogue.chunks)} chunks:"
]
for i, chunk in enumerate(dialogue.chunks):
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
output_lines.append(f" Content preview: {chunk.content}...")

View File

@@ -5,8 +5,6 @@ from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.models.message_models import DialogData, Statement
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
from app.core.memory.models.variate_config import StatementExtractionConfig
from app.core.memory.utils.data.ontology import (
LABEL_DEFINITIONS,
@@ -22,11 +20,10 @@ logger = logging.getLogger(__name__)
class ExtractedStatement(BaseModel):
"""Schema for extracted statement from LLM"""
statement: str = Field(..., description="The extracted statement text")
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION")
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
class StatementExtractionResponse(BaseModel):
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
@@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel):
return v
class StatementExtractor:
"""Class for extracting statements from dialog chunks using LLM (relations separated)"""
"""Class for extracting statements from dialog chunks using LLM"""
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
"""Initialize the StatementExtractor with an LLM client and configuration
Args:
@@ -71,6 +67,21 @@ class StatementExtractor:
self.llm_client = llm_client
self.config = config or StatementExtractionConfig()
def _get_speaker_from_chunk(self, chunk) -> Optional[str]:
"""Get speaker directly from Chunk
Args:
chunk: Chunk object containing speaker field
Returns:
Speaker role ("user"/"assistant") or None if cannot be determined
"""
if hasattr(chunk, 'speaker') and chunk.speaker:
return chunk.speaker
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
return None
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
"""Process a single chunk and return extracted statements
@@ -82,10 +93,12 @@ class StatementExtractor:
Returns:
List of ExtractedStatement objects extracted from the chunk
"""
# Prepare the chunk content for processing
chunk_content = chunk.content
if not chunk_content or len(chunk_content.strip()) < 5:
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
return []
# Render the prompt using helper function
prompt_content = await render_statement_extraction_prompt(
chunk_content=chunk_content,
definitions=LABEL_DEFINITIONS,
@@ -136,7 +149,9 @@ class StatementExtractor:
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
except (KeyError, ValueError):
relevence_info = RelevenceInfo.RELEVANT
chunk_speaker = self._get_speaker_from_chunk(chunk)
chunk_statement = Statement(
statement=extracted_stmt.statement,
stmt_type=stmt_type,
@@ -144,7 +159,9 @@ class StatementExtractor:
relevence_info=relevence_info,
chunk_id=chunk.id,
group_id=group_id,
speaker=chunk_speaker,
)
chunk_statements.append(chunk_statement)
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
@@ -226,12 +243,7 @@ class StatementExtractor:
return output_path
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
"""按对话分组聚合强/弱关系并写入 TXT 文件。
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
"""
"""Group and aggregate strong/weak relations by dialogue and write to TXT file."""
print("\n=== Relations Classify ===")
# 使用全局配置的输出路径

View File

@@ -101,6 +101,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
# 添加情绪字段处理
"emotion_type": statement.emotion_type,
"emotion_intensity": statement.emotion_intensity,
@@ -163,7 +165,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
"sequence_number": chunk.sequence_number,
"start_index": metadata.get("start_index"),
"end_index": metadata.get("end_index")
"end_index": metadata.get("end_index"),
# 添加 speaker 字段(用于基于角色的情绪提取)
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
}
flattened_chunks.append(flattened_chunk)

View File

@@ -12,7 +12,7 @@ class UserInput(BaseModel):
class Write_UserInput(BaseModel):
message: str
messages: list[dict]
group_id: str
config_id: Optional[str] = None

View File

@@ -20,11 +20,13 @@ from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_agent_schema import Write_UserInput
from app.schemas.memory_config_schema import ConfigurationError
from app.services.memory_base_service import Translation_English
from app.services.memory_config_service import MemoryConfigService
@@ -260,13 +262,13 @@ class MemoryAgentService:
logger.info("Log streaming completed, cleaning up resources")
# LogStreamer uses context manager for file handling, so cleanup is automatic
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
"""
Process write operation with config_id
Args:
group_id: Group identifier (also used as end_user_id)
message: Message to write
messages: Structured message list [{"role": "user", "content": "..."}, ...]
config_id: Configuration ID from database
db: SQLAlchemy database session
storage_type: Storage type (neo4j or rag)
@@ -287,7 +289,7 @@ class MemoryAgentService:
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
raise
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
@@ -315,14 +317,28 @@ class MemoryAgentService:
try:
if storage_type == "rag":
result = await write_rag(group_id, message, user_rag_memory_id)
# For RAG storage, convert messages to single string
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
result = await write_rag(group_id, message_text, user_rag_memory_id)
return result
else:
async with make_write_graph() as graph:
config = {"configurable": {"thread_id": group_id}}
# Convert structured messages to LangChain messages
langchain_messages = []
for msg in messages:
if msg['role'] == 'user':
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
from langchain_core.messages import AIMessage
langchain_messages.append(AIMessage(content=msg['content']))
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
"memory_config": memory_config}
initial_state = {
"messages": langchain_messages,
"group_id": group_id,
"memory_config": memory_config
}
# 获取节点更新信息
async for update_event in graph.astream(
@@ -335,7 +351,9 @@ class MemoryAgentService:
massages = node_data
massagesstatus = massages.get('write_result')['status']
contents = massages.get('write_result')
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
# Convert messages back to string for logging
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
except Exception as e:
# Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}"
@@ -500,6 +518,57 @@ class MemoryAgentService:
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
result = reorder_output_results(optimized_outputs)
# 保存短期记忆到数据库
# 只有 search_switch 不为 "2"(快速检索)时才保存
try:
from app.repositories.memory_short_repository import ShortTermMemoryRepository
retrieved_content = []
repo = ShortTermMemoryRepository(db)
if str(search_switch) != "2":
for intermediate in _intermediate_outputs:
logger.debug(f"处理中间结果: {intermediate}")
intermediate_type = intermediate.get('type', '')
if intermediate_type == "search_result":
query = intermediate.get('query', '')
raw_results = intermediate.get('raw_results', {})
reranked_results = raw_results.get('reranked_results', [])
try:
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
except Exception:
statements = []
# 去重
statements = list(set(statements))
if query and statements:
retrieved_content.append({query: statements})
# 如果 retrieved_content 为空,设置为空字符串
if retrieved_content == []:
retrieved_content = ''
# 只有当回答不是"信息不足"且不是快速检索时才保存
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
# 使用 upsert 方法
repo.upsert(
end_user_id=group_id,
messages=message,
aimessages=summary,
retrieved_content=retrieved_content,
search_switch=str(search_switch)
)
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
else:
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
except Exception as save_error:
# 保存失败不应该影响主流程,只记录错误
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation
if audit_logger:
duration = time.time() - start_time
@@ -531,7 +600,49 @@ class MemoryAgentService:
)
raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
"""
Get standardized message list from user input.
Args:
user_input: Write_UserInput object
Returns:
list[dict]: Message list, each message contains role and content
Raises:
ValueError: If messages is empty or format is incorrect
"""
from app.core.logging_config import get_api_logger
logger = get_api_logger()
if len(user_input.messages) == 0:
logger.error("Validation failed: Message list cannot be empty")
raise ValueError("Message list cannot be empty")
for idx, msg in enumerate(user_input.messages):
if not isinstance(msg, dict):
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
if 'role' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
if 'content' not in msg:
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
if msg['role'] not in ['user', 'assistant']:
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
if not msg['content'] or not msg['content'].strip():
logger.error(f"Validation failed: Message {idx} content is empty")
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
return user_input.messages
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
"""

View File

@@ -472,13 +472,19 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]:
"""Celery task to process a write message via MemoryAgentService.
支持两种消息格式:
1. 字符串格式向后兼容message="user: xxx\nassistant: yyy"
2. 结构化消息列表推荐message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
Args:
group_id: Group ID for the memory agent (also used as end_user_id)
message: Message to write
message: Message to write (str or list[dict])
config_id: Optional configuration ID
storage_type: Storage type (neo4j/rag)
user_rag_memory_id: RAG memory ID
Returns:
Dict containing the result and metadata