diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index b14a06af..b96f4bde 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -353,15 +353,13 @@ async def get_workspace_total_memory_count( "details": [] } - # 2. 对每个 host_id 调用 search_all 获取 total + # 2. 使用 search_all_batch 批量查询所有宿主的记忆数量 from app.services import memory_storage_service - total_count = 0 - details = [] - # 如果提供了 end_user_id,只查询该用户 if end_user_id: - search_result = await memory_storage_service.search_all(end_user_id=end_user_id) + batch_result = await memory_storage_service.search_all_batch([end_user_id]) + count = batch_result.get(end_user_id, 0) # 查询用户名称 from app.repositories.end_user_repository import EndUserRepository repo = EndUserRepository(db) @@ -369,42 +367,31 @@ async def get_workspace_total_memory_count( user_name = end_user.other_name if end_user else None return { - "total_memory_count": search_result.get("total", 0), + "total_memory_count": count, "host_count": 1, "details": [{ "end_user_id": end_user_id, - "count": search_result.get("total", 0), + "count": count, "name": user_name }] } - for host in hosts: - try: - end_user_id_str = str(host.id) - - search_result = await memory_storage_service.search_all( - end_user_id=end_user_id_str - ) - - host_total = search_result.get("total", 0) - total_count += host_total - - details.append({ - "end_user_id": end_user_id_str, - "count": host_total, - "name": host.other_name # 使用 other_name 字段 - }) - - business_logger.debug(f"EndUser {end_user_id_str} ({host.other_name}) 记忆数: {host_total}") - - except Exception as e: - business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}") - # 失败的 host 记为 0 - details.append({ - "end_user_id": str(host.id), - "count": 0, - "name": host.other_name # 使用 other_name 字段 - }) + # 批量查询所有宿主记忆数量(一次 Neo4j 查询) + end_user_ids = [str(host.id) for host in hosts] + batch_result = await memory_storage_service.search_all_batch(end_user_ids) + + # 构建 host name 映射 + host_name_map = {str(host.id): host.other_name for host in hosts} + + total_count = sum(batch_result.values()) + details = [ + { + "end_user_id": uid, + "count": batch_result.get(uid, 0), + "name": host_name_map.get(uid) + } + for uid in end_user_ids + ] result = { "total_memory_count": total_count, diff --git a/api/app/utils/performance_timer.py b/api/app/utils/performance_timer.py index 6b0ec5d6..04e52fb1 100644 --- a/api/app/utils/performance_timer.py +++ b/api/app/utils/performance_timer.py @@ -6,13 +6,13 @@ """ import time -from contextlib import contextmanager +from contextlib import asynccontextmanager, contextmanager from app.core.logging_config import get_api_logger # 获取API专用日志器 api_logger = get_api_logger() - +# 同步的上下文管理器,使用@contextmanager修饰 @contextmanager def timer(label: str, user_count: int = 0): """上下文管理器:用于测量代码块执行时间 @@ -35,3 +35,27 @@ def timer(label: str, user_count: int = 0): elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒 extra_info = f", 用户数: {user_count}" if user_count > 0 else "" api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}") + +# 异步的上下文管理器,使用@asynccontextmanager装饰 +@asynccontextmanager +async def async_timer(label: str, user_count: int = 0): + """异步上下文管理器:用于测量包含 await 的异步代码块执行时间 + + Args: + label: 统计标签,用于标识被测量的代码块 + user_count: 用户数,可选参数,用于记录处理的用户数量 + + Usage: + async with async_timer("获取用户列表"): + users = await get_users() + + async with async_timer("批量处理", user_count=len(user_ids)): + await process_users(user_ids) + """ + start = time.perf_counter() + try: + yield + finally: + elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒 + extra_info = f", 用户数: {user_count}" if user_count > 0 else "" + api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")