diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index e05daf4a..f0756764 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -9,7 +9,7 @@ import os import re import time import uuid -from threading import Lock + from typing import Any, AsyncGenerator, Dict, List, Optional import redis @@ -51,9 +51,7 @@ _neo4j_connector = Neo4jConnector() class MemoryAgentService: """Service for memory agent operations""" - def __init__(self): - self.user_locks: Dict[str, Lock] = {} - self.locks_lock = Lock() + def writer_messages_deal(self,messages,start_time,group_id,config_id,message): messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') @@ -83,12 +81,7 @@ class MemoryAgentService: raise ValueError(f"写入失败: {messages}") - def get_group_lock(self, group_id: str) -> Lock: - """Get lock for specific group to prevent concurrent processing""" - with self.locks_lock: - if group_id not in self.user_locks: - self.user_locks[group_id] = Lock() - return self.user_locks[group_id] + def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" @@ -417,241 +410,236 @@ class MemoryAgentService: except ImportError: audit_logger = None - # Get group lock to prevent concurrent processing - group_lock = self.get_group_lock(group_id) + try: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="MemoryAgentService" + ) + logger.info(f"Configuration loaded successfully: {memory_config.config_name}") + except ConfigurationError as e: + error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" + logger.error(error_msg) - with group_lock: - # Step 1: Load configuration from database only - try: - config_service = MemoryConfigService(db) - memory_config = config_service.load_memory_config( + # Log failed operation + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", config_id=config_id, - service_name="MemoryAgentService" + group_id=group_id, + success=False, + duration=duration, + error=error_msg ) - logger.info(f"Configuration loaded successfully: {memory_config.config_name}") - except ConfigurationError as e: - error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" - logger.error(error_msg) - # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=error_msg - ) + raise ValueError(error_msg) - raise ValueError(error_msg) + # Step 2: Prepare history + history.append({"role": "user", "content": message}) + logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") - # Step 2: Prepare history - history.append({"role": "user", "content": message}) - logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") + # Step 3: Initialize MCP client and execute read workflow + mcp_config = get_mcp_server_config() + client = MultiServerMCPClient(mcp_config) - # Step 3: Initialize MCP client and execute read workflow - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) + async with client.session('data_flow') as session: + session_start = time.time() + logger.debug("Connected to MCP Server: data_flow") - async with client.session('data_flow') as session: - session_start = time.time() - logger.debug("Connected to MCP Server: data_flow") - - tools_start = time.time() - tools = await load_mcp_tools(session) - tools_time = time.time() - tools_start - logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - - outputs = [] - intermediate_outputs = [] - seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates + tools_start = time.time() + tools = await load_mcp_tools(session) + tools_time = time.time() - tools_start + logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - # Pass memory_config to the graph workflow - graph_start = time.time() - async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: - graph_init_time = time.time() - graph_start - logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - - start = time.time() - config = {"configurable": {"thread_id": group_id}} - workflow_errors = [] # Track errors from workflow - - event_count = 0 - async for event in graph.astream( - {"messages": history, "memory_config": memory_config, "errors": []}, - stream_mode="values", - config=config - ): - event_count += 1 - event_start = time.time() - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) + outputs = [] + intermediate_outputs = [] + seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates - for msg in messages: - msg_content = msg.content - msg_role = msg.__class__.__name__.lower().replace("message", "") - outputs.append({ - "role": msg_role, - "content": msg_content - }) + # Pass memory_config to the graph workflow + graph_start = time.time() + async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: + graph_init_time = time.time() - graph_start + logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - # Extract intermediate outputs - if hasattr(msg, 'content'): - try: - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - content_to_parse = msg_content - if isinstance(msg_content, list): - for block in msg_content: - if isinstance(block, dict) and block.get('type') == 'text': - content_to_parse = block.get('text', '') - break - else: - continue # No text block found + start = time.time() + config = {"configurable": {"thread_id": group_id}} + workflow_errors = [] # Track errors from workflow - # Try to parse content as JSON - if isinstance(content_to_parse, str): - try: - parsed = json.loads(content_to_parse) - if isinstance(parsed, dict): - # Check for single intermediate output - if '_intermediate' in parsed: - intermediate_data = parsed['_intermediate'] + event_count = 0 + async for event in graph.astream( + {"messages": history, "memory_config": memory_config, "errors": []}, + stream_mode="values", + config=config + ): + event_count += 1 + event_start = time.time() + messages = event.get('messages') + # Capture any errors from the state + if event.get('errors'): + workflow_errors.extend(event.get('errors', [])) + + for msg in messages: + msg_content = msg.content + msg_role = msg.__class__.__name__.lower().replace("message", "") + outputs.append({ + "role": msg_role, + "content": msg_content + }) + + # Extract intermediate outputs + if hasattr(msg, 'content'): + try: + # Handle MCP content format: [{'type': 'text', 'text': '...'}] + content_to_parse = msg_content + if isinstance(msg_content, list): + for block in msg_content: + if isinstance(block, dict) and block.get('type') == 'text': + content_to_parse = block.get('text', '') + break + else: + continue # No text block found + + # Try to parse content as JSON + if isinstance(content_to_parse, str): + try: + parsed = json.loads(content_to_parse) + if isinstance(parsed, dict): + # Check for single intermediate output + if '_intermediate' in parsed: + intermediate_data = parsed['_intermediate'] + output_key = self._create_intermediate_key(intermediate_data) + + if output_key not in seen_intermediates: + seen_intermediates.add(output_key) + intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + + # Check for multiple intermediate outputs (from Retrieve) + if '_intermediates' in parsed: + for intermediate_data in parsed['_intermediates']: output_key = self._create_intermediate_key(intermediate_data) if output_key not in seen_intermediates: seen_intermediates.add(output_key) intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + except (json.JSONDecodeError, ValueError): + pass + except Exception as e: + logger.debug(f"Failed to extract intermediate output: {e}") - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in parsed: - for intermediate_data in parsed['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) + event_time = time.time() - event_start + logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except (json.JSONDecodeError, ValueError): - pass - except Exception as e: - logger.debug(f"Failed to extract intermediate output: {e}") - - event_time = time.time() - event_start - logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") + workflow_duration = time.time() - start + session_duration = time.time() - session_start + logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") + logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") + logger.info(f"[PERF] Total events processed: {event_count}") + # Extract final answer + final_answer = "" + for messages in outputs: + if messages['role'] == 'tool': + message = messages['content'] - workflow_duration = time.time() - start - session_duration = time.time() - session_start - logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") - logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") - logger.info(f"[PERF] Total events processed: {event_count}") - # Extract final answer - final_answer = "" - for messages in outputs: - if messages['role'] == 'tool': - message = messages['content'] + # Handle MCP content format: [{'type': 'text', 'text': '...'}] + if isinstance(message, list): + # Extract text from MCP content blocks + for block in message: + if isinstance(block, dict) and block.get('type') == 'text': + message = block.get('text', '') + break + else: + continue # No text block found - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(message, list): - # Extract text from MCP content blocks - for block in message: - if isinstance(block, dict) and block.get('type') == 'text': - message = block.get('text', '') - break - else: - continue # No text block found + try: + parsed = json.loads(message) if isinstance(message, str) else message + if isinstance(parsed, dict): + if parsed.get('status') == 'success': + summary_result = parsed.get('summary_result') + if summary_result: + final_answer = summary_result + except (json.JSONDecodeError, ValueError): + pass - try: - parsed = json.loads(message) if isinstance(message, str) else message - if isinstance(parsed, dict): - if parsed.get('status') == 'success': - summary_result = parsed.get('summary_result') - if summary_result: - final_answer = summary_result - except (json.JSONDecodeError, ValueError): - pass + # 记录成功的操作 + total_duration = time.time() - start_time - # 记录成功的操作 - total_duration = time.time() - start_time + # Check for workflow errors + if workflow_errors: + error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) + logger.warning(f"Read workflow completed with errors: {error_details}") - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.warning(f"Read workflow completed with errors: {error_details}") - - if audit_logger: - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=False, - duration=total_duration, - error=error_details, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer), - "errors": workflow_errors - } - ) - - # Raise error if no answer was produced - if not final_answer: - raise ValueError(f"Read workflow failed: {error_details}") - - if audit_logger and not workflow_errors: + if audit_logger: audit_logger.log_operation( operation="READ", config_id=config_id, group_id=group_id, - success=True, + success=False, duration=total_duration, + error=error_details, details={ "search_switch": search_switch, "history_length": len(history), "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer) + "has_answer": bool(final_answer), + "errors": workflow_errors } ) - retrieved_content=[] - repo = ShortTermMemoryRepository(db) - if str(search_switch)!="2": - for intermediate in intermediate_outputs: - print(intermediate) - intermediate_type=intermediate['type'] - if intermediate_type=="search_result": - query=intermediate['query'] - raw_results=intermediate['raw_results'] - reranked_results=raw_results.get('reranked_results',[]) - try: - statements=[statement['statement'] for statement in reranked_results.get('statements', [])] - except Exception: - statements=[] - statements=list(set(statements)) - retrieved_content.append({query:statements}) - if retrieved_content==[]: - retrieved_content='' - if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] - # 使用 upsert 方法 - repo.upsert( - end_user_id=end_user_id, # 确保这个变量在作用域内 - messages=ori_message, - aimessages=final_answer, - retrieved_content=retrieved_content, - search_switch=str(search_switch) - ) - print("写入成功") + + # Raise error if no answer was produced + if not final_answer: + raise ValueError(f"Read workflow failed: {error_details}") + + if audit_logger and not workflow_errors: + audit_logger.log_operation( + operation="READ", + config_id=config_id, + group_id=group_id, + success=True, + duration=total_duration, + details={ + "search_switch": search_switch, + "history_length": len(history), + "intermediate_outputs_count": len(intermediate_outputs), + "has_answer": bool(final_answer) + } + ) + retrieved_content=[] + repo = ShortTermMemoryRepository(db) + if str(search_switch)!="2": + for intermediate in intermediate_outputs: + print(intermediate) + intermediate_type=intermediate['type'] + if intermediate_type=="search_result": + query=intermediate['query'] + raw_results=intermediate['raw_results'] + reranked_results=raw_results.get('reranked_results',[]) + try: + statements=[statement['statement'] for statement in reranked_results.get('statements', [])] + except Exception: + statements=[] + statements=list(set(statements)) + retrieved_content.append({query:statements}) + if retrieved_content==[]: + retrieved_content='' + if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] + # 使用 upsert 方法 + repo.upsert( + end_user_id=end_user_id, # 确保这个变量在作用域内 + messages=ori_message, + aimessages=final_answer, + retrieved_content=retrieved_content, + search_switch=str(search_switch) + ) + print("写入成功") - return { - "answer": final_answer, - "intermediate_outputs": intermediate_outputs - } - + return { + "answer": final_answer, + "intermediate_outputs": intermediate_outputs + } + def _create_intermediate_key(self, output: Dict) -> str: """ Create a unique key for an intermediate output to detect duplicates.