perf(memory): add detailed performance logging and optimize batch access recording

- Add [PERF] prefixed logging throughout hybrid search pipeline for better performance visibility
- Break down latency metrics with separate timing for config loading, embedder initialization, and rerank computation
- Format latency breakdown as JSON in performance summary logs
- Optimize batch_record_access to process node access records in parallel using asyncio.gather instead of sequential processing
- Add performance timing instrumentation for forgetting config loading and rerank computation stages
- Reorganize imports in access_history_manager for consistency
- Improve observability of search performance bottlenecks through structured logging
This commit is contained in:
Ke Sun
2026-01-14 12:02:10 +08:00
parent 78bb9315b7
commit a6e1898e1b
3 changed files with 76 additions and 31 deletions

View File

@@ -842,7 +842,7 @@ async def run_hybrid_search(
if search_type in ["keyword", "hybrid"]:
# Keyword-based search
logger.info("Starting keyword search...")
logger.info("[PERF] Starting keyword search...")
keyword_start = time.time()
keyword_task = asyncio.create_task(
search_graph(
@@ -856,7 +856,7 @@ async def run_hybrid_search(
if search_type in ["embedding", "hybrid"]:
# Embedding-based search
logger.info("Starting embedding search...")
logger.info("[PERF] Starting embedding search...")
embedding_start = time.time()
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
@@ -872,13 +872,13 @@ async def run_hybrid_search(
type="llm"
)
config_load_time = time.time() - config_load_start
logger.info(f"Config loading took {config_load_time:.4f}s")
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
# Init embedder
embedder_init_start = time.time()
embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task(
search_graph_by_embedding(
@@ -895,7 +895,7 @@ async def run_hybrid_search(
keyword_results = await keyword_task
keyword_latency = time.time() - keyword_start
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
if search_type == "keyword":
results = keyword_results
else:
@@ -905,7 +905,7 @@ async def run_hybrid_search(
embedding_results = await embedding_task
embedding_latency = time.time() - embedding_start
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
if search_type == "embedding":
results = embedding_results
else:
@@ -922,17 +922,21 @@ async def run_hybrid_search(
# Apply two-stage reranking with ACTR activation calculation
rerank_start = time.time()
logger.info("Using two-stage reranking with ACTR activation")
logger.info("[PERF] Using two-stage reranking with ACTR activation")
# 加载遗忘引擎配置
config_start = time.time()
try:
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
forgetting_cfg = ForgettingEngineConfig()
config_time = time.time() - config_start
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
# 统一使用激活度重排序(两阶段:检索 + ACTR计算
rerank_compute_start = time.time()
reranked_results = rerank_with_activation(
keyword_results=keyword_results,
embedding_results=embedding_results,
@@ -941,10 +945,12 @@ async def run_hybrid_search(
forgetting_config=forgetting_cfg,
activation_boost_factor=activation_boost_factor,
)
rerank_compute_time = time.time() - rerank_compute_start
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
# Optional: apply reranker placeholder if enabled via config
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
@@ -985,8 +991,10 @@ async def run_hybrid_search(
else:
results["latency_metrics"] = latency_metrics
logger.info(f"Total search completed in {total_latency:.4f}s")
logger.info(f"Latency breakdown: {latency_metrics}")
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
logger.info(f"[PERF] =========================================")
# Sanitize results: drop large/unused fields
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs

View File

@@ -8,14 +8,16 @@ Classes:
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
"""
import asyncio
import logging
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
@@ -188,30 +190,43 @@ class AccessHistoryManager:
Returns:
List[Dict[str, Any]]: 成功更新的节点列表
"""
import time
batch_start = time.time()
if current_time is None:
current_time = datetime.now()
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
tasks = []
for node_id in node_ids:
task = self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
current_time=current_time
)
tasks.append(task)
# Execute all tasks in parallel
task_results = await asyncio.gather(*tasks, return_exceptions=True)
# Collect successful results and count failures
results = []
failed_count = 0
for node_id in node_ids:
try:
updated_node = await self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
current_time=current_time
)
results.append(updated_node)
except Exception as e:
for node_id, result in zip(node_ids, task_results):
if isinstance(result, Exception):
failed_count += 1
logger.warning(
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(result)}"
)
else:
results.append(result)
batch_duration = time.time() - batch_start
logger.info(
f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}"
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
)
return results
@@ -531,7 +546,10 @@ class AccessHistoryManager:
Dict[str, Any]: 更新数据,包含所有需要更新的字段
"""
access_history = node_data.get('access_history') or []
importance_score = node_data.get('importance_score', 0.5)
# Handle None importance_score - default to 0.5
importance_score = node_data.get('importance_score')
if importance_score is None:
importance_score = 0.5
# 追加新的访问时间
new_access_history = access_history + [current_time_iso]

View File

@@ -456,23 +456,36 @@ class MemoryAgentService:
client = MultiServerMCPClient(mcp_config)
async with client.session('data_flow') as session:
session_start = time.time()
logger.debug("Connected to MCP Server: data_flow")
tools_start = time.time()
tools = await load_mcp_tools(session)
tools_time = time.time() - tools_start
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
outputs = []
intermediate_outputs = []
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
# Pass memory_config to the graph workflow
graph_start = time.time()
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
graph_init_time = time.time() - graph_start
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
start = time.time()
config = {"configurable": {"thread_id": group_id}}
workflow_errors = [] # Track errors from workflow
event_count = 0
async for event in graph.astream(
{"messages": history, "memory_config": memory_config, "errors": []},
stream_mode="values",
config=config
):
event_count += 1
event_start = time.time()
messages = event.get('messages')
# Capture any errors from the state
if event.get('errors'):
@@ -525,9 +538,15 @@ class MemoryAgentService:
pass
except Exception as e:
logger.debug(f"Failed to extract intermediate output: {e}")
event_time = time.time() - event_start
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
workflow_duration = time.time() - start
logger.info(f"Read graph workflow completed in {workflow_duration}s")
session_duration = time.time() - session_start
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
logger.info(f"[PERF] Total events processed: {event_count}")
# Extract final answer
final_answer = ""
for messages in outputs:
@@ -1186,8 +1205,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
ValueError: 当终端用户不存在或应用未发布时
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Getting connected config for end_user: {end_user_id}")
@@ -1266,8 +1285,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
对于查询失败的用户value 包含 error 字段
"""
from app.models.app_release_model import AppRelease
from app.models.end_user_model import EndUser
from app.models.data_config_model import DataConfig
from app.models.end_user_model import EndUser
from sqlalchemy import select
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")