Merge remote-tracking branch 'origin/develop' into develop
This commit is contained in:
@@ -160,9 +160,12 @@ async def write_server(
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
user_input.message,
|
||||
messages_list, # 传递结构化消息列表
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
@@ -219,9 +222,12 @@ async def write_server_async(
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
@@ -564,8 +570,23 @@ async def status_type(
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
# 将消息列表转换为字符串用于分类
|
||||
# 只取最后一条用户消息进行分类
|
||||
last_user_message = ""
|
||||
for msg in reversed(messages_list):
|
||||
if msg.get('role') == 'user':
|
||||
last_user_message = msg.get('content', '')
|
||||
break
|
||||
|
||||
if not last_user_message:
|
||||
# 如果没有用户消息,使用所有消息的内容
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
user_input.message,
|
||||
last_user_message,
|
||||
user_input.config_id,
|
||||
db
|
||||
)
|
||||
|
||||
@@ -145,44 +145,98 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
end_user_end=f"Term_{end_user_end}"
|
||||
print(messages)
|
||||
print(aimessages)
|
||||
session_id = store.save_session(
|
||||
userid=end_user_end,
|
||||
messages=messages,
|
||||
apply_id=end_user_end,
|
||||
group_id=end_user_end,
|
||||
aimessages=aimessages
|
||||
)
|
||||
store.delete_duplicate_sessions()
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
return session_id
|
||||
async def term_memory_redis_read(self,end_user_end):
|
||||
end_user_end = f"Term_{end_user_end}"
|
||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
messagss_list=[]
|
||||
retrieved_content=[]
|
||||
for messages in history:
|
||||
query = messages.get("Query")
|
||||
aimessages = messages.get("Answer")
|
||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
retrieved_content.append({query: aimessages})
|
||||
return messagss_list,retrieved_content
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
# end_user_end=f"Term_{end_user_end}"
|
||||
# print(messages)
|
||||
# print(aimessages)
|
||||
# session_id = store.save_session(
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# group_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
# messagss_list=[]
|
||||
# retrieved_content=[]
|
||||
# for messages in history:
|
||||
# query = messages.get("Query")
|
||||
# aimessages = messages.get("Answer")
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
|
||||
|
||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # group_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -227,29 +281,30 @@ class LangChainAgent:
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# db_for_memory = next(get_db())
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# print(retrieved_content)
|
||||
# # 为长期记忆操作获取新的数据库连接
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
db_for_memory = next(get_db())
|
||||
if memory_flag:
|
||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
print(retrieved_content)
|
||||
# 为长期记忆操作获取新的数据库连接
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
raise
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
|
||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||
# # 长期记忆写入(
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -277,8 +332,10 @@ class LangChainAgent:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
|
||||
await self.term_memory_save(message_chat,end_user_id,content)
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -346,27 +403,27 @@ class LangChainAgent:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
# # TODO 乐力齐
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# db_for_memory = next(get_db())
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# # 长期记忆写入
|
||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to long term memory: {e}")
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
if memory_flag:
|
||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
db_for_memory = next(get_db())
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to long term memory: {e}")
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
|
||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -418,8 +475,10 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
if memory_flag:
|
||||
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
|
||||
await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -9,22 +9,29 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
content: Data content to write
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
state: WriteState containing messages, group_id, and memory_config
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
"""
|
||||
content=state.get('data','')
|
||||
group_id=state.get('group_id','')
|
||||
memory_config=state.get('memory_config', '')
|
||||
messages = state.get('messages', [])
|
||||
group_id = state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# Map LangChain message types to role names
|
||||
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
|
||||
structured_messages.append({
|
||||
"role": role,
|
||||
"content": msg.content # content is now guaranteed to be a string
|
||||
})
|
||||
|
||||
try:
|
||||
result=await write(
|
||||
content=content,
|
||||
result = await write(
|
||||
messages=structured_messages,
|
||||
user_id=group_id,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
@@ -32,18 +39,17 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
write_result= {
|
||||
write_result = {
|
||||
"status": "success",
|
||||
"data": content,
|
||||
"data": structured_messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name,
|
||||
}
|
||||
return {"write_result":write_result}
|
||||
|
||||
return {"write_result": write_result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||
write_result= {
|
||||
write_result = {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
@@ -27,18 +26,12 @@ async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
The workflow directly processes messages from the initial state
|
||||
and saves them to Neo4j storage.
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("content_input", content_input_write)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_edge("content_input", "save_neo4j")
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
@@ -12,32 +12,49 @@ async def get_chunked_dialogs(
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
content: str = "这是用户的输入",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from all test data entries using the specified chunker strategy.
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
content: Dialog content
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks for each test entry
|
||||
List of DialogData objects with generated chunks
|
||||
"""
|
||||
dialog_data_list = []
|
||||
messages = []
|
||||
|
||||
messages.append(ConversationMessage(role="用户", msg=content))
|
||||
|
||||
# Create DialogData
|
||||
conversation_context = ConversationContext(msgs=messages)
|
||||
# Create DialogData with group_id based on the entry's id for uniqueness
|
||||
from app.core.logging_config import get_agent_logger
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
||||
raise ValueError("messages parameter must be a non-empty list")
|
||||
|
||||
conversation_messages = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
conversation_context = ConversationContext(msgs=conversation_messages)
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
@@ -46,25 +63,11 @@ async def get_chunked_dialogs(
|
||||
apply_id=apply_id,
|
||||
config_id=config_id
|
||||
)
|
||||
# Create DialogueChunker and process the dialogue
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
||||
|
||||
dialog_data_list.append(dialog_data)
|
||||
|
||||
# Convert to dict with datetime serialized
|
||||
def serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
|
||||
print(dialog_data_list)
|
||||
|
||||
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
|
||||
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
|
||||
|
||||
|
||||
return dialog_data_list
|
||||
return [dialog_data]
|
||||
|
||||
@@ -29,25 +29,22 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
Only MemoryConfig is needed - LLM and embedding clients are constructed
|
||||
internally from the config.
|
||||
|
||||
Args:
|
||||
content: Dialogue content to process
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
"""
|
||||
# Extract config values
|
||||
@@ -89,7 +86,7 @@ async def write(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
content=content,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
except Exception:
|
||||
# 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入
|
||||
OpenAIClient = Any
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMChunker:
|
||||
"""基于LLM的智能分块策略"""
|
||||
"""LLM-based intelligent chunking strategy"""
|
||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||
self.llm_client = llm_client
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
async def __call__(self, text: str) -> List[Any]:
|
||||
# 使用LLM分析文本结构并进行智能分块
|
||||
prompt = f"""
|
||||
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
|
||||
请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。
|
||||
Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long.
|
||||
Return results in JSON format with a chunks array, each chunk having a text field.
|
||||
|
||||
文本内容:
|
||||
Text content:
|
||||
{text[:5000]}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
|
||||
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
@@ -171,8 +173,6 @@ class ChunkerClient:
|
||||
base_chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SentenceChunker":
|
||||
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
|
||||
# 为了兼容不同版本,这里仅传递广泛支持的参数
|
||||
self.chunker = SentenceChunker(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
@@ -186,100 +186,93 @@ class ChunkerClient:
|
||||
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
生成分块,支持异步操作
|
||||
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||
|
||||
Each message creates one chunk, directly inheriting role information.
|
||||
If a message is too long, it will be split into multiple sub-chunks,
|
||||
each maintaining the same speaker.
|
||||
|
||||
Raises:
|
||||
ValueError: If dialogue has no messages or chunking fails
|
||||
"""
|
||||
try:
|
||||
# 预处理文本:确保对话标记格式统一
|
||||
content = dialogue.content
|
||||
content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号
|
||||
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
|
||||
|
||||
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
|
||||
# 同步分块器
|
||||
chunks = self.chunker(content)
|
||||
# Validate dialogue has messages
|
||||
if not dialogue.context or not dialogue.context.msgs:
|
||||
raise ValueError(
|
||||
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||
f"Cannot generate chunks from empty dialogue."
|
||||
)
|
||||
|
||||
dialogue.chunks = []
|
||||
|
||||
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||
# Validate message has required attributes
|
||||
if not hasattr(msg, 'role') or not hasattr(msg, 'msg'):
|
||||
raise ValueError(
|
||||
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||
f"missing 'role' or 'msg' attribute"
|
||||
)
|
||||
|
||||
msg_content = msg.msg.strip()
|
||||
|
||||
# Skip empty messages
|
||||
if not msg_content:
|
||||
continue
|
||||
|
||||
# 如果消息太长,可以进一步分块
|
||||
if len(msg_content) > self.chunk_size:
|
||||
# 对单个消息的内容进行分块
|
||||
try:
|
||||
sub_chunks = self.chunker(msg_content)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||
)
|
||||
|
||||
for idx, sub_chunk in enumerate(sub_chunks):
|
||||
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||
sub_chunk_text = sub_chunk_text.strip()
|
||||
|
||||
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||
continue
|
||||
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {sub_chunk_text}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
metadata={
|
||||
"message_index": msg_idx,
|
||||
"message_role": msg.role,
|
||||
"sub_chunk_index": idx,
|
||||
"total_sub_chunks": len(sub_chunks),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
else:
|
||||
# 异步分块器(如LLMChunker)
|
||||
chunks = await self.chunker(content)
|
||||
|
||||
# 过滤空块和过小的块
|
||||
valid_chunks = []
|
||||
for c in chunks:
|
||||
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
|
||||
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
|
||||
valid_chunks.append(c)
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text if hasattr(c, 'text') else str(c),
|
||||
# 消息不长,直接作为一个 chunk
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {msg_content}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
metadata={
|
||||
"start_index": getattr(c, "start_index", None),
|
||||
"end_index": getattr(c, "end_index", None),
|
||||
"message_index": msg_idx,
|
||||
"message_role": msg.role,
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
)
|
||||
for c in valid_chunks
|
||||
]
|
||||
return dialogue
|
||||
|
||||
except Exception as e:
|
||||
print(f"分块失败: {e}")
|
||||
|
||||
# 改进的后备方案:尝试按对话回合分割
|
||||
try:
|
||||
# 简单的按对话分割
|
||||
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
|
||||
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, start_index, end_index):
|
||||
self.text = text
|
||||
self.start_index = start_index
|
||||
self.end_index = end_index
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_start = 0
|
||||
|
||||
for match in matches:
|
||||
speaker, ct = match[0], match[1].strip()
|
||||
turn_text = f"{speaker} {ct}"
|
||||
|
||||
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
current_chunk = turn_text
|
||||
current_start = dialogue.content.find(turn_text, current_start)
|
||||
else:
|
||||
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text,
|
||||
metadata={
|
||||
"start_index": c.start_index,
|
||||
"end_index": c.end_index,
|
||||
"chunker_strategy": "DialogueTurnFallback",
|
||||
},
|
||||
)
|
||||
for c in chunks
|
||||
]
|
||||
|
||||
except Exception:
|
||||
# 最后的手段:单一大块
|
||||
dialogue.chunks = [Chunk(
|
||||
content=dialogue.content,
|
||||
metadata={"chunker_strategy": "SingleChunkFallback"},
|
||||
)]
|
||||
|
||||
return dialogue
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
# Validate we generated at least one chunk
|
||||
if not dialogue.chunks:
|
||||
raise ValueError(
|
||||
f"No valid chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"All messages were either empty or too short. "
|
||||
f"Messages count: {len(dialogue.context.msgs)}"
|
||||
)
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
"""
|
||||
评估分块质量
|
||||
"""
|
||||
"""Evaluate chunking quality."""
|
||||
if not getattr(dialogue, 'chunks', None):
|
||||
return {}
|
||||
|
||||
@@ -304,11 +297,8 @@ class ChunkerClient:
|
||||
return metrics
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
||||
"""
|
||||
保存分块结果到文件,文件名包含策略名称
|
||||
"""
|
||||
"""Save chunking results to file with strategy name in filename."""
|
||||
strategy_name = self.chunker_config.chunker_strategy
|
||||
# 在文件名中添加策略名称
|
||||
base_name, ext = os.path.splitext(output_path)
|
||||
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
||||
|
||||
|
||||
@@ -92,8 +92,6 @@ class OpenAIClient(LLMClient):
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
||||
|
||||
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -149,13 +147,10 @@ class OpenAIClient(LLMClient):
|
||||
config=config
|
||||
)
|
||||
|
||||
logger.debug(f"使用 PydanticOutputParser 解析成功")
|
||||
return parsed
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
|
||||
)
|
||||
logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}")
|
||||
|
||||
# 方法 2: 使用 LangChain 的 with_structured_output
|
||||
template = """{question}"""
|
||||
@@ -173,13 +168,17 @@ class OpenAIClient(LLMClient):
|
||||
|
||||
# 验证并返回结果
|
||||
try:
|
||||
return response_model.model_validate(parsed)
|
||||
result = response_model.model_validate(parsed)
|
||||
return result
|
||||
except Exception:
|
||||
# 如果已经是 Pydantic 实例,直接返回
|
||||
if hasattr(parsed, "model_dump"):
|
||||
return parsed
|
||||
# 尝试从 JSON 解析
|
||||
return response_model.model_validate_json(json.dumps(parsed))
|
||||
result = response_model.model_validate_json(json.dumps(parsed))
|
||||
return result
|
||||
else:
|
||||
logger.warning("with_structured_output 方法不可用")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}")
|
||||
|
||||
@@ -224,6 +224,7 @@ class StatementNode(Node):
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses)
|
||||
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
|
||||
emotion_target: Optional emotion target (person or object name)
|
||||
emotion_subject: Optional emotion subject (self/other/object)
|
||||
@@ -249,6 +250,12 @@ class StatementNode(Node):
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
|
||||
# Speaker identification
|
||||
speaker: Optional[str] = Field(
|
||||
None,
|
||||
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
|
||||
)
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
|
||||
@@ -25,10 +25,10 @@ class ConversationMessage(BaseModel):
|
||||
"""Represents a single message in a conversation.
|
||||
|
||||
Attributes:
|
||||
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
|
||||
role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant)
|
||||
msg: Text content of the message
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class Statement(BaseModel):
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
group_id: Optional group ID for multi-tenancy
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
@@ -74,6 +75,7 @@ class Statement(BaseModel):
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
statement: str = Field(..., description="The text content of the statement.")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||
stmt_type: StatementType = Field(..., description="The type of the statement.")
|
||||
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
|
||||
@@ -118,36 +120,36 @@ class Chunk(BaseModel):
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the chunk
|
||||
text: List of messages in the chunk
|
||||
content: The content of the chunk as a formatted string
|
||||
speaker: The speaker/role for this chunk (user/assistant)
|
||||
statements: List of statements extracted from this chunk
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
metadata: Additional metadata as key-value pairs
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
|
||||
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
|
||||
"""Create a chunk from a list of messages.
|
||||
def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None):
|
||||
"""Create a chunk from a single message (1 Message = 1 Chunk).
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
message: Single conversation message
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
Returns:
|
||||
Chunk instance with formatted content
|
||||
Chunk instance with speaker directly from message.role
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
# Generate content from messages
|
||||
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
|
||||
return cls(text=messages, content=content, metadata=metadata)
|
||||
|
||||
return cls(
|
||||
content=f"{message.role}: {message.msg}",
|
||||
speaker=message.role,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
class DialogData(BaseModel):
|
||||
"""Represents the complete data structure for a dialog record.
|
||||
|
||||
@@ -550,7 +550,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取情绪信息(优化版:全局陈述句级并行)
|
||||
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -558,7 +558,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
情绪信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始情绪信息提取(全局陈述句级并行)")
|
||||
logger.info("开始情绪信息提取(仅处理用户消息)")
|
||||
|
||||
# 收集所有陈述句及其配置
|
||||
all_statements = []
|
||||
@@ -597,15 +597,22 @@ class ExtractionOrchestrator:
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 收集所有陈述句(只收集 speaker 为 "user" 的)
|
||||
total_statements = 0
|
||||
filtered_statements = 0
|
||||
|
||||
# 收集所有陈述句
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
all_statements.append((statement, data_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
total_statements += 1
|
||||
# 只处理用户的陈述句 (role 为 "user")
|
||||
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
||||
all_statements.append((statement, data_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
filtered_statements += 1
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
|
||||
logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
|
||||
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
@@ -1033,6 +1040,7 @@ class ExtractionOrchestrator:
|
||||
apply_id=dialog_data.apply_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
statement=statement.statement,
|
||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||
statement_embedding=statement.statement_embedding,
|
||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||
|
||||
@@ -22,12 +22,12 @@ class DialogueChunker:
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
# 对于 LLMChunker,需要传入 llm_client
|
||||
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
@@ -41,29 +41,19 @@ class DialogueChunker:
|
||||
|
||||
Returns:
|
||||
A list of Chunk objects
|
||||
|
||||
Raises:
|
||||
ValueError: If chunking fails or returns empty chunks
|
||||
"""
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
# Defensive fallback: ensure at least one chunk is returned for non-empty content
|
||||
try:
|
||||
chunks = result_dialogue.chunks
|
||||
except Exception:
|
||||
chunks = []
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
if not chunks or len(chunks) == 0:
|
||||
# If the dialogue has content, return a single fallback chunk built from messages
|
||||
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
|
||||
if content_str and len(content_str.strip()) > 0:
|
||||
fallback_chunk = Chunk.from_messages(
|
||||
dialogue.context.msgs,
|
||||
metadata={
|
||||
"fallback": "single_chunk",
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
"source": "DialogueChunkerFallback",
|
||||
},
|
||||
)
|
||||
return [fallback_chunk]
|
||||
# No content: return empty list
|
||||
return []
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -72,22 +62,25 @@ class DialogueChunker:
|
||||
|
||||
Args:
|
||||
dialogue: The processed DialogData object with chunks
|
||||
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
|
||||
output_path: Optional path to save the output
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
"""
|
||||
if not output_path:
|
||||
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
||||
output_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt"
|
||||
)
|
||||
|
||||
output_lines = []
|
||||
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
|
||||
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
|
||||
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
|
||||
output_lines.append(f"Total characters: {len(dialogue.content)}")
|
||||
|
||||
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs)} messages",
|
||||
f"Total characters: {len(dialogue.content)}",
|
||||
f"Generated {len(dialogue.chunks)} chunks:"
|
||||
]
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {chunk.content}...")
|
||||
|
||||
@@ -5,8 +5,6 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.models.variate_config import StatementExtractionConfig
|
||||
from app.core.memory.utils.data.ontology import (
|
||||
LABEL_DEFINITIONS,
|
||||
@@ -22,11 +20,10 @@ logger = logging.getLogger(__name__)
|
||||
class ExtractedStatement(BaseModel):
|
||||
"""Schema for extracted statement from LLM"""
|
||||
statement: str = Field(..., description="The extracted statement text")
|
||||
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION")
|
||||
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
|
||||
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
||||
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
||||
|
||||
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
|
||||
class StatementExtractionResponse(BaseModel):
|
||||
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
|
||||
|
||||
@@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel):
|
||||
return v
|
||||
|
||||
class StatementExtractor:
|
||||
"""Class for extracting statements from dialog chunks using LLM (relations separated)"""
|
||||
"""Class for extracting statements from dialog chunks using LLM"""
|
||||
|
||||
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
|
||||
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
"""Initialize the StatementExtractor with an LLM client and configuration
|
||||
|
||||
Args:
|
||||
@@ -71,6 +67,21 @@ class StatementExtractor:
|
||||
self.llm_client = llm_client
|
||||
self.config = config or StatementExtractionConfig()
|
||||
|
||||
def _get_speaker_from_chunk(self, chunk) -> Optional[str]:
|
||||
"""Get speaker directly from Chunk
|
||||
|
||||
Args:
|
||||
chunk: Chunk object containing speaker field
|
||||
|
||||
Returns:
|
||||
Speaker role ("user"/"assistant") or None if cannot be determined
|
||||
"""
|
||||
if hasattr(chunk, 'speaker') and chunk.speaker:
|
||||
return chunk.speaker
|
||||
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
@@ -82,10 +93,12 @@ class StatementExtractor:
|
||||
Returns:
|
||||
List of ExtractedStatement objects extracted from the chunk
|
||||
"""
|
||||
# Prepare the chunk content for processing
|
||||
chunk_content = chunk.content
|
||||
|
||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||
return []
|
||||
|
||||
# Render the prompt using helper function
|
||||
prompt_content = await render_statement_extraction_prompt(
|
||||
chunk_content=chunk_content,
|
||||
definitions=LABEL_DEFINITIONS,
|
||||
@@ -136,7 +149,9 @@ class StatementExtractor:
|
||||
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
|
||||
except (KeyError, ValueError):
|
||||
relevence_info = RelevenceInfo.RELEVANT
|
||||
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
chunk_statement = Statement(
|
||||
statement=extracted_stmt.statement,
|
||||
stmt_type=stmt_type,
|
||||
@@ -144,7 +159,9 @@ class StatementExtractor:
|
||||
relevence_info=relevence_info,
|
||||
chunk_id=chunk.id,
|
||||
group_id=group_id,
|
||||
speaker=chunk_speaker,
|
||||
)
|
||||
|
||||
chunk_statements.append(chunk_statement)
|
||||
|
||||
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
|
||||
@@ -226,12 +243,7 @@ class StatementExtractor:
|
||||
return output_path
|
||||
|
||||
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
|
||||
"""按对话分组聚合强/弱关系并写入 TXT 文件。
|
||||
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
|
||||
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
|
||||
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
|
||||
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
|
||||
"""
|
||||
"""Group and aggregate strong/weak relations by dialogue and write to TXT file."""
|
||||
print("\n=== Relations Classify ===")
|
||||
|
||||
# 使用全局配置的输出路径
|
||||
|
||||
@@ -101,6 +101,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
# "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else []
|
||||
# }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}),
|
||||
"statement_embedding": statement.statement_embedding if statement.statement_embedding else None,
|
||||
# 添加 speaker 字段(用于基于角色的情绪提取)
|
||||
"speaker": statement.speaker if hasattr(statement, 'speaker') else None,
|
||||
# 添加情绪字段处理
|
||||
"emotion_type": statement.emotion_type,
|
||||
"emotion_intensity": statement.emotion_intensity,
|
||||
@@ -163,7 +165,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) ->
|
||||
"chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None,
|
||||
"sequence_number": chunk.sequence_number,
|
||||
"start_index": metadata.get("start_index"),
|
||||
"end_index": metadata.get("end_index")
|
||||
"end_index": metadata.get("end_index"),
|
||||
# 添加 speaker 字段(用于基于角色的情绪提取)
|
||||
"speaker": chunk.speaker if hasattr(chunk, 'speaker') else None
|
||||
}
|
||||
flattened_chunks.append(flattened_chunk)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class UserInput(BaseModel):
|
||||
|
||||
|
||||
class Write_UserInput(BaseModel):
|
||||
message: str
|
||||
messages: list[dict]
|
||||
group_id: str
|
||||
config_id: Optional[str] = None
|
||||
|
||||
|
||||
@@ -20,11 +20,13 @@ from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph
|
||||
from app.core.memory.agent.logger_file.log_streamer import LogStreamer
|
||||
from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_agent_schema import Write_UserInput
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -260,13 +262,13 @@ class MemoryAgentService:
|
||||
logger.info("Log streaming completed, cleaning up resources")
|
||||
# LogStreamer uses context manager for file handling, so cleanup is automatic
|
||||
|
||||
async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str:
|
||||
"""
|
||||
Process write operation with config_id
|
||||
|
||||
Args:
|
||||
group_id: Group identifier (also used as end_user_id)
|
||||
message: Message to write
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
config_id: Configuration ID from database
|
||||
db: SQLAlchemy database session
|
||||
storage_type: Storage type (neo4j or rag)
|
||||
@@ -287,7 +289,7 @@ class MemoryAgentService:
|
||||
raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.")
|
||||
except Exception as e:
|
||||
if "No memory configuration found" in str(e):
|
||||
raise # Re-raise our specific error
|
||||
raise
|
||||
logger.error(f"Failed to get connected config for end_user {group_id}: {e}")
|
||||
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
|
||||
|
||||
@@ -315,14 +317,28 @@ class MemoryAgentService:
|
||||
|
||||
try:
|
||||
if storage_type == "rag":
|
||||
result = await write_rag(group_id, message, user_rag_memory_id)
|
||||
# For RAG storage, convert messages to single string
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
result = await write_rag(group_id, message_text, user_rag_memory_id)
|
||||
return result
|
||||
else:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
# Convert structured messages to LangChain messages
|
||||
langchain_messages = []
|
||||
for msg in messages:
|
||||
if msg['role'] == 'user':
|
||||
langchain_messages.append(HumanMessage(content=msg['content']))
|
||||
elif msg['role'] == 'assistant':
|
||||
from langchain_core.messages import AIMessage
|
||||
langchain_messages.append(AIMessage(content=msg['content']))
|
||||
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id,
|
||||
"memory_config": memory_config}
|
||||
initial_state = {
|
||||
"messages": langchain_messages,
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config
|
||||
}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
@@ -335,7 +351,9 @@ class MemoryAgentService:
|
||||
massages = node_data
|
||||
massagesstatus = massages.get('write_result')['status']
|
||||
contents = massages.get('write_result')
|
||||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents)
|
||||
# Convert messages back to string for logging
|
||||
message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
||||
return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents)
|
||||
except Exception as e:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
@@ -500,6 +518,57 @@ class MemoryAgentService:
|
||||
optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
result = reorder_output_results(optimized_outputs)
|
||||
|
||||
# 保存短期记忆到数据库
|
||||
# 只有 search_switch 不为 "2"(快速检索)时才保存
|
||||
try:
|
||||
from app.repositories.memory_short_repository import ShortTermMemoryRepository
|
||||
|
||||
retrieved_content = []
|
||||
repo = ShortTermMemoryRepository(db)
|
||||
|
||||
if str(search_switch) != "2":
|
||||
for intermediate in _intermediate_outputs:
|
||||
logger.debug(f"处理中间结果: {intermediate}")
|
||||
intermediate_type = intermediate.get('type', '')
|
||||
|
||||
if intermediate_type == "search_result":
|
||||
query = intermediate.get('query', '')
|
||||
raw_results = intermediate.get('raw_results', {})
|
||||
reranked_results = raw_results.get('reranked_results', [])
|
||||
|
||||
try:
|
||||
statements = [statement['statement'] for statement in reranked_results.get('statements', [])]
|
||||
except Exception:
|
||||
statements = []
|
||||
|
||||
# 去重
|
||||
statements = list(set(statements))
|
||||
|
||||
if query and statements:
|
||||
retrieved_content.append({query: statements})
|
||||
|
||||
# 如果 retrieved_content 为空,设置为空字符串
|
||||
if retrieved_content == []:
|
||||
retrieved_content = ''
|
||||
|
||||
# 只有当回答不是"信息不足"且不是快速检索时才保存
|
||||
if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2":
|
||||
# 使用 upsert 方法
|
||||
repo.upsert(
|
||||
end_user_id=group_id,
|
||||
messages=message,
|
||||
aimessages=summary,
|
||||
retrieved_content=retrieved_content,
|
||||
search_switch=str(search_switch)
|
||||
)
|
||||
logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}")
|
||||
else:
|
||||
logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}")
|
||||
|
||||
except Exception as save_error:
|
||||
# 保存失败不应该影响主流程,只记录错误
|
||||
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
|
||||
|
||||
# Log successful operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
@@ -531,7 +600,49 @@ class MemoryAgentService:
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
"""
|
||||
Get standardized message list from user input.
|
||||
|
||||
Args:
|
||||
user_input: Write_UserInput object
|
||||
|
||||
Returns:
|
||||
list[dict]: Message list, each message contains role and content
|
||||
|
||||
Raises:
|
||||
ValueError: If messages is empty or format is incorrect
|
||||
"""
|
||||
from app.core.logging_config import get_api_logger
|
||||
logger = get_api_logger()
|
||||
|
||||
if len(user_input.messages) == 0:
|
||||
logger.error("Validation failed: Message list cannot be empty")
|
||||
raise ValueError("Message list cannot be empty")
|
||||
|
||||
for idx, msg in enumerate(user_input.messages):
|
||||
if not isinstance(msg, dict):
|
||||
logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}")
|
||||
raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}")
|
||||
|
||||
if 'role' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}")
|
||||
|
||||
if 'content' not in msg:
|
||||
logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}")
|
||||
raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}")
|
||||
|
||||
if msg['role'] not in ['user', 'assistant']:
|
||||
logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}")
|
||||
raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}")
|
||||
|
||||
if not msg['content'] or not msg['content'].strip():
|
||||
logger.error(f"Validation failed: Message {idx} content is empty")
|
||||
raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}")
|
||||
|
||||
logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}")
|
||||
return user_input.messages
|
||||
|
||||
async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict:
|
||||
"""
|
||||
|
||||
@@ -472,13 +472,19 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.write_message", bind=True)
|
||||
def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]:
|
||||
def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]:
|
||||
"""Celery task to process a write message via MemoryAgentService.
|
||||
|
||||
支持两种消息格式:
|
||||
1. 字符串格式(向后兼容):message="user: xxx\nassistant: yyy"
|
||||
2. 结构化消息列表(推荐):message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}]
|
||||
|
||||
Args:
|
||||
group_id: Group ID for the memory agent (also used as end_user_id)
|
||||
message: Message to write
|
||||
message: Message to write (str or list[dict])
|
||||
config_id: Optional configuration ID
|
||||
storage_type: Storage type (neo4j/rag)
|
||||
user_rag_memory_id: RAG memory ID
|
||||
|
||||
Returns:
|
||||
Dict containing the result and metadata
|
||||
|
||||
Reference in New Issue
Block a user