diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 11df8166..ae2b9cfa 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -842,7 +842,7 @@ async def run_hybrid_search( if search_type in ["keyword", "hybrid"]: # Keyword-based search - logger.info("Starting keyword search...") + logger.info("[PERF] Starting keyword search...") keyword_start = time.time() keyword_task = asyncio.create_task( search_graph( @@ -856,7 +856,7 @@ async def run_hybrid_search( if search_type in ["embedding", "hybrid"]: # Embedding-based search - logger.info("Starting embedding search...") + logger.info("[PERF] Starting embedding search...") embedding_start = time.time() # 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig @@ -872,13 +872,13 @@ async def run_hybrid_search( type="llm" ) config_load_time = time.time() - config_load_start - logger.info(f"Config loading took {config_load_time:.4f}s") + logger.info(f"[PERF] Config loading took {config_load_time:.4f}s") # Init embedder embedder_init_start = time.time() embedder = OpenAIEmbedderClient(model_config=rb_config) embedder_init_time = time.time() - embedder_init_start - logger.info(f"Embedder init took {embedder_init_time:.4f}s") + logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s") embedding_task = asyncio.create_task( search_graph_by_embedding( @@ -895,7 +895,7 @@ async def run_hybrid_search( keyword_results = await keyword_task keyword_latency = time.time() - keyword_start latency_metrics["keyword_search_latency"] = round(keyword_latency, 4) - logger.info(f"Keyword search completed in {keyword_latency:.4f}s") + logger.info(f"[PERF] Keyword search completed in {keyword_latency:.4f}s") if search_type == "keyword": results = keyword_results else: @@ -905,7 +905,7 @@ async def run_hybrid_search( embedding_results = await embedding_task embedding_latency = time.time() - embedding_start latency_metrics["embedding_search_latency"] = round(embedding_latency, 4) - logger.info(f"Embedding search completed in {embedding_latency:.4f}s") + logger.info(f"[PERF] Embedding search completed in {embedding_latency:.4f}s") if search_type == "embedding": results = embedding_results else: @@ -922,17 +922,21 @@ async def run_hybrid_search( # Apply two-stage reranking with ACTR activation calculation rerank_start = time.time() - logger.info("Using two-stage reranking with ACTR activation") + logger.info("[PERF] Using two-stage reranking with ACTR activation") # 加载遗忘引擎配置 + config_start = time.time() try: pc = get_pipeline_config(memory_config) forgetting_cfg = pc.forgetting_engine except Exception as e: logger.debug(f"Failed to load forgetting config, using defaults: {e}") forgetting_cfg = ForgettingEngineConfig() + config_time = time.time() - config_start + logger.info(f"[PERF] Forgetting config loading took {config_time:.4f}s") # 统一使用激活度重排序(两阶段:检索 + ACTR计算) + rerank_compute_start = time.time() reranked_results = rerank_with_activation( keyword_results=keyword_results, embedding_results=embedding_results, @@ -941,10 +945,12 @@ async def run_hybrid_search( forgetting_config=forgetting_cfg, activation_boost_factor=activation_boost_factor, ) + rerank_compute_time = time.time() - rerank_compute_start + logger.info(f"[PERF] Rerank computation took {rerank_compute_time:.4f}s") rerank_latency = time.time() - rerank_start latency_metrics["reranking_latency"] = round(rerank_latency, 4) - logger.info(f"Reranking completed in {rerank_latency:.4f}s") + logger.info(f"[PERF] Total reranking completed in {rerank_latency:.4f}s") # Optional: apply reranker placeholder if enabled via config reranked_results = apply_reranker_placeholder(reranked_results, query_text) @@ -985,8 +991,10 @@ async def run_hybrid_search( else: results["latency_metrics"] = latency_metrics - logger.info(f"Total search completed in {total_latency:.4f}s") - logger.info(f"Latency breakdown: {latency_metrics}") + logger.info(f"[PERF] ===== SEARCH PERFORMANCE SUMMARY =====") + logger.info(f"[PERF] Total search completed in {total_latency:.4f}s") + logger.info(f"[PERF] Latency breakdown: {json.dumps(latency_metrics, indent=2)}") + logger.info(f"[PERF] =========================================") # Sanitize results: drop large/unused fields _remove_keys_recursive(results, ["name_embedding"]) # drop entity name embeddings from outputs diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py index 1a2e3cbc..5722769a 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -8,14 +8,16 @@ Classes: AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查 """ +import asyncio import logging -from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from enum import Enum +from typing import Any, Dict, List, Optional, Tuple +from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( + ACTRCalculator, +) from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator - logger = logging.getLogger(__name__) @@ -188,30 +190,43 @@ class AccessHistoryManager: Returns: List[Dict[str, Any]]: 成功更新的节点列表 """ + import time + batch_start = time.time() + if current_time is None: current_time = datetime.now() + # PERFORMANCE FIX: Process all nodes in parallel instead of sequentially + tasks = [] + for node_id in node_ids: + task = self.record_access( + node_id=node_id, + node_label=node_label, + group_id=group_id, + current_time=current_time + ) + tasks.append(task) + + # Execute all tasks in parallel + task_results = await asyncio.gather(*tasks, return_exceptions=True) + + # Collect successful results and count failures results = [] failed_count = 0 - for node_id in node_ids: - try: - updated_node = await self.record_access( - node_id=node_id, - node_label=node_label, - group_id=group_id, - current_time=current_time - ) - results.append(updated_node) - except Exception as e: + for node_id, result in zip(node_ids, task_results): + if isinstance(result, Exception): failed_count += 1 logger.warning( - f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}" + f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(result)}" ) + else: + results.append(result) + batch_duration = time.time() - batch_start logger.info( - f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " - f"失败 {failed_count}" + f"[PERF] 批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " + f"失败 {failed_count}, 耗时 {batch_duration:.4f}s" ) return results @@ -531,7 +546,10 @@ class AccessHistoryManager: Dict[str, Any]: 更新数据,包含所有需要更新的字段 """ access_history = node_data.get('access_history') or [] - importance_score = node_data.get('importance_score', 0.5) + # Handle None importance_score - default to 0.5 + importance_score = node_data.get('importance_score') + if importance_score is None: + importance_score = 0.5 # 追加新的访问时间 new_access_history = access_history + [current_time_iso] diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 10f53ed7..2d78d796 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -456,23 +456,36 @@ class MemoryAgentService: client = MultiServerMCPClient(mcp_config) async with client.session('data_flow') as session: + session_start = time.time() logger.debug("Connected to MCP Server: data_flow") + + tools_start = time.time() tools = await load_mcp_tools(session) + tools_time = time.time() - tools_start + logger.info(f"[PERF] MCP tools loading took: {tools_time:.4f}s") + outputs = [] intermediate_outputs = [] seen_intermediates = set() # Track seen intermediate outputs to avoid duplicates # Pass memory_config to the graph workflow + graph_start = time.time() async with make_read_graph(group_id, tools, search_switch, group_id, group_id, memory_config=memory_config, storage_type=storage_type, user_rag_memory_id=user_rag_memory_id) as graph: + graph_init_time = time.time() - graph_start + logger.info(f"[PERF] Graph initialization took: {graph_init_time:.4f}s") + start = time.time() config = {"configurable": {"thread_id": group_id}} workflow_errors = [] # Track errors from workflow - + + event_count = 0 async for event in graph.astream( {"messages": history, "memory_config": memory_config, "errors": []}, stream_mode="values", config=config ): + event_count += 1 + event_start = time.time() messages = event.get('messages') # Capture any errors from the state if event.get('errors'): @@ -525,9 +538,15 @@ class MemoryAgentService: pass except Exception as e: logger.debug(f"Failed to extract intermediate output: {e}") + + event_time = time.time() - event_start + logger.info(f"[PERF] Event {event_count} processing took: {event_time:.4f}s") workflow_duration = time.time() - start - logger.info(f"Read graph workflow completed in {workflow_duration}s") + session_duration = time.time() - session_start + logger.info(f"[PERF] Read graph workflow completed in {workflow_duration}s") + logger.info(f"[PERF] Total session duration: {session_duration:.4f}s") + logger.info(f"[PERF] Total events processed: {event_count}") # Extract final answer final_answer = "" for messages in outputs: @@ -1186,8 +1205,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An ValueError: 当终端用户不存在或应用未发布时 """ from app.models.app_release_model import AppRelease - from app.models.end_user_model import EndUser from app.models.data_config_model import DataConfig + from app.models.end_user_model import EndUser from sqlalchemy import select logger.info(f"Getting connected config for end_user: {end_user_id}") @@ -1266,8 +1285,8 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) 对于查询失败的用户,value 包含 error 字段 """ from app.models.app_release_model import AppRelease - from app.models.end_user_model import EndUser from app.models.data_config_model import DataConfig + from app.models.end_user_model import EndUser from sqlalchemy import select logger.info(f"Batch getting connected configs for {len(end_user_ids)} end users")