From 2e504f9c485aef4411c5d5621bb7e9f43cdfd6e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 21 Jan 2026 13:55:32 +0800 Subject: [PATCH 1/8] Feature/distinction role (#165) * [feature]A set of information for role recognition writing * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [fix]Based on the AI review to fix the code * [changes]Disable the function of batch writing multiple groups of conversations in a cumulative manner * [fix]Addressing vulnerability risks --- .../controllers/memory_agent_controller.py | 27 ++- api/app/core/agent/langchain_agent.py | 217 +++++++++++------- .../langgraph_graph/nodes/write_nodes.py | 40 ++-- .../agent/langgraph_graph/write_graph.py | 13 +- .../core/memory/agent/utils/get_dialogs.py | 63 ++--- .../core/memory/agent/utils/write_tools.py | 9 +- .../core/memory/llm_tools/chunker_client.py | 188 +++++++-------- .../core/memory/llm_tools/openai_client.py | 15 +- api/app/core/memory/models/graph_models.py | 7 + api/app/core/memory/models/message_models.py | 30 +-- .../extraction_orchestrator.py | 20 +- .../knowledge_extraction/chunk_extraction.py | 55 ++--- .../statement_extraction.py | 42 ++-- api/app/repositories/neo4j/add_nodes.py | 6 +- api/app/schemas/memory_agent_schema.py | 2 +- api/app/services/memory_agent_service.py | 76 +++++- api/app/tasks.py | 10 +- 17 files changed, 490 insertions(+), 330 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 46fe3043..416ed710 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -160,9 +160,12 @@ async def write_server( api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + result = await memory_agent_service.write_memory( user_input.group_id, - user_input.message, + messages_list, # 传递结构化消息列表 config_id, db, storage_type, @@ -219,9 +222,12 @@ async def write_server_async( if knowledge: user_rag_memory_id = str(knowledge.id) api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + task = celery_app.send_task( "app.core.memory.agent.write_message", - args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id] + args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Write task queued: {task.id}") @@ -564,8 +570,23 @@ async def status_type( """ api_logger.info(f"Status type check requested for group {user_input.group_id}") try: + # 获取标准化的消息列表 + messages_list = memory_agent_service.get_messages_list(user_input) + + # 将消息列表转换为字符串用于分类 + # 只取最后一条用户消息进行分类 + last_user_message = "" + for msg in reversed(messages_list): + if msg.get('role') == 'user': + last_user_message = msg.get('content', '') + break + + if not last_user_message: + # 如果没有用户消息,使用所有消息的内容 + last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) + result = await memory_agent_service.classify_message_type( - user_input.message, + last_user_message, user_input.config_id, db ) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 91445b12..87b46e6f 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -145,44 +145,98 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - async def term_memory_save(self,messages,end_user_end,aimessages): - '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' - end_user_end=f"Term_{end_user_end}" - print(messages) - print(aimessages) - session_id = store.save_session( - userid=end_user_end, - messages=messages, - apply_id=end_user_end, - group_id=end_user_end, - aimessages=aimessages - ) - store.delete_duplicate_sessions() - # logger.info(f'Redis_Agent:{end_user_end};{session_id}') - return session_id - async def term_memory_redis_read(self,end_user_end): - end_user_end = f"Term_{end_user_end}" - history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) - # logger.info(f'Redis_Agent:{end_user_end};{history}') - messagss_list=[] - retrieved_content=[] - for messages in history: - query = messages.get("Query") - aimessages = messages.get("Answer") - messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - retrieved_content.append({query: aimessages}) - return messagss_list,retrieved_content +# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # async def term_memory_save(self,messages,end_user_end,aimessages): + # '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' + # end_user_end=f"Term_{end_user_end}" + # print(messages) + # print(aimessages) + # session_id = store.save_session( + # userid=end_user_end, + # messages=messages, + # apply_id=end_user_end, + # group_id=end_user_end, + # aimessages=aimessages + # ) + # store.delete_duplicate_sessions() + # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') + # return session_id + +# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # async def term_memory_redis_read(self,end_user_end): + # end_user_end = f"Term_{end_user_end}" + # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) + # # logger.info(f'Redis_Agent:{end_user_end};{history}') + # messagss_list=[] + # retrieved_content=[] + # for messages in history: + # query = messages.get("Query") + # aimessages = messages.get("Answer") + # messagss_list.append(f'用户:{query}。AI回复:{aimessages}') + # retrieved_content.append({query: aimessages}) + # return messagss_list,retrieved_content - - async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id): + async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): + """ + 写入记忆(支持结构化消息) + + Args: + storage_type: 存储类型 (neo4j/rag) + end_user_id: 终端用户ID + user_message: 用户消息内容 + ai_message: AI 回复内容 + user_rag_memory_id: RAG 记忆ID + actual_end_user_id: 实际用户ID + actual_config_id: 配置ID + + 逻辑说明: + - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 + - Neo4j 模式:使用结构化消息列表 + 1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant] + 2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景) + 3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段 + """ if storage_type == "rag": - await write_rag(end_user_id, message, user_rag_memory_id) + # RAG 模式:组合消息为字符串格式(保持原有逻辑) + combined_message = f"user: {user_message}\nassistant: {ai_message}" + await write_rag(end_user_id, combined_message, user_rag_memory_id) logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') else: - write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type, - user_rag_memory_id) + # Neo4j 模式:使用结构化消息列表 + structured_messages = [] + + # 始终添加用户消息(如果不为空) + if user_message: + structured_messages.append({"role": "user", "content": user_message}) + + # 只有当 AI 回复不为空时才添加 assistant 消息 + if ai_message: + structured_messages.append({"role": "assistant", "content": ai_message}) + + # 如果没有消息,直接返回 + if not structured_messages: + logger.warning(f"No messages to write for user {actual_end_user_id}") + return + + # 调用 Celery 任务,传递结构化消息列表 + # 数据流: + # 1. structured_messages 传递给 write_message_task + # 2. write_message_task 调用 memory_agent_service.write_memory + # 3. write_memory 调用 write_tools.write,传递 messages 参数 + # 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数 + # 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段 + # 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段 + logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}") + write_id = write_message_task.delay( + actual_end_user_id, # group_id: 用户ID + structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}] + actual_config_id, # config_id: 配置ID + storage_type, # storage_type: "neo4j" + user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用) + ) + logger.info(f"[WRITE] Celery task submitted - task_id={write_id}") write_status = get_task_memory_write_result(str(write_id)) - logger.info(f'Agent:{actual_end_user_id};{write_status}') + logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') async def chat( self, @@ -227,29 +281,30 @@ class LangChainAgent: actual_end_user_id = end_user_id if end_user_id is not None else "unknown" logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') +# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码 +# history_term_memory_result = await self.term_memory_redis_read(end_user_id) +# history_term_memory = history_term_memory_result[0] +# db_for_memory = next(get_db()) +# if memory_flag: +# if len(history_term_memory)>=4 and storage_type != "rag": +# history_term_memory = ';'.join(history_term_memory) +# retrieved_content = history_term_memory_result[1] +# print(retrieved_content) +# # 为长期记忆操作获取新的数据库连接 +# try: +# repo = LongTermMemoryRepository(db_for_memory) +# repo.upsert(end_user_id, retrieved_content) +# logger.info( +# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') +# except Exception as e: +# logger.error(f"Failed to write to LongTermMemory: {e}") +# raise +# finally: +# db_for_memory.close() - history_term_memory_result = await self.term_memory_redis_read(end_user_id) - history_term_memory = history_term_memory_result[0] - db_for_memory = next(get_db()) - if memory_flag: - if len(history_term_memory)>=4 and storage_type != "rag": - history_term_memory = ';'.join(history_term_memory) - retrieved_content = history_term_memory_result[1] - print(retrieved_content) - # 为长期记忆操作获取新的数据库连接 - try: - repo = LongTermMemoryRepository(db_for_memory) - repo.upsert(end_user_id, retrieved_content) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - except Exception as e: - logger.error(f"Failed to write to LongTermMemory: {e}") - raise - finally: - db_for_memory.close() - - await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id) - await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id) +# # 长期记忆写入( +# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id) +# # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -277,8 +332,10 @@ class LangChainAgent: elapsed_time = time.time() - start_time if memory_flag: - await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id) - await self.term_memory_save(message_chat,end_user_id,content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id) + # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # await self.term_memory_save(message_chat, end_user_id, content) response = { "content": content, "model": self.model_name, @@ -346,27 +403,27 @@ class LangChainAgent: db.close() except Exception as e: logger.warning(f"Failed to get db session: {e}") +# # TODO 乐力齐 +# history_term_memory_result = await self.term_memory_redis_read(end_user_id) +# history_term_memory = history_term_memory_result[0] +# if memory_flag: +# if len(history_term_memory) >= 4 and storage_type != "rag": +# history_term_memory = ';'.join(history_term_memory) +# retrieved_content = history_term_memory_result[1] +# db_for_memory = next(get_db()) +# try: +# repo = LongTermMemoryRepository(db_for_memory) +# repo.upsert(end_user_id, retrieved_content) +# logger.info( +# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') +# # 长期记忆写入 +# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id) +# except Exception as e: +# logger.error(f"Failed to write to long term memory: {e}") +# finally: +# db_for_memory.close() - history_term_memory_result = await self.term_memory_redis_read(end_user_id) - history_term_memory = history_term_memory_result[0] - if memory_flag: - if len(history_term_memory) >= 4 and storage_type != "rag": - history_term_memory = ';'.join(history_term_memory) - retrieved_content = history_term_memory_result[1] - db_for_memory = next(get_db()) - try: - repo = LongTermMemoryRepository(db_for_memory) - repo.upsert(end_user_id, retrieved_content) - logger.info( - f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}') - await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id, - history_term_memory, actual_config_id) - except Exception as e: - logger.error(f"Failed to write to long term memory: {e}") - finally: - db_for_memory.close() - - await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id) + # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表 messages = self._prepare_messages(message, history, context) @@ -418,8 +475,10 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") if memory_flag: - await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id) - await self.term_memory_save(message_chat, end_user_id, full_content) + # AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话) + await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id) + # TODO 乐力齐 - 累积多组对话批量写入功能已禁用 + # await self.term_memory_save(message_chat, end_user_id, full_content) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index 8421d059..6af313c3 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -9,22 +9,29 @@ async def write_node(state: WriteState) -> WriteState: Write data to the database/file system. Args: - ctx: FastMCP context for dependency injection - content: Data content to write - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration + state: WriteState containing messages, group_id, and memory_config Returns: - dict: Contains 'status', 'saved_to', and 'data' fields + dict: Contains 'write_result' with status and data fields """ - content=state.get('data','') - group_id=state.get('group_id','') - memory_config=state.get('memory_config', '') + messages = state.get('messages', []) + group_id = state.get('group_id', '') + memory_config = state.get('memory_config', '') + + # Convert LangChain messages to structured format expected by write() + structured_messages = [] + for msg in messages: + if hasattr(msg, 'type') and hasattr(msg, 'content'): + # Map LangChain message types to role names + role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type + structured_messages.append({ + "role": role, + "content": msg.content # content is now guaranteed to be a string + }) + try: - result=await write( - content=content, + result = await write( + messages=structured_messages, user_id=group_id, apply_id=group_id, group_id=group_id, @@ -32,18 +39,17 @@ async def write_node(state: WriteState) -> WriteState: ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") - write_result= { + write_result = { "status": "success", - "data": content, + "data": structured_messages, "config_id": memory_config.config_id, "config_name": memory_config.config_name, } - return {"write_result":write_result} - + return {"write_result": write_result} except Exception as e: logger.error(f"Data_write failed: {e}", exc_info=True) - write_result= { + write_result = { "status": "error", "message": str(e), } diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index 5a6f1e28..fe281a23 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -14,7 +14,6 @@ from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node -from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -27,18 +26,12 @@ async def make_write_graph(): """ Create a write graph workflow for memory operations. - Args: - user_id: User identifier - tools: MCP tools loaded from session - apply_id: Application identifier - group_id: Group identifier - memory_config: MemoryConfig object containing all configuration + The workflow directly processes messages from the initial state + and saves them to Neo4j storage. """ workflow = StateGraph(WriteState) - workflow.add_node("content_input", content_input_write) workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "content_input") - workflow.add_edge("content_input", "save_neo4j") + workflow.add_edge(START, "save_neo4j") workflow.add_edge("save_neo4j", END) graph = workflow.compile() diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index b03fe57c..82a41773 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -12,32 +12,49 @@ async def get_chunked_dialogs( group_id: str = "group_1", user_id: str = "user1", apply_id: str = "applyid", - content: str = "这是用户的输入", + messages: list = None, ref_id: str = "wyl_20251027", config_id: str = None ) -> List[DialogData]: - """Generate chunks from all test data entries using the specified chunker strategy. + """Generate chunks from structured messages using the specified chunker strategy. Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) group_id: Group identifier user_id: User identifier apply_id: Application identifier - content: Dialog content + messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier config_id: Configuration ID for processing Returns: - List of DialogData objects with generated chunks for each test entry + List of DialogData objects with generated chunks """ - dialog_data_list = [] - messages = [] - - messages.append(ConversationMessage(role="用户", msg=content)) - - # Create DialogData - conversation_context = ConversationContext(msgs=messages) - # Create DialogData with group_id based on the entry's id for uniqueness + from app.core.logging_config import get_agent_logger + logger = get_agent_logger(__name__) + + if not messages or not isinstance(messages, list) or len(messages) == 0: + raise ValueError("messages parameter must be a non-empty list") + + conversation_messages = [] + + for idx, msg in enumerate(messages): + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields") + + role = msg['role'] + content = msg['content'] + + if role not in ['user', 'assistant']: + raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") + + if content.strip(): + conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) + + if not conversation_messages: + raise ValueError("Message list cannot be empty after filtering") + + conversation_context = ConversationContext(msgs=conversation_messages) dialog_data = DialogData( context=conversation_context, ref_id=ref_id, @@ -46,25 +63,11 @@ async def get_chunked_dialogs( apply_id=apply_id, config_id=config_id ) - # Create DialogueChunker and process the dialogue + chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = extracted_chunks + + logger.info(f"DialogData created with {len(extracted_chunks)} chunks") - dialog_data_list.append(dialog_data) - - # Convert to dict with datetime serialized - def serialize_datetime(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - combined_output = [dd.model_dump() for dd in dialog_data_list] - - print(dialog_data_list) - - # with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f: - # json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime) - - - return dialog_data_list + return [dialog_data] diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 53c941ad..1df0b336 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -29,25 +29,22 @@ logger = get_agent_logger(__name__) async def write( - content: str, user_id: str, apply_id: str, group_id: str, memory_config: MemoryConfig, + messages: list, ref_id: str = "wyl20251027", ) -> None: """ Execute the complete knowledge extraction pipeline. - Only MemoryConfig is needed - LLM and embedding clients are constructed - internally from the config. - Args: - content: Dialogue content to process user_id: User identifier apply_id: Application identifier group_id: Group identifier memory_config: MemoryConfig object containing all configuration + messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference ID, defaults to "wyl20251027" """ # Extract config values @@ -89,7 +86,7 @@ async def write( group_id=group_id, user_id=user_id, apply_id=apply_id, - content=content, + messages=messages, ref_id=ref_id, config_id=config_id, ) diff --git a/api/app/core/memory/llm_tools/chunker_client.py b/api/app/core/memory/llm_tools/chunker_client.py index 4178ce0a..87cdb9f4 100644 --- a/api/app/core/memory/llm_tools/chunker_client.py +++ b/api/app/core/memory/llm_tools/chunker_client.py @@ -4,6 +4,7 @@ import os import asyncio import json import numpy as np +import logging # Fix tokenizer parallelism warning os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk try: from app.core.memory.llm_tools.openai_client import OpenAIClient except Exception: - # 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入 OpenAIClient = Any +# Initialize logger +logger = logging.getLogger(__name__) + class LLMChunker: - """基于LLM的智能分块策略""" + """LLM-based intelligent chunking strategy""" def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000): self.llm_client = llm_client self.chunk_size = chunk_size async def __call__(self, text: str) -> List[Any]: - # 使用LLM分析文本结构并进行智能分块 prompt = f""" - 请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。 - 请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。 + Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long. + Return results in JSON format with a chunks array, each chunk having a text field. - 文本内容: + Text content: {text[:5000]} """ messages = [ - {"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"}, + {"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."}, {"role": "user", "content": prompt} ] @@ -171,8 +173,6 @@ class ChunkerClient: base_chunk_size=self.chunk_size, ) elif chunker_config.chunker_strategy == "SentenceChunker": - # 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数 - # 为了兼容不同版本,这里仅传递广泛支持的参数 self.chunker = SentenceChunker( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, @@ -186,100 +186,93 @@ class ChunkerClient: async def generate_chunks(self, dialogue: DialogData): """ - 生成分块,支持异步操作 + Generate chunks following 1 Message = 1 Chunk strategy. + + Each message creates one chunk, directly inheriting role information. + If a message is too long, it will be split into multiple sub-chunks, + each maintaining the same speaker. + + Raises: + ValueError: If dialogue has no messages or chunking fails """ - try: - # 预处理文本:确保对话标记格式统一 - content = dialogue.content - content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号 - content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行 - - if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__): - # 同步分块器 - chunks = self.chunker(content) + # Validate dialogue has messages + if not dialogue.context or not dialogue.context.msgs: + raise ValueError( + f"Dialogue {dialogue.ref_id} has no messages. " + f"Cannot generate chunks from empty dialogue." + ) + + dialogue.chunks = [] + + # 按消息分块:每个消息创建一个或多个 chunk,直接继承角色 + for msg_idx, msg in enumerate(dialogue.context.msgs): + # Validate message has required attributes + if not hasattr(msg, 'role') or not hasattr(msg, 'msg'): + raise ValueError( + f"Message {msg_idx} in dialogue {dialogue.ref_id} " + f"missing 'role' or 'msg' attribute" + ) + + msg_content = msg.msg.strip() + + # Skip empty messages + if not msg_content: + continue + + # 如果消息太长,可以进一步分块 + if len(msg_content) > self.chunk_size: + # 对单个消息的内容进行分块 + try: + sub_chunks = self.chunker(msg_content) + except Exception as e: + raise ValueError( + f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}" + ) + + for idx, sub_chunk in enumerate(sub_chunks): + sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk) + sub_chunk_text = sub_chunk_text.strip() + + if len(sub_chunk_text) < (self.min_characters_per_chunk or 50): + continue + + chunk = Chunk( + content=f"{msg.role}: {sub_chunk_text}", + speaker=msg.role, # 直接继承角色 + metadata={ + "message_index": msg_idx, + "message_role": msg.role, + "sub_chunk_index": idx, + "total_sub_chunks": len(sub_chunks), + "chunker_strategy": self.chunker_config.chunker_strategy, + }, + ) + dialogue.chunks.append(chunk) else: - # 异步分块器(如LLMChunker) - chunks = await self.chunker(content) - - # 过滤空块和过小的块 - valid_chunks = [] - for c in chunks: - chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c - if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50): - valid_chunks.append(c) - - dialogue.chunks = [ - Chunk( - content=c.text if hasattr(c, 'text') else str(c), + # 消息不长,直接作为一个 chunk + chunk = Chunk( + content=f"{msg.role}: {msg_content}", + speaker=msg.role, # 直接继承角色 metadata={ - "start_index": getattr(c, "start_index", None), - "end_index": getattr(c, "end_index", None), + "message_index": msg_idx, + "message_role": msg.role, "chunker_strategy": self.chunker_config.chunker_strategy, }, ) - for c in valid_chunks - ] - return dialogue - - except Exception as e: - print(f"分块失败: {e}") - - # 改进的后备方案:尝试按对话回合分割 - try: - # 简单的按对话分割 - dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)' - matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL) - - class SimpleChunk: - def __init__(self, text, start_index, end_index): - self.text = text - self.start_index = start_index - self.end_index = end_index - - chunks = [] - current_chunk = "" - current_start = 0 - - for match in matches: - speaker, ct = match[0], match[1].strip() - turn_text = f"{speaker} {ct}" - - if len(current_chunk) + len(turn_text) > (self.chunk_size or 500): - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - current_chunk = turn_text - current_start = dialogue.content.find(turn_text, current_start) - else: - current_chunk += ("\n" + turn_text) if current_chunk else turn_text - - if current_chunk: - chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk))) - - dialogue.chunks = [ - Chunk( - content=c.text, - metadata={ - "start_index": c.start_index, - "end_index": c.end_index, - "chunker_strategy": "DialogueTurnFallback", - }, - ) - for c in chunks - ] - - except Exception: - # 最后的手段:单一大块 - dialogue.chunks = [Chunk( - content=dialogue.content, - metadata={"chunker_strategy": "SingleChunkFallback"}, - )] - - return dialogue + dialogue.chunks.append(chunk) + + # Validate we generated at least one chunk + if not dialogue.chunks: + raise ValueError( + f"No valid chunks generated for dialogue {dialogue.ref_id}. " + f"All messages were either empty or too short. " + f"Messages count: {len(dialogue.context.msgs)}" + ) + + return dialogue def evaluate_chunking(self, dialogue: DialogData) -> dict: - """ - 评估分块质量 - """ + """Evaluate chunking quality.""" if not getattr(dialogue, 'chunks', None): return {} @@ -304,11 +297,8 @@ class ChunkerClient: return metrics def save_chunking_results(self, dialogue: DialogData, output_path: str): - """ - 保存分块结果到文件,文件名包含策略名称 - """ + """Save chunking results to file with strategy name in filename.""" strategy_name = self.chunker_config.chunker_strategy - # 在文件名中添加策略名称 base_name, ext = os.path.splitext(output_path) strategy_output_path = f"{base_name}_{strategy_name}{ext}" diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index dce7b495..43c2b445 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -92,8 +92,6 @@ class OpenAIClient(LLMClient): config["callbacks"] = [self.langfuse_handler] response = await chain.ainvoke({"messages": messages}, config=config) - - logger.debug(f"LLM 响应成功: {len(str(response))} 字符") return response except Exception as e: @@ -149,13 +147,10 @@ class OpenAIClient(LLMClient): config=config ) - logger.debug(f"使用 PydanticOutputParser 解析成功") return parsed except Exception as e: - logger.warning( - f"PydanticOutputParser 解析失败,尝试其他方法: {e}" - ) + logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}") # 方法 2: 使用 LangChain 的 with_structured_output template = """{question}""" @@ -173,13 +168,17 @@ class OpenAIClient(LLMClient): # 验证并返回结果 try: - return response_model.model_validate(parsed) + result = response_model.model_validate(parsed) + return result except Exception: # 如果已经是 Pydantic 实例,直接返回 if hasattr(parsed, "model_dump"): return parsed # 尝试从 JSON 解析 - return response_model.model_validate_json(json.dumps(parsed)) + result = response_model.model_validate_json(json.dumps(parsed)) + return result + else: + logger.warning("with_structured_output 方法不可用") except Exception as e: logger.error(f"结构化输出失败: {e}") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 39d618fc..7a48d6cb 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -224,6 +224,7 @@ class StatementNode(Node): chunk_id: ID of the parent chunk this statement belongs to stmt_type: Type of the statement (from ontology) statement: The actual statement text content + speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses) emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node emotion_target: Optional emotion target (person or object name) emotion_subject: Optional emotion subject (self/other/object) @@ -249,6 +250,12 @@ class StatementNode(Node): stmt_type: str = Field(..., description="Type of the statement") statement: str = Field(..., description="The statement text content") + # Speaker identification + speaker: Optional[str] = Field( + None, + description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" + ) + # Emotion fields (ordered as requested, emotion_intensity first for display) emotion_intensity: Optional[float] = Field( None, diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index 199bdd75..bcf08999 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -25,10 +25,10 @@ class ConversationMessage(BaseModel): """Represents a single message in a conversation. Attributes: - role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant) + role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant) msg: Text content of the message """ - role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').") + role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") msg: str = Field(..., description="The text content of the message.") @@ -57,6 +57,7 @@ class Statement(BaseModel): chunk_id: ID of the parent chunk this statement belongs to group_id: Optional group ID for multi-tenancy statement: The actual statement text content + speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) statement_embedding: Optional embedding vector for the statement stmt_type: Type of the statement (from ontology) temporal_info: Temporal information extracted from the statement @@ -74,6 +75,7 @@ class Statement(BaseModel): chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") statement: str = Field(..., description="The text content of the statement.") + speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") stmt_type: StatementType = Field(..., description="The type of the statement.") temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.") @@ -118,36 +120,36 @@ class Chunk(BaseModel): Attributes: id: Unique identifier for the chunk - text: List of messages in the chunk content: The content of the chunk as a formatted string + speaker: The speaker/role for this chunk (user/assistant) statements: List of statements extracted from this chunk chunk_embedding: Optional embedding vector for the chunk metadata: Additional metadata as key-value pairs """ id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.") - text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.") content: str = Field(..., description="The content of the chunk as a string.") + speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") @classmethod - def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None): - """Create a chunk from a list of messages. + def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None): + """Create a chunk from a single message (1 Message = 1 Chunk). Args: - messages: List of conversation messages + message: Single conversation message metadata: Optional metadata dictionary Returns: - Chunk instance with formatted content + Chunk instance with speaker directly from message.role """ - if metadata is None: - metadata = {} - # Generate content from messages - content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages]) - return cls(text=messages, content=content, metadata=metadata) - + return cls( + content=f"{message.role}: {message.msg}", + speaker=message.role, + metadata=metadata or {} + ) + class DialogData(BaseModel): """Represents the complete data structure for a dialog record. diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 75aaa7df..46ba1dde 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -550,7 +550,7 @@ class ExtractionOrchestrator: self, dialog_data_list: List[DialogData] ) -> List[Dict[str, Any]]: """ - 从对话中提取情绪信息(优化版:全局陈述句级并行) + 从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行) Args: dialog_data_list: 对话数据列表 @@ -558,7 +558,7 @@ class ExtractionOrchestrator: Returns: 情绪信息映射列表,每个对话对应一个字典 """ - logger.info("开始情绪信息提取(全局陈述句级并行)") + logger.info("开始情绪信息提取(仅处理用户消息)") # 收集所有陈述句及其配置 all_statements = [] @@ -597,15 +597,22 @@ class ExtractionOrchestrator: if not data_config or not data_config.emotion_enabled: logger.info("情绪提取未启用,跳过") return [{} for _ in dialog_data_list] + + # 收集所有陈述句(只收集 speaker 为 "user" 的) + total_statements = 0 + filtered_statements = 0 - # 收集所有陈述句 for d_idx, dialog in enumerate(dialog_data_list): for chunk in dialog.chunks: for statement in chunk.statements: - all_statements.append((statement, data_config)) - statement_metadata.append((d_idx, statement.id)) + total_statements += 1 + # 只处理用户的陈述句 (role 为 "user") + if hasattr(statement, 'speaker') and statement.speaker == "user": + all_statements.append((statement, data_config)) + statement_metadata.append((d_idx, statement.id)) + filtered_statements += 1 - logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪") + logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪") # 初始化情绪提取服务 from app.services.emotion_extraction_service import EmotionExtractionService @@ -1033,6 +1040,7 @@ class ExtractionOrchestrator: apply_id=dialog_data.apply_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, + speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 statement_embedding=statement.statement_embedding, valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None, diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py index edb60a4d..40e98507 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py @@ -22,12 +22,12 @@ class DialogueChunker: Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker + Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker """ self.chunker_strategy = chunker_strategy chunker_config_dict = get_chunker_config(chunker_strategy) self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) - # 对于 LLMChunker,需要传入 llm_client + if self.chunker_config.chunker_strategy == "LLMChunker": self.chunker_client = ChunkerClient(self.chunker_config, llm_client) else: @@ -41,29 +41,19 @@ class DialogueChunker: Returns: A list of Chunk objects + + Raises: + ValueError: If chunking fails or returns empty chunks """ result_dialogue = await self.chunker_client.generate_chunks(dialogue) - # Defensive fallback: ensure at least one chunk is returned for non-empty content - try: - chunks = result_dialogue.chunks - except Exception: - chunks = [] + chunks = result_dialogue.chunks if not chunks or len(chunks) == 0: - # If the dialogue has content, return a single fallback chunk built from messages - content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "") - if content_str and len(content_str.strip()) > 0: - fallback_chunk = Chunk.from_messages( - dialogue.context.msgs, - metadata={ - "fallback": "single_chunk", - "chunker_strategy": self.chunker_config.chunker_strategy, - "source": "DialogueChunkerFallback", - }, - ) - return [fallback_chunk] - # No content: return empty list - return [] + raise ValueError( + f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " + f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, " + f"Strategy: {self.chunker_config.chunker_strategy}" + ) return chunks @@ -72,22 +62,25 @@ class DialogueChunker: Args: dialogue: The processed DialogData object with chunks - output_path: Optional path to save the output (default: chunker_output_{strategy}.txt) + output_path: Optional path to save the output Returns: The path where the output was saved """ if not output_path: - output_path = os.path.join(os.path.dirname(__file__), "..", "..", - f"chunker_output_{self.chunker_strategy.lower()}.txt") + output_path = os.path.join( + os.path.dirname(__file__), "..", "..", + f"chunker_output_{self.chunker_strategy.lower()}.txt" + ) - output_lines = [] - output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===") - output_lines.append(f"Dialogue ID: {dialogue.ref_id}") - output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages") - output_lines.append(f"Total characters: {len(dialogue.content)}") - - output_lines.append(f"Generated {len(dialogue.chunks)} chunks:") + output_lines = [ + f"=== Chunking Results ({self.chunker_strategy}) ===", + f"Dialogue ID: {dialogue.ref_id}", + f"Original conversation has {len(dialogue.context.msgs)} messages", + f"Total characters: {len(dialogue.content)}", + f"Generated {len(dialogue.chunks)} chunks:" + ] + for i, chunk in enumerate(dialogue.chunks): output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") output_lines.append(f" Content preview: {chunk.content}...") diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index 8d37f5d2..fb1b539a 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -5,8 +5,6 @@ from datetime import datetime from typing import Any, Dict, List, Optional from app.core.memory.models.message_models import DialogData, Statement - -#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 from app.core.memory.models.variate_config import StatementExtractionConfig from app.core.memory.utils.data.ontology import ( LABEL_DEFINITIONS, @@ -22,11 +20,10 @@ logger = logging.getLogger(__name__) class ExtractedStatement(BaseModel): """Schema for extracted statement from LLM""" statement: str = Field(..., description="The extracted statement text") - statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION") + statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION") temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL") relevence: str = Field(..., description="RELEVANT or IRRELEVANT") -# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句) class StatementExtractionResponse(BaseModel): statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements") @@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel): return v class StatementExtractor: - """Class for extracting statements from dialog chunks using LLM (relations separated)""" + """Class for extracting statements from dialog chunks using LLM""" def __init__(self, llm_client: Any, config: StatementExtractionConfig = None): - # 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。 """Initialize the StatementExtractor with an LLM client and configuration Args: @@ -71,6 +67,21 @@ class StatementExtractor: self.llm_client = llm_client self.config = config or StatementExtractionConfig() + def _get_speaker_from_chunk(self, chunk) -> Optional[str]: + """Get speaker directly from Chunk + + Args: + chunk: Chunk object containing speaker field + + Returns: + Speaker role ("user"/"assistant") or None if cannot be determined + """ + if hasattr(chunk, 'speaker') and chunk.speaker: + return chunk.speaker + + logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") + return None + async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: """Process a single chunk and return extracted statements @@ -82,10 +93,12 @@ class StatementExtractor: Returns: List of ExtractedStatement objects extracted from the chunk """ - # Prepare the chunk content for processing chunk_content = chunk.content + + if not chunk_content or len(chunk_content.strip()) < 5: + logger.warning(f"Chunk {chunk.id} content too short or empty, skipping") + return [] - # Render the prompt using helper function prompt_content = await render_statement_extraction_prompt( chunk_content=chunk_content, definitions=LABEL_DEFINITIONS, @@ -136,7 +149,9 @@ class StatementExtractor: relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT except (KeyError, ValueError): relevence_info = RelevenceInfo.RELEVANT - + + chunk_speaker = self._get_speaker_from_chunk(chunk) + chunk_statement = Statement( statement=extracted_stmt.statement, stmt_type=stmt_type, @@ -144,7 +159,9 @@ class StatementExtractor: relevence_info=relevence_info, chunk_id=chunk.id, group_id=group_id, + speaker=chunk_speaker, ) + chunk_statements.append(chunk_statement) # 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata @@ -226,12 +243,7 @@ class StatementExtractor: return output_path def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str: - """按对话分组聚合强/弱关系并写入 TXT 文件。 - - 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content` - - 在该对话段内再分为 Strong Relations / Weak Relations 两部分 - - Strong: 逐条输出 `Chunk ID` 与 `Triple` - - Weak: 逐条输出 `Chunk ID` 与 `Entity` - """ + """Group and aggregate strong/weak relations by dialogue and write to TXT file.""" print("\n=== Relations Classify ===") # 使用全局配置的输出路径 diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index 1e24eeae..cf60a773 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -101,6 +101,8 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC # "entities": [entity.model_dump() for entity in statement.triplet_extraction_info.entities] if statement.triplet_extraction_info else [] # }) if statement.triplet_extraction_info else json.dumps({"triplets": [], "entities": []}), "statement_embedding": statement.statement_embedding if statement.statement_embedding else None, + # 添加 speaker 字段(用于基于角色的情绪提取) + "speaker": statement.speaker if hasattr(statement, 'speaker') else None, # 添加情绪字段处理 "emotion_type": statement.emotion_type, "emotion_intensity": statement.emotion_intensity, @@ -163,7 +165,9 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> "chunk_embedding": chunk.chunk_embedding if chunk.chunk_embedding else None, "sequence_number": chunk.sequence_number, "start_index": metadata.get("start_index"), - "end_index": metadata.get("end_index") + "end_index": metadata.get("end_index"), + # 添加 speaker 字段(用于基于角色的情绪提取) + "speaker": chunk.speaker if hasattr(chunk, 'speaker') else None } flattened_chunks.append(flattened_chunk) diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index 47dc6b2a..fbc0e45c 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -12,7 +12,7 @@ class UserInput(BaseModel): class Write_UserInput(BaseModel): - message: str + messages: list[dict] group_id: str config_id: Optional[str] = None diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index fd0cb0eb..65dd628a 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -20,11 +20,13 @@ from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.type_classifier import status_typle +from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_base_service import Translation_English from app.services.memory_config_service import MemoryConfigService @@ -260,13 +262,13 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: + async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id Args: group_id: Group identifier (also used as end_user_id) - message: Message to write + messages: Structured message list [{"role": "user", "content": "..."}, ...] config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -287,7 +289,7 @@ class MemoryAgentService: raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): - raise # Re-raise our specific error + raise logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") @@ -315,14 +317,28 @@ class MemoryAgentService: try: if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) + # For RAG storage, convert messages to single string + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + result = await write_rag(group_id, message_text, user_rag_memory_id) return result else: async with make_write_graph() as graph: config = {"configurable": {"thread_id": group_id}} + # Convert structured messages to LangChain messages + langchain_messages = [] + for msg in messages: + if msg['role'] == 'user': + langchain_messages.append(HumanMessage(content=msg['content'])) + elif msg['role'] == 'assistant': + from langchain_core.messages import AIMessage + langchain_messages.append(AIMessage(content=msg['content'])) + # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, - "memory_config": memory_config} + initial_state = { + "messages": langchain_messages, + "group_id": group_id, + "memory_config": memory_config + } # 获取节点更新信息 async for update_event in graph.astream( @@ -335,7 +351,9 @@ class MemoryAgentService: massages = node_data massagesstatus = massages.get('write_result')['status'] contents = massages.get('write_result') - return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) + # Convert messages back to string for logging + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" @@ -531,7 +549,49 @@ class MemoryAgentService: ) raise ValueError(error_msg) - + def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: + """ + Get standardized message list from user input. + + Args: + user_input: Write_UserInput object + + Returns: + list[dict]: Message list, each message contains role and content + + Raises: + ValueError: If messages is empty or format is incorrect + """ + from app.core.logging_config import get_api_logger + logger = get_api_logger() + + if len(user_input.messages) == 0: + logger.error("Validation failed: Message list cannot be empty") + raise ValueError("Message list cannot be empty") + + for idx, msg in enumerate(user_input.messages): + if not isinstance(msg, dict): + logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") + raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") + + if 'role' not in msg: + logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") + raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}") + + if 'content' not in msg: + logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") + raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") + + if msg['role'] not in ['user', 'assistant']: + logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") + raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}") + + if not msg['content'] or not msg['content'].strip(): + logger.error(f"Validation failed: Message {idx} content is empty") + raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}") + + logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") + return user_input.messages async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index fba9f290..e375de35 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -472,13 +472,19 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, group_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: +def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. + 支持两种消息格式: + 1. 字符串格式(向后兼容):message="user: xxx\nassistant: yyy" + 2. 结构化消息列表(推荐):message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}] + Args: group_id: Group ID for the memory agent (also used as end_user_id) - message: Message to write + message: Message to write (str or list[dict]) config_id: Optional configuration ID + storage_type: Storage type (neo4j/rag) + user_rag_memory_id: RAG memory ID Returns: Dict containing the result and metadata From 37ef497f4cbd342c4ccb472efe808bdc4a6ca636 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 21 Jan 2026 16:04:16 +0800 Subject: [PATCH 2/8] Feature/distinction role (#167) * [feature]A set of information for role recognition writing * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [fix]Based on the AI review to fix the code * [changes]Disable the function of batch writing multiple groups of conversations in a cumulative manner * [fix]Addressing vulnerability risks * [fix]Fixing short-term memory writing * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [feature]A set of information for role recognition writing * [fix]Fix the code after rebasing. * [fix]Based on the AI review to fix the code * [fix]Fixing short-term memory writing --- api/app/services/memory_agent_service.py | 51 ++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 65dd628a..692e9a9a 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -518,6 +518,57 @@ class MemoryAgentService: optimized_outputs = merge_multiple_search_results(_intermediate_outputs) result = reorder_output_results(optimized_outputs) + # 保存短期记忆到数据库 + # 只有 search_switch 不为 "2"(快速检索)时才保存 + try: + from app.repositories.memory_short_repository import ShortTermMemoryRepository + + retrieved_content = [] + repo = ShortTermMemoryRepository(db) + + if str(search_switch) != "2": + for intermediate in _intermediate_outputs: + logger.debug(f"处理中间结果: {intermediate}") + intermediate_type = intermediate.get('type', '') + + if intermediate_type == "search_result": + query = intermediate.get('query', '') + raw_results = intermediate.get('raw_results', {}) + reranked_results = raw_results.get('reranked_results', []) + + try: + statements = [statement['statement'] for statement in reranked_results.get('statements', [])] + except Exception: + statements = [] + + # 去重 + statements = list(set(statements)) + + if query and statements: + retrieved_content.append({query: statements}) + + # 如果 retrieved_content 为空,设置为空字符串 + if retrieved_content == []: + retrieved_content = '' + + # 只有当回答不是"信息不足"且不是快速检索时才保存 + if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": + # 使用 upsert 方法 + repo.upsert( + end_user_id=group_id, + messages=message, + aimessages=summary, + retrieved_content=retrieved_content, + search_switch=str(search_switch) + ) + logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}") + else: + logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") + + except Exception as save_error: + # 保存失败不应该影响主流程,只记录错误 + logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) + # Log successful operation if audit_logger: duration = time.time() - start_time From c24fb731472804b1f2a7266cdc46bdd8a9c5ebb4 Mon Sep 17 00:00:00 2001 From: Ke Sun <33739460+keeees@users.noreply.github.com> Date: Wed, 21 Jan 2026 17:58:46 +0800 Subject: [PATCH 3/8] Fix/memory celery fix (#168) * refactor(celery): optimize task routing and worker configuration - Simplify Celery queue configuration with single default 'io_tasks' queue - Implement task routing strategy separating IO-bound and CPU-bound tasks - Add Flower monitoring support with task event tracking enabled - Add summary node search optimization to only retrieve summary nodes - Clean up unused imports and reorganize import statements for consistency - Update docker-compose configuration to support multi-queue worker setup * chore(celery): simplify flower configuration and add gevent dependency * chore(dependencies): add gevent dependency to requirements - Add gevent==24.11.1 to api/requirements.txt - Gevent is required for async worker support in Celery - Complements existing flower and celery configuration * refactor(celery): simplify async event loop handling and reorganize task queues - Replace complex nest_asyncio and manual event loop management with asyncio.run() in read_message_task, write_message_task, regenerate_memory_cache, and workspace_reflection_task - Rename task queues from io_tasks/cpu_tasks to memory_tasks/document_tasks for better semantic clarity - Update task routing configuration to reflect new queue names for memory agent tasks and document processing tasks - Remove redundant exception handling comments and simplify error handling logic - Update README with improved community support section including GitHub Issues, Pull Requests, Discussions, and WeChat community links - Simplifies event loop management by leveraging asyncio.run() which handles loop creation and cleanup automatically, reducing code complexity and potential race conditions --- README.md | 13 +- api/app/celery_app.py | 75 ++++---- .../langgraph_graph/nodes/summary_nodes.py | 8 +- .../validators/memory_config_validators.py | 33 ++-- api/app/repositories/neo4j/graph_search.py | 92 ++++++---- api/app/services/draft_run_service.py | 15 +- api/app/services/memory_agent_service.py | 23 ++- api/app/services/memory_config_service.py | 33 ++-- api/app/tasks.py | 173 +++++------------- api/docker-compose.yml | 45 ++++- api/pyproject.toml | 2 + api/requirements.txt | 1 + 12 files changed, 254 insertions(+), 259 deletions(-) diff --git a/README.md b/README.md index 7d26f7f7..32a779d2 100644 --- a/README.md +++ b/README.md @@ -334,7 +334,12 @@ step6: Log In to the Frontend Interface. ## License This project is licensed under the Apache License 2.0. For details, see the LICENSE file. -## Acknowledgements & Community -- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions. -- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines. -- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file +## Community & Support + +Join our community to ask questions, share your work, and connect with fellow developers. + +- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues). +- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls). +- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions). +- **WeChat**: Scan the QR code below to join our WeChat community group. +- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 85ad0643..185d746c 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -1,4 +1,5 @@ import os +import platform from datetime import timedelta from urllib.parse import quote @@ -14,27 +15,12 @@ celery_app = Celery( backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", ) -# 配置使用本地队列,避免与远程 worker 冲突 -celery_app.conf.task_default_queue = 'localhost_test_wyl' -celery_app.conf.task_default_exchange = 'localhost_test_wyl' -celery_app.conf.task_default_routing_key = 'localhost_test_wyl' +# Default queue for unrouted tasks +celery_app.conf.task_default_queue = 'memory_tasks' # macOS 兼容性配置 -import platform - -if platform.system() == 'Darwin': # macOS - # 设置环境变量解决 fork 问题 +if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') - - # 使用 solo 池避免多进程问题 - celery_app.conf.worker_pool = 'solo' - - # 设置唯一的节点名称 - import socket - import time - hostname = socket.gethostname() - timestamp = int(time.time()) - celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}" # Celery 配置 celery_app.conf.update( @@ -52,36 +38,47 @@ celery_app.conf.update( task_ignore_result=False, # 超时设置 - task_time_limit=30 * 60, # 30 分钟硬超时 - task_soft_time_limit=25 * 60, # 25 分钟软超时 + task_time_limit=1800, # 30分钟硬超时 + task_soft_time_limit=1500, # 25分钟软超时 - # Worker 设置 - 针对 macOS 优化 - worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积 - worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏 - worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker + # Worker 设置 (per-worker settings are in docker-compose command line) + worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution # 结果过期时间 - result_expires=3600, # 结果保存 1 小时 + result_expires=3600, # 结果保存1小时 # 任务确认设置 - task_acks_late=True, # 任务完成后才确认,避免任务丢失 - worker_disable_rate_limits=True, # 禁用速率限制 + task_acks_late=True, + task_reject_on_worker_lost=True, + worker_disable_rate_limits=True, - # 任务路由(可选,用于不同队列) - # task_routes={ - # 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'}, - # 'app.core.memory.agent.read_message': {'queue': 'memory_processing'}, - # 'app.core.memory.agent.write_message': {'queue': 'memory_processing'}, - # 'tasks.process_item': {'queue': 'default'}, - # }, + # FLower setting + worker_send_task_events=True, + task_send_sent_event=True, + + # task routing + task_routes={ + # Memory tasks → memory_tasks queue (threads worker) + 'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.read_message': {'queue': 'memory_tasks'}, + 'app.core.memory.agent.write_message': {'queue': 'memory_tasks'}, + + # Document tasks → document_tasks queue (prefork worker) + 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, + 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, + + # Beat/periodic tasks → document_tasks queue (prefork worker) + 'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'}, + 'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'}, + 'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'}, + 'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'}, + }, ) # 自动发现任务模块 celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS) -health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME @@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘 # 构建定时任务配置 beat_schedule_config = { - - # "check-read-service": { - # "task": "app.core.memory.agent.health.check_read_service", - # "schedule": health_schedule, - # "args": (), - # }, "run-workspace-reflection": { "task": "app.tasks.workspace_reflection_task", "schedule": workspace_reflection_schedule, diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 0d0b57b0..44f89c6a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -4,12 +4,11 @@ import os import time from app.core.logging_config import get_agent_logger, log_time -from app.db import get_db - from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, SummaryResponse, ) +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.search_service import SearchService from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, @@ -18,7 +17,7 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) @@ -182,7 +181,8 @@ async def Input_Summary(state: ReadState) -> ReadState: search_params = { "group_id": group_id, "question": data, - "return_raw_results": True + "return_raw_results": True, + "include": ["summaries"] # Only search summary nodes for faster performance } try: diff --git a/api/app/core/validators/memory_config_validators.py b/api/app/core/validators/memory_config_validators.py index 6ccf3ddb..333572e6 100644 --- a/api/app/core/validators/memory_config_validators.py +++ b/api/app/core/validators/memory_config_validators.py @@ -89,14 +89,15 @@ def validate_model_exists_and_active( start_time = time.time() try: - # First check if model exists at all (without tenant filtering) - model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) - - # Then check with tenant filtering + # OPTIMIZED: Single query with tenant filter + # We'll check tenant mismatch in the error handling model = ModelConfigRepository.get_by_id(db, model_id, tenant_id) elapsed_ms = (time.time() - start_time) * 1000 if not model: + # Model not found with tenant filter - check if it exists without filter + model_without_tenant = ModelConfigRepository.get_by_id(db, model_id, tenant_id=None) + if model_without_tenant: # Model exists but belongs to different tenant logger.warning( @@ -208,8 +209,11 @@ def validate_embedding_model( db: Session, tenant_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None -) -> UUID: - """Validate that embedding model is available and return its UUID. +) -> tuple[UUID, str]: + """Validate that embedding model is available and return its UUID and name. + + Returns: + Tuple of (embedding_uuid, embedding_name) Raises: InvalidConfigError: If embedding_id is not provided or invalid @@ -225,14 +229,19 @@ def validate_embedding_model( workspace_id=workspace_id ) - embedding_uuid, _ = validate_and_resolve_model_id( + embedding_uuid, embedding_name = validate_and_resolve_model_id( embedding_id, "embedding", db, tenant_id, required=True, config_id=config_id, workspace_id=workspace_id ) - print(100*'-') - print(embedding_uuid) - print(_) - print(100*'-') + + logger.debug( + "Embedding model validated", + extra={ + "embedding_uuid": str(embedding_uuid), + "embedding_name": embedding_name, + "config_id": config_id + } + ) if embedding_uuid is None: raise InvalidConfigError( @@ -243,7 +252,7 @@ def validate_embedding_model( workspace_id=workspace_id ) - return embedding_uuid + return embedding_uuid, embedding_name def validate_llm_model( diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 0b6a27c6..6f5764b4 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -305,12 +305,19 @@ async def search_graph( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id + # Skip activation updates if only searching summaries (optimization) + needs_activation_update = any( + key in include and key in results and results[key] + for key in ['statements', 'entities', 'chunks'] ) + if needs_activation_update: + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + return results @@ -339,7 +346,7 @@ async def search_graph_by_embedding( embed_start = time.time() embeddings = await embedder_client.response([query_text]) embed_time = time.time() - embed_start - print(f"[PERF] Embedding generation took: {embed_time:.4f}s") + logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s") if not embeddings or not embeddings[0]: return {"statements": [], "chunks": [], "entities": [], "summaries": []} @@ -393,7 +400,7 @@ async def search_graph_by_embedding( query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) query_time = time.time() - query_start - print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") + logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") # Build results dictionary results: Dict[str, List[Dict[str, Any]]] = { @@ -417,14 +424,23 @@ async def search_graph_by_embedding( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - update_start = time.time() - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id + # Skip activation updates if only searching summaries (optimization) + needs_activation_update = any( + key in include and key in results and results[key] + for key in ['statements', 'entities', 'chunks'] ) - update_time = time.time() - update_start - print(f"[PERF] Activation value updates took: {update_time:.4f}s") + + if needs_activation_update: + update_start = time.time() + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + update_time = time.time() - update_start + logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") + else: + logger.info(f"[PERF] Skipping activation updates (only summaries)") return results async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 @@ -535,7 +551,7 @@ async def search_graph_by_keyword_temporal( - Returns up to 'limit' statements """ if not query_text: - print(f"query_text不能为空") + logger.warning(f"query_text cannot be empty") return {"statements": []} statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, @@ -549,7 +565,7 @@ async def search_graph_by_keyword_temporal( invalid_date=invalid_date, limit=limit, ) - print(f"查询结果为:\n{statements}") + logger.debug(f"Temporal keyword search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -594,9 +610,9 @@ async def search_graph_by_temporal( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}") + logger.debug(f"Temporal search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -623,7 +639,7 @@ async def search_graph_by_dialog_id( - Returns up to 'limit' dialogues """ if not dialog_id: - print(f"dialog_id不能为空") + logger.warning(f"dialog_id cannot be empty") return {"dialogues": []} dialogues = await connector.execute_query( @@ -642,7 +658,7 @@ async def search_graph_by_chunk_id( limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - print(f"chunk_id不能为空") + logger.warning(f"chunk_id cannot be empty") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, @@ -679,9 +695,9 @@ async def search_graph_by_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -719,9 +735,9 @@ async def search_graph_by_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -759,9 +775,9 @@ async def search_graph_g_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -799,9 +815,9 @@ async def search_graph_g_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -839,9 +855,9 @@ async def search_graph_l_created_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} @@ -879,9 +895,9 @@ async def search_graph_l_valid_at( limit=limit, ) - print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") - print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") - print(f"查询结果为:\n{statements}") + logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}") + logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") + logger.debug(f"Search results: {len(statements)} statements found") # 更新 Statement 节点的激活值 results = {"statements": statements} diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 50934226..46bda5f6 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -10,11 +10,6 @@ import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional -from langchain.tools import tool -from pydantic import BaseModel, Field -from sqlalchemy import select -from sqlalchemy.orm import Session - from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.exceptions import BusinessException @@ -28,6 +23,10 @@ from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger from app.services.tool_service import ToolService +from langchain.tools import tool +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session logger = get_business_logger() class KnowledgeRetrievalInput(BaseModel): @@ -107,9 +106,9 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str "app.core.memory.agent.read_message", args=[end_user_id, question, [], "1", config_id, storage_type, user_rag_memory_id] ) - result = task_service.get_task_memory_read_result(task.id) - status = result.get("status") - logger.info(f"读取任务状态:{status}") + # result = task_service.get_task_memory_read_result(task.id) + # status = result.get("status") + # logger.info(f"读取任务状态:{status}") finally: db.close() diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 692e9a9a..6748d6c7 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -10,15 +10,17 @@ import re import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional -import redis -from langchain_core.messages import HumanMessage +import redis from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results +from app.core.memory.agent.utils.messages_tools import ( + merge_multiple_search_results, + reorder_output_results, +) from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags @@ -33,6 +35,7 @@ from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session @@ -404,6 +407,7 @@ class MemoryAgentService: import time start_time = time.time() + logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}") # Resolve config_id if None using end_user's connected config if config_id is None: @@ -427,13 +431,15 @@ class MemoryAgentService: audit_logger = None + config_load_start = time.time() try: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( config_id=config_id, service_name="MemoryAgentService" ) - logger.info(f"Configuration loaded successfully: {memory_config.config_name}") + config_load_time = time.time() - config_load_start + logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}") except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) @@ -457,6 +463,7 @@ class MemoryAgentService: logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow + graph_exec_start = time.time() try: async with make_read_graph() as graph: config = {"configurable": {"thread_id": group_id}} @@ -513,6 +520,9 @@ class MemoryAgentService: if summary_n and summary_n != [] and summary_n != {}: _intermediate_outputs.append(summary_n) + graph_exec_time = time.time() - graph_exec_start + logger.info(f"[PERF] Graph execution completed in {graph_exec_time:.4f}s") + _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] optimized_outputs = merge_multiple_search_results(_intermediate_outputs) @@ -570,6 +580,8 @@ class MemoryAgentService: logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) # Log successful operation + total_time = time.time() - start_time + logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( @@ -587,7 +599,8 @@ class MemoryAgentService: except Exception as e: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" - logger.error(error_msg) + total_time = time.time() - start_time + logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index 09e980a0..0099eb18 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -125,7 +125,11 @@ class MemoryConfigService: try: validated_config_id = _validate_config_id(config_id) + # Step 1: Get config and workspace + db_query_start = time.time() result = DataConfigRepository.get_config_with_workspace(self.db, validated_config_id) + db_query_time = time.time() - db_query_start + logger.info(f"[PERF] Config+Workspace query: {db_query_time:.4f}s") if not result: elapsed_ms = (time.time() - start_time) * 1000 config_logger.error( @@ -144,16 +148,20 @@ class MemoryConfigService: memory_config, workspace = result - # Validate embedding model - embedding_uuid = validate_embedding_model( + # Step 2: Validate embedding model (returns both UUID and name) + embed_start = time.time() + embedding_uuid, embedding_name = validate_embedding_model( validated_config_id, memory_config.embedding_id, self.db, workspace.tenant_id, workspace.id, ) + embed_time = time.time() - embed_start + logger.info(f"[PERF] Embedding validation: {embed_time:.4f}s") - # Resolve LLM model + # Step 3: Resolve LLM model + llm_start = time.time() llm_uuid, llm_name = validate_and_resolve_model_id( memory_config.llm_id, "llm", @@ -163,8 +171,11 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id, ) + llm_time = time.time() - llm_start + logger.info(f"[PERF] LLM validation: {llm_time:.4f}s") - # Resolve optional rerank model + # Step 4: Resolve optional rerank model + rerank_start = time.time() rerank_uuid = None rerank_name = None if memory_config.rerank_id: @@ -177,16 +188,12 @@ class MemoryConfigService: config_id=validated_config_id, workspace_id=workspace.id, ) + rerank_time = time.time() - rerank_start + if memory_config.rerank_id: + logger.info(f"[PERF] Rerank validation: {rerank_time:.4f}s") - # Get embedding model name - embedding_name, _ = validate_model_exists_and_active( - embedding_uuid, - "embedding", - self.db, - workspace.tenant_id, - config_id=validated_config_id, - workspace_id=workspace.id, - ) + # Note: embedding_name is now returned from validate_embedding_model above + # No need for redundant query! # Create immutable MemoryConfig object config = MemoryConfig( diff --git a/api/app/tasks.py b/api/app/tasks.py index e375de35..fa9d1fdf 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -425,24 +425,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, db.close() try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time return { @@ -455,7 +438,6 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, } except BaseException as e: elapsed_time = time.time() - start_time - # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -528,24 +510,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ db.close() try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") @@ -560,7 +525,6 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ } except BaseException as e: elapsed_time = time.time() - start_time - # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -600,53 +564,53 @@ def reflection_timer_task() -> None: """ reflection_engine() - -@celery_app.task(name="app.core.memory.agent.health.check_read_service") -def check_read_service_task() -> Dict[str, str]: - """Call read_service and write latest status to Redis. +# unused task +# @celery_app.task(name="app.core.memory.agent.health.check_read_service") +# def check_read_service_task() -> Dict[str, str]: +# """Call read_service and write latest status to Redis. - Returns status data dict that gets written to Redis. - """ - client = redis.Redis( - host=settings.REDIS_HOST, - port=settings.REDIS_PORT, - db=settings.REDIS_DB, - password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None - ) - try: - api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" - payload = { - "user_id": "健康检查", - "apply_id": "健康检查", - "group_id": "健康检查", - "message": "你好", - "history": [], - "search_switch": "2", - } - resp = requests.post(api_url, json=payload, timeout=15) - ok = resp.status_code == 200 - status = "Success" if ok else "Fail" - msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" - error = "" if ok else resp.text - code = 0 if ok else 500 - except Exception as e: - status = "Fail" - msg = "接口请求失败" - error = str(e) - code = 500 +# Returns status data dict that gets written to Redis. +# """ +# client = redis.Redis( +# host=settings.REDIS_HOST, +# port=settings.REDIS_PORT, +# db=settings.REDIS_DB, +# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None +# ) +# try: +# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" +# payload = { +# "user_id": "健康检查", +# "apply_id": "健康检查", +# "group_id": "健康检查", +# "message": "你好", +# "history": [], +# "search_switch": "2", +# } +# resp = requests.post(api_url, json=payload, timeout=15) +# ok = resp.status_code == 200 +# status = "Success" if ok else "Fail" +# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" +# error = "" if ok else resp.text +# code = 0 if ok else 500 +# except Exception as e: +# status = "Fail" +# msg = "接口请求失败" +# error = str(e) +# code = 500 - data = { - "status": status, - "msg": msg, - "error": error, - "code": str(code), - "time": str(int(time.time())), - } +# data = { +# "status": status, +# "msg": msg, +# "error": error, +# "code": str(code), +# "time": str(int(time.time())), +# } - client.hset("memsci:health:read_service", mapping=data) - client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) +# client.hset("memsci:health:read_service", mapping=data) +# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) - return data +# return data @celery_app.task(name="app.controllers.memory_storage_controller.search_all") @@ -911,24 +875,7 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1055,24 +1002,7 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } try: - # 使用 nest_asyncio 来避免事件循环冲突 - try: - import nest_asyncio - nest_asyncio.apply() - except ImportError: - pass - - # 尝试获取现有事件循环,如果不存在则创建新的 - try: - loop = asyncio.get_event_loop() - if loop.is_closed(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - result = loop.run_until_complete(_run()) + result = asyncio.run(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1148,11 +1078,4 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str "duration_seconds": duration } - # 运行异步函数 - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - result = loop.run_until_complete(_run()) - return result - finally: - loop.close() + return asyncio.run(_run()) diff --git a/api/docker-compose.yml b/api/docker-compose.yml index 8bc19f3a..a7337689 100644 --- a/api/docker-compose.yml +++ b/api/docker-compose.yml @@ -7,10 +7,6 @@ services: - "8002:8000" env_file: - .env - environment: - - SERVER_IP=0.0.0.0 - # 如果代码里必须要 MCP_SERVER_URL,可以先注释或指向占位 - # - MCP_SERVER_URL= volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro @@ -19,20 +15,53 @@ services: networks: - default - celery + depends_on: + - worker-memory + - worker-document - # Celery worker - worker: + # Memory worker - Memory read/write tasks (threads pool for asyncio) + worker-memory: image: redbear-mem-open:latest - container_name: worker + container_name: worker-memory env_file: - .env volumes: - ./files:/files - /etc/localtime:/etc/localtime:ro - command: celery -A app.celery_worker.celery_app worker --loglevel=info + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=threads --concurrency=100 --queues=memory_tasks -n memory_worker@%h restart: unless-stopped networks: - celery + # Document worker - Document parsing tasks (prefork for CPU-bound) + worker-document: + image: redbear-mem-open:latest + container_name: worker-document + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app worker -E --loglevel=info --pool=prefork --concurrency=4 --queues=document_tasks --max-tasks-per-child=100 -n document_worker@%h + restart: unless-stopped + networks: + - celery + + # Celery Beat - scheduler + beat: + image: redbear-mem-open:latest + container_name: celery-beat + env_file: + - .env + volumes: + - ./files:/files + - /etc/localtime:/etc/localtime:ro + command: celery -A app.celery_worker.celery_app beat --loglevel=info + restart: unless-stopped + networks: + - celery + depends_on: + - worker-memory + networks: celery: diff --git a/api/pyproject.toml b/api/pyproject.toml index 6da684de..414ba372 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "bcrypt==5.0.0", "billiard==4.2.2", "celery==5.5.3", + "flower==2.0.1", "cffi==2.0.0", "click==8.3.0", "click-didyoumean==0.3.1", @@ -138,6 +139,7 @@ dependencies = [ "python-calamine>=0.4.0", "xlrd==2.0.2", "deprecated>=1.3.1", + "flower>=2.0.1", ] [tool.pytest.ini_options] diff --git a/api/requirements.txt b/api/requirements.txt index 99252e09..444a194b 100644 --- a/api/requirements.txt +++ b/api/requirements.txt @@ -6,6 +6,7 @@ async-timeout==5.0.1 bcrypt==5.0.0 billiard==4.2.2 celery==5.5.3 +flower==2.0.1 cffi==2.0.0 click==8.3.0 click-didyoumean==0.3.1 From 1e5acd85ffbed63cbddd70b54a2e7b8cb333dd28 Mon Sep 17 00:00:00 2001 From: Ke Sun <33739460+keeees@users.noreply.github.com> Date: Wed, 21 Jan 2026 18:11:50 +0800 Subject: [PATCH 4/8] Update community links in README.md --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 32a779d2..2f53a996 100644 --- a/README.md +++ b/README.md @@ -338,8 +338,9 @@ This project is licensed under the Apache License 2.0. For details, see the LICE Join our community to ask questions, share your work, and connect with fellow developers. -- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/redbear-ai/memorybear/issues). -- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/redbear-ai/memorybear/pulls). -- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/redbear-ai/memorybear/discussions). +- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues). +- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls). +- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions). - **WeChat**: Scan the QR code below to join our WeChat community group. -- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com \ No newline at end of file +- ![wecom-temp-114020-47fe87a75da439f09f5dc93a01593046](https://github.com/user-attachments/assets/8c81885c-4134-40d5-96e2-7f78cc082dc6) +- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com From b6e6dbf27f04cd577482a8e876342c9f6f65d9fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= <162269739+lanceyq@users.noreply.github.com> Date: Wed, 21 Jan 2026 18:20:28 +0800 Subject: [PATCH 5/8] Fix/memory interface (#169) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [changes]《Modify the interface》 1.Remove the "/search/entity_graph" interface 2.Reconstruct the "/updated_end_user/profile" interface 3.Remove the "Update Username" interface 4.Fix the batch query of user association memory configuration * [changes]《Modify the interface》 1.Remove the "/search/entity_graph" interface 2.Reconstruct the "/updated_end_user/profile" interface 3.Remove the "Update Username" interface 4.Fix the batch query of user association memory configuration * [fix]Fix the error response type --- .../controllers/memory_agent_controller.py | 2 +- .../memory_dashboard_controller.py | 48 ---------- .../controllers/memory_storage_controller.py | 17 +--- .../controllers/user_memory_controllers.py | 70 ++++---------- .../repositories/data_config_repository.py | 32 ------- api/app/repositories/end_user_repository.py | 36 ------- api/app/schemas/memory_agent_schema.py | 4 - api/app/services/memory_agent_service.py | 8 +- api/app/services/memory_storage_service.py | 21 ---- api/app/services/user_memory_service.py | 95 +++++++++++++++++++ 10 files changed, 119 insertions(+), 214 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 416ed710..7707522c 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -682,7 +682,7 @@ async def get_user_profile_api( current_user: User = Depends(get_current_user) ): """ - 获取用户详情,包含: + 获取工作空间下Popular Memory Tags,包含: - name: 用户名字(直接使用 end_user_id) - tags: 3个用户特征标签(从语句和实体中LLM总结) - hot_tags: 4个热门记忆标签 diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 2afff491..e03c1846 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -5,7 +5,6 @@ from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user from app.models.user_model import User -from app.schemas.memory_agent_schema import End_User_Information from app.schemas.response_schema import ApiResponse from app.services import memory_dashboard_service, memory_storage_service, workspace_service @@ -40,54 +39,7 @@ def get_workspace_total_end_users( api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}") return success(data=total_end_users, msg="用户数量获取成功") -@router.post("/update/end_users", response_model=ApiResponse) -async def update_workspace_end_users( - user_input: End_User_Information, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user), -): - """ - 更新工作空间的宿主信息 - """ - username = user_input.end_user_name # 要更新的用户名 - end_user_input_id = user_input.id # 宿主ID - workspace_id = current_user.current_workspace_id - - api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息") - api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}") - try: - # 导入更新函数 - from app.repositories.end_user_repository import update_end_user_other_name - import uuid - - # 转换 end_user_id 为 UUID 类型 - end_user_uuid = uuid.UUID(end_user_input_id) - - # 直接更新数据库中的 other_name 字段 - updated_count = update_end_user_other_name( - db=db, - end_user_id=end_user_uuid, - other_name=username - ) - - api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}") - - return success( - data={ - "updated_count": updated_count, - "end_user_id": end_user_input_id, - "updated_other_name": username - }, - msg=f"成功更新 {updated_count} 个宿主的信息" - ) - - except Exception as e: - api_logger.error(f"更新宿主信息失败: {str(e)}") - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"更新宿主信息失败: {str(e)}" - ) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 63d9078a..f4175923 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -28,7 +28,6 @@ from app.services.memory_storage_service import ( search_dialogue, search_edges, search_entity, - search_entity_graph, search_statement, ) from fastapi import APIRouter, Depends @@ -412,21 +411,7 @@ async def search_entity_edges( api_logger.error(f"Search edges failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) -@router.get("/search/entity_graph", response_model=ApiResponse) -async def search_for_entity_graph( - end_user_id: Optional[str] = None, - current_user: User = Depends(get_current_user), - ) -> dict: - """ - 搜索所有实体之间的关系网络 - """ - api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}") - try: - result = await search_entity_graph(end_user_id) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Search entity graph failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e)) + @router.get("/analytics/hot_memory_tags", response_model=ApiResponse) diff --git a/api/app/controllers/user_memory_controllers.py b/api/app/controllers/user_memory_controllers.py index d99eb47e..3b7345b6 100644 --- a/api/app/controllers/user_memory_controllers.py +++ b/api/app/controllers/user_memory_controllers.py @@ -351,12 +351,11 @@ async def update_end_user_profile( 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。 所有字段都是可选的,只更新提供的字段。 - """ workspace_id = current_user.current_workspace_id end_user_id = profile_update.end_user_id - # 检查用户是否已选择工作空间 + # 验证工作空间 if workspace_id is None: api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") @@ -366,57 +365,24 @@ async def update_end_user_profile( f"workspace={workspace_id}" ) - try: - # 查询终端用户 - end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() + # 调用 Service 层处理业务逻辑 + result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update) - if not end_user: - api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}") - return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}") - - # 更新字段(只更新提供的字段,排除 end_user_id) - # 允许 None 值来重置字段(如 hire_date) - update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) - - # 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime - if 'hire_date' in update_data: - hire_date_timestamp = update_data['hire_date'] - if hire_date_timestamp is not None: - update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) - # 如果是 None,保持 None(允许清空) - - for field, value in update_data.items(): - setattr(end_user, field, value) - - # 更新 updated_at 时间戳 - end_user.updated_at = datetime.datetime.now() - - # 更新 updatetime_profile 为当前时间 - end_user.updatetime_profile = datetime.datetime.now() - - # 提交更改 - db.commit() - db.refresh(end_user) - - # 构建响应数据 - profile_data = EndUserProfileResponse( - id=end_user.id, - other_name=end_user.other_name, - position=end_user.position, - department=end_user.department, - contact=end_user.contact, - phone=end_user.phone, - hire_date=end_user.hire_date, - updatetime_profile=end_user.updatetime_profile - ) - - api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") - return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功") - - except Exception as e: - db.rollback() - api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") - return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e)) + if result["success"]: + api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}") + return success(data=result["data"], msg="更新成功") + else: + error_msg = result["error"] + api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") + + # 根据错误类型映射到合适的业务错误码 + if error_msg == "终端用户不存在": + return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg) + elif error_msg == "无效的用户ID格式": + return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg) + else: + # 只有未预期的错误才使用 INTERNAL_ERROR + return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg) @router.get("/memory_space/timeline_memories", response_model=ApiResponse) async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str="zh", diff --git a/api/app/repositories/data_config_repository.py b/api/app/repositories/data_config_repository.py index d26058b2..3df7f800 100644 --- a/api/app/repositories/data_config_repository.py +++ b/api/app/repositories/data_config_repository.py @@ -104,38 +104,6 @@ class DataConfigRepository: r.statement AS statement """ - # Entity graph within group (source node, edge, target node) - SEARCH_FOR_ENTITY_GRAPH = """ - MATCH (n:ExtractedEntity)-[r]->(m:ExtractedEntity) - WHERE n.group_id = $group_id - RETURN - { - entity_idx: n.entity_idx, - connect_strength: n.connect_strength, - description: n.description, - entity_type: n.entity_type, - name: n.name, - fact_summary: COALESCE(n.fact_summary, ''), - id: n.id - } AS sourceNode, - { - rel_id: elementId(r), - source_id: startNode(r).id, - target_id: endNode(r).id, - predicate: r.predicate, - statement_id: r.statement_id, - statement: r.statement - } AS edge, - { - entity_idx: m.entity_idx, - connect_strength: m.connect_strength, - description: m.description, - entity_type: m.entity_type, - name: m.name, - fact_summary: COALESCE(m.fact_summary, ''), - id: m.id - } AS targetNode - """ @staticmethod def update_reflection_config( db: Session, diff --git a/api/app/repositories/end_user_repository.py b/api/app/repositories/end_user_repository.py index b9e82693..c7d13f8f 100644 --- a/api/app/repositories/end_user_repository.py +++ b/api/app/repositories/end_user_repository.py @@ -276,42 +276,6 @@ def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser] end_user = repo.get_end_user_by_id(end_user_id) return end_user -def update_end_user_other_name( - db: Session, - end_user_id: uuid.UUID, - other_name: str -) -> int: - """ - 通过 end_user_id 更新 end_user 表中的 other_name 字段 - - Args: - db: 数据库会话 - end_user_id: 宿主ID - other_name: 要更新的用户名 - - Returns: - int: 更新的记录数 - """ - try: - # 执行更新 - updated_count = ( - db.query(EndUser) - .filter(EndUser.id == end_user_id) - .update( - {EndUser.other_name: other_name}, - synchronize_session=False - ) - ) - - db.commit() - db_logger.info(f"成功更新宿主 {end_user_id} 的 other_name 为: {other_name}") - return updated_count - - except Exception as e: - db.rollback() - db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}") - raise - # 新增的缓存操作函数(保持与类方法一致的接口) def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]: """根据ID获取终端用户(用于缓存操作)""" diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index fbc0e45c..d4354c40 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -15,7 +15,3 @@ class Write_UserInput(BaseModel): messages: list[dict] group_id: str config_id: Optional[str] = None - -class End_User_Information(BaseModel): - end_user_name: str # 这是要更新的用户名 - id: str # 宿主ID,用于匹配条件 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 6748d6c7..d744b766 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -1157,7 +1157,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) """ from app.models.app_release_model import AppRelease from app.models.end_user_model import EndUser - from app.models.memory_config_model import MemoryConfig + from app.models.data_config_model import DataConfig from sqlalchemy import select logger.info(f"Batch getting connected configs for {len(end_user_ids)} end_users") @@ -1215,8 +1215,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 批量查询 memory_config_name config_id_to_name = {} if memory_config_ids: - memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() - config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs} + memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all() + config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs} # 4. 构建最终结果 for end_user_id, app_id in user_to_app.items(): @@ -1233,7 +1233,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None # 获取配置名称 - memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None + memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None result[end_user_id] = { "memory_config_id": memory_config_id, diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 9cac26ec..83d5923d 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -506,27 +506,6 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any] return result -async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]: - """搜索所有实体之间的关系网络(group 维度)。""" - result = await _neo4j_connector.execute_query( - DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH, - group_id=end_user_id, - ) - # 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”) - for item in result: - source_fact = item["sourceNode"]["fact_summary"] - target_fact = item["targetNode"]["fact_summary"] - # 截取前三条“来源” - item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else [] - item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else [] - # 与现有返回风格保持一致,携带搜索类型、数量与详情 - data = { - "search_for": "entity_graph", - "num": len(result), - "detials": result, - } - return data - async def analytics_hot_memory_tags( db: Session, diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index ae07256a..863bccb0 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -357,6 +357,101 @@ class UserMemoryService: data[key] = UserMemoryService._datetime_to_timestamp(original_value) return data + def update_end_user_profile( + self, + db: Session, + end_user_id: str, + profile_update: Any + ) -> Dict[str, Any]: + """ + 更新终端用户的基本信息 + + Args: + db: 数据库会话 + end_user_id: 终端用户ID (UUID) + profile_update: 包含更新字段的 Pydantic 模型 + + Returns: + { + "success": bool, + "data": dict, # 更新后的用户档案数据 + "error": Optional[str] + } + """ + try: + # 转换为UUID并查询用户 + user_uuid = uuid.UUID(end_user_id) + repo = EndUserRepository(db) + end_user = repo.get_by_id(user_uuid) + + if not end_user: + logger.warning(f"终端用户不存在: end_user_id={end_user_id}") + return { + "success": False, + "data": None, + "error": "终端用户不存在" + } + + # 获取更新数据(排除 end_user_id 字段) + update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'}) + + # 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime + if 'hire_date' in update_data: + hire_date_timestamp = update_data['hire_date'] + if hire_date_timestamp is not None: + from app.core.api_key_utils import timestamp_to_datetime + update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp) + # 如果是 None,保持 None(允许清空) + + # 更新字段 + for field, value in update_data.items(): + setattr(end_user, field, value) + + # 更新时间戳 + end_user.updated_at = datetime.now() + end_user.updatetime_profile = datetime.now() + + # 提交更改 + db.commit() + db.refresh(end_user) + + # 构建响应数据 + from app.schemas.end_user_schema import EndUserProfileResponse + profile_data = EndUserProfileResponse( + id=end_user.id, + other_name=end_user.other_name, + position=end_user.position, + department=end_user.department, + contact=end_user.contact, + phone=end_user.phone, + hire_date=end_user.hire_date, + updatetime_profile=end_user.updatetime_profile + ) + + logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}") + + return { + "success": True, + "data": self.convert_profile_to_dict_with_timestamp(profile_data), + "error": None + } + + except ValueError: + logger.error(f"无效的 end_user_id 格式: {end_user_id}") + return { + "success": False, + "data": None, + "error": "无效的用户ID格式" + } + except Exception as e: + db.rollback() + logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}") + return { + "success": False, + "data": None, + "error": str(e) + } + async def get_cached_memory_insight( self, db: Session, From fb25495f1b44a5d6c62744113d463562cd00e2d3 Mon Sep 17 00:00:00 2001 From: lixinyue11 <94037597+lixinyue11@users.noreply.github.com> Date: Wed, 21 Jan 2026 18:21:51 +0800 Subject: [PATCH 6/8] Fix/memory mcp2 1 (#170) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 去掉MCP框架,重构 * feat(celery): add comprehensive logging to worker and write task - Initialize logging system in Celery worker entry point with LoggingConfig - Add logger instance and startup message to celery_worker.py - Reorganize imports in tasks.py for better readability and consistency - Add detailed logging to write_message_task for debugging and monitoring - Log task start with group_id, config_id, and storage_type parameters - Log service execution and completion status with results - Add exception handling with error logging and stack trace capture - Log task completion time and Celery task ID for performance tracking - Improves observability and troubleshooting of async task execution * 去掉MCP框架,重构 * 去掉MCP框架,重构 * 快速检索,需要在接口部分添加LLM整合 * 快速检索,需要在接口部分添加LLM整合 --------- Co-authored-by: Ke Sun --- .../controllers/memory_agent_controller.py | 15 +++++ .../langgraph_graph/nodes/problem_nodes.py | 41 ++++++------ .../agent/langgraph_graph/read_graph.py | 1 - .../agent/services/optimized_llm_service.py | 4 +- api/app/services/memory_agent_service.py | 62 ++++++++++++++++++- 5 files changed, 99 insertions(+), 24 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 7707522c..22830890 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -9,6 +9,8 @@ from app.db import get_db from app.dependencies import cur_workspace_access_guard, get_current_user from app.models import ModelApiKey from app.models.user_model import User +from app.core.memory.agent.utils.session_tools import SessionService +from app.core.memory.agent.utils.redis_tool import store from app.repositories import knowledge_repository, WorkspaceRepository from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.response_schema import ApiResponse @@ -291,6 +293,19 @@ async def read_server( storage_type, user_rag_memory_id ) + if str(user_input.search_switch) == "2": + retrieve_info = result['answer'] + history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) + query = user_input.message + + # 调用 memory_agent_service 的方法生成最终答案 + result['answer'] = await memory_agent_service.generate_summary_from_retrieve( + retrieve_info=retrieve_info, + history=history, + query=query, + config_id=config_id, + db=db + ) return success(data=result, msg="回复对话消息成功") except BaseException as e: # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index e02ef62b..697a13bd 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -18,16 +18,19 @@ template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') db_session = next(get_db()) logger = get_agent_logger(__name__) + class ProblemNodeService(LLMServiceMixin): """问题处理节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 problem_service = ProblemNodeService() + async def Split_The_Problem(state: ReadState) -> ReadState: """问题分解节点""" # 从状态中获取数据 @@ -36,10 +39,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState: memory_config = state.get('memory_config', None) history = await SessionService(store).get_history(group_id, group_id, group_id) - + # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() - + system_prompt = await problem_service.template_service.render_template( template_name='problem_breakdown_prompt.jinja2', operation_name='split_the_problem', @@ -47,7 +50,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: sentence=content, json_schema=json_schema ) - + try: # 使用优化的LLM服务 structured = await problem_service.call_llm_structured( @@ -57,10 +60,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState: response_model=ProblemExtensionResponse, fallback_value=[] ) - + # 添加更详细的日志记录 logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") - + # 验证结构化响应 if not structured or not hasattr(structured, 'root'): logger.warning("Split_The_Problem: 结构化响应为空或格式不正确") @@ -73,17 +76,17 @@ async def Split_The_Problem(state: ReadState) -> ReadState: [item.model_dump() for item in structured.root], ensure_ascii=False ) - + split_result_dict = [] for index, item in enumerate(json.loads(split_result)): split_data = { - "id": f"Q{index+1}", + "id": f"Q{index + 1}", "question": item['extended_question'], "type": item['type'], "reason": item['reason'] } split_result_dict.append(split_data) - + logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项") result = { @@ -96,13 +99,13 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "original_query": content } } - + except Exception as e: logger.error( f"Split_The_Problem failed: {e}", exc_info=True ) - + # 提供更详细的错误信息 error_details = { "error_type": type(e).__name__, @@ -110,9 +113,9 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "content_length": len(content), "llm_model_id": memory_config.llm_model_id if memory_config else None } - + logger.error(f"Split_The_Problem error details: {error_details}") - + # 创建默认的空结果 result = { "context": json.dumps([], ensure_ascii=False), @@ -126,10 +129,11 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "error": error_details } } - + # 返回更新后的状态,包含spit_context字段 return {"spit_data": result} + async def Problem_Extension(state: ReadState) -> ReadState: """问题扩展节点""" # 获取原始数据和分解结果 @@ -153,10 +157,10 @@ async def Problem_Extension(state: ReadState) -> ReadState: data = [] history = await SessionService(store).get_history(group_id, group_id, group_id) - + # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() - + system_prompt = await problem_service.template_service.render_template( template_name='Problem_Extension_prompt.jinja2', operation_name='problem_extension', @@ -242,7 +246,4 @@ async def Problem_Extension(state: ReadState) -> ReadState: } } - return {"problem_extension": result} - - - + return {"problem_extension": result} \ No newline at end of file diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index c01889a9..19011a5f 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -59,7 +59,6 @@ async def make_read_graph(): workflow.add_conditional_edges("Retrieve", Retrieve_continue) workflow.add_edge("Retrieve_Summary", END) workflow.add_conditional_edges("Verify", Verify_continue) - workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) diff --git a/api/app/core/memory/agent/services/optimized_llm_service.py b/api/app/core/memory/agent/services/optimized_llm_service.py index 68919c4a..6942d421 100644 --- a/api/app/core/memory/agent/services/optimized_llm_service.py +++ b/api/app/core/memory/agent/services/optimized_llm_service.py @@ -162,7 +162,7 @@ class OptimizedLLMService: return fallback_value elif isinstance(fallback_value, dict): return response_model(**fallback_value) - + # 尝试创建空的响应模型 if hasattr(response_model, 'root'): # RootModel类型 @@ -170,7 +170,7 @@ class OptimizedLLMService: else: # 普通BaseModel类型 return response_model() - + except Exception as e: logger.error(f"创建降级响应失败: {e}") # 最后的降级策略 diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index d744b766..8170bdd8 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -683,7 +683,67 @@ class MemoryAgentService: logger.debug(f"Message type: {status}") return status - # ==================== 新增的三个接口方法 ==================== + async def generate_summary_from_retrieve( + self, + retrieve_info: str, + history: List[Dict], + query: str, + config_id: str, + db: Session + ) -> str: + """ + 基于检索信息、历史对话和查询生成最终答案 + + 使用 Retrieve_Summary_prompt.jinja2 模板调用大模型生成答案 + + Args: + retrieve_info: 检索到的信息 + history: 历史对话记录 + query: 用户查询 + config_id: 配置ID + db: 数据库会话 + + Returns: + 生成的答案文本 + """ + logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") + + try: + # 加载配置 + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="MemoryAgentService" + ) + + # 导入必要的模块 + from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import summary_llm + from app.core.memory.agent.models.summary_models import RetrieveSummaryResponse + + # 构建状态对象 + state = { + "data": query, + "memory_config": memory_config + } + + # 直接调用 summary_llm 函数 + answer = await summary_llm( + state=state, + history=history, + retrieve_info=retrieve_info, + template_name='Retrieve_Summary_prompt.jinja2', + operation_name='retrieve_summary', + response_model=RetrieveSummaryResponse, + search_mode="1" + ) + + logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...") + return answer if answer else "信息不足,无法回答。" + + except Exception as e: + logger.error(f"生成摘要失败: {str(e)}", exc_info=True) + return "信息不足,无法回答。" + async def get_knowledge_type_stats( self, From 4a4931bee228a40cd139eba456ab528ac0d91634 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 21 Jan 2026 19:37:03 +0800 Subject: [PATCH 7/8] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=B8=AD=E7=BF=BB?= =?UTF-8?q?=E8=8B=B1=E5=8A=9F=E8=83=BD=EF=BC=88=E8=AE=B0=E5=BF=86=E6=97=B6?= =?UTF-8?q?=E9=97=B4=E7=BA=BF=EF=BC=89(=E7=94=A8=E6=88=B7=E6=91=98?= =?UTF-8?q?=E8=A6=81)(=E5=85=B4=E8=B6=A3=E5=88=86=E5=B8=83=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3)(=E6=9F=A5=E8=AF=A2=E6=A0=B8=E5=BF=83=E6=A1=A3?= =?UTF-8?q?=E6=A1=88)(=E8=AE=B0=E5=BF=86=E6=B4=9E=E5=AF=9F)-=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E6=B7=BB=E5=8A=A0=E7=BF=BB=E8=AF=91=E5=AD=97=E6=AE=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/app/__init__.py | 0 .../controllers/memory_agent_controller.py | 51 ++-- .../service/memory_api_controller.py | 3 +- api/app/core/agent/langchain_agent.py | 69 +++-- .../langgraph_graph/nodes/problem_nodes.py | 8 +- .../langgraph_graph/nodes/retrieve_nodes.py | 18 +- .../langgraph_graph/nodes/summary_nodes.py | 16 +- .../nodes/verification_nodes.py | 6 +- .../langgraph_graph/nodes/write_nodes.py | 41 +-- .../agent/langgraph_graph/read_graph.py | 6 +- .../agent/langgraph_graph/tools/tool.py | 30 +- .../agent/langgraph_graph/write_graph.py | 19 +- .../agent/services/parameter_builder.py | 6 +- .../memory/agent/services/search_service.py | 8 +- .../memory/agent/services/session_service.py | 18 +- .../core/memory/agent/utils/get_dialogs.py | 75 +++-- api/app/core/memory/agent/utils/llm_tools.py | 10 +- api/app/core/memory/agent/utils/redis_tool.py | 26 +- .../core/memory/agent/utils/session_tools.py | 18 +- .../core/memory/agent/utils/write_tools.py | 14 +- .../core/memory/analytics/hot_memory_tags.py | 36 +-- .../analytics/implicit_memory/data_source.py | 4 +- .../memory/evaluation/dialogue_queries.py | 4 +- .../memory/evaluation/extraction_utils.py | 12 +- .../evaluation/locomo/locomo_benchmark.py | 26 +- .../memory/evaluation/locomo/locomo_test.py | 2 +- .../memory/evaluation/locomo/locomo_utils.py | 18 +- .../evaluation/locomo/qwen_search_eval.py | 24 +- .../longmemeval/qwen_search_eval.py | 52 ++-- .../evaluation/longmemeval/test_eval.py | 52 ++-- .../memory/evaluation/memsciqa/evaluate_qa.py | 14 +- .../evaluation/memsciqa/memsciqa-test.py | 14 +- api/app/core/memory/evaluation/run_eval.py | 20 +- api/app/core/memory/models/config_models.py | 4 +- api/app/core/memory/models/graph_models.py | 16 +- api/app/core/memory/models/message_models.py | 16 +- api/app/core/memory/src/search.py | 108 +++++-- .../data_preprocessing/data_preprocessor.py | 10 +- .../deduplication/deduped_and_disamb.py | 18 +- .../deduplication/entity_dedup_llm.py | 18 +- .../deduplication/second_layer_dedup.py | 8 +- .../deduplication/two_stage_dedup.py | 14 +- .../extraction_orchestrator.py | 50 ++-- .../knowledge_extraction/memory_summary.py | 6 +- .../statement_extraction.py | 16 +- .../temporal_extraction.py | 2 +- .../triplet_extraction.py | 2 +- .../access_history_manager.py | 86 +++--- .../forgetting_engine/forgetting_scheduler.py | 28 +- .../forgetting_engine/forgetting_strategy.py | 20 +- .../storage_services/search/__init__.py | 6 +- .../storage_services/search/hybrid_search.py | 14 +- .../storage_services/search/keyword_search.py | 12 +- .../search/search_strategy.py | 10 +- .../search/semantic_search.py | 12 +- api/app/core/memory/utils/config/get_data.py | 4 +- api/app/core/memory/utils/log/audit_logger.py | 12 +- api/app/core/rag/vdb/field.py | 2 +- api/app/repositories/neo4j/add_edges.py | 4 +- api/app/repositories/neo4j/add_nodes.py | 22 +- .../neo4j/base_neo4j_repository.py | 2 +- api/app/repositories/neo4j/cypher_queries.py | 165 ++++------- .../repositories/neo4j/dialog_repository.py | 32 +-- .../repositories/neo4j/emotion_repository.py | 22 +- api/app/repositories/neo4j/graph_saver.py | 12 +- api/app/repositories/neo4j/graph_search.py | 266 ++++++++---------- .../neo4j/memory_summary_repository.py | 30 +- api/app/repositories/neo4j/neo4j_connector.py | 16 +- .../neo4j/statement_repository.py | 2 +- api/app/schemas/memory_agent_schema.py | 4 +- api/app/services/draft_run_service.py | 2 +- api/app/services/emotion_analytics_service.py | 12 +- api/app/services/memory_agent_service.py | 121 ++++---- api/app/services/memory_api_service.py | 19 +- api/app/services/memory_base_service.py | 18 +- .../memory_entity_relationship_service.py | 4 +- api/app/services/memory_episodic_service.py | 30 +- api/app/services/memory_explicit_service.py | 16 +- api/app/services/memory_forget_service.py | 64 ++--- api/app/services/memory_konwledges_server.py | 14 +- api/app/services/memory_storage_service.py | 39 ++- api/app/services/pilot_run_service.py | 2 +- api/app/services/user_memory_service.py | 42 +-- api/app/tasks.py | 209 +++++++++----- 84 files changed, 1193 insertions(+), 1190 deletions(-) create mode 100644 api/app/__init__.py diff --git a/api/app/__init__.py b/api/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index 22830890..a1337085 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -125,7 +125,7 @@ async def write_server( Write service endpoint - processes write operations synchronously Args: - user_input: Write request containing message and group_id + user_input: Write request containing message and end_user_id Returns: Response with write operation status @@ -160,14 +160,11 @@ async def write_server( api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") storage_type = 'neo4j' - api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") + api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: - # 获取标准化的消息列表 - messages_list = memory_agent_service.get_messages_list(user_input) - result = await memory_agent_service.write_memory( - user_input.group_id, - messages_list, # 传递结构化消息列表 + user_input.end_user_id, + user_input.message, config_id, db, storage_type, @@ -196,7 +193,7 @@ async def write_server_async( Async write service endpoint - enqueues write processing to Celery Args: - user_input: Write request containing message and group_id + user_input: Write request containing message and end_user_id Returns: Task ID for tracking async operation @@ -224,12 +221,9 @@ async def write_server_async( if knowledge: user_rag_memory_id = str(knowledge.id) api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") try: - # 获取标准化的消息列表 - messages_list = memory_agent_service.get_messages_list(user_input) - task = celery_app.send_task( "app.core.memory.agent.write_message", - args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id] + args=[user_input.end_user_id, user_input.message, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Write task queued: {task.id}") @@ -255,7 +249,7 @@ async def read_server( - "2": Direct answer based on context Args: - user_input: Read request with message, history, search_switch, and group_id + user_input: Read request with message, history, search_switch, and end_user_id Returns: Response with query answer @@ -279,12 +273,13 @@ async def read_server( name="USER_RAG_MERORY", workspace_id=workspace_id ) - if knowledge: user_rag_memory_id = str(knowledge.id) + if knowledge: + user_rag_memory_id = str(knowledge.id) - api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") + api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") try: result = await memory_agent_service.read_memory( - user_input.group_id, + user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch, @@ -297,7 +292,7 @@ async def read_server( retrieve_info = result['answer'] history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) query = user_input.message - + # 调用 memory_agent_service 的方法生成最终答案 result['answer'] = await memory_agent_service.generate_summary_from_retrieve( retrieve_info=retrieve_info, @@ -403,7 +398,7 @@ async def read_server_async( try: task = celery_app.send_task( "app.core.memory.agent.read_message", - args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch, + args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch, config_id, storage_type, user_rag_memory_id] ) api_logger.info(f"Read task queued: {task.id}") @@ -447,7 +442,7 @@ async def get_read_task_result( return success( data={ "result": task_result.get("result"), - "group_id": task_result.get("group_id"), + "end_user_id": task_result.get("end_user_id"), "elapsed_time": task_result.get("elapsed_time"), "task_id": task_id }, @@ -524,7 +519,7 @@ async def get_write_task_result( return success( data={ "result": task_result.get("result"), - "group_id": task_result.get("group_id"), + "end_user_id": task_result.get("end_user_id"), "elapsed_time": task_result.get("elapsed_time"), "task_id": task_id }, @@ -578,16 +573,16 @@ async def status_type( Determine the type of user message (read or write) Args: - user_input: Request containing user message and group_id + user_input: Request containing user message and end_user_id Returns: Type classification result """ - api_logger.info(f"Status type check requested for group {user_input.group_id}") + api_logger.info(f"Status type check requested for group {user_input.end_user_id}") try: # 获取标准化的消息列表 messages_list = memory_agent_service.get_messages_list(user_input) - + # 将消息列表转换为字符串用于分类 # 只取最后一条用户消息进行分类 last_user_message = "" @@ -595,13 +590,13 @@ async def status_type( if msg.get('role') == 'user': last_user_message = msg.get('content', '') break - + if not last_user_message: # 如果没有用户消息,使用所有消息的内容 last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) - + result = await memory_agent_service.classify_message_type( - last_user_message, + user_input.message, user_input.config_id, db ) @@ -624,7 +619,7 @@ async def get_knowledge_type_stats_api( 会对缺失类型补 0,返回字典形式。 可选按状态过滤。 - 知识库类型根据当前用户的 current_workspace_id 过滤 - - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤 + - memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (end_user_id) 过滤 - 如果用户没有当前工作空间或未提供 end_user_id,对应的统计返回 0 """ api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") @@ -697,7 +692,7 @@ async def get_user_profile_api( current_user: User = Depends(get_current_user) ): """ - 获取工作空间下Popular Memory Tags,包含: + 获取用户详情,包含: - name: 用户名字(直接使用 end_user_id) - tags: 3个用户特征标签(从语句和实体中LLM总结) - hot_tags: 4个热门记忆标签 diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 30ca1306..87c1aa20 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -39,7 +39,7 @@ async def write_memory_api_service( Stores memory content for the specified end user using the Memory API Service. """ - logger.info(f"Memory write request - end_user_id: {payload.end_user_id}") + logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}") memory_api_service = MemoryAPIService(db) @@ -50,6 +50,7 @@ async def write_memory_api_service( config_id=payload.config_id, storage_type=payload.storage_type, user_rag_memory_id=payload.user_rag_memory_id, + tenant_id=api_key_auth.tenant_id, ) logger.info(f"Memory write successful for end_user: {payload.end_user_id}") diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 87b46e6f..e6c59a79 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -145,41 +145,38 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_save(self,messages,end_user_end,aimessages): - # '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' - # end_user_end=f"Term_{end_user_end}" - # print(messages) - # print(aimessages) - # session_id = store.save_session( - # userid=end_user_end, - # messages=messages, - # apply_id=end_user_end, - # group_id=end_user_end, - # aimessages=aimessages - # ) - # store.delete_duplicate_sessions() - # # logger.info(f'Redis_Agent:{end_user_end};{session_id}') - # return session_id - -# TODO 乐力齐 - 累积多组对话批量写入功能已禁用 - # async def term_memory_redis_read(self,end_user_end): - # end_user_end = f"Term_{end_user_end}" - # history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) - # # logger.info(f'Redis_Agent:{end_user_end};{history}') - # messagss_list=[] - # retrieved_content=[] - # for messages in history: - # query = messages.get("Query") - # aimessages = messages.get("Answer") - # messagss_list.append(f'用户:{query}。AI回复:{aimessages}') - # retrieved_content.append({query: aimessages}) - # return messagss_list,retrieved_content + async def term_memory_save(self,messages,end_user_end,aimessages): + '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j''' + end_user_end=f"Term_{end_user_end}" + print(messages) + print(aimessages) + session_id = store.save_session( + userid=end_user_end, + messages=messages, + apply_id=end_user_end, + end_user_id=end_user_end, + aimessages=aimessages + ) + store.delete_duplicate_sessions() + # logger.info(f'Redis_Agent:{end_user_end};{session_id}') + return session_id + async def term_memory_redis_read(self,end_user_end): + end_user_end = f"Term_{end_user_end}" + history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end) + # logger.info(f'Redis_Agent:{end_user_end};{history}') + messagss_list=[] + retrieved_content=[] + for messages in history: + query = messages.get("Query") + aimessages = messages.get("Answer") + messagss_list.append(f'用户:{query}。AI回复:{aimessages}') + retrieved_content.append({query: aimessages}) + return messagss_list,retrieved_content async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id): """ 写入记忆(支持结构化消息) - + Args: storage_type: 存储类型 (neo4j/rag) end_user_id: 终端用户ID @@ -188,7 +185,7 @@ class LangChainAgent: user_rag_memory_id: RAG 记忆ID actual_end_user_id: 实际用户ID actual_config_id: 配置ID - + 逻辑说明: - RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变 - Neo4j 模式:使用结构化消息列表 @@ -204,20 +201,20 @@ class LangChainAgent: else: # Neo4j 模式:使用结构化消息列表 structured_messages = [] - + # 始终添加用户消息(如果不为空) if user_message: structured_messages.append({"role": "user", "content": user_message}) - + # 只有当 AI 回复不为空时才添加 assistant 消息 if ai_message: structured_messages.append({"role": "assistant", "content": ai_message}) - + # 如果没有消息,直接返回 if not structured_messages: logger.warning(f"No messages to write for user {actual_end_user_id}") return - + # 调用 Celery 任务,传递结构化消息列表 # 数据流: # 1. structured_messages 传递给 write_message_task diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index 697a13bd..bb8f3ae5 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -35,10 +35,10 @@ async def Split_The_Problem(state: ReadState) -> ReadState: """问题分解节点""" # 从状态中获取数据 content = state.get('data', '') - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - history = await SessionService(store).get_history(group_id, group_id, group_id) + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() @@ -140,7 +140,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: start = time.time() content = state.get('data', '') data = state.get('spit_data', '')['context'] - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') memory_config = state.get('memory_config', None) @@ -156,7 +156,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: databasets = {} data = [] - history = await SessionService(store).get_history(group_id, group_id, group_id) + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = ProblemExtensionResponse.model_json_schema() diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index 14f8fa8b..1880357c 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -52,9 +52,9 @@ async def rag_config(state): return kb_config async def rag_knowledge(state,question): kb_config = await rag_config(state) - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') user_rag_memory_id=state.get("user_rag_memory_id",'') - retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)]) + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) try: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] clean_content = '\n\n'.join(retrieval_knowledge) @@ -159,7 +159,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: problem_extension=state.get('problem_extension', '')['context'] storage_type=state.get('storage_type', '') user_rag_memory_id=state.get('user_rag_memory_id', '') - group_id=state.get('group_id', '') + end_user_id=state.get('end_user_id', '') memory_config = state.get('memory_config', None) original=state.get('data', '') problem_list=[] @@ -172,7 +172,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: try: # Prepare search parameters based on storage type search_params = { - "group_id": group_id, + "end_user_id": end_user_id, "question": question, "return_raw_results": True } @@ -263,13 +263,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState: async def retrieve(state: ReadState) -> ReadState: - # 从state中获取group_id + # 从state中获取end_user_id import time start=time.time() problem_extension = state.get('problem_extension', '')['context'] storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) original = state.get('data', '') problem_list = [] @@ -295,13 +295,13 @@ async def retrieve(state: ReadState) -> ReadState: temperature=0.2, ) - time_retrieval_tool = create_time_retrieval_tool(group_id) - search_params = { "group_id": group_id, "return_raw_results": True } + time_retrieval_tool = create_time_retrieval_tool(end_user_id) + search_params = { "end_user_id": end_user_id, "return_raw_results": True } hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) agent = create_agent( llm, tools=[time_retrieval_tool,hybrid_retrieval], - system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}" + system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" ) # 创建异步任务处理单个问题 diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 44f89c6a..8ccad579 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -34,8 +34,8 @@ class SummaryNodeService(LLMServiceMixin): summary_service = SummaryNodeService() async def summary_history(state: ReadState) -> ReadState: - group_id = state.get("group_id", '') - history = await SessionService(store).get_history(group_id, group_id, group_id) + end_user_id = state.get("end_user_id", '') + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) return history async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: @@ -122,12 +122,12 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o async def summary_redis_save(state: ReadState,aimessages) -> ReadState: data = state.get("data", '') - group_id = state.get("group_id", '') + end_user_id = state.get("end_user_id", '') await SessionService(store).save_session( - user_id=group_id, + user_id=end_user_id, query=data, - apply_id=group_id, - group_id=group_id, + apply_id=end_user_id, + end_user_id=end_user_id, ai_response=aimessages ) await SessionService(store).cleanup_duplicates() @@ -175,11 +175,11 @@ async def Input_Summary(state: ReadState) -> ReadState: memory_config = state.get('memory_config', None) user_rag_memory_id=state.get("user_rag_memory_id",'') data=state.get("data", '') - group_id=state.get("group_id", '') + end_user_id=state.get("end_user_id", '') logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") history = await summary_history( state) search_params = { - "group_id": group_id, + "end_user_id": end_user_id, "question": data, "return_raw_results": True, "include": ["summaries"] # Only search summary nodes for faster performance diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index dac7ea14..ad605ec9 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -62,12 +62,12 @@ async def Verify(state: ReadState): logger.info("=== Verify 节点开始执行 ===") try: content = state.get('data', '') - group_id = state.get('group_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}") + logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}") - history = await SessionService(store).get_history(group_id, group_id, group_id) + history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) logger.info(f"Verify: 获取历史记录完成,history length={len(history)}") retrieve = state.get("retrieve", {}) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index 6af313c3..e2a61045 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -9,47 +9,36 @@ async def write_node(state: WriteState) -> WriteState: Write data to the database/file system. Args: - state: WriteState containing messages, group_id, and memory_config + content: Data content to write + end_user_id: End user identifier + memory_config: MemoryConfig object containing all configuration Returns: - dict: Contains 'write_result' with status and data fields + dict: Contains 'status', 'saved_to', and 'data' fields """ - messages = state.get('messages', []) - group_id = state.get('group_id', '') - memory_config = state.get('memory_config', '') - - # Convert LangChain messages to structured format expected by write() - structured_messages = [] - for msg in messages: - if hasattr(msg, 'type') and hasattr(msg, 'content'): - # Map LangChain message types to role names - role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type - structured_messages.append({ - "role": role, - "content": msg.content # content is now guaranteed to be a string - }) - + content=state.get('data','') + end_user_id=state.get('end_user_id','') + memory_config=state.get('memory_config', '') try: - result = await write( - messages=structured_messages, - user_id=group_id, - apply_id=group_id, - group_id=group_id, + result=await write( + content=content, + end_user_id=end_user_id, memory_config=memory_config, ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") - write_result = { + write_result= { "status": "success", - "data": structured_messages, + "data": content, "config_id": memory_config.config_id, "config_name": memory_config.config_name, } - return {"write_result": write_result} + return {"write_result":write_result} + except Exception as e: logger.error(f"Data_write failed: {e}", exc_info=True) - write_result = { + write_result= { "status": "error", "message": str(e), } diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index 19011a5f..3476d0ec 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -79,7 +79,7 @@ async def make_read_graph(): async def main(): """主函数 - 运行工作流""" message = "昨天有什么好看的电影" - group_id = '88a459f5_text09' # 组ID + end_user_id = '88a459f5_text09' # 组ID storage_type = 'neo4j' # 存储类型 search_switch = '1' # 搜索开关 user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID @@ -95,9 +95,9 @@ async def main(): start=time.time() try: async with make_read_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id + initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} # 获取节点更新信息 _intermediate_outputs = [] diff --git a/api/app/core/memory/agent/langgraph_graph/tools/tool.py b/api/app/core/memory/agent/langgraph_graph/tools/tool.py index ce6d5dd4..c4814de1 100644 --- a/api/app/core/memory/agent/langgraph_graph/tools/tool.py +++ b/api/app/core/memory/agent/langgraph_graph/tools/tool.py @@ -48,11 +48,11 @@ def extract_tool_message_content(response): class TimeRetrievalInput(BaseModel): """时间检索工具的输入模式""" context: str = Field(description="用户输入的查询内容") - group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") + end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果") -def create_time_retrieval_tool(group_id: str): +def create_time_retrieval_tool(end_user_id: str): """ - 创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) + 创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements) """ def clean_temporal_result_fields(data): @@ -93,26 +93,26 @@ def create_time_retrieval_tool(group_id: str): return data @tool - def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str: + def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str: """ 优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段 显式接收参数: - context: 查询上下文内容 - start_date: 开始时间(可选,格式:YYYY-MM-DD) - end_date: 结束时间(可选,格式:YYYY-MM-DD) - - group_id_param: 组ID(可选,用于覆盖默认组ID) + - end_user_id_param: 组ID(可选,用于覆盖默认组ID) - clean_output: 是否清理输出中的元数据字段 -end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d") """ async def _async_search(): # 使用传入的参数或默认值 - actual_group_id = group_id_param or group_id + actual_end_user_id = end_user_id_param or end_user_id actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d") actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d") # 基本时间搜索 results = await search_by_temporal( - group_id=actual_group_id, + end_user_id=actual_end_user_id, start_date=actual_start_date, end_date=actual_end_date, limit=10 @@ -147,7 +147,7 @@ def create_time_retrieval_tool(group_id: str): # 关键词时间搜索 results = await search_by_keyword_temporal( query_text=context, - group_id=group_id, + end_user_id=end_user_id, start_date=actual_start_date, end_date=actual_end_date, limit=15 @@ -172,7 +172,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): Args: memory_config: 内存配置对象 - **search_params: 搜索参数,包含group_id, limit, include等 + **search_params: 搜索参数,包含end_user_id, limit, include等 """ def clean_result_fields(data): @@ -211,7 +211,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): context: str, search_type: str = "hybrid", limit: int = 10, - group_id: str = None, + end_user_id: str = None, rerank_alpha: float = 0.6, use_forgetting_rerank: bool = False, use_llm_rerank: bool = False, @@ -224,7 +224,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): context: 查询内容 search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') limit: 结果数量限制 - group_id: 组ID,用于过滤搜索结果 + end_user_id: 组ID,用于过滤搜索结果 rerank_alpha: 重排序权重参数 use_forgetting_rerank: 是否使用遗忘重排序 use_llm_rerank: 是否使用LLM重排序 @@ -238,7 +238,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params): final_params = { "query_text": context, "search_type": search_type, - "group_id": group_id or search_params.get("group_id"), + "end_user_id": end_user_id or search_params.get("end_user_id"), "limit": limit or search_params.get("limit", 10), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]), "output_path": None, # 不保存到文件 @@ -291,7 +291,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): context: str, search_type: str = "hybrid", limit: int = 10, - group_id: str = None, + end_user_id: str = None, clean_output: bool = True ) -> str: """ @@ -301,7 +301,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): context: 查询内容 search_type: 搜索类型 ('keyword', 'embedding', 'hybrid') limit: 结果数量限制 - group_id: 组ID,用于过滤搜索结果 + end_user_id: 组ID,用于过滤搜索结果 clean_output: 是否清理输出中的元数据字段 """ async def _async_search(): @@ -311,7 +311,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params): "context": context, "search_type": search_type, "limit": limit, - "group_id": group_id, + "end_user_id": end_user_id, "clean_output": clean_output }) diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index fe281a23..d8fcf210 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -14,6 +14,7 @@ from app.db import get_db from app.core.logging_config import get_agent_logger from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write from app.services.memory_config_service import MemoryConfigService warnings.filterwarnings("ignore", category=RuntimeWarning) @@ -26,12 +27,18 @@ async def make_write_graph(): """ Create a write graph workflow for memory operations. - The workflow directly processes messages from the initial state - and saves them to Neo4j storage. + Args: + user_id: User identifier + tools: MCP tools loaded from session + apply_id: Application identifier + end_user_id: Group identifier + memory_config: MemoryConfig object containing all configuration """ workflow = StateGraph(WriteState) + workflow.add_node("content_input", content_input_write) workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "save_neo4j") + workflow.add_edge(START, "content_input") + workflow.add_edge("content_input", "save_neo4j") workflow.add_edge("save_neo4j", END) graph = workflow.compile() @@ -42,7 +49,7 @@ async def make_write_graph(): async def main(): """主函数 - 运行工作流""" message = "今天周一" - group_id = 'new_2025test1103' # 组ID + end_user_id = 'new_2025test1103' # 组ID # 获取数据库会话 @@ -54,9 +61,9 @@ async def main(): ) try: async with make_write_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config} + initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config} # 获取节点更新信息 async for update_event in graph.astream( diff --git a/api/app/core/memory/agent/services/parameter_builder.py b/api/app/core/memory/agent/services/parameter_builder.py index a58fcf1a..74382ade 100644 --- a/api/app/core/memory/agent/services/parameter_builder.py +++ b/api/app/core/memory/agent/services/parameter_builder.py @@ -24,7 +24,7 @@ class ParameterBuilder: tool_call_id: str, search_switch: str, apply_id: str, - group_id: str, + end_user_id: str, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None ) -> Dict[str, Any]: @@ -44,7 +44,7 @@ class ParameterBuilder: tool_call_id: Extracted tool call identifier search_switch: Search routing parameter apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier storage_type: Storage type for the workspace (optional) user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional) @@ -55,7 +55,7 @@ class ParameterBuilder: base_args = { "usermessages": tool_call_id, "apply_id": apply_id, - "group_id": group_id + "end_user_id": end_user_id } # Always add storage_type and user_rag_memory_id (with defaults if None) diff --git a/api/app/core/memory/agent/services/search_service.py b/api/app/core/memory/agent/services/search_service.py index 8a2e7cfe..4fc4256e 100644 --- a/api/app/core/memory/agent/services/search_service.py +++ b/api/app/core/memory/agent/services/search_service.py @@ -91,7 +91,7 @@ class SearchService: async def execute_hybrid_search( self, - group_id: str, + end_user_id: str, question: str, limit: int = 5, search_type: str = "hybrid", @@ -105,7 +105,7 @@ class SearchService: Execute hybrid search and return clean content. Args: - group_id: Group identifier for filtering results + end_user_id: Group identifier for filtering results question: Search query text limit: Maximum number of results to return (default: 5) search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") @@ -130,7 +130,7 @@ class SearchService: answer = await run_hybrid_search( query_text=cleaned_query, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include, output_path=output_path, @@ -186,7 +186,7 @@ class SearchService: except Exception as e: logger.error( - f"Search failed for query '{question}' in group '{group_id}': {e}", + f"Search failed for query '{question}' in group '{end_user_id}': {e}", exc_info=True ) # Return empty results on failure diff --git a/api/app/core/memory/agent/services/session_service.py b/api/app/core/memory/agent/services/session_service.py index b2d4f0ff..f7389984 100644 --- a/api/app/core/memory/agent/services/session_service.py +++ b/api/app/core/memory/agent/services/session_service.py @@ -59,7 +59,7 @@ class SessionService: self, user_id: str, apply_id: str, - group_id: str + end_user_id: str ) -> List[dict]: """ Retrieve conversation history from Redis. @@ -67,20 +67,20 @@ class SessionService: Args: user_id: User identifier apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier Returns: List of conversation history items with Query and Answer keys Returns empty list if no history found or on error """ try: - history = self.store.find_user_apply_group(user_id, apply_id, group_id) + history = self.store.find_user_apply_group(user_id, apply_id, end_user_id) # Validate history structure if not isinstance(history, list): logger.warning( f"Invalid history format for user {user_id}, " - f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" + f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}" ) return [] @@ -89,7 +89,7 @@ class SessionService: except Exception as e: logger.error( f"Failed to retrieve history for user {user_id}, " - f"apply {apply_id}, group {group_id}: {e}", + f"apply {apply_id}, group {end_user_id}: {e}", exc_info=True ) # Return empty list on error to allow execution to continue @@ -100,7 +100,7 @@ class SessionService: user_id: str, query: str, apply_id: str, - group_id: str, + end_user_id: str, ai_response: str ) -> Optional[str]: """ @@ -110,7 +110,7 @@ class SessionService: user_id: User identifier query: User query/message apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier ai_response: AI response/answer Returns: @@ -131,7 +131,7 @@ class SessionService: userid=user_id, messages=query, apply_id=apply_id, - group_id=group_id, + end_user_id=end_user_id, aimessages=ai_response ) @@ -152,7 +152,7 @@ class SessionService: Duplicates are identified by matching: - sessionid - user_id (id field) - - group_id + - end_user_id - messages - aimessages diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index 82a41773..4751f18c 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -9,65 +9,56 @@ from app.core.memory.models.message_models import DialogData, ConversationContex async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", - group_id: str = "group_1", - user_id: str = "user1", - apply_id: str = "applyid", - messages: list = None, + end_user_id: str = "group_1", + content: str = "这是用户的输入", ref_id: str = "wyl_20251027", config_id: str = None ) -> List[DialogData]: - """Generate chunks from structured messages using the specified chunker strategy. + """Generate chunks from all test data entries using the specified chunker strategy. Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - group_id: Group identifier - user_id: User identifier - apply_id: Application identifier - messages: Structured message list [{"role": "user", "content": "..."}, ...] + end_user_id: End user identifier + content: Dialog content ref_id: Reference identifier config_id: Configuration ID for processing Returns: - List of DialogData objects with generated chunks + List of DialogData objects with generated chunks for each test entry """ - from app.core.logging_config import get_agent_logger - logger = get_agent_logger(__name__) - - if not messages or not isinstance(messages, list) or len(messages) == 0: - raise ValueError("messages parameter must be a non-empty list") - - conversation_messages = [] - - for idx, msg in enumerate(messages): - if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: - raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields") - - role = msg['role'] - content = msg['content'] - - if role not in ['user', 'assistant']: - raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") - - if content.strip(): - conversation_messages.append(ConversationMessage(role=role, msg=content.strip())) - - if not conversation_messages: - raise ValueError("Message list cannot be empty after filtering") - - conversation_context = ConversationContext(msgs=conversation_messages) + dialog_data_list = [] + messages = [] + + messages.append(ConversationMessage(role="用户", msg=content)) + + # Create DialogData + conversation_context = ConversationContext(msgs=messages) + # Create DialogData with end_user_id dialog_data = DialogData( context=conversation_context, ref_id=ref_id, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, + end_user_id=end_user_id, config_id=config_id ) - + # Create DialogueChunker and process the dialogue chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) dialog_data.chunks = extracted_chunks - - logger.info(f"DialogData created with {len(extracted_chunks)} chunks") - return [dialog_data] + dialog_data_list.append(dialog_data) + + # Convert to dict with datetime serialized + def serialize_datetime(obj): + if isinstance(obj, datetime): + return obj.isoformat() + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + + combined_output = [dd.model_dump() for dd in dialog_data_list] + + print(dialog_data_list) + + # with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f: + # json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime) + + + return dialog_data_list diff --git a/api/app/core/memory/agent/utils/llm_tools.py b/api/app/core/memory/agent/utils/llm_tools.py index 8dd2f1d3..aca7fdd7 100644 --- a/api/app/core/memory/agent/utils/llm_tools.py +++ b/api/app/core/memory/agent/utils/llm_tools.py @@ -12,13 +12,11 @@ class WriteState(TypedDict): Langgrapg Writing TypedDict ''' messages: Annotated[list[AnyMessage], add_messages] - user_id:str - apply_id:str - group_id:str + end_user_id: str errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}] memory_config: object write_result: dict - data:str + data: str class ReadState(TypedDict): """ @@ -28,7 +26,7 @@ class ReadState(TypedDict): messages: 消息列表,支持自动追加 loop_count: 遍历次数 search_switch: 搜索类型开关 - group_id: 组标识 + end_user_id: 组标识 config_id: 配置ID,用于过滤结果 data: 从content_input_node传递的内容数据 spit_data: 从Split_The_Problem传递的分解结果 @@ -39,7 +37,7 @@ class ReadState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式 loop_count: int search_switch: str - group_id: str + end_user_id: str config_id: str data: str # 新增字段用于传递内容 spit_data: dict # 新增字段用于传递问题分解结果 diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index 31a76a11..505545b3 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -28,7 +28,7 @@ class RedisSessionStore: return text # 修改后的 save_session 方法 - def save_session(self, userid, messages, aimessages, apply_id, group_id): + def save_session(self, userid, messages, aimessages, apply_id, end_user_id): """ 写入一条会话数据,返回 session_id 优化版本:确保写入时间不超过1秒 @@ -46,7 +46,7 @@ class RedisSessionStore: "id": self.uudi, "sessionid": userid, "apply_id": apply_id, - "group_id": group_id, + "end_user_id": end_user_id, "messages": messages, "aimessages": aimessages, "starttime": starttime @@ -67,7 +67,7 @@ class RedisSessionStore: def save_sessions_batch(self, sessions_data): """ 批量写入多条会话数据,返回 session_id 列表 - sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, group_id + sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id 优化版本:批量操作,大幅提升性能 """ try: @@ -83,7 +83,7 @@ class RedisSessionStore: "id": self.uudi, "sessionid": session.get('userid'), "apply_id": session.get('apply_id'), - "group_id": session.get('group_id'), + "end_user_id": session.get('end_user_id'), "messages": session.get('messages'), "aimessages": session.get('aimessages'), "starttime": starttime @@ -108,9 +108,9 @@ class RedisSessionStore: data = self.r.hgetall(key) return data if data else None - def get_session_apply_group(self, sessionid, apply_id, group_id): + def get_session_apply_group(self, sessionid, apply_id, end_user_id): """ - 根据 sessionid、apply_id 和 group_id 三个条件查询会话数据 + 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据 """ result_items = [] @@ -124,7 +124,7 @@ class RedisSessionStore: # 检查三个条件是否都匹配 if (data.get('sessionid') == sessionid and data.get('apply_id') == apply_id and - data.get('group_id') == group_id): + data.get('end_user_id') == end_user_id): result_items.append(data) return result_items @@ -172,7 +172,7 @@ class RedisSessionStore: def delete_duplicate_sessions(self): """ 删除重复会话数据,条件: - "sessionid"、"user_id"、"group_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 + "sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除 优化版本:使用 pipeline 批量操作,确保在1秒内完成 """ import time @@ -202,12 +202,12 @@ class RedisSessionStore: # 获取五个字段的值 sessionid = data.get('sessionid', '') user_id = data.get('id', '') - group_id = data.get('group_id', '') + end_user_id = data.get('end_user_id', '') messages = data.get('messages', '') aimessages = data.get('aimessages', '') # 用五元组作为唯一标识 - identifier = (sessionid, user_id, group_id, messages, aimessages) + identifier = (sessionid, user_id, end_user_id, messages, aimessages) if identifier in seen: # 重复,标记为待删除 @@ -248,9 +248,9 @@ class RedisSessionStore: result_items = [] return (result_items) - def find_user_apply_group(self, sessionid, apply_id, group_id): + def find_user_apply_group(self, sessionid, apply_id, end_user_id): """ - 根据 sessionid、apply_id 和 group_id 三个条件查询会话数据,返回最新的6条 + 根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条 """ import time start_time = time.time() @@ -276,7 +276,7 @@ class RedisSessionStore: # 检查是否符合三个条件 if (data.get('apply_id') == apply_id and - data.get('group_id') == group_id): + data.get('end_user_id') == end_user_id): # 支持模糊匹配 sessionid 或者完全匹配 if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: matched_items.append({ diff --git a/api/app/core/memory/agent/utils/session_tools.py b/api/app/core/memory/agent/utils/session_tools.py index b2d4f0ff..f7389984 100644 --- a/api/app/core/memory/agent/utils/session_tools.py +++ b/api/app/core/memory/agent/utils/session_tools.py @@ -59,7 +59,7 @@ class SessionService: self, user_id: str, apply_id: str, - group_id: str + end_user_id: str ) -> List[dict]: """ Retrieve conversation history from Redis. @@ -67,20 +67,20 @@ class SessionService: Args: user_id: User identifier apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier Returns: List of conversation history items with Query and Answer keys Returns empty list if no history found or on error """ try: - history = self.store.find_user_apply_group(user_id, apply_id, group_id) + history = self.store.find_user_apply_group(user_id, apply_id, end_user_id) # Validate history structure if not isinstance(history, list): logger.warning( f"Invalid history format for user {user_id}, " - f"apply {apply_id}, group {group_id}: expected list, got {type(history)}" + f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}" ) return [] @@ -89,7 +89,7 @@ class SessionService: except Exception as e: logger.error( f"Failed to retrieve history for user {user_id}, " - f"apply {apply_id}, group {group_id}: {e}", + f"apply {apply_id}, group {end_user_id}: {e}", exc_info=True ) # Return empty list on error to allow execution to continue @@ -100,7 +100,7 @@ class SessionService: user_id: str, query: str, apply_id: str, - group_id: str, + end_user_id: str, ai_response: str ) -> Optional[str]: """ @@ -110,7 +110,7 @@ class SessionService: user_id: User identifier query: User query/message apply_id: Application identifier - group_id: Group identifier + end_user_id: Group identifier ai_response: AI response/answer Returns: @@ -131,7 +131,7 @@ class SessionService: userid=user_id, messages=query, apply_id=apply_id, - group_id=group_id, + end_user_id=end_user_id, aimessages=ai_response ) @@ -152,7 +152,7 @@ class SessionService: Duplicates are identified by matching: - sessionid - user_id (id field) - - group_id + - end_user_id - messages - aimessages diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 1df0b336..ce55286e 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -29,9 +29,7 @@ logger = get_agent_logger(__name__) async def write( - user_id: str, - apply_id: str, - group_id: str, + end_user_id: str, memory_config: MemoryConfig, messages: list, ref_id: str = "wyl20251027", @@ -40,9 +38,7 @@ async def write( Execute the complete knowledge extraction pipeline. Args: - user_id: User identifier - apply_id: Application identifier - group_id: Group identifier + end_user_id: End user identifier memory_config: MemoryConfig object containing all configuration messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference ID, defaults to "wyl20251027" @@ -58,7 +54,7 @@ async def write( logger.info(f"LLM model: {memory_config.llm_model_name}") logger.info(f"Embedding model: {memory_config.embedding_model_name}") logger.info(f"Chunker strategy: {chunker_strategy}") - logger.info(f"Group ID: {group_id}") + logger.info(f"End User ID: {end_user_id}") # Construct clients from memory_config using factory pattern with db session with get_db_context() as db: @@ -83,9 +79,7 @@ async def write( step_start = time.time() chunked_dialogs = await get_chunked_dialogs( chunker_strategy=chunker_strategy, - group_id=group_id, - user_id=user_id, - apply_id=apply_id, + end_user_id=end_user_id, messages=messages, ref_id=ref_id, config_id=config_id, diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index cab6cacd..95302726 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -16,13 +16,13 @@ class FilteredTags(BaseModel): """用于接收LLM筛选后的核心标签列表的模型。""" meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") -async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: +async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 Args: tags: 原始标签列表 - group_id: 用户组ID,用于获取配置 + end_user_id: 用户组ID,用于获取配置 Returns: 筛选后的标签列表 @@ -37,12 +37,12 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: get_end_user_connected_config, ) - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) config_id = connected_config.get("memory_config_id") if not config_id: raise ValueError( - f"No memory_config_id found for group_id: {group_id}. " + f"No memory_config_id found for end_user_id: {end_user_id}. " "Please ensure the user has a valid memory configuration." ) @@ -87,7 +87,7 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]: async def get_raw_tags_from_db( connector: Neo4jConnector, - group_id: str, + end_user_id: str, limit: int, by_user: bool = False ) -> List[Tuple[str, int]]: @@ -99,9 +99,9 @@ async def get_raw_tags_from_db( Args: connector: Neo4j连接器实例 - group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id + end_user_id: 如果by_user=False,则为end_user_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 - by_user: 是否按user_id查询(默认False,按group_id查询) + by_user: 是否按user_id查询(默认False,按end_user_id查询) Returns: List[Tuple[str, int]]: 标签名称和频率的元组列表 @@ -119,7 +119,7 @@ async def get_raw_tags_from_db( else: query = ( "MATCH (e:ExtractedEntity) " - "WHERE e.group_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " + "WHERE e.end_user_id = $id AND e.entity_type <> '人物' AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "RETURN e.name AS name, count(e) AS frequency " "ORDER BY frequency DESC " "LIMIT $limit" @@ -128,44 +128,44 @@ async def get_raw_tags_from_db( # 使用项目的Neo4jConnector执行查询 results = await connector.execute_query( query, - id=group_id, + id=end_user_id, limit=limit, names_to_exclude=names_to_exclude ) return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 Args: - group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id + end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id limit: 返回的标签数量限制 - by_user: 是否按user_id查询(默认False,按group_id查询) + by_user: 是否按user_id查询(默认False,按end_user_id查询) Raises: - ValueError: 如果group_id未提供或为空 + ValueError: 如果end_user_id未提供或为空 """ - # 验证group_id必须提供且不为空 - if not group_id or not group_id.strip(): + # 验证end_user_id必须提供且不为空 + if not end_user_id or not end_user_id.strip(): raise ValueError( - "group_id is required. Please provide a valid group_id or user_id." + "end_user_id is required. Please provide a valid end_user_id or user_id." ) # 使用项目的Neo4jConnector connector = Neo4jConnector() try: # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user) + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user) if not raw_tags_with_freq: return [] raw_tag_names = [tag for tag, freq in raw_tags_with_freq] # 2. 初始化LLM客户端并使用LLM筛选出有意义的标签 - meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id) + meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, end_user_id) # 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序) final_tags = [] diff --git a/api/app/core/memory/analytics/implicit_memory/data_source.py b/api/app/core/memory/analytics/implicit_memory/data_source.py index d277a05e..18678a55 100644 --- a/api/app/core/memory/analytics/implicit_memory/data_source.py +++ b/api/app/core/memory/analytics/implicit_memory/data_source.py @@ -75,8 +75,8 @@ class MemoryDataSource: start_date = time_range.start_date if time_range else None end_date = time_range.end_date if time_range else None - summary_dicts = await self.memory_summary_repo.find_by_group_id( - group_id=user_id, + summary_dicts = await self.memory_summary_repo.find_by_end_user_id( + end_user_id=user_id, limit=limit, start_date=start_date, end_date=end_date diff --git a/api/app/core/memory/evaluation/dialogue_queries.py b/api/app/core/memory/evaluation/dialogue_queries.py index fd7fa671..25abe64e 100644 --- a/api/app/core/memory/evaluation/dialogue_queries.py +++ b/api/app/core/memory/evaluation/dialogue_queries.py @@ -41,7 +41,7 @@ DIALOGUE_EMBEDDING_SEARCH = """ WITH $embedding AS q MATCH (d:Dialogue) WHERE d.dialog_embedding IS NOT NULL - AND ($group_id IS NULL OR d.group_id = $group_id) + AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id) WITH d, q, d.dialog_embedding AS v WITH d, reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot, @@ -50,7 +50,7 @@ WITH d, WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score WHERE score > $threshold RETURN d.id AS dialog_id, - d.group_id AS group_id, + d.end_user_id AS end_user_id, d.content AS content, d.created_at AS created_at, d.expired_at AS expired_at, diff --git a/api/app/core/memory/evaluation/extraction_utils.py b/api/app/core/memory/evaluation/extraction_utils.py index 9afa228c..9e70bc28 100644 --- a/api/app/core/memory/evaluation/extraction_utils.py +++ b/api/app/core/memory/evaluation/extraction_utils.py @@ -36,7 +36,7 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector async def ingest_contexts_via_full_pipeline( contexts: List[str], - group_id: str, + end_user_id: str, chunker_strategy: str | None = None, embedding_name: str | None = None, save_chunk_output: bool = False, @@ -48,7 +48,7 @@ async def ingest_contexts_via_full_pipeline( This function mirrors the steps in main(), but starts from raw text contexts. Args: contexts: List of dialogue texts, each containing lines like "role: message". - group_id: Group ID to assign to generated DialogData and graph nodes. + end_user_id: Group ID to assign to generated DialogData and graph nodes. chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY. embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. @@ -109,7 +109,7 @@ async def ingest_contexts_via_full_pipeline( dialog = DialogData( context=context_model, ref_id=f"pipeline_item_{idx}", - group_id=group_id, + end_user_id=end_user_id, user_id="default_user", apply_id="default_application", ) @@ -318,16 +318,16 @@ async def handle_context_processing(args): print("No contexts provided for processing.") return False - return await main_from_contexts(contexts, args.context_group_id) + return await main_from_contexts(contexts, args.context_end_user_id) -async def main_from_contexts(contexts: List[str], group_id: str): +async def main_from_contexts(contexts: List[str], end_user_id: str): """Run the pipeline from provided dialogue contexts instead of test data.""" print("=== Running pipeline from provided contexts ===") success = await ingest_contexts_via_full_pipeline( contexts=contexts, - group_id=group_id, + end_user_id=end_user_id, chunker_strategy=SELECTED_CHUNKER_STRATEGY, embedding_name=SELECTED_EMBEDDING_ID, save_chunk_output=True diff --git a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py index b7d988c5..1c70c28e 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py +++ b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py @@ -47,7 +47,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.utils.definitions import ( PROJECT_ROOT, SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, + SELECTED_end_user_id, SELECTED_LLM_ID, ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -59,7 +59,7 @@ from app.services.memory_config_service import MemoryConfigService async def run_locomo_benchmark( sample_size: int = 20, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, search_type: str = "hybrid", search_limit: int = 12, context_char_budget: int = 8000, @@ -85,7 +85,7 @@ async def run_locomo_benchmark( Args: sample_size: Number of QA pairs to evaluate (from first conversation) - group_id: Database group ID for retrieval (uses default if None) + end_user_id: Database group ID for retrieval (uses default if None) search_type: "keyword", "embedding", or "hybrid" search_limit: Max documents to retrieve per query context_char_budget: Max characters for context @@ -96,8 +96,8 @@ async def run_locomo_benchmark( Returns: Dictionary with evaluation results including metrics, timing, and samples """ - # Use default group_id if not provided - group_id = group_id or SELECTED_GROUP_ID + # Use default end_user_id if not provided + end_user_id = end_user_id or SELECTED_end_user_id # Determine data path data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") @@ -110,7 +110,7 @@ async def run_locomo_benchmark( print(f"{'='*60}") print("📊 Configuration:") print(f" Sample size: {sample_size}") - print(f" Group ID: {group_id}") + print(f" Group ID: {end_user_id}") print(f" Search type: {search_type}") print(f" Search limit: {search_limit}") print(f" Context budget: {context_char_budget} chars") @@ -134,7 +134,7 @@ async def run_locomo_benchmark( # Step 2: Extract conversations and ingest if needed if skip_ingest: print("⏭️ Skipping data ingestion (using existing data in Neo4j)") - print(f" Group ID: {group_id}\n") + print(f" Group ID: {end_user_id}\n") else: print("💾 Checking database ingestion...") try: @@ -142,10 +142,10 @@ async def run_locomo_benchmark( print(f"📝 Extracted {len(conversations)} conversations") # Always ingest for now (ingestion check not implemented) - print(f"🔄 Ingesting conversations into group '{group_id}'...") + print(f"🔄 Ingesting conversations into group '{end_user_id}'...") success = await ingest_conversations_if_needed( conversations=conversations, - group_id=group_id, + end_user_id=end_user_id, reset=reset_group ) @@ -224,7 +224,7 @@ async def run_locomo_benchmark( try: retrieved_info = await retrieve_relevant_information( question=question, - group_id=group_id, + end_user_id=end_user_id, search_type=search_type, search_limit=search_limit, connector=connector, @@ -409,7 +409,7 @@ async def run_locomo_benchmark( "sample_size": len(qa_items), "timestamp": datetime.now().isoformat(), "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_type": search_type, "search_limit": search_limit, "context_char_budget": context_char_budget, @@ -467,7 +467,7 @@ def main(): help="Number of QA pairs to evaluate" ) parser.add_argument( - "--group_id", + "--end_user_id", type=str, default=None, help="Database group ID for retrieval (uses default if not specified)" @@ -516,7 +516,7 @@ def main(): # Run benchmark result = asyncio.run(run_locomo_benchmark( sample_size=args.sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_type=args.search_type, search_limit=args.search_limit, context_char_budget=args.context_char_budget, diff --git a/api/app/core/memory/evaluation/locomo/locomo_test.py b/api/app/core/memory/evaluation/locomo/locomo_test.py index b5ad5820..b871fb9c 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_test.py +++ b/api/app/core/memory/evaluation/locomo/locomo_test.py @@ -555,7 +555,7 @@ async def run_enhanced_evaluation(): search_results = await run_hybrid_search( query_text=q, search_type="hybrid", - group_id="locomo_sk", + end_user_id="locomo_sk", limit=20, include=["statements", "chunks", "entities", "summaries"], alpha=0.6, # BM25权重 diff --git a/api/app/core/memory/evaluation/locomo/locomo_utils.py b/api/app/core/memory/evaluation/locomo/locomo_utils.py index 69be5da9..d3b74947 100644 --- a/api/app/core/memory/evaluation/locomo/locomo_utils.py +++ b/api/app/core/memory/evaluation/locomo/locomo_utils.py @@ -348,7 +348,7 @@ def select_and_format_information( async def retrieve_relevant_information( question: str, - group_id: str, + end_user_id: str, search_type: str, search_limit: int, connector: Any, @@ -368,7 +368,7 @@ async def retrieve_relevant_information( Args: question: Question to search for - group_id: Database group ID (identifies which conversation memory to search) + end_user_id: Database group ID (identifies which conversation memory to search) search_type: "keyword", "embedding", or "hybrid" search_limit: Max memory pieces to retrieve connector: Neo4j connector instance @@ -396,7 +396,7 @@ async def retrieve_relevant_information( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -455,7 +455,7 @@ async def retrieve_relevant_information( search_results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit ) @@ -491,7 +491,7 @@ async def retrieve_relevant_information( search_results = await run_hybrid_search( query_text=question, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], output_path=None, @@ -524,7 +524,7 @@ async def retrieve_relevant_information( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -584,7 +584,7 @@ async def retrieve_relevant_information( async def ingest_conversations_if_needed( conversations: List[str], - group_id: str, + end_user_id: str, reset: bool = False ) -> bool: """ @@ -603,7 +603,7 @@ async def ingest_conversations_if_needed( Args: conversations: List of raw conversation texts from LoCoMo dataset Example: ["User: I went to Paris. AI: When was that?", ...] - group_id: Target group ID for database storage + end_user_id: Target group ID for database storage reset: Whether to clear existing data first (not implemented in wrapper) Returns: @@ -617,7 +617,7 @@ async def ingest_conversations_if_needed( try: success = await ingest_contexts_via_full_pipeline( contexts=conversations, - group_id=group_id, + end_user_id=end_user_id, save_chunk_output=True ) return success diff --git a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py index 87a70a29..3147e880 100644 --- a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py @@ -30,7 +30,7 @@ from app.core.memory.storage_services.search import run_hybrid_search from app.core.memory.utils.config.definitions import ( PROJECT_ROOT, SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, + SELECTED_end_user_id, SELECTED_LLM_ID, ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -249,7 +249,7 @@ def get_search_params_by_category(category: str): async def run_locomo_eval( sample_size: int = 1, - group_id: str | None = None, + end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, # 保持默认值不变 llm_temperature: float = 0.0, @@ -262,7 +262,7 @@ async def run_locomo_eval( ) -> Dict[str, Any]: # 函数内部使用三路检索逻辑,但保持参数签名不变 - group_id = group_id or SELECTED_GROUP_ID + end_user_id = end_user_id or SELECTED_end_user_id data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") if not os.path.exists(data_path): data_path = os.path.join(os.getcwd(), "data", "locomo10.json") @@ -340,7 +340,7 @@ async def run_locomo_eval( # 关键修复:强制重新摄入纯净的对话数据 print("🔄 强制重新摄入纯净的对话数据...") - await ingest_contexts_via_full_pipeline(contents, group_id, save_chunk_output=True) + await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True) # 使用异步LLM客户端 with get_db_context() as db: @@ -405,7 +405,7 @@ async def run_locomo_eval( connector=connector, embedder_client=embedder, query_text=q, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit, include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型 ) @@ -456,7 +456,7 @@ async def run_locomo_eval( search_results = await search_graph( connector=connector, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit ) dialogs = search_results.get("dialogues", []) @@ -486,7 +486,7 @@ async def run_locomo_eval( search_results = await run_hybrid_search( query_text=q, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit, include=["chunks", "statements", "entities", "summaries"], output_path=None, @@ -524,7 +524,7 @@ async def run_locomo_eval( connector=connector, embedder_client=embedder, query_text=q, - group_id=group_id, + end_user_id=end_user_id, limit=adjusted_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -597,7 +597,7 @@ async def run_locomo_eval( "dialogues": [ { "uuid": d.get("uuid", ""), - "group_id": d.get("group_id", ""), + "end_user_id": d.get("end_user_id", ""), "content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""), "score": d.get("score", 0.0) } @@ -795,7 +795,7 @@ async def run_locomo_eval( }, "samples": samples, "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_limit": search_limit, "context_char_budget": context_char_budget, "search_type": search_type, @@ -825,7 +825,7 @@ async def run_locomo_eval( def main(): parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search") parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate") - parser.add_argument("--group_id", type=str, default=None, help="Group ID for retrieval") + parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval") parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query") parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context") parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature") @@ -841,7 +841,7 @@ def main(): result = asyncio.run(run_locomo_eval( sample_size=args.sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py index 53c5ce19..320f9de7 100644 --- a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py @@ -523,11 +523,11 @@ def generate_query_keywords_cn(question: str) -> List[str]: # 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: +async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: results: List[Dict[str, Any]] = [] try: for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) + rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) if rows: results.extend(rows) except Exception: @@ -547,15 +547,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st # 通过对话/陈述中的entity_ids反查实体名称 _FETCH_ENTITIES_BY_IDS = """ MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type """ -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: +async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: if not ids: return [] try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) + rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) return rows or [] except Exception: return [] @@ -565,18 +565,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou _TIME_ENTITY_SEARCH = """ MATCH (e:ExtractedEntity) WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type LIMIT $limit """ -async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: +async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: """专门搜索时间相关的实体""" try: date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" rows = await connector.execute_query(_TIME_ENTITY_SEARCH, date_pattern=date_pattern, - group_id=group_id, + end_user_id=end_user_id, limit=limit) return rows or [] except Exception: @@ -623,7 +623,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: async def run_longmemeval_test( sample_size: int = 3, - group_id: str = "longmemeval_zh_bak_3", + end_user_id: str = "longmemeval_zh_bak_3", search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, @@ -677,13 +677,13 @@ async def run_longmemeval_test( contexts.extend(selected) print(f"📥 摄入 {len(contexts)} 个上下文到数据库") - if reset_group_before_ingest and group_id: + if reset_group_before_ingest and end_user_id: try: _tmp_conn = Neo4jConnector() - await _tmp_conn.delete_group(group_id) - print(f"🧹 已清空组 {group_id} 的历史图数据") + await _tmp_conn.delete_group(end_user_id) + print(f"🧹 已清空组 {end_user_id} 的历史图数据") except Exception as _e: - print(f"⚠️ 清空组数据失败(忽略继续): {group_id} - {_e}") + print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}") finally: try: await _tmp_conn.close() @@ -695,7 +695,7 @@ async def run_longmemeval_test( else: await _ingest_fn( contexts, - group_id, + end_user_id, save_chunk_output=save_chunk_output, save_chunk_output_path=save_chunk_output_path, ) @@ -750,7 +750,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -795,7 +795,7 @@ async def run_longmemeval_test( search_results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) chunks = search_results.get("chunks", []) @@ -830,7 +830,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], ) @@ -848,7 +848,7 @@ async def run_longmemeval_test( kw_res = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) if isinstance(kw_res, dict): @@ -859,7 +859,7 @@ async def run_longmemeval_test( # 时间推理问题的特殊处理 if is_temporal: # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, group_id, search_limit//2) + time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) if time_entities: kw_entities.extend(time_entities) # 添加时间相关关键词检索 @@ -869,7 +869,7 @@ async def run_longmemeval_test( time_res = await search_graph( connector=connector, q=tk, - group_id=group_id, + end_user_id=end_user_id, limit=2, ) if isinstance(time_res, dict): @@ -880,7 +880,7 @@ async def run_longmemeval_test( # 中文关键词拆分后做别名匹配 cn_tokens = _extract_cn_tokens(question) - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) + alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) if alias_entities: kw_entities.extend(alias_entities) @@ -894,7 +894,7 @@ async def run_longmemeval_test( except Exception: pass if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, group_id) + id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) if id_entities: kw_entities.extend(id_entities) @@ -908,7 +908,7 @@ async def run_longmemeval_test( sub_res = await search_graph( connector=connector, q=str(kw), - group_id=group_id, + end_user_id=end_user_id, limit=max(3, search_limit // 2), ) if isinstance(sub_res, dict): @@ -927,7 +927,7 @@ async def run_longmemeval_test( opt_res = await search_graph( connector=connector, q=str(opt), - group_id=group_id, + end_user_id=end_user_id, limit=max(3, search_limit // 2), ) if isinstance(opt_res, dict): diff --git a/api/app/core/memory/evaluation/longmemeval/test_eval.py b/api/app/core/memory/evaluation/longmemeval/test_eval.py index 08a763e3..a49d48d0 100644 --- a/api/app/core/memory/evaluation/longmemeval/test_eval.py +++ b/api/app/core/memory/evaluation/longmemeval/test_eval.py @@ -498,11 +498,11 @@ def smart_context_selection(contexts: List[str], question: str, max_chars: int = # 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], group_id: str | None, limit: int) -> List[Dict[str, Any]]: +async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: results: List[Dict[str, Any]] = [] try: for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, group_id=group_id, limit=limit) + rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) if rows: results.extend(rows) except Exception: @@ -522,15 +522,15 @@ async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[st # 通过对话/陈述中的entity_ids反查实体名称 _FETCH_ENTITIES_BY_IDS = """ MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type """ -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], group_id: str | None) -> List[Dict[str, Any]]: +async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: if not ids: return [] try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), group_id=group_id) + rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) return rows or [] except Exception: return [] @@ -540,18 +540,18 @@ async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], grou _TIME_ENTITY_SEARCH = """ MATCH (e:ExtractedEntity) WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($group_id IS NULL OR e.group_id = $group_id) -RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type +AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) +RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type LIMIT $limit """ -async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: +async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: """专门搜索时间相关的实体""" try: date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" rows = await connector.execute_query(_TIME_ENTITY_SEARCH, date_pattern=date_pattern, - group_id=group_id, + end_user_id=end_user_id, limit=limit) return rows or [] except Exception: @@ -559,25 +559,25 @@ async def _search_time_entities(connector: Neo4jConnector, group_id: str | None, # 技术术语专门检索 -async def _search_tech_terms(connector: Neo4jConnector, question: str, group_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: +async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: """专门搜索技术术语相关的实体""" tech_entities = [] try: # GPS相关 if any(term in question for term in ["GPS", "导航", "定位系统"]): - gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", group_id=group_id, limit=limit) + gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit) if gps_rows: tech_entities.extend(gps_rows) # 活动相关 if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]): - workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", group_id=group_id, limit=limit) + workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit) if workshop_rows: tech_entities.extend(workshop_rows) # 时间顺序相关 if any(term in question for term in ["先", "后", "第一个"]): - time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", group_id=group_id, limit=limit) + time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit) if time_rows: tech_entities.extend(time_rows) @@ -627,7 +627,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: async def run_longmemeval_test( sample_size: int = 3, - group_id: str = "longmemeval_zh_bak_2", + end_user_id: str = "longmemeval_zh_bak_2", search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, @@ -707,7 +707,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["dialogues", "statements", "entities"], ) @@ -746,7 +746,7 @@ async def run_longmemeval_test( search_results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) dialogs = search_results.get("dialogues", []) @@ -776,7 +776,7 @@ async def run_longmemeval_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["dialogues", "statements", "entities"], ) @@ -792,7 +792,7 @@ async def run_longmemeval_test( kw_res = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, ) if isinstance(kw_res, dict): @@ -801,14 +801,14 @@ async def run_longmemeval_test( kw_entities = kw_res.get("entities", []) or [] # 技术术语专门检索 - tech_entities = await _search_tech_terms(connector, question, group_id, search_limit//2) + tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2) if tech_entities: kw_entities.extend(tech_entities) # 时间推理问题的特殊处理 if is_temporal: # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, group_id, search_limit//2) + time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) if time_entities: kw_entities.extend(time_entities) # 添加时间相关关键词检索 @@ -818,7 +818,7 @@ async def run_longmemeval_test( time_res = await search_graph( connector=connector, q=tk, - group_id=group_id, + end_user_id=end_user_id, limit=2, ) if isinstance(time_res, dict): @@ -829,7 +829,7 @@ async def run_longmemeval_test( # 中文关键词拆分后做别名匹配 cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取 - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, group_id, search_limit) + alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) if alias_entities: kw_entities.extend(alias_entities) @@ -843,7 +843,7 @@ async def run_longmemeval_test( except Exception: pass if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, group_id) + id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) if id_entities: kw_entities.extend(id_entities) @@ -857,7 +857,7 @@ async def run_longmemeval_test( sub_res = await search_graph( connector=connector, q=str(kw), - group_id=group_id, + end_user_id=end_user_id, limit=max(3, search_limit // 2), ) if isinstance(sub_res, dict): @@ -876,7 +876,7 @@ async def run_longmemeval_test( opt_res = await search_graph( connector=connector, q=str(opt), - group_id=group_id, + end_user_id=group_id, limit=max(3, search_limit // 2), ) if isinstance(opt_res, dict): diff --git a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py b/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py index 6efb66ff..ec147f3c 100644 --- a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py +++ b/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py @@ -27,7 +27,7 @@ from app.core.memory.storage_services.search import run_hybrid_search from app.core.memory.utils.config.definitions import ( PROJECT_ROOT, SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, + SELECTED_end_user_id, SELECTED_LLM_ID, ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -135,8 +135,8 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any return merged -async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: - group_id = group_id or SELECTED_GROUP_ID +async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: + end_user_id = end_user_id or SELECTED_end_user_id # Load data data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") if not os.path.exists(data_path): @@ -147,7 +147,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 # 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 contexts: List[str] = [build_context_from_dialog(item) for item in items] - await ingest_contexts_via_full_pipeline(contexts, group_id) + await ingest_contexts_via_full_pipeline(contexts, end_user_id) # LLM client (使用异步调用) with get_db_context() as db: @@ -173,7 +173,7 @@ async def run_memsciqa_eval(sample_size: int = 1, group_id: str | None = None, s results = await run_hybrid_search( query_text=question, search_type=search_type, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["dialogues", "statements", "entities"], output_path=None, @@ -298,7 +298,7 @@ def main(): load_dotenv() parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") - parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json") + parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") @@ -309,7 +309,7 @@ def main(): result = asyncio.run( run_memsciqa_eval( sample_size=args.sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py index 279f4042..631035aa 100644 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py +++ b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py @@ -33,7 +33,7 @@ from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.utils.config.definitions import ( PROJECT_ROOT, SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, + SELECTED_end_user_id, SELECTED_LLM_ID, ) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -198,7 +198,7 @@ def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]: async def run_memsciqa_test( sample_size: int = 3, - group_id: str | None = None, + end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, @@ -216,7 +216,7 @@ async def run_memsciqa_test( """ # 默认使用指定的 memsci 组 ID - group_id = group_id or "group_memsci" + end_user_id = end_user_id or "group_memsci" # 数据路径解析(项目根与当前工作目录兜底) if not data_path: @@ -282,7 +282,7 @@ async def run_memsciqa_test( connector=connector, embedder_client=embedder, query_text=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues ) @@ -291,7 +291,7 @@ async def run_memsciqa_test( results = await search_graph( connector=connector, q=question, - group_id=group_id, + end_user_id=end_user_id, limit=search_limit, include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues ) @@ -499,7 +499,7 @@ async def run_memsciqa_test( }, "samples": samples, "params": { - "group_id": group_id, + "end_user_id": end_user_id, "search_limit": search_limit, "context_char_budget": context_char_budget, "llm_temperature": llm_temperature, @@ -542,7 +542,7 @@ def main(): result = asyncio.run( run_memsciqa_test( sample_size=sample_size, - group_id=args.group_id, + end_user_id=args.end_user_id, search_limit=args.search_limit, context_char_budget=args.context_char_budget, llm_temperature=args.llm_temperature, diff --git a/api/app/core/memory/evaluation/run_eval.py b/api/app/core/memory/evaluation/run_eval.py index 1de3de89..f665bdb8 100644 --- a/api/app/core/memory/evaluation/run_eval.py +++ b/api/app/core/memory/evaluation/run_eval.py @@ -15,7 +15,7 @@ except Exception: return None from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT +from app.core.memory.utils.config.definitions import SELECTED_end_user_id, PROJECT_ROOT from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test @@ -26,7 +26,7 @@ async def run( dataset: str, sample_size: int, reset_group: bool, - group_id: str | None, + end_user_id: str | None, judge_model: str | None = None, search_limit: int | None = None, context_char_budget: int | None = None, @@ -37,17 +37,17 @@ async def run( max_contexts_per_item: int | None = None, ) -> Dict[str, Any]: # 恢复原始风格:统一入口做路由,并沿用各数据集既有默认 - group_id = group_id or SELECTED_GROUP_ID + end_user_id = end_user_id or SELECTED_end_user_id if reset_group: connector = Neo4jConnector() try: - await connector.delete_group(group_id) + await connector.delete_group(end_user_id) finally: await connector.close() if dataset == "locomo": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} + kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} if search_limit is not None: kwargs["search_limit"] = search_limit if context_char_budget is not None: @@ -61,7 +61,7 @@ async def run( return await run_locomo_eval(**kwargs) if dataset == "memsciqa": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} + kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} if search_limit is not None: kwargs["search_limit"] = search_limit if context_char_budget is not None: @@ -75,7 +75,7 @@ async def run( return await run_memsciqa_eval(**kwargs) if dataset == "longmemeval": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "group_id": group_id} + kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} if search_limit is not None: kwargs["search_limit"] = search_limit if context_char_budget is not None: @@ -99,8 +99,8 @@ def main(): parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo") parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True) parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通") - parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 group_id 的图数据") - parser.add_argument("--group-id", type=str, default=None, help="可选 group_id,默认取 runtime.json") + parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据") + parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名") parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)") parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)") @@ -117,7 +117,7 @@ def main(): args.dataset, args.sample_size, args.reset_group, - args.group_id, + args.end_user_id, args.judge_model, args.search_limit, args.context_char_budget, diff --git a/api/app/core/memory/models/config_models.py b/api/app/core/memory/models/config_models.py index f3341cc5..ca1780aa 100644 --- a/api/app/core/memory/models/config_models.py +++ b/api/app/core/memory/models/config_models.py @@ -72,7 +72,7 @@ class TemporalSearchParams(BaseModel): """Parameters for temporal search queries in the knowledge graph. Attributes: - group_id: Group ID to filter search results (default: 'test') + end_user_id: Group ID to filter search results (default: 'test') apply_id: Application ID to filter search results user_id: User ID to filter search results start_date: Start date for temporal filtering (format: 'YYYY-MM-DD') @@ -81,7 +81,7 @@ class TemporalSearchParams(BaseModel): invalid_date: Date when memory should be invalid (format: 'YYYY-MM-DD') limit: Maximum number of results to return (default: 3) """ - group_id: Optional[str] = Field("test", description="The group ID to filter the search.") + end_user_id: Optional[str] = Field("test", description="The group ID to filter the search.") apply_id: Optional[str] = Field(None, description="The apply ID to filter the search.") user_id: Optional[str] = Field(None, description="The user ID to filter the search.") start_date: Optional[str] = Field(None, description="The start date for the search.") diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 7a48d6cb..79b88fdc 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -103,9 +103,7 @@ class Edge(BaseModel): id: Unique identifier for the edge source: ID of the source node target: ID of the target node - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data + end_user_id: End user ID for multi-tenancy run_id: Unique identifier for the pipeline run that created this edge created_at: Timestamp when the edge was created (system perspective) expired_at: Optional timestamp when the edge expires (system perspective) @@ -113,9 +111,7 @@ class Edge(BaseModel): id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the edge.") source: str = Field(..., description="The ID of the source node.") target: str = Field(..., description="The ID of the target node.") - group_id: str = Field(..., description="The group ID of the edge.") - user_id: str = Field(..., description="The user ID of the edge.") - apply_id: str = Field(..., description="The apply ID of the edge.") + end_user_id: str = Field(..., description="The end user ID of the edge.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.") @@ -185,18 +181,14 @@ class Node(BaseModel): Attributes: id: Unique identifier for the node name: Name of the node - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data + end_user_id: End user ID for multi-tenancy run_id: Unique identifier for the pipeline run that created this node created_at: Timestamp when the node was created (system perspective) expired_at: Optional timestamp when the node expires (system perspective) """ id: str = Field(..., description="The unique identifier for the node.") name: str = Field(..., description="The name of the node.") - group_id: str = Field(..., description="The group ID of the node.") - user_id: str = Field(..., description="The user ID of the edge.") - apply_id: str = Field(..., description="The apply ID of the edge.") + end_user_id: str = Field(..., description="The end user ID of the node.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(..., description="The valid time of the node from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the node from system perspective.") diff --git a/api/app/core/memory/models/message_models.py b/api/app/core/memory/models/message_models.py index bcf08999..c660d841 100644 --- a/api/app/core/memory/models/message_models.py +++ b/api/app/core/memory/models/message_models.py @@ -55,7 +55,7 @@ class Statement(BaseModel): Attributes: id: Unique identifier for the statement chunk_id: ID of the parent chunk this statement belongs to - group_id: Optional group ID for multi-tenancy + end_user_id: Optional group ID for multi-tenancy statement: The actual statement text content speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses) statement_embedding: Optional embedding vector for the statement @@ -73,7 +73,7 @@ class Statement(BaseModel): """ id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the statement.") chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.") - group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") + end_user_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.") statement: str = Field(..., description="The text content of the statement.") speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses") statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.") @@ -159,9 +159,7 @@ class DialogData(BaseModel): context: Full conversation context dialog_embedding: Optional embedding vector for the entire dialog ref_id: Reference ID linking to external dialog system - group_id: Group ID for multi-tenancy - user_id: User ID for user-specific data - apply_id: Application ID for application-specific data + end_user_id: End user ID for multi-tenancy created_at: Timestamp when the dialog was created expired_at: Timestamp when the dialog expires (default: far future) metadata: Additional metadata as key-value pairs @@ -175,9 +173,7 @@ class DialogData(BaseModel): context: ConversationContext = Field(..., description="The full conversation context as a single string.") dialog_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the dialog.") ref_id: str = Field(..., description="Refer to external dialog id. This is used to link to the original dialog.") - group_id: str = Field(default=..., description="Group ID of dialogue data") - user_id: str = Field(..., description="USER ID of dialogue data") - apply_id: str = Field(..., description="APPLY ID of dialogue data") + end_user_id: str = Field(default=..., description="End user ID of dialogue data") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") created_at: datetime = Field(default_factory=datetime.now, description="The timestamp when the dialog was created.") expired_at: datetime = Field(default_factory=lambda: datetime(9999, 12, 31), description="The timestamp when the dialog expires.") @@ -256,5 +252,5 @@ class DialogData(BaseModel): """ for chunk in self.chunks: for statement in chunk.statements: - if statement.group_id is None: - statement.group_id = self.group_id + if statement.end_user_id is None: + statement.end_user_id = self.end_user_id diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 91e47eae..345cd69b 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -6,6 +6,7 @@ import os import time from datetime import datetime from typing import TYPE_CHECKING, Any, Dict, List, Optional +from uuid import UUID if TYPE_CHECKING: from app.schemas.memory_config_schema import MemoryConfig @@ -396,13 +397,13 @@ def rerank_with_activation( return reranked -def log_search_query(query_text: str, search_type: str, group_id: str | None, limit: int, include: List[str], log_file: str = None): +def log_search_query(query_text: str, search_type: str, end_user_id: str | None, limit: int, include: List[str], log_file: str = None): """Log search query information using the logger. Args: query_text: The search query text search_type: Type of search (keyword, embedding, hybrid) - group_id: Group identifier for filtering + end_user_id: Group identifier for filtering limit: Maximum number of results include: List of result types to include log_file: Deprecated parameter, kept for backward compatibility @@ -413,7 +414,7 @@ def log_search_query(query_text: str, search_type: str, group_id: str | None, li # Log using the standard logger logger.info( f"Search query: query='{cleaned_query}', type={search_type}, " - f"group_id={group_id}, limit={limit}, include={include}" + f"end_user_id={end_user_id}, limit={limit}, include={include}" ) @@ -672,7 +673,7 @@ def apply_reranker_placeholder( async def run_hybrid_search( query_text: str, search_type: str, - group_id: str | None, + end_user_id: str | None, limit: int, include: List[str], output_path: str | None, @@ -692,6 +693,9 @@ async def run_hybrid_search( # Start overall timing search_start_time = time.time() latency_metrics = {} + print(100*'-') + print(memory_config) + print(100 * '-') logger.info(f"using embedding_id:{memory_config.embedding_model_id}...") # Clean and normalize the incoming query before use/logging @@ -715,7 +719,7 @@ async def run_hybrid_search( } # Log the search query - log_search_query(query_text, search_type, group_id, limit, include) + log_search_query(query_text, search_type, end_user_id, limit, include) connector = Neo4jConnector() results = {} @@ -732,7 +736,7 @@ async def run_hybrid_search( search_graph( connector=connector, q=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include ) @@ -769,7 +773,7 @@ async def run_hybrid_search( connector=connector, embedder_client=embedder, query_text=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include, ) @@ -916,9 +920,7 @@ async def run_hybrid_search( async def search_by_temporal( - group_id: Optional[str] = "test", - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = "test", start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -929,7 +931,7 @@ async def search_by_temporal( Temporal search across Statements. - Matches statements created between start_date and end_date - - Optionally filters by group_id + - Optionally filters by end_user_id - Returns up to 'limit' statements """ connector = Neo4jConnector() @@ -939,9 +941,7 @@ async def search_by_temporal( end_date = normalize_date_safe(end_date) params = TemporalSearchParams.model_validate({ - "group_id": group_id, - "apply_id": apply_id, - "user_id": user_id, + "end_user_id": end_user_id, "start_date": start_date, "end_date": end_date, "valid_date": valid_date, @@ -950,9 +950,7 @@ async def search_by_temporal( }) statements = await search_graph_by_temporal( connector=connector, - group_id=params.group_id, - apply_id=params.apply_id, - user_id=params.user_id, + end_user_id=params.end_user_id, start_date=params.start_date, end_date=params.end_date, valid_date=params.valid_date, @@ -964,9 +962,7 @@ async def search_by_temporal( async def search_by_keyword_temporal( query_text: str, - group_id: Optional[str] = "test", - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = "test", start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -987,9 +983,7 @@ async def search_by_keyword_temporal( invalid_date = normalize_date_safe(invalid_date) params = TemporalSearchParams.model_validate({ - "group_id": group_id, - "apply_id": apply_id, - "user_id": user_id, + "end_user_id": end_user_id, "start_date": start_date, "end_date": end_date, "valid_date": valid_date, @@ -999,9 +993,7 @@ async def search_by_keyword_temporal( statements = await search_graph_by_keyword_temporal( connector=connector, query_text=query_text, - group_id=params.group_id, - apply_id=params.apply_id, - user_id=params.user_id, + end_user_id=params.end_user_id, start_date=params.start_date, end_date=params.end_date, valid_date=params.valid_date, @@ -1013,7 +1005,7 @@ async def search_by_keyword_temporal( async def search_chunk_by_chunk_id( chunk_id: str, - group_id: Optional[str] = "test", + end_user_id: Optional[str] = "test", limit: int = 1, ): """ @@ -1023,8 +1015,68 @@ async def search_chunk_by_chunk_id( chunks = await search_graph_by_chunk_id( connector=connector, chunk_id=chunk_id, - group_id=group_id, + end_user_id=end_user_id, limit=limit ) return {"chunks": chunks} +if __name__ == '__main__': + # 测试混合检索功能 + from app.schemas.memory_config_schema import MemoryConfig + from app.db import get_db + from app.services.memory_config_service import MemoryConfigService + + # 从数据库获取真实配置 + db = next(get_db()) + try: + config_service = MemoryConfigService(db) + + # 使用 config_id=17 获取配置 + memory_config = config_service.load_memory_config(config_id=17) + + if not memory_config: + print("错误:找不到 config_id=17 的配置") + print("请先在数据库中创建配置,或修改 config_id") + exit(1) + + print(f"✓ 成功加载配置: {memory_config.config_name}") + print(f" - Workspace: {memory_config.workspace_name}") + print(f" - LLM Model: {memory_config.llm_model_name}") + print(f" - Embedding Model: {memory_config.embedding_model_name}") + print(f" - Storage Type: {memory_config.storage_type}") + print() + + # 修改这里的参数进行测试 + test_end_user_id = "021886bc-fab9-4fd5-b607-497b262e0381" # 修改为你的 end_user_id + test_query = "小明擅长什么?" # 修改为你的查询 + + print(f"开始测试检索...") + print(f" - Query: {test_query}") + print(f" - End User ID: {test_end_user_id}") + print(f" - Search Type: hybrid") + print() + + results = asyncio.run(run_hybrid_search( + query_text=test_query, + search_type="hybrid", # 可选: "keyword", "embedding", "hybrid" + end_user_id=test_end_user_id, + limit=10, + include=["statements", "entities", "chunks", "summaries"], + output_path=None, + memory_config=memory_config, + rerank_alpha=0.6, + use_forgetting_rerank=False, + use_llm_rerank=False + )) + + print("=" * 80) + print("检索结果:") + print("=" * 80) + print(results) + + except Exception as e: + print(f"错误: {e}") + import traceback + traceback.print_exc() + finally: + db.close() diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py index f5e72517..4dafd3ed 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_preprocessor.py @@ -555,8 +555,8 @@ class DataPreprocessor: dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) - # 获取group_id,如果不存在则生成默认值 - group_id = item.get('group_id', f'group_default_{i}') + # 获取end_user_id,如果不存在则生成默认值 + end_user_id = item.get('end_user_id', f'group_default_{i}') user_id = item.get('user_id', f'user_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}') @@ -574,7 +574,7 @@ class DataPreprocessor: dialog_data = DialogData( context=context, ref_id=dialog_id, - group_id=group_id, + end_user_id=end_user_id, user_id=user_id, apply_id=apply_id, metadata=metadata @@ -644,7 +644,7 @@ class DataPreprocessor: context = ConversationContext(msgs=messages) dialog_id = item.get('dialog_id', item.get('ref_id', item.get('id', f'dialog_{i}'))) - group_id = item.get('group_id', f'group_default_{i}') + end_user_id = item.get('end_user_id', f'group_default_{i}') user_id = item.get('user_id', f'user_default_{i}') apply_id = item.get('apply_id', f'apply_default_{i}') @@ -657,7 +657,7 @@ class DataPreprocessor: dialog_data = DialogData( context=context, ref_id=dialog_id, - group_id=group_id, + end_user_id=end_user_id, user_id=user_id, apply_id=apply_id, metadata=metadata diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py index 62b656b0..a425e0ed 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/deduped_and_disamb.py @@ -199,7 +199,7 @@ def accurate_match( entity_nodes: List[ExtractedEntityNode] ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: """ - 精确匹配:按 (group_id, name, entity_type) 合并实体并建立重定向与合并记录。 + 精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。 返回: (deduped_entities, id_redirect, exact_merge_map) """ exact_merge_map: Dict[str, Dict] = {} @@ -210,8 +210,8 @@ def accurate_match( for ent in entity_nodes: name_norm = (getattr(ent, "name", "") or "").strip() type_norm = (getattr(ent, "entity_type", "") or "").strip() - key = f"{getattr(ent, 'group_id', None)}|{name_norm}|{type_norm}" - # 为避免跨业务组误并,明确以 group_id 为范围边界 + key = f"{getattr(ent, 'end_user_id', None)}|{name_norm}|{type_norm}" + # 为避免跨业务组误并,明确以 end_user_id 为范围边界 if key not in canonical_map: canonical_map[key] = ent id_redirect[ent.id] = ent.id @@ -223,11 +223,11 @@ def accurate_match( id_redirect[ent.id] = canonical.id # 记录精确匹配的合并项(使用规范化键,避免外层变量误用) try: - k = f"{canonical.group_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" + k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}" if k not in exact_merge_map: exact_merge_map[k] = { "canonical_id": canonical.id, - "group_id": canonical.group_id, + "end_user_id": canonical.end_user_id, "name": canonical.name, "entity_type": canonical.entity_type, "merged_ids": set(), @@ -596,7 +596,7 @@ def fuzzy_match( b = deduped_entities[j] # 跳过不同业务组的实体 - if getattr(a, "group_id", None) != getattr(b, "group_id", None): + if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None): j += 1 continue @@ -671,7 +671,7 @@ def fuzzy_match( merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" merge_reason = "[别名匹配]" if alias_match_merge else "[模糊]" fuzzy_merge_records.append( - f"{merge_reason} 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type}) | " + f"{merge_reason} 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type}) | " f"s_name={s_name:.3f}, s_type={s_type:.3f}, overall={overall:.3f}, exact_alias={has_exact_match}" ) except Exception: @@ -779,7 +779,7 @@ async def LLM_decision( # 决策中包含去重和消歧的功能 # 记录 LLM 融合日志 try: llm_records.append( - f"[LLM融合] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" + f"[LLM融合] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})" ) # 详细的“同类名称相似”记录改由 LLM 去重模块统一生成以携带 conf/reason except Exception: @@ -847,7 +847,7 @@ async def LLM_disamb_decision( id_redirect[k] = a.id try: disamb_records.append( - f"[DISAMB合并应用] 规范实体 {a.id} ({a.group_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.group_id}|{b.name}|{b.entity_type})" + f"[DISAMB合并应用] 规范实体 {a.id} ({a.end_user_id}|{a.name}|{a.entity_type}) <- 合并实体 {b.id} ({b.end_user_id}|{b.name}|{b.entity_type})" ) except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py index 734f7b69..0249ac1f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/entity_dedup_llm.py @@ -174,7 +174,7 @@ async def _judge_pair( pass # 3. 构建LLM判断的“上下文信息”(规则层计算的所有特征) 判断上下文特征有助于实体消歧首先判断的类型关系 ctx = { - "same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), + "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "type_similarity": _type_similarity(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "name_text_sim": name_text_sim, @@ -235,7 +235,7 @@ async def _judge_pair_disamb( except Exception: pass ctx = { - "same_group": getattr(a, "group_id", None) == getattr(b, "group_id", None), + "same_group": getattr(a, "end_user_id", None) == getattr(b, "end_user_id", None), "type_ok": _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)), "name_text_sim": name_text_sim, "name_embed_sim": name_embed_sim, @@ -317,8 +317,8 @@ async def llm_dedup_entities( # 保留对偶判断作为子流程,是为了 a = entity_nodes[i] for j in range(i + 1, len(entity_nodes)): b = entity_nodes[j] - # 规则1:必须属于同一组(group_id相同,不同组的实体不重复) - if getattr(a, "group_id", None) != getattr(b, "group_id", None): + # 规则1:必须属于同一组(end_user_id相同,不同组的实体不重复) + if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None): continue # 规则2:类型必须兼容(调用_simple_type_ok判断) if not _simple_type_ok(getattr(a, "entity_type", None), getattr(b, "entity_type", None)): @@ -474,7 +474,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 - max_rounds: upper bound for iterative passes (default 3) - auto_merge_threshold: decision confidence for auto-merge when no co-occurrence (default 0.90) - co_ctx_threshold: lower threshold when co-occurrence is detected (default 0.83) - - shuffle_each_round: whether to shuffle entities within group_id each round to vary block composition + - shuffle_each_round: whether to shuffle entities within end_user_id each round to vary block composition Returns: - global_redirect: dict losing_id -> canonical_id accumulated across rounds @@ -509,7 +509,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 def _partition_blocks(nodes: List[ExtractedEntityNode]) -> List[List[ExtractedEntityNode]]: """ - 按 group_id 分块,避免跨组实体在同一块,减少无效候选对 + 按 end_user_id 分块,避免跨组实体在同一块,减少无效候选对 Args: nodes: 实体节点列表 @@ -519,7 +519,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 """ groups: Dict[str, List[ExtractedEntityNode]] = {} for e in nodes: - gid = getattr(e, "group_id", None) + gid = getattr(e, "end_user_id", None) groups.setdefault(str(gid), []).append(e) blocks: List[List[ExtractedEntityNode]] = [] for gid, arr in groups.items(): @@ -559,7 +559,7 @@ async def llm_dedup_entities_iterative_blocks( # 迭代分块并发 LLM 去重 # Collapse nodes to canonical reps before each round to avoid redundant comparisons # 步骤1:折叠实体(合并已确定的重复实体,减少后续计算量) current_nodes = _collapse_nodes(current_nodes) - # 步骤2:分块(按group_id分块,避免跨组处理) + # 步骤2:分块(按end_user_id分块,避免跨组处理) blocks = _partition_blocks(current_nodes) if not blocks: # 无块可处理(实体已全部折叠),退出循环 break @@ -645,7 +645,7 @@ async def llm_disambiguate_pairs_iterative( a = entity_nodes[i] b = entity_nodes[j] # 必须同组 - if getattr(a, "group_id", None) != getattr(b, "group_id", None): + if getattr(a, "end_user_id", None) != getattr(b, "end_user_id", None): continue ta = getattr(a, "entity_type", None) tb = getattr(b, "entity_type", None) diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py index b41f35a4..dbc697d9 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/second_layer_dedup.py @@ -61,7 +61,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: return ExtractedEntityNode( id=row.get("id"), name=row.get("name") or "", - group_id=row.get("group_id") or "", + end_user_id=row.get("end_user_id") or "", user_id=row.get("user_id") or "", apply_id=row.get("apply_id") or "", created_at=_parse_dt(row.get("created_at")), @@ -79,7 +79,7 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode: async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑,与 Neo4j 中同组实体联合去重 connector: Neo4jConnector, - group_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重 + end_user_id: str, # 用于定位neo4j中同一组的实体,确保只在同组内去重 entity_nodes: List[ExtractedEntityNode], # 输入的实体节点列表,包含待去重的实体 statement_entity_edges: List[StatementEntityEdge], # 输入的语句实体边列表,用于处理实体之间的关系 entity_entity_edges: List[EntityEntityEdge], # 输入的实体实体边列表,用于处理实体之间的关系 @@ -88,7 +88,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑 ) -> Tuple[List[ExtractedEntityNode], List[StatementEntityEdge], List[EntityEntityEdge]]: """ 第二层去重消歧: - - 以第一层结果为索引,检索相同 group_id 下的 DB 候选实体 + - 以第一层结果为索引,检索相同 end_user_id 下的 DB 候选实体 - 将 DB 候选与当前实体集合联合,按既有精确/模糊/LLM 决策进行融合 - 返回融合后的实体与重定向后的边(边已指向规范 ID,优先 DB ID) """ @@ -102,7 +102,7 @@ async def second_layer_dedup_and_merge_with_neo4j( # 二层去重的核心逻辑 ] candidates_map = await get_dedup_candidates_for_entities( # 从 Neo4j 中查询候选实体,并将结果赋值给candidates_map(等待异步操作完成)。 - connector=connector, group_id=group_id, + connector=connector, end_user_id=end_user_id, entities=incoming_rows, # 传入参数:第一层实体的核心信息(作为查询索引) use_contains_fallback=True # 传入参数:启用 “包含关系” 作为匹配失败的降级策略(若精确匹配无结果,用包含关系召回候选),与src\database\cypher_queries.py的307产生联动 ) diff --git a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py index 11845d7d..f28b8a5f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py +++ b/api/app/core/memory/storage_services/extraction_engine/deduplication/two_stage_dedup.py @@ -57,11 +57,11 @@ async def dedup_layers_and_merge_and_return( if pipeline_config is None: raise ValueError("pipeline_config is required for dedup_layers_and_merge_and_return") - # 先探测 group_id,决定报告写入策略 - group_id: Optional[str] = None + # 先探测 end_user_id,决定报告写入策略 + end_user_id: Optional[str] = None for dd in dialog_data_list: - group_id = getattr(dd, "group_id", None) - if group_id: + end_user_id = getattr(dd, "end_user_id", None) + if end_user_id: break # 第一层去重消歧 @@ -82,11 +82,11 @@ async def dedup_layers_and_merge_and_return( # 第二层去重消歧:与 Neo4j 中同组实体联合融合 try: - if group_id: + if end_user_id: if connector: fused_entity_nodes, fused_statement_entity_edges, fused_entity_entity_edges = await second_layer_dedup_and_merge_with_neo4j( connector=connector, - group_id=group_id, + end_user_id=end_user_id, entity_nodes=dedup_entity_nodes, statement_entity_edges=dedup_statement_entity_edges, entity_entity_edges=dedup_entity_entity_edges, @@ -96,7 +96,7 @@ async def dedup_layers_and_merge_and_return( else: print("Skip second-layer dedup: missing connector") else: - print("Skip second-layer dedup: missing group_id") + print("Skip second-layer dedup: missing end_user_id") except Exception as e: print(f"Second-layer dedup failed: {e}") diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 46ba1dde..c2c5d54e 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -287,7 +287,7 @@ class ExtractionOrchestrator: for d_idx, dialog in enumerate(dialog_data_list): dialogue_content = dialog.content if self.config.statement_extraction.include_dialogue_context else None for c_idx, chunk in enumerate(dialog.chunks): - all_chunks.append((chunk, dialog.group_id, dialogue_content)) + all_chunks.append((chunk, dialog.end_user_id, dialogue_content)) chunk_metadata.append((d_idx, c_idx)) logger.info(f"收集到 {len(all_chunks)} 个分块,开始全局并行提取") @@ -299,9 +299,9 @@ class ExtractionOrchestrator: # 全局并行处理所有分块 async def extract_for_chunk(chunk_data, chunk_index): nonlocal completed_chunks - chunk, group_id, dialogue_content = chunk_data + chunk, end_user_id, dialogue_content = chunk_data try: - statements = await self.statement_extractor._extract_statements(chunk, group_id, dialogue_content) + statements = await self.statement_extractor._extract_statements(chunk, end_user_id, dialogue_content) # 流式输出:每提取完一个分块的陈述句,立即发送进度 # 注意:只在试运行模式下发送陈述句详情,正式模式不发送 @@ -992,9 +992,7 @@ class ExtractionOrchestrator: id=dialog_data.id, name=f"Dialog_{dialog_data.id}", # 添加必需的 name 字段 ref_id=dialog_data.ref_id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id content=dialog_data.context.content if dialog_data.context else "", dialog_embedding=dialog_data.dialog_embedding if hasattr(dialog_data, 'dialog_embedding') else None, @@ -1012,9 +1010,7 @@ class ExtractionOrchestrator: id=chunk.id, name=f"Chunk_{chunk.id}", # 添加必需的 name 字段 dialog_id=dialog_data.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id content=chunk.content, chunk_embedding=chunk.chunk_embedding, @@ -1035,9 +1031,7 @@ class ExtractionOrchestrator: stmt_type=getattr(statement, 'stmt_type', 'general'), # 添加必需的 stmt_type 字段 temporal_info=getattr(statement, 'temporal_info', TemporalInfo.ATEMPORAL), # 添加必需的 temporal_info 字段 connect_strength=statement.connect_strength if statement.connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段 - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id statement=statement.statement, speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段 @@ -1060,9 +1054,7 @@ class ExtractionOrchestrator: statement_chunk_edge = StatementChunkEdge( source=statement.id, target=chunk.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, ) @@ -1095,9 +1087,7 @@ class ExtractionOrchestrator: aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases name_embedding=getattr(entity, 'name_embedding', None), is_explicit_memory=getattr(entity, 'is_explicit_memory', False), # 新增:传递语义记忆标记 - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, @@ -1112,9 +1102,7 @@ class ExtractionOrchestrator: source=statement.id, target=entity.id, connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, ) @@ -1134,9 +1122,7 @@ class ExtractionOrchestrator: relation_type=triplet.predicate, statement=statement.statement, source_statement_id=statement.id, - group_id=dialog_data.group_id, - user_id=dialog_data.user_id, - apply_id=dialog_data.apply_id, + end_user_id=dialog_data.end_user_id, run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id created_at=dialog_data.created_at, expired_at=dialog_data.expired_at, @@ -1763,14 +1749,14 @@ class ExtractionOrchestrator: async def get_chunked_dialogs( chunker_strategy: str = "RecursiveChunker", - group_id: str = "group_1", + end_user_id: str = "group_1", indices: Optional[List[int]] = None, ) -> List[DialogData]: """从测试数据生成分块对话 Args: chunker_strategy: 分块策略(默认: RecursiveChunker) - group_id: 组ID + end_user_id: 组ID indices: 要处理的数据索引列表(可选) Returns: @@ -1834,7 +1820,7 @@ async def get_chunked_dialogs( dialog_data = DialogData( context=conversation_context, ref_id=data['id'], - group_id=group_id, + end_user_id=end_user_id, metadata=dialog_metadata, ) @@ -1936,7 +1922,7 @@ async def get_chunked_dialogs_from_preprocessed( async def get_chunked_dialogs_with_preprocessing( chunker_strategy: str = "RecursiveChunker", - group_id: str = "default", + end_user_id: str = "default", user_id: str = "default", apply_id: str = "default", indices: Optional[List[int]] = None, @@ -1948,7 +1934,7 @@ async def get_chunked_dialogs_with_preprocessing( Args: chunker_strategy: 分块策略 - group_id: 组ID + end_user_id: 组ID user_id: 用户ID apply_id: 应用ID indices: 要处理的数据索引列表 @@ -1976,11 +1962,9 @@ async def get_chunked_dialogs_with_preprocessing( indices=indices, ) - # 设置 group_id, user_id, apply_id + # 设置 end_user_id for dd in preprocessed_data: - dd.group_id = group_id - dd.user_id = user_id - dd.apply_id = apply_id + dd.end_user_id = end_user_id # 步骤2: 语义剪枝 try: diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py index 7e75fd2d..f39313a8 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/memory_summary.py @@ -193,9 +193,9 @@ async def _process_chunk_summary( node = MemorySummaryNode( id=uuid4().hex, name=title if title else f"MemorySummaryChunk_{chunk.id}", - group_id=dialog.group_id, - user_id=dialog.user_id, - apply_id=dialog.apply_id, + end_user_id=dialog.end_user_id, + user_id=dialog.end_user_id, + apply_id=dialog.end_user_id, run_id=dialog.run_id, # 使用 dialog 的 run_id created_at=datetime.now(), expired_at=datetime(9999, 12, 31), diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py index fb1b539a..b06bd70f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/statement_extraction.py @@ -82,12 +82,12 @@ class StatementExtractor: logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty") return None - async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: + async def _extract_statements(self, chunk, end_user_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]: """Process a single chunk and return extracted statements Args: chunk: Chunk object to process - group_id: Group ID to assign to all statements in this chunk + end_user_id: Group ID to assign to all statements in this chunk dialogue_content: Full dialogue content to provide as context Returns: @@ -158,7 +158,7 @@ class StatementExtractor: temporal_info=temporal_type, relevence_info=relevence_info, chunk_id=chunk.id, - group_id=group_id, + end_user_id=end_user_id, speaker=chunk_speaker, ) @@ -184,10 +184,10 @@ class StatementExtractor: logger.info(f"Processing {len(chunks_to_process)} chunks for statement extraction") - # Process all chunks concurrently, passing the group_id and dialogue content from dialog_data + # Process all chunks concurrently, passing the end_user_id and dialogue content from dialog_data dialogue_content = dialog_data.content if self.config.include_dialogue_context else None results = await asyncio.gather( - *[self._extract_statements(chunk, dialog_data.group_id, dialogue_content) for chunk in chunks_to_process], + *[self._extract_statements(chunk, dialog_data.end_user_id, dialogue_content) for chunk in chunks_to_process], return_exceptions=True ) @@ -225,7 +225,7 @@ class StatementExtractor: for i, statement in enumerate(statements, 1): f.write(f"Statement {i}:\n") f.write(f"Id: {statement.id}\n") - f.write(f"Group Id: {statement.group_id}\n") + f.write(f"Group Id: {statement.end_user_id}\n") f.write(f"Content: {statement.statement}\n") f.write(f"Type: {statement.stmt_type.value}\n") f.write(f"Temporal Info: {statement.temporal_info.value}\n") @@ -298,7 +298,7 @@ class StatementExtractor: dialog_sections.append({ "dialog_id": dialog.ref_id, - "group_id": dialog.group_id, + "end_user_id": dialog.end_user_id, "content": dialog.content if getattr(dialog, "content", None) else "", "strong": strong_relations, "weak": weak_relations, @@ -312,7 +312,7 @@ class StatementExtractor: for idx, section in enumerate(dialog_sections, 1): f.write(f"Dialog {idx}:\n") f.write(f"Dialog ID: {section.get('dialog_id', '')}\n") - f.write(f"Group ID: {section.get('group_id', '')}\n") + f.write(f"Group ID: {section.get('end_user_id', '')}\n") f.write("Content:\n") f.write(f"{section.get('content', '')}\n") f.write("-" * 40 + "\n\n") diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py index 9528e638..499027a4 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/temporal_extraction.py @@ -132,7 +132,7 @@ class TemporalExtractor: prompt_logger.info("") prompt_logger.info("=== TEMPORAL EXTRACTION RESULTS ===") prompt_logger.info( - f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}" + f"[Temporal] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}" ) except Exception: pass diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py index d3d059b0..bfc0bc88 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/triplet_extraction.py @@ -116,7 +116,7 @@ class TripletExtractor: logger.info(f"Processing {len(all_statements)} statements for triplet extraction...") try: prompt_logger.info( - f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, group_id={getattr(dialog_data, 'group_id', None)}, statements_to_process={len(all_statements)}" + f"[Triplet] Dialog ref_id={getattr(dialog_data, 'ref_id', None)}, end_user_id={getattr(dialog_data, 'end_user_id', None)}, statements_to_process={len(all_statements)}" ) except Exception: pass diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index 5722769a..a71c0957 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -75,7 +75,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, current_time: Optional[datetime] = None ) -> Dict[str, Any]: """ @@ -91,7 +91,7 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) - group_id: 组ID(可选,用于过滤) + end_user_id: 组ID(可选,用于过滤) current_time: 当前时间(可选,默认使用系统时间) Returns: @@ -123,7 +123,7 @@ class AccessHistoryManager: for attempt in range(self.max_retries): try: # 步骤1:读取当前节点状态 - node_data = await self._fetch_node(node_id, node_label, group_id) + node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: raise ValueError( @@ -142,7 +142,7 @@ class AccessHistoryManager: node_id=node_id, node_label=node_label, update_data=update_data, - group_id=group_id + end_user_id=end_user_id ) logger.info( @@ -172,7 +172,7 @@ class AccessHistoryManager: self, node_ids: List[str], node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, current_time: Optional[datetime] = None ) -> List[Dict[str, Any]]: """ @@ -184,7 +184,7 @@ class AccessHistoryManager: Args: node_ids: 节点ID列表 node_label: 节点标签(所有节点必须是同一类型) - group_id: 组ID(可选) + end_user_id: 组ID(可选) current_time: 当前时间(可选) Returns: @@ -202,7 +202,7 @@ class AccessHistoryManager: task = self.record_access( node_id=node_id, node_label=node_label, - group_id=group_id, + end_user_id=end_user_id, current_time=current_time ) tasks.append(task) @@ -235,7 +235,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Tuple[ConsistencyCheckResult, Optional[str]]: """ 检查节点数据的一致性 @@ -249,14 +249,14 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Tuple[ConsistencyCheckResult, Optional[str]]: - 一致性检查结果枚举 - 错误描述(如果不一致) """ - node_data = await self._fetch_node(node_id, node_label, group_id) + node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: return ConsistencyCheckResult.CONSISTENT, None @@ -305,7 +305,7 @@ class AccessHistoryManager: async def check_batch_consistency( self, node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 1000 ) -> Dict[str, Any]: """ @@ -313,7 +313,7 @@ class AccessHistoryManager: Args: node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) limit: 检查的最大节点数 Returns: @@ -329,16 +329,16 @@ class AccessHistoryManager: MATCH (n:{node_label}) WHERE n.access_history IS NOT NULL """ - if group_id: - query += " AND n.group_id = $group_id" + if end_user_id: + query += " AND n.end_user_id = $end_user_id" query += """ RETURN n.id as id LIMIT $limit """ params = {"limit": limit} - if group_id: - params["group_id"] = group_id + if end_user_id: + params["end_user_id"] = end_user_id results = await self.connector.execute_query(query, **params) node_ids = [r['id'] for r in results] @@ -351,7 +351,7 @@ class AccessHistoryManager: result, message = await self.check_consistency( node_id=node_id, node_label=node_label, - group_id=group_id + end_user_id=end_user_id ) if result == ConsistencyCheckResult.CONSISTENT: @@ -387,7 +387,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> bool: """ 自动修复节点的数据不一致问题 @@ -401,7 +401,7 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: bool: 修复成功返回True,否则返回False @@ -411,7 +411,7 @@ class AccessHistoryManager: result, message = await self.check_consistency( node_id=node_id, node_label=node_label, - group_id=group_id + end_user_id=end_user_id ) if result == ConsistencyCheckResult.CONSISTENT: @@ -419,7 +419,7 @@ class AccessHistoryManager: return True # 获取节点数据 - node_data = await self._fetch_node(node_id, node_label, group_id) + node_data = await self._fetch_node(node_id, node_label, end_user_id) if not node_data: logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") return False @@ -457,8 +457,8 @@ class AccessHistoryManager: query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ - if group_id: - query += " WHERE n.group_id = $group_id" + if end_user_id: + query += " WHERE n.end_user_id = $end_user_id" query += """ SET n += $repair_data RETURN n @@ -468,8 +468,8 @@ class AccessHistoryManager: 'node_id': node_id, 'repair_data': repair_data } - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id await self.connector.execute_query(query, **params) @@ -491,7 +491,7 @@ class AccessHistoryManager: self, node_id: str, node_label: str, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Optional[Dict[str, Any]]: """ 获取节点数据 @@ -499,7 +499,7 @@ class AccessHistoryManager: Args: node_id: 节点ID node_label: 节点标签 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Optional[Dict[str, Any]]: 节点数据,如果不存在返回None @@ -507,8 +507,8 @@ class AccessHistoryManager: query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ - if group_id: - query += " WHERE n.group_id = $group_id" + if end_user_id: + query += " WHERE n.end_user_id = $end_user_id" query += """ RETURN n.id as id, n.importance_score as importance_score, @@ -519,8 +519,8 @@ class AccessHistoryManager: """ params = {'node_id': node_id} - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id results = await self.connector.execute_query(query, **params) @@ -585,7 +585,7 @@ class AccessHistoryManager: node_id: str, node_label: str, update_data: Dict[str, Any], - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Dict[str, Any]: """ 原子性更新节点(使用乐观锁) @@ -597,7 +597,7 @@ class AccessHistoryManager: node_id: 节点ID node_label: 节点标签 update_data: 更新数据 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Dict[str, Any]: 更新后的节点数据 @@ -606,13 +606,13 @@ class AccessHistoryManager: RuntimeError: 如果更新失败或发生版本冲突 """ # 定义事务函数 - async def update_transaction(tx, node_id, node_label, update_data, group_id): + async def update_transaction(tx, node_id, node_label, update_data, end_user_id): # 步骤1:读取当前节点并获取版本号 read_query = f""" MATCH (n:{node_label} {{id: $node_id}}) """ - if group_id: - read_query += " WHERE n.group_id = $group_id" + if end_user_id: + read_query += " WHERE n.end_user_id = $end_user_id" read_query += """ RETURN n.id as id, n.version as version, @@ -624,8 +624,8 @@ class AccessHistoryManager: """ read_params = {'node_id': node_id} - if group_id: - read_params['group_id'] = group_id + if end_user_id: + read_params['end_user_id'] = end_user_id read_result = await tx.run(read_query, **read_params) current_node = await read_result.single() @@ -656,8 +656,8 @@ class AccessHistoryManager: # 构建 WHERE 子句 where_conditions = [] - if group_id: - where_conditions.append("n.group_id = $group_id") + if end_user_id: + where_conditions.append("n.end_user_id = $end_user_id") # 添加版本检查 if current_version > 0: @@ -695,8 +695,8 @@ class AccessHistoryManager: 'last_access_time': update_data['last_access_time'], 'access_count': update_data['access_count'] } - if group_id: - update_params['group_id'] = group_id + if end_user_id: + update_params['end_user_id'] = end_user_id update_result = await tx.run(update_query, **update_params) updated_node = await update_result.single() @@ -720,7 +720,7 @@ class AccessHistoryManager: node_id=node_id, node_label=node_label, update_data=update_data, - group_id=group_id + end_user_id=end_user_id ) return result except Exception as e: diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py index 6d42af53..e9d4c144 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py @@ -66,7 +66,7 @@ class ForgettingScheduler: async def run_forgetting_cycle( self, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, max_merge_batch_size: int = 100, min_days_since_access: int = 30, config_id: Optional[int] = None, @@ -77,7 +77,7 @@ class ForgettingScheduler: Args: - group_id: 组 ID(可选,用于过滤特定组的节点) + end_user_id: 组 ID(可选,用于过滤特定组的节点) max_merge_batch_size: 单次最大融合节点对数(默认 100) min_days_since_access: 最小未访问天数(默认 30 天) config_id: 配置ID(可选,用于获取 llm_id) @@ -107,19 +107,19 @@ class ForgettingScheduler: start_time_iso = start_time.isoformat() logger.info( - f"开始遗忘周期: group_id={group_id}, " + f"开始遗忘周期: end_user_id={end_user_id}, " f"max_batch={max_merge_batch_size}, " f"min_days={min_days_since_access}" ) try: # 步骤1:统计遗忘前的节点数量 - nodes_before = await self._count_knowledge_nodes(group_id) + nodes_before = await self._count_knowledge_nodes(end_user_id) logger.info(f"遗忘前节点总数: {nodes_before}") # 步骤2:识别可遗忘的节点对 forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes( - group_id=group_id, + end_user_id=end_user_id, min_days_since_access=min_days_since_access ) @@ -213,7 +213,7 @@ class ForgettingScheduler: 'statement_text': pair['statement_text'], 'statement_activation': pair['statement_activation'], 'statement_importance': pair['statement_importance'], - 'group_id': group_id + 'end_user_id': end_user_id } entity_node = { @@ -222,7 +222,7 @@ class ForgettingScheduler: 'entity_type': pair['entity_type'], 'entity_activation': pair['entity_activation'], 'entity_importance': pair['entity_importance'], - 'group_id': group_id + 'end_user_id': end_user_id } # 融合节点 @@ -262,7 +262,7 @@ class ForgettingScheduler: continue # 步骤6:统计遗忘后的节点数量 - nodes_after = await self._count_knowledge_nodes(group_id) + nodes_after = await self._count_knowledge_nodes(end_user_id) logger.info(f"遗忘后节点总数: {nodes_after}") # 步骤7:生成遗忘报告 @@ -315,7 +315,7 @@ class ForgettingScheduler: async def _count_knowledge_nodes( self, - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> int: """ 统计知识层节点总数 @@ -323,7 +323,7 @@ class ForgettingScheduler: 统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。 Args: - group_id: 组 ID(可选,用于过滤特定组的节点) + end_user_id: 组 ID(可选,用于过滤特定组的节点) Returns: int: 知识层节点总数 @@ -333,16 +333,16 @@ class ForgettingScheduler: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) """ - if group_id: - query += " AND n.group_id = $group_id" + if end_user_id: + query += " AND n.end_user_id = $end_user_id" query += """ RETURN count(n) as total """ params = {} - if group_id: - params['group_id'] = group_id + if end_user_id: + end_user_id['end_user_id'] = end_user_id results = await self.connector.execute_query(query, **params) diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py index ccd8d2ca..6b2d9e99 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py @@ -90,7 +90,7 @@ class ForgettingStrategy: async def find_forgettable_nodes( self, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, min_days_since_access: int = 30 ) -> List[Dict[str, Any]]: """ @@ -102,7 +102,7 @@ class ForgettingStrategy: 3. Statement 和 Entity 之间存在关系边 Args: - group_id: 组 ID(可选,用于过滤特定组的节点) + end_user_id: 组 ID(可选,用于过滤特定组的节点) min_days_since_access: 最小未访问天数(默认 30 天) Returns: @@ -136,8 +136,8 @@ class ForgettingStrategy: AND (e.entity_type IS NULL OR e.entity_type <> 'Person') """ - if group_id: - query += " AND s.group_id = $group_id AND e.group_id = $group_id" + if end_user_id: + query += " AND s.end_user_id = $end_user_id AND e.end_user_id = $end_user_id" query += """ RETURN s.id as statement_id, @@ -159,8 +159,8 @@ class ForgettingStrategy: 'threshold': self.forgetting_threshold, 'cutoff_time': cutoff_time_iso } - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id results = await self.connector.execute_query(query, **params) @@ -247,8 +247,8 @@ class ForgettingStrategy: entity_activation = entity_node['entity_activation'] entity_importance = entity_node['entity_importance'] - # 获取 group_id(从 statement 或 entity 节点) - group_id = statement_node.get('group_id') or entity_node.get('group_id') + # 获取 end_user_id(从 statement 或 entity 节点) + end_user_id = statement_node.get('end_user_id') or entity_node.get('end_user_id') # 生成摘要内容 summary_text = await self._generate_summary( @@ -325,7 +325,7 @@ class ForgettingStrategy: last_access_time: $current_time, access_count: 1, version: 1, - group_id: $group_id, + end_user_id: $end_user_id, created_at: datetime($current_time), merged_at: datetime($current_time) }) @@ -423,7 +423,7 @@ class ForgettingStrategy: 'inherited_activation': inherited_activation, 'inherited_importance': inherited_importance, 'current_time': current_time_iso, - 'group_id': group_id + 'end_user_id': end_user_id } try: diff --git a/api/app/core/memory/storage_services/search/__init__.py b/api/app/core/memory/storage_services/search/__init__.py index 2bec5bf1..c12c39b0 100644 --- a/api/app/core/memory/storage_services/search/__init__.py +++ b/api/app/core/memory/storage_services/search/__init__.py @@ -37,7 +37,7 @@ __all__ = [ async def run_hybrid_search( query_text: str, search_type: str = "hybrid", - group_id: str | None = None, + end_user_id: str | None = None, apply_id: str | None = None, user_id: str | None = None, limit: int = 50, @@ -54,7 +54,7 @@ async def run_hybrid_search( Args: query_text: 查询文本 search_type: 搜索类型("hybrid", "keyword", "semantic") - group_id: 组ID过滤 + end_user_id: 组ID过滤 apply_id: 应用ID过滤 user_id: 用户ID过滤 limit: 每个类别的最大结果数 @@ -104,7 +104,7 @@ async def run_hybrid_search( # 执行搜索 result = await strategy.search( query_text=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include, alpha=alpha, diff --git a/api/app/core/memory/storage_services/search/hybrid_search.py b/api/app/core/memory/storage_services/search/hybrid_search.py index 43215df5..4111b09c 100644 --- a/api/app/core/memory/storage_services/search/hybrid_search.py +++ b/api/app/core/memory/storage_services/search/hybrid_search.py @@ -77,7 +77,7 @@ # async def search( # self, # query_text: str, -# group_id: Optional[str] = None, +# end_user_id: Optional[str] = None, # limit: int = 50, # include: Optional[List[str]] = None, # **kwargs @@ -86,7 +86,7 @@ # Args: # query_text: 查询文本 -# group_id: 可选的组ID过滤 +# end_user_id: 可选的组ID过滤 # limit: 每个类别的最大结果数 # include: 要包含的搜索类别列表 # **kwargs: 其他搜索参数(如alpha, use_forgetting_curve) @@ -94,7 +94,7 @@ # Returns: # SearchResult: 搜索结果对象 # """ -# logger.info(f"执行混合搜索: query='{query_text}', group_id={group_id}, limit={limit}") +# logger.info(f"执行混合搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}") # # 从kwargs中获取参数 # alpha = kwargs.get("alpha", self.alpha) @@ -107,14 +107,14 @@ # # 并行执行关键词搜索和语义搜索 # keyword_result = await self.keyword_strategy.search( # query_text=query_text, -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # include=include_list # ) # semantic_result = await self.semantic_strategy.search( # query_text=query_text, -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # include=include_list # ) @@ -139,7 +139,7 @@ # metadata = self._create_metadata( # query_text=query_text, # search_type="hybrid", -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # include=include_list, # alpha=alpha, @@ -165,7 +165,7 @@ # metadata=self._create_metadata( # query_text=query_text, # search_type="hybrid", -# group_id=group_id, +# end_user_id=end_user_id, # limit=limit, # error=str(e) # ) diff --git a/api/app/core/memory/storage_services/search/keyword_search.py b/api/app/core/memory/storage_services/search/keyword_search.py index 95dd0581..d2591945 100644 --- a/api/app/core/memory/storage_services/search/keyword_search.py +++ b/api/app/core/memory/storage_services/search/keyword_search.py @@ -44,7 +44,7 @@ class KeywordSearchStrategy(SearchStrategy): async def search( self, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: Optional[List[str]] = None, **kwargs @@ -53,7 +53,7 @@ class KeywordSearchStrategy(SearchStrategy): Args: query_text: 查询文本 - group_id: 可选的组ID过滤 + end_user_id: 可选的组ID过滤 limit: 每个类别的最大结果数 include: 要包含的搜索类别列表 **kwargs: 其他搜索参数 @@ -61,7 +61,7 @@ class KeywordSearchStrategy(SearchStrategy): Returns: SearchResult: 搜索结果对象 """ - logger.info(f"执行关键词搜索: query='{query_text}', group_id={group_id}, limit={limit}") + logger.info(f"执行关键词搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}") # 获取有效的搜索类别 include_list = self._get_include_list(include) @@ -75,7 +75,7 @@ class KeywordSearchStrategy(SearchStrategy): results_dict = await search_graph( connector=self.connector, q=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -84,7 +84,7 @@ class KeywordSearchStrategy(SearchStrategy): metadata = self._create_metadata( query_text=query_text, search_type="keyword", - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -115,7 +115,7 @@ class KeywordSearchStrategy(SearchStrategy): metadata=self._create_metadata( query_text=query_text, search_type="keyword", - group_id=group_id, + end_user_id=end_user_id, limit=limit, error=str(e) ) diff --git a/api/app/core/memory/storage_services/search/search_strategy.py b/api/app/core/memory/storage_services/search/search_strategy.py index 27c02c89..3a670dd6 100644 --- a/api/app/core/memory/storage_services/search/search_strategy.py +++ b/api/app/core/memory/storage_services/search/search_strategy.py @@ -58,7 +58,7 @@ class SearchStrategy(ABC): async def search( self, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: Optional[List[str]] = None, **kwargs @@ -67,7 +67,7 @@ class SearchStrategy(ABC): Args: query_text: 查询文本 - group_id: 可选的组ID过滤 + end_user_id: 可选的组ID过滤 limit: 每个类别的最大结果数 include: 要包含的搜索类别列表(statements, chunks, entities, summaries) **kwargs: 其他搜索参数 @@ -81,7 +81,7 @@ class SearchStrategy(ABC): self, query_text: str, search_type: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, **kwargs ) -> Dict[str, Any]: @@ -90,7 +90,7 @@ class SearchStrategy(ABC): Args: query_text: 查询文本 search_type: 搜索类型 - group_id: 组ID + end_user_id: 组ID limit: 结果限制 **kwargs: 其他元数据 @@ -100,7 +100,7 @@ class SearchStrategy(ABC): metadata = { "query": query_text, "search_type": search_type, - "group_id": group_id, + "end_user_id": end_user_id, "limit": limit, "timestamp": datetime.now().isoformat() } diff --git a/api/app/core/memory/storage_services/search/semantic_search.py b/api/app/core/memory/storage_services/search/semantic_search.py index b20f90a5..8d4eb05f 100644 --- a/api/app/core/memory/storage_services/search/semantic_search.py +++ b/api/app/core/memory/storage_services/search/semantic_search.py @@ -85,7 +85,7 @@ class SemanticSearchStrategy(SearchStrategy): async def search( self, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: Optional[List[str]] = None, **kwargs @@ -94,7 +94,7 @@ class SemanticSearchStrategy(SearchStrategy): Args: query_text: 查询文本 - group_id: 可选的组ID过滤 + end_user_id: 可选的组ID过滤 limit: 每个类别的最大结果数 include: 要包含的搜索类别列表 **kwargs: 其他搜索参数 @@ -102,7 +102,7 @@ class SemanticSearchStrategy(SearchStrategy): Returns: SearchResult: 搜索结果对象 """ - logger.info(f"执行语义搜索: query='{query_text}', group_id={group_id}, limit={limit}") + logger.info(f"执行语义搜索: query='{query_text}', end_user_id={end_user_id}, limit={limit}") # 获取有效的搜索类别 include_list = self._get_include_list(include) @@ -119,7 +119,7 @@ class SemanticSearchStrategy(SearchStrategy): connector=self.connector, embedder_client=self.embedder_client, query_text=query_text, - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -128,7 +128,7 @@ class SemanticSearchStrategy(SearchStrategy): metadata = self._create_metadata( query_text=query_text, search_type="semantic", - group_id=group_id, + end_user_id=end_user_id, limit=limit, include=include_list ) @@ -159,7 +159,7 @@ class SemanticSearchStrategy(SearchStrategy): metadata=self._create_metadata( query_text=query_text, search_type="semantic", - group_id=group_id, + end_user_id=end_user_id, limit=limit, error=str(e) ) diff --git a/api/app/core/memory/utils/config/get_data.py b/api/app/core/memory/utils/config/get_data.py index 1de6f6aa..e37ad723 100644 --- a/api/app/core/memory/utils/config/get_data.py +++ b/api/app/core/memory/utils/config/get_data.py @@ -23,7 +23,7 @@ async def _load_(data: List[Any]) -> List[Dict]: target_keys = [ "id", "statement", - "group_id", + "end_user_id", "chunk_id", "created_at", "expired_at", @@ -75,7 +75,7 @@ async def get_data(result): """ EXCLUDE_FIELDS = { "user_id", - "group_id", + "end_user_id", "entity_type", "connect_strength", "relationship_type", diff --git a/api/app/core/memory/utils/log/audit_logger.py b/api/app/core/memory/utils/log/audit_logger.py index 9010aad5..f80ad4d5 100644 --- a/api/app/core/memory/utils/log/audit_logger.py +++ b/api/app/core/memory/utils/log/audit_logger.py @@ -62,7 +62,7 @@ class ConfigAuditLogger: self, config_id: str, user_id: Optional[str] = None, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, success: bool = True, details: Optional[Dict[str, Any]] = None ): @@ -72,14 +72,14 @@ class ConfigAuditLogger: Args: config_id: 配置 ID user_id: 用户 ID(可选) - group_id: 组 ID(可选) + end_user_id: 组 ID(可选) success: 是否成功 details: 详细信息(可选) """ result = "SUCCESS" if success else "FAILED" msg = ( f"CONFIG_LOAD config_id={config_id} " - f"user={user_id or 'N/A'} group={group_id or 'N/A'} " + f"user={user_id or 'N/A'} group={end_user_id or 'N/A'} " f"result={result}" ) if details: @@ -121,7 +121,7 @@ class ConfigAuditLogger: self, operation: str, config_id: str, - group_id: str, + end_user_id: str, success: bool = True, duration: Optional[float] = None, error: Optional[str] = None, @@ -133,7 +133,7 @@ class ConfigAuditLogger: Args: operation: 操作类型(WRITE, READ 等) config_id: 配置 ID - group_id: 组 ID + end_user_id: 组 ID success: 是否成功 duration: 操作耗时(秒) error: 错误信息(可选) @@ -142,7 +142,7 @@ class ConfigAuditLogger: result = "SUCCESS" if success else "FAILED" msg = ( f"{operation.upper()} config_id={config_id} " - f"group={group_id} result={result}" + f"group={end_user_id} result={result}" ) if duration is not None: msg += f" duration={duration:.2f}s" diff --git a/api/app/core/rag/vdb/field.py b/api/app/core/rag/vdb/field.py index 86d39060..99d872c2 100644 --- a/api/app/core/rag/vdb/field.py +++ b/api/app/core/rag/vdb/field.py @@ -4,7 +4,7 @@ from enum import StrEnum, auto class Field(StrEnum): CONTENT_KEY = "page_content" METADATA_KEY = "metadata" - GROUP_KEY = "group_id" + GROUP_KEY = "end_user_id" VECTOR = auto() # Sparse Vector aims to support full text search SPARSE_VECTOR = auto() diff --git a/api/app/repositories/neo4j/add_edges.py b/api/app/repositories/neo4j/add_edges.py index 3b45867e..162bf411 100644 --- a/api/app/repositories/neo4j/add_edges.py +++ b/api/app/repositories/neo4j/add_edges.py @@ -32,7 +32,7 @@ async def add_chunk_statement_edges(chunks: List[Chunk], connector: Neo4jConnect "id": stable_edge_id, "source": chunk.id, "target": stmt.id, - "group_id": getattr(stmt, 'group_id', None), + "end_user_id": getattr(stmt, 'end_user_id', None), "user_id":getattr(stmt, 'user_id', None), "apply_id": getattr(stmt, 'apply_id', None), "run_id": getattr(stmt, 'run_id', None) or getattr(chunk, 'run_id', None), @@ -83,7 +83,7 @@ async def add_memory_summary_statement_edges(summaries: List[MemorySummaryNode], edges.append({ "summary_id": s.id, "chunk_id": chunk_id, - "group_id": s.group_id, + "end_user_id": s.end_user_id, "run_id": s.run_id, "created_at": s.created_at.isoformat() if s.created_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None, diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index cf60a773..fcf700b5 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -6,10 +6,10 @@ from app.core.memory.models.graph_models import DialogueNode, StatementNode, Chu from app.repositories.neo4j.neo4j_connector import Neo4jConnector -async def delete_all_nodes(group_id: str, connector: Neo4jConnector): +async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector): """Delete all nodes in the database.""" - result = await connector.execute_query(f"MATCH (n {{group_id: '{group_id}'}}) DETACH DELETE n") - print(f"All group_id: {group_id} node and edge deleted successfully") + result = await connector.execute_query(f"MATCH (n {{end_user_id: '{end_user_id}'}}) DETACH DELETE n") + print(f"All end_user_id: {end_user_id} node and edge deleted successfully") return result async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConnector) -> Optional[List[str]]: @@ -32,9 +32,7 @@ async def add_dialogue_nodes(dialogues: List[DialogueNode], connector: Neo4jConn for dialogue in dialogues: flattened_dialogues.append({ "id": dialogue.id, - "group_id": dialogue.group_id, - "user_id": dialogue.user_id, - "apply_id": dialogue.apply_id, + "end_user_id": dialogue.end_user_id, "run_id": dialogue.run_id, "ref_id": dialogue.ref_id, "name": dialogue.name, @@ -79,9 +77,7 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC flattened_statement = { "id": statement.id, "name": statement.name, - "group_id": statement.group_id, - "user_id": statement.user_id, - "apply_id": statement.apply_id, + "end_user_id": statement.end_user_id, "run_id": statement.run_id, "chunk_id": statement.chunk_id, # "created_at": statement.created_at.isoformat(), @@ -154,9 +150,7 @@ async def add_chunk_nodes(chunks: List[ChunkNode], connector: Neo4jConnector) -> flattened_chunk = { "id": chunk.id, "name": chunk.name, - "group_id": chunk.group_id, - "user_id": chunk.user_id, - "apply_id": chunk.apply_id, + "end_user_id": chunk.end_user_id, "run_id": chunk.run_id, "created_at": chunk.created_at.isoformat() if chunk.created_at else None, "expired_at": chunk.expired_at.isoformat() if chunk.expired_at else None, @@ -206,9 +200,7 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector flattened.append({ "id": s.id, "name": s.name, - "group_id": s.group_id, - "user_id": s.user_id, - "apply_id": s.apply_id, + "end_user_id": s.end_user_id, "run_id": s.run_id, "created_at": s.created_at.isoformat() if s.created_at else None, "expired_at": s.expired_at.isoformat() if s.expired_at else None, diff --git a/api/app/repositories/neo4j/base_neo4j_repository.py b/api/app/repositories/neo4j/base_neo4j_repository.py index 959a1e68..df953eb9 100644 --- a/api/app/repositories/neo4j/base_neo4j_repository.py +++ b/api/app/repositories/neo4j/base_neo4j_repository.py @@ -152,7 +152,7 @@ class BaseNeo4jRepository(BaseRepository[T]): Example: >>> results = await repository.find( - ... {"group_id": "group_123", "user_id": "user_456"}, + ... {"end_user_id": "group_123", "user_id": "user_456"}, ... limit=50 ... ) """ diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index cd3cbed7..eaef1e7a 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -3,9 +3,7 @@ DIALOGUE_NODE_SAVE = """ UNWIND $dialogues AS dialogue MERGE (n:Dialogue {id: dialogue.id}) SET n.uuid = coalesce(n.uuid, dialogue.id), - n.group_id = dialogue.group_id, - n.user_id = dialogue.user_id, - n.apply_id = dialogue.apply_id, + n.end_user_id = dialogue.end_user_id, n.run_id = dialogue.run_id, n.ref_id = dialogue.ref_id, n.created_at = dialogue.created_at, @@ -22,9 +20,7 @@ SET s += { id: statement.id, run_id: statement.run_id, chunk_id: statement.chunk_id, - group_id: statement.group_id, - user_id: statement.user_id, - apply_id: statement.apply_id, + end_user_id: statement.end_user_id, stmt_type: statement.stmt_type, statement: statement.statement, emotion_intensity: statement.emotion_intensity, @@ -54,9 +50,7 @@ MERGE (c:Chunk {id: chunk.id}) SET c += { id: chunk.id, name: chunk.name, - group_id: chunk.group_id, - user_id: chunk.user_id, - apply_id: chunk.apply_id, + end_user_id: chunk.end_user_id, run_id: chunk.run_id, created_at: chunk.created_at, expired_at: chunk.expired_at, @@ -76,9 +70,7 @@ EXTRACTED_ENTITY_NODE_SAVE = """ UNWIND $entities AS entity MERGE (e:ExtractedEntity {id: entity.id}) SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity.name ELSE e.name END, - e.group_id = CASE WHEN entity.group_id IS NOT NULL AND entity.group_id <> '' THEN entity.group_id ELSE e.group_id END, - e.user_id = CASE WHEN entity.user_id IS NOT NULL AND entity.user_id <> '' THEN entity.user_id ELSE e.user_id END, - e.apply_id = CASE WHEN entity.apply_id IS NOT NULL AND entity.apply_id <> '' THEN entity.apply_id ELSE e.apply_id END, + e.end_user_id = CASE WHEN entity.end_user_id IS NOT NULL AND entity.end_user_id <> '' THEN entity.end_user_id ELSE e.end_user_id END, e.run_id = CASE WHEN entity.run_id IS NOT NULL AND entity.run_id <> '' THEN entity.run_id ELSE e.run_id END, e.created_at = CASE WHEN entity.created_at IS NOT NULL AND (e.created_at IS NULL OR entity.created_at < e.created_at) @@ -134,9 +126,9 @@ RETURN e.id AS uuid # Add back ENTITY_RELATIONSHIP_SAVE to be used by graph_saver.save_entities_and_relationships ENTITY_RELATIONSHIP_SAVE = """ UNWIND $relationships AS rel -// Match entities by stable id within group, do not constrain by run_id -MATCH (subject:ExtractedEntity {id: rel.source_id, group_id: rel.group_id}) -MATCH (object:ExtractedEntity {id: rel.target_id, group_id: rel.group_id}) +// Match entities by stable id within end_user_id, do not constrain by run_id +MATCH (subject:ExtractedEntity {id: rel.source_id, end_user_id: rel.end_user_id}) +MATCH (object:ExtractedEntity {id: rel.target_id, end_user_id: rel.end_user_id}) // Avoid duplicate edges across runs for the same endpoints MERGE (subject)-[r:EXTRACTED_RELATIONSHIP]->(object) SET r.predicate = rel.predicate, @@ -148,7 +140,7 @@ SET r.predicate = rel.predicate, r.created_at = rel.created_at, r.expired_at = rel.expired_at, r.run_id = rel.run_id, - r.group_id = rel.group_id + r.end_user_id = rel.end_user_id RETURN elementId(r) AS uuid """ @@ -160,7 +152,7 @@ UNWIND $weak_entities AS entity MERGE (e:ExtractedEntity {id: entity.id, run_id: entity.run_id}) SET e += { name: entity.name, - group_id: entity.group_id, + end_user_id: entity.end_user_id, run_id: entity.run_id, description: entity.description, chunk_id: entity.chunk_id, @@ -175,11 +167,11 @@ RETURN e.id AS id SAVE_STRONG_TRIPLE_ENTITIES = """ UNWIND $items AS item MERGE (s:ExtractedEntity {id: item.source_id, run_id: item.run_id}) -SET s += {name: item.subject, group_id: item.group_id, run_id: item.run_id} +SET s += {name: item.subject, end_user_id: item.end_user_id, run_id: item.run_id} // Independent strong flag SET s.is_strong = true MERGE (o:ExtractedEntity {id: item.target_id, run_id: item.run_id}) -SET o += {name: item.object, group_id: item.group_id, run_id: item.run_id} +SET o += {name: item.object, end_user_id: item.end_user_id, run_id: item.run_id} // Independent strong flag SET o.is_strong = true """ @@ -194,7 +186,7 @@ DIALOGUE_STATEMENT_EDGE_SAVE = """ // 仅按端点去重,关系属性可更新 MERGE (dialogue)-[e:MENTIONS]->(statement) SET e.uuid = edge.id, - e.group_id = edge.group_id, + e.end_user_id = edge.end_user_id, e.created_at = edge.created_at, e.expired_at = edge.expired_at RETURN e.uuid AS uuid @@ -208,7 +200,7 @@ CHUNK_STATEMENT_EDGE_SAVE = """ MATCH (statement:Statement {id: edge.source, run_id: edge.run_id}) MATCH (chunk:Chunk {id: edge.target, run_id: edge.run_id}) MERGE (chunk)-[e:CONTAINS {id: edge.id}]->(statement) - SET e.group_id = edge.group_id, + SET e.end_user_id = edge.end_user_id, e.run_id = edge.run_id, e.created_at = edge.created_at, e.expired_at = edge.expired_at @@ -218,13 +210,12 @@ CHUNK_STATEMENT_EDGE_SAVE = """ STATEMENT_ENTITY_EDGE_SAVE = """ UNWIND $relationships AS rel // Statement nodes are per-run; keep run_id constraint on statements -// Statement nodes are per-run; keep run_id constraint on statements MATCH (statement:Statement {id: rel.source, run_id: rel.run_id}) -// Entities are shared across runs within a group; do not constrain by run_id -MATCH (entity:ExtractedEntity {id: rel.target, group_id: rel.group_id}) +// Entities are shared across runs within end_user_id; do not constrain by run_id +MATCH (entity:ExtractedEntity {id: rel.target, end_user_id: rel.end_user_id}) // Avoid duplicate edges across runs for same endpoints MERGE (statement)-[r:REFERENCES_ENTITY]->(entity) -SET r.group_id = rel.group_id, +SET r.end_user_id = rel.end_user_id, r.run_id = rel.run_id, r.created_at = rel.created_at, r.expired_at = rel.expired_at, @@ -236,10 +227,10 @@ ENTITY_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('entity_embedding_index', $limit * 100, $embedding) YIELD node AS e, score WHERE e.name_embedding IS NOT NULL - AND ($group_id IS NULL OR e.group_id = $group_id) + AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) RETURN e.id AS id, e.name AS name, - e.group_id AS group_id, + e.end_user_id AS end_user_id, e.entity_type AS entity_type, COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, COALESCE(e.importance_score, 0.5) AS importance_score, @@ -254,10 +245,10 @@ STATEMENT_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('statement_embedding_index', $limit * 100, $embedding) YIELD node AS s, score WHERE s.statement_embedding IS NOT NULL - AND ($group_id IS NULL OR s.group_id = $group_id) + AND ($end_user_id IS NULL OR s.end_user_id = $end_user_id) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.expired_at AS expired_at, @@ -277,9 +268,9 @@ CHUNK_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('chunk_embedding_index', $limit * 100, $embedding) YIELD node AS c, score WHERE c.chunk_embedding IS NOT NULL - AND ($group_id IS NULL OR c.group_id = $group_id) + AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id) RETURN c.id AS chunk_id, - c.group_id AS group_id, + c.end_user_id AS end_user_id, c.content AS content, c.dialog_id AS dialog_id, COALESCE(c.activation_value, 0.5) AS activation_value, @@ -292,12 +283,12 @@ LIMIT $limit SEARCH_STATEMENTS_BY_KEYWORD = """ CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score -WHERE ($group_id IS NULL OR s.group_id = $group_id) +WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.expired_at AS expired_at, @@ -316,15 +307,13 @@ LIMIT $limit # 查询实体名称包含指定字符串的实体 SEARCH_ENTITIES_BY_NAME = """ CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score -WHERE ($group_id IS NULL OR e.group_id = $group_id) +WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) RETURN e.id AS id, e.name AS name, - e.group_id AS group_id, + e.end_user_id AS end_user_id, e.entity_type AS entity_type, - e.apply_id AS apply_id, - e.user_id AS user_id, e.created_at AS created_at, e.expired_at AS expired_at, e.entity_idx AS entity_idx, @@ -347,11 +336,11 @@ LIMIT $limit SEARCH_CHUNKS_BY_CONTENT = """ CALL db.index.fulltext.queryNodes("chunksFulltext", $q) YIELD node AS c, score -WHERE ($group_id IS NULL OR c.group_id = $group_id) +WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) OPTIONAL MATCH (c)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) RETURN c.id AS chunk_id, - c.group_id AS group_id, + c.end_user_id AS end_user_id, c.content AS content, c.dialog_id AS dialog_id, c.sequence_number AS sequence_number, @@ -413,10 +402,10 @@ LIMIT $limit SEARCH_DIALOGUE_BY_DIALOG_ID = """ MATCH (d:Dialogue) -WHERE ($group_id IS NULL OR d.group_id = $group_id) +WHERE ($end_user_id IS NULL OR d.end_user_id = $end_user_id) AND d.id = $dialog_id RETURN d.id AS dialog_id, - d.group_id AS group_id, + d.end_user_id AS end_user_id, d.content AS content, d.created_at AS created_at, d.expired_at AS expired_at @@ -426,10 +415,10 @@ LIMIT $limit SEARCH_CHUNK_BY_CHUNK_ID = """ MATCH (c:Chunk) -WHERE ($group_id IS NULL OR c.group_id = $group_id) +WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id) AND c.id = $chunk_id RETURN c.id AS chunk_id, - c.group_id AS group_id, + c.end_user_id AS end_user_id, c.content AS content, c.dialog_id AS dialog_id, c.created_at AS created_at, @@ -441,18 +430,14 @@ LIMIT $limit SEARCH_STATEMENTS_BY_TEMPORAL = """ MATCH (s:Statement) -WHERE ($group_id IS NULL OR s.group_id = $group_id) - AND ($apply_id IS NULL OR s.apply_id = $apply_id) - AND ($user_id IS NULL OR s.user_id = $user_id) +WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND ((($start_date IS NULL OR datetime(s.created_at) >= datetime($start_date)) AND ($end_date IS NULL OR datetime(s.created_at) <= datetime($end_date))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) AND ($invalid_date IS NULL OR (s.invalid_at IS NOT NULL AND datetime(s.invalid_at) <= datetime($invalid_date))))) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, - s.apply_id AS apply_id, - s.user_id AS user_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.valid_at AS valid_at, @@ -468,9 +453,7 @@ LIMIT $limit SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL = """ CALL db.index.fulltext.queryNodes("statementsFulltext", $q) YIELD node AS s, score -WHERE ($group_id IS NULL OR s.group_id = $group_id) - AND ($apply_id IS NULL OR s.apply_id = $apply_id) - AND ($user_id IS NULL OR s.user_id = $user_id) +WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND ((($start_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) >= datetime($start_date))) AND ($end_date IS NULL OR (s.created_at IS NOT NULL AND datetime(s.created_at) <= datetime($end_date)))) OR (($valid_date IS NULL OR (s.valid_at IS NOT NULL AND datetime(s.valid_at) >= datetime($valid_date))) @@ -479,9 +462,7 @@ OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(e:ExtractedEntity) RETURN s.id AS id, s.statement AS statement, - s.group_id AS group_id, - s.apply_id AS apply_id, - s.user_id AS user_id, + s.end_user_id AS end_user_id, s.chunk_id AS chunk_id, s.created_at AS created_at, s.valid_at AS valid_at, @@ -499,15 +480,11 @@ LIMIT $limit SEARCH_STATEMENTS_BY_CREATED_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 10)) = date($created_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -519,15 +496,11 @@ LIMIT $limit SEARCH_STATEMENTS_BY_VALID_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) = date($valid_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -539,15 +512,11 @@ LIMIT $limit SEARCH_STATEMENTS_G_CREATED_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) = date($created_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -559,15 +528,11 @@ LIMIT $limit SEARCH_STATEMENTS_L_CREATED_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($created_at IS NOT NULL AND date(substring(n.created_at, 0, 19)) < date($created_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -579,15 +544,11 @@ LIMIT $limit SEARCH_STATEMENTS_G_VALID_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) > date($valid_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -599,15 +560,11 @@ LIMIT $limit SEARCH_STATEMENTS_L_VALID_AT = """ MATCH (n:Statement) -WHERE ($group_id IS NULL OR n.group_id = $group_id) - AND ($apply_id IS NULL OR n.apply_id = $apply_id) - AND ($user_id IS NULL OR n.user_id = $user_id) +WHERE ($end_user_id IS NULL OR n.end_user_id = $end_user_id) AND ($valid_at IS NOT NULL AND date(substring(n.valid_at, 0, 10)) < date($valid_at)) RETURN n.id AS id, n.statement AS statement, - n.group_id AS group_id, - n.apply_id AS apply_id, - n.user_id AS user_id, + n.end_user_id AS end_user_id, n.chunk_id AS chunk_id, n.created_at AS created_at, n.valid_at AS valid_at, @@ -665,18 +622,18 @@ LIMIT $limit # 根据id修改句子的invalid_at的值 UPDATE_STATEMENT_INVALID_AT = """ -MATCH (n:Statement {group_id: $group_id, id: $id}) +MATCH (n:Statement {end_user_id: $end_user_id, id: $id}) SET n.invalid_at = $new_invalid_at """ # MemorySummary keyword search using fulltext index SEARCH_MEMORY_SUMMARIES_BY_KEYWORD = """ CALL db.index.fulltext.queryNodes("summariesFulltext", $q) YIELD node AS m, score -WHERE ($group_id IS NULL OR m.group_id = $group_id) +WHERE ($end_user_id IS NULL OR m.end_user_id = $end_user_id) OPTIONAL MATCH (m)-[:DERIVED_FROM_STATEMENT]->(s:Statement) RETURN m.id AS id, m.name AS name, - m.group_id AS group_id, + m.end_user_id AS end_user_id, m.dialog_id AS dialog_id, m.chunk_ids AS chunk_ids, m.content AS content, @@ -695,10 +652,10 @@ MEMORY_SUMMARY_EMBEDDING_SEARCH = """ CALL db.index.vector.queryNodes('summary_embedding_index', $limit * 100, $embedding) YIELD node AS m, score WHERE m.summary_embedding IS NOT NULL - AND ($group_id IS NULL OR m.group_id = $group_id) + AND ($end_user_id IS NULL OR m.end_user_id = $end_user_id) RETURN m.id AS id, m.name AS name, - m.group_id AS group_id, + m.end_user_id AS end_user_id, m.dialog_id AS dialog_id, m.chunk_ids AS chunk_ids, m.content AS content, @@ -718,9 +675,7 @@ MERGE (m:MemorySummary {id: summary.id}) SET m += { id: summary.id, name: summary.name, - group_id: summary.group_id, - user_id: summary.user_id, - apply_id: summary.apply_id, + end_user_id: summary.end_user_id, run_id: summary.run_id, created_at: summary.created_at, expired_at: summary.expired_at, @@ -814,7 +769,7 @@ RETURN count(losing) as deleted neo4j_statement_part = ''' MATCH (n:Statement) -WHERE n.group_id = "{}" +WHERE n.end_user_id = "{}" AND datetime(n.created_at) >= datetime() - duration('P3D') RETURN n.statement as statement_name, @@ -824,7 +779,7 @@ RETURN ''' neo4j_statement_all = ''' MATCH (n:Statement) -WHERE n.group_id = "{}" +WHERE n.end_user_id = "{}" RETURN n.statement as statement_name, n.id as statement_id @@ -832,7 +787,7 @@ RETURN ''' neo4j_query_part = """ MATCH (n)-[r]-(m:ExtractedEntity) - WHERE n.group_id = "{}" + WHERE n.end_user_id = "{}" AND datetime(n.created_at) >= datetime() - duration('P3D') WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) @@ -853,7 +808,7 @@ neo4j_query_part = """ """ neo4j_query_all = """ MATCH (n)-[r]-(m:ExtractedEntity) - WHERE n.group_id = "{}" + WHERE n.end_user_id = "{}" WITH DISTINCT m OPTIONAL MATCH (m)-[rel]-(other:ExtractedEntity) RETURN @@ -1027,14 +982,14 @@ RETURN DISTINCT Memory_Space_User=""" MATCH (n)-[r]->(m) -WHERE n.group_id = $group_id AND m.name="用户" +WHERE n.end_user_id = $end_user_id AND m.name="用户" return DISTINCT elementId(m) as id """ Memory_Space_Entity=""" MATCH (n)-[]-(m) WHERE elementId(m) = $id AND m.entity_type = "Person" RETURN -DISTINCT m.name as name,m.group_id as group_id +DISTINCT m.name as name,m.end_user_id as end_user_id """ Memory_Space_Associative=""" MATCH (u)-[]-(x)-[]-(h) diff --git a/api/app/repositories/neo4j/dialog_repository.py b/api/app/repositories/neo4j/dialog_repository.py index ccb3d94c..48376c2a 100644 --- a/api/app/repositories/neo4j/dialog_repository.py +++ b/api/app/repositories/neo4j/dialog_repository.py @@ -19,7 +19,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): """对话仓储 管理对话节点的创建、查询、更新和删除操作。 - 提供按group_id、user_id、ref_id等条件查询对话的方法。 + 提供按end_user_id、user_id、ref_id等条件查询对话的方法。 Attributes: connector: Neo4j连接器实例 @@ -54,17 +54,17 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): return DialogueNode(**n) - async def find_by_group_id(self, group_id: str, limit: int = 100) -> List[DialogueNode]: - """根据group_id查询对话 + async def find_by_end_user_id(self, end_user_id: str, limit: int = 100) -> List[DialogueNode]: + """根据end_user_id查询对话 Args: - group_id: 组ID + end_user_id: 组ID limit: 返回结果的最大数量 Returns: List[DialogueNode]: 对话列表 """ - return await self.find({"group_id": group_id}, limit=limit) + return await self.find({"end_user_id": end_user_id}, limit=limit) async def find_by_user_id(self, user_id: str, limit: int = 100) -> List[DialogueNode]: """根据user_id查询对话 @@ -94,14 +94,14 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): async def find_by_group_and_user( self, - group_id: str, + end_user_id: str, user_id: str, limit: int = 100 ) -> List[DialogueNode]: - """根据group_id和user_id查询对话 + """根据end_user_id和user_id查询对话 Args: - group_id: 组ID + end_user_id: 组ID user_id: 用户ID limit: 返回结果的最大数量 @@ -109,20 +109,20 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): List[DialogueNode]: 对话列表 """ return await self.find( - {"group_id": group_id, "user_id": user_id}, + {"end_user_id": end_user_id, "user_id": user_id}, limit=limit ) async def find_recent_dialogs( self, - group_id: str, + end_user_id: str, days: int = 7, limit: int = 100 ) -> List[DialogueNode]: """查询最近的对话 Args: - group_id: 组ID + end_user_id: 组ID days: 查询最近多少天的对话 limit: 返回结果的最大数量 @@ -131,7 +131,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id AND n.created_at >= datetime() - duration({{days: $days}}) RETURN n ORDER BY n.created_at DESC @@ -139,7 +139,7 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): """ results = await self.connector.execute_query( query, - group_id=group_id, + end_user_id=end_user_id, days=days, limit=limit ) @@ -164,16 +164,16 @@ class DialogRepository(BaseNeo4jRepository[DialogueNode]): async def find_by_config_and_group( self, config_id: str, - group_id: str, + end_user_id: str, limit: int = 100 ) -> List[DialogueNode]: - """根据config_id和group_id查询对话 + """根据config_id和end_user_id查询对话 支持按配置ID和组ID同时过滤,确保只返回使用特定配置处理的对话。 Args: config_id: 配置ID - group_id: 组ID + end_user_id: 组ID limit: 返回结果的最大数量 Returns: diff --git a/api/app/repositories/neo4j/emotion_repository.py b/api/app/repositories/neo4j/emotion_repository.py index d445c8d4..7a8ebcf9 100644 --- a/api/app/repositories/neo4j/emotion_repository.py +++ b/api/app/repositories/neo4j/emotion_repository.py @@ -40,7 +40,7 @@ class EmotionRepository: async def get_emotion_tags( self, - group_id: str, + end_user_id: str, emotion_type: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, @@ -51,7 +51,7 @@ class EmotionRepository: 查询指定用户的情绪类型分布,包括计数、百分比和平均强度。 Args: - group_id: 用户组ID(宿主ID) + end_user_id: 用户组ID(宿主ID) emotion_type: 可选的情绪类型过滤(joy/sadness/anger/fear/surprise/neutral) start_date: 可选的开始日期(ISO格式字符串) end_date: 可选的结束日期(ISO格式字符串) @@ -65,8 +65,8 @@ class EmotionRepository: - avg_intensity: 平均强度 """ # 构建查询条件 - where_clauses = ["s.group_id = $group_id", "s.emotion_type IS NOT NULL"] - params = {"group_id": group_id, "limit": limit} + where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_type IS NOT NULL"] + params = {"end_user_id": end_user_id, "limit": limit} if emotion_type: where_clauses.append("s.emotion_type = $emotion_type") @@ -119,7 +119,7 @@ class EmotionRepository: async def get_emotion_wordcloud( self, - group_id: str, + end_user_id: str, emotion_type: Optional[str] = None, limit: int = 50 ) -> List[Dict[str, Any]]: @@ -128,7 +128,7 @@ class EmotionRepository: 查询情绪关键词及其频率,用于生成词云可视化。 Args: - group_id: 用户组ID(宿主ID) + end_user_id: 用户组ID(宿主ID) emotion_type: 可选的情绪类型过滤 limit: 返回关键词的最大数量 @@ -140,8 +140,8 @@ class EmotionRepository: - avg_intensity: 平均强度 """ # 构建查询条件 - where_clauses = ["s.group_id = $group_id", "s.emotion_keywords IS NOT NULL"] - params = {"group_id": group_id, "limit": limit} + where_clauses = ["s.end_user_id = $end_user_id", "s.emotion_keywords IS NOT NULL"] + params = {"end_user_id": end_user_id, "limit": limit} if emotion_type: where_clauses.append("s.emotion_type = $emotion_type") @@ -186,7 +186,7 @@ class EmotionRepository: async def get_emotions_in_range( self, - group_id: str, + end_user_id: str, time_range: str = "30d" ) -> List[Dict[str, Any]]: """获取时间范围内的情绪数据 @@ -194,7 +194,7 @@ class EmotionRepository: 查询指定时间范围内的所有情绪数据,用于健康指数计算。 Args: - group_id: 用户组ID(宿主ID) + end_user_id: 用户组ID(宿主ID) time_range: 时间范围(7d/30d/90d) Returns: @@ -214,7 +214,7 @@ class EmotionRepository: # 优化的 Cypher 查询:使用字符串比较避免时区问题 query = """ MATCH (s:Statement) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id AND s.emotion_type IS NOT NULL AND s.created_at >= $start_date RETURN s.id as statement_id, diff --git a/api/app/repositories/neo4j/graph_saver.py b/api/app/repositories/neo4j/graph_saver.py index 13215e0f..1575315f 100644 --- a/api/app/repositories/neo4j/graph_saver.py +++ b/api/app/repositories/neo4j/graph_saver.py @@ -44,9 +44,7 @@ async def save_entities_and_relationships( 'created_at': edge.created_at.isoformat(), 'expired_at': edge.expired_at.isoformat(), 'run_id': edge.run_id, - 'group_id': edge.group_id, - 'user_id': edge.user_id, - 'apply_id': edge.apply_id, + 'end_user_id': edge.end_user_id, } all_relationships.append(relationship) @@ -101,9 +99,7 @@ async def save_statement_chunk_edges( "id": edge.id, "source": edge.source, "target": edge.target, - "group_id": edge.group_id, - "user_id": edge.user_id, - "apply_id": edge.apply_id, + "end_user_id": edge.end_user_id, "run_id": edge.run_id, "created_at": edge.created_at.isoformat() if edge.created_at else None, "expired_at": edge.expired_at.isoformat() if edge.expired_at else None, @@ -132,9 +128,7 @@ async def save_statement_entity_edges( edge_data = { "source": edge.source, "target": edge.target, - "group_id": edge.group_id, - "user_id": edge.user_id, - "apply_id": edge.apply_id, + "end_user_id": edge.end_user_id, "run_id": edge.run_id, "connect_strength": edge.connect_strength, "created_at": edge.created_at.isoformat() if edge.created_at else None, diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index 6f5764b4..9660f6cb 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -33,7 +33,7 @@ async def _update_activation_values_batch( connector: Neo4jConnector, nodes: List[Dict[str, Any]], node_label: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, max_retries: int = 3 ) -> List[Dict[str, Any]]: """ @@ -46,7 +46,7 @@ async def _update_activation_values_batch( connector: Neo4j连接器 nodes: 节点列表,每个节点必须包含 'id' 字段 node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) - group_id: 组ID(可选) + end_user_id: 组ID(可选) max_retries: 最大重试次数 Returns: @@ -97,7 +97,7 @@ async def _update_activation_values_batch( updated_nodes = await access_manager.record_batch_access( node_ids=unique_node_ids, node_label=node_label, - group_id=group_id + end_user_id=end_user_id ) logger.info( @@ -118,7 +118,7 @@ async def _update_activation_values_batch( async def _update_search_results_activation( connector: Neo4jConnector, results: Dict[str, List[Dict[str, Any]]], - group_id: Optional[str] = None + end_user_id: Optional[str] = None ) -> Dict[str, List[Dict[str, Any]]]: """ 更新搜索结果中所有知识节点的激活值 @@ -129,7 +129,7 @@ async def _update_search_results_activation( Args: connector: Neo4j连接器 results: 搜索结果字典,包含不同类型节点的列表 - group_id: 组ID(可选) + end_user_id: 组ID(可选) Returns: Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果 @@ -152,7 +152,7 @@ async def _update_search_results_activation( connector=connector, nodes=results[key], node_label=label, - group_id=group_id + end_user_id=end_user_id ) ) update_keys.append(key) @@ -218,7 +218,7 @@ async def _update_search_results_activation( async def search_graph( connector: Neo4jConnector, q: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: List[str] = None, ) -> Dict[str, List[Dict[str, Any]]]: @@ -236,7 +236,7 @@ async def search_graph( Args: connector: Neo4j connector q: Query text - group_id: Optional group filter + end_user_id: Optional group filter limit: Max results per category include: List of categories to search (default: all) @@ -254,7 +254,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("statements") @@ -263,7 +263,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_ENTITIES_BY_NAME, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("entities") @@ -272,7 +272,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_CHUNKS_BY_CONTENT, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("chunks") @@ -281,7 +281,7 @@ async def search_graph( tasks.append(connector.execute_query( SEARCH_MEMORY_SUMMARIES_BY_KEYWORD, q=q, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("summaries") @@ -305,19 +305,12 @@ async def search_graph( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - # Skip activation updates if only searching summaries (optimization) - needs_activation_update = any( - key in include and key in results and results[key] - for key in ['statements', 'entities', 'chunks'] + results = await _update_search_results_activation( + connector=connector, + results=results, + end_user_id=end_user_id ) - if needs_activation_update: - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id - ) - return results @@ -325,7 +318,7 @@ async def search_graph_by_embedding( connector: Neo4jConnector, embedder_client, query_text: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 50, include: List[str] = ["statements", "chunks", "entities","summaries"], ) -> Dict[str, List[Dict[str, Any]]]: @@ -337,7 +330,7 @@ async def search_graph_by_embedding( - Computes query embedding with the provided embedder_client - Ranks by cosine similarity in Cypher - - Filters by group_id if provided + - Filters by end_user_id if provided - Returns up to 'limit' per included type """ import time @@ -346,7 +339,7 @@ async def search_graph_by_embedding( embed_start = time.time() embeddings = await embedder_client.response([query_text]) embed_time = time.time() - embed_start - logger.info(f"[PERF] Embedding generation took: {embed_time:.4f}s") + print(f"[PERF] Embedding generation took: {embed_time:.4f}s") if not embeddings or not embeddings[0]: return {"statements": [], "chunks": [], "entities": [], "summaries": []} @@ -361,7 +354,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( STATEMENT_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("statements") @@ -371,7 +364,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( CHUNK_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("chunks") @@ -381,7 +374,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( ENTITY_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("entities") @@ -391,7 +384,7 @@ async def search_graph_by_embedding( tasks.append(connector.execute_query( MEMORY_SUMMARY_EMBEDDING_SEARCH, embedding=embedding, - group_id=group_id, + end_user_id=end_user_id, limit=limit, )) task_keys.append("summaries") @@ -400,7 +393,7 @@ async def search_graph_by_embedding( query_start = time.time() task_results = await asyncio.gather(*tasks, return_exceptions=True) query_time = time.time() - query_start - logger.info(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") + print(f"[PERF] Neo4j queries (parallel) took: {query_time:.4f}s") # Build results dictionary results: Dict[str, List[Dict[str, Any]]] = { @@ -424,28 +417,19 @@ async def search_graph_by_embedding( results[key] = _deduplicate_results(results[key]) # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) - # Skip activation updates if only searching summaries (optimization) - needs_activation_update = any( - key in include and key in results and results[key] - for key in ['statements', 'entities', 'chunks'] + update_start = time.time() + results = await _update_search_results_activation( + connector=connector, + results=results, + end_user_id=end_user_id ) - - if needs_activation_update: - update_start = time.time() - results = await _update_search_results_activation( - connector=connector, - results=results, - group_id=group_id - ) - update_time = time.time() - update_start - logger.info(f"[PERF] Activation value updates took: {update_time:.4f}s") - else: - logger.info(f"[PERF] Skipping activation updates (only summaries)") + update_time = time.time() - update_start + print(f"[PERF] Activation value updates took: {update_time:.4f}s") return results async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 connector: Neo4jConnector, - group_id: str, + end_user_id: str, entities: List[Dict[str, Any]], use_contains_fallback: bool = True, batch_size: int = 500, @@ -453,7 +437,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 ) -> Dict[str, List[Dict[str, Any]]]: """ 为第二层去重消歧批量检索候选实体(适配新版 cypher_queries): - - 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (group_id, name) 检索候选; + - 使用全文索引查询 `SEARCH_ENTITIES_BY_NAME` 按 (end_user_id, name) 检索候选; - 保留并发控制与返回结构(incoming_id -> [db_entity_props...]); - 若提供 `entity_type`,在本地对返回结果做类型过滤; - `use_contains_fallback` 保留形参以兼容,必要时可扩展二次查询策略。 @@ -477,7 +461,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, q=name, - group_id=group_id, + end_user_id=end_user_id, limit=100, ) except Exception: @@ -501,7 +485,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 rows = await connector.execute_query( SEARCH_ENTITIES_BY_NAME, q=name.lower(), - group_id=group_id, + end_user_id=end_user_id, limit=100, ) for r in rows: @@ -532,9 +516,7 @@ async def get_dedup_candidates_for_entities( # 适配新版查询:使用全 async def search_graph_by_keyword_temporal( connector: Neo4jConnector, query_text: str, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -547,32 +529,30 @@ async def search_graph_by_keyword_temporal( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements containing query_text created between start_date and end_date - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ if not query_text: - logger.warning(f"query_text cannot be empty") + print(f"query_text不能为空") return {"statements": []} statements = await connector.execute_query( SEARCH_STATEMENTS_BY_KEYWORD_TEMPORAL, q=query_text, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, start_date=start_date, end_date=end_date, valid_date=valid_date, invalid_date=invalid_date, limit=limit, ) - logger.debug(f"Temporal keyword search results: {len(statements)} statements found") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results @@ -580,9 +560,7 @@ async def search_graph_by_keyword_temporal( async def search_graph_by_temporal( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None, valid_date: Optional[str] = None, @@ -595,14 +573,12 @@ async def search_graph_by_temporal( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created between start_date and end_date - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_BY_TEMPORAL, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, start_date=start_date, end_date=end_date, valid_date=valid_date, @@ -610,16 +586,16 @@ async def search_graph_by_temporal( limit=limit, ) - logger.debug(f"Temporal search query: {SEARCH_STATEMENTS_BY_TEMPORAL}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, start_date={start_date}, end_date={end_date}, valid_date={valid_date}, invalid_date={invalid_date}, limit={limit}") - logger.debug(f"Temporal search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results @@ -628,23 +604,23 @@ async def search_graph_by_temporal( async def search_graph_by_dialog_id( connector: Neo4jConnector, dialog_id: str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Dialogues. - Matches dialogues with dialog_id - - Optionally filters by group_id + - Optionally filters by end_user_id - Returns up to 'limit' dialogues """ if not dialog_id: - logger.warning(f"dialog_id cannot be empty") + print(f"dialog_id不能为空") return {"dialogues": []} dialogues = await connector.execute_query( SEARCH_DIALOGUE_BY_DIALOG_ID, - group_id=group_id, + end_user_id=end_user_id, dialog_id=dialog_id, limit=limit, ) @@ -654,15 +630,15 @@ async def search_graph_by_dialog_id( async def search_graph_by_chunk_id( connector: Neo4jConnector, chunk_id : str, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: if not chunk_id: - logger.warning(f"chunk_id cannot be empty") + print(f"chunk_id不能为空") return {"chunks": []} chunks = await connector.execute_query( SEARCH_CHUNK_BY_CHUNK_ID, - group_id=group_id, + end_user_id=end_user_id, chunk_id=chunk_id, limit=limit, ) @@ -671,9 +647,9 @@ async def search_graph_by_chunk_id( async def search_graph_by_created_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -683,37 +659,37 @@ async def search_graph_by_created_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_BY_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + created_at=created_at, limit=limit, ) - logger.debug(f"Search by created_at query: {SEARCH_STATEMENTS_BY_CREATED_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id} created_at: {created_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_by_valid_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -723,37 +699,37 @@ async def search_graph_by_valid_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_BY_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + valid_at=valid_at, limit=limit, ) - logger.debug(f"Search by valid_at query: {SEARCH_STATEMENTS_BY_VALID_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_g_created_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -763,37 +739,37 @@ async def search_graph_g_created_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_G_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + created_at=created_at, limit=limit, ) - logger.debug(f"Search greater than created_at query: {SEARCH_STATEMENTS_G_CREATED_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_g_valid_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -803,37 +779,37 @@ async def search_graph_g_valid_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_G_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + valid_at=valid_at, limit=limit, ) - logger.debug(f"Search greater than valid_at query: {SEARCH_STATEMENTS_G_VALID_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_l_created_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + created_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -843,37 +819,37 @@ async def search_graph_l_created_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_L_CREATED_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + created_at=created_at, limit=limit, ) - logger.debug(f"Search less than created_at query: {SEARCH_STATEMENTS_L_CREATED_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, created_at={created_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, created_at: {created_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results async def search_graph_l_valid_at( connector: Neo4jConnector, - group_id: Optional[str] = None, - apply_id: Optional[str] = None, - user_id: Optional[str] = None, + end_user_id: Optional[str] = None, + + valid_at: Optional[str] = None, limit: int = 1, ) -> Dict[str, List[Dict[str, Any]]]: @@ -883,28 +859,28 @@ async def search_graph_l_valid_at( INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - - Optionally filters by group_id, apply_id, user_id + - Optionally filters by end_user_id, apply_id, user_id - Returns up to 'limit' statements """ statements = await connector.execute_query( SEARCH_STATEMENTS_L_VALID_AT, - group_id=group_id, - apply_id=apply_id, - user_id=user_id, + end_user_id=end_user_id, + + valid_at=valid_at, limit=limit, ) - logger.debug(f"Search less than valid_at query: {SEARCH_STATEMENTS_L_VALID_AT}") - logger.debug(f"Query params: group_id={group_id}, apply_id={apply_id}, user_id={user_id}, valid_at={valid_at}, limit={limit}") - logger.debug(f"Search results: {len(statements)} statements found") + print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") + print(f"查询参数为:\n{{end_user_id: {end_user_id}, valid_at: {valid_at}, limit: {limit}}}") + print(f"查询结果为:\n{statements}") # 更新 Statement 节点的激活值 results = {"statements": statements} results = await _update_search_results_activation( connector=connector, results=results, - group_id=group_id + end_user_id=end_user_id ) return results diff --git a/api/app/repositories/neo4j/memory_summary_repository.py b/api/app/repositories/neo4j/memory_summary_repository.py index fc743f33..2564aeab 100644 --- a/api/app/repositories/neo4j/memory_summary_repository.py +++ b/api/app/repositories/neo4j/memory_summary_repository.py @@ -18,7 +18,7 @@ class MemorySummaryRepository(BaseNeo4jRepository): """Memory Summary Repository Manages CRUD operations for MemorySummary nodes. - Provides methods to query summaries by group_id, user_id, and time ranges. + Provides methods to query summaries by end_user_id, user_id, and time ranges. Attributes: connector: Neo4j connector instance @@ -51,17 +51,17 @@ class MemorySummaryRepository(BaseNeo4jRepository): return dict(n) - async def find_by_group_id( + async def find_by_end_user_id( self, - group_id: str, + end_user_id: str, limit: int = 1000, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> List[Dict[str, Any]]: - """Query memory summaries by group_id + """Query memory summaries by end_user_id Args: - group_id: Group ID to filter by + end_user_id: Group ID to filter by limit: Maximum number of results to return start_date: Optional start date filter end_date: Optional end date filter @@ -71,10 +71,10 @@ class MemorySummaryRepository(BaseNeo4jRepository): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id """ - params = {"group_id": group_id, "limit": limit} + params = {"end_user_id": end_user_id, "limit": limit} # Add date range filters if provided if start_date: @@ -139,16 +139,16 @@ class MemorySummaryRepository(BaseNeo4jRepository): async def find_by_group_and_user( self, - group_id: str, + end_user_id: str, user_id: str, limit: int = 1000, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None ) -> List[Dict[str, Any]]: - """Query memory summaries by both group_id and user_id + """Query memory summaries by both end_user_id and user_id Args: - group_id: Group ID to filter by + end_user_id: Group ID to filter by user_id: User ID to filter by limit: Maximum number of results to return start_date: Optional start date filter @@ -159,10 +159,10 @@ class MemorySummaryRepository(BaseNeo4jRepository): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id AND n.user_id = $user_id + WHERE n.end_user_id = $end_user_id AND n.user_id = $user_id """ - params = {"group_id": group_id, "user_id": user_id, "limit": limit} + params = {"end_user_id": end_user_id, "user_id": user_id, "limit": limit} # Add date range filters if provided if start_date: @@ -184,14 +184,14 @@ class MemorySummaryRepository(BaseNeo4jRepository): async def find_recent_summaries( self, - group_id: str, + end_user_id: str, days: int = 7, limit: int = 1000 ) -> List[Dict[str, Any]]: """Query recent memory summaries Args: - group_id: Group ID to filter by + end_user_id: Group ID to filter by days: Number of recent days to query limit: Maximum number of results to return @@ -200,7 +200,7 @@ class MemorySummaryRepository(BaseNeo4jRepository): """ query = f""" MATCH (n:{self.node_label}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id AND n.created_at >= datetime() - duration({{days: $days}}) RETURN n ORDER BY n.created_at DESC diff --git a/api/app/repositories/neo4j/neo4j_connector.py b/api/app/repositories/neo4j/neo4j_connector.py index 7c4b43b5..456c4e08 100644 --- a/api/app/repositories/neo4j/neo4j_connector.py +++ b/api/app/repositories/neo4j/neo4j_connector.py @@ -141,14 +141,14 @@ class Neo4jConnector: async with self.driver.session(database="neo4j") as session: return await session.execute_read(transaction_func, **kwargs) - async def delete_group(self, group_id: str): + async def delete_group(self, end_user_id: str): """删除指定组的所有数据 - 删除所有属于指定group_id的节点和边。 + 删除所有属于指定end_user_id的节点和边。 这是一个危险操作,会永久删除数据。 Args: - group_id: 要删除的组ID + end_user_id: 要删除的组ID Example: >>> connector = Neo4jConnector() @@ -157,14 +157,14 @@ class Neo4jConnector: """ # 删除节点(DETACH DELETE会同时删除相关的边) await self.driver.execute_query( - "MATCH (n) WHERE n.group_id = $group_id DETACH DELETE n", + "MATCH (n) WHERE n.end_user_id = $end_user_id DETACH DELETE n", database="neo4j", - group_id=group_id + end_user_id=end_user_id ) # 删除独立的边(如果有的话) await self.driver.execute_query( - "MATCH ()-[r]->() WHERE r.group_id = $group_id DELETE r", + "MATCH ()-[r]->() WHERE r.end_user_id = $end_user_id DELETE r", database="neo4j", - group_id=group_id + end_user_id=end_user_id ) - print(f"Group {group_id} deleted.") + print(f"Group {end_user_id} deleted.") diff --git a/api/app/repositories/neo4j/statement_repository.py b/api/app/repositories/neo4j/statement_repository.py index cd9f2fac..4f12af83 100644 --- a/api/app/repositories/neo4j/statement_repository.py +++ b/api/app/repositories/neo4j/statement_repository.py @@ -20,7 +20,7 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]): """陈述句仓储 管理陈述句节点的创建、查询、更新和删除操作。 - 提供按chunk_id、group_id、向量相似度等条件查询陈述句的方法。 + 提供按chunk_id、end_user_id、向量相似度等条件查询陈述句的方法。 Attributes: connector: Neo4j连接器实例 diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index d4354c40..e7b1fe65 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -7,11 +7,11 @@ class UserInput(BaseModel): message: str history: list[dict] search_switch: str - group_id: str + end_user_id: str config_id: Optional[str] = None class Write_UserInput(BaseModel): messages: list[dict] - group_id: str + end_user_id: str config_id: Optional[str] = None diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 46bda5f6..2dd06e89 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -92,7 +92,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str try: memory_content = asyncio.run( MemoryAgentService().read_memory( - group_id=end_user_id, + end_user_id=end_user_id, message=question, history=[], search_switch="2", diff --git a/api/app/services/emotion_analytics_service.py b/api/app/services/emotion_analytics_service.py index 601d2921..19c6cef1 100644 --- a/api/app/services/emotion_analytics_service.py +++ b/api/app/services/emotion_analytics_service.py @@ -75,7 +75,7 @@ class EmotionAnalyticsService: # 调用仓储层查询 tags = await self.emotion_repo.get_emotion_tags( - group_id=end_user_id, + end_user_id=end_user_id, emotion_type=emotion_type, start_date=start_date, end_date=end_date, @@ -157,7 +157,7 @@ class EmotionAnalyticsService: # 调用仓储层查询 keywords = await self.emotion_repo.get_emotion_wordcloud( - group_id=end_user_id, + end_user_id=end_user_id, emotion_type=emotion_type, limit=limit ) @@ -339,7 +339,7 @@ class EmotionAnalyticsService: # 获取时间范围内的情绪数据 emotions = await self.emotion_repo.get_emotions_in_range( - group_id=end_user_id, + end_user_id=end_user_id, time_range=time_range ) @@ -519,7 +519,7 @@ class EmotionAnalyticsService: # 3. 获取情绪数据用于模式分析 emotions = await self.emotion_repo.get_emotions_in_range( - group_id=end_user_id, + end_user_id=end_user_id, time_range="30d" ) @@ -598,13 +598,13 @@ class EmotionAnalyticsService: # 查询用户的实体和标签 query = """ MATCH (e:Entity) - WHERE e.group_id = $group_id + WHERE e.end_user_id = $end_user_id RETURN e.name as name, e.type as type ORDER BY e.created_at DESC LIMIT 20 """ - entities = await connector.execute_query(query, group_id=end_user_id) + entities = await connector.execute_query(query, end_user_id=end_user_id) # 提取兴趣标签 interests = [e["name"] for e in entities if e.get("type") in ["INTEREST", "HOBBY"]][:5] diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 8170bdd8..e475bef0 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -27,6 +27,7 @@ from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType +from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError @@ -54,25 +55,25 @@ _neo4j_connector = Neo4jConnector() class MemoryAgentService: """Service for memory agent operations""" - def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context): + def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context): duration = time.time() - start_time if str(messages) == 'success': - logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") + logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, duration=duration, details={"message_length": len(message)}) return context else: - logger.warning(f"Write operation failed for group {group_id}") + logger.warning(f"Write operation failed for group {end_user_id}") # 记录失败的操作 if audit_logger: audit_logger.log_operation( operation="WRITE", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=False, duration=duration, error=f"写入失败: {messages[:100]}" @@ -265,13 +266,13 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, group_id: str, messages: list[dict], config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: + async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id Args: - group_id: Group identifier (also used as end_user_id) - messages: Structured message list [{"role": "user", "content": "..."}, ...] + end_user_id: Group identifier (also used as end_user_id) + message: Message to write config_id: Configuration ID from database db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) @@ -286,15 +287,15 @@ class MemoryAgentService: # Resolve config_id if None using end_user's connected config if config_id is None: try: - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) config_id = connected_config.get("memory_config_id") if config_id is None: - raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): - raise - logger.error(f"Failed to get connected config for end_user {group_id}: {e}") - raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") + raise # Re-raise our specific error + logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}") + raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") import time start_time = time.time() @@ -314,7 +315,7 @@ class MemoryAgentService: # Log failed operation if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -322,11 +323,11 @@ class MemoryAgentService: if storage_type == "rag": # For RAG storage, convert messages to single string message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(group_id, message_text, user_rag_memory_id) + result = await write_rag(end_user_id, message_text, user_rag_memory_id) return result else: async with make_write_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # Convert structured messages to LangChain messages langchain_messages = [] for msg in messages: @@ -339,7 +340,7 @@ class MemoryAgentService: # 初始状态 - 包含所有必要字段 initial_state = { "messages": langchain_messages, - "group_id": group_id, + "end_user_id": end_user_id, "memory_config": memory_config } @@ -356,14 +357,14 @@ class MemoryAgentService: contents = massages.get('write_result') # Convert messages back to string for logging message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message_text, contents) + return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -371,15 +372,14 @@ class MemoryAgentService: async def read_memory( self, - group_id: str, + end_user_id: str, message: str, history: List[Dict], search_switch: str, config_id: Optional[str], db: Session, storage_type: str, - user_rag_memory_id: str - ) -> Dict: + user_rag_memory_id: str) -> Dict: """ Process read operation with config_id @@ -389,7 +389,7 @@ class MemoryAgentService: - "2": Direct answer based on context Args: - group_id: Group identifier (also used as end_user_id) + end_user_id: Group identifier (also used as end_user_id) message: User message history: Conversation history search_switch: Search mode switch @@ -407,22 +407,22 @@ class MemoryAgentService: import time start_time = time.time() - logger.info(f"[PERF] read_memory started for group_id={group_id}, search_switch={search_switch}") + ori_message= message # Resolve config_id if None using end_user's connected config if config_id is None: try: - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) config_id = connected_config.get("memory_config_id") if config_id is None: - raise ValueError(f"No memory configuration found for end_user {group_id}. Please ensure the user has a connected memory configuration.") + raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error - logger.error(f"Failed to get connected config for end_user {group_id}: {e}") - raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") + logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}") + raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") - logger.info(f"Read operation for group {group_id} with config_id {config_id}") + logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") # 导入审计日志记录器 try: @@ -431,15 +431,13 @@ class MemoryAgentService: audit_logger = None - config_load_start = time.time() try: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( config_id=config_id, service_name="MemoryAgentService" ) - config_load_time = time.time() - config_load_start - logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}") + logger.info(f"Configuration loaded successfully: {memory_config.config_name}") except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) @@ -450,7 +448,7 @@ class MemoryAgentService: audit_logger.log_operation( operation="READ", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=False, duration=duration, error=error_msg @@ -460,16 +458,16 @@ class MemoryAgentService: # Step 2: Prepare history history.append({"role": "user", "content": message}) - logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") + logger.debug(f"Group ID:{end_user_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow graph_exec_start = time.time() try: async with make_read_graph() as graph: - config = {"configurable": {"thread_id": group_id}} + config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, - "group_id": group_id + "end_user_id": end_user_id , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "memory_config": memory_config} # 获取节点更新信息 @@ -565,13 +563,13 @@ class MemoryAgentService: if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # 使用 upsert 方法 repo.upsert( - end_user_id=group_id, + end_user_id=end_user_id, messages=message, aimessages=summary, retrieved_content=retrieved_content, search_switch=str(search_switch) ) - logger.info(f"成功保存短期记忆: group_id={group_id}, search_switch={search_switch}") + logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") else: logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") @@ -580,14 +578,12 @@ class MemoryAgentService: logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) # Log successful operation - total_time = time.time() - start_time - logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( operation="READ", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=True, duration=duration ) @@ -599,14 +595,13 @@ class MemoryAgentService: except Exception as e: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" - total_time = time.time() - start_time - logger.error(f"[PERF] read_memory failed after {total_time:.4f}s: {error_msg}") + logger.error(error_msg) if audit_logger: duration = time.time() - start_time audit_logger.log_operation( operation="READ", config_id=config_id, - group_id=group_id, + end_user_id=end_user_id, success=False, duration=duration, error=error_msg @@ -755,7 +750,7 @@ class MemoryAgentService: """ 统计知识库类型分布,包含: 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) - 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤) + 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/end_user_id 过滤) 3. total: 所有类型的总和 参数: @@ -841,11 +836,11 @@ class MemoryAgentService: for end_user in end_users: end_user_id_str = str(end_user.id) memory_query = """ - MATCH (n:Chunk) WHERE n.group_id = $group_id RETURN count(n) AS Count + MATCH (n:Chunk) WHERE n.end_user_id = $end_user_id RETURN count(n) AS Count """ neo4j_result = await _neo4j_connector.execute_query( memory_query, - group_id=end_user_id_str, + end_user_id=end_user_id_str, ) chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 total_chunks += chunk_count @@ -885,7 +880,7 @@ class MemoryAgentService: 获取指定用户的热门记忆标签 参数: - - end_user_id: 用户ID(可选),对应Neo4j中的group_id字段 + - end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段 - limit: 返回标签数量限制 返回格式: @@ -895,7 +890,7 @@ class MemoryAgentService: ] """ try: - # by_user=False 表示按 group_id 查询(在Neo4j中,group_id就是用户维度) + # by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度) tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) payload=[] for tag, freq in tags: @@ -970,21 +965,21 @@ class MemoryAgentService: # 查询该用户的语句 query = ( "MATCH (s:Statement) " - "WHERE ($group_id IS NULL OR s.group_id = $group_id) AND s.statement IS NOT NULL " + "WHERE ($end_user_id IS NULL OR s.end_user_id = $end_user_id) AND s.statement IS NOT NULL " "RETURN s.statement AS statement " "ORDER BY s.created_at DESC LIMIT 100" ) - rows = await connector.execute_query(query, group_id=end_user_id) + rows = await connector.execute_query(query, end_user_id=end_user_id) statements = [r.get("statement", "") for r in rows if r.get("statement")] # 查询该用户的热门实体 entity_query = ( "MATCH (e:ExtractedEntity) " - "WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL " + "WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' AND e.name IS NOT NULL " "RETURN e.name AS name, count(e) AS frequency " "ORDER BY frequency DESC LIMIT 20" ) - entity_rows = await connector.execute_query(entity_query, group_id=end_user_id) + entity_rows = await connector.execute_query(entity_query, end_user_id=end_user_id) entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows] await connector.close() @@ -1037,14 +1032,14 @@ class MemoryAgentService: names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria'] hot_tag_query = ( "MATCH (e:ExtractedEntity) " - "WHERE ($group_id IS NULL OR e.group_id = $group_id) AND e.entity_type <> '人物' " + "WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) AND e.entity_type <> '人物' " "AND e.name IS NOT NULL AND NOT e.name IN $names_to_exclude " "RETURN e.name AS name, count(e) AS frequency " "ORDER BY frequency DESC LIMIT 4" ) hot_tag_rows = await connector.execute_query( hot_tag_query, - group_id=end_user_id, + end_user_id=end_user_id, names_to_exclude=names_to_exclude ) await connector.close() @@ -1190,6 +1185,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An "memory_config_id": memory_config_id } + print(188*'*') + print(result) + print(188 * '*') + logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") return result @@ -1230,10 +1229,10 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 1. 批量查询所有 end_user 及其 app_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - + # 创建 end_user_id -> app_id 的映射 user_to_app = {str(eu.id): eu.app_id for eu in end_users} - + # 记录未找到的用户 found_user_ids = set(user_to_app.keys()) missing_user_ids = set(end_user_ids) - found_user_ids @@ -1275,13 +1274,13 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 批量查询 memory_config_name config_id_to_name = {} if memory_config_ids: - memory_configs = db.query(DataConfig).filter(DataConfig.config_id.in_(memory_config_ids)).all() - config_id_to_name = {str(mc.config_id): mc.config_name for mc in memory_configs} + memory_configs = db.query(MemoryConfig).filter(MemoryConfig.id.in_(memory_config_ids)).all() + config_id_to_name = {str(mc.id): mc.config_name for mc in memory_configs} # 4. 构建最终结果 for end_user_id, app_id in user_to_app.items(): release = app_to_release.get(app_id) - + if not release: logger.warning(f"No active release found for app: {app_id} (end_user: {end_user_id})") result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} @@ -1293,7 +1292,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None # 获取配置名称 - memory_config_name = config_id_to_name.get(str(memory_config_id)) if memory_config_id else None + memory_config_name = config_id_to_name.get(memory_config_id) if memory_config_id else None result[end_user_id] = { "memory_config_id": memory_config_id, diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 0ae2b965..c33c9c6b 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -25,7 +25,7 @@ class MemoryAPIService: This service provides a thin layer that: 1. Validates end_user exists and belongs to the authorized workspace - 2. Maps end_user_id to group_id for memory operations + 2. Maps end_user_id to end_user_id for memory operations 3. Delegates to MemoryAgentService for actual memory read/write operations """ @@ -68,7 +68,7 @@ class MemoryAPIService: ) end_user = self.db.query(EndUser).filter(EndUser.id == end_user_uuid).first() - + if not end_user: logger.warning(f"End user not found: {end_user_id}") raise ResourceNotFoundException( @@ -115,7 +115,7 @@ class MemoryAPIService: Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as group_id) + end_user_id: End user identifier (used as end_user_id) message: Message content to store config_id: Optional memory configuration ID storage_type: Storage backend (neo4j or rag) @@ -133,13 +133,12 @@ class MemoryAPIService: # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - # Use end_user_id as group_id for memory operations - group_id = end_user_id + # Use end_user_id as end_user_id for memory operations try: # Delegate to MemoryAgentService result = await MemoryAgentService().write_memory( - group_id=group_id, + end_user_id=end_user_id, message=message, config_id=config_id, db=self.db, @@ -186,7 +185,7 @@ class MemoryAPIService: Args: workspace_id: Workspace ID for resource validation - end_user_id: End user identifier (used as group_id) + end_user_id: End user identifier (used as end_user_id) message: Query message search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search) config_id: Optional memory configuration ID @@ -205,13 +204,13 @@ class MemoryAPIService: # Validate end_user exists and belongs to workspace self.validate_end_user(end_user_id, workspace_id) - # Use end_user_id as group_id for memory operations - group_id = end_user_id + # Use end_user_id as end_user_id for memory operations + try: # Delegate to MemoryAgentService result = await MemoryAgentService().read_memory( - group_id=group_id, + end_user_id=end_user_id, message=message, history=[], search_switch=search_switch, diff --git a/api/app/services/memory_base_service.py b/api/app/services/memory_base_service.py index 25a8281d..bc647752 100644 --- a/api/app/services/memory_base_service.py +++ b/api/app/services/memory_base_service.py @@ -326,7 +326,7 @@ class MemoryBaseService: Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: 最大emotion_intensity对应的emotion_type,如果没有则返回None @@ -334,7 +334,7 @@ class MemoryBaseService: try: query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) WHERE stmt.emotion_type IS NOT NULL AND stmt.emotion_intensity IS NOT NULL @@ -347,7 +347,7 @@ class MemoryBaseService: result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) if result and len(result) > 0: @@ -381,10 +381,10 @@ class MemoryBaseService: if end_user_id: query = """ MATCH (n:MemorySummary) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN count(n) as count """ - result = await self.neo4j_connector.execute_query(query, group_id=end_user_id) + result = await self.neo4j_connector.execute_query(query, end_user_id=end_user_id) else: query = """ MATCH (n:MemorySummary) @@ -423,12 +423,12 @@ class MemoryBaseService: if end_user_id: semantic_query = """ MATCH (e:ExtractedEntity) - WHERE e.group_id = $group_id AND e.is_explicit_memory = true + WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true RETURN count(e) as count """ semantic_result = await self.neo4j_connector.execute_query( semantic_query, - group_id=end_user_id + end_user_id=end_user_id ) else: semantic_query = """ @@ -519,7 +519,7 @@ class MemoryBaseService: """ if end_user_id: - query += " AND n.group_id = $group_id" + query += " AND n.end_user_id = $end_user_id" query += """ RETURN sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes @@ -528,7 +528,7 @@ class MemoryBaseService: # 设置查询参数 params = {'threshold': forgetting_threshold} if end_user_id: - params['group_id'] = end_user_id + params['end_user_id'] = end_user_id # 执行查询 result = await self.neo4j_connector.execute_query(query, **params) diff --git a/api/app/services/memory_entity_relationship_service.py b/api/app/services/memory_entity_relationship_service.py index 9b5f3c99..7081d28b 100644 --- a/api/app/services/memory_entity_relationship_service.py +++ b/api/app/services/memory_entity_relationship_service.py @@ -717,8 +717,8 @@ class MemoryInteraction: ori_data= await self.connector.execute_query(Memory_Space_Entity, id=self.id) if ori_data!=[]: # name = ori_data[0]['name'] - group_id = [i['group_id'] for i in ori_data][0] - Space_User = await self.connector.execute_query(Memory_Space_User, group_id=group_id) + end_user_id = [i['end_user_id'] for i in ori_data][0] + Space_User = await self.connector.execute_query(Memory_Space_User, end_user_id=end_user_id) if not Space_User: return [] user_id=Space_User[0]['id'] diff --git a/api/app/services/memory_episodic_service.py b/api/app/services/memory_episodic_service.py index 12eeff6e..08751fd1 100644 --- a/api/app/services/memory_episodic_service.py +++ b/api/app/services/memory_episodic_service.py @@ -34,7 +34,7 @@ class MemoryEpisodicService(MemoryBaseService): Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: (标题, 类型)元组,如果不存在则返回默认值 @@ -43,14 +43,14 @@ class MemoryEpisodicService(MemoryBaseService): # 查询Summary节点的name(作为title)和memory_type(作为type) query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id RETURN s.name AS title, s.memory_type AS type """ result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) if not result or len(result) == 0: @@ -77,7 +77,7 @@ class MemoryEpisodicService(MemoryBaseService): Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: 前3个实体的name属性列表 @@ -87,7 +87,7 @@ class MemoryEpisodicService(MemoryBaseService): # 按activation_value降序排序,返回前3个 query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) MATCH (stmt)-[:REFERENCES_ENTITY]->(entity:ExtractedEntity) WHERE entity.activation_value IS NOT NULL @@ -99,7 +99,7 @@ class MemoryEpisodicService(MemoryBaseService): result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) # 提取实体名称 @@ -123,7 +123,7 @@ class MemoryEpisodicService(MemoryBaseService): Args: summary_id: Summary节点的ID - end_user_id: 终端用户ID (group_id) + end_user_id: 终端用户ID (end_user_id) Returns: 所有Statement节点的statement属性内容列表 @@ -132,7 +132,7 @@ class MemoryEpisodicService(MemoryBaseService): # 查询Summary节点指向的所有Statement节点 query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id MATCH (s)-[:DERIVED_FROM_STATEMENT]->(stmt:Statement) WHERE stmt.statement IS NOT NULL AND stmt.statement <> '' RETURN stmt.statement AS statement @@ -141,7 +141,7 @@ class MemoryEpisodicService(MemoryBaseService): result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) # 提取statement内容 @@ -214,12 +214,12 @@ class MemoryEpisodicService(MemoryBaseService): # 1. 先查询所有情景记忆的总数(不受筛选条件限制) total_all_query = """ MATCH (s:MemorySummary) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id RETURN count(s) AS total_all """ total_all_result = await self.neo4j_connector.execute_query( total_all_query, - group_id=end_user_id + end_user_id=end_user_id ) total_all = total_all_result[0]["total_all"] if total_all_result else 0 @@ -229,7 +229,7 @@ class MemoryEpisodicService(MemoryBaseService): # 3. 构建Cypher查询 query = """ MATCH (s:MemorySummary) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id """ # 添加时间范围过滤 @@ -248,7 +248,7 @@ class MemoryEpisodicService(MemoryBaseService): ORDER BY s.created_at DESC """ - params = {"group_id": end_user_id} + params = {"end_user_id": end_user_id} if time_filter: params["time_filter"] = time_filter if title_keyword: @@ -333,14 +333,14 @@ class MemoryEpisodicService(MemoryBaseService): # 1. 查询指定的MemorySummary节点 query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $summary_id AND s.group_id = $group_id + WHERE elementId(s) = $summary_id AND s.end_user_id = $end_user_id RETURN elementId(s) AS id, s.created_at AS created_at """ result = await self.neo4j_connector.execute_query( query, summary_id=summary_id, - group_id=end_user_id + end_user_id=end_user_id ) # 2. 如果节点不存在,返回错误 diff --git a/api/app/services/memory_explicit_service.py b/api/app/services/memory_explicit_service.py index 713215c3..f8d39ae8 100644 --- a/api/app/services/memory_explicit_service.py +++ b/api/app/services/memory_explicit_service.py @@ -60,7 +60,7 @@ class MemoryExplicitService(MemoryBaseService): # ========== 1. 查询情景记忆(MemorySummary节点) ========== episodic_query = """ MATCH (s:MemorySummary) - WHERE s.group_id = $group_id + WHERE s.end_user_id = $end_user_id RETURN elementId(s) AS id, s.name AS title, s.content AS content, @@ -70,7 +70,7 @@ class MemoryExplicitService(MemoryBaseService): episodic_result = await self.neo4j_connector.execute_query( episodic_query, - group_id=end_user_id + end_user_id=end_user_id ) # 处理情景记忆数据 @@ -96,7 +96,7 @@ class MemoryExplicitService(MemoryBaseService): # ========== 2. 查询语义记忆(ExtractedEntity节点) ========== semantic_query = """ MATCH (e:ExtractedEntity) - WHERE e.group_id = $group_id + WHERE e.end_user_id = $end_user_id AND e.is_explicit_memory = true RETURN elementId(e) AS id, e.name AS name, @@ -107,7 +107,7 @@ class MemoryExplicitService(MemoryBaseService): semantic_result = await self.neo4j_connector.execute_query( semantic_query, - group_id=end_user_id + end_user_id=end_user_id ) # 处理语义记忆数据 @@ -189,7 +189,7 @@ class MemoryExplicitService(MemoryBaseService): # ========== 1. 先尝试查询情景记忆 ========== episodic_query = """ MATCH (s:MemorySummary) - WHERE elementId(s) = $memory_id AND s.group_id = $group_id + WHERE elementId(s) = $memory_id AND s.end_user_id = $end_user_id RETURN s.name AS title, s.content AS content, s.created_at AS created_at @@ -198,7 +198,7 @@ class MemoryExplicitService(MemoryBaseService): episodic_result = await self.neo4j_connector.execute_query( episodic_query, memory_id=memory_id, - group_id=end_user_id + end_user_id=end_user_id ) if episodic_result and len(episodic_result) > 0: @@ -229,7 +229,7 @@ class MemoryExplicitService(MemoryBaseService): semantic_query = """ MATCH (e:ExtractedEntity) WHERE elementId(e) = $memory_id - AND e.group_id = $group_id + AND e.end_user_id = $end_user_id AND e.is_explicit_memory = true RETURN e.name AS name, e.description AS core_definition, @@ -240,7 +240,7 @@ class MemoryExplicitService(MemoryBaseService): semantic_result = await self.neo4j_connector.execute_query( semantic_query, memory_id=memory_id, - group_id=end_user_id + end_user_id=end_user_id ) if semantic_result and len(semantic_result) > 0: diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py index 2db4cdc7..558efe43 100644 --- a/api/app/services/memory_forget_service.py +++ b/api/app/services/memory_forget_service.py @@ -132,7 +132,7 @@ class MemoryForgetService: async def _get_knowledge_stats( self, connector: Neo4jConnector, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, forgetting_threshold: float = 0.3 ) -> Dict[str, Any]: """ @@ -140,7 +140,7 @@ class MemoryForgetService: Args: connector: Neo4j 连接器 - group_id: 组ID(可选) + end_user_id: 组ID(可选) forgetting_threshold: 遗忘阈值 Returns: @@ -152,8 +152,8 @@ class MemoryForgetService: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) """ - if group_id: - query += " AND n.group_id = $group_id" + if end_user_id: + query += " AND n.end_user_id = $end_user_id" query += """ WITH n, @@ -172,8 +172,8 @@ class MemoryForgetService: """ params = {'threshold': forgetting_threshold} - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id results = await connector.execute_query(query, **params) @@ -200,7 +200,7 @@ class MemoryForgetService: async def _get_pending_forgetting_nodes( self, connector: Neo4jConnector, - group_id: str, + end_user_id: str, forgetting_threshold: float, min_days_since_access: int, limit: int = 20 @@ -212,7 +212,7 @@ class MemoryForgetService: Args: connector: Neo4j 连接器 - group_id: 组ID + end_user_id: 组ID forgetting_threshold: 遗忘阈值 min_days_since_access: 最小未访问天数 limit: 返回节点数量限制 @@ -229,7 +229,7 @@ class MemoryForgetService: query = """ MATCH (n) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) - AND n.group_id = $group_id + AND n.end_user_id = $end_user_id AND n.activation_value IS NOT NULL AND n.activation_value < $threshold AND n.last_access_time IS NOT NULL @@ -250,7 +250,7 @@ class MemoryForgetService: """ params = { - 'group_id': group_id, + 'end_user_id': end_user_id, 'threshold': forgetting_threshold, 'min_access_time_str': min_access_time_str, 'limit': limit @@ -291,7 +291,7 @@ class MemoryForgetService: async def trigger_forgetting_cycle( self, db: Session, - group_id: str, + end_user_id: str, max_merge_batch_size: Optional[int] = None, min_days_since_access: Optional[int] = None, config_id: Optional[int] = None @@ -303,10 +303,10 @@ class MemoryForgetService: Args: db: 数据库会话 - group_id: 组ID(即终端用户ID,必填) + end_user_id: 组ID(即终端用户ID,必填) max_merge_batch_size: 最大融合批次大小(可选) min_days_since_access: 最小未访问天数(可选) - config_id: 配置ID(必填,由控制器层通过 group_id 获取) + config_id: 配置ID(必填,由控制器层通过 end_user_id 获取) Returns: dict: 遗忘报告 @@ -319,7 +319,7 @@ class MemoryForgetService: # 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取) report = await forgetting_scheduler.run_forgetting_cycle( - group_id=group_id, + end_user_id=end_user_id, max_merge_batch_size=max_merge_batch_size, min_days_since_access=min_days_since_access, config_id=config_id, @@ -338,7 +338,7 @@ class MemoryForgetService: stats_query = """ MATCH (n) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) - AND n.group_id = $group_id + AND n.end_user_id = $end_user_id RETURN count(n) as total_nodes, avg(n.activation_value) as average_activation, @@ -347,7 +347,7 @@ class MemoryForgetService: stats_results = await connector.execute_query( stats_query, - group_id=group_id, + end_user_id=end_user_id, threshold=config['forgetting_threshold'] ) @@ -364,7 +364,7 @@ class MemoryForgetService: # 保存历史记录到数据库 self.history_repository.create( db=db, - end_user_id=group_id, + end_user_id=end_user_id, execution_time=execution_time, merged_count=report['merged_count'], failed_count=report['failed_count'], @@ -376,7 +376,7 @@ class MemoryForgetService: ) api_logger.info( - f"已保存遗忘周期历史记录: end_user_id={group_id}, " + f"已保存遗忘周期历史记录: end_user_id={end_user_id}, " f"merged_count={report['merged_count']}" ) @@ -465,7 +465,7 @@ class MemoryForgetService: async def get_forgetting_stats( self, db: Session, - group_id: Optional[str] = None, + end_user_id: Optional[str] = None, config_id: Optional[int] = None ) -> Dict[str, Any]: """ @@ -475,7 +475,7 @@ class MemoryForgetService: Args: db: 数据库会话 - group_id: 组ID(可选) + end_user_id: 组ID(可选) config_id: 配置ID(可选,用于获取遗忘阈值) Returns: @@ -493,8 +493,8 @@ class MemoryForgetService: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) """ - if group_id: - activation_query += " AND n.group_id = $group_id" + if end_user_id: + activation_query += " AND n.end_user_id = $end_user_id" activation_query += """ RETURN @@ -506,8 +506,8 @@ class MemoryForgetService: """ params = {'threshold': forgetting_threshold} - if group_id: - params['group_id'] = group_id + if end_user_id: + params['end_user_id'] = end_user_id activation_results = await connector.execute_query(activation_query, **params) @@ -539,8 +539,8 @@ class MemoryForgetService: WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) """ - if group_id: - distribution_query += " AND n.group_id = $group_id" + if end_user_id: + distribution_query += " AND n.end_user_id = $end_user_id" distribution_query += """ WITH n, @@ -558,8 +558,8 @@ class MemoryForgetService: """ dist_params = {} - if group_id: - dist_params['group_id'] = group_id + if end_user_id: + dist_params['end_user_id'] = end_user_id distribution_results = await connector.execute_query(distribution_query, **dist_params) @@ -582,11 +582,11 @@ class MemoryForgetService: # 获取最近7个日期的历史趋势数据(每天取最后一次执行) recent_trends = [] try: - if group_id: + if end_user_id: # 查询所有历史记录 history_records = self.history_repository.get_recent_by_end_user( db=db, - end_user_id=group_id + end_user_id=end_user_id ) # 按日期分组(一天可能有多次执行,取最后一次) @@ -632,7 +632,7 @@ class MemoryForgetService: # 获取待遗忘节点列表(前20个满足遗忘条件的节点) pending_nodes = [] try: - if group_id: + if end_user_id: # 验证 min_days_since_access 配置值 min_days = config.get('min_days_since_access') if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0: @@ -643,7 +643,7 @@ class MemoryForgetService: pending_nodes = await self._get_pending_forgetting_nodes( connector=connector, - group_id=group_id, + end_user_id=end_user_id, forgetting_threshold=forgetting_threshold, min_days_since_access=int(min_days), limit=20 diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index c6297e12..420f7ca1 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -450,12 +450,12 @@ async def create_document_chunk( return success(data=chunk, msg="文档块创建成功") -async def write_rag(group_id, message, user_rag_memory_id): +async def write_rag(end_user_id, message, user_rag_memory_id): """ 将消息写入 RAG 知识库 Args: - group_id: 组ID,用作文件标题 + end_user_id: 组ID,用作文件标题 message: 消息内容 user_rag_memory_id: 知识库ID(必须是有效的UUID) @@ -487,10 +487,10 @@ async def write_rag(group_id, message, user_rag_memory_id): db = next(db_gen) try: - create_data = CustomTextFileCreate(title=group_id, content=message) + create_data = CustomTextFileCreate(title=end_user_id, content=message) current_user = SimpleUser(user_rag_memory_id) # 检查文档是否已存在 - document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{group_id}.txt") + document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") print('======',document) api_logger.info(f"查找文档结果: document_id={document}") if document is not None: @@ -508,7 +508,7 @@ async def write_rag(group_id, message, user_rag_memory_id): return result else: # 文档不存在,创建新文档 - api_logger.info(f"文档不存在,创建新文档: group_id={group_id}") + api_logger.info(f"文档不存在,创建新文档: end_user_id={end_user_id}") result = await memory_konwledges_up( kb_id=user_rag_memory_id, parent_id=user_rag_memory_id, @@ -520,13 +520,13 @@ async def write_rag(group_id, message, user_rag_memory_id): new_document_id = find_document_id_by_kb_and_filename( db=db, kb_id=user_rag_memory_id, - file_name=f"{group_id}.txt" + file_name=f"{end_user_id}.txt" ) if new_document_id: await parse_document_by_id(new_document_id, db=db, current_user=current_user) else: - api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}") + api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") return result finally: # 确保数据库会话被关闭 diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 83d5923d..05a84c01 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -183,7 +183,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "config_name": config.config_name, "config_desc": config.config_desc, "workspace_id": str(config.workspace_id) if config.workspace_id else None, - "group_id": config.group_id, + "end_user_id": config.end_user_id, "user_id": config.user_id, "apply_id": config.apply_id, "llm_id": config.llm_id, @@ -391,7 +391,7 @@ _neo4j_connector = Neo4jConnector() async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_DIALOGUE, - group_id=end_user_id, + end_user_id=end_user_id, ) data = {"search_for": "dialogue", "num": result[0]["num"]} return data @@ -400,7 +400,7 @@ async def search_dialogue(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_CHUNK, - group_id=end_user_id, + end_user_id=end_user_id, ) data = {"search_for": "chunk", "num": result[0]["num"]} return data @@ -409,7 +409,7 @@ async def search_chunk(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_STATEMENT, - group_id=end_user_id, + end_user_id=end_user_id, ) data = {"search_for": "statement", "num": result[0]["num"]} return data @@ -418,7 +418,7 @@ async def search_statement(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_ENTITY, - group_id=end_user_id, + end_user_id=end_user_id, ) data = {"search_for": "entity", "num": result[0]["num"]} return data @@ -427,7 +427,7 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]: result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_ALL, - group_id=end_user_id, + end_user_id=end_user_id, ) # 检查结果是否为空或长度不足 @@ -462,7 +462,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A """ result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_ALL, - group_id=end_user_id, + end_user_id=end_user_id, ) # 检查结果是否为空或长度不足 @@ -493,7 +493,7 @@ async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, A async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_DETIALS, - group_id=end_user_id, + end_user_id=end_user_id, ) return result @@ -501,11 +501,32 @@ async def search_detials(end_user_id: Optional[str] = None) -> List[Dict[str, An async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]]: result = await _neo4j_connector.execute_query( DataConfigRepository.SEARCH_FOR_EDGES, - group_id=end_user_id, + end_user_id=end_user_id, ) return result +async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, Any]: + """搜索所有实体之间的关系网络(group 维度)。""" + result = await _neo4j_connector.execute_query( + DataConfigRepository.SEARCH_FOR_ENTITY_GRAPH, + end_user_id=end_user_id, + ) + # 对source_node 和 target_node 的 fact_summary进行截取,只截取前三条的内容(需要提取前三条“来源”) + for item in result: + source_fact = item["sourceNode"]["fact_summary"] + target_fact = item["targetNode"]["fact_summary"] + # 截取前三条“来源” + item["sourceNode"]["fact_summary"] = source_fact.split("\n")[:4] if source_fact else [] + item["targetNode"]["fact_summary"] = target_fact.split("\n")[:4] if target_fact else [] + # 与现有返回风格保持一致,携带搜索类型、数量与详情 + data = { + "search_for": "entity_graph", + "num": len(result), + "detials": result, + } + return data + async def analytics_hot_memory_tags( db: Session, diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 17dfd7eb..755dda14 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -91,7 +91,7 @@ async def run_pilot_extraction( dialog = DialogData( context=context, ref_id="pilot_dialog_1", - group_id=str(memory_config.workspace_id), + end_user_id=str(memory_config.workspace_id), user_id=str(memory_config.tenant_id), apply_id=str(memory_config.config_id), metadata={"source": "pilot_run", "input_type": "frontend_text"}, diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 863bccb0..3a90a821 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -155,10 +155,10 @@ class MemoryInsightHelper: """ query = """ MATCH (d:Dialogue) - WHERE d.group_id = $group_id AND d.created_at IS NOT NULL AND d.created_at <> '' + WHERE d.end_user_id = $end_user_id AND d.created_at IS NOT NULL AND d.created_at <> '' RETURN d.created_at AS creation_time """ - records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id) if not records: return [] @@ -211,17 +211,17 @@ class MemoryInsightHelper: async def get_social_connections(self) -> dict | None: """Find the user with whom the most memories are shared.""" query = """ - MATCH (c1:Chunk {group_id: $group_id}) + MATCH (c1:Chunk {end_user_id: $end_user_id}) OPTIONAL MATCH (c1)-[:CONTAINS]->(s:Statement) OPTIONAL MATCH (s)<-[:CONTAINS]-(c2:Chunk) - WHERE c1.group_id <> c2.group_id AND s IS NOT NULL AND c2 IS NOT NULL - WITH c2.group_id AS other_user_id, COUNT(DISTINCT s) AS common_statements + WHERE c1.end_user_id <> c2.end_user_id AND s IS NOT NULL AND c2 IS NOT NULL + WITH c2.end_user_id AS other_user_id, COUNT(DISTINCT s) AS common_statements WHERE common_statements > 0 RETURN other_user_id, common_statements ORDER BY common_statements DESC LIMIT 1 """ - records = await self.neo4j_connector.execute_query(query, group_id=self.user_id) + records = await self.neo4j_connector.execute_query(query, end_user_id=self.user_id) if not records or not records[0].get("other_user_id"): return None @@ -230,7 +230,7 @@ class MemoryInsightHelper: time_range_query = """ MATCH (c:Chunk) - WHERE c.group_id IN [$user_id, $other_user_id] + WHERE c.end_user_id IN [$user_id, $other_user_id] RETURN min(c.created_at) AS start_time, max(c.created_at) AS end_time """ time_records = await self.neo4j_connector.execute_query( @@ -294,11 +294,11 @@ class UserSummaryHelper: """Fetch recent statements authored by the user/group for context.""" query = ( "MATCH (s:Statement) " - "WHERE s.group_id = $group_id AND s.statement IS NOT NULL " + "WHERE s.end_user_id = $end_user_id AND s.statement IS NOT NULL " "RETURN s.statement AS statement, s.created_at AS created_at " "ORDER BY created_at DESC LIMIT $limit" ) - rows = await self.connector.execute_query(query, group_id=self.user_id, limit=limit) + rows = await self.connector.execute_query(query, end_user_id=self.user_id, limit=limit) records = [] for r in rows: try: @@ -1152,7 +1152,7 @@ async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, import re # 创建 UserSummaryHelper 实例 - user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_GROUP_ID", "group_123")) + user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123")) try: # 1) 收集上下文数据 @@ -1273,10 +1273,10 @@ async def analytics_node_statistics( if end_user_id: query = f""" MATCH (n:{node_type}) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN count(n) as count """ - result = await _neo4j_connector.execute_query(query, group_id=end_user_id) + result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id) else: query = f""" MATCH (n:{node_type}) @@ -1387,10 +1387,10 @@ async def analytics_memory_types( # 查询 Statement 节点数量 query = """ MATCH (n:Statement) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN count(n) as count """ - result = await _neo4j_connector.execute_query(query, group_id=end_user_id) + result = await _neo4j_connector.execute_query(query, end_user_id=end_user_id) statement_count = result[0]["count"] if result and len(result) > 0 else 0 # 取三分之一作为隐性记忆数量 implicit_count = round(statement_count / 3) @@ -1504,7 +1504,7 @@ async def analytics_graph_data( 包含节点、边和统计信息的字典 """ try: - # 1. 获取 group_id + # 1. 获取 end_user_id user_uuid = uuid.UUID(end_user_id) repo = EndUserRepository(db) end_user = repo.get_by_id(user_uuid) @@ -1528,7 +1528,7 @@ async def analytics_graph_data( # 基于中心节点的扩展查询 node_query = f""" MATCH path = (center)-[*1..{depth}]-(connected) - WHERE center.group_id = $group_id + WHERE center.end_user_id = $end_user_id AND elementId(center) = $center_node_id WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes UNWIND all_nodes as n @@ -1539,7 +1539,7 @@ async def analytics_graph_data( LIMIT $limit """ node_params = { - "group_id": end_user_id, + "end_user_id": end_user_id, "center_node_id": center_node_id, "limit": limit } @@ -1547,7 +1547,7 @@ async def analytics_graph_data( # 按节点类型过滤查询 node_query = """ MATCH (n) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id AND labels(n)[0] IN $node_types RETURN elementId(n) as id, @@ -1556,7 +1556,7 @@ async def analytics_graph_data( LIMIT $limit """ node_params = { - "group_id": end_user_id, + "end_user_id": end_user_id, "node_types": node_types, "limit": limit } @@ -1564,7 +1564,7 @@ async def analytics_graph_data( # 查询所有节点 node_query = """ MATCH (n) - WHERE n.group_id = $group_id + WHERE n.end_user_id = $end_user_id RETURN elementId(n) as id, labels(n)[0] as label, @@ -1572,7 +1572,7 @@ async def analytics_graph_data( LIMIT $limit """ node_params = { - "group_id": end_user_id, + "end_user_id": end_user_id, "limit": limit } diff --git a/api/app/tasks.py b/api/app/tasks.py index fa9d1fdf..f4b5f78f 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -382,12 +382,12 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): @celery_app.task(name="app.core.memory.agent.read_message", bind=True) -def read_message_task(self, group_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: +def read_message_task(self, end_user_id: str, message: str, history: List[Dict[str, Any]], search_switch: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: """Celery task to process a read message via MemoryAgentService. Args: - group_id: Group ID for the memory agent (also used as end_user_id) + end_user_id: Group ID for the memory agent (also used as end_user_id) message: User message to process history: Conversation history search_switch: Search switch parameter @@ -408,7 +408,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, from app.services.memory_agent_service import get_end_user_connected_config db = next(get_db()) try: - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") finally: db.close() @@ -420,24 +420,42 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, db = next(get_db()) try: service = MemoryAgentService() - return await service.read_memory(group_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) + return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) finally: db.close() try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time return { "status": "SUCCESS", "result": result, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id } except BaseException as e: elapsed_time = time.time() - start_time + # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -446,7 +464,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, return { "status": "FAILURE", "error": detailed_error, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id @@ -454,19 +472,13 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, @celery_app.task(name="app.core.memory.agent.write_message", bind=True) -def write_message_task(self, group_id: str, message, config_id: str, storage_type: str, user_rag_memory_id: str) -> Dict[str, Any]: +def write_message_task(self, end_user_id: str, message: str, config_id: str,storage_type:str,user_rag_memory_id:str) -> Dict[str, Any]: """Celery task to process a write message via MemoryAgentService. - 支持两种消息格式: - 1. 字符串格式(向后兼容):message="user: xxx\nassistant: yyy" - 2. 结构化消息列表(推荐):message=[{"role": "user", "content": "xxx"}, {"role": "assistant", "content": "yyy"}] - Args: - group_id: Group ID for the memory agent (also used as end_user_id) - message: Message to write (str or list[dict]) + end_user_id: Group ID for the memory agent (also used as end_user_id) + message: Message to write config_id: Optional configuration ID - storage_type: Storage type (neo4j/rag) - user_rag_memory_id: RAG memory ID Returns: Dict containing the result and metadata @@ -477,7 +489,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ from app.core.logging_config import get_logger logger = get_logger(__name__) - logger.info(f"[CELERY WRITE] Starting write task - group_id={group_id}, config_id={config_id}, storage_type={storage_type}") + logger.info(f"[CELERY WRITE] Starting write task - end_user_id={end_user_id}, config_id={config_id}, storage_type={storage_type}") start_time = time.time() # Resolve config_id if None @@ -487,7 +499,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ from app.services.memory_agent_service import get_end_user_connected_config db = next(get_db()) try: - connected_config = get_end_user_connected_config(group_id, db) + connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") finally: db.close() @@ -500,7 +512,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ try: logger.info(f"[CELERY WRITE] Executing MemoryAgentService.write_memory") service = MemoryAgentService() - result = await service.write_memory(group_id, message, actual_config_id, db, storage_type, user_rag_memory_id) + result = await service.write_memory(end_user_id, message, actual_config_id, db, storage_type, user_rag_memory_id) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result except Exception as e: @@ -510,7 +522,24 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ db.close() try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time logger.info(f"[CELERY WRITE] Task completed successfully - elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") @@ -518,13 +547,14 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ return { "status": "SUCCESS", "result": result, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id } except BaseException as e: elapsed_time = time.time() - start_time + # Handle ExceptionGroup from TaskGroup if hasattr(e, 'exceptions'): error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] detailed_error = "; ".join(error_messages) @@ -536,7 +566,7 @@ def write_message_task(self, group_id: str, message, config_id: str, storage_typ return { "status": "FAILURE", "error": detailed_error, - "group_id": group_id, + "end_user_id": end_user_id, "config_id": config_id, "elapsed_time": elapsed_time, "task_id": self.request.id @@ -564,53 +594,53 @@ def reflection_timer_task() -> None: """ reflection_engine() -# unused task -# @celery_app.task(name="app.core.memory.agent.health.check_read_service") -# def check_read_service_task() -> Dict[str, str]: -# """Call read_service and write latest status to Redis. + +@celery_app.task(name="app.core.memory.agent.health.check_read_service") +def check_read_service_task() -> Dict[str, str]: + """Call read_service and write latest status to Redis. -# Returns status data dict that gets written to Redis. -# """ -# client = redis.Redis( -# host=settings.REDIS_HOST, -# port=settings.REDIS_PORT, -# db=settings.REDIS_DB, -# password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None -# ) -# try: -# api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" -# payload = { -# "user_id": "健康检查", -# "apply_id": "健康检查", -# "group_id": "健康检查", -# "message": "你好", -# "history": [], -# "search_switch": "2", -# } -# resp = requests.post(api_url, json=payload, timeout=15) -# ok = resp.status_code == 200 -# status = "Success" if ok else "Fail" -# msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" -# error = "" if ok else resp.text -# code = 0 if ok else 500 -# except Exception as e: -# status = "Fail" -# msg = "接口请求失败" -# error = str(e) -# code = 500 + Returns status data dict that gets written to Redis. + """ + client = redis.Redis( + host=settings.REDIS_HOST, + port=settings.REDIS_PORT, + db=settings.REDIS_DB, + password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None + ) + try: + api_url = f"http://{settings.SERVER_IP}:8000/api/memory/read_service" + payload = { + "user_id": "健康检查", + "apply_id": "健康检查", + "end_user_id": "健康检查", + "message": "你好", + "history": [], + "search_switch": "2", + } + resp = requests.post(api_url, json=payload, timeout=15) + ok = resp.status_code == 200 + status = "Success" if ok else "Fail" + msg = "接口请求成功" if ok else f"接口请求失败: {resp.status_code}" + error = "" if ok else resp.text + code = 0 if ok else 500 + except Exception as e: + status = "Fail" + msg = "接口请求失败" + error = str(e) + code = 500 -# data = { -# "status": status, -# "msg": msg, -# "error": error, -# "code": str(code), -# "time": str(int(time.time())), -# } + data = { + "status": status, + "msg": msg, + "error": error, + "code": str(code), + "time": str(int(time.time())), + } -# client.hset("memsci:health:read_service", mapping=data) -# client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) + client.hset("memsci:health:read_service", mapping=data) + client.expire("memsci:health:read_service", int(settings.HEALTH_CHECK_SECONDS)) -# return data + return data @celery_app.task(name="app.controllers.memory_storage_controller.search_all") @@ -875,7 +905,24 @@ def regenerate_memory_cache(self) -> Dict[str, Any]: } try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1002,7 +1049,24 @@ def workspace_reflection_task(self) -> Dict[str, Any]: } try: - result = asyncio.run(_run()) + # 使用 nest_asyncio 来避免事件循环冲突 + try: + import nest_asyncio + nest_asyncio.apply() + except ImportError: + pass + + # 尝试获取现有事件循环,如果不存在则创建新的 + try: + loop = asyncio.get_event_loop() + if loop.is_closed(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + result = loop.run_until_complete(_run()) elapsed_time = time.time() - start_time result["elapsed_time"] = elapsed_time result["task_id"] = self.request.id @@ -1048,7 +1112,7 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str # 运行遗忘周期 report = await forget_service.trigger_forgetting( db=db, - group_id=None, # 处理所有组 + end_user_id=None, # 处理所有组 config_id=config_id ) @@ -1078,4 +1142,11 @@ def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str "duration_seconds": duration } - return asyncio.run(_run()) + # 运行异步函数 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete(_run()) + return result + finally: + loop.close() From f0efed8aa1b5ec135f61201823bc643c9853c4e5 Mon Sep 17 00:00:00 2001 From: lixinyue <2569494688@qq.com> Date: Wed, 21 Jan 2026 20:33:22 +0800 Subject: [PATCH 8/8] =?UTF-8?q?=E6=8A=8Agroup=5Fid=E6=9B=BF=E6=8D=A2end=5F?= =?UTF-8?q?user=5Fid?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/memory_agent_controller.py | 6 +- .../langgraph_graph/nodes/write_nodes.py | 2 +- .../core/memory/agent/utils/write_tools.py | 17 ++- api/app/services/memory_agent_service.py | 128 +++++++++++------- 4 files changed, 100 insertions(+), 53 deletions(-) diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index a1337085..ad5e3048 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -164,7 +164,7 @@ async def write_server( try: result = await memory_agent_service.write_memory( user_input.end_user_id, - user_input.message, + user_input.messages, config_id, db, storage_type, @@ -290,7 +290,7 @@ async def read_server( ) if str(user_input.search_switch) == "2": retrieve_info = result['answer'] - history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id) + history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) query = user_input.message # 调用 memory_agent_service 的方法生成最终答案 @@ -596,7 +596,7 @@ async def status_type( last_user_message = " ".join([msg.get('content', '') for msg in messages_list]) result = await memory_agent_service.classify_message_type( - user_input.message, + user_input.messages, user_input.config_id, db ) diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index e2a61045..1dab1b0a 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -21,9 +21,9 @@ async def write_node(state: WriteState) -> WriteState: memory_config=state.get('memory_config', '') try: result=await write( - content=content, end_user_id=end_user_id, memory_config=memory_config, + messages=content, # 修复:使用正确的参数名 messages ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index ce55286e..b8bc58eb 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -77,10 +77,25 @@ async def write( # Step 1: Load and chunk data step_start = time.time() + + # Convert messages list to content string + # messages format: [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}, ...] + if isinstance(messages, list) and len(messages) > 0: + # Extract content from the last user message or concatenate all messages + if isinstance(messages[-1], dict) and 'content' in messages[-1]: + content = messages[-1]['content'] + else: + # Fallback: concatenate all message contents + content = " ".join([msg.get('content', '') for msg in messages if isinstance(msg, dict)]) + elif isinstance(messages, str): + content = messages + else: + content = str(messages) + chunked_dialogs = await get_chunked_dialogs( chunker_strategy=chunker_strategy, end_user_id=end_user_id, - messages=messages, + content=content, # 修复:使用 content 参数而不是 messages ref_id=ref_id, config_id=config_id, ) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index e475bef0..8a0d5a39 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -36,6 +36,7 @@ from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) +from langchain_core.messages import AIMessage from langchain_core.messages import HumanMessage from pydantic import BaseModel, Field from sqlalchemy import func @@ -57,7 +58,6 @@ class MemoryAgentService: def writer_messages_deal(self, messages, start_time, end_user_id, config_id, message, context): duration = time.time() - start_time - if str(messages) == 'success': logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 @@ -266,7 +266,7 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: + async def write_memory(self, end_user_id: str, messages: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id @@ -319,53 +319,85 @@ class MemoryAgentService: raise ValueError(error_msg) - try: - if storage_type == "rag": - # For RAG storage, convert messages to single string - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - result = await write_rag(end_user_id, message_text, user_rag_memory_id) - return result - else: - async with make_write_graph() as graph: - config = {"configurable": {"thread_id": end_user_id}} - # Convert structured messages to LangChain messages - langchain_messages = [] - for msg in messages: - if msg['role'] == 'user': - langchain_messages.append(HumanMessage(content=msg['content'])) - elif msg['role'] == 'assistant': - from langchain_core.messages import AIMessage - langchain_messages.append(AIMessage(content=msg['content'])) - - # 初始状态 - 包含所有必要字段 - initial_state = { - "messages": langchain_messages, - "end_user_id": end_user_id, - "memory_config": memory_config - } + async with make_write_graph() as graph: + config = {"configurable": {"thread_id": end_user_id}} + # Convert structured messages to LangChain messages + langchain_messages = [] + for msg in messages: + if msg['role'] == 'user': + langchain_messages.append(HumanMessage(content=msg['content'])) + elif msg['role'] == 'assistant': + langchain_messages.append(AIMessage(content=msg['content'])) - # 获取节点更新信息 - async for update_event in graph.astream( - initial_state, - stream_mode="updates", - config=config - ): - for node_name, node_data in update_event.items(): - if 'save_neo4j' == node_name: - massages = node_data - massagesstatus = massages.get('write_result')['status'] - contents = massages.get('write_result') - # Convert messages back to string for logging - message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) - except Exception as e: - # Ensure proper error handling and logging - error_msg = f"Write operation failed: {str(e)}" - logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) - raise ValueError(error_msg) + # 初始状态 - 包含所有必要字段 + initial_state = { + "messages": langchain_messages, + "end_user_id": end_user_id, + "memory_config": memory_config + } + + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", + config=config + ): + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + print(massages) + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + # Convert messages back to string for logging + message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) + + # try: + # if storage_type == "rag": + # # For RAG storage, convert messages to single string + # message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + # result = await write_rag(end_user_id, message_text, user_rag_memory_id) + # return result + # else: + # async with make_write_graph() as graph: + # config = {"configurable": {"thread_id": end_user_id}} + # # Convert structured messages to LangChain messages + # langchain_messages = [] + # for msg in messages: + # if msg['role'] == 'user': + # langchain_messages.append(HumanMessage(content=msg['content'])) + # elif msg['role'] == 'assistant': + # langchain_messages.append(AIMessage(content=msg['content'])) + # + # # 初始状态 - 包含所有必要字段 + # initial_state = { + # "messages": langchain_messages, + # "end_user_id": end_user_id, + # "memory_config": memory_config + # } + # + # # 获取节点更新信息 + # async for update_event in graph.astream( + # initial_state, + # stream_mode="updates", + # config=config + # ): + # for node_name, node_data in update_event.items(): + # if 'save_neo4j' == node_name: + # massages = node_data + # massagesstatus = massages.get('write_result')['status'] + # contents = massages.get('write_result') + # # Convert messages back to string for logging + # message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + # return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) + # except Exception as e: + # # Ensure proper error handling and logging + # error_msg = f"Write operation failed: {str(e)}" + # logger.error(error_msg) + # if audit_logger: + # duration = time.time() - start_time + # audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + # raise ValueError(error_msg)