Interface performance optimization, using only one function

This commit is contained in:
lanceyq
2026-04-02 14:19:27 +08:00
parent 960ee9f2df
commit abbd92b74c
3 changed files with 23 additions and 72 deletions

View File

@@ -26,7 +26,7 @@ from app.services.memory_storage_service import (
analytics_hot_memory_tags,
analytics_recent_activity_stats,
kb_type_distribution,
search_all,
search_all_batch,
search_chunk,
search_detials,
search_dialogue,
@@ -409,7 +409,10 @@ async def search_all_num(
) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try:
result = await search_all(end_user_id)
if not end_user_id:
return success(data={"total": 0}, msg="查询成功")
batch_result = await search_all_batch([end_user_id])
result = {"total": batch_result.get(end_user_id, 0)}
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search all failed: {str(e)}")

View File

@@ -613,33 +613,6 @@ async def search_entity(end_user_id: Optional[str] = None) -> Dict[str, Any]:
return data
async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""查询用户的记忆总量简化版本只返回total
Args:
end_user_id: 用户ID
Returns:
Dict[str, Any]: {"total": int}
"""
if not end_user_id:
return {"total": 0}
result = await _neo4j_connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=[end_user_id],
)
# 从批量查询结果中提取该用户的total
total = 0
for row in result:
if row["user_id"] == end_user_id:
total = row["total"]
break
return {"total": total}
async def kb_type_distribution(end_user_id: Optional[str] = None) -> Dict[str, Any]:
"""统一知识库类型分布接口。

View File

@@ -1324,7 +1324,7 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
from app.models.app_model import App
from app.models.end_user_model import EndUser
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
from app.services.memory_storage_service import search_all_batch
with get_db_context() as db:
try:
@@ -1358,27 +1358,15 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
EndUser.workspace_id == workspace_id
).distinct().all()
# 3. 遍历所有end_user查询每个宿主的记忆总量并累加
total_num = 0
end_user_details = []
# 3. 批量查询所有宿主的记忆总量
end_user_id_list = [str(eid) for (eid,) in end_users]
batch_result = await search_all_batch(end_user_id_list)
for (end_user_id,) in end_users:
try:
# 调用 search_all 接口查询该宿主的总量
result = await search_all(str(end_user_id))
user_total = result.get("total", 0)
total_num += user_total
end_user_details.append({
"end_user_id": str(end_user_id),
"total": user_total
})
except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户
end_user_details.append({
"end_user_id": str(end_user_id),
"total": 0,
"error": str(e)
})
total_num = sum(batch_result.values())
end_user_details = [
{"end_user_id": uid, "total": batch_result.get(uid, 0)}
for uid in end_user_id_list
]
# 4. 写入数据库
memory_increment = write_memory_increment(
@@ -1441,7 +1429,7 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
from app.models.end_user_model import EndUser
from app.models.workspace_model import Workspace
from app.repositories.memory_increment_repository import write_memory_increment
from app.services.memory_storage_service import search_all
from app.services.memory_storage_service import search_all_batch
with get_db_context() as db:
try:
@@ -1499,28 +1487,15 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
EndUser.workspace_id == workspace_id
).distinct().all()
# 3. 遍历所有end_user查询每个宿主的记忆总量并累加
total_num = 0
end_user_details = []
# 3. 批量查询所有宿主的记忆总量
end_user_id_list = [str(eid) for (eid,) in end_users]
batch_result = await search_all_batch(end_user_id_list)
for (end_user_id,) in end_users:
try:
# 调用 search_all 接口查询该宿主的总量
result = await search_all(str(end_user_id))
user_total = result.get("total", 0)
total_num += user_total
end_user_details.append({
"end_user_id": str(end_user_id),
"total": user_total
})
except Exception as e:
# 记录单个用户查询失败,但继续处理其他用户
logger.warning(f"查询用户 {end_user_id} 记忆失败: {str(e)}")
end_user_details.append({
"end_user_id": str(end_user_id),
"total": 0,
"error": str(e)
})
total_num = sum(batch_result.values())
end_user_details = [
{"end_user_id": uid, "total": batch_result.get(uid, 0)}
for uid in end_user_id_list
]
# 4. 写入数据库
memory_increment = write_memory_increment(