diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index f0756764..21407a33 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -9,30 +9,28 @@ import os import re import time import uuid - from typing import Any, AsyncGenerator, Dict, List, Optional import redis +from langchain_core.messages import HumanMessage + from app.core.config import settings from app.core.logging_config import get_config_logger, get_logger from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph from app.core.memory.agent.logger_file.log_streamer import LogStreamer -from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config +from app.core.memory.agent.utils.messages_tools import merge_multiple_search_results, reorder_output_results from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_config_schema import ConfigurationError from app.services.memory_config_service import MemoryConfigService from app.services.memory_konwledges_server import ( write_rag, ) -from langchain_mcp_adapters.client import MultiServerMCPClient -from langchain_mcp_adapters.tools import load_mcp_tools from pydantic import BaseModel, Field from sqlalchemy import func from sqlalchemy.orm import Session @@ -50,21 +48,17 @@ _neo4j_connector = Neo4jConnector() class MemoryAgentService: """Service for memory agent operations""" - - - def writer_messages_deal(self,messages,start_time,group_id,config_id,message): - messages = str(messages).replace("'", '"').replace('\\n', '').replace('\n', '').replace('\\', '') - countext = re.findall(r'"status": "(.*?)",', messages)[0] + def writer_messages_deal(self, messages, start_time, group_id, config_id, message, context): duration = time.time() - start_time - if countext == 'success': + if str(messages) == 'success': logger.info(f"Write operation successful for group {group_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=True, duration=duration, details={"message_length": len(message)}) - return countext + return context else: logger.warning(f"Write operation failed for group {group_id}") @@ -80,9 +74,9 @@ class MemoryAgentService: ) raise ValueError(f"写入失败: {messages}") - - + + def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" last_message = event["messages"][-1] @@ -119,15 +113,15 @@ class MemoryAgentService: return True return False - + async def get_health_status(self) -> Dict: """ Get latest health status from Redis cache - + Returns health status information written by Celery periodic task """ logger.info("Checking health status") - + client = redis.Redis( host=settings.REDIS_HOST, port=settings.REDIS_PORT, @@ -135,34 +129,51 @@ class MemoryAgentService: password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None ) payload = client.hgetall("memsci:health:read_service") or {} - + if payload: # decode bytes to str decoded = {k.decode("utf-8"): v.decode("utf-8") for k, v in payload.items()} status = decoded.get("status", "unknown") else: status = "unknown" - + + # Add database connection pool status + try: + from app.db import get_pool_status + pool_status = get_pool_status() + logger.info(f"Database pool status: {pool_status}") + + # Check if pool usage is too high + if pool_status.get("usage_percent", 0) > 80: + logger.warning(f"High database pool usage: {pool_status['usage_percent']}%") + status = "warning" + + except Exception as e: + logger.error(f"Failed to get pool status: {e}") + pool_status = {"error": str(e)} + logger.info(f"Health status: {status}") - return {"status": status} + return { + "status": status, + "database_pool": pool_status + } def get_log_content(self) -> str: """ Read and return agent service log file content - - Returns cleaned log content using the same cleaning logic as transmission mode + + Returns cleaned log content using the same cleaning logic as transmission mode Returns cleaned log content using the same cleaning logic as transmission mode """ logger.info("Reading log file") - # Use project root directory for logs - # Get the project root (redbear-mem directory) + current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = os.path.dirname(app_dir) # redbear-mem directory log_path = os.path.join(project_root, "logs", "agent_service.log") - + summer = '' with open(log_path, "r", encoding="utf-8") as infile: @@ -176,83 +187,83 @@ class MemoryAgentService: logger.info(f"Log content retrieved, size: {len(summer)} bytes") return summer - + async def stream_log_content(self) -> AsyncGenerator[str, None]: """ Stream log content in real-time using Server-Sent Events (SSE) - + This method establishes a streaming connection and transmits log entries as they are written to the log file. It uses the LogStreamer to watch the file and yields SSE-formatted messages. - + Yields: SSE-formatted strings with the following event types: - log: Contains log content and timestamp - keepalive: Periodic keepalive messages to maintain connection - error: Error information if streaming fails - done: Indicates streaming has completed - + Raises: FileNotFoundError: If log file doesn't exist at stream start Exception: For other unexpected errors during streaming """ logger.info("Starting log content streaming") - + # Get log file path - use project root directory current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = os.path.dirname(app_dir) # redbear-mem directory log_path = os.path.join(project_root, "logs", "agent_service.log") - + # Check if file exists before starting stream if not os.path.exists(log_path): logger.error(f"Log file not found: {log_path}") # Send error event in SSE format yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件不存在', 'error': f'File not found: {log_path}'})}\n\n" return - + streamer = None try: # Initialize LogStreamer with keepalive interval from settings (default 300 seconds) keepalive_interval = getattr(settings, 'LOG_STREAM_KEEPALIVE_INTERVAL', 300) streamer = LogStreamer(log_path, keepalive_interval=keepalive_interval) - + logger.info(f"LogStreamer initialized for {log_path}") - + # Stream log content using read_existing_and_stream to get all existing content first async for message in streamer.read_existing_and_stream(): event_type = message.get("event") data = message.get("data") - + # Format as SSE message # SSE format: "event: \ndata: \n\n" sse_message = f"event: {event_type}\ndata: {json.dumps(data)}\n\n" - + logger.debug(f"Streaming event: {event_type}") yield sse_message - + # If error or done event, stop streaming if event_type in ["error", "done"]: logger.info(f"Stream ended with event: {event_type}") break - + except FileNotFoundError as e: logger.error(f"Log file not found during streaming: {e}") yield f"event: error\ndata: {json.dumps({'code': 4006, 'message': '日志文件在流式传输期间变得不可用', 'error': str(e)})}\n\n" - + except Exception as e: logger.error(f"Unexpected error during log streaming: {e}", exc_info=True) yield f"event: error\ndata: {json.dumps({'code': 8001, 'message': '流式传输期间发生错误', 'error': str(e)})}\n\n" - + finally: # Resource cleanup logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - + async def write_memory(self, group_id: str, message: str, config_id: Optional[str], db: Session, storage_type: str, user_rag_memory_id: str) -> str: """ Process write operation with config_id - + Args: group_id: Group identifier (also used as end_user_id) message: Message to write @@ -260,10 +271,10 @@ class MemoryAgentService: db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID - + Returns: Write operation result status - + Raises: ValueError: If config loading fails or write operation fails """ @@ -279,7 +290,7 @@ class MemoryAgentService: raise # Re-raise our specific error logger.error(f"Failed to get connected config for end_user {group_id}: {e}") raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}") - + import time start_time = time.time() @@ -294,61 +305,49 @@ class MemoryAgentService: except ConfigurationError as e: error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" logger.error(error_msg) - + # Log failed operation if audit_logger: duration = time.time() - start_time audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) - + raise ValueError(error_msg) - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - if storage_type == "rag": - result = await write_rag(group_id, message, user_rag_memory_id) - return result - else: - async with client.session("data_flow") as session: - logger.debug("Connected to MCP Server: data_flow") - tools = await load_mcp_tools(session) - workflow_errors = [] # Track errors from workflow - - # Pass memory_config to the graph workflow - async with make_write_graph(group_id, tools, group_id, group_id, memory_config=memory_config) as graph: - logger.debug("Write graph created successfully") + try: + if storage_type == "rag": + result = await write_rag(group_id, message, user_rag_memory_id) + return result + else: + async with make_write_graph() as graph: config = {"configurable": {"thread_id": group_id}} + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, + "memory_config": memory_config} - async for event in graph.astream( - {"messages": message, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 获取节点更新信息 + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.error(f"Write workflow failed with errors: {error_details}") - - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - group_id=group_id, - success=False, - duration=duration, - error=error_details - ) - - raise ValueError(f"Write workflow failed: {error_details}") - - return self.writer_messages_deal(messages, start_time, group_id, config_id, message) - + for node_name, node_data in update_event.items(): + if 'save_neo4j' == node_name: + massages = node_data + massagesstatus = massages.get('write_result')['status'] + contents = massages.get('write_result') + return self.writer_messages_deal(massagesstatus, start_time, group_id, config_id, message, contents) + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Write operation failed: {str(e)}" + logger.error(error_msg) + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, group_id=group_id, success=False, duration=duration, error=error_msg) + raise ValueError(error_msg) + + + + async def read_memory( self, group_id: str, @@ -362,12 +361,12 @@ class MemoryAgentService: ) -> Dict: """ Process read operation with config_id - + search_switch values: - "0": Requires verification - "1": No verification, direct split - "2": Direct answer based on context - + Args: group_id: Group identifier (also used as end_user_id) message: User message @@ -377,18 +376,17 @@ class MemoryAgentService: db: SQLAlchemy database session storage_type: Storage type (neo4j or rag) user_rag_memory_id: User RAG memory ID - + Returns: Dict with 'answer' and 'intermediate_outputs' keys - + Raises: ValueError: If config loading fails """ import time start_time = time.time() - ori_message=message - end_user_id=group_id + # Resolve config_id if None using end_user's connected config if config_id is None: try: @@ -410,6 +408,7 @@ class MemoryAgentService: except ImportError: audit_logger = None + try: config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( @@ -440,326 +439,126 @@ class MemoryAgentService: logger.debug(f"Group ID:{group_id}, Message:{message}, History:{history}, Config ID:{config_id}") # Step 3: Initialize MCP client and execute read workflow - mcp_config = get_mcp_server_config() - client = MultiServerMCPClient(mcp_config) - - async with client.session('data_flow') as session: - session_start = time.time() - logger.debug("Connected to MCP Server: data_flow") - - tools_start = time.time() - tools = await load_mcp_tools(session) - tools_time = time.time() - tools_start - logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") - - outputs = [] - intermediate_outputs = [] - seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates - - # Pass memory_config to the graph workflow - graph_start = time.time() - async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: - graph_init_time = time.time() - graph_start - logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") - - start = time.time() + try: + async with make_read_graph() as graph: config = {"configurable": {"thread_id": group_id}} - workflow_errors = [] # Track errors from workflow - - event_count = 0 - async for event in graph.astream( - {"messages": history, "memory_config": memory_config, "errors": []}, - stream_mode="values", + # 初始状态 - 包含所有必要字段 + initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, + "group_id": group_id + , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} + # 获取节点更新信息 + _intermediate_outputs = [] + summary = '' + async for update_event in graph.astream( + initial_state, + stream_mode="updates", config=config ): - event_count += 1 - event_start = time.time() - messages = event.get('messages') - # Capture any errors from the state - if event.get('errors'): - workflow_errors.extend(event.get('errors', [])) + for node_name, node_data in update_event.items(): + print(f"处理节点: {node_name}") - for msg in messages: - msg_content = msg.content - msg_role = msg.__class__.__name__.lower().replace("message", "") - outputs.append({ - "role": msg_role, - "content": msg_content - }) + # 处理不同Summary节点的返回结构 + if 'Summary' in node_name: + if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: + summary = node_data['InputSummary']['summary_result'] + elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']: + summary = node_data['RetrieveSummary']['summary_result'] + elif 'summary' in node_data and 'summary_result' in node_data['summary']: + summary = node_data['summary']['summary_result'] + elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']: + summary = node_data['SummaryFails']['summary_result'] - # Extract intermediate outputs - if hasattr(msg, 'content'): - try: - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - content_to_parse = msg_content - if isinstance(msg_content, list): - for block in msg_content: - if isinstance(block, dict) and block.get('type') == 'text': - content_to_parse = block.get('text', '') - break - else: - continue # No text block found + spit_data = node_data.get('spit_data', {}).get('_intermediate', None) + if spit_data and spit_data != [] and spit_data != {}: + _intermediate_outputs.append(spit_data) - # Try to parse content as JSON - if isinstance(content_to_parse, str): - try: - parsed = json.loads(content_to_parse) - if isinstance(parsed, dict): - # Check for single intermediate output - if '_intermediate' in parsed: - intermediate_data = parsed['_intermediate'] - output_key = self._create_intermediate_key(intermediate_data) + # Problem_Extension 节点 + problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) + if problem_extension and problem_extension != [] and problem_extension != {}: + _intermediate_outputs.append(problem_extension) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) + # Retrieve 节点 + retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) + if retrieve_node and retrieve_node != [] and retrieve_node != {}: + _intermediate_outputs.extend(retrieve_node) - # Check for multiple intermediate outputs (from Retrieve) - if '_intermediates' in parsed: - for intermediate_data in parsed['_intermediates']: - output_key = self._create_intermediate_key(intermediate_data) + # Verify 节点 + verify_n = node_data.get('verify', {}).get('_intermediate', None) + if verify_n and verify_n != [] and verify_n != {}: + _intermediate_outputs.append(verify_n) - if output_key not in seen_intermediates: - seen_intermediates.add(output_key) - intermediate_outputs.append(self._format_intermediate_output(intermediate_data)) - except (json.JSONDecodeError, ValueError): - pass - except Exception as e: - logger.debug(f"Failed to extract intermediate output: {e}") + # Summary 节点 + summary_n = node_data.get('summary', {}).get('_intermediate', None) + if summary_n and summary_n != [] and summary_n != {}: + _intermediate_outputs.append(summary_n) - event_time = time.time() - event_start - logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") + _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}] - workflow_duration = time.time() - start - session_duration = time.time() - session_start - logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") - logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") - logger.info(f"[PERF] Total events processed: {event_count}") - # Extract final answer - final_answer = "" - for messages in outputs: - if messages['role'] == 'tool': - message = messages['content'] + optimized_outputs = merge_multiple_search_results(_intermediate_outputs) + result = reorder_output_results(optimized_outputs) - # Handle MCP content format: [{'type': 'text', 'text': '...'}] - if isinstance(message, list): - # Extract text from MCP content blocks - for block in message: - if isinstance(block, dict) and block.get('type') == 'text': - message = block.get('text', '') - break - else: - continue # No text block found - - try: - parsed = json.loads(message) if isinstance(message, str) else message - if isinstance(parsed, dict): - if parsed.get('status') == 'success': - summary_result = parsed.get('summary_result') - if summary_result: - final_answer = summary_result - except (json.JSONDecodeError, ValueError): - pass - - # 记录成功的操作 - total_duration = time.time() - start_time - - # Check for workflow errors - if workflow_errors: - error_details = "; ".join([f"{e['tool']}: {e['error']}" for e in workflow_errors]) - logger.warning(f"Read workflow completed with errors: {error_details}") + # Log successful operation + if audit_logger: + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + group_id=group_id, + success=True, + duration=duration + ) + return { + "answer": summary, + "intermediate_outputs": result + } + except Exception as e: + # Ensure proper error handling and logging + error_msg = f"Read operation failed: {str(e)}" + logger.error(error_msg) if audit_logger: + duration = time.time() - start_time audit_logger.log_operation( operation="READ", config_id=config_id, group_id=group_id, success=False, - duration=total_duration, - error=error_details, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer), - "errors": workflow_errors - } + duration=duration, + error=error_msg ) - - # Raise error if no answer was produced - if not final_answer: - raise ValueError(f"Read workflow failed: {error_details}") - - if audit_logger and not workflow_errors: - audit_logger.log_operation( - operation="READ", - config_id=config_id, - group_id=group_id, - success=True, - duration=total_duration, - details={ - "search_switch": search_switch, - "history_length": len(history), - "intermediate_outputs_count": len(intermediate_outputs), - "has_answer": bool(final_answer) - } - ) - retrieved_content=[] - repo = ShortTermMemoryRepository(db) - if str(search_switch)!="2": - for intermediate in intermediate_outputs: - print(intermediate) - intermediate_type=intermediate['type'] - if intermediate_type=="search_result": - query=intermediate['query'] - raw_results=intermediate['raw_results'] - reranked_results=raw_results.get('reranked_results',[]) - try: - statements=[statement['statement'] for statement in reranked_results.get('statements', [])] - except Exception: - statements=[] - statements=list(set(statements)) - retrieved_content.append({query:statements}) - if retrieved_content==[]: - retrieved_content='' - if '信息不足,无法回答。' != str(final_answer) and str(search_switch).strip() != "2":#and retrieved_content!=[] - # 使用 upsert 方法 - repo.upsert( - end_user_id=end_user_id, # 确保这个变量在作用域内 - messages=ori_message, - aimessages=final_answer, - retrieved_content=retrieved_content, - search_switch=str(search_switch) - ) - print("写入成功") + raise ValueError(error_msg) - return { - "answer": final_answer, - "intermediate_outputs": intermediate_outputs - } - - def _create_intermediate_key(self, output: Dict) -> str: - """ - Create a unique key for an intermediate output to detect duplicates. - - Args: - output: Intermediate output dictionary - - Returns: - Unique string key for this output - """ - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - # Use type + original query as key - return f"split:{output.get('original_query', '')}" - elif output_type == 'problem_extension': - # Use type + original query as key - return f"extension:{output.get('original_query', '')}" - elif output_type == 'search_result': - # Use type + query + index as key - return f"search:{output.get('query', '')}:{output.get('index', 0)}" - elif output_type == 'retrieval_summary': - # Use type + query as key - return f"summary:{output.get('query', '')}" - elif output_type == 'verification': - # Use type + query as key - return f"verification:{output.get('query', '')}" - elif output_type == 'input_summary': - # Use type + query as key - return f"input_summary:{output.get('query', '')}" - else: - # Fallback: use JSON representation - import json - return json.dumps(output, sort_keys=True) - - def _format_intermediate_output(self, output: Dict) -> Dict: - """Format intermediate output for frontend display.""" - output_type = output.get('type', 'unknown') - - if output_type == 'problem_split': - return { - 'type': 'problem_split', - 'title': '问题拆分', - 'data': output.get('data', []), - 'original_query': output.get('original_query', '') - } - elif output_type == 'problem_extension': - return { - 'type': 'problem_extension', - 'title': '问题扩展', - 'data': output.get('data', {}), - 'original_query': output.get('original_query', '') - } - elif output_type == 'search_result': - return { - 'type': 'search_result', - 'title': f'检索结果 ({output.get("index", 0)}/{output.get("total", 0)})', - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results', ''), - 'index': output.get('index', 0), - 'total': output.get('total', 0) - } - elif output_type == 'retrieval_summary': - return { - 'type': 'retrieval_summary', - 'title': '检索总结', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - elif output_type == 'verification': - return { - 'type': 'verification', - 'title': '数据验证', - 'result': output.get('result', 'unknown'), - 'reason': output.get('reason', ''), - 'query': output.get('query', ''), - 'verified_count': output.get('verified_count', 0) - } - elif output_type == 'input_summary': - return { - 'type': 'input_summary', - 'title': '快速答案', - 'summary': output.get('summary', ''), - 'query': output.get('query', ''), - 'raw_results': output.get('raw_results'), - - } - else: - return output - async def classify_message_type(self, message: str, config_id: int, db: Session) -> Dict: """ Determine the type of user message (read or write) Updated to eliminate global variables in favor of explicit parameters. - + Args: message: User message to classify config_id: Configuration ID to load LLM model from database db: Database session - + Returns: Type classification result """ logger.info("Classifying message type") - + # Load configuration to get LLM model ID config_service = MemoryConfigService(db) memory_config = config_service.load_memory_config( config_id=config_id, service_name="MemoryAgentService" ) - + status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") return status - + # ==================== 新增的三个接口方法 ==================== - + async def get_knowledge_type_stats( self, end_user_id: Optional[str] = None, @@ -772,13 +571,13 @@ class MemoryAgentService: 1. PostgreSQL 中的知识库类型:General, Web, Third-party, Folder(根据 workspace_id 过滤) 2. Neo4j 中的 memory 类型(仅统计 Chunk 数量,根据 end_user_id/group_id 过滤) 3. total: 所有类型的总和 - + 参数: - end_user_id: 用户组ID(可选,未提供时 memory 统计为 0) - only_active: 是否仅统计有效记录 - current_workspace_id: 当前工作空间ID(可选,未提供时知识库统计为 0) - db: 数据库会话 - + 返回格式: { "General": count, @@ -790,18 +589,18 @@ class MemoryAgentService: } """ result = {} - + # 1. 统计 PostgreSQL 中的知识库类型 try: if db is None: from app.db import get_db db_gen = get_db() db = next(db_gen) - + # 初始化所有标准类型为 0 for kb_type in KnowledgeType: result[kb_type.value] = 0 - + # 如果提供了 workspace_id,则按 workspace_id 过滤 if current_workspace_id: # 构建查询条件 @@ -809,47 +608,48 @@ class MemoryAgentService: Knowledge.type, func.count(Knowledge.id).label('count') ).filter(Knowledge.workspace_id == current_workspace_id) - + # 检查 Knowledge 模型是否有 status 字段 if only_active and hasattr(Knowledge, 'status'): query = query.filter(Knowledge.status == 1) - + # 按类型分组 type_counts = query.group_by(Knowledge.type).all() - + # 只填充标准类型的统计值,忽略其他类型 valid_types = {kb_type.value for kb_type in KnowledgeType} for type_name, count in type_counts: if type_name in valid_types: result[type_name] = count - + logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}") else: # 没有提供 workspace_id,所有知识库类型返回 0 logger.info("未提供 workspace_id,知识库类型统计全部为 0") - + except Exception as e: logger.error(f"知识库类型统计失败: {e}") raise Exception(f"知识库类型统计失败: {e}") - + # 2. 统计 Neo4j 中的 memory 总量(统计当前空间下所有宿主的 Chunk 总数) try: if current_workspace_id: # 获取当前空间下的所有宿主 from app.repositories import app_repository, end_user_repository from app.schemas.app_schema import App as AppSchema - + from app.schemas.end_user_schema import EndUser as EndUserSchema + # 查询应用并转换为 Pydantic 模型 apps_orm = app_repository.get_apps_by_workspace_id(db, current_workspace_id) apps = [AppSchema.model_validate(h) for h in apps_orm] app_ids = [app.id for app in apps] - + # 获取所有宿主 end_users = [] for app_id in app_ids: end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id) end_users.extend(h for h in end_user_orm_list) - + # 统计所有宿主的 Chunk 总数 total_chunks = 0 for end_user in end_users: @@ -864,27 +664,27 @@ class MemoryAgentService: chunk_count = neo4j_result[0]["Count"] if neo4j_result else 0 total_chunks += chunk_count logger.debug(f"EndUser {end_user_id_str} Chunk数量: {chunk_count}") - + result["memory"] = total_chunks logger.info(f"Neo4j memory统计成功: 总Chunk数={total_chunks}, 宿主数={len(end_users)}") else: # 没有 workspace_id 时,返回 0 result["memory"] = 0 logger.info("未提供 workspace_id,memory 统计为 0") - + except Exception as e: logger.error(f"Neo4j memory统计失败: {e}", exc_info=True) # 如果 Neo4j 查询失败,memory 设为 0 result["memory"] = 0 - + # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( - result.get("General", 0) + - result.get("Web", 0) + - result.get("Third-party", 0) + + result.get("General", 0) + + result.get("Web", 0) + + result.get("Third-party", 0) + result.get("Folder", 0) ) - + return result @@ -895,11 +695,11 @@ class MemoryAgentService: ) -> List[Dict[str, Any]]: """ 获取指定用户的热门记忆标签 - + 参数: - end_user_id: 用户ID(可选),对应Neo4j中的group_id字段 - limit: 返回标签数量限制 - + 返回格式: [ {"name": "标签名", "frequency": 频次}, @@ -928,13 +728,13 @@ class MemoryAgentService: 1. 用户名字(直接使用 end_user_name) 2. 用户标签(从摘要中用LLM总结3个标签) 3. 热门记忆标签(从hot_memory_tags获取前4个) - + 参数: - end_user_id: 用户ID(可选) - current_user_id: 当前登录用户的ID(保留参数) - llm_id: LLM模型ID(用于生成标签,可选,如果不提供则跳过标签生成) - db: 数据库会话(可选) - + 返回格式: { "name": "用户名", @@ -947,13 +747,13 @@ class MemoryAgentService: } """ result = {} - + # 1. 根据 end_user_id 获取 end_user_name try: if end_user_id and db: from app.repositories import end_user_repository from app.schemas.end_user_schema import EndUser as EndUserSchema - + end_user_orm = end_user_repository.get_end_user_by_id(db, end_user_id) if end_user_orm: end_user = EndUserSchema.model_validate(end_user_orm) @@ -965,14 +765,14 @@ class MemoryAgentService: except Exception as e: logger.error(f"Failed to get end_user_name: {e}") end_user_name = "默认用户" - + result["name"] = end_user_name logger.debug(f"The end_user is: {end_user_name}") - + # 2. 使用LLM从语句和实体中提取标签 try: connector = Neo4jConnector() - + # 查询该用户的语句 query = ( "MATCH (s:Statement) " @@ -982,7 +782,7 @@ class MemoryAgentService: ) rows = await connector.execute_query(query, group_id=end_user_id) statements = [r.get("statement", "") for r in rows if r.get("statement")] - + # 查询该用户的热门实体 entity_query = ( "MATCH (e:ExtractedEntity) " @@ -992,9 +792,9 @@ class MemoryAgentService: ) entity_rows = await connector.execute_query(entity_query, group_id=end_user_id) entities = [f"{r['name']} ({r['frequency']})" for r in entity_rows] - + await connector.close() - + if not statements or not llm_id: result["tags"] = [] if not llm_id and statements: @@ -1003,16 +803,16 @@ class MemoryAgentService: # 构建摘要文本 summary_text = f"用户语句样本:{' | '.join(statements[:20])}\n核心实体:{', '.join(entities)}" logger.debug(f"User data found: {len(statements)} statements, {len(entities)} entities") - + # 使用LLM提取标签 with get_db_context() as db: factory = MemoryClientFactory(db) llm_client = factory.get_llm_client(llm_id) - + # 定义标签提取的结构 class UserTags(BaseModel): tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") - + messages = [ { "role": "system", @@ -1023,20 +823,20 @@ class MemoryAgentService: "content": f"请从以下用户信息中提取3个标签:\n\n{summary_text}" } ] - + user_tags = await llm_client.response_structured( messages=messages, response_model=UserTags ) - + result["tags"] = user_tags.tags logger.debug(f"Extracted tags: {user_tags.tags}") - + except Exception as e: # 如果提取失败,使用默认值 logger.error(f"Failed to extract user tags: {e}") result["tags"] = [] - + try: # 3. 获取热门记忆标签(前4个) connector = Neo4jConnector() @@ -1049,18 +849,18 @@ class MemoryAgentService: "ORDER BY frequency DESC LIMIT 4" ) hot_tag_rows = await connector.execute_query( - hot_tag_query, - group_id=end_user_id, + hot_tag_query, + group_id=end_user_id, names_to_exclude=names_to_exclude ) await connector.close() - + result["hot_tags"] = [{"name": r["name"], "frequency": r["frequency"]} for r in hot_tag_rows] logger.debug(f"Hot tags found: {len(result['hot_tags'])} tags") except Exception as e: logger.error(f"Failed to get hot tags: {e}") result["hot_tags"] = [] - + return result async def stream_log_content(self) -> AsyncGenerator[str, None]: @@ -1135,79 +935,40 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic -# async def get_api_docs(self, file_path: Optional[str] = None) -> Dict[str, Any]: -# """ -# Parse and return API documentation - -# Args: -# file_path: Optional path to API docs file. If None, uses default path. - -# Returns: -# Dict containing parsed API documentation or error information -# """ -# try: -# target = file_path or get_default_docs_path() - -# if not os.path.isfile(target): -# return { -# "success": False, -# "msg": "API文档文件不存在", -# "error_code": "DOC_NOT_FOUND", -# "data": {"path": target} -# } - -# data = parse_api_docs(target) -# return { -# "success": True, -# "msg": "解析成功", -# "data": data -# } -# except Exception as e: -# logger.error(f"Failed to parse API docs: {e}") -# return { -# "success": False, -# "msg": "解析失败", -# "error_code": "DOC_PARSE_ERROR", -# "data": {"error": str(e)} -# } - - def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, Any]: """ 获取终端用户关联的记忆配置 - + 通过以下流程获取配置: 1. 根据 end_user_id 获取用户的 app_id 2. 获取该应用的最新发布版本 3. 从发布版本的 config 字段中提取 memory_config_id - 4. 根据 memory_config_id 查询配置名称 - + Args: end_user_id: 终端用户ID db: 数据库会话 - + Returns: - 包含 memory_config_id、config_name 和相关信息的字典 - + 包含 memory_config_id 和相关信息的字典 + Raises: ValueError: 当终端用户不存在或应用未发布时 """ from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig from app.models.end_user_model import EndUser from sqlalchemy import select - + logger.info(f"Getting connected config for end_user: {end_user_id}") - + # 1. 获取 end_user 及其 app_id end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first() if not end_user: logger.warning(f"End user not found: {end_user_id}") raise ValueError(f"终端用户不存在: {end_user_id}") - + app_id = end_user.app_id logger.debug(f"Found end_user app_id: {app_id}") - + # 2. 获取该应用的最新发布版本 stmt = ( select(AppRelease) @@ -1215,170 +976,25 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An .order_by(AppRelease.version.desc()) ) latest_release = db.scalars(stmt).first() - + if not latest_release: logger.warning(f"No active release found for app: {app_id}") raise ValueError(f"应用未发布: {app_id}") - + logger.debug(f"Found latest release: version={latest_release.version}, id={latest_release.id}") - + # 3. 从 config 中提取 memory_config_id config = latest_release.config or {} memory_obj = config.get('memory', {}) memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - - # 4. 根据 memory_config_id 查询配置名称 - config_name = None - if memory_config_id: - try: - # memory_config_id 可能是整数或字符串,需要转换 - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - data_config = db.query(DataConfig).filter(DataConfig.config_id == config_id).first() - if data_config: - config_name = data_config.config_name - logger.debug(f"Found config_name: {config_name} for config_id: {config_id}") - else: - logger.warning(f"DataConfig not found for config_id: {config_id}") - except (ValueError, TypeError) as e: - logger.warning(f"Invalid memory_config_id format: {memory_config_id}, error: {str(e)}") - + result = { "end_user_id": str(end_user_id), "app_id": str(app_id), "release_id": str(latest_release.id), "release_version": latest_release.version, - "memory_config_id": memory_config_id, - "memory_config_name": config_name + "memory_config_id": memory_config_id } - - logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, config_name={config_name}") - return result - -def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) -> Dict[str, Dict[str, Any]]: - """ - 批量获取多个终端用户关联的记忆配置 - - 通过优化的查询减少数据库往返次数: - 1. 一次性查询所有 end_user 及其 app_id - 2. 批量查询所有相关的 app_release - 3. 批量查询所有相关的 data_config - - Args: - end_user_ids: 终端用户ID列表 - db: 数据库会话 - - Returns: - 字典,key 为 end_user_id,value 为配置信息字典 - 对于查询失败的用户,value 包含 error 字段 - """ - from app.models.app_release_model import AppRelease - from app.models.data_config_model import DataConfig - from app.models.end_user_model import EndUser - from sqlalchemy import select - - logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users") - - result = {} - - # 1. 批量查询所有 end_user 及其 app_id - end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - - # 构建 end_user_id -> end_user 的映射 - end_user_map = {str(user.id): user for user in end_users} - - # 记录不存在的用户 - for user_id in end_user_ids: - if user_id not in end_user_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"终端用户不存在: {user_id}" - } - - if not end_users: - logger.warning("No valid end users found") - return result - - # 2. 批量查询所有相关应用的最新发布版本 - app_ids = [user.app_id for user in end_users] - - # 使用子查询找到每个 app 的最新版本 - from sqlalchemy import and_ - - # 查询所有相关的活跃发布版本 - releases = db.query(AppRelease).filter( - and_( - AppRelease.app_id.in_(app_ids), - AppRelease.is_active.is_(True) - ) - ).order_by(AppRelease.app_id, AppRelease.version.desc()).all() - - # 构建 app_id -> latest_release 的映射(每个 app 只保留最新版本) - app_release_map = {} - for release in releases: - app_id_str = str(release.app_id) - if app_id_str not in app_release_map: - app_release_map[app_id_str] = release - - # 3. 收集所有 memory_config_id - memory_config_ids = [] - for release in app_release_map.values(): - config = release.config or {} - memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - memory_config_ids.append(config_id) - except (ValueError, TypeError): - pass - - # 4. 批量查询所有 data_config - config_name_map = {} - if memory_config_ids: - data_configs = db.query(DataConfig).filter( - DataConfig.config_id.in_(memory_config_ids) - ).all() - config_name_map = {config.config_id: config.config_name for config in data_configs} - - # 5. 组装结果 - for user in end_users: - user_id = str(user.id) - app_id = str(user.app_id) - - # 检查是否有发布版本 - if app_id not in app_release_map: - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": None, - "memory_config_name": None, - "error": f"应用未发布: {app_id}" - } - continue - - release = app_release_map[app_id] - - # 提取 memory_config_id - config = release.config or {} - memory_obj = config.get('memory', {}) - memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None - - # 获取 config_name - config_name = None - if memory_config_id: - try: - config_id = int(memory_config_id) if isinstance(memory_config_id, str) else memory_config_id - config_name = config_name_map.get(config_id) - except (ValueError, TypeError): - pass - - result[user_id] = { - "end_user_id": user_id, - "memory_config_id": memory_config_id, - "memory_config_name": config_name - } - - logger.info(f"Successfully retrieved batch configs: total={len(result)}, with_config={sum(1 for v in result.values() if v.get('memory_config_id'))}") + logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") return result \ No newline at end of file