From 2e504f9c485aef4411c5d5621bb7e9f43cdfd6e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 21 Jan 2026 13:55:32 +0800 Subject: [PATCH] Feature/distinction role (#165) * [feature]A set of information for role recognition writing * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [fix]Based on the AI review to fix the code * [changes]Disable the function of batch writing multiple groups of conversations in a cumulative manner * [fix]Addressing vulnerability risks --- .../controllers/memory_agent_controller.py | 27 ++- api/app/core/agent/langchain_agent.py | 217 +++++++++++------- .../langgraph_graph/nodes/write_nodes.py | 40 ++-- .../agent/langgraph_graph/write_graph.py | 13 +- .../core/memory/agent/utils/get_dialogs.py | 63 ++--- .../core/memory/agent/utils/write_tools.py | 9 +- .../core/memory/llm_tools/chunker_client.py | 188 +++++++-------- .../core/memory/llm_tools/openai_client.py | 15 +- api/app/core/memory/models/graph_models.py | 7 + api/app/core/memory/models/message_models.py | 30 +-- .../extraction_orchestrator.py | 20 +- .../knowledge_extraction/chunk_extraction.py | 55 ++--- .../statement_extraction.py | 42 ++-- api/app/repositories/neo4j/add_nodes.py | 6 +- api/app/schemas/memory_agent_schema.py | 2 +- api/app/services/memory_agent_service.py | 76 +++++- api/app/tasks.py | 10 +- 17 files changed, 490 insertions(+), 330 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 46fe3043..416ed710 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -160,9 +160,12 @@ async def write_server( api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + result = await memory_agent_service.write_memory( user_input.group_id, - user_input.message, + messages_list, # 传递结构化消息列表 config_id, db, storage_type, @@ -219,9 +222,12 @@ async def write_server_async( if knowledge: user_rag_memory_id = str(knowledge.id) api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + task = celery_app.send_task( "app.core.memory.agent.write_message", - args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id] + args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Write task queued: {task.id}") @@ -564,8 +570,23 @@ async def status_type( """ api_logger.info(f"Status type check requested for group {user_input.group_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + + # 将消息列表转换为字符串用于分类 + # 只取最后一条用户消息进行分类 + last_user_message = "" + for msg in reversed(messages_list): + if msg.get('role') == 'user': + last_user_message = msg.get('content', '') + break + + if not last_user_message: + # 如果没有用户消息,使用所有消息的内容 + last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) + result = await memory_agent_service.classify_message_type( - user_input.message, + last_user_message, user_input.config_id, db ) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 91445b12..87b46e6f 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -145,44 +145,98 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - async def term_memory_save(self,messages,end_user_end,aimessages): - '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' - end_user_end=f"Term_{end_user_end}" - print(messages) - print(aimessages) - session_id = store.save_session( - userid=end_user_end, - messages=messages, - apply_id=end_user_end, - group_id=end_user_end, - aimessages=aimessages - ) - store.delete_duplicate_sessions() - # logger.info(f'Redis_Agent:{end_user_end};{session_id}') - return session_id - async def term_memory_redis_read(self,end_user_end): - end_user_end = f"Term_{end_user_end}" - history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) - # logger.info(f'Redis_Agent:{end_user_end};{history}') - messagss_list=[] - retrieved_content=[] - for messages in history: - query = messages.get("Query") - aimessages = messages.get("Answer") - messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - retrieved_content.append({query: aimessages}) - return messagss_list,retrieved_content +# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # async def term_memory_save(self,messages,end_user_end,aimessages): + # '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' + # end_user_end=f"Term_{end_user_end}" + # print(messages) + # print(aimessages) + # session_id = store.save_session( + # userid=end_user_end, + # messages=messages, + # apply_id=end_user_end, + # group_id=end_user_end, + # aimessages=aimessages + # ) + # store.delete_duplicate_sessions() + # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') + # return session_id + +# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # async def term_memory_redis_read(self,end_user_end): + # end_user_end = f"Term_{end_user_end}" + # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) + # # logger.info(f'Redis_Agent:{end_user_end};{history}') + # messagss_list=[] + # retrieved_content=[] + # for messages in history: + # query = messages.get("Query") + # aimessages = messages.get("Answer") + # messagss_list.append(f'用户:{query}。AI回复:{aimessages}') + # retrieved_content.append({query: aimessages}) + # return messagss_list,retrieved_content - - async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): + async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): + """ + 写入记忆(支持结构化消息) + + Args: + storage_type: 存储类型 (neo4j/rag) + end_user_id: 终端用户ID + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ if storage_type == "rag": - await write_rag(end_user_id, message, user_rag_memory_id) + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') else: - write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, - user_rag_memory_id) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if user_message: + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if ai_message: + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + # 调用 Celery 任务,传递结构化消息列表 + # 数据流: + # 1. structured_messages 传递给 write_message_task + # 2. write_message_task 调用 memory_agent_service.write_memory + # 3. write_memory 调用 write_tools.write,传递 messages 参数 + # 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数 + # 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段 + # 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段 + logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # group_id: 用户ID + structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + actual_config_id, # config_id: 配置ID + storage_type, # storage_type: "neo4j" + user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'Agent:{actual_end_user_id};{write_status}') + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') async def chat( self, @@ -227,29 +281,30 @@ class LangChainAgent: actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') +# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码 +# history_term_memory_result = await self.term_memory_redis_read(end_user_id) +# history_term_memory = history_term_memory_result[0] +# db_for_memory = next(get_db()) +# if memory_flag: +# if len(history_term_memory)>=4 and storage_type != "rag": +# history_term_memory = ';'.join(history_term_memory) +# retrieved_content = history_term_memory_result[1] +# print(retrieved_content) +# # 为长期记忆操作获取新的数据库连接 +# try: +# repo = LongTermMemoryRepository(db_for_memory) +# repo.upsert(end_user_id, retrieved_content) +# logger.info( +# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') +# except Exception as e: +# logger.error(f"Failed to write to LongTermMemory: {e}") +# raise +# finally: +# db_for_memory.close() - history_term_memory_result = await self.term_memory_redis_read(end_user_id) - history_term_memory = history_term_memory_result[0] - db_for_memory = next(get_db()) - if memory_flag: - if len(history_term_memory)>=4 and storage_type != "rag": - history_term_memory = ';'.join(history_term_memory) - retrieved_content = history_term_memory_result[1] - print(retrieved_content) - # 为长期记忆操作获取新的数据库连接 - try: - repo = LongTermMemoryRepository(db_for_memory) - repo.upsert(end_user_id, retrieved_content) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - except Exception as e: - logger.error(f"Failed to write to LongTermMemory: {e}") - raise - finally: - db_for_memory.close() - - await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id) - await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id) +# # 长期记忆写入( +# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id) +# # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -277,8 +332,10 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id) - await self.term_memory_save(message_chat,end_user_id,content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) + # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # await self.term_memory_save(message_chat, end_user_id, content) response = { "content": content, "model": self.model_name, @@ -346,27 +403,27 @@ class LangChainAgent: db.close() except Exception as e: logger.warning(f"Failed to get db session: {e}") +# # TODO 乐力齐 +# history_term_memory_result = await self.term_memory_redis_read(end_user_id) +# history_term_memory = history_term_memory_result[0] +# if memory_flag: +# if len(history_term_memory) >= 4 and storage_type != "rag": +# history_term_memory = ';'.join(history_term_memory) +# retrieved_content = history_term_memory_result[1] +# db_for_memory = next(get_db()) +# try: +# repo = LongTermMemoryRepository(db_for_memory) +# repo.upsert(end_user_id, retrieved_content) +# logger.info( +# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') +# # 长期记忆写入 +# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id) +# except Exception as e: +# logger.error(f"Failed to write to long term memory: {e}") +# finally: +# db_for_memory.close() - history_term_memory_result = await self.term_memory_redis_read(end_user_id) - history_term_memory = history_term_memory_result[0] - if memory_flag: - if len(history_term_memory) >= 4 and storage_type != "rag": - history_term_memory = ';'.join(history_term_memory) - retrieved_content = history_term_memory_result[1] - db_for_memory = next(get_db()) - try: - repo = LongTermMemoryRepository(db_for_memory) - repo.upsert(end_user_id, retrieved_content) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, - history_term_memory, actual_config_id) - except Exception as e: - logger.error(f"Failed to write to long term memory: {e}") - finally: - db_for_memory.close() - - await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id) + # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -418,8 +475,10 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") if memory_flag: - await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id) - await self.term_memory_save(message_chat, end_user_id, full_content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) + # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # await self.term_memory_save(message_chat, end_user_id, full_content) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index 8421d059..6af313c3 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -9,22 +9,29 @@ async def write_node(state: WriteState) -> WriteState: Write data to the database/file system. Args: - ctx: FastMCP context for dependency injection - content: Data content to write - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration + state: WriteState containing messages, group_id, and memory_config Returns: - dict: Contains 'status', 'saved_to', and 'data' fields + dict: Contains 'write_result' with status and data fields """ - content=state.get('data','') - group_id=state.get('group_id','') - memory_config=state.get('memory_config', '') + messages = state.get('messages', []) + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', '') + + # Convert LangChain messages to structured format expected by write() + structured_messages = [] + for msg in messages: + if hasattr(msg, 'type') and hasattr(msg, 'content'): + # Map LangChain message types to role names + role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type + structured_messages.append({ + "role": role, + "content": msg.content # content is now guaranteed to be a string + }) + try: - result=await write( - content=content, + result = await write( + messages=structured_messages, user_id=group_id, apply_id=group_id, group_id=group_id, @@ -32,18 +39,17 @@ async def write_node(state: WriteState) -> WriteState: ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") - write_result= { + write_result = { "status": "success", - "data": content, + "data": structured_messages, "config_id": memory_config.config_id, "config_name": memory_config.config_name, } - return {"write_result":write_result} - + return {"write_result": write_result} except Exception as e: logger.error(f"Data_write failed: {e}", exc_info=True) - write_result= { + write_result = { "status": "error", "message": str(e), } diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 5a6f1e28..fe281a23 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -14,7 +14,6 @@ from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -27,18 +26,12 @@ async def make_write_graph(): """ Create a write graph workflow for memory operations. - Args: - user_id: User identifier - tools: MCP tools loaded from session - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration + The workflow directly processes messages from the initial state + and saves them to Neo4j storage. """ workflow = StateGraph(WriteState) - workflow.add_node("content_input", content_input_write) workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "content_input") - workflow.add_edge("content_input", "save_neo4j") + workflow.add_edge(START, "save_neo4j") workflow.add_edge("save_neo4j", END) graph = workflow.compile() diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index b03fe57c..82a41773 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -12,32 +12,49 @@ async def get_chunked_dialogs( group_id: str = "group_1", user_id: str = "user1", apply_id: str = "applyid", - content: str = "这是用户的输入", + messages: list = None, ref_id: str = "wyl_20251027", config_id: str = None ) -> List[DialogData]: - """Generate chunks from all test data entries using the specified chunker strategy. + """Generate chunks from structured messages using the specified chunker strategy. Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) group_id: Group identifier user_id: User identifier apply_id: Application identifier - content: Dialog content + messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier config_id: Configuration ID for processing Returns: - List of DialogData objects with generated chunks for each test entry + List of DialogData objects with generated chunks """ - dialog_data_list = [] - messages = [] - - messages.append(ConversationMessage(role="用户", msg=content)) - - # Create DialogData - conversation_context = ConversationContext(msgs=messages) - # Create DialogData with group_id based on the entry's id for uniqueness + from app.core.logging_config import get_agent_logger + logger = get_agent_logger(__name__) + + if not messages or not isinstance(messages, list) or len(messages) == 0: + raise ValueError("messages parameter must be a non-empty list") + + conversation_messages = [] + + for idx, msg in enumerate(messages): + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields") + + role = msg['role'] + content = msg['content'] + + if role not in ['user', 'assistant']: + raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") + + if content.strip(): + conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) + + if not conversation_messages: + raise ValueError("Message list cannot be empty after filtering") + + conversation_context = ConversationContext(msgs=conversation_messages) dialog_data = DialogData( context=conversation_context, ref_id=ref_id, @@ -46,25 +63,11 @@ async def get_chunked_dialogs( apply_id=apply_id, config_id=config_id ) - # Create DialogueChunker and process the dialogue + chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = extracted_chunks + + logger.info(f"DialogData created with {len(extracted_chunks)} chunks") - dialog_data_list.append(dialog_data) - - # Convert to dict with datetime serialized - def serialize_datetime(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - combined_output = [dd.model_dump() for dd in dialog_data_list] - - print(dialog_data_list) - - # with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f: - # json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime) - - - return dialog_data_list + return [dialog_data] diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 53c941ad..1df0b336 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -29,25 +29,22 @@ logger = get_agent_logger(__name__) async def write( - content: str, user_id: str, apply_id: str, group_id: str, memory_config: MemoryConfig, + messages: list, ref_id: str = "wyl20251027", ) -> None: """ Execute the complete knowledge extraction pipeline. - Only MemoryConfig is needed - LLM and embedding clients are constructed - internally from the config. - Args: - content: Dialogue content to process user_id: User identifier apply_id: Application identifier group_id: Group identifier memory_config: MemoryConfig object containing all configuration + messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference ID, defaults to "wyl20251027" """ # Extract config values @@ -89,7 +86,7 @@ async def write( group_id=group_id, user_id=user_id, apply_id=apply_id, - content=content, + messages=messages, ref_id=ref_id, config_id=config_id, ) diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 4178ce0a..87cdb9f4 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -4,6 +4,7 @@ import os import asyncio import json import numpy as np +import logging # Fix tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk try: from app.core.memory.llm_tools.openai_client import OpenAIClient except Exception: - # 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入 OpenAIClient = Any +# Initialize logger +logger = logging.getLogger(__name__) + class LLMChunker: - """基于LLM的智能分块策略""" + """LLM-based intelligent chunking strategy""" def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000): self.llm_client = llm_client self.chunk_size = chunk_size async def __call__(self, text: str) -> List[Any]: - # 使用LLM分析文本结构并进行智能分块 prompt = f""" - 请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。 - 请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。 + Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long. + Return results in JSON format with a chunks array, each chunk having a text field. - 文本内容: + Text content: {text[:5000]} """ messages = [ - {"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"}, + {"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."}, {"role": "user", "content": prompt} ] @@ -171,8 +173,6 @@ class ChunkerClient: base_chunk_size=self.chunk_size, ) elif chunker_config.chunker_strategy == "SentenceChunker": - # 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数 - # 为了兼容不同版本,这里仅传递广泛支持的参数 self.chunker = SentenceChunker( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, @@ -186,100 +186,93 @@ class ChunkerClient: async def generate_chunks(self, dialogue: DialogData): """ - 生成分块,支持异步操作 + Generate chunks following 1 Message = 1 Chunk strategy. + + Each message creates one chunk, directly inheriting role information. + If a message is too long, it will be split into multiple sub-chunks, + each maintaining the same speaker. + + Raises: + ValueError: If dialogue has no messages or chunking fails """ - try: - # 预处理文本:确保对话标记格式统一 - content = dialogue.content - content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号 - content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行 - - if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__): - # 同步分块器 - chunks = self.chunker(content) + # Validate dialogue has messages + if not dialogue.context or not dialogue.context.msgs: + raise ValueError( + f"Dialogue {dialogue.ref_id} has no messages. " + f"Cannot generate chunks from empty dialogue." + ) + + dialogue.chunks = [] + + # 按消息分块:每个消息创建一个或多个 chunk,直接继承角色 + for msg_idx, msg in enumerate(dialogue.context.msgs): + # Validate message has required attributes + if not hasattr(msg, 'role') or not hasattr(msg, 'msg'): + raise ValueError( + f"Message {msg_idx} in dialogue {dialogue.ref_id} " + f"missing 'role' or 'msg' attribute" + ) + + msg_content = msg.msg.strip() + + # Skip empty messages + if not msg_content: + continue + + # 如果消息太长,可以进一步分块 + if len(msg_content) > self.chunk_size: + # 对单个消息的内容进行分块 + try: + sub_chunks = self.chunker(msg_content) + except Exception as e: + raise ValueError( + f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}" + ) + + for idx, sub_chunk in enumerate(sub_chunks): + sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk) + sub_chunk_text = sub_chunk_text.strip() + + if len(sub_chunk_text) < (self.min_characters_per_chunk or 50): + continue + + chunk = Chunk( + content=f"{msg.role}: {sub_chunk_text}", + speaker=msg.role, # 直接继承角色 + metadata={ + "message_index": msg_idx, + "message_role": msg.role, + "sub_chunk_index": idx, + "total_sub_chunks": len(sub_chunks), + "chunker_strategy": self.chunker_config.chunker_strategy, + }, + ) + dialogue.chunks.append(chunk) else: - # 异步分块器(如LLMChunker) - chunks = await self.chunker(content) - - # 过滤空块和过小的块 - valid_chunks = [] - for c in chunks: - chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c - if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50): - valid_chunks.append(c) - - dialogue.chunks = [ - Chunk( - content=c.text if hasattr(c, 'text') else str(c), + # 消息不长,直接作为一个 chunk + chunk = Chunk( + content=f"{msg.role}: {msg_content}", + speaker=msg.role, # 直接继承角色 metadata={ - "start_index": getattr(c, "start_index", None), - "end_index": getattr(c, "end_index", None), + "message_index": msg_idx, + "message_role": msg.role, "chunker_strategy": self.chunker_config.chunker_strategy, }, ) - for c in valid_chunks - ] - return dialogue - - except Exception as e: - print(f"分块失败: {e}") - - # 改进的后备方案:尝试按对话回合分割 - try: - # 简单的按对话分割 - dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)' - matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL) - - class SimpleChunk: - def __init__(self, text, start_index, end_index): - self.text = text - self.start_index = start_index - self.end_index = end_index - - chunks = [] - current_chunk = "" - current_start = 0 - - for match in matches: - speaker, ct = match[0], match[1].strip() - turn_text = f"{speaker} {ct}" - - if len(current_chunk) + len(turn_text) > (self.chunk_size or 500): - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - current_chunk = turn_text - current_start = dialogue.content.find(turn_text, current_start) - else: - current_chunk += ("\n" + turn_text) if current_chunk else turn_text - - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - - dialogue.chunks = [ - Chunk( - content=c.text, - metadata={ - "start_index": c.start_index, - "end_index": c.end_index, - "chunker_strategy": "DialogueTurnFallback", - }, - ) - for c in chunks - ] - - except Exception: - # 最后的手段:单一大块 - dialogue.chunks = [Chunk( - content=dialogue.content, - metadata={"chunker_strategy": "SingleChunkFallback"}, - )] - - return dialogue + dialogue.chunks.append(chunk) + + # Validate we generated at least one chunk + if not dialogue.chunks: + raise ValueError( + f"No valid chunks generated for dialogue {dialogue.ref_id}. " + f"All messages were either empty or too short. " + f"Messages count: {len(dialogue.context.msgs)}" + ) + + return dialogue def evaluate_chunking(self, dialogue: DialogData) -> dict: - """ - 评估分块质量 - """ + """Evaluate chunking quality.""" if not getattr(dialogue, 'chunks', None): return {} @@ -304,11 +297,8 @@ class ChunkerClient: return metrics def save_chunking_results(self, dialogue: DialogData, output_path: str): - """ - 保存分块结果到文件,文件名包含策略名称 - """ + """Save chunking results to file with strategy name in filename.""" strategy_name = self.chunker_config.chunker_strategy - # 在文件名中添加策略名称 base_name, ext = os.path.splitext(output_path) strategy_output_path = f"{base_name}_{strategy_name}{ext}" diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index dce7b495..43c2b445 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -92,8 +92,6 @@ class OpenAIClient(LLMClient): config["callbacks"] = [self.langfuse_handler] response = await chain.ainvoke({"messages": messages}, config=config) - - logger.debug(f"LLM 响应成功: {len(str(response))} 字符") return response except Exception as e: @@ -149,13 +147,10 @@ class OpenAIClient(LLMClient): config=config ) - logger.debug(f"使用 PydanticOutputParser 解析成功") return parsed except Exception as e: - logger.warning( - f"PydanticOutputParser 解析失败,尝试其他方法: {e}" - ) + logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}") # 方法 2: 使用 LangChain 的 with_structured_output template = """{question}""" @@ -173,13 +168,17 @@ class OpenAIClient(LLMClient): # 验证并返回结果 try: - return response_model.model_validate(parsed) + result = response_model.model_validate(parsed) + return result except Exception: # 如果已经是 Pydantic 实例,直接返回 if hasattr(parsed, "model_dump"): return parsed # 尝试从 JSON 解析 - return response_model.model_validate_json(json.dumps(parsed)) + result = response_model.model_validate_json(json.dumps(parsed)) + return result + else: + logger.warning("with_structured_output 方法不可用") except Exception as e: logger.error(f"结构化输出失败: {e}") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 39d618fc..7a48d6cb 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -224,6 +224,7 @@ class StatementNode(Node): chunk_id: ID of the parent chunk this statement belongs to stmt_type: Type of the statement (from ontology) statement: The actual statement text content + speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses) emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node emotion_target: Optional emotion target (person or object name) emotion_subject: Optional emotion subject (self/other/object) @@ -249,6 +250,12 @@ class StatementNode(Node): stmt_type: str = Field(..., description="Type of the statement") statement: str = Field(..., description="The statement text content") + # Speaker identification + speaker: Optional[str] = Field( + None, + description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" + ) + # Emotion fields (ordered as requested, emotion_intensity first for display) emotion_intensity: Optional[float] = Field( None, diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 199bdd75..bcf08999 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -25,10 +25,10 @@ class ConversationMessage(BaseModel): """Represents a single message in a conversation. Attributes: - role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant) + role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant) msg: Text content of the message """ - role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').") + role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") msg: str = Field(..., description="The text content of the message.") @@ -57,6 +57,7 @@ class Statement(BaseModel): chunk_id: ID of the parent chunk this statement belongs to group_id: Optional group ID for multi-tenancy statement: The actual statement text content + speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) statement_embedding: Optional embedding vector for the statement stmt_type: Type of the statement (from ontology) temporal_info: Temporal information extracted from the statement @@ -74,6 +75,7 @@ class Statement(BaseModel): chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") statement: str = Field(..., description="The text content of the statement.") + speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") stmt_type: StatementType = Field(..., description="The type of the statement.") temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.") @@ -118,36 +120,36 @@ class Chunk(BaseModel): Attributes: id: Unique identifier for the chunk - text: List of messages in the chunk content: The content of the chunk as a formatted string + speaker: The speaker/role for this chunk (user/assistant) statements: List of statements extracted from this chunk chunk_embedding: Optional embedding vector for the chunk metadata: Additional metadata as key-value pairs """ id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.") - text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.") content: str = Field(..., description="The content of the chunk as a string.") + speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") @classmethod - def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None): - """Create a chunk from a list of messages. + def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None): + """Create a chunk from a single message (1 Message = 1 Chunk). Args: - messages: List of conversation messages + message: Single conversation message metadata: Optional metadata dictionary Returns: - Chunk instance with formatted content + Chunk instance with speaker directly from message.role """ - if metadata is None: - metadata = {} - # Generate content from messages - content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages]) - return cls(text=messages, content=content, metadata=metadata) - + return cls( + content=f"{message.role}: {message.msg}", + speaker=message.role, + metadata=metadata or {} + ) + class DialogData(BaseModel): """Represents the complete data structure for a dialog record. diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 75aaa7df..46ba1dde 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -550,7 +550,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ - 从对话中提取情绪信息(优化版:全局陈述句级并行) + 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) Args: dialog_data_list: 对话数据列表 @@ -558,7 +558,7 @@ class ExtractionOrchestrator: Returns: 情绪信息映射列表,每个对话对应一个字典 """ - logger.info("开始情绪信息提取(全局陈述句级并行)") + logger.info("开始情绪信息提取(仅处理用户消息)") # 收集所有陈述句及其配置 all_statements = [] @@ -597,15 +597,22 @@ class ExtractionOrchestrator: if not data_config or not data_config.emotion_enabled: logger.info("情绪提取未启用,跳过") return [{} for _ in dialog_data_list] + + # 收集所有陈述句(只收集 speaker 为 "user" 的) + total_statements = 0 + filtered_statements = 0 - # 收集所有陈述句 for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: - all_statements.append((statement, data_config)) - statement_metadata.append((d_idx, statement.id)) + total_statements += 1 + # 只处理用户的陈述句 (role 为 "user") + if hasattr(statement, 'speaker') and statement.speaker == "user": + all_statements.append((statement, data_config)) + statement_metadata.append((d_idx, statement.id)) + filtered_statements += 1 - logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") + logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪") # 初始化情绪提取服务 from app.services.emotion_extraction_service import EmotionExtractionService @@ -1033,6 +1040,7 @@ class ExtractionOrchestrator: apply_id=dialog_data.apply_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, + speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 statement_embedding=statement.statement_embedding, valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py index edb60a4d..40e98507 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py @@ -22,12 +22,12 @@ class DialogueChunker: Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker + Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker """ self.chunker_strategy = chunker_strategy chunker_config_dict = get_chunker_config(chunker_strategy) self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) - # 对于 LLMChunker,需要传入 llm_client + if self.chunker_config.chunker_strategy == "LLMChunker": self.chunker_client = ChunkerClient(self.chunker_config, llm_client) else: @@ -41,29 +41,19 @@ class DialogueChunker: Returns: A list of Chunk objects + + Raises: + ValueError: If chunking fails or returns empty chunks """ result_dialogue = await self.chunker_client.generate_chunks(dialogue) - # Defensive fallback: ensure at least one chunk is returned for non-empty content - try: - chunks = result_dialogue.chunks - except Exception: - chunks = [] + chunks = result_dialogue.chunks if not chunks or len(chunks) == 0: - # If the dialogue has content, return a single fallback chunk built from messages - content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "") - if content_str and len(content_str.strip()) > 0: - fallback_chunk = Chunk.from_messages( - dialogue.context.msgs, - metadata={ - "fallback": "single_chunk", - "chunker_strategy": self.chunker_config.chunker_strategy, - "source": "DialogueChunkerFallback", - }, - ) - return [fallback_chunk] - # No content: return empty list - return [] + raise ValueError( + f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " + f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, " + f"Strategy: {self.chunker_config.chunker_strategy}" + ) return chunks @@ -72,22 +62,25 @@ class DialogueChunker: Args: dialogue: The processed DialogData object with chunks - output_path: Optional path to save the output (default: chunker_output_{strategy}.txt) + output_path: Optional path to save the output Returns: The path where the output was saved """ if not output_path: - output_path = os.path.join(os.path.dirname(__file__), "..", "..", - f"chunker_output_{self.chunker_strategy.lower()}.txt") + output_path = os.path.join( + os.path.dirname(__file__), "..", "..", + f"chunker_output_{self.chunker_strategy.lower()}.txt" + ) - output_lines = [] - output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===") - output_lines.append(f"Dialogue ID: {dialogue.ref_id}") - output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages") - output_lines.append(f"Total characters: {len(dialogue.content)}") - - output_lines.append(f"Generated {len(dialogue.chunks)} chunks:") + output_lines = [ + f"=== Chunking Results ({self.chunker_strategy}) ===", + f"Dialogue ID: {dialogue.ref_id}", + f"Original conversation has {len(dialogue.context.msgs)} messages", + f"Total characters: {len(dialogue.content)}", + f"Generated {len(dialogue.chunks)} chunks:" + ] + for i, chunk in enumerate(dialogue.chunks): output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") output_lines.append(f" Content preview: {chunk.content}...") diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index 8d37f5d2..fb1b539a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -5,8 +5,6 @@ from datetime import datetime from typing import Any, Dict, List, Optional from app.core.memory.models.message_models import DialogData, Statement - -#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 from app.core.memory.models.variate_config import StatementExtractionConfig from app.core.memory.utils.data.ontology import ( LABEL_DEFINITIONS, @@ -22,11 +20,10 @@ logger = logging.getLogger(__name__) class ExtractedStatement(BaseModel): """Schema for extracted statement from LLM""" statement: str = Field(..., description="The extracted statement text") - statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION") + statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION") temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL") relevence: str = Field(..., description="RELEVANT or IRRELEVANT") -# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句) class StatementExtractionResponse(BaseModel): statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements") @@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel): return v class StatementExtractor: - """Class for extracting statements from dialog chunks using LLM (relations separated)""" + """Class for extracting statements from dialog chunks using LLM""" def __init__(self, llm_client: Any, config: StatementExtractionConfig = None): - # 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 """Initialize the StatementExtractor with an LLM client and configuration Args: @@ -71,6 +67,21 @@ class StatementExtractor: self.llm_client = llm_client self.config = config or StatementExtractionConfig() + def _get_speaker_from_chunk(self, chunk) -> Optional[str]: + """Get speaker directly from Chunk + + Args: + chunk: Chunk object containing speaker field + + Returns: + Speaker role ("user"/"assistant") or None if cannot be determined + """ + if hasattr(chunk, 'speaker') and chunk.speaker: + return chunk.speaker + + logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") + return None + async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: """Process a single chunk and return extracted statements @@ -82,10 +93,12 @@ class StatementExtractor: Returns: List of ExtractedStatement objects extracted from the chunk """ - # Prepare the chunk content for processing chunk_content = chunk.content + + if not chunk_content or len(chunk_content.strip()) < 5: + logger.warning(f"Chunk {chunk.id} content too short or empty, skipping") + return [] - # Render the prompt using helper function prompt_content = await render_statement_extraction_prompt( chunk_content=chunk_content, definitions=LABEL_DEFINITIONS, @@ -136,7 +149,9 @@ class StatementExtractor: relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT except (KeyError, ValueError): relevence_info = RelevenceInfo.RELEVANT - + + chunk_speaker = self._get_speaker_from_chunk(chunk) + chunk_statement = Statement( statement=extracted_stmt.statement, stmt_type=stmt_type, @@ -144,7 +159,9 @@ class StatementExtractor: relevence_info=relevence_info, chunk_id=chunk.id, group_id=group_id, + speaker=chunk_speaker, ) + chunk_statements.append(chunk_statement) # 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata @@ -226,12 +243,7 @@ class StatementExtractor: return output_path def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str: - """按对话分组聚合强/弱关系并写入 TXT 文件。 - - 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content` - - 在该对话段内再分为 Strong Relations / Weak Relations 两部分 - - Strong: 逐条输出 `Chunk ID` 与 `Triple` - - Weak: 逐条输出 `Chunk ID` 与 `Entity` - """ + """Group and aggregate strong/weak relations by dialogue and write to TXT file.""" print("\n=== Relations Classify ===") # 使用全局配置的输出路径 diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 1e24eeae..cf60a773 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -101,6 +101,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), "statement_embedding": statement.statement_embedding if statement.statement_embedding else None, + # 添加 speaker 字段(用于基于角色的情绪提取) + "speaker": statement.speaker if hasattr(statement, 'speaker') else None, # 添加情绪字段处理 "emotion_type": statement.emotion_type, "emotion_intensity": statement.emotion_intensity, @@ -163,7 +165,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, "sequence_number": chunk.sequence_number, "start_index": metadata.get("start_index"), - "end_index": metadata.get("end_index") + "end_index": metadata.get("end_index"), + # 添加 speaker 字段(用于基于角色的情绪提取) + "speaker": chunk.speaker if hasattr(chunk, 'speaker') else None } flattened_chunks.append(flattened_chunk) diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 47dc6b2a..fbc0e45c 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -12,7 +12,7 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): - message: str + messages: list[dict] group_id: str config_id: Optional[str] = None diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index fd0cb0eb..65dd628a 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -20,11 +20,13 @@ from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.type_classifier import status_typle +from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_base_service import Translation_English from app.services.memory_config_service import MemoryConfigService @@ -260,13 +262,13 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: + async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id Args: group_id: Group identifier (also used as end_user_id) - message: Message to write + messages: Structured message list [{"role": "user", "content": "..."}, ...] config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -287,7 +289,7 @@ class MemoryAgentService: raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): - raise # Re-raise our specific error + raise logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") @@ -315,14 +317,28 @@ class MemoryAgentService: try: if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) + # For RAG storage, convert messages to single string + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + result = await write_rag(group_id, message_text, user_rag_memory_id) return result else: async with make_write_graph() as graph: config = {"configurable": {"thread_id": group_id}} + # Convert structured messages to LangChain messages + langchain_messages = [] + for msg in messages: + if msg['role'] == 'user': + langchain_messages.append(HumanMessage(content=msg['content'])) + elif msg['role'] == 'assistant': + from langchain_core.messages import AIMessage + langchain_messages.append(AIMessage(content=msg['content'])) + # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, - "memory_config": memory_config} + initial_state = { + "messages": langchain_messages, + "group_id": group_id, + "memory_config": memory_config + } # 获取节点更新信息 async for update_event in graph.astream( @@ -335,7 +351,9 @@ class MemoryAgentService: massages = node_data massagesstatus = massages.get('write_result')['status'] contents = massages.get('write_result') - return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) + # Convert messages back to string for logging + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" @@ -531,7 +549,49 @@ class MemoryAgentService: ) raise ValueError(error_msg) - + def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: + """ + Get standardized message list from user input. + + Args: + user_input: Write_UserInput object + + Returns: + list[dict]: Message list, each message contains role and content + + Raises: + ValueError: If messages is empty or format is incorrect + """ + from app.core.logging_config import get_api_logger + logger = get_api_logger() + + if len(user_input.messages) == 0: + logger.error("Validation failed: Message list cannot be empty") + raise ValueError("Message list cannot be empty") + + for idx, msg in enumerate(user_input.messages): + if not isinstance(msg, dict): + logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") + raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") + + if 'role' not in msg: + logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") + raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}") + + if 'content' not in msg: + logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") + raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") + + if msg['role'] not in ['user', 'assistant']: + logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") + raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}") + + if not msg['content'] or not msg['content'].strip(): + logger.error(f"Validation failed: Message {idx} content is empty") + raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}") + + logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") + return user_input.messages async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index fba9f290..e375de35 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -472,13 +472,19 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: +def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. + 支持两种消息格式: + 1. 字符串格式(向后兼容):message="user: xxx\nassistant: yyy" + 2. 结构化消息列表(推荐):message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}] + Args: group_id: Group ID for the memory agent (also used as end_user_id) - message: Message to write + message: Message to write (str or list[dict]) config_id: Optional configuration ID + storage_type: Storage type (neo4j/rag) + user_rag_memory_id: RAG memory ID Returns: Dict containing the result and metadata