From abbd92b74c999300060a1e6db70c4015f393beda Mon Sep 17 00:00:00 2001 From: lanceyq <1982376970@qq.com> Date: Thu, 2 Apr 2026 14:19:27 +0800 Subject: [PATCH] Interface performance optimization, using only one function --- .../controllers/memory_storage_controller.py | 7 ++- api/app/services/memory_storage_service.py | 27 -------- api/app/tasks.py | 61 ++++++------------- 3 files changed, 23 insertions(+), 72 deletions(-) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index d8b39325..76eed50f 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -26,7 +26,7 @@ from app.services.memory_storage_service import ( analytics_hot_memory_tags, analytics_recent_activity_stats, kb_type_distribution, - search_all, + search_all_batch, search_chunk, search_detials, search_dialogue, @@ -409,7 +409,10 @@ async def search_all_num( ) -> dict: api_logger.info(f"Search all requested for end_user_id: {end_user_id}") try: - result = await search_all(end_user_id) + if not end_user_id: + return success(data={"total": 0}, msg="查询成功") + batch_result = await search_all_batch([end_user_id]) + result = {"total": batch_result.get(end_user_id, 0)} return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"Search all failed: {str(e)}") diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index fe0f3c32..132370b6 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -613,33 +613,6 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]: return data -async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]: - """查询用户的记忆总量(简化版本,只返回total) - - Args: - end_user_id: 用户ID - - Returns: - Dict[str, Any]: {"total": int} - """ - if not end_user_id: - return {"total": 0} - - result = await _neo4j_connector.execute_query( - MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, - end_user_ids=[end_user_id], - ) - - # 从批量查询结果中提取该用户的total - total = 0 - for row in result: - if row["user_id"] == end_user_id: - total = row["total"] - break - - return {"total": total} - - async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]: """统一知识库类型分布接口。 diff --git a/api/app/tasks.py b/api/app/tasks.py index c15aaeeb..5d29488a 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1324,7 +1324,7 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: from app.models.app_model import App from app.models.end_user_model import EndUser from app.repositories.memory_increment_repository import write_memory_increment - from app.services.memory_storage_service import search_all + from app.services.memory_storage_service import search_all_batch with get_db_context() as db: try: @@ -1358,27 +1358,15 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]: EndUser.workspace_id == workspace_id ).distinct().all() - # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 - total_num = 0 - end_user_details = [] + # 3. 批量查询所有宿主的记忆总量 + end_user_id_list = [str(eid) for (eid,) in end_users] + batch_result = await search_all_batch(end_user_id_list) - for (end_user_id,) in end_users: - try: - # 调用 search_all 接口查询该宿主的总量 - result = await search_all(str(end_user_id)) - user_total = result.get("total", 0) - total_num += user_total - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": user_total - }) - except Exception as e: - # 记录单个用户查询失败,但继续处理其他用户 - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": 0, - "error": str(e) - }) + total_num = sum(batch_result.values()) + end_user_details = [ + {"end_user_id": uid, "total": batch_result.get(uid, 0)} + for uid in end_user_id_list + ] # 4. 写入数据库 memory_increment = write_memory_increment( @@ -1441,7 +1429,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: from app.models.end_user_model import EndUser from app.models.workspace_model import Workspace from app.repositories.memory_increment_repository import write_memory_increment - from app.services.memory_storage_service import search_all + from app.services.memory_storage_service import search_all_batch with get_db_context() as db: try: @@ -1499,28 +1487,15 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: EndUser.workspace_id == workspace_id ).distinct().all() - # 3. 遍历所有end_user,查询每个宿主的记忆总量并累加 - total_num = 0 - end_user_details = [] + # 3. 批量查询所有宿主的记忆总量 + end_user_id_list = [str(eid) for (eid,) in end_users] + batch_result = await search_all_batch(end_user_id_list) - for (end_user_id,) in end_users: - try: - # 调用 search_all 接口查询该宿主的总量 - result = await search_all(str(end_user_id)) - user_total = result.get("total", 0) - total_num += user_total - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": user_total - }) - except Exception as e: - # 记录单个用户查询失败,但继续处理其他用户 - logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}") - end_user_details.append({ - "end_user_id": str(end_user_id), - "total": 0, - "error": str(e) - }) + total_num = sum(batch_result.values()) + end_user_details = [ + {"end_user_id": uid, "total": batch_result.get(uid, 0)} + for uid in end_user_id_list + ] # 4. 写入数据库 memory_increment = write_memory_increment(