changes:(services) Modify the query method for user memory to batch processing.
This commit is contained in:
@@ -353,15 +353,13 @@ async def get_workspace_total_memory_count(
|
|||||||
"details": []
|
"details": []
|
||||||
}
|
}
|
||||||
|
|
||||||
# 2. 对每个 host_id 调用 search_all 获取 total
|
# 2. 使用 search_all_batch 批量查询所有宿主的记忆数量
|
||||||
from app.services import memory_storage_service
|
from app.services import memory_storage_service
|
||||||
|
|
||||||
total_count = 0
|
|
||||||
details = []
|
|
||||||
|
|
||||||
# 如果提供了 end_user_id,只查询该用户
|
# 如果提供了 end_user_id,只查询该用户
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
search_result = await memory_storage_service.search_all(end_user_id=end_user_id)
|
batch_result = await memory_storage_service.search_all_batch([end_user_id])
|
||||||
|
count = batch_result.get(end_user_id, 0)
|
||||||
# 查询用户名称
|
# 查询用户名称
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
@@ -369,42 +367,31 @@ async def get_workspace_total_memory_count(
|
|||||||
user_name = end_user.other_name if end_user else None
|
user_name = end_user.other_name if end_user else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"total_memory_count": search_result.get("total", 0),
|
"total_memory_count": count,
|
||||||
"host_count": 1,
|
"host_count": 1,
|
||||||
"details": [{
|
"details": [{
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"count": search_result.get("total", 0),
|
"count": count,
|
||||||
"name": user_name
|
"name": user_name
|
||||||
}]
|
}]
|
||||||
}
|
}
|
||||||
|
|
||||||
for host in hosts:
|
# 批量查询所有宿主记忆数量(一次 Neo4j 查询)
|
||||||
try:
|
end_user_ids = [str(host.id) for host in hosts]
|
||||||
end_user_id_str = str(host.id)
|
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||||
|
|
||||||
search_result = await memory_storage_service.search_all(
|
# 构建 host name 映射
|
||||||
end_user_id=end_user_id_str
|
host_name_map = {str(host.id): host.other_name for host in hosts}
|
||||||
)
|
|
||||||
|
total_count = sum(batch_result.values())
|
||||||
host_total = search_result.get("total", 0)
|
details = [
|
||||||
total_count += host_total
|
{
|
||||||
|
"end_user_id": uid,
|
||||||
details.append({
|
"count": batch_result.get(uid, 0),
|
||||||
"end_user_id": end_user_id_str,
|
"name": host_name_map.get(uid)
|
||||||
"count": host_total,
|
}
|
||||||
"name": host.other_name # 使用 other_name 字段
|
for uid in end_user_ids
|
||||||
})
|
]
|
||||||
|
|
||||||
business_logger.debug(f"EndUser {end_user_id_str} ({host.other_name}) 记忆数: {host_total}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
|
|
||||||
# 失败的 host 记为 0
|
|
||||||
details.append({
|
|
||||||
"end_user_id": str(host.id),
|
|
||||||
"count": 0,
|
|
||||||
"name": host.other_name # 使用 other_name 字段
|
|
||||||
})
|
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"total_memory_count": total_count,
|
"total_memory_count": total_count,
|
||||||
|
|||||||
@@ -6,13 +6,13 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import asynccontextmanager, contextmanager
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
# 同步的上下文管理器,使用@contextmanager修饰
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def timer(label: str, user_count: int = 0):
|
def timer(label: str, user_count: int = 0):
|
||||||
"""上下文管理器:用于测量代码块执行时间
|
"""上下文管理器:用于测量代码块执行时间
|
||||||
@@ -35,3 +35,27 @@ def timer(label: str, user_count: int = 0):
|
|||||||
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
|
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
|
||||||
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
|
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
|
||||||
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")
|
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")
|
||||||
|
|
||||||
|
# 异步的上下文管理器,使用@asynccontextmanager装饰
|
||||||
|
@asynccontextmanager
|
||||||
|
async def async_timer(label: str, user_count: int = 0):
|
||||||
|
"""异步上下文管理器:用于测量包含 await 的异步代码块执行时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label: 统计标签,用于标识被测量的代码块
|
||||||
|
user_count: 用户数,可选参数,用于记录处理的用户数量
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
async with async_timer("获取用户列表"):
|
||||||
|
users = await get_users()
|
||||||
|
|
||||||
|
async with async_timer("批量处理", user_count=len(user_ids)):
|
||||||
|
await process_users(user_ids)
|
||||||
|
"""
|
||||||
|
start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
|
||||||
|
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
|
||||||
|
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")
|
||||||
|
|||||||
Reference in New Issue
Block a user