perf(memory): add detailed performance logging and optimize batch access recording
- Add [PERF] prefixed logging throughout hybrid search pipeline for better performance visibility - Break down latency metrics with separate timing for config loading, embedder initialization, and rerank computation - Format latency breakdown as JSON in performance summary logs - Optimize batch_record_access to process node access records in parallel using asyncio.gather instead of sequential processing - Add performance timing instrumentation for forgetting config loading and rerank computation stages - Reorganize imports in access_history_manager for consistency - Improve observability of search performance bottlenecks through structured logging
This commit is contained in:
@@ -842,7 +842,7 @@ async def run_hybrid_search(
|
||||
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
logger.info("Starting keyword search...")
|
||||
logger.info("[PERF] Starting keyword search...")
|
||||
keyword_start = time.time()
|
||||
keyword_task = asyncio.create_task(
|
||||
search_graph(
|
||||
@@ -856,7 +856,7 @@ async def run_hybrid_search(
|
||||
|
||||
if search_type in ["embedding", "hybrid"]:
|
||||
# Embedding-based search
|
||||
logger.info("Starting embedding search...")
|
||||
logger.info("[PERF] Starting embedding search...")
|
||||
embedding_start = time.time()
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
@@ -872,13 +872,13 @@ async def run_hybrid_search(
|
||||
type="llm"
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"Config loading took {config_load_time:.4f}s")
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"Embedder init took {embedder_init_time:.4f}s")
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
@@ -895,7 +895,7 @@ async def run_hybrid_search(
|
||||
keyword_results = await keyword_task
|
||||
keyword_latency = time.time() - keyword_start
|
||||
latency_metrics["keyword_search_latency"] = round(keyword_latency, 4)
|
||||
logger.info(f"Keyword search completed in {keyword_latency:.4f}s")
|
||||
logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s")
|
||||
if search_type == "keyword":
|
||||
results = keyword_results
|
||||
else:
|
||||
@@ -905,7 +905,7 @@ async def run_hybrid_search(
|
||||
embedding_results = await embedding_task
|
||||
embedding_latency = time.time() - embedding_start
|
||||
latency_metrics["embedding_search_latency"] = round(embedding_latency, 4)
|
||||
logger.info(f"Embedding search completed in {embedding_latency:.4f}s")
|
||||
logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s")
|
||||
if search_type == "embedding":
|
||||
results = embedding_results
|
||||
else:
|
||||
@@ -922,17 +922,21 @@ async def run_hybrid_search(
|
||||
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
rerank_start = time.time()
|
||||
logger.info("Using two-stage reranking with ACTR activation")
|
||||
logger.info("[PERF] Using two-stage reranking with ACTR activation")
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
config_start = time.time()
|
||||
try:
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
config_time = time.time() - config_start
|
||||
logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s")
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
rerank_compute_start = time.time()
|
||||
reranked_results = rerank_with_activation(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
@@ -941,10 +945,12 @@ async def run_hybrid_search(
|
||||
forgetting_config=forgetting_cfg,
|
||||
activation_boost_factor=activation_boost_factor,
|
||||
)
|
||||
rerank_compute_time = time.time() - rerank_compute_start
|
||||
logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s")
|
||||
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
|
||||
logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s")
|
||||
|
||||
# Optional: apply reranker placeholder if enabled via config
|
||||
reranked_results = apply_reranker_placeholder(reranked_results, query_text)
|
||||
@@ -985,8 +991,10 @@ async def run_hybrid_search(
|
||||
else:
|
||||
results["latency_metrics"] = latency_metrics
|
||||
|
||||
logger.info(f"Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"Latency breakdown: {latency_metrics}")
|
||||
logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====")
|
||||
logger.info(f"[PERF] Total search completed in {total_latency:.4f}s")
|
||||
logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}")
|
||||
logger.info(f"[PERF] =========================================")
|
||||
|
||||
# Sanitize results: drop large/unused fields
|
||||
_remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs
|
||||
|
||||
@@ -8,14 +8,16 @@ Classes:
|
||||
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -188,30 +190,43 @@ class AccessHistoryManager:
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 成功更新的节点列表
|
||||
"""
|
||||
import time
|
||||
batch_start = time.time()
|
||||
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
# PERFORMANCE FIX: Process all nodes in parallel instead of sequentially
|
||||
tasks = []
|
||||
for node_id in node_ids:
|
||||
task = self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
current_time=current_time
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Execute all tasks in parallel
|
||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Collect successful results and count failures
|
||||
results = []
|
||||
failed_count = 0
|
||||
|
||||
for node_id in node_ids:
|
||||
try:
|
||||
updated_node = await self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
current_time=current_time
|
||||
)
|
||||
results.append(updated_node)
|
||||
except Exception as e:
|
||||
for node_id, result in zip(node_ids, task_results):
|
||||
if isinstance(result, Exception):
|
||||
failed_count += 1
|
||||
logger.warning(
|
||||
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(result)}"
|
||||
)
|
||||
else:
|
||||
results.append(result)
|
||||
|
||||
batch_duration = time.time() - batch_start
|
||||
logger.info(
|
||||
f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"失败 {failed_count}"
|
||||
f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"失败 {failed_count}, 耗时 {batch_duration:.4f}s"
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -531,7 +546,10 @@ class AccessHistoryManager:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
"""
|
||||
access_history = node_data.get('access_history') or []
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
# Handle None importance_score - default to 0.5
|
||||
importance_score = node_data.get('importance_score')
|
||||
if importance_score is None:
|
||||
importance_score = 0.5
|
||||
|
||||
# 追加新的访问时间
|
||||
new_access_history = access_history + [current_time_iso]
|
||||
|
||||
@@ -456,23 +456,36 @@ class MemoryAgentService:
|
||||
client = MultiServerMCPClient(mcp_config)
|
||||
|
||||
async with client.session('data_flow') as session:
|
||||
session_start = time.time()
|
||||
logger.debug("Connected to MCP Server: data_flow")
|
||||
|
||||
tools_start = time.time()
|
||||
tools = await load_mcp_tools(session)
|
||||
tools_time = time.time() - tools_start
|
||||
logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s")
|
||||
|
||||
outputs = []
|
||||
intermediate_outputs = []
|
||||
seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates
|
||||
|
||||
# Pass memory_config to the graph workflow
|
||||
graph_start = time.time()
|
||||
async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph:
|
||||
graph_init_time = time.time() - graph_start
|
||||
logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s")
|
||||
|
||||
start = time.time()
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
workflow_errors = [] # Track errors from workflow
|
||||
|
||||
|
||||
event_count = 0
|
||||
async for event in graph.astream(
|
||||
{"messages": history, "memory_config": memory_config, "errors": []},
|
||||
stream_mode="values",
|
||||
config=config
|
||||
):
|
||||
event_count += 1
|
||||
event_start = time.time()
|
||||
messages = event.get('messages')
|
||||
# Capture any errors from the state
|
||||
if event.get('errors'):
|
||||
@@ -525,9 +538,15 @@ class MemoryAgentService:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to extract intermediate output: {e}")
|
||||
|
||||
event_time = time.time() - event_start
|
||||
logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s")
|
||||
|
||||
workflow_duration = time.time() - start
|
||||
logger.info(f"Read graph workflow completed in {workflow_duration}s")
|
||||
session_duration = time.time() - session_start
|
||||
logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s")
|
||||
logger.info(f"[PERF] Total session duration: {session_duration:.4f}s")
|
||||
logger.info(f"[PERF] Total events processed: {event_count}")
|
||||
# Extract final answer
|
||||
final_answer = ""
|
||||
for messages in outputs:
|
||||
@@ -1186,8 +1205,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
||||
ValueError: 当终端用户不存在或应用未发布时
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Getting connected config for end_user: {end_user_id}")
|
||||
@@ -1266,8 +1285,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session)
|
||||
对于查询失败的用户,value 包含 error 字段
|
||||
"""
|
||||
from app.models.app_release_model import AppRelease
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.data_config_model import DataConfig
|
||||
from app.models.end_user_model import EndUser
|
||||
from sqlalchemy import select
|
||||
|
||||
logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")
|
||||
|
||||
Reference in New Issue
Block a user