diff --git a/api/app/core/memory/evaluation/__init__.py b/api/app/core/memory/evaluation/__init__.py deleted file mode 100644 index e9d6aa6c..00000000 --- a/api/app/core/memory/evaluation/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Evaluation package with dataset-specific pipelines and a unified runner.""" diff --git a/api/app/core/memory/evaluation/benchmark.md b/api/app/core/memory/evaluation/benchmark.md deleted file mode 100644 index 2853b22b..00000000 --- a/api/app/core/memory/evaluation/benchmark.md +++ /dev/null @@ -1,30 +0,0 @@ -⏬数据集下载地址: - Locomo10.json:https://github.com/snap-research/locomo/tree/main/data - LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned - msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct - 上方数据集下载好后全部放入app/core/memory/data文件夹中 - -全流程基准测试运行: - locomo: - python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32 - LongMemEval: - python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group - memsciqa: - python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci - -单独检索评估运行命令: - python -m app.core.memory.evaluation.locomo.locomo_test - python -m app.core.memory.evaluation.longmemeval.test_eval - python -m app.core.memory.evaluation.memsciqa.memsciqa-test - 需要先在项目中修改需要检测评估的group_id。 - -参数及解释: - ● --dataset longmemeval - 指定数据集 - ● --sample-size 10 - 评估10个样本 - ● --start-index 0 - 从第0个样本开始 - ● --group-id longmemeval_zh_bak_2 - 使用指定的组ID - ● --search-limit 8 - 检索限制8条 - ● --context-char-budget 4000 - 上下文字符预算4000 - ● --search-type hybrid - 使用混合检索 - ● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文 - ● --reset-group - 运行前清空组数据 \ No newline at end of file diff --git a/api/app/core/memory/evaluation/common/metrics.py b/api/app/core/memory/evaluation/common/metrics.py deleted file mode 100644 index acc27fb9..00000000 --- a/api/app/core/memory/evaluation/common/metrics.py +++ /dev/null @@ -1,100 +0,0 @@ -import math -import re -from typing import List, Dict - - -def _normalize(text: str) -> List[str]: - """Lowercase, strip punctuation, and split into tokens.""" - text = text.lower().strip() - # Python's re doesn't support \p classes; use a simple non-word filter - text = re.sub(r"[^\w\s]", " ", text) - tokens = [t for t in text.split() if t] - return tokens - - -def exact_match(pred: str, ref: str) -> float: - return float(_normalize(pred) == _normalize(ref)) - - -def jaccard(pred: str, ref: str) -> float: - p = set(_normalize(pred)) - r = set(_normalize(ref)) - if not p and not r: - return 1.0 - if not p or not r: - return 0.0 - return len(p & r) / len(p | r) - - -def f1_score(pred: str, ref: str) -> float: - p_tokens = _normalize(pred) - r_tokens = _normalize(ref) - if not p_tokens and not r_tokens: - return 1.0 - if not p_tokens or not r_tokens: - return 0.0 - p_set = set(p_tokens) - r_set = set(r_tokens) - tp = len(p_set & r_set) - precision = tp / len(p_set) if p_set else 0.0 - recall = tp / len(r_set) if r_set else 0.0 - if precision + recall == 0: - return 0.0 - return 2 * precision * recall / (precision + recall) - - -def bleu1(pred: str, ref: str) -> float: - """Unigram BLEU (BLEU-1) with clipping and brevity penalty.""" - p_tokens = _normalize(pred) - r_tokens = _normalize(ref) - if not p_tokens: - return 0.0 - # Clipped count - r_counts: Dict[str, int] = {} - for t in r_tokens: - r_counts[t] = r_counts.get(t, 0) + 1 - clipped = 0 - p_counts: Dict[str, int] = {} - for t in p_tokens: - p_counts[t] = p_counts.get(t, 0) + 1 - for t, c in p_counts.items(): - clipped += min(c, r_counts.get(t, 0)) - precision = clipped / max(len(p_tokens), 1) - # Brevity penalty - ref_len = len(r_tokens) - pred_len = len(p_tokens) - if pred_len > ref_len or pred_len == 0: - bp = 1.0 - else: - bp = math.exp(1 - ref_len / max(pred_len, 1)) - return bp * precision - - -def percentile(values: List[float], p: float) -> float: - if not values: - return 0.0 - vals = sorted(values) - k = (len(vals) - 1) * p - f = math.floor(k) - c = math.ceil(k) - if f == c: - return vals[int(k)] - return vals[f] + (k - f) * (vals[c] - vals[f]) - - -def latency_stats(latencies_ms: List[float]) -> Dict[str, float]: - """Return basic latency stats: mean, p50, p95, iqr (p75-p25).""" - if not latencies_ms: - return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0} - p25 = percentile(latencies_ms, 0.25) - p50 = percentile(latencies_ms, 0.50) - p75 = percentile(latencies_ms, 0.75) - p95 = percentile(latencies_ms, 0.95) - mean = sum(latencies_ms) / max(len(latencies_ms), 1) - return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25} - - -def avg_context_tokens(contexts: List[str]) -> float: - if not contexts: - return 0.0 - return sum(len(_normalize(c)) for c in contexts) / len(contexts) diff --git a/api/app/core/memory/evaluation/dialogue_queries.py b/api/app/core/memory/evaluation/dialogue_queries.py deleted file mode 100644 index 25abe64e..00000000 --- a/api/app/core/memory/evaluation/dialogue_queries.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Dialogue search queries for evaluation purposes. -This file contains Cypher queries for searching dialogues, entities, and chunks. -Placed in evaluation directory to avoid circular imports with src modules. -""" - -# Entity search queries -SEARCH_ENTITIES_BY_NAME = """ -MATCH (e:Entity) -WHERE e.name = $name -RETURN e -""" - -SEARCH_ENTITIES_BY_NAME_FALLBACK = """ -MATCH (e:Entity) -WHERE e.name CONTAINS $name -RETURN e -""" - -# Chunk search queries -SEARCH_CHUNKS_BY_CONTENT = """ -MATCH (c:Chunk) -WHERE c.content CONTAINS $content -RETURN c -""" - -# Dialogue search queries -SEARCH_DIALOGUE_BY_DIALOG_ID = """ -MATCH (d:Dialogue) -WHERE d.dialog_id = $dialog_id -RETURN d -""" - -SEARCH_DIALOGUES_BY_CONTENT = """ -MATCH (d:Dialogue) -WHERE d.content CONTAINS $q -RETURN d -""" - -DIALOGUE_EMBEDDING_SEARCH = """ -WITH $embedding AS q -MATCH (d:Dialogue) -WHERE d.dialog_embedding IS NOT NULL - AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id) -WITH d, q, d.dialog_embedding AS v -WITH d, - reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot, - sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm, - sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm -WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score -WHERE score > $threshold -RETURN d.id AS dialog_id, - d.end_user_id AS end_user_id, - d.content AS content, - d.created_at AS created_at, - d.expired_at AS expired_at, - score -ORDER BY score DESC -LIMIT $limit -""" diff --git a/api/app/core/memory/evaluation/extraction_utils.py b/api/app/core/memory/evaluation/extraction_utils.py deleted file mode 100644 index 9e70bc28..00000000 --- a/api/app/core/memory/evaluation/extraction_utils.py +++ /dev/null @@ -1,341 +0,0 @@ -import asyncio -import json -import os -import re -from datetime import datetime -from typing import Any, Dict, List, Optional - -from app.core.memory.llm_tools.openai_client import LLMClient -from app.core.memory.models.message_models import ( - ConversationContext, - ConversationMessage, - DialogData, -) - -# 使用新的模块化架构 -from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( - ExtractionOrchestrator, -) -from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import ( - DialogueChunker, -) -from app.core.memory.utils.config.definitions import ( - SELECTED_CHUNKER_STRATEGY, - SELECTED_EMBEDDING_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context - -# Import from database module -from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -# Cypher queries for evaluation -# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py - - -async def ingest_contexts_via_full_pipeline( - contexts: List[str], - end_user_id: str, - chunker_strategy: str | None = None, - embedding_name: str | None = None, - save_chunk_output: bool = False, - save_chunk_output_path: str | None = None, -) -> bool: - """DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator - - Run the full extraction pipeline on provided dialogue contexts and save to Neo4j. - This function mirrors the steps in main(), but starts from raw text contexts. - Args: - contexts: List of dialogue texts, each containing lines like "role: message". - end_user_id: Group ID to assign to generated DialogData and graph nodes. - chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY. - embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. - save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. - save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt. - Returns: - True if data saved successfully, False otherwise. - """ - chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY - embedding_name = embedding_name or SELECTED_EMBEDDING_ID - - # Initialize llm client with graceful fallback - llm_client = None - llm_available = True - try: - from app.core.memory.utils.config import definitions as config_defs - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) - except Exception as e: - print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}") - llm_available = False - - # Step A: Build DialogData list from contexts with robust parsing - chunker = DialogueChunker(chunker_strategy) - dialog_data_list: List[DialogData] = [] - - for idx, ctx in enumerate(contexts): - messages: List[ConversationMessage] = [] - - # Improved parsing: capture multi-line message blocks, normalize roles - pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)" - matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL)) - - if matches: - for m in matches: - raw_role = m.group(1).strip() - content = m.group(2).strip() - norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户" - messages.append(ConversationMessage(role=norm_role, msg=content)) - else: - # Fallback: line-by-line parsing - for raw in ctx.split("\n"): - line = raw.strip() - if not line: - continue - m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line) - if m: - role = m.group(1).strip() - msg = m.group(2).strip() - norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户" - messages.append(ConversationMessage(role=norm_role, msg=msg)) - else: - # Final fallback: treat as user message - default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户" - messages.append(ConversationMessage(role=default_role, msg=line)) - - context_model = ConversationContext(msgs=messages) - dialog = DialogData( - context=context_model, - ref_id=f"pipeline_item_{idx}", - end_user_id=end_user_id, - user_id="default_user", - apply_id="default_application", - ) - # Generate chunks - dialog.chunks = await chunker.process_dialogue(dialog) - dialog_data_list.append(dialog) - - if not dialog_data_list: - print("No dialogs to process for ingestion.") - return False - - # Optionally save chunking outputs for debugging - if save_chunk_output: - try: - def _serialize_datetime(obj): - if isinstance(obj, datetime): - return obj.isoformat() - raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") - - from app.core.config import settings - settings.ensure_memory_output_dir() - default_path = settings.get_memory_output_path("chunker_test_output.txt") - out_path = save_chunk_output_path or default_path - - combined_output = [dd.model_dump() for dd in dialog_data_list] - with open(out_path, "w", encoding="utf-8") as f: - json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime) - print(f"Saved chunking results to: {out_path}") - except Exception as e: - print(f"Failed to save chunking results: {e}") - - # Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线 - if not llm_available: - print("[Ingestion] Skipping extraction pipeline (no LLM).") - return False - - # 初始化 embedder 客户端 - from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.core.models.base import RedBearModelConfig - from app.services.memory_config_service import MemoryConfigService - - try: - with get_db_context() as db: - embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID) - embedder_config = RedBearModelConfig(**embedder_config_dict) - embedder_client = OpenAIEmbedderClient(embedder_config) - except Exception as e: - print(f"[Ingestion] Failed to initialize embedder client: {e}") - print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).") - return False - - connector = Neo4jConnector() - - # 初始化并运行 ExtractionOrchestrator - from app.core.memory.utils.config.config_utils import get_pipeline_config - config = get_pipeline_config() - - orchestrator = ExtractionOrchestrator( - llm_client=llm_client, - embedder_client=embedder_client, - connector=connector, - config=config, - ) - - # 创建一个包装的 orchestrator 来修复时间提取器的输出 - # 保存原始的 _assign_extracted_data 方法 - original_assign = orchestrator._assign_extracted_data - - def clean_temporal_value(value): - """清理 temporal_validity 字段的值,将无效值转换为 None""" - if value is None: - return None - if isinstance(value, str): - # 处理字符串形式的 'null', 'None', 空字符串等 - if value.lower() in ('null', 'none', '') or value.strip() == '': - return None - return value - - async def patched_assign_extracted_data(*args, **kwargs): - """包装方法:在赋值后清理 temporal_validity 中的无效字符串""" - result = await original_assign(*args, **kwargs) - - # 清理返回的 dialog_data_list 中的 temporal_validity - for dialog in result: - if hasattr(dialog, 'chunks') and dialog.chunks: - for chunk in dialog.chunks: - if hasattr(chunk, 'statements') and chunk.statements: - for statement in chunk.statements: - if hasattr(statement, 'temporal_validity') and statement.temporal_validity: - tv = statement.temporal_validity - # 清理 valid_at 和 invalid_at - if hasattr(tv, 'valid_at'): - tv.valid_at = clean_temporal_value(tv.valid_at) - if hasattr(tv, 'invalid_at'): - tv.invalid_at = clean_temporal_value(tv.invalid_at) - return result - - # 替换方法 - orchestrator._assign_extracted_data = patched_assign_extracted_data - - # 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理 - original_create = orchestrator._create_nodes_and_edges - - async def patched_create_nodes_and_edges(dialog_data_list_arg): - """包装方法:在创建节点前再次清理 temporal_validity""" - # 最后一次清理,确保万无一失 - for dialog in dialog_data_list_arg: - if hasattr(dialog, 'chunks') and dialog.chunks: - for chunk in dialog.chunks: - if hasattr(chunk, 'statements') and chunk.statements: - for statement in chunk.statements: - if hasattr(statement, 'temporal_validity') and statement.temporal_validity: - tv = statement.temporal_validity - if hasattr(tv, 'valid_at'): - tv.valid_at = clean_temporal_value(tv.valid_at) - if hasattr(tv, 'invalid_at'): - tv.invalid_at = clean_temporal_value(tv.invalid_at) - - return await original_create(dialog_data_list_arg) - - orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges - - # 运行完整的提取流水线 - # orchestrator.run 返回 7 个元素的元组 - result = await orchestrator.run(dialog_data_list, is_pilot_run=False) - ( - dialogue_nodes, - chunk_nodes, - statement_nodes, - entity_nodes, - statement_chunk_edges, - statement_entity_edges, - entity_entity_edges, - ) = result - - # statement_chunk_edges 已经由 orchestrator 创建,无需重复创建 - - # Step G: 生成记忆摘要 - print("[Ingestion] Generating memory summaries...") - try: - from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( - memory_summary_generation, - ) - from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges - from app.repositories.neo4j.add_nodes import add_memory_summary_nodes - - summaries = await memory_summary_generation( - chunked_dialogs=dialog_data_list, - llm_client=llm_client, - embedder_client=embedder_client - ) - print(f"[Ingestion] Generated {len(summaries)} memory summaries") - except Exception as e: - print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}") - summaries = [] - - # Step H: Save to Neo4j - try: - success = await save_dialog_and_statements_to_neo4j( - dialogue_nodes=dialogue_nodes, - chunk_nodes=chunk_nodes, - statement_nodes=statement_nodes, - entity_nodes=entity_nodes, - entity_edges=entity_entity_edges, - statement_chunk_edges=statement_chunk_edges, - statement_entity_edges=statement_entity_edges, - connector=connector - ) - - # Save memory summaries separately - if summaries: - try: - await add_memory_summary_nodes(summaries, connector) - await add_memory_summary_statement_edges(summaries, connector) - print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j") - except Exception as e: - print(f"Warning: Failed to save summary nodes: {e}") - - await connector.close() - if success: - print("Successfully saved extracted data to Neo4j!") - else: - print("Failed to save data to Neo4j") - return success - except Exception as e: - print(f"Failed to save data to Neo4j: {e}") - return False - - -async def handle_context_processing(args): - """Handle context-based processing from command line arguments.""" - contexts = [] - - if args.contexts: - contexts.extend(args.contexts) - - if args.context_file: - try: - with open(args.context_file, 'r', encoding='utf-8') as f: - contexts.extend(line.strip() for line in f if line.strip()) - except Exception as e: - print(f"Error reading context file: {e}") - return False - - if not contexts: - print("No contexts provided for processing.") - return False - - return await main_from_contexts(contexts, args.context_end_user_id) - - -async def main_from_contexts(contexts: List[str], end_user_id: str): - """Run the pipeline from provided dialogue contexts instead of test data.""" - print("=== Running pipeline from provided contexts ===") - - success = await ingest_contexts_via_full_pipeline( - contexts=contexts, - end_user_id=end_user_id, - chunker_strategy=SELECTED_CHUNKER_STRATEGY, - embedding_name=SELECTED_EMBEDDING_ID, - save_chunk_output=True - ) - - if success: - print("Successfully processed and saved contexts to Neo4j!") - else: - print("Failed to process contexts.") - - return success diff --git a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py b/api/app/core/memory/evaluation/locomo/locomo_benchmark.py deleted file mode 100644 index 1c70c28e..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_benchmark.py +++ /dev/null @@ -1,575 +0,0 @@ -""" -LoCoMo Benchmark Script - -This module provides the main entry point for running LoCoMo benchmark evaluations. -It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation -in a clean, maintainable way. - -Usage: - python locomo_benchmark.py --sample_size 20 --search_type hybrid -""" - -import argparse -import asyncio -import json -import os -import time -from datetime import datetime -from typing import Any, Dict, List, Optional - -try: - from dotenv import load_dotenv -except ImportError: - def load_dotenv(): - pass - -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - bleu1, - f1_score, - jaccard, - latency_stats, -) -from app.core.memory.evaluation.locomo.locomo_metrics import ( - get_category_name, - locomo_f1_score, - locomo_multi_f1, -) -from app.core.memory.evaluation.locomo.locomo_utils import ( - extract_conversations, - ingest_conversations_if_needed, - load_locomo_data, - resolve_temporal_references, - retrieve_relevant_information, - select_and_format_information, -) -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_end_user_id, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService - - -async def run_locomo_benchmark( - sample_size: int = 20, - end_user_id: Optional[str] = None, - search_type: str = "hybrid", - search_limit: int = 12, - context_char_budget: int = 8000, - reset_group: bool = False, - skip_ingest: bool = False, - output_dir: Optional[str] = None -) -> Dict[str, Any]: - """ - Run LoCoMo benchmark evaluation. - - This function orchestrates the complete evaluation pipeline: - 1. Load LoCoMo dataset (only QA pairs from first conversation) - 2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True) - 3. For each question: - - Retrieve relevant information - - Generate answer using LLM - - Calculate metrics - 4. Aggregate results and save to file - - Note: By default, only the first conversation is ingested into the database, - and only QA pairs from that conversation are evaluated. This ensures that - all questions have corresponding memory in the database for retrieval. - - Args: - sample_size: Number of QA pairs to evaluate (from first conversation) - end_user_id: Database group ID for retrieval (uses default if None) - search_type: "keyword", "embedding", or "hybrid" - search_limit: Max documents to retrieve per query - context_char_budget: Max characters for context - reset_group: Whether to clear and re-ingest data (not implemented) - skip_ingest: If True, skip data ingestion and use existing data in Neo4j - output_dir: Directory to save results (uses default if None) - - Returns: - Dictionary with evaluation results including metrics, timing, and samples - """ - # Use default end_user_id if not provided - end_user_id = end_user_id or SELECTED_end_user_id - - # Determine data path - data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") - if not os.path.exists(data_path): - # Fallback to current directory - data_path = os.path.join(os.getcwd(), "data", "locomo10.json") - - print(f"\n{'='*60}") - print("🚀 Starting LoCoMo Benchmark Evaluation") - print(f"{'='*60}") - print("📊 Configuration:") - print(f" Sample size: {sample_size}") - print(f" Group ID: {end_user_id}") - print(f" Search type: {search_type}") - print(f" Search limit: {search_limit}") - print(f" Context budget: {context_char_budget} chars") - print(f" Data path: {data_path}") - print(f"{'='*60}\n") - - # Step 1: Load LoCoMo data - print("📂 Loading LoCoMo dataset...") - try: - # Only load QA pairs from the first conversation (index 0) - # since we only ingest the first conversation into the database - qa_items = load_locomo_data(data_path, sample_size, conversation_index=0) - print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n") - except Exception as e: - print(f"❌ Failed to load data: {e}") - return { - "error": f"Data loading failed: {e}", - "timestamp": datetime.now().isoformat() - } - - # Step 2: Extract conversations and ingest if needed - if skip_ingest: - print("⏭️ Skipping data ingestion (using existing data in Neo4j)") - print(f" Group ID: {end_user_id}\n") - else: - print("💾 Checking database ingestion...") - try: - conversations = extract_conversations(data_path, max_dialogues=1) - print(f"📝 Extracted {len(conversations)} conversations") - - # Always ingest for now (ingestion check not implemented) - print(f"🔄 Ingesting conversations into group '{end_user_id}'...") - success = await ingest_conversations_if_needed( - conversations=conversations, - end_user_id=end_user_id, - reset=reset_group - ) - - if success: - print("✅ Ingestion completed successfully\n") - else: - print("⚠️ Ingestion may have failed, continuing anyway\n") - - except Exception as e: - print(f"❌ Ingestion failed: {e}") - print("⚠️ Continuing with evaluation (database may be empty)\n") - - # Step 3: Initialize clients - print("🔧 Initializing clients...") - connector = Neo4jConnector() - - # Initialize LLM client with database context - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) - - # Initialize embedder - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - print("✅ Clients initialized\n") - - # Step 4: Process questions - print(f"🔍 Processing {len(qa_items)} questions...") - print(f"{'='*60}\n") - - # Tracking variables - latencies_search: List[float] = [] - latencies_llm: List[float] = [] - context_counts: List[int] = [] - context_chars: List[int] = [] - context_tokens: List[int] = [] - - # Metric lists - f1_scores: List[float] = [] - bleu1_scores: List[float] = [] - jaccard_scores: List[float] = [] - locomo_f1_scores: List[float] = [] - - # Per-category tracking - category_counts: Dict[str, int] = {} - category_f1: Dict[str, List[float]] = {} - category_bleu1: Dict[str, List[float]] = {} - category_jaccard: Dict[str, List[float]] = {} - category_locomo_f1: Dict[str, List[float]] = {} - - # Detailed samples - samples: List[Dict[str, Any]] = [] - - # Fixed anchor date for temporal resolution - anchor_date = datetime(2023, 5, 8) - - try: - for idx, item in enumerate(qa_items, 1): - question = item.get("question", "") - ground_truth = item.get("answer", "") - category = get_category_name(item) - - # Ensure ground truth is a string - ground_truth_str = str(ground_truth) if ground_truth is not None else "" - - print(f"[{idx}/{len(qa_items)}] Category: {category}") - print(f"❓ Question: {question}") - print(f"✅ Ground Truth: {ground_truth_str}") - - # Step 4a: Retrieve relevant information - t_search_start = time.time() - try: - retrieved_info = await retrieve_relevant_information( - question=question, - end_user_id=end_user_id, - search_type=search_type, - search_limit=search_limit, - connector=connector, - embedder=embedder - ) - t_search_end = time.time() - search_latency = (t_search_end - t_search_start) * 1000 - latencies_search.append(search_latency) - - print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)") - - except Exception as e: - print(f"❌ Retrieval failed: {e}") - retrieved_info = [] - search_latency = 0.0 - latencies_search.append(search_latency) - - # Step 4b: Select and format context - context_text = select_and_format_information( - retrieved_info=retrieved_info, - question=question, - max_chars=context_char_budget - ) - - # Resolve temporal references - context_text = resolve_temporal_references(context_text, anchor_date) - - # Add reference date to context - if context_text: - context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}" - else: - context_text = "No relevant context found." - - # Track context statistics - context_counts.append(len(retrieved_info)) - context_chars.append(len(context_text)) - context_tokens.append(len(context_text.split())) - - print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs") - - # Step 4c: Generate answer with LLM - messages = [ - { - "role": "system", - "content": ( - "You are a precise QA assistant. Answer following these rules:\n" - "1) Extract the EXACT information mentioned in the context\n" - "2) For time questions: calculate actual dates from relative times\n" - "3) Return ONLY the answer text in simplest form\n" - "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" - "5) If no clear answer found, respond with 'Unknown'" - ) - }, - { - "role": "user", - "content": f"Question: {question}\n\nContext:\n{context_text}" - } - ] - - t_llm_start = time.time() - try: - response = await llm_client.chat(messages=messages) - t_llm_end = time.time() - llm_latency = (t_llm_end - t_llm_start) * 1000 - latencies_llm.append(llm_latency) - - # Extract prediction from response - if hasattr(response, 'content'): - prediction = response.content.strip() - elif isinstance(response, dict): - prediction = response["choices"][0]["message"]["content"].strip() - else: - prediction = "Unknown" - - print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)") - - except Exception as e: - print(f"❌ LLM failed: {e}") - prediction = "Unknown" - llm_latency = 0.0 - latencies_llm.append(llm_latency) - - # Step 4d: Calculate metrics - f1_val = f1_score(prediction, ground_truth_str) - bleu1_val = bleu1(prediction, ground_truth_str) - jaccard_val = jaccard(prediction, ground_truth_str) - - # LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop) - if item.get("category") == 1: - locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str) - else: - locomo_f1_val = locomo_f1_score(prediction, ground_truth_str) - - # Accumulate metrics - f1_scores.append(f1_val) - bleu1_scores.append(bleu1_val) - jaccard_scores.append(jaccard_val) - locomo_f1_scores.append(locomo_f1_val) - - # Track by category - category_counts[category] = category_counts.get(category, 0) + 1 - category_f1.setdefault(category, []).append(f1_val) - category_bleu1.setdefault(category, []).append(bleu1_val) - category_jaccard.setdefault(category, []).append(jaccard_val) - category_locomo_f1.setdefault(category, []).append(locomo_f1_val) - - print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, " - f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}") - print() - - # Save sample details - samples.append({ - "question": question, - "ground_truth": ground_truth_str, - "prediction": prediction, - "category": category, - "metrics": { - "f1": f1_val, - "bleu1": bleu1_val, - "jaccard": jaccard_val, - "locomo_f1": locomo_f1_val - }, - "retrieval": { - "num_docs": len(retrieved_info), - "context_length": len(context_text) - }, - "timing": { - "search_ms": search_latency, - "llm_ms": llm_latency - } - }) - - finally: - # Close connector - await connector.close() - - # Step 5: Aggregate results - print(f"\n{'='*60}") - print("📊 Aggregating Results") - print(f"{'='*60}\n") - - # Overall metrics - overall_metrics = { - "f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0, - "bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0, - "jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0, - "locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0 - } - - # Per-category metrics - by_category: Dict[str, Dict[str, Any]] = {} - for cat in category_counts: - f1_list = category_f1.get(cat, []) - b1_list = category_bleu1.get(cat, []) - j_list = category_jaccard.get(cat, []) - lf_list = category_locomo_f1.get(cat, []) - - by_category[cat] = { - "count": category_counts[cat], - "f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0, - "bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0, - "jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0, - "locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0 - } - - # Latency statistics - latency = { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm) - } - - # Context statistics - context_stats = { - "avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0, - "avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0, - "avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0 - } - - # Build result dictionary - result = { - "dataset": "locomo", - "sample_size": len(qa_items), - "timestamp": datetime.now().isoformat(), - "params": { - "end_user_id": end_user_id, - "search_type": search_type, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "llm_id": SELECTED_LLM_ID, - "embedding_id": SELECTED_EMBEDDING_ID - }, - "overall_metrics": overall_metrics, - "by_category": by_category, - "latency": latency, - "context_stats": context_stats, - "samples": samples - } - - # Step 6: Save results - if output_dir is None: - output_dir = os.path.join( - os.path.dirname(__file__), - "results" - ) - - os.makedirs(output_dir, exist_ok=True) - - # Generate timestamped filename - timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") - output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json") - - try: - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"✅ Results saved to: {output_path}\n") - except Exception as e: - print(f"❌ Failed to save results: {e}") - print("📊 Printing results to console instead:\n") - print(json.dumps(result, ensure_ascii=False, indent=2)) - - return result - - -def main(): - """ - Parse command-line arguments and run benchmark. - - This function provides a CLI interface for running LoCoMo benchmarks - with configurable parameters. - """ - parser = argparse.ArgumentParser( - description="Run LoCoMo benchmark evaluation", - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--sample_size", - type=int, - default=20, - help="Number of QA pairs to evaluate" - ) - parser.add_argument( - "--end_user_id", - type=str, - default=None, - help="Database group ID for retrieval (uses default if not specified)" - ) - parser.add_argument( - "--search_type", - type=str, - default="hybrid", - choices=["keyword", "embedding", "hybrid"], - help="Search strategy to use" - ) - parser.add_argument( - "--search_limit", - type=int, - default=12, - help="Maximum number of documents to retrieve per query" - ) - parser.add_argument( - "--context_char_budget", - type=int, - default=8000, - help="Maximum characters for context" - ) - parser.add_argument( - "--reset_group", - action="store_true", - help="Clear and re-ingest data (not implemented)" - ) - parser.add_argument( - "--skip_ingest", - action="store_true", - help="Skip data ingestion and use existing data in Neo4j" - ) - parser.add_argument( - "--output_dir", - type=str, - default=None, - help="Directory to save results (uses default if not specified)" - ) - - args = parser.parse_args() - - # Load environment variables - load_dotenv() - - # Run benchmark - result = asyncio.run(run_locomo_benchmark( - sample_size=args.sample_size, - end_user_id=args.end_user_id, - search_type=args.search_type, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - reset_group=args.reset_group, - skip_ingest=args.skip_ingest, - output_dir=args.output_dir - )) - - # Print summary - print(f"\n{'='*60}") - - # Check if there was an error - if 'error' in result: - print("❌ Benchmark Failed!") - print(f"{'='*60}") - print(f"Error: {result['error']}") - return - - print("🎉 Benchmark Complete!") - print(f"{'='*60}") - print("📊 Final Results:") - print(f" Sample size: {result.get('sample_size', 0)}") - print(f" F1: {result['overall_metrics']['f1']:.3f}") - print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}") - print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}") - print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}") - - if result.get('context_stats'): - print("\n📈 Context Statistics:") - print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}") - print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}") - print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}") - - if result.get('latency'): - print("\n⏱️ Latency Statistics:") - print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, " - f"P50: {result['latency']['search']['p50']:.1f}ms, " - f"P95: {result['latency']['search']['p95']:.1f}ms") - print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, " - f"P50: {result['latency']['llm']['p50']:.1f}ms, " - f"P95: {result['latency']['llm']['p95']:.1f}ms") - - if result.get('by_category'): - print("\n📂 Results by Category:") - for cat, metrics in result['by_category'].items(): - print(f" {cat}:") - print(f" Count: {metrics['count']}") - print(f" F1: {metrics['f1']:.3f}") - print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}") - print(f" Jaccard: {metrics['jaccard']:.3f}") - - print(f"\n{'='*60}\n") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/locomo/locomo_metrics.py b/api/app/core/memory/evaluation/locomo/locomo_metrics.py deleted file mode 100644 index 20d5f2b5..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_metrics.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -LoCoMo-specific metric calculations. - -This module provides clean, simplified implementations of metrics used for -LoCoMo benchmark evaluation, including text normalization and F1 score variants. -""" - -import re -from typing import Dict, Any - - -def normalize_text(text: str) -> str: - """ - Normalize text for LoCoMo evaluation. - - Normalization steps: - - Convert to lowercase - - Remove commas - - Remove stop words (a, an, the, and) - - Remove punctuation - - Normalize whitespace - - Args: - text: Input text to normalize - - Returns: - Normalized text string with consistent formatting - - Examples: - >>> normalize_text("The cat, and the dog") - 'cat dog' - >>> normalize_text("Hello, World!") - 'hello world' - """ - # Ensure input is a string - text = str(text) if text is not None else "" - - # Convert to lowercase - text = text.lower() - - # Remove commas - text = re.sub(r"[\,]", " ", text) - - # Remove stop words - text = re.sub(r"\b(a|an|the|and)\b", " ", text) - - # Remove punctuation (keep only word characters and whitespace) - text = re.sub(r"[^\w\s]", " ", text) - - # Normalize whitespace (collapse multiple spaces to single space) - text = " ".join(text.split()) - - return text - - -def locomo_f1_score(prediction: str, ground_truth: str) -> float: - """ - Calculate LoCoMo F1 score for single-answer questions. - - Uses token-level precision and recall based on normalized text. - Treats tokens as sets (no duplicate counting). - - Args: - prediction: Model's predicted answer - ground_truth: Correct answer - - Returns: - F1 score between 0.0 and 1.0 - - Examples: - >>> locomo_f1_score("Paris", "Paris") - 1.0 - >>> locomo_f1_score("The cat", "cat") - 1.0 - >>> locomo_f1_score("dog", "cat") - 0.0 - """ - # Ensure inputs are strings - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - # Normalize and tokenize - pred_tokens = normalize_text(pred_str).split() - truth_tokens = normalize_text(truth_str).split() - - # Handle empty cases - if not pred_tokens or not truth_tokens: - return 0.0 - - # Convert to sets for comparison - pred_set = set(pred_tokens) - truth_set = set(truth_tokens) - - # Calculate true positives (intersection) - true_positives = len(pred_set & truth_set) - - # Calculate precision and recall - precision = true_positives / len(pred_set) if pred_set else 0.0 - recall = true_positives / len(truth_set) if truth_set else 0.0 - - # Calculate F1 score - if precision + recall == 0: - return 0.0 - - f1 = 2 * precision * recall / (precision + recall) - return f1 - - -def locomo_multi_f1(prediction: str, ground_truth: str) -> float: - """ - Calculate LoCoMo F1 score for multi-answer questions. - - Handles comma-separated answers by: - 1. Splitting both prediction and ground truth by commas - 2. For each ground truth answer, finding the best matching prediction - 3. Averaging the F1 scores across all ground truth answers - - Args: - prediction: Model's predicted answer (may contain multiple comma-separated answers) - ground_truth: Correct answer (may contain multiple comma-separated answers) - - Returns: - Average F1 score across all ground truth answers (0.0 to 1.0) - - Examples: - >>> locomo_multi_f1("Paris, London", "Paris, London") - 1.0 - >>> locomo_multi_f1("Paris", "Paris, London") - 0.5 - >>> locomo_multi_f1("Paris, Berlin", "Paris, London") - 0.5 - """ - # Ensure inputs are strings - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - # Split by commas and strip whitespace - predictions = [p.strip() for p in pred_str.split(',') if p.strip()] - ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()] - - # Handle empty cases - if not predictions or not ground_truths: - return 0.0 - - # For each ground truth, find the best matching prediction - f1_scores = [] - for gt in ground_truths: - # Calculate F1 with each prediction and take the maximum - best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions) - f1_scores.append(best_f1) - - # Return average F1 across all ground truths - return sum(f1_scores) / len(f1_scores) - - -def get_category_name(item: Dict[str, Any]) -> str: - """ - Extract and normalize category name from QA item. - - Handles both numeric categories (1-4) and string categories with various formats. - Supports multiple field names: "cat", "category", "type". - - Category mapping: - - 1 or "multi-hop" -> "Multi-Hop" - - 2 or "temporal" -> "Temporal" - - 3 or "open domain" -> "Open Domain" - - 4 or "single-hop" -> "Single-Hop" - - Args: - item: QA item dictionary containing category information - - Returns: - Standardized category name or "unknown" if not found - - Examples: - >>> get_category_name({"category": 1}) - 'Multi-Hop' - >>> get_category_name({"cat": "temporal"}) - 'Temporal' - >>> get_category_name({"type": "Single-Hop"}) - 'Single-Hop' - """ - # Numeric category mapping - CATEGORY_MAP = { - 1: "Multi-Hop", - 2: "Temporal", - 3: "Open Domain", - 4: "Single-Hop", - } - - # String category aliases (case-insensitive) - TYPE_ALIASES = { - "single-hop": "Single-Hop", - "singlehop": "Single-Hop", - "single hop": "Single-Hop", - "multi-hop": "Multi-Hop", - "multihop": "Multi-Hop", - "multi hop": "Multi-Hop", - "open domain": "Open Domain", - "opendomain": "Open Domain", - "temporal": "Temporal", - } - - # Try "cat" field first (string category) - cat = item.get("cat") - if isinstance(cat, str) and cat.strip(): - name = cat.strip() - lower = name.lower() - return TYPE_ALIASES.get(lower, name) - - # Try "category" field (can be int or string) - cat_num = item.get("category") - if isinstance(cat_num, int): - return CATEGORY_MAP.get(cat_num, "unknown") - elif isinstance(cat_num, str) and cat_num.strip(): - lower = cat_num.strip().lower() - return TYPE_ALIASES.get(lower, cat_num.strip()) - - # Try "type" field as fallback - cat_type = item.get("type") - if isinstance(cat_type, str) and cat_type.strip(): - lower = cat_type.strip().lower() - return TYPE_ALIASES.get(lower, cat_type.strip()) - - return "unknown" diff --git a/api/app/core/memory/evaluation/locomo/locomo_test.py b/api/app/core/memory/evaluation/locomo/locomo_test.py deleted file mode 100644 index 01c45123..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_test.py +++ /dev/null @@ -1,811 +0,0 @@ -# file name: check_neo4j_connection_fixed.py -import asyncio -import json -import math -import os -import re -import sys -import time -from datetime import datetime, timedelta -from typing import Any, Dict, List -from pathlib import Path - -from dotenv import load_dotenv - -# 1 -# 添加项目根目录到路径 -current_dir = Path(__file__).resolve().parent -project_root = str(current_dir.parent) -if project_root not in sys.path: - sys.path.insert(0, project_root) -# 关键:将 src 目录置于最前,确保从当前仓库加载模块 -src_dir = os.path.join(project_root, "src") -if src_dir not in sys.path: - sys.path.insert(0, src_dir) - -load_dotenv() - -# 首先定义 _loc_normalize 函数,因为其他函数依赖它 -def _loc_normalize(text: str) -> str: - text = str(text) if text is not None else "" - text = text.lower() - text = re.sub(r"[\,]", " ", text) - text = re.sub(r"\b(a|an|the|and)\b", " ", text) - text = re.sub(r"[^\w\s]", " ", text) - text = " ".join(text.split()) - return text - -# 尝试从 metrics.py 导入基础指标 -try: - from common.metrics import bleu1, f1_score, jaccard - print("✅ 从 metrics.py 导入基础指标成功") -except ImportError as e: - print(f"❌ 从 metrics.py 导入失败: {e}") - # 回退到本地实现 - def f1_score(pred: str, ref: str) -> float: - pred_str = str(pred) if pred is not None else "" - ref_str = str(ref) if ref is not None else "" - - p_tokens = _loc_normalize(pred_str).split() - r_tokens = _loc_normalize(ref_str).split() - if not p_tokens and not r_tokens: - return 1.0 - if not p_tokens or not r_tokens: - return 0.0 - p_set = set(p_tokens) - r_set = set(r_tokens) - tp = len(p_set & r_set) - precision = tp / len(p_set) if p_set else 0.0 - recall = tp / len(r_set) if r_set else 0.0 - if precision + recall == 0: - return 0.0 - return 2 * precision * recall / (precision + recall) - - def bleu1(pred: str, ref: str) -> float: - pred_str = str(pred) if pred is not None else "" - ref_str = str(ref) if ref is not None else "" - - p_tokens = _loc_normalize(pred_str).split() - r_tokens = _loc_normalize(ref_str).split() - if not p_tokens: - return 0.0 - - r_counts = {} - for t in r_tokens: - r_counts[t] = r_counts.get(t, 0) + 1 - - clipped = 0 - p_counts = {} - for t in p_tokens: - p_counts[t] = p_counts.get(t, 0) + 1 - - for t, c in p_counts.items(): - clipped += min(c, r_counts.get(t, 0)) - - precision = clipped / max(len(p_tokens), 1) - ref_len = len(r_tokens) - pred_len = len(p_tokens) - - if pred_len > ref_len or pred_len == 0: - bp = 1.0 - else: - bp = math.exp(1 - ref_len / max(pred_len, 1)) - - return bp * precision - - def jaccard(pred: str, ref: str) -> float: - pred_str = str(pred) if pred is not None else "" - ref_str = str(ref) if ref is not None else "" - - p = set(_loc_normalize(pred_str).split()) - r = set(_loc_normalize(ref_str).split()) - if not p and not r: - return 1.0 - if not p or not r: - return 0.0 - return len(p & r) / len(p | r) - -# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标 -try: - # 添加 evaluation 目录路径 - evaluation_dir = os.path.join(project_root, "evaluation") - if evaluation_dir not in sys.path: - sys.path.insert(0, evaluation_dir) - - # 尝试从不同位置导入 - try: - from locomo.qwen_search_eval import ( - _resolve_relative_times, - loc_f1_score, - loc_multi_f1, - ) - print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功") - except ImportError: - from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1 - print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功") - -except ImportError as e: - print(f"❌ 从 qwen_search_eval.py 导入失败: {e}") - # 回退到本地实现 LoCoMo 特定函数 - def _resolve_relative_times(text: str, anchor: datetime) -> str: - t = str(text) if text is not None else "" - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - return t - - def loc_f1_score(prediction: str, ground_truth: str) -> float: - p_tokens = _loc_normalize(prediction).split() - g_tokens = _loc_normalize(ground_truth).split() - if not p_tokens or not g_tokens: - return 0.0 - p = set(p_tokens) - g = set(g_tokens) - tp = len(p & g) - precision = tp / len(p) if p else 0.0 - recall = tp / len(g) if g else 0.0 - return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 - - def loc_multi_f1(prediction: str, ground_truth: str) -> float: - predictions = [p.strip() for p in str(prediction).split(',') if p.strip()] - ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()] - if not predictions or not ground_truths: - return 0.0 - def _f1(a: str, b: str) -> float: - return loc_f1_score(a, b) - vals = [] - for gt in ground_truths: - vals.append(max(_f1(pred, gt) for pred in predictions)) - return sum(vals) / len(vals) - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str: - """基于问题关键词智能选择上下文""" - if not contexts: - return "" - - # 提取问题关键词(只保留有意义的词) - question_lower = question.lower() - stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'} - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = {word for word in question_words if word not in stop_words and len(word) > 2} - - print(f"🔍 问题关键词: {question_words}") - - # 给每个上下文打分 - scored_contexts = [] - for i, context in enumerate(contexts): - context_lower = context.lower() - score = 0 - - # 关键词匹配得分 - keyword_matches = 0 - for word in question_words: - if word in context_lower: - keyword_matches += 1 - # 关键词出现次数越多,得分越高 - score += context_lower.count(word) * 2 - - # 上下文长度得分(适中的长度更好) - context_len = len(context) - if 100 < context_len < 2000: # 理想长度范围 - score += 5 - elif context_len >= 2000: # 太长可能包含无关信息 - score += 2 - - # 如果是前几个上下文,给予额外分数(通常相关性更高) - if i < 3: - score += 3 - - scored_contexts.append((score, context, keyword_matches)) - - # 按得分排序 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - - # 选择高得分的上下文,直到达到字符限制 - selected = [] - total_chars = 0 - selected_count = 0 - - print("📊 上下文相关性分析:") - for score, context, matches in scored_contexts[:5]: # 只显示前5个 - print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}") - - for score, context, matches in scored_contexts: - if total_chars + len(context) <= max_chars: - selected.append(context) - total_chars += len(context) - selected_count += 1 - else: - # 如果这个上下文得分很高但放不下,尝试截取 - if score > 10 and total_chars < max_chars - 500: - remaining = max_chars - total_chars - # 找到包含关键词的部分 - lines = context.split('\n') - relevant_lines = [] - current_chars = 0 - - for line in lines: - line_lower = line.lower() - line_relevance = any(word in line_lower for word in question_words) - - if line_relevance and current_chars < remaining - 100: - relevant_lines.append(line) - current_chars += len(line) - - if relevant_lines: - truncated = '\n'.join(relevant_lines) - if len(truncated) > 100: # 确保有足够内容 - selected.append(truncated + "\n[相关内容截断...]") - total_chars += len(truncated) - selected_count += 1 - break # 不再尝试添加更多上下文 - - result = "\n\n".join(selected) - print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符") - return result - - -def get_dynamic_search_params(question: str, question_index: int, total_questions: int): - """根据问题复杂度和进度动态调整检索参数""" - - # 分析问题复杂度 - word_count = len(question.split()) - has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago']) - has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while']) - - # 根据进度调整 - 后期问题可能需要更精确的检索 - progress_factor = question_index / total_questions - - base_limit = 12 - if has_temporal and has_multi_hop: - base_limit = 20 - elif word_count > 8: - base_limit = 16 - - # 随着测试进行,逐渐收紧检索范围 - adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3))) - - # 动态调整最大字符数 - max_chars = 8000 + 4000 * (1 - progress_factor) - - return { - "limit": adjusted_limit, - "max_chars": int(max_chars) - } - - -class EnhancedEvaluationMonitor: - def __init__(self, reset_interval=5, performance_threshold=0.6): - self.question_count = 0 - self.reset_interval = reset_interval - self.performance_threshold = performance_threshold - self.consecutive_low_scores = 0 - self.performance_history = [] - self.recent_f1_scores = [] - - def should_reset_connections(self, current_f1=None): - """基于计数和性能双重判断""" - # 定期重置 - if self.question_count % self.reset_interval == 0: - return True - - # 性能驱动的重置 - if current_f1 is not None and current_f1 < self.performance_threshold: - self.consecutive_low_scores += 1 - if self.consecutive_low_scores >= 2: # 连续2个低分就重置 - print("🚨 连续低分,触发紧急重置") - self.consecutive_low_scores = 0 - return True - else: - self.consecutive_low_scores = 0 - - return False - - def record_performance(self, question_index, metrics, context_length, retrieved_docs): - """记录性能指标,检测衰减""" - self.performance_history.append({ - 'index': question_index, - 'metrics': metrics, - 'context_length': context_length, - 'retrieved_docs': retrieved_docs, - 'timestamp': time.time() - }) - - # 记录最近的F1分数 - self.recent_f1_scores.append(metrics['f1']) - if len(self.recent_f1_scores) > 5: - self.recent_f1_scores.pop(0) - - def get_recent_performance(self): - """获取近期平均性能""" - if not self.recent_f1_scores: - return 0.5 - return sum(self.recent_f1_scores) / len(self.recent_f1_scores) - - def get_performance_trend(self): - """分析性能趋势""" - if len(self.performance_history) < 2: - return "stable" - - recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]] - earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]] - - if len(recent_metrics) < 2 or len(earlier_metrics) < 2: - return "stable" - - recent_avg = sum(recent_metrics) / len(recent_metrics) - earlier_avg = sum(earlier_metrics) / len(earlier_metrics) - - if recent_avg < earlier_avg * 0.8: - return "degrading" - elif recent_avg > earlier_avg * 1.1: - return "improving" - else: - return "stable" - - -def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float): - """基于问题复杂度和近期性能动态调整检索参数""" - - # 基础参数 - base_params = get_dynamic_search_params(question, question_index, total_questions) - - # 性能自适应调整 - if recent_performance < 0.5: # 近期表现差 - # 增加检索范围,尝试获取更多上下文 - base_params["limit"] = min(base_params["limit"] + 5, 25) - base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000) - print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})") - - elif recent_performance > 0.8: # 近期表现好 - # 收紧检索,提高精度 - base_params["limit"] = max(base_params["limit"] - 2, 8) - base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000) - print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})") - - # 中间阶段特殊处理 - mid_sequence_factor = abs(question_index / total_questions - 0.5) - if mid_sequence_factor < 0.2: # 在中间30%的问题 - print("🎯 中间阶段:使用更精确的检索策略") - base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量 - base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000) - - return base_params - - -def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str: - """考虑问题序列位置的智能选择""" - - if not contexts: - return "" - - # 在序列中间阶段使用更严格的筛选 - mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离 - - if mid_sequence_factor < 0.2: # 在中间30%的问题 - print("🎯 中间阶段:使用严格上下文筛选") - - # 提取问题关键词 - question_lower = question.lower() - stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'} - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = {word for word in question_words if word not in stop_words and len(word) > 2} - - # 只保留高度相关的上下文 - filtered_contexts = [] - for context in contexts: - context_lower = context.lower() - relevance_score = sum(3 if word in context_lower else 0 for word in question_words) - - # 额外加分给包含数字、日期的上下文(对事实性问题更重要) - if any(char.isdigit() for char in context): - relevance_score += 2 - - # 提高阈值:只有得分>=3的上下文才保留 - if relevance_score >= 3: - filtered_contexts.append(context) - else: - print(f" - 过滤低分上下文: 得分={relevance_score}") - - contexts = filtered_contexts - print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文") - - # 使用原有的智能选择逻辑 - return smart_context_selection(contexts, question, max_chars) - - -async def run_enhanced_evaluation(): - """使用增强方法进行完整评估 - 解决中间性能衰减问题""" - try: - from dotenv import load_dotenv - except Exception: - def load_dotenv(): - return None - - # 修正导入路径:使用 app.core.memory.src 前缀 - from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient - from app.core.memory.utils.config.definitions import ( - SELECTED_EMBEDDING_ID, - SELECTED_LLM_ID, - ) - from app.core.memory.utils.llm.llm_utils import MemoryClientFactory - from app.core.models.base import RedBearModelConfig - from app.db import get_db_context - from app.repositories.neo4j.graph_search import search_graph_by_embedding - from app.repositories.neo4j.neo4j_connector import Neo4jConnector - from app.services.memory_config_service import MemoryConfigService - - # 加载数据 - # 获取项目根目录 - current_file = os.path.abspath(__file__) - evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录 - memory_dir = os.path.dirname(evaluation_dir) # memory目录 - data_path = os.path.join(memory_dir, "data", "locomo10.json") - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - - qa_items = [] - if isinstance(raw, list): - for entry in raw: - qa_items.extend(entry.get("qa", [])) - else: - qa_items.extend(raw.get("qa", [])) - - items = qa_items[:20] # 测试多少个问题 - - # 初始化增强监控器 - monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6) - - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm = factory.get_llm_client(SELECTED_LLM_ID) - - # 初始化embedder - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 初始化连接器 - connector = Neo4jConnector() - - # 初始化结果字典 - results = { - "questions": [], - "overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0}, - "category_metrics": {}, - "retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0}, - "performance_trend": "stable", - "timestamp": datetime.now().isoformat(), - "enhanced_strategy": True - } - - total_f1 = 0.0 - total_bleu1 = 0.0 - total_jaccard = 0.0 - total_loc_f1 = 0.0 - total_context_length = 0 - total_retrieved_docs = 0 - category_stats = {} - - try: - for i, item in enumerate(items): - monitor.question_count += 1 - - # 获取近期性能用于重置判断 - recent_performance = monitor.get_recent_performance() - - # 增强的重置判断 - should_reset = monitor.should_reset_connections(current_f1=recent_performance) - if should_reset and i > 0: - print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...") - await connector.close() - connector = Neo4jConnector() # 创建新连接 - print("✅ 连接重置完成") - - q = item.get("question", "") - ref = item.get("answer", "") - ref_str = str(ref) if ref is not None else "" - - print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}") - print(f"✅ 真实答案: {ref_str}") - - # 分类别统计 - category = "Unknown" - if item.get("category") == 1: - category = "Multi-Hop" - elif item.get("category") == 2: - category = "Temporal" - elif item.get("category") == 3: - category = "Open Domain" - elif item.get("category") == 4: - category = "Single-Hop" - - # 增强的检索参数 - search_params = get_enhanced_search_params(q, i, len(items), recent_performance) - search_limit = search_params["limit"] - max_chars = search_params["max_chars"] - - print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}") - - # 使用项目标准的混合检索方法 - t0 = time.time() - contexts_all = [] - - try: - # 使用统一的搜索服务 - from app.core.memory.storage_services.search import run_hybrid_search - - print("🔀 使用混合搜索服务...") - - search_results = await run_hybrid_search( - query_text=q, - search_type="hybrid", - end_user_id="locomo_sk", - limit=20, - include=["statements", "chunks", "entities", "summaries"], - alpha=0.6, # BM25权重 - embedding_id=SELECTED_EMBEDDING_ID - ) - - # 处理搜索结果 - 新的搜索服务返回统一的结构 - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要") - - # 构建上下文:优先使用 chunks、statements 和 summaries - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # 实体摘要:最多加入前3个高分实体,避免噪声 - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - print(f"📊 有效上下文数量: {len(contexts_all)}") - except Exception as e: - print(f"❌ 检索失败: {e}") - contexts_all = [] - - t1 = time.time() - search_time = (t1 - t0) * 1000 - - # 增强的上下文选择 - context_text = "" - if contexts_all: - # 使用增强的上下文选择 - context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars) - - # 如果智能选择后仍然过长,进行最终保护性截断 - if len(context_text) > max_chars: - print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断") - context_text = context_text[:max_chars] + "\n\n[最终截断...]" - - # 时间解析 - anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性 - context_text = _resolve_relative_times(context_text, anchor_date) - - context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text - - print(f"📝 最终上下文长度: {len(context_text)} 字符") - - # 显示不同上下文的预览(不只是第一条) - print("🔍 上下文预览:") - for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文 - preview = context[:150].replace('\n', ' ') - print(f" 上下文{j+1}: {preview}...") - - # 🔍 调试:检查答案是否在上下文中 - if ref_str and ref_str.strip(): - answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all) - print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}") - - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # LLM 回答 - messages = [ - {"role": "system", "content": ( - "You are a precise QA assistant. Answer following these rules:\n" - "1) Extract the EXACT information mentioned in the context\n" - "2) For time questions: calculate actual dates from relative times\n" - "3) Return ONLY the answer text in simplest form\n" - "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" - "5) If no clear answer found, respond with 'Unknown'" - )}, - {"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"}, - ] - - t2 = time.time() - try: - # 使用异步调用 - resp = await llm.chat(messages=messages) - # 兼容不同的响应格式 - pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - except Exception as e: - print(f"❌ LLM 生成失败: {e}") - pred = "Unknown" - t3 = time.time() - llm_time = (t3 - t2) * 1000 - - # 计算指标 - 使用导入的指标函数 - f1_val = f1_score(pred, ref_str) - bleu1_val = bleu1(pred, ref_str) - jaccard_val = jaccard(pred, ref_str) - loc_f1_val = loc_f1_score(pred, ref_str) - - print(f"🤖 LLM 回答: {pred}") - print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}") - print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms") - - # 更新统计 - total_f1 += f1_val - total_bleu1 += bleu1_val - total_jaccard += jaccard_val - total_loc_f1 += loc_f1_val - total_context_length += len(context_text) - total_retrieved_docs += len(contexts_all) - - if category not in category_stats: - category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0} - - category_stats[category]["count"] += 1 - category_stats[category]["f1_sum"] += f1_val - category_stats[category]["b1_sum"] += bleu1_val - category_stats[category]["j_sum"] += jaccard_val - category_stats[category]["loc_f1_sum"] += loc_f1_val - - # 记录性能指标 - metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val} - monitor.record_performance(i, metrics, len(context_text), len(contexts_all)) - - # 保存结果 - question_result = { - "question": q, - "ground_truth": ref_str, - "prediction": pred, - "category": category, - "metrics": metrics, - "retrieval": { - "retrieved_documents": len(contexts_all), - "context_length": len(context_text), - "search_limit": search_limit, - "max_chars": max_chars, - "recent_performance": recent_performance - }, - "timing": { - "search_ms": search_time, - "llm_ms": llm_time - } - } - - results["questions"].append(question_result) - - print("="*60) - - except Exception as e: - print(f"❌ 评估过程中发生错误: {e}") - # 即使出错,也返回已有的结果 - import traceback - traceback.print_exc() - - finally: - await connector.close() - - # 计算总体指标 - n = len(items) - if n > 0: - results["overall_metrics"] = { - "f1": total_f1 / n, - "b1": total_bleu1 / n, - "j": total_jaccard / n, - "loc_f1": total_loc_f1 / n - } - - for category, stats in category_stats.items(): - count = stats["count"] - results["category_metrics"][category] = { - "count": count, - "f1": stats["f1_sum"] / count, - "bleu1": stats["b1_sum"] / count, - "jaccard": stats["j_sum"] / count, - "loc_f1": stats["loc_f1_sum"] / count - } - - results["retrieval_stats"]["avg_context_length"] = total_context_length / n - results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n - - # 分析性能趋势 - results["performance_trend"] = monitor.get_performance_trend() - results["reset_interval"] = monitor.reset_interval - results["total_questions_processed"] = monitor.question_count - - return results - - -if __name__ == "__main__": - print("🚀 运行增强版完整评估(解决中间性能衰减问题)...") - print("📋 增强特性:") - print(" - 双重重置策略:定期重置 + 性能驱动重置") - print(" - 动态检索参数:基于近期性能自适应调整") - print(" - 中间阶段严格筛选:提高上下文质量要求") - print(" - 连续性能监控:实时检测性能衰减") - - result = asyncio.run(run_enhanced_evaluation()) - - print("\n📊 最终评估结果:") - print("总体指标:") - print(f" F1: {result['overall_metrics']['f1']:.4f}") - print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}") - print(f" Jaccard: {result['overall_metrics']['j']:.4f}") - print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}") - - print("\n分类别指标:") - for category, metrics in result['category_metrics'].items(): - print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})") - - print("\n检索统计:") - stats = result['retrieval_stats'] - print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符") - print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}") - - print(f"\n性能趋势: {result['performance_trend']}") - print(f"重置间隔: 每{result['reset_interval']}个问题") - print(f"处理问题总数: {result['total_questions_processed']}") - print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}") - - - # 保存结果到指定目录 - # 使用代码文件所在目录的绝对路径 - current_file_dir = os.path.dirname(os.path.abspath(__file__)) - output_dir = os.path.join(current_file_dir, "results") - os.makedirs(output_dir, exist_ok=True) - output_file = os.path.join(output_dir, "enhanced_evaluation_results.json") - with open(output_file, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n详细结果已保存到: {output_file}") diff --git a/api/app/core/memory/evaluation/locomo/locomo_utils.py b/api/app/core/memory/evaluation/locomo/locomo_utils.py deleted file mode 100644 index d3b74947..00000000 --- a/api/app/core/memory/evaluation/locomo/locomo_utils.py +++ /dev/null @@ -1,626 +0,0 @@ -""" -LoCoMo Utilities Module - -This module provides helper functions for the LoCoMo benchmark evaluation: -- Data loading from JSON files -- Conversation extraction for ingestion -- Temporal reference resolution -- Context selection and formatting -- Retrieval wrapper functions -- Ingestion wrapper functions -""" - -import os -import json -import re -from datetime import datetime, timedelta -from typing import List, Dict, Any, Optional - -from app.core.memory.utils.definitions import PROJECT_ROOT -from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline - - -def load_locomo_data( - data_path: str, - sample_size: int, - conversation_index: int = 0 -) -> List[Dict[str, Any]]: - """ - Load LoCoMo dataset from JSON file. - - The LoCoMo dataset structure is a list of conversation objects, where each - object contains a "qa" list of question-answer pairs. - - Args: - data_path: Path to locomo10.json file - sample_size: Number of QA pairs to load (limits total QA items returned) - conversation_index: Which conversation to load QA pairs from (default: 0 for first) - - Returns: - List of QA item dictionaries, each containing: - - question: str - - answer: str - - category: int (1-4) - - evidence: List[str] - - Raises: - FileNotFoundError: If data_path does not exist - json.JSONDecodeError: If file is not valid JSON - IndexError: If conversation_index is out of range - """ - if not os.path.exists(data_path): - raise FileNotFoundError(f"LoCoMo data file not found: {data_path}") - - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - - # LoCoMo data structure: list of objects, each with a "qa" list - qa_items: List[Dict[str, Any]] = [] - - if isinstance(raw, list): - # Only load QA pairs from the specified conversation - if conversation_index < len(raw): - entry = raw[conversation_index] - if isinstance(entry, dict) and "qa" in entry: - qa_items.extend(entry.get("qa", [])) - else: - raise IndexError( - f"Conversation index {conversation_index} out of range. " - f"Dataset has {len(raw)} conversations." - ) - else: - # Fallback: single object with qa list - if conversation_index == 0: - qa_items.extend(raw.get("qa", [])) - else: - raise IndexError( - f"Conversation index {conversation_index} out of range. " - f"Dataset has only 1 conversation." - ) - - # Return only the requested sample size - return qa_items[:sample_size] - - -def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]: - """ - Extract conversation texts from LoCoMo data for ingestion. - - This function extracts the raw conversation dialogues from the LoCoMo dataset - so they can be ingested into the memory system. Each conversation is formatted - as a multi-line string with "role: message" format. - - Args: - data_path: Path to locomo10.json file - max_dialogues: Maximum number of dialogues to extract (default: 1) - - Returns: - List of conversation strings formatted for ingestion. - Each string contains multiple lines in format "role: message" - - Example output: - [ - "User: I went to the store yesterday.\\nAI: What did you buy?\\n...", - "User: I love hiking.\\nAI: Where do you like to hike?\\n..." - ] - """ - if not os.path.exists(data_path): - raise FileNotFoundError(f"LoCoMo data file not found: {data_path}") - - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - - # Ensure we have a list of entries - entries = raw if isinstance(raw, list) else [raw] - - contents: List[str] = [] - - for i, entry in enumerate(entries[:max_dialogues]): - if not isinstance(entry, dict): - continue - - conv = entry.get("conversation", {}) - - if not isinstance(conv, dict): - continue - - lines: List[str] = [] - - # Collect all session_* messages - for key, val in sorted(conv.items()): - if isinstance(val, list) and key.startswith("session_"): - for msg in val: - if not isinstance(msg, dict): - continue - - role = msg.get("speaker") or "User" - text = msg.get("text") or "" - text = str(text).strip() - - if not text: - continue - - lines.append(f"{role}: {text}") - - if lines: - contents.append("\n".join(lines)) - - return contents - - -def resolve_temporal_references(text: str, anchor_date: datetime) -> str: - """ - Resolve relative temporal references to absolute dates. - - This function converts relative time expressions (like "today", "yesterday", - "3 days ago") into absolute ISO date strings based on an anchor date. - - Supported patterns: - - today, yesterday, tomorrow - - X days ago, in X days - - last week, next week - - Args: - text: Text containing temporal references - anchor_date: Reference date for resolution (datetime object) - - Returns: - Text with temporal references replaced by ISO dates (YYYY-MM-DD format) - - Example: - >>> anchor = datetime(2023, 5, 8) - >>> resolve_temporal_references("I saw him yesterday", anchor) - "I saw him 2023-05-07" - """ - # Ensure input is a string - t = str(text) if text is not None else "" - - # today / yesterday / tomorrow - t = re.sub( - r"\btoday\b", - anchor_date.date().isoformat(), - t, - flags=re.IGNORECASE - ) - t = re.sub( - r"\byesterday\b", - (anchor_date - timedelta(days=1)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - t = re.sub( - r"\btomorrow\b", - (anchor_date + timedelta(days=1)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - - # X days ago - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor_date - timedelta(days=n)).date().isoformat() - - # in X days - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor_date + timedelta(days=n)).date().isoformat() - - t = re.sub( - r"\b(\d+)\s+days?\s+ago\b", - _ago_repl, - t, - flags=re.IGNORECASE - ) - t = re.sub( - r"\bin\s+(\d+)\s+days?\b", - _in_repl, - t, - flags=re.IGNORECASE - ) - - # last week / next week (approximate as 7 days) - t = re.sub( - r"\blast\s+week\b", - (anchor_date - timedelta(days=7)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - t = re.sub( - r"\bnext\s+week\b", - (anchor_date + timedelta(days=7)).date().isoformat(), - t, - flags=re.IGNORECASE - ) - - return t - - -def select_and_format_information( - retrieved_info: List[str], - question: str, - max_chars: int = 8000 -) -> str: - """ - Intelligently select and format most relevant retrieved information for LLM prompt. - - This function scores each piece of retrieved information based on keyword matching - with the question, then selects the highest-scoring pieces up to the character limit. - - Scoring criteria: - - Keyword matches (higher weight for multiple occurrences) - - Context length (moderate length preferred) - - Position (earlier contexts get bonus points) - - Args: - retrieved_info: List of retrieved information strings (chunks, statements, entities) - question: Question being answered - max_chars: Maximum total characters to include in final prompt - - Returns: - Formatted string combining the most relevant information for LLM prompt. - Contexts are separated by double newlines. - - Example: - >>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"] - >>> question = "Where did Alice go?" - >>> select_and_format_information(contexts, question, max_chars=100) - "Alice went to Paris\\n\\nAlice visited the Eiffel Tower" - """ - if not retrieved_info: - return "" - - # Extract question keywords (filter out stop words and short words) - question_lower = question.lower() - stop_words = { - 'what', 'when', 'where', 'who', 'why', 'how', - 'did', 'do', 'does', 'is', 'are', 'was', 'were', - 'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at' - } - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = { - word for word in question_words - if word not in stop_words and len(word) > 2 - } - - # Score each context - scored_contexts = [] - for i, context in enumerate(retrieved_info): - context_lower = context.lower() - score = 0 - - # Keyword matching score - keyword_matches = 0 - for word in question_words: - if word in context_lower: - keyword_matches += 1 - # Multiple occurrences increase score - score += context_lower.count(word) * 2 - - # Length score (prefer moderate length) - context_len = len(context) - if 100 < context_len < 2000: - score += 5 - elif context_len >= 2000: - score += 2 - - # Position bonus (earlier contexts often more relevant) - if i < 3: - score += 3 - - scored_contexts.append((score, context, keyword_matches)) - - # Sort by score (descending) - scored_contexts.sort(key=lambda x: x[0], reverse=True) - - # Select contexts up to character limit - selected = [] - total_chars = 0 - - for score, context, matches in scored_contexts: - if total_chars + len(context) <= max_chars: - selected.append(context) - total_chars += len(context) - else: - # Try to include high-scoring context by truncating - if score > 10 and total_chars < max_chars - 500: - remaining = max_chars - total_chars - # Find lines with keywords - lines = context.split('\n') - relevant_lines = [] - current_chars = 0 - - for line in lines: - line_lower = line.lower() - line_relevance = any(word in line_lower for word in question_words) - - if line_relevance and current_chars < remaining - 100: - relevant_lines.append(line) - current_chars += len(line) - - if relevant_lines and len('\n'.join(relevant_lines)) > 100: - truncated = '\n'.join(relevant_lines) - selected.append(truncated + "\n[Content truncated...]") - total_chars += len(truncated) - break - - return "\n\n".join(selected) - - -async def retrieve_relevant_information( - question: str, - end_user_id: str, - search_type: str, - search_limit: int, - connector: Any, - embedder: Any -) -> List[str]: - """ - Retrieve relevant information from memory graph for a question. - - This function searches the Neo4j memory graph (populated during ingestion) and - returns relevant chunks, statements, and entity information that might help - answer the question. - - The function supports three search types: - - "keyword": Full-text search using Cypher queries - - "embedding": Vector similarity search using embeddings - - "hybrid": Combination of keyword and embedding search with reranking - - Args: - question: Question to search for - end_user_id: Database group ID (identifies which conversation memory to search) - search_type: "keyword", "embedding", or "hybrid" - search_limit: Max memory pieces to retrieve - connector: Neo4j connector instance - embedder: Embedder client instance - - Returns: - List of text strings (chunks, statements, entity summaries) from memory graph. - Each string represents a piece of retrieved information. - - Raises: - Exception: If search fails (caught and returns empty list) - """ - from app.repositories.neo4j.graph_search import ( - search_graph, - search_graph_by_embedding - ) - from app.core.memory.storage_services.search import run_hybrid_search - - contexts_all: List[str] = [] - - try: - if search_type == "embedding": - # Embedding-based search - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # Build context from chunks - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - # Add statements - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - # Add summaries - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # Add top entities (limit to 3 to avoid noise) - if entities: - scored = [e for e in entities if e.get("score") is not None] - top_entities = ( - sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] - if scored else entities[:3] - ) - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append( - f"EntitySummary: {name}" - f"{(' [' + '; '.join(meta) + ']') if meta else ''}" - ) - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - # Keyword-based search - search_results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit - ) - - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - # Build context from dialogues - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - - # Add statements - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - # Add entity names - if entities: - entity_names = [ - str(e.get("name", "")).strip() - for e in entities[:5] - if e.get("name") - ] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid - # Hybrid search with fallback to embedding - try: - search_results = await run_hybrid_search( - query_text=question, - search_type=search_type, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - output_path=None, - ) - - # Handle flat structure (new API format) - if search_results and isinstance(search_results, dict): - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # Check if we got results - if not (chunks or statements or entities or summaries): - # Try nested structure (backward compatibility) - reranked = search_results.get("reranked_results", {}) - if reranked and isinstance(reranked, dict): - chunks = reranked.get("chunks", []) - statements = reranked.get("statements", []) - entities = reranked.get("entities", []) - summaries = reranked.get("summaries", []) - else: - raise ValueError("Hybrid search returned empty results") - else: - raise ValueError("Hybrid search returned empty results") - - except Exception as e: - # Fallback to embedding search - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # Build context (same for both hybrid and fallback) - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # Add top entities - if entities: - scored = [e for e in entities if e.get("score") is not None] - top_entities = ( - sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] - if scored else entities[:3] - ) - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append( - f"EntitySummary: {name}" - f"{(' [' + '; '.join(meta) + ']') if meta else ''}" - ) - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - except Exception as e: - # Return empty list on error - contexts_all = [] - - return contexts_all - - -async def ingest_conversations_if_needed( - conversations: List[str], - end_user_id: str, - reset: bool = False -) -> bool: - """ - Wrapper for conversation ingestion using external extraction pipeline. - - This function populates the Neo4j database with processed conversation data - (chunks, statements, entities) so that the retrieval system has memory to search. - - The ingestion process: - 1. Parses conversation text into dialogue messages - 2. Chunks the dialogues into semantic units - 3. Extracts statements and entities using LLM - 4. Generates embeddings for all content - 5. Stores everything in Neo4j graph database - - Args: - conversations: List of raw conversation texts from LoCoMo dataset - Example: ["User: I went to Paris. AI: When was that?", ...] - end_user_id: Target group ID for database storage - reset: Whether to clear existing data first (not implemented in wrapper) - - Returns: - True if successful, False otherwise - - Note: - The external function uses "contexts" to mean "conversation texts". - This runs the full extraction pipeline: chunking → entity extraction → - statement extraction → embedding → Neo4j storage. - """ - try: - success = await ingest_contexts_via_full_pipeline( - contexts=conversations, - end_user_id=end_user_id, - save_chunk_output=True - ) - return success - except Exception as e: - print(f"[Ingestion] Failed to ingest conversations: {e}") - return False diff --git a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py b/api/app/core/memory/evaluation/locomo/qwen_search_eval.py deleted file mode 100644 index 6a5caa0c..00000000 --- a/api/app/core/memory/evaluation/locomo/qwen_search_eval.py +++ /dev/null @@ -1,878 +0,0 @@ -import argparse -import asyncio -import json -import os -import statistics -import time -from datetime import datetime, timedelta -from typing import Any, Dict, List - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -import re - -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - bleu1, - jaccard, - latency_stats, -) -from app.core.memory.evaluation.common.metrics import f1_score as common_f1 -from app.core.memory.evaluation.extraction_utils import ( - ingest_contexts_via_full_pipeline, -) -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.storage_services.search import run_hybrid_search -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService - - -# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现) -def _loc_normalize(text: str) -> str: - import re - # 确保输入是字符串 - text = str(text) if text is not None else "" - text = text.lower() - text = re.sub(r"[\,]", " ", text) # 去掉逗号 - text = re.sub(r"\b(a|an|the|and)\b", " ", text) - text = re.sub(r"[^\w\s]", " ", text) - text = " ".join(text.split()) - return text - -# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week) -def _resolve_relative_times(text: str, anchor: datetime) -> str: - import re - # 确保输入是字符串 - t = str(text) if text is not None else "" - # today / yesterday / tomorrow - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - # X days ago / in X days - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - # last week / next week(以7天近似) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - return t - -def loc_f1_score(prediction: str, ground_truth: str) -> float: - # 单答案 F1:按词集合计算(近似原始实现,去除词干依赖) - # 确保输入是字符串 - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - p_tokens = _loc_normalize(pred_str).split() - g_tokens = _loc_normalize(truth_str).split() - if not p_tokens or not g_tokens: - return 0.0 - p = set(p_tokens) - g = set(g_tokens) - tp = len(p & g) - precision = tp / len(p) if p else 0.0 - recall = tp / len(g) if g else 0.0 - return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 - -def loc_multi_f1(prediction: str, ground_truth: str) -> float: - # 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均 - # 确保输入是字符串 - pred_str = str(prediction) if prediction is not None else "" - truth_str = str(ground_truth) if ground_truth is not None else "" - - predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()] - ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()] - if not predictions or not ground_truths: - return 0.0 - def _f1(a: str, b: str) -> float: - return loc_f1_score(a, b) - vals = [] - for gt in ground_truths: - vals.append(max(_f1(pred, gt) for pred in predictions)) - return sum(vals) / len(vals) - -# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type -CATEGORY_MAP_NUM_TO_NAME = { - 4: "Single-Hop", - 1: "Multi-Hop", - 3: "Open Domain", - 2: "Temporal", -} - -_TYPE_ALIASES = { - "single-hop": "Single-Hop", - "singlehop": "Single-Hop", - "single hop": "Single-Hop", - "multi-hop": "Multi-Hop", - "multihop": "Multi-Hop", - "multi hop": "Multi-Hop", - "open domain": "Open Domain", - "opendomain": "Open Domain", - "temporal": "Temporal", -} - -def get_category_label(item: Dict[str, Any]) -> str: - # 1) 直接用字符串 cat - cat = item.get("cat") - if isinstance(cat, str) and cat.strip(): - name = cat.strip() - lower = name.lower() - return _TYPE_ALIASES.get(lower, name) - # 2) 数字 category 转名称 - cat_num = item.get("category") - if isinstance(cat_num, int): - return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown") - # 3) 备用 type 字段 - t = item.get("type") - if isinstance(t, str) and t.strip(): - lower = t.strip().lower() - return _TYPE_ALIASES.get(lower, t.strip()) - return "unknown" - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str: - """基于问题关键词智能选择上下文""" - if not contexts: - return "" - - # 提取问题关键词(只保留有意义的词) - question_lower = question.lower() - stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'} - question_words = set(re.findall(r'\b\w+\b', question_lower)) - question_words = {word for word in question_words if word not in stop_words and len(word) > 2} - - print(f"🔍 问题关键词: {question_words}") - - # 给每个上下文打分 - scored_contexts = [] - for i, context in enumerate(contexts): - context_lower = context.lower() - score = 0 - - # 关键词匹配得分 - keyword_matches = 0 - for word in question_words: - if word in context_lower: - keyword_matches += 1 - # 关键词出现次数越多,得分越高 - score += context_lower.count(word) * 2 - - # 上下文长度得分(适中的长度更好) - context_len = len(context) - if 100 < context_len < 2000: # 理想长度范围 - score += 5 - elif context_len >= 2000: # 太长可能包含无关信息 - score += 2 - - # 如果是前几个上下文,给予额外分数(通常相关性更高) - if i < 3: - score += 3 - - scored_contexts.append((score, context, keyword_matches)) - - # 按得分排序 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - - # 选择高得分的上下文,直到达到字符限制 - selected = [] - total_chars = 0 - selected_count = 0 - - print("📊 上下文相关性分析:") - for score, context, matches in scored_contexts[:5]: # 只显示前5个 - print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}") - - for score, context, matches in scored_contexts: - if total_chars + len(context) <= max_chars: - selected.append(context) - total_chars += len(context) - selected_count += 1 - else: - # 如果这个上下文得分很高但放不下,尝试截取 - if score > 10 and total_chars < max_chars - 500: - remaining = max_chars - total_chars - # 找到包含关键词的部分 - lines = context.split('\n') - relevant_lines = [] - current_chars = 0 - - for line in lines: - line_lower = line.lower() - line_relevance = any(word in line_lower for word in question_words) - - if line_relevance and current_chars < remaining - 100: - relevant_lines.append(line) - current_chars += len(line) - - if relevant_lines: - truncated = '\n'.join(relevant_lines) - if len(truncated) > 100: # 确保有足够内容 - selected.append(truncated + "\n[相关内容截断...]") - total_chars += len(truncated) - selected_count += 1 - break # 不再尝试添加更多上下文 - - result = "\n\n".join(selected) - print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符") - return result - - -def get_search_params_by_category(category: str): - """根据问题类别调整检索参数""" - params_map = { - "Multi-Hop": {"limit": 20, "max_chars": 15000}, - "Temporal": {"limit": 16, "max_chars": 10000}, - "Open Domain": {"limit": 24, "max_chars": 18000}, - "Single-Hop": {"limit": 12, "max_chars": 8000}, - } - return params_map.get(category, {"limit": 16, "max_chars": 12000}) - - -async def run_locomo_eval( - sample_size: int = 1, - end_user_id: str | None = None, - search_limit: int = 8, - context_char_budget: int = 4000, # 保持默认值不变 - llm_temperature: float = 0.0, - llm_max_tokens: int = 32, - search_type: str = "hybrid", # 保持默认值不变 - output_path: str | None = None, - skip_ingest_if_exists: bool = True, - llm_timeout: float = 10.0, - llm_max_retries: int = 1 -) -> Dict[str, Any]: - - # 函数内部使用三路检索逻辑,但保持参数签名不变 - end_user_id = end_user_id or SELECTED_end_user_id - data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") - if not os.path.exists(data_path): - data_path = os.path.join(os.getcwd(), "data", "locomo10.json") - with open(data_path, "r", encoding="utf-8") as f: - raw = json.load(f) - # LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表 - qa_items: List[Dict[str, Any]] = [] - if isinstance(raw, list): - for entry in raw: - qa_items.extend(entry.get("qa", [])) - else: - qa_items.extend(raw.get("qa", [])) - items: List[Dict[str, Any]] = qa_items[:sample_size] - - # === 保持原来的数据摄入逻辑 === - entries = raw if isinstance(raw, list) else [raw] - - # 只摄入前1条对话(保持原样) - max_dialogues_to_ingest = 1 - contents: List[str] = [] - print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条") - - for i, entry in enumerate(entries[:max_dialogues_to_ingest]): - if not isinstance(entry, dict): - continue - - conv = entry.get("conversation", {}) - sample_id = entry.get("sample_id", f"unknown_{i}") - - print(f"🔍 处理对话 {i+1}: {sample_id}") - - lines: List[str] = [] - if isinstance(conv, dict): - # 收集所有 session_* 的消息 - session_count = 0 - for key, val in conv.items(): - if isinstance(val, list) and key.startswith("session_"): - session_count += 1 - for msg in val: - role = msg.get("speaker") or "用户" - text = msg.get("text") or "" - text = str(text).strip() - if not text: - continue - lines.append(f"{role}: {text}") - - print(f" - 包含 {session_count} 个session, {len(lines)} 条消息") - - if not lines: - print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入") - continue - - contents.append("\n".join(lines)) - - print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容") - - # 选择要评测的QA对(从所有对话中选取) - indexed_items: List[tuple[int, Dict[str, Any]]] = [] - if isinstance(raw, list): - for e_idx, entry in enumerate(raw): - for qa in entry.get("qa", []): - indexed_items.append((e_idx, qa)) - else: - for qa in raw.get("qa", []): - indexed_items.append((0, qa)) - - # 这里使用sample_size来限制评测的QA数量 - selected = indexed_items[:sample_size] - items: List[Dict[str, Any]] = [qa for _, qa in selected] - - print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话") - # === 修改结束 === - - connector = Neo4jConnector() - - # 关键修复:强制重新摄入纯净的对话数据 - print("🔄 强制重新摄入纯净的对话数据...") - await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True) - - # 使用异步LLM客户端 - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) - # 初始化embedder用于直接调用 - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # connector initialized above - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - # 上下文诊断收集 - per_query_context_counts: List[int] = [] - per_query_context_avg_tokens: List[float] = [] - per_query_context_chars: List[int] = [] - per_query_context_tokens_total: List[int] = [] - # 详细样本调试信息 - samples: List[Dict[str, Any]] = [] - # 通用指标 - f1s: List[float] = [] - b1s: List[float] = [] - jss: List[float] = [] - # 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1) - loc_f1s: List[float] = [] - # Per-category aggregation - cat_counts: Dict[str, int] = {} - cat_f1s: Dict[str, List[float]] = {} - cat_b1s: Dict[str, List[float]] = {} - cat_jss: Dict[str, List[float]] = {} - cat_loc_f1s: Dict[str, List[float]] = {} - try: - for item in items: - q = item.get("question", "") - ref = item.get("answer", "") - # 确保答案是字符串 - ref_str = str(ref) if ref is not None else "" - cat = get_category_label(item) - - print(f"\n=== 处理问题: {q} ===") - - # 根据类别调整检索参数 - search_params = get_search_params_by_category(cat) - adjusted_limit = search_params["limit"] - max_chars = search_params["max_chars"] - - print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}") - - # 改进的检索逻辑:使用三路检索(statements, dialogues, entities) - t0 = time.time() - contexts_all: List[str] = [] - search_results = None # 保存完整的检索结果 - - try: - if search_type == "embedding": - # 直接调用嵌入检索,包含三路数据 - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=q, - end_user_id=end_user_id, - limit=adjusted_limit, - include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型 - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要") - - # 构建上下文:优先使用 chunks、statements 和 summaries - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # 实体摘要:最多加入前3个高分实体,避免噪声 - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - # 直接调用关键词检索 - search_results = await search_graph( - connector=connector, - q=q, - end_user_id=end_user_id, - limit=adjusted_limit - ) - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体") - - # 构建上下文 - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体处理(关键词检索的实体可能没有分数) - if entities: - entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid - # 🎯 关键修复:混合检索使用更严格的回退机制 - print("🔀 使用混合检索(带回退机制)...") - try: - search_results = await run_hybrid_search( - query_text=q, - search_type=search_type, - end_user_id=end_user_id, - limit=adjusted_limit, - include=["chunks", "statements", "entities", "summaries"], - output_path=None, - ) - - # 🎯 关键修复:正确处理混合检索的扁平结构 - # 新的API返回扁平结构,直接从顶层获取结果 - if search_results and isinstance(search_results, dict): - # 新API返回扁平结构:直接从顶层获取 - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - # 检查是否有有效结果 - if chunks or statements or entities or summaries: - print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要") - else: - # 如果顶层没有结果,尝试旧的嵌套结构(向后兼容) - reranked = search_results.get("reranked_results", {}) - if reranked and isinstance(reranked, dict): - chunks = reranked.get("chunks", []) - statements = reranked.get("statements", []) - entities = reranked.get("entities", []) - summaries = reranked.get("summaries", []) - print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述") - else: - raise ValueError("混合检索返回空结果") - else: - raise ValueError("混合检索返回空结果") - - except Exception as e: - print(f"❌ 混合检索失败: {e},回退到嵌入检索") - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=q, - end_user_id=end_user_id, - limit=adjusted_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述") - - # 🎯 统一处理:构建上下文(所有检索类型共用) - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - - # 实体摘要:最多加入前3个高分实体 - if entities: - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - # 关键修复:过滤掉包含当前问题答案的上下文 - filtered_contexts = [] - for context in contexts_all: - content = str(context) - # 排除包含当前问题标准答案的上下文 - if ref_str and ref_str.strip() and ref_str.strip() in content: - print("🚫 过滤掉包含标准答案的上下文") - continue - filtered_contexts.append(context) - - print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)") - contexts_all = filtered_contexts - - # 输出完整的检索结果信息 - print("🔍 检索结果详情:") - if search_results: - output_data = { - "statements": [ - { - "statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""), - "score": s.get("score", 0.0) - } - for s in (statements[:2] if 'statements' in locals() else []) - ], - "dialogues": [ - { - "uuid": d.get("uuid", ""), - "end_user_id": d.get("end_user_id", ""), - "content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""), - "score": d.get("score", 0.0) - } - for d in (dialogs[:2] if 'dialogs' in locals() else []) - ], - "entities": [ - { - "name": e.get("name", ""), - "entity_type": e.get("entity_type", ""), - "score": e.get("score", 0.0) - } - for e in (entities[:2] if 'entities' in locals() else []) - ] - } - print(json.dumps(output_data, ensure_ascii=False, indent=2)) - else: - print(" 无检索结果") - - except Exception as e: - print(f"❌ {search_type}检索失败: {e}") - contexts_all = [] - search_results = None - - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 使用智能上下文选择 - context_text = "" - if contexts_all: - context_text = smart_context_selection(contexts_all, q, max_chars=max_chars) - - # 如果智能选择后仍然过长,进行最终保护性截断 - if len(context_text) > max_chars: - print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断") - context_text = context_text[:max_chars] + "\n\n[最终截断...]" - - # 时间解析 - anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性 - context_text = _resolve_relative_times(context_text, anchor_date) - - context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text - - print(f"📝 最终上下文长度: {len(context_text)} 字符") - - # 显示不同上下文的预览 - print("🔍 上下文预览:") - for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文 - preview = context[:150].replace('\n', ' ') - print(f" 上下文{j+1}: {preview}...") - - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # 记录上下文诊断信息 - per_query_context_counts.append(len(contexts_all)) - per_query_context_avg_tokens.append(avg_context_tokens([context_text])) - per_query_context_chars.append(len(context_text)) - per_query_context_tokens_total.append(len(_loc_normalize(context_text).split())) - - # LLM 提示词 - messages = [ - {"role": "system", "content": ( - "You are a precise QA assistant. Answer following these rules:\n" - "1) Extract the EXACT information mentioned in the context\n" - "2) For time questions: calculate actual dates from relative times\n" - "3) Return ONLY the answer text in simplest form\n" - "4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n" - "5) If no clear answer found, respond with 'Unknown'" - )}, - {"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"}, - ] - - t2 = time.time() - # 使用异步调用 - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - - # 兼容不同的响应格式 - pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - - # 计算指标(确保使用字符串) - f1_val = common_f1(str(pred), ref_str) - b1_val = bleu1(str(pred), ref_str) - j_val = jaccard(str(pred), ref_str) - - f1s.append(f1_val) - b1s.append(b1_val) - jss.append(j_val) - - # Accumulate by category - cat_counts[cat] = cat_counts.get(cat, 0) + 1 - cat_f1s.setdefault(cat, []).append(f1_val) - cat_b1s.setdefault(cat, []).append(b1_val) - cat_jss.setdefault(cat, []).append(j_val) - - # LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1 - if item.get("category") in [2, 3, 4]: - loc_val = loc_f1_score(str(pred), ref_str) - elif item.get("category") in [1]: - loc_val = loc_multi_f1(str(pred), ref_str) - else: - loc_val = loc_f1_score(str(pred), ref_str) - loc_f1s.append(loc_val) - cat_loc_f1s.setdefault(cat, []).append(loc_val) - - # 保存完整的检索结果信息 - samples.append({ - "question": q, - "answer": ref_str, - "category": cat, - "prediction": pred, - "metrics": { - "f1": f1_val, - "b1": b1_val, - "j": j_val, - "loc_f1": loc_val - }, - "retrieval": { - "retrieved_documents": len(contexts_all), - "context_length": len(context_text), - "search_limit": adjusted_limit, - "max_chars": max_chars - }, - "timing": { - "search_ms": (t1 - t0) * 1000, - "llm_ms": (t3 - t2) * 1000 - } - }) - - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {ref_str}") - print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}") - - # Compute per-category averages and dispersion (std, iqr) - def _percentile(sorted_vals: List[float], p: float) -> float: - if not sorted_vals: - return 0.0 - if len(sorted_vals) == 1: - return sorted_vals[0] - k = (len(sorted_vals) - 1) * p - f = int(k) - c = f + 1 if f + 1 < len(sorted_vals) else f - if f == c: - return sorted_vals[f] - return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f) - - by_category: Dict[str, Dict[str, float | int]] = {} - for c in cat_counts: - f_list = cat_f1s.get(c, []) - b_list = cat_b1s.get(c, []) - j_list = cat_jss.get(c, []) - lf_list = cat_loc_f1s.get(c, []) - j_sorted = sorted(j_list) - j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0 - j_q75 = _percentile(j_sorted, 0.75) - j_q25 = _percentile(j_sorted, 0.25) - by_category[c] = { - "count": cat_counts[c], - "f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0, - "b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0, - "j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0, - "j_std": j_std, - "j_iqr": (j_q75 - j_q25) if j_list else 0.0, - # 参考 LoCoMo 评测的类别专用 F1 - "loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0, - } - - # 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿 - cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts} - - result = { - "dataset": "locomo", - "items": len(items), - "metrics": { - "f1": sum(f1s) / max(len(f1s), 1), - "b1": sum(b1s) / max(len(b1s), 1), - "j": sum(jss) / max(len(jss), 1), - # LoCoMo 类别专用 F1 的总体 - "loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1), - }, - "by_category": by_category, - "category_counts": cat_counts, - "cum_accuracy_by_category": cum_accuracy_by_category, - "context": { - "avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0, - "avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0, - "count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0, - "avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0, - }, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "samples": samples, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "search_type": search_type, - "llm_id": SELECTED_LLM_ID, - "retrieval_embedding_id": SELECTED_EMBEDDING_ID, - "skip_ingest_if_exists": skip_ingest_if_exists, - "llm_timeout": llm_timeout, - "llm_max_retries": llm_max_retries, - "llm_temperature": llm_temperature, - "llm_max_tokens": llm_max_tokens - }, - "timestamp": datetime.now().isoformat() - } - if output_path: - try: - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"✅ 结果已保存到: {output_path}") - except Exception as e: - print(f"❌ 保存结果失败: {e}") - return result - finally: - await connector.close() - - -def main(): - parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search") - parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate") - parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval") - parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query") - parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context") - parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature") - parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens") - parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type") - parser.add_argument("--output_path", type=str, default=None, help="Output path for results") - parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists") - parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds") - parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries") - args = parser.parse_args() - - load_dotenv() - - result = asyncio.run(run_locomo_eval( - sample_size=args.sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - output_path=args.output_path, - skip_ingest_if_exists=args.skip_ingest_if_exists, - llm_timeout=args.llm_timeout, - llm_max_retries=args.llm_max_retries - )) - - print("\n" + "="*50) - print("📊 最终评测结果:") - print(f" 样本数量: {result['items']}") - print(f" F1: {result['metrics']['f1']:.3f}") - print(f" BLEU-1: {result['metrics']['b1']:.3f}") - print(f" Jaccard: {result['metrics']['j']:.3f}") - print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}") - print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符") - print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms") - print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms") - - if result['by_category']: - print("\n📈 按类别细分:") - for cat, metrics in result['by_category'].items(): - print(f" {cat}:") - print(f" 样本数: {metrics['count']}") - print(f" F1: {metrics['f1']:.3f}") - print(f" LoCoMo F1: {metrics['loc_f1']:.3f}") - print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py b/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py deleted file mode 100644 index 8710a504..00000000 --- a/api/app/core/memory/evaluation/longmemeval/qwen_search_eval.py +++ /dev/null @@ -1,1364 +0,0 @@ -import argparse -import asyncio -import json -import os -import re -import statistics -import time -from datetime import datetime, timedelta -from typing import Any, Dict, List - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -# 确保可以找到 src 及项目根路径 -import sys -from pathlib import Path - -_THIS_DIR = Path(__file__).resolve().parent -_PROJECT_ROOT = str(_THIS_DIR.parents[2]) -_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") -for _p in (_SRC_DIR, _PROJECT_ROOT): - if _p not in sys.path: - sys.path.insert(0, _p) - -# 与现有评估脚本保持一致的导入方式 -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - -try: - # 优先从 extraction_utils1 导入 - from app.core.memory.evaluation.extraction_utils import ( - ingest_contexts_via_full_pipeline, # type: ignore - ) -except Exception: - ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查 -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - jaccard, - latency_stats, -) -from app.core.memory.evaluation.common.metrics import f1_score as common_f1 -from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.services.memory_config_service import MemoryConfigService - -try: - from app.core.memory.evaluation.common.metrics import exact_match -except Exception: - # 兜底:简单的大小写不敏感比较 - def exact_match(pred: str, ref: str) -> bool: - return str(pred).strip().lower() == str(ref).strip().lower() - - -def load_dataset_any(path: str) -> List[Dict[str, Any]]: - """健壮地加载数据集(兼容 list 或多段 JSON)。""" - with open(path, "r", encoding="utf-8") as f: - s = f.read().strip() - try: - obj = json.loads(s) - if isinstance(obj, list): - return obj - elif isinstance(obj, dict): - return [obj] - except json.JSONDecodeError: - pass - dec = json.JSONDecoder() - idx = 0 - items: List[Dict[str, Any]] = [] - while idx < len(s): - while idx < len(s) and s[idx].isspace(): - idx += 1 - if idx >= len(s): - break - try: - obj, end = dec.raw_decode(s, idx) - if isinstance(obj, list): - for it in obj: - if isinstance(it, dict): - items.append(it) - elif isinstance(obj, dict): - items.append(obj) - idx = end - except json.JSONDecodeError: - nl = s.find("\n", idx) - if nl == -1: - break - idx = nl + 1 - return items - - -def is_chinese_text(s: str) -> bool: - return bool(re.search(r"[\u4e00-\u9fff]", s or "")) - - -def build_context_from_sessions(item: Dict[str, Any]) -> List[str]: - """从数据项的 haystack_sessions 构建上下文片段。 - - 优先返回包含 has_answer 的消息 - - 其次返回拼接后的整段会话 - """ - contexts: List[str] = [] - sessions = item.get("haystack_sessions", []) or item.get("sessions", []) - for session in sessions: - parts: List[str] = [] - if isinstance(session, list): - for msg in session: - role = msg.get("role", "") - content = msg.get("content", "") or msg.get("text", "") - if content: - parts.append(f"{role}: {content}" if role else str(content)) - if msg.get("has_answer", False): - contexts.append(f"{role}: {content}" if role else str(content)) - elif isinstance(session, dict): - role = session.get("role", "") - content = session.get("content", "") or session.get("text", "") - if content: - parts.append(f"{role}: {content}" if role else str(content)) - if session.get("has_answer", False): - contexts.append(f"{role}: {content}" if role else str(content)) - if parts: - contexts.append("\n".join(parts)) - # 兜底:存在单字段上下文 - if not contexts: - single_ctx = item.get("context") or item.get("dialogue") or item.get("conversation") - if isinstance(single_ctx, str) and single_ctx.strip(): - contexts.append(single_ctx.strip()) - return contexts - - -def extract_candidate_options(question: str) -> List[str]: - """从问题中提取候选选项(A-or-B 类问题)。""" - q = (question or "").strip() - options: List[str] = [] - - # 1) 引号包裹的片段 - for pat in [r"'([^']+)'", r'\"([^\"]+)\"', r'“([^”]+)”', r'‘([^’]+)’']: - for m in re.findall(pat, q): - val = (m or "").strip() - if val: - options.append(val) - - # 2) or/还是/或者 连接词 - if len(options) < 2: - pats = [ - r"([^,;,;]+?)\s+or\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+还是\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+或者\s+([^,;,;\?\.!.。!]+)", - ] - for pat in pats: - matches = list(re.finditer(pat, q, flags=re.IGNORECASE)) - if matches: - m = matches[-1] - cand1 = m.group(1).strip().strip("??.,,;; ") - cand2 = m.group(2).strip().strip("??.,,;; ") - options.extend([cand1, cand2]) - break - - # 去重 - seen = set() - uniq: List[str] = [] - for o in options: - o2 = o.strip() - key = o2.lower() if not is_chinese_text(o2) else o2 - if o2 and key not in seen: - uniq.append(o2) - seen.add(key) - return uniq - - -def extract_time_entities(text: str) -> List[Dict[str, Any]]: - """增强时间实体提取,专门用于时间推理问题""" - time_entities = [] - - # 日期模式 - date_patterns = [ - (r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', 'date'), # YYYY-MM-DD - (r'\b(\d{1,2})月(\d{1,2})日\b', 'date'), # 中文日期 - (r'\b(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份 - (r'\b(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份缩写 - ] - - # 时间间隔模式 - duration_patterns = [ - (r'(\d+)\s*天', 'days'), - (r'(\d+)\s*周', 'weeks'), - (r'(\d+)\s*个月', 'months'), - (r'(\d+)\s*年', 'years'), - (r'(\d+)\s*days?', 'days'), - (r'(\d+)\s*weeks?', 'weeks'), - (r'(\d+)\s*months?', 'months'), - (r'(\d+)\s*years?', 'years'), - ] - - # 事件时间关系模式 - temporal_relation_patterns = [ - (r'(之前|以前|前)\s*(\d+)\s*天', 'days_before'), - (r'(之后|以后|后)\s*(\d+)\s*天', 'days_after'), - (r'(\d+)\s*天\s*(之前|以前|前)', 'days_before'), - (r'(\d+)\s*天\s*(之后|以后|后)', 'days_after'), - (r'(\d+)\s*days?\s*(before|ago)', 'days_before'), - (r'(\d+)\s*days?\s*(after|later)', 'days_after'), - ] - - # 提取日期 - for pattern, entity_type in date_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间间隔 - for pattern, entity_type in duration_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间关系 - for pattern, entity_type in temporal_relation_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(2)) if match.groups() >= 2 else int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - return time_entities - - -def calculate_time_difference(date1: str, date2: str) -> int: - """计算两个日期之间的天数差""" - try: - # 解析日期格式 - def parse_date(date_str: str) -> datetime: - # 尝试多种日期格式 - formats = [ - '%Y-%m-%d', - '%m月%d日', - '%B %d, %Y', - '%b %d, %Y', - '%Y年%m月%d日' - ] - - for fmt in formats: - try: - return datetime.strptime(date_str, fmt) - except ValueError: - continue - - # 如果都无法解析,返回当前日期 - return datetime.now() - - d1 = parse_date(date1) - d2 = parse_date(date2) - - # 计算天数差(绝对值) - return abs((d2 - d1).days) - except Exception: - return -1 # 表示计算失败 - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """增强版上下文选择:特别优化时间推理问题的处理""" - if not contexts: - return "" - - # 检测是否为时间推理问题 - is_temporal_question = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 提取时间实体从问题中 - question_time_entities = extract_time_entities(question) - - # 英文关键词(去停用词) - question_lower = question.lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but','many','which','first' - } - eng_words = [w for w in set(re.findall(r'\b\w+\b', question_lower)) - if w not in stop_words and len(w) > 2] - - # 中文片段与候选选项 - cn_tokens = generate_query_keywords_cn(question) - options = extract_candidate_options(question) - - # 时间推理问题的特殊处理 - if is_temporal_question: - # 为时间问题添加时间相关关键词 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'days', 'first', '先后'] - eng_words = [w for w in eng_words if w not in ['days', 'first']] # 避免重复 - cn_tokens.extend([kw for kw in time_keywords if kw not in cn_tokens]) - - # 限制关键词数量,优先时间相关 - tokens = time_keywords[:2] + cn_tokens[:2] + eng_words[:1] + options[:1] - else: - # 常规问题处理 - tokens = cn_tokens[:3] + options[:2] + eng_words[:1] - - # 去重 - seen = set() - final_tokens: List[str] = [] - for t in tokens: - t2 = t.strip() - if t2 and t2 not in seen: - final_tokens.append(t2) - seen.add(t2) - - scored_contexts: List[tuple[float, str]] = [] - - # 时间推理问题的权重映射 - temporal_weight_map = { - "天": 2.0, "日": 2.0, "月": 1.8, "年": 1.8, "days": 2.0, - "before": 1.5, "after": 1.5, "first": 1.5, "先后": 1.5 - } - - # 常规问题的权重映射 - normal_weight_map = { - "问题": 2.0, "故障": 2.0, "异常": 1.8, "不正常": 1.8, "坏了": 1.8, - "系统": 1.3, "GPS": 1.5, "保养": 1.4, "设备": 1.2, "模块": 1.2, "功能": 1.1 - } - - weight_map = temporal_weight_map if is_temporal_question else normal_weight_map - - for i, context in enumerate(contexts): - context_str = str(context) - lines = re.split(r'[\r\n]+', context_str) - hit_lines: List[str] = [] - kw_hits: float = 0.0 - time_entity_count = 0 - - for line in lines: - ln = line.strip() - if not ln: - continue - - has_keyword = False - # 关键词匹配 - for tok in final_tokens: - if tok and tok in ln: - w = weight_map.get(tok, 1.0) - kw_hits += ln.count(tok) * w - has_keyword = True - - # 时间实体检测(特别针对时间推理问题) - if is_temporal_question: - time_entities = extract_time_entities(ln) - time_entity_count += len(time_entities) - if time_entities: - has_keyword = True - - if has_keyword: - # 对于时间推理问题,保留包含时间信息的完整行 - hit_lines.append(ln) - - snippet = "\n".join(hit_lines) if hit_lines else context_str.strip() - - # 限制单段长度,但对时间推理问题稍微放宽限制 - max_snippet_len = 600 if is_temporal_question else 500 - if len(snippet) > max_snippet_len: - snippet = snippet[:max_snippet_len] - - # 评分逻辑 - has_number = 1 if re.search(r'\d', snippet) else 0 - has_date = 1 if (re.search(r'\b\d{4}-\d{1,2}-\d{1,2}\b', snippet) or - re.search(r'\d{1,2}月\d{1,2}日', snippet)) else 0 - - # 时间推理问题的特殊评分 - if is_temporal_question: - time_bonus = time_entity_count * 2.0 # 时间实体奖励 - temporal_coherence = 3 if (has_date and time_entity_count >= 2) else 0 - else: - time_bonus = 0 - temporal_coherence = 0 - - length_bonus = 5 if 50 < len(snippet) < 1000 else (2 if len(snippet) >= 1000 else 0) - pos_bonus = 3 if i < 3 else 0 - - score = (kw_hits * 0.8 + (has_number + has_date) * 1.5 + - length_bonus + pos_bonus + time_bonus + temporal_coherence) - - scored_contexts.append((score, snippet)) - - # 选择累计至总字符预算 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - selected: List[str] = [] - total_chars = 0 - - for score, snippet in scored_contexts: - if total_chars + len(snippet) <= max_chars: - selected.append(snippet) - total_chars += len(snippet) - else: - if not selected and len(snippet) > max_chars: - selected.append(snippet[:max_chars]) - break - - final_context = "\n\n".join(selected) - - # 对于时间推理问题,添加时间计算提示 - if is_temporal_question and question_time_entities: - time_prompt = "\n\n[时间推理提示:请仔细分析上述上下文中的日期和时间关系,计算时间间隔或确定事件顺序]" - if total_chars + len(time_prompt) <= max_chars: - final_context += time_prompt - - return final_context - - -# 中文关键词提取(短语级,含数词/日期/常见领域词) -def _extract_cn_tokens(text: str) -> List[str]: - if not text: - return [] - t = str(text) - # 去掉常见功能词(粗略,不依赖分词库) - stop_words = [ - "我","我们","你","他","她","它","这","那","哪","一个","一次","一些","什么","怎么","是否","吗","呢", - "很","更","最","已经","正在","将要","马上","尽快","最近","关于","有关","以及","并且","或者","还是", - "因为","所以","如果","但是","而且","然后","之后","之前","同时","另外","并","但","却","被","把","让","给", - "和","与","跟","及","还有","就","都","在","对","对于","的","了","着","过","到","于","从","以","为","向","至","是" - ] - for sw in stop_words: - t = t.replace(sw, " ") - # 去标点 - t = re.sub(r"[,。!?、;:,.!?;:\"'()()[]\[\]\-—…·]", " ", t) - # 基础中文片段(>=2) - base = re.findall(r"[\u4e00-\u9fff]{2,}", t) - # 特殊组合:第X次XXXX - specials = re.findall(r"第[一二三四五六七八九十]+次[\u4e00-\u9fff]{2,6}", text) - # 领域词(简单词典) - # 日期与数字 - dates = re.findall(r"\d{4}年\d{1,2}月\d{1,2}日|\d{1,2}月\d{1,2}日|\d{4}-\d{1,2}-\d{1,2}", text) - numbers = re.findall(r"\b\d+\b", text) - - tokens: List[str] = specials + base + dates + numbers - - generic = {"建议","推荐","帮助","提升","技能","有效","团队","参与度","喜欢","开始"} - tokens: List[str] = specials + base + dates + numbers - uniq: List[str] = [] - seen = set() - for tok in tokens: - tok2 = tok.strip() - if len(tok2) < 2 or len(tok2) > 6: - continue - if tok2 in generic: - continue - if tok2 not in seen: - uniq.append(tok2) - seen.add(tok2) - # 排除常见疑问型短语 - blacklist_exact = {"是什么","多少","多少天","哪个","哪些","之间","先","后","之前","之后"} - uniq2: List[str] = [u for u in uniq if u not in blacklist_exact] - return uniq2[:12] - - -# 面向检索的中文关键词生成:强调"短语、核心名词、问题/故障" -def generate_query_keywords_cn(question: str) -> List[str]: - if not question: - return [] - raw = _extract_cn_tokens(question) - core: List[str] = [] - seen = set() - - def push(x: str): - x2 = x.strip() - if not x2: - return - if 2 <= len(x2) <= 6 and x2 not in seen: - core.append(x2) - seen.add(x2) - - # 检测时间推理问题 - is_temporal = any(keyword in question for keyword in ['天', '日', 'before', 'after', 'first', '先后', '间隔']) - if is_temporal: - push("天") - push("日") - push("先后") - - # 明确优先的核心词 - if "新车" in question: - push("新车") - # 第X次保养/维修 - specials = re.findall(r"第[一二三四五六七八九十]+次[\u4e00-\u9fff]{2,6}", question) - for s in specials: - if "保养" in s or "维修" in s: - push(s) - if "保养" in question: - push("保养") - # 问题/故障类词,如题含"问题"则扩展同义词 - if "问题" in question: - for w in ["问题","故障","异常","不正常"]: - push(w) - - # 补充:从原始片段筛更短的名词短语(过滤疑问型词) - blacklist = {"是什么","多少","哪个","还是","或者","之间","先","后","之前","之后"} - for tok in raw: - if tok in blacklist: - continue - push(tok) - - # 限制数量,避免过长列表影响检索稳定性 - return core[:4] # 稍微增加限制 - - -# 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - try: - for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) - if rows: - results.extend(rows) - except Exception: - pass - - # 按 name 去重 - deduped: List[Dict[str, Any]] = [] - seen = set() - for r in results: - k = str(r.get("name", "")) - if k and k not in seen: - deduped.append(r) - seen.add(k) - return deduped - - -# 通过对话/陈述中的entity_ids反查实体名称 -_FETCH_ENTITIES_BY_IDS = """ -MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -""" - -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: - if not ids: - return [] - try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) - return rows or [] - except Exception: - return [] - - -# 增强的时间实体检索 -_TIME_ENTITY_SEARCH = """ -MATCH (e:ExtractedEntity) -WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -LIMIT $limit -""" - -async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: - """专门搜索时间相关的实体""" - try: - date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" - rows = await connector.execute_query(_TIME_ENTITY_SEARCH, - date_pattern=date_pattern, - end_user_id=end_user_id, - limit=limit) - return rows or [] - except Exception: - return [] - - -# 中英相对时间解析:today/昨天/上周/3天后 等简单归一化为日期 -def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: - t = str(text) if text is not None else "" - # 英文 today/yesterday/tomorrow - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - - # 英文 X days ago / in X days - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - - # 中文 今天/昨天/明天 - t = re.sub(r"今天", anchor.date().isoformat(), t) - t = re.sub(r"昨日|昨天", (anchor - timedelta(days=1)).date().isoformat(), t) - t = re.sub(r"明天", (anchor + timedelta(days=1)).date().isoformat(), t) - # 中文 X天前 / X天后 - t = re.sub(r"(\d+)天前", lambda m: (anchor - timedelta(days=int(m.group(1)))).date().isoformat(), t) - t = re.sub(r"(\d+)天后", lambda m: (anchor + timedelta(days=int(m.group(1)))).date().isoformat(), t) - # 中文 上周 / 下周(近似7天) - t = re.sub(r"上周", (anchor - timedelta(days=7)).date().isoformat(), t) - t = re.sub(r"下周", (anchor + timedelta(days=7)).date().isoformat(), t) - # 中文 月日(无年份)补全年份 - def _md_repl(m: re.Match[str]) -> str: - mon = int(m.group(1)); day = int(m.group(2)) - return f"{anchor.year}-{mon:02d}-{day:02d}" - t = re.sub(r"(\d{1,2})月(\d{1,2})日", _md_repl, t) - return t - - -async def run_longmemeval_test( - sample_size: int = 3, - end_user_id: str = "longmemeval_zh_bak_3", - search_limit: int = 8, - context_char_budget: int = 4000, - llm_temperature: float = 0.0, - llm_max_tokens: int = 16, - search_type: str = "hybrid", - data_path: str | None = None, - start_index: int = 0, - max_contexts_per_item: int = 2, - save_chunk_output: bool = True, - save_chunk_output_path: str | None = None, - reset_group_before_ingest: bool = False, - skip_ingest: bool = False, -) -> Dict[str, Any]: - """LongMemEval 评估测试:增强时间推理能力""" - - # 数据路径 - if not data_path: - # 固定使用中文数据集:data/longmemeval_oracle_zh.json - zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json") - zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json") - if os.path.exists(zh_proj): - data_path = zh_proj - elif os.path.exists(zh_cwd): - data_path = zh_cwd - else: - raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json,请确保其存在于项目根目录或当前工作目录的 data 目录下。") - - qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) - # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 - if sample_size is None or sample_size <= 0: - items = qa_list[start_index:] - else: - items = qa_list[start_index:start_index + sample_size] - - # 可选:摄入上下文(默认启用) - if not skip_ingest: - # 选择上下文并限量 - contexts: List[str] = [] - for it in items: - built = build_context_from_sessions(it) - full_transcripts = [c for c in built if "\n" in c] - evidence_msgs = [c for c in built if "\n" not in c] - selected: List[str] = [] - take_e = min(len(evidence_msgs), max_contexts_per_item) - selected.extend(evidence_msgs[:take_e]) - remain = max_contexts_per_item - len(selected) - if remain > 0 and full_transcripts: - selected.extend(full_transcripts[:remain]) - if not selected and built: - selected.append(built[0]) - contexts.extend(selected) - - print(f"📥 摄入 {len(contexts)} 个上下文到数据库") - if reset_group_before_ingest and end_user_id: - try: - _tmp_conn = Neo4jConnector() - await _tmp_conn.delete_group(end_user_id) - print(f"🧹 已清空组 {end_user_id} 的历史图数据") - except Exception as _e: - print(f"⚠️ 清空组数据失败(忽略继续): {end_user_id} - {_e}") - finally: - try: - await _tmp_conn.close() - except Exception: - pass - _ingest_fn = ingest_contexts_via_full_pipeline - if _ingest_fn is None: - print("⚠️ 摄入函数不可用,已跳过摄入。请确认 PYTHONPATH 包含 'src' 或从项目根运行。") - else: - await _ingest_fn( - contexts, - end_user_id, - save_chunk_output=save_chunk_output, - save_chunk_output_path=save_chunk_output_path, - ) - - # 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端 - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) - connector = Neo4jConnector() - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 指标收集 - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - per_query_context_counts: List[int] = [] - per_query_context_avg_tokens: List[float] = [] - per_query_context_chars: List[int] = [] - - type_correct: Dict[str, List[float]] = {} - type_f1: Dict[str, List[float]] = {} - type_jacc: Dict[str, List[float]] = {} - - samples: List[Dict[str, Any]] = [] - # 统计重复的上下文预览(跨样本),便于诊断"相同上下文"问题 - preview_counter: Dict[str, int] = {} - - try: - for item in items: - question = item.get("question", "") - reference = item.get("answer", "") - qtype = item.get("question_type") or item.get("type", "unknown") - - print(f"\n=== 处理问题: {question} ===") - - # 检测问题类型 - is_temporal = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 检索 - t0 = time.time() - contexts_all: List[str] = [] - dialogs, statements, entities = [], [], [] - - try: - if search_type == "embedding": - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - - # for sm in summaries: - # summary_text = str(sm.get("summary", "")).strip() - # if summary_text: - # contexts_all.append(summary_text) - - # 实体摘要(最多3个) - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - search_results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - chunks = search_results.get("chunks", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - summaries = search_results.get("summaries", []) - - for c in chunks: - content = str(c.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - for sm in summaries: - summary_text = str(sm.get("summary", "")).strip() - if summary_text: - contexts_all.append(summary_text) - if entities: - entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid(增强版:特别优化时间推理问题) - emb_chunks, emb_statements, emb_entities, emb_summaries, emb_dialogs = [], [], [], [], [] - kw_dialogs, kw_statements, kw_entities = [], [], [] - - # 1) 嵌入检索 - try: - emb_res = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], - ) - if isinstance(emb_res, dict): - emb_chunks = emb_res.get("chunks", []) or [] - emb_statements = emb_res.get("statements", []) or [] - emb_entities = emb_res.get("entities", []) or [] - emb_summaries = emb_res.get("summaries", []) or [] - emb_dialogs = emb_res.get("dialogues", []) or [] - except Exception as e: - print(f"⚠️ 嵌入检索失败,将继续进行关键词检索: {e}") - - # 2) 关键词检索(增强版) - try: - kw_res = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - if isinstance(kw_res, dict): - kw_dialogs = kw_res.get("dialogues", []) or [] - kw_statements = kw_res.get("statements", []) or [] - kw_entities = kw_res.get("entities", []) or [] - - # 时间推理问题的特殊处理 - if is_temporal: - # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) - if time_entities: - kw_entities.extend(time_entities) - # 添加时间相关关键词检索 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'first'] - for tk in time_keywords: - try: - time_res = await search_graph( - connector=connector, - q=tk, - end_user_id=end_user_id, - limit=2, - ) - if isinstance(time_res, dict): - kw_dialogs.extend(time_res.get("dialogues", []) or []) - kw_statements.extend(time_res.get("statements", []) or []) - except Exception: - pass - - # 中文关键词拆分后做别名匹配 - cn_tokens = _extract_cn_tokens(question) - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) - if alias_entities: - kw_entities.extend(alias_entities) - - # 从对话/陈述中的 entity_ids 反查实体 - ids = [] - try: - for d in kw_dialogs: - ids.extend(d.get("entity_ids", []) or []) - for s in kw_statements: - ids.extend(s.get("entity_ids", []) or []) - except Exception: - pass - if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) - if id_entities: - kw_entities.extend(id_entities) - - # 多关键词检索 - try: - eng_words = [w for w in set(re.findall(r"\b\w+\b", question.lower())) if len(w) > 2] - kw_list = generate_query_keywords_cn(question)[:3] + eng_words[:1] - for kw in kw_list: - if not kw: - continue - sub_res = await search_graph( - connector=connector, - q=str(kw), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(sub_res, dict): - kw_dialogs.extend(sub_res.get("dialogues", []) or []) - kw_statements.extend(sub_res.get("statements", []) or []) - kw_entities.extend(sub_res.get("entities", []) or []) - except Exception: - pass - - # 选项参与关键词检索 - try: - opt_list = extract_candidate_options(question)[:2] - for opt in opt_list: - if not opt: - continue - opt_res = await search_graph( - connector=connector, - q=str(opt), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(opt_res, dict): - kw_dialogs.extend(opt_res.get("dialogues", []) or []) - kw_statements.extend(opt_res.get("statements", []) or []) - kw_entities.extend(opt_res.get("entities", []) or []) - except Exception: - pass - except Exception as e: - print(f"❌ 关键词检索失败: {e}") - - # 3) 合并、排序并去重 - all_dialogs = emb_dialogs + kw_dialogs - all_statements = emb_statements + kw_statements - all_entities = emb_entities + kw_entities - - def dedup(items: List[Dict[str, Any]], key_field: str = "uuid") -> List[Dict[str, Any]]: - seen = set() - out = [] - for it in items: - key = str(it.get(key_field, "")) + str(it.get("content", "") + str(it.get("statement", ""))) - if key not in seen: - out.append(it) - seen.add(key) - return out - - # 时间推理问题优先排序包含时间信息的文档 - if is_temporal: - def temporal_score(item: Dict[str, Any]) -> float: - base_score = float(item.get("score", 0.0)) - content = str(item.get("content", "") + str(item.get("statement", ""))) - time_entities = extract_time_entities(content) - time_bonus = len(time_entities) * 0.5 - return base_score + time_bonus - - dialogs = dedup(sorted(all_dialogs, key=temporal_score, reverse=True)) - statements = dedup(sorted(all_statements, key=temporal_score, reverse=True)) - else: - dialogs = dedup(sorted(all_dialogs, key=lambda d: float(d.get("score", 0.0)), reverse=True)) - statements = dedup(sorted(all_statements, key=lambda s: float(s.get("score", 0.0)), reverse=True)) - - entities = dedup(all_entities, key_field="name") - - # 4) 构建上下文 - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体摘要 - try: - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - except Exception: - pass - - # 全局回退 - if not contexts_all and search_type in ("embedding", "hybrid"): - try: - print("🔁 检索为空,回退到关键词检索...") - kw_fallback = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=max(search_limit, 5), - ) - fb_dialogs = kw_fallback.get("dialogues", []) or [] - fb_statements = kw_fallback.get("statements", []) or [] - fb_entities = kw_fallback.get("entities", []) or [] - - for d in fb_dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in fb_statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - if fb_entities: - entity_names = [str(e.get("name", "")).strip() for e in fb_entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - dialogs = fb_dialogs if fb_dialogs else dialogs - statements = fb_statements if fb_statements else statements - entities = fb_entities if fb_entities else entities - print(f"↩️ 回退到关键词检索: {len(fb_dialogs)} 对话, {len(fb_statements)} 条陈述, {len(fb_entities)} 个实体") - except Exception as fe: - print(f"❌ 关键词回退失败: {fe}") - - ent_count = len(entities) if isinstance(entities, list) else 0 - print(f"✅ {search_type}检索成功: {len(dialogs)} 对话, {len(statements)} 条陈述, {ent_count} 个实体") - if is_temporal: - print("⏰ 检测为时间推理问题,已启用时间优化检索") - - except Exception as e: - print(f"❌ {search_type}检索失败: {e}") - contexts_all = [] - - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 智能上下文选择 - context_text = "" - if contexts_all: - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) - # 相对时间解析 - try: - context_text = _resolve_relative_times_cn_en(context_text, anchor=datetime.now()) - except Exception: - pass - # 诊断信息 - try: - cn_diag = generate_query_keywords_cn(question)[:3] - opts = extract_candidate_options(question)[:2] - qlw = [w for w in set(re.findall(r'\b\w+\b', question.lower())) if len(w) > 2][:1] - diag_tokens: List[str] = [] - for t in cn_diag + opts + qlw: - if t and t not in diag_tokens: - diag_tokens.append(t) - print(f"🔍 关键词/选项: {', '.join(diag_tokens)}") - preview = context_text[:200].replace('\n', ' ') - print(f"🔎 上下文预览: {preview}...") - key_preview = preview.strip() - if key_preview: - preview_counter[key_preview] = preview_counter.get(key_preview, 0) + 1 - except Exception: - pass - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # 记录上下文诊断信息 - per_query_context_counts.append(len(contexts_all)) - per_query_context_avg_tokens.append(avg_context_tokens([context_text])) - per_query_context_chars.append(len(context_text)) - - # LLM 推理(增强时间推理提示) - options = extract_candidate_options(question) - if len(options) >= 2: - opt_lines = "\n".join(f"- {o}" for o in options) - # 时间推理问题的特殊提示 - if is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "Return ONLY one string: exactly one option from the provided candidates. If the context is insufficient, respond with 'Unknown'. " - "Pay special attention to date sequences and time intervals." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. Return ONLY one string: exactly one option from the provided candidates. " - "If the context is insufficient, respond with 'Unknown'. If the context expresses a synonym or paraphrase of a candidate, return the closest candidate. " - "Do not include explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": ( - f"Question: {question}\n\nCandidates:\n{opt_lines}\n\nContext:\n{context_text}\n\nReturn EXACTLY one candidate string (or 'Unknown')." - ), - }, - ] - else: - # 时间推理问题的特殊提示 - if is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "If the context contains the answer, return a concise answer phrase focusing on temporal information. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. If the context contains the answer, return a concise answer phrase. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": f"Question: {question}\n\nContext:\n{context_text}\n\nReturn ONLY the answer (or 'Unknown').", - }, - ] - - t2 = time.time() - # 使用异步调用 - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - - # 兼容不同的响应格式 - pred_raw = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - - # 选项题输出规范化 - pred = pred_raw - if len(options) >= 2 and not pred_raw.lower().startswith("unknown"): - def _basic_norm(s: str) -> str: - s = s.lower().strip() - return re.sub(r"[^\w\s]", " ", s) - def _jaccard(a: str, b: str) -> float: - ta = set(t for t in _basic_norm(a).split() if t) - tb = set(t for t in _basic_norm(b).split() if t) - if not ta and not tb: - return 1.0 - if not ta or not tb: - return 0.0 - return len(ta & tb) / len(ta | tb) - best = None - best_score = -1.0 - for o in options: - score = _jaccard(pred_raw, o) - if score > best_score: - best = o - best_score = score - if best is not None and best_score > 0.0: - pred = best - - # 指标 - flag = exact_match(pred, reference) - f1_val = common_f1(str(pred), str(reference)) - j_val = jaccard(str(pred), str(reference)) - - type_correct.setdefault(qtype, []).append(flag) - type_f1.setdefault(qtype, []).append(f1_val) - type_jacc.setdefault(qtype, []).append(j_val) - - samples.append({ - "question": question, - "prediction": pred, - "answer": reference, - "question_type": qtype, - "is_temporal": is_temporal, - "question_id": item.get("question_id"), - "options": options, - "context_count": len(contexts_all), - "context_chars": len(context_text), - "retrieved_dialogue_count": len(dialogs), - "retrieved_statement_count": len(statements), - "metrics": { - "exact_match": bool(flag), - "f1": f1_val, - "jaccard": j_val - }, - "timing": { - "search_ms": (t1 - t0) * 1000, - "llm_ms": (t3 - t2) * 1000 - } - }) - - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {reference}") - print(f"📈 当前指标 - Exact Match: {flag}, F1: {f1_val:.3f}, Jaccard: {j_val:.3f}") - - # 聚合结果 - type_acc = {t: (sum(v) / max(len(v), 1)) for t, v in type_correct.items()} - f1_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_f1.items()} - jacc_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_jacc.items()} - - result = { - "dataset": "longmemeval", - "items": len(items), - "accuracy_by_type": type_acc, - "f1_by_type": f1_by_type, - "jaccard_by_type": jacc_by_type, - "samples": samples, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "context": { - "avg_tokens": statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0, - "avg_chars": statistics.mean(per_query_context_chars) if per_query_context_chars else 0.0, - "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, - }, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "search_type": search_type, - "llm_id": SELECTED_LLM_ID, - "embedding_id": SELECTED_EMBEDDING_ID, - "sample_size": sample_size, - "start_index": start_index, - }, - "timestamp": datetime.now().isoformat() - } - - # 计算汇总指标 - try: - total_items = max(len(samples), 1) - correct_count = sum(1 for s in samples if s.get("metrics", {}).get("exact_match")) - score_accuracy = (correct_count / total_items) * 100.0 - - total_latencies_ms = [] - for s in samples: - t = s.get("timing", {}) - total_latencies_ms.append(float(t.get("search_ms", 0.0)) + float(t.get("llm_ms", 0.0))) - total_lat_stats = latency_stats(total_latencies_ms) if total_latencies_ms else {"p50": 0.0, "iqr": 0.0} - latency_median_s = total_lat_stats.get("p50", 0.0) / 1000.0 - latency_iqr_s = total_lat_stats.get("iqr", 0.0) / 1000.0 - - avg_ctx_tokens = statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0 - avg_ctx_tokens_k = avg_ctx_tokens / 1000.0 - - result["metric_summary"] = { - "score_accuracy": score_accuracy, - "latency_median_s": latency_median_s, - "latency_iqr_s": latency_iqr_s, - "avg_context_tokens_k": avg_ctx_tokens_k, - } - except Exception: - result["metric_summary"] = { - "score_accuracy": 0.0, - "latency_median_s": 0.0, - "latency_iqr_s": 0.0, - "avg_context_tokens_k": 0.0, - } - - # 诊断信息 - try: - dups = sorted([(k, c) for k, c in preview_counter.items() if c > 1], key=lambda x: -x[1])[:5] - result["diagnostics"] = { - "duplicate_previews_top": [{"count": c, "preview": k[:120]} for k, c in dups], - "unique_preview_count": len(preview_counter), - } - except Exception: - pass - - return result - - finally: - await connector.close() - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="LongMemEval 评估测试脚本(增强时间推理版)") - parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)") - parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") - parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") - parser.add_argument("--group-id", type=str, default="longmemeval_zh_bak_3", help="图数据库 Group ID") - parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=16, help="LLM 最大输出 token") - parser.add_argument("--search-type", type=str, default="hybrid", choices=["embedding","keyword","hybrid"], help="检索类型") - parser.add_argument("--data-path", type=str, default=None, help="数据集路径") - parser.add_argument("--max-contexts-per-item", type=int, default=2, help="每条样本最多摄入的上下文段数") - parser.add_argument("--no-save-chunk-output", action="store_true", help="不保存分块结果(默认保存)") - parser.add_argument("--save-chunk-output-path", type=str, default=None, help="自定义分块输出路径") - parser.add_argument("--reset-group-before-ingest", action="store_true", help="摄入前清空该 Group 在图数据库中的历史数据") - parser.add_argument("--skip-ingest", action="store_true", help="跳过摄入,仅检索评估") - args = parser.parse_args() - - sample_size = 0 if args.all else args.sample_size - - result = asyncio.run( - run_longmemeval_test( - sample_size=sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - data_path=args.data_path, - start_index=args.start_index, - max_contexts_per_item=args.max_contexts_per_item, - save_chunk_output=(not args.no_save_chunk_output), - save_chunk_output_path=args.save_chunk_output_path, - reset_group_before_ingest=args.reset_group_before_ingest, - skip_ingest=args.skip_ingest, - ) - ) - - # 打印结果 - print("\n" + "="*50) - print("📊 LongMemEval 测试结果:") - print(f" 样本数量: {result['items']}") - - if result['accuracy_by_type']: - print("\n📈 按问题类型细分:") - for qtype, acc in result['accuracy_by_type'].items(): - print(f" {qtype}:") - print(f" Score (Accuracy): {acc:.3f}") - - print(f"\n📊 指标总览:") - ms = result.get('metric_summary', {}) - print(f" Score (Accuracy): {ms.get('score_accuracy', 0.0):.1f}%") - print(f" Latency (s): median {ms.get('latency_median_s', 0.0):.3f}s") - print(f" Latency IQR (s): {ms.get('latency_iqr_s', 0.0):.3f}s") - print(f" Avg Context Tokens (k): {ms.get('avg_context_tokens_k', 0.0):.3f}k") - - print(f"\n⏱️ 细分性能指标:") - print(f" 检索延迟(均值): {result['latency']['search']['mean']:.1f}ms") - print(f" LLM延迟(均值): {result['latency']['llm']['mean']:.1f}ms") - print(f" 上下文长度(均值): {result['context']['avg_chars']:.0f} 字符") - - - # 保存结果到文件 - try: - out_dir = os.path.join(PROJECT_ROOT, "evaluation", "longmemeval", "results") - os.makedirs(out_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json") - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n💾 结果已保存: {out_path}") - except Exception as e: - print(f"⚠️ 结果保存失败: {e}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/longmemeval/test_eval.py b/api/app/core/memory/evaluation/longmemeval/test_eval.py deleted file mode 100644 index 67bd6ec2..00000000 --- a/api/app/core/memory/evaluation/longmemeval/test_eval.py +++ /dev/null @@ -1,1330 +0,0 @@ -import argparse -import asyncio -import json -import os -import re -import statistics -import time -from datetime import datetime, timedelta -from typing import Any, Dict, List - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -# 与现有评估脚本保持一致的导入方式 -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - jaccard, - latency_stats, -) -from app.core.memory.evaluation.common.metrics import f1_score as common_f1 -from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService - -try: - from app.core.memory.evaluation.common.metrics import exact_match -except Exception: - # 兜底:简单的大小写不敏感比较 - def exact_match(pred: str, ref: str) -> bool: - return str(pred).strip().lower() == str(ref).strip().lower() - - -def load_dataset_any(path: str) -> List[Dict[str, Any]]: - """健壮地加载数据集(兼容 list 或多段 JSON)。""" - with open(path, "r", encoding="utf-8") as f: - s = f.read().strip() - try: - obj = json.loads(s) - if isinstance(obj, list): - return obj - elif isinstance(obj, dict): - return [obj] - except json.JSONDecodeError: - pass - dec = json.JSONDecoder() - idx = 0 - items: List[Dict[str, Any]] = [] - while idx < len(s): - while idx < len(s) and s[idx].isspace(): - idx += 1 - if idx >= len(s): - break - try: - obj, end = dec.raw_decode(s, idx) - if isinstance(obj, list): - for it in obj: - if isinstance(it, dict): - items.append(it) - elif isinstance(obj, dict): - items.append(obj) - idx = end - except json.JSONDecodeError: - nl = s.find("\n", idx) - if nl == -1: - break - idx = nl + 1 - return items - - -def is_chinese_text(s: str) -> bool: - return bool(re.search(r"[\u4e00-\u9fff]", s or "")) - - -def extract_candidate_options(question: str) -> List[str]: - """从问题中提取候选选项(A-or-B 类问题)。""" - q = (question or "").strip() - options: List[str] = [] - - # 1) 引号包裹的片段 - for pat in [r"'([^']+)'", r'\"([^\"]+)\"', r'“([^”]+)”', r'‘([^’]+)’']: - for m in re.findall(pat, q): - val = (m or "").strip() - if val: - options.append(val) - - # 2) or/还是/或者 连接词 - if len(options) < 2: - pats = [ - r"([^,;,;]+?)\s+or\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+还是\s+([^,;,;\?\.!.。!]+)", - r"([^,;,;]+?)\s+或者\s+([^,;,;\?\.!.。!]+)", - ] - for pat in pats: - matches = list(re.finditer(pat, q, flags=re.IGNORECASE)) - if matches: - m = matches[-1] - cand1 = m.group(1).strip().strip("??.,,;; ") - cand2 = m.group(2).strip().strip("??.,,;; ") - options.extend([cand1, cand2]) - break - - # 去重 - seen = set() - uniq: List[str] = [] - for o in options: - o2 = o.strip() - key = o2.lower() if not is_chinese_text(o2) else o2 - if o2 and key not in seen: - uniq.append(o2) - seen.add(key) - return uniq - - -def extract_time_entities(text: str) -> List[Dict[str, Any]]: - """增强时间实体提取,专门用于时间推理问题""" - time_entities = [] - - # 日期模式 - date_patterns = [ - (r'\b(\d{4})-(\d{1,2})-(\d{1,2})\b', 'date'), # YYYY-MM-DD - (r'\b(\d{1,2})月(\d{1,2})日\b', 'date'), # 中文日期 - (r'\b(January|February|March|April|May|June|July|August|September|October|November|December)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份 - (r'\b(Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\s+(\d{1,2}),?\s+(\d{4})?', 'date'), # 英文月份缩写 - ] - - # 时间间隔模式 - duration_patterns = [ - (r'(\d+)\s*天', 'days'), - (r'(\d+)\s*周', 'weeks'), - (r'(\d+)\s*个月', 'months'), - (r'(\d+)\s*年', 'years'), - (r'(\d+)\s*days?', 'days'), - (r'(\d+)\s*weeks?', 'weeks'), - (r'(\d+)\s*months?', 'months'), - (r'(\d+)\s*years?', 'years'), - ] - - # 事件时间关系模式 - temporal_relation_patterns = [ - (r'(之前|以前|前)\s*(\d+)\s*天', 'days_before'), - (r'(之后|以后|后)\s*(\d+)\s*天', 'days_after'), - (r'(\d+)\s*天\s*(之前|以前|前)', 'days_before'), - (r'(\d+)\s*天\s*(之后|以后|后)', 'days_after'), - (r'(\d+)\s*days?\s*(before|ago)', 'days_before'), - (r'(\d+)\s*days?\s*(after|later)', 'days_after'), - ] - - # 提取日期 - for pattern, entity_type in date_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间间隔 - for pattern, entity_type in duration_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - # 提取时间关系 - for pattern, entity_type in temporal_relation_patterns: - matches = re.finditer(pattern, text, re.IGNORECASE) - for match in matches: - time_entities.append({ - 'text': match.group(), - 'type': entity_type, - 'value': int(match.group(2)) if match.groups() >= 2 else int(match.group(1)), - 'start': match.start(), - 'end': match.end() - }) - - return time_entities - - -def calculate_time_difference(date1: str, date2: str) -> int: - """计算两个日期之间的天数差""" - try: - # 解析日期格式 - def parse_date(date_str: str) -> datetime: - # 尝试多种日期格式 - formats = [ - '%Y-%m-%d', - '%m月%d日', - '%B %d, %Y', - '%b %d, %Y', - '%Y年%m月%d日' - ] - - for fmt in formats: - try: - return datetime.strptime(date_str, fmt) - except ValueError: - continue - - # 如果都无法解析,返回当前日期 - return datetime.now() - - d1 = parse_date(date1) - d2 = parse_date(date2) - - # 计算天数差(绝对值) - return abs((d2 - d1).days) - except Exception: - return -1 # 表示计算失败 - - -def _extract_cn_tokens(text: str) -> List[str]: - """中文关键词提取(短语级,含数词/日期/常见领域词)""" - if not text: - return [] - t = str(text) - # 去掉常见功能词(粗略,不依赖分词库) - stop_words = [ - "我","我们","你","他","她","它","这","那","哪","一个","一次","一些","什么","怎么","是否","吗","呢", - "很","更","最","已经","正在","将要","马上","尽快","最近","关于","有关","以及","并且","或者","还是", - "因为","所以","如果","但是","而且","然后","之后","之前","同时","另外","并","但","却","被","把","让","给", - "和","与","跟","及","还有","就","都","在","对","对于","的","了","着","过","到","于","从","以","为","向","至","是" - ] - for sw in stop_words: - t = t.replace(sw, " ") - # 去标点 - t = re.sub(r"[,。!?、;:,.!?;:\"'()()[]\[\]\-—…·]", " ", t) - # 基础中文片段(>=2) - base = re.findall(r"[\u4e00-\u9fff]{2,}", t) - # 特殊组合:第X次XXXX - specials = re.findall(r"第[一二三四五六七八九十]+次[\u4e00-\u9fff]{2,6}", text) - # 日期与数字 - dates = re.findall(r"\d{4}年\d{1,2}月\d{1,2}日|\d{1,2}月\d{1,2}日|\d{4}-\d{1,2}-\d{1,2}", text) - numbers = re.findall(r"\b\d+\b", text) - - generic = {"建议","推荐","帮助","提升","技能","有效","团队","参与度","喜欢","开始"} - tokens: List[str] = specials + base + dates + numbers - uniq: List[str] = [] - seen = set() - for tok in tokens: - tok2 = tok.strip() - if len(tok2) < 2 or len(tok2) > 6: - continue - if tok2 in generic: - continue - if tok2 not in seen: - uniq.append(tok2) - seen.add(tok2) - # 排除常见疑问型短语 - blacklist_exact = {"是什么","多少","多少天","哪个","哪些","之间","先","后","之前","之后"} - uniq2: List[str] = [u for u in uniq if u not in blacklist_exact] - return uniq2[:12] - - -def generate_query_keywords_cn(question: str) -> List[str]: - """增强版关键词提取,特别关注技术术语和专有名词""" - if not question: - return [] - - # 提取专有名词(带引号的内容) - quoted_terms = re.findall(r'["""]([^"""]+)["""]', question) - - # 提取技术术语(中英文混合) - tech_terms = re.findall(r'[A-Z][a-zA-Z]+\s+[A-Z][a-zA-Z]+|[A-Za-z]+[\u4e00-\u9fff]+|[\u4e00-\u9fff]+[A-Za-z]+', question) - - # 提取核心名词短语 - core_nouns = re.findall(r'[\u4e00-\u9fff]{2,5}系统|[\u4e00-\u9fff]{2,5}管理|[\u4e00-\u9fff]{2,5}分析|[\u4e00-\u9fff]{2,5}工作坊|[\u4e00-\u9fff]{2,5}研讨会', question) - - # 基础中文片段 - base_tokens = _extract_cn_tokens(question) - - # 特定领域关键词增强 - domain_keywords = [] - # GPS相关 - if any(term in question for term in ["GPS", "导航", "定位系统", "系统运行"]): - domain_keywords.extend(["GPS", "导航系统", "定位", "系统故障", "功能异常"]) - # 活动相关 - if any(term in question for term in ["工作坊", "研讨会", "网络研讨会", "活动"]): - domain_keywords.extend(["工作坊", "研讨会", "参加", "参与", "活动"]) - # 时间顺序相关 - if any(term in question for term in ["先", "后", "第一个", "之前", "首先"]): - domain_keywords.extend(["先", "后", "之前", "之后", "第一次", "首先"]) - # 设备相关 - if any(term in question for term in ["设备", "手机", "电脑", "笔记本电脑"]): - domain_keywords.extend(["设备", "手机", "电脑", "笔记本电脑", "购买"]) - - # 合并并去重 - all_tokens = quoted_terms + tech_terms + core_nouns + base_tokens + domain_keywords - seen = set() - final_tokens = [] - - for token in all_tokens: - token = token.strip() - if len(token) >= 2 and token not in seen: - final_tokens.append(token) - seen.add(token) - - return final_tokens[:8] - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """增强版上下文选择:特别优化技术术语和精确匹配""" - if not contexts: - return "" - - # 检测是否为时间推理问题 - is_temporal_question = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 提取时间实体从问题中 - question_time_entities = extract_time_entities(question) - - # 提取关键技术实体 - key_entities = [] - # GPS相关 - if any(term in question for term in ["GPS", "导航", "定位系统", "系统运行"]): - key_entities.extend(["GPS", "导航", "定位", "系统", "功能", "问题", "故障"]) - # 活动相关 - if any(term in question for term in ["工作坊", "研讨会", "网络研讨会", "活动"]): - key_entities.extend(["工作坊", "研讨会", "参加", "参与", "活动", "时间"]) - # 时间顺序相关 - if any(term in question for term in ["先", "后", "第一个", "之前", "首先"]): - key_entities.extend(["先", "后", "之前", "之后", "第一次", "首先"]) - - # 英文关键词(去停用词) - question_lower = question.lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but','many','which','first' - } - eng_words = [w for w in set(re.findall(r'\b\w+\b', question_lower)) - if w not in stop_words and len(w) > 2] - - # 中文片段与候选选项 - cn_tokens = generate_query_keywords_cn(question) - options = extract_candidate_options(question) - - # 时间推理问题的特殊处理 - if is_temporal_question: - # 为时间问题添加时间相关关键词 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'days', 'first', '先后'] - eng_words = [w for w in eng_words if w not in ['days', 'first']] # 避免重复 - cn_tokens.extend([kw for kw in time_keywords if kw not in cn_tokens]) - - # 限制关键词数量,优先时间相关 - tokens = time_keywords[:2] + key_entities[:3] + cn_tokens[:2] + eng_words[:1] + options[:1] - else: - # 常规问题处理,优先关键技术实体 - tokens = key_entities[:4] + cn_tokens[:3] + options[:2] + eng_words[:1] - - # 去重 - seen = set() - final_tokens: List[str] = [] - for t in tokens: - t2 = t.strip() - if t2 and t2 not in seen: - final_tokens.append(t2) - seen.add(t2) - - scored_contexts: List[tuple[float, str]] = [] - - # 关键技术实体权重映射 - key_entity_weights = { - "GPS": 3.0, "导航": 2.5, "系统": 2.0, "功能": 2.0, "问题": 2.0, "故障": 2.5, - "工作坊": 2.5, "研讨会": 2.5, "参加": 2.0, "参与": 2.0, - "先": 2.0, "后": 2.0, "之前": 2.0, "之后": 2.0, "第一次": 2.5 - } - - # 时间推理问题的权重映射 - temporal_weight_map = { - "天": 2.0, "日": 2.0, "月": 1.8, "年": 1.8, "days": 2.0, - "before": 1.5, "after": 1.5, "first": 1.5, "先后": 1.5 - } - - # 常规问题的权重映射 - normal_weight_map = { - "问题": 2.0, "故障": 2.0, "异常": 1.8, "不正常": 1.8, "坏了": 1.8, - "系统": 1.3, "GPS": 1.5, "保养": 1.4, "设备": 1.2, "模块": 1.2, "功能": 1.1 - } - - # 合并权重映射 - weight_map = {**normal_weight_map, **temporal_weight_map, **key_entity_weights} - - for i, context in enumerate(contexts): - context_str = str(context) - lines = re.split(r'[\r\n]+', context_str) - hit_lines: List[str] = [] - kw_hits: float = 0.0 - time_entity_count = 0 - key_entity_hits = 0 - - for line in lines: - ln = line.strip() - if not ln: - continue - - has_keyword = False - # 关键词匹配 - for tok in final_tokens: - if tok and tok in ln: - w = weight_map.get(tok, 1.0) - hit_count = ln.count(tok) - kw_hits += hit_count * w - # 关键技术实体额外奖励 - if tok in key_entity_weights: - key_entity_hits += hit_count - has_keyword = True - - # 时间实体检测(特别针对时间推理问题) - if is_temporal_question: - time_entities = extract_time_entities(ln) - time_entity_count += len(time_entities) - if time_entities: - has_keyword = True - - # 精确匹配奖励(完整问题关键词出现在上下文中) - for q_word in question.split(): - if len(q_word) > 3 and q_word in ln: - kw_hits += 0.5 # 精确匹配奖励 - - if has_keyword: - # 对于包含关键信息的行,保留完整行 - hit_lines.append(ln) - - snippet = "\n".join(hit_lines) if hit_lines else context_str.strip() - - # 限制单段长度,但对包含关键信息的上下文稍微放宽限制 - max_snippet_len = 600 if (key_entity_hits > 0 or time_entity_count > 0) else 500 - if len(snippet) > max_snippet_len: - snippet = snippet[:max_snippet_len] - - # 评分逻辑 - has_number = 1 if re.search(r'\d', snippet) else 0 - has_date = 1 if (re.search(r'\b\d{4}-\d{1,2}-\d{1,2}\b', snippet) or - re.search(r'\d{1,2}月\d{1,2}日', snippet)) else 0 - - # 关键技术实体奖励 - key_entity_bonus = key_entity_hits * 1.0 - - # 时间推理问题的特殊评分 - if is_temporal_question: - time_bonus = time_entity_count * 2.0 # 时间实体奖励 - temporal_coherence = 3 if (has_date and time_entity_count >= 2) else 0 - else: - time_bonus = 0 - temporal_coherence = 0 - - length_bonus = 5 if 50 < len(snippet) < 1000 else (2 if len(snippet) >= 1000 else 0) - pos_bonus = 3 if i < 3 else 0 - - score = (kw_hits * 0.8 + (has_number + has_date) * 1.5 + - length_bonus + pos_bonus + time_bonus + temporal_coherence + key_entity_bonus) - - scored_contexts.append((score, snippet)) - - # 选择累计至总字符预算 - scored_contexts.sort(key=lambda x: x[0], reverse=True) - selected: List[str] = [] - total_chars = 0 - - for score, snippet in scored_contexts: - if total_chars + len(snippet) <= max_chars: - selected.append(snippet) - total_chars += len(snippet) - else: - if not selected and len(snippet) > max_chars: - selected.append(snippet[:max_chars]) - break - - final_context = "\n\n".join(selected) - - # 对于时间推理问题,添加时间计算提示 - if is_temporal_question and question_time_entities: - time_prompt = "\n\n[时间推理提示:请仔细分析上述上下文中的日期和时间关系,计算时间间隔或确定事件顺序]" - if total_chars + len(time_prompt) <= max_chars: - final_context += time_prompt - - return final_context - - -# 通过别名匹配进行实体关键词检索(多token合并) -async def _search_entities_by_aliases(connector: Neo4jConnector, tokens: List[str], end_user_id: str | None, limit: int) -> List[Dict[str, Any]]: - results: List[Dict[str, Any]] = [] - try: - for tok in tokens: - rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q=tok, end_user_id=end_user_id, limit=limit) - if rows: - results.extend(rows) - except Exception: - pass - - # 按 name 去重 - deduped: List[Dict[str, Any]] = [] - seen = set() - for r in results: - k = str(r.get("name", "")) - if k and k not in seen: - deduped.append(r) - seen.add(k) - return deduped - - -# 通过对话/陈述中的entity_ids反查实体名称 -_FETCH_ENTITIES_BY_IDS = """ -MATCH (e:ExtractedEntity) -WHERE e.id IN $ids AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -""" - -async def _fetch_entities_by_ids(connector: Neo4jConnector, ids: List[str], end_user_id: str | None) -> List[Dict[str, Any]]: - if not ids: - return [] - try: - rows = await connector.execute_query(_FETCH_ENTITIES_BY_IDS, ids=list({i for i in ids if i}), end_user_id=end_user_id) - return rows or [] - except Exception: - return [] - - -# 增强的时间实体检索 -_TIME_ENTITY_SEARCH = """ -MATCH (e:ExtractedEntity) -WHERE e.entity_type CONTAINS "TIME" OR e.entity_type CONTAINS "DATE" OR e.name =~ $date_pattern -AND ($end_user_id IS NULL OR e.end_user_id = $end_user_id) -RETURN e.id AS id, e.name AS name, e.end_user_id AS end_user_id, e.entity_type AS entity_type -LIMIT $limit -""" - -async def _search_time_entities(connector: Neo4jConnector, end_user_id: str | None, limit: int = 5) -> List[Dict[str, Any]]: - """专门搜索时间相关的实体""" - try: - date_pattern = r".*\d{4}.*|.*\d{1,2}月\d{1,2}日.*" - rows = await connector.execute_query(_TIME_ENTITY_SEARCH, - date_pattern=date_pattern, - end_user_id=end_user_id, - limit=limit) - return rows or [] - except Exception: - return [] - - -# 技术术语专门检索 -async def _search_tech_terms(connector: Neo4jConnector, question: str, end_user_id: str | None, limit: int = 3) -> List[Dict[str, Any]]: - """专门搜索技术术语相关的实体""" - tech_entities = [] - try: - # GPS相关 - if any(term in question for term in ["GPS", "导航", "定位系统"]): - gps_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="GPS", end_user_id=end_user_id, limit=limit) - if gps_rows: - tech_entities.extend(gps_rows) - - # 活动相关 - if any(term in question for term in ["工作坊", "研讨会", "网络研讨会"]): - workshop_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="工作坊", end_user_id=end_user_id, limit=limit) - if workshop_rows: - tech_entities.extend(workshop_rows) - - # 时间顺序相关 - if any(term in question for term in ["先", "后", "第一个"]): - time_rows = await connector.execute_query(SEARCH_ENTITIES_BY_NAME, q="第一次", end_user_id=end_user_id, limit=limit) - if time_rows: - tech_entities.extend(time_rows) - - except Exception: - pass - - return tech_entities - - -# 中英相对时间解析:today/昨天/上周/3天后 等简单归一化为日期 -def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str: - t = str(text) if text is not None else "" - # 英文 today/yesterday/tomorrow - t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE) - - # 英文 X days ago / in X days - def _ago_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor - timedelta(days=n)).date().isoformat() - def _in_repl(m: re.Match[str]) -> str: - n = int(m.group(1)) - return (anchor + timedelta(days=n)).date().isoformat() - t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE) - t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE) - - # 中文 今天/昨天/明天 - t = re.sub(r"今天", anchor.date().isoformat(), t) - t = re.sub(r"昨日|昨天", (anchor - timedelta(days=1)).date().isoformat(), t) - t = re.sub(r"明天", (anchor + timedelta(days=1)).date().isoformat(), t) - # 中文 X天前 / X天后 - t = re.sub(r"(\d+)天前", lambda m: (anchor - timedelta(days=int(m.group(1)))).date().isoformat(), t) - t = re.sub(r"(\d+)天后", lambda m: (anchor + timedelta(days=int(m.group(1)))).date().isoformat(), t) - # 中文 上周 / 下周(近似7天) - t = re.sub(r"上周", (anchor - timedelta(days=7)).date().isoformat(), t) - t = re.sub(r"下周", (anchor + timedelta(days=7)).date().isoformat(), t) - # 中文 月日(无年份)补全年份 - def _md_repl(m: re.Match[str]) -> str: - mon = int(m.group(1)); day = int(m.group(2)) - return f"{anchor.year}-{mon:02d}-{day:02d}" - t = re.sub(r"(\d{1,2})月(\d{1,2})日", _md_repl, t) - return t - - -async def run_longmemeval_test( - sample_size: int = 3, - end_user_id: str = "longmemeval_zh_bak_2", - search_limit: int = 8, - context_char_budget: int = 4000, - llm_temperature: float = 0.0, - llm_max_tokens: int = 16, - search_type: str = "hybrid", - data_path: str | None = None, - start_index: int = 0, -) -> Dict[str, Any]: - """LongMemEval 评估测试:增强技术术语检索能力""" - - # 数据路径 - if not data_path: - # 固定使用中文数据集:data/longmemeval_oracle_zh.json - zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json") - zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json") - if os.path.exists(zh_proj): - data_path = zh_proj - elif os.path.exists(zh_cwd): - data_path = zh_cwd - else: - raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json,请确保其存在于项目根目录或当前工作目录的 data 目录下。") - - qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) - # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 - if sample_size is None or sample_size <= 0: - items = qa_list[start_index:] - else: - items = qa_list[start_index:start_index + sample_size] - - # 初始化组件 - 使用异步LLM客户端 - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) - connector = Neo4jConnector() - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 指标收集 - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - per_query_context_counts: List[int] = [] - per_query_context_avg_tokens: List[float] = [] - per_query_context_chars: List[int] = [] - - type_correct: Dict[str, List[float]] = {} - type_f1: Dict[str, List[float]] = {} - type_jacc: Dict[str, List[float]] = {} - - samples: List[Dict[str, Any]] = [] - # 统计重复的上下文预览(跨样本),便于诊断"相同上下文"问题 - preview_counter: Dict[str, int] = {} - - try: - for item in items: - question = item.get("question", "") - reference = item.get("answer", "") - qtype = item.get("question_type") or item.get("type", "unknown") - - print(f"\n=== 处理问题: {question} ===") - - # 检测问题类型 - is_temporal = any(keyword in question.lower() for keyword in - ['days', 'day', 'before', 'after', 'first', '先后', '顺序', '间隔', '多久', '多少天']) - - # 检索 - t0 = time.time() - contexts_all: List[str] = [] - dialogs, statements, entities = [], [], [] - - try: - if search_type == "embedding": - search_results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["dialogues", "statements", "entities"], - ) - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体摘要(最多3个) - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - elif search_type == "keyword": - search_results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - dialogs = search_results.get("dialogues", []) - statements = search_results.get("statements", []) - entities = search_results.get("entities", []) - - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - if entities: - entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - else: # hybrid(增强版:特别优化技术术语检索) - emb_dialogs, emb_statements, emb_entities = [], [], [] - kw_dialogs, kw_statements, kw_entities = [], [], [] - - # 1) 嵌入检索 - try: - emb_res = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["dialogues", "statements", "entities"], - ) - if isinstance(emb_res, dict): - emb_dialogs = emb_res.get("dialogues", []) or [] - emb_statements = emb_res.get("statements", []) or [] - emb_entities = emb_res.get("entities", []) or [] - except Exception as e: - print(f"⚠️ 嵌入检索失败,将继续进行关键词检索: {e}") - - # 2) 关键词检索(增强版) - try: - kw_res = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - ) - if isinstance(kw_res, dict): - kw_dialogs = kw_res.get("dialogues", []) or [] - kw_statements = kw_res.get("statements", []) or [] - kw_entities = kw_res.get("entities", []) or [] - - # 技术术语专门检索 - tech_entities = await _search_tech_terms(connector, question, end_user_id, search_limit//2) - if tech_entities: - kw_entities.extend(tech_entities) - - # 时间推理问题的特殊处理 - if is_temporal: - # 专门搜索时间实体 - time_entities = await _search_time_entities(connector, end_user_id, search_limit//2) - if time_entities: - kw_entities.extend(time_entities) - # 添加时间相关关键词检索 - time_keywords = ['天', '日', '月', '年', 'before', 'after', 'first'] - for tk in time_keywords: - try: - time_res = await search_graph( - connector=connector, - q=tk, - end_user_id=end_user_id, - limit=2, - ) - if isinstance(time_res, dict): - kw_dialogs.extend(time_res.get("dialogues", []) or []) - kw_statements.extend(time_res.get("statements", []) or []) - except Exception: - pass - - # 中文关键词拆分后做别名匹配 - cn_tokens = generate_query_keywords_cn(question) # 使用增强版关键词提取 - alias_entities = await _search_entities_by_aliases(connector, cn_tokens, end_user_id, search_limit) - if alias_entities: - kw_entities.extend(alias_entities) - - # 从对话/陈述中的 entity_ids 反查实体 - ids = [] - try: - for d in kw_dialogs: - ids.extend(d.get("entity_ids", []) or []) - for s in kw_statements: - ids.extend(s.get("entity_ids", []) or []) - except Exception: - pass - if ids: - id_entities = await _fetch_entities_by_ids(connector, ids, end_user_id) - if id_entities: - kw_entities.extend(id_entities) - - # 多关键词检索(使用增强版关键词) - try: - eng_words = [w for w in set(re.findall(r"\b\w+\b", question.lower())) if len(w) > 2] - kw_list = generate_query_keywords_cn(question)[:4] # 使用更多关键词 - for kw in kw_list: - if not kw: - continue - sub_res = await search_graph( - connector=connector, - q=str(kw), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(sub_res, dict): - kw_dialogs.extend(sub_res.get("dialogues", []) or []) - kw_statements.extend(sub_res.get("statements", []) or []) - kw_entities.extend(sub_res.get("entities", []) or []) - except Exception: - pass - - # 选项参与关键词检索 - try: - opt_list = extract_candidate_options(question)[:2] - for opt in opt_list: - if not opt: - continue - opt_res = await search_graph( - connector=connector, - q=str(opt), - end_user_id=end_user_id, - limit=max(3, search_limit // 2), - ) - if isinstance(opt_res, dict): - kw_dialogs.extend(opt_res.get("dialogues", []) or []) - kw_statements.extend(opt_res.get("statements", []) or []) - kw_entities.extend(opt_res.get("entities", []) or []) - except Exception: - pass - except Exception as e: - print(f"❌ 关键词检索失败: {e}") - - # 3) 合并、排序并去重 - all_dialogs = emb_dialogs + kw_dialogs - all_statements = emb_statements + kw_statements - all_entities = emb_entities + kw_entities - - def dedup(items: List[Dict[str, Any]], key_field: str = "uuid") -> List[Dict[str, Any]]: - seen = set() - out = [] - for it in items: - key = str(it.get(key_field, "")) + str(it.get("content", "") + str(it.get("statement", ""))) - if key not in seen: - out.append(it) - seen.add(key) - return out - - # 关键技术实体优先排序 - def enhanced_score(item: Dict[str, Any]) -> float: - score_val = item.get("score", 0.0) - base_score = float(score_val) if score_val is not None else 0.0 - content = str(item.get("content", "") + str(item.get("statement", ""))) - - # 关键技术实体奖励 - key_entities = [] - if any(term in question for term in ["GPS", "导航", "系统"]): - key_entities.extend(["GPS", "导航", "系统", "功能"]) - if any(term in question for term in ["工作坊", "研讨会", "活动"]): - key_entities.extend(["工作坊", "研讨会", "参加"]) - - key_bonus = 0 - for key_ent in key_entities: - if key_ent in content: - key_bonus += 1.0 - - # 时间实体奖励 - time_bonus = 0 - if is_temporal: - time_entities = extract_time_entities(content) - time_bonus = len(time_entities) * 0.5 - - return base_score + key_bonus + time_bonus - - dialogs = dedup(sorted(all_dialogs, key=enhanced_score, reverse=True)) - statements = dedup(sorted(all_statements, key=enhanced_score, reverse=True)) - entities = dedup(all_entities, key_field="name") - - # 4) 构建上下文 - for d in dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - # 实体摘要 - try: - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - except Exception: - pass - - # 全局回退 - if not contexts_all and search_type in ("embedding", "hybrid"): - try: - print("🔁 检索为空,回退到关键词检索...") - kw_fallback = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=max(search_limit, 5), - ) - fb_dialogs = kw_fallback.get("dialogues", []) or [] - fb_statements = kw_fallback.get("statements", []) or [] - fb_entities = kw_fallback.get("entities", []) or [] - - for d in fb_dialogs: - content = str(d.get("content", "")).strip() - if content: - contexts_all.append(content) - for s in fb_statements: - stmt_text = str(s.get("statement", "")).strip() - if stmt_text: - contexts_all.append(stmt_text) - if fb_entities: - entity_names = [str(e.get("name", "")).strip() for e in fb_entities[:5] if e.get("name")] - if entity_names: - contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") - - dialogs = fb_dialogs if fb_dialogs else dialogs - statements = fb_statements if fb_statements else statements - entities = fb_entities if fb_entities else entities - print(f"↩️ 回退到关键词检索: {len(fb_dialogs)} 对话, {len(fb_statements)} 条陈述, {len(fb_entities)} 个实体") - except Exception as fe: - print(f"❌ 关键词回退失败: {fe}") - - ent_count = len(entities) if isinstance(entities, list) else 0 - print(f"✅ {search_type}检索成功: {len(dialogs)} 对话, {len(statements)} 条陈述, {ent_count} 个实体") - if is_temporal: - print("⏰ 检测为时间推理问题,已启用时间优化检索") - - except Exception as e: - print(f"❌ {search_type}检索失败: {e}") - contexts_all = [] - - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 智能上下文选择 - context_text = "" - if contexts_all: - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) - # 相对时间解析 - try: - context_text = _resolve_relative_times_cn_en(context_text, anchor=datetime.now()) - except Exception: - pass - # 诊断信息 - try: - cn_diag = generate_query_keywords_cn(question)[:4] # 显示更多关键词 - opts = extract_candidate_options(question)[:2] - qlw = [w for w in set(re.findall(r'\b\w+\b', question.lower())) if len(w) > 2][:1] - diag_tokens: List[str] = [] - for t in cn_diag + opts + qlw: - if t and t not in diag_tokens: - diag_tokens.append(t) - print(f"🔍 关键词/选项: {', '.join(diag_tokens)}") - preview = context_text[:200].replace('\n', ' ') - print(f"🔎 上下文预览: {preview}...") - key_preview = preview.strip() - if key_preview: - preview_counter[key_preview] = preview_counter.get(key_preview, 0) + 1 - except Exception: - pass - else: - print("❌ 没有检索到有效上下文") - context_text = "No relevant context found." - - # 记录上下文诊断信息 - per_query_context_counts.append(len(contexts_all)) - per_query_context_avg_tokens.append(avg_context_tokens([context_text])) - per_query_context_chars.append(len(context_text)) - - # LLM 推理(增强技术术语提示) - options = extract_candidate_options(question) - if len(options) >= 2: - opt_lines = "\n".join(f"- {o}" for o in options) - # 技术术语问题的特殊提示 - if any(term in question for term in ["GPS", "系统", "功能", "工作坊", "研讨会"]): - system_prompt = ( - "You are a QA assistant specializing in technical and activity-related questions. " - "Pay special attention to technical terms like GPS, systems, functions, workshops, and seminars. " - "Return ONLY one string: exactly one option from the provided candidates. If the context is insufficient, respond with 'Unknown'. " - "Focus on matching technical details and activity sequences accurately." - ) - elif is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "Return ONLY one string: exactly one option from the provided candidates. If the context is insufficient, respond with 'Unknown'. " - "Pay special attention to date sequences and time intervals." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. Return ONLY one string: exactly one option from the provided candidates. " - "If the context is insufficient, respond with 'Unknown'. If the context expresses a synonym or paraphrase of a candidate, return the closest candidate. " - "Do not include explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": ( - f"Question: {question}\n\nCandidates:\n{opt_lines}\n\nContext:\n{context_text}\n\nReturn EXACTLY one candidate string (or 'Unknown')." - ), - }, - ] - else: - # 技术术语问题的特殊提示 - if any(term in question for term in ["GPS", "系统", "功能", "工作坊", "研讨会"]): - system_prompt = ( - "You are a QA assistant specializing in technical and activity-related questions. " - "Pay special attention to technical terms like GPS, systems, functions, workshops, and seminars. " - "If the context contains the answer, return a concise answer phrase focusing on technical details. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - elif is_temporal: - system_prompt = ( - "You are a QA assistant specializing in temporal reasoning. Analyze the dates and time relationships in the context carefully. " - "If the context contains the answer, return a concise answer phrase focusing on temporal information. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - else: - system_prompt = ( - "You are a QA assistant. Respond in the same language as the question. If the context contains the answer, return a concise answer phrase. " - "If the answer cannot be determined from the context, respond with 'Unknown'. Return ONLY the final answer string, no explanations." - ) - - messages = [ - {"role": "system", "content": system_prompt}, - { - "role": "user", - "content": f"Question: {question}\n\nContext:\n{context_text}\n\nReturn ONLY the answer (or 'Unknown').", - }, - ] - - t2 = time.time() - # 使用异步调用 - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - - # 兼容不同的响应格式 - pred_raw = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown") - - # 选项题输出规范化 - pred = pred_raw - if len(options) >= 2 and not pred_raw.lower().startswith("unknown"): - def _basic_norm(s: str) -> str: - s = s.lower().strip() - return re.sub(r"[^\w\s]", " ", s) - def _jaccard(a: str, b: str) -> float: - ta = set(t for t in _basic_norm(a).split() if t) - tb = set(t for t in _basic_norm(b).split() if t) - if not ta and not tb: - return 1.0 - if not ta or not tb: - return 0.0 - return len(ta & tb) / len(ta | tb) - best = None - best_score = -1.0 - for o in options: - score = _jaccard(pred_raw, o) - if score > best_score: - best = o - best_score = score - if best is not None and best_score > 0.0: - pred = best - - # 指标 - flag = exact_match(pred, reference) - f1_val = common_f1(str(pred), str(reference)) - j_val = jaccard(str(pred), str(reference)) - - type_correct.setdefault(qtype, []).append(flag) - type_f1.setdefault(qtype, []).append(f1_val) - type_jacc.setdefault(qtype, []).append(j_val) - - samples.append({ - "question": question, - "prediction": pred, - "answer": reference, - "question_type": qtype, - "is_temporal": is_temporal, - "question_id": item.get("question_id"), - "options": options, - "context_count": len(contexts_all), - "context_chars": len(context_text), - "retrieved_dialogue_count": len(dialogs), - "retrieved_statement_count": len(statements), - "metrics": { - "exact_match": bool(flag), - "f1": f1_val, - "jaccard": j_val - }, - "timing": { - "search_ms": (t1 - t0) * 1000, - "llm_ms": (t3 - t2) * 1000 - } - }) - - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {reference}") - print(f"📈 当前指标 - Exact Match: {flag}, F1: {f1_val:.3f}, Jaccard: {j_val:.3f}") - - # 聚合结果 - type_acc = {t: (sum(v) / max(len(v), 1)) for t, v in type_correct.items()} - f1_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_f1.items()} - jacc_by_type = {t: (sum(v) / max(len(v), 1)) for t, v in type_jacc.items()} - - result = { - "dataset": "longmemeval", - "items": len(items), - "accuracy_by_type": type_acc, - "f1_by_type": f1_by_type, - "jaccard_by_type": jacc_by_type, - "samples": samples, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "context": { - "avg_tokens": statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0, - "avg_chars": statistics.mean(per_query_context_chars) if per_query_context_chars else 0.0, - "count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0, - }, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "search_type": search_type, - "llm_id": SELECTED_LLM_ID, - "embedding_id": SELECTED_EMBEDDING_ID, - "sample_size": sample_size, - "start_index": start_index, - }, - "timestamp": datetime.now().isoformat() - } - - # 计算汇总指标 - try: - total_items = max(len(samples), 1) - correct_count = sum(1 for s in samples if s.get("metrics", {}).get("exact_match")) - score_accuracy = (correct_count / total_items) * 100.0 - - total_latencies_ms = [] - for s in samples: - t = s.get("timing", {}) - total_latencies_ms.append(float(t.get("search_ms", 0.0)) + float(t.get("llm_ms", 0.0))) - total_lat_stats = latency_stats(total_latencies_ms) if total_latencies_ms else {"p50": 0.0, "iqr": 0.0} - latency_median_s = total_lat_stats.get("p50", 0.0) / 1000.0 - latency_iqr_s = total_lat_stats.get("iqr", 0.0) / 1000.0 - - avg_ctx_tokens = statistics.mean(per_query_context_avg_tokens) if per_query_context_avg_tokens else 0.0 - avg_ctx_tokens_k = avg_ctx_tokens / 1000.0 - - result["metric_summary"] = { - "score_accuracy": score_accuracy, - "latency_median_s": latency_median_s, - "latency_iqr_s": latency_iqr_s, - "avg_context_tokens_k": avg_ctx_tokens_k, - } - except Exception: - result["metric_summary"] = { - "score_accuracy": 0.0, - "latency_median_s": 0.0, - "latency_iqr_s": 0.0, - "avg_context_tokens_k": 0.0, - } - - # 诊断信息 - try: - dups = sorted([(k, c) for k, c in preview_counter.items() if c > 1], key=lambda x: -x[1])[:5] - result["diagnostics"] = { - "duplicate_previews_top": [{"count": c, "preview": k[:120]} for k, c in dups], - "unique_preview_count": len(preview_counter), - } - except Exception: - pass - - return result - - finally: - await connector.close() - - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="LongMemEval 评估测试脚本(增强技术术语检索版)") - parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)") - parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") - parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") - parser.add_argument("--group-id", type=str, default="longmemeval_zh_bak_3", help="图数据库 Group ID") - parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=16, help="LLM 最大输出 token") - parser.add_argument("--search-type", type=str, default="hybrid", choices=["embedding","keyword","hybrid"], help="检索类型") - parser.add_argument("--data-path", type=str, default=None, help="数据集路径") - args = parser.parse_args() - - sample_size = 0 if args.all else args.sample_size - - result = asyncio.run( - run_longmemeval_test( - sample_size=sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - data_path=args.data_path, - start_index=args.start_index, - ) - ) - - # 打印结果 - print("\n" + "="*50) - print("📊 LongMemEval 测试结果:") - print(f" 样本数量: {result['items']}") - - if result['accuracy_by_type']: - print("\n📈 按问题类型细分:") - for qtype, acc in result['accuracy_by_type'].items(): - print(f" {qtype}:") - print(f" Score (Accuracy): {acc:.3f}") - - print(f"\n📊 指标总览:") - ms = result.get('metric_summary', {}) - print(f" Score (Accuracy): {ms.get('score_accuracy', 0.0):.1f}%") - print(f" Latency (s): median {ms.get('latency_median_s', 0.0):.3f}s") - print(f" Latency IQR (s): {ms.get('latency_iqr_s', 0.0):.3f}s") - print(f" Avg Context Tokens (k): {ms.get('avg_context_tokens_k', 0.0):.3f}k") - - print(f"\n⏱️ 细分性能指标:") - print(f" 检索延迟(均值): {result['latency']['search']['mean']:.1f}ms") - print(f" LLM延迟(均值): {result['latency']['llm']['mean']:.1f}ms") - print(f" 上下文长度(均值): {result['context']['avg_chars']:.0f} 字符") - - - # 保存结果到文件 - try: - out_dir = os.path.join(PROJECT_ROOT, "evaluation", "longmemeval", "results") - os.makedirs(out_dir, exist_ok=True) - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json") - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n💾 结果已保存: {out_path}") - except Exception as e: - print(f"⚠️ 结果保存失败: {e}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py b/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py deleted file mode 100644 index 869fdb60..00000000 --- a/api/app/core/memory/evaluation/memsciqa/evaluate_qa.py +++ /dev/null @@ -1,324 +0,0 @@ -import argparse -import asyncio -import json -import os -import time -from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List - -if TYPE_CHECKING: - from app.schemas.memory_config_schema import MemoryConfig - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - exact_match, - latency_stats, -) -from app.core.memory.evaluation.extraction_utils import ( - ingest_contexts_via_full_pipeline, -) -from app.core.memory.storage_services.search import run_hybrid_search -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.db import get_db_context -from app.repositories.neo4j.neo4j_connector import Neo4jConnector - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """基于问题关键词对上下文进行评分选择,并在预算内拼接文本。""" - if not contexts: - return "" - import re - # 提取问题关键词(移除停用词) - question_lower = (question or "").lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but' - } - question_words = set(re.findall(r"\b\w+\b", question_lower)) - question_words = {w for w in question_words if w not in stop_words and len(w) > 2} - - # 评分 - scored = [] - for i, ctx in enumerate(contexts): - ctx_lower = (ctx or "").lower() - score = 0 - matches = 0 - for w in question_words: - if w in ctx_lower: - matches += 1 - score += ctx_lower.count(w) * 2 - length = len(ctx) - if 100 < length < 2000: - score += 5 - elif length >= 2000: - score += 2 - if i < 3: - score += 3 - scored.append((score, ctx, matches)) - - scored.sort(key=lambda x: x[0], reverse=True) - - # 选择直到达到字符限制,必要时截断包含关键词的段落 - selected: List[str] = [] - total = 0 - for score, ctx, _ in scored: - if total + len(ctx) <= max_chars: - selected.append(ctx) - total += len(ctx) - else: - if score > 10 and total < max_chars - 200: - remaining = max_chars - total - lines = ctx.split('\n') - rel_lines: List[str] = [] - cur = 0 - for line in lines: - l = line.lower() - if any(w in l for w in question_words) and cur < remaining - 50: - rel_lines.append(line) - cur += len(line) - if rel_lines: - truncated = '\n'.join(rel_lines) - if len(truncated) > 50: - selected.append(truncated + "\n[相关内容截断...]") - total += len(truncated) - break - return "\n\n".join(selected) - - -def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str: - """Compose a text context from `dialog` list in msc_self_instruct item.""" - parts: List[str] = [] - for turn in dialog_obj.get("dialog", []): - speaker = turn.get("speaker", "") - text = turn.get("text", "") - if text: - parts.append(f"{speaker}: {text}") - return "\n".join(parts) - - -def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]: - """Combine dialogues from embedding and keyword searches (embedding first).""" - if results is None: - return [] - emb = [] - kw = [] - if isinstance(results.get("embedding_search"), dict): - emb = results.get("embedding_search", {}).get("dialogues", []) or [] - elif isinstance(results.get("dialogues"), list): - emb = results.get("dialogues", []) or [] - if isinstance(results.get("keyword_search"), dict): - kw = results.get("keyword_search", {}).get("dialogues", []) or [] - seen = set() - merged: List[Dict[str, Any]] = [] - for d in emb: - k = (str(d.get("uuid", "")), str(d.get("content", ""))) - if k not in seen: - merged.append(d) - seen.add(k) - for d in kw: - k = (str(d.get("uuid", "")), str(d.get("content", ""))) - if k not in seen: - merged.append(d) - seen.add(k) - return merged - - -async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: - end_user_id = end_user_id or SELECTED_GROUP_ID - # Load data - data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") - if not os.path.exists(data_path): - data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl") - with open(data_path, "r", encoding="utf-8") as f: - lines = f.readlines() - items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]] - # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 - # 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 - contexts: List[str] = [build_context_from_dialog(item) for item in items] - await ingest_contexts_via_full_pipeline(contexts, end_user_id) - - # LLM client (使用异步调用) - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm_client = factory.get_llm_client(SELECTED_LLM_ID) - - # Evaluate each item - connector = Neo4jConnector() - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - contexts_used: List[str] = [] - correct_flags: List[float] = [] - f1s: List[float] = [] - b1s: List[float] = [] - jss: List[float] = [] - try: - for item in items: - question = item.get("self_instruct", {}).get("B", "") or item.get("question", "") - reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "") - # 检索:对齐 locomo 的三路检索(dialogues/statements/entities) - t0 = time.time() - try: - results = await run_hybrid_search( - query_text=question, - search_type=search_type, - end_user_id=end_user_id, - limit=search_limit, - include=["dialogues", "statements", "entities"], - output_path=None, - memory_config=memory_config, - ) - except Exception: - results = None - t1 = time.time() - latencies_search.append((t1 - t0) * 1000) - - # 构建上下文:包含对话、陈述和实体摘要,并智能选择 - contexts_all: List[str] = [] - if results: - if search_type == "hybrid": - emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {} - kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {} - emb_dialogs = emb.get("dialogues", []) - emb_statements = emb.get("statements", []) - emb_entities = emb.get("entities", []) - kw_dialogs = kw.get("dialogues", []) - kw_statements = kw.get("statements", []) - kw_entities = kw.get("entities", []) - all_dialogs = emb_dialogs + kw_dialogs - all_statements = emb_statements + kw_statements - all_entities = emb_entities + kw_entities - - # 简单去重与限制 - seen_texts = set() - for d in all_dialogs: - text = str(d.get("content", "")).strip() - if text and text not in seen_texts: - contexts_all.append(text) - seen_texts.add(text) - if len(contexts_all) >= search_limit: - break - for s in all_statements: - text = str(s.get("statement", "")).strip() - if text and text not in seen_texts: - contexts_all.append(text) - seen_texts.add(text) - if len(contexts_all) >= search_limit: - break - # 实体摘要(最多3个) - names = [] - merged_entities = all_entities[:] - for e in merged_entities: - name = str(e.get("name", "")).strip() - if name and name not in names: - names.append(name) - if len(names) >= 3: - break - if names: - contexts_all.append("EntitySummary: " + ", ".join(names)) - else: - dialogs = results.get("dialogues", []) - statements = results.get("statements", []) - entities = results.get("entities", []) - for d in dialogs: - text = str(d.get("content", "")).strip() - if text: - contexts_all.append(text) - for s in statements: - text = str(s.get("statement", "")).strip() - if text: - contexts_all.append(text) - names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")] - if names: - contexts_all.append("EntitySummary: " + ", ".join(names)) - - # 智能选择并截断到预算 - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else "" - if not context_text: - context_text = "No relevant context found." - contexts_used.append(context_text[:200]) - - # Call LLM (使用异步调用) - messages = [ - {"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."}, - {"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"}, - ] - t2 = time.time() - resp = await llm_client.chat(messages=messages) - t3 = time.time() - latencies_llm.append((t3 - t2) * 1000) - pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip()) - # Metrics: F1, BLEU-1, Jaccard; keep exact match for reference - correct_flags.append(exact_match(pred, reference)) - from app.core.memory.evaluation.common.metrics import ( - bleu1, - f1_score, - jaccard, - ) - f1s.append(f1_score(str(pred), str(reference))) - b1s.append(bleu1(str(pred), str(reference))) - jss.append(jaccard(str(pred), str(reference))) - - # Aggregate metrics - acc = sum(correct_flags) / max(len(correct_flags), 1) - ctx_avg_tokens = avg_context_tokens(contexts_used) - result = { - "dataset": "memsciqa", - "items": len(items), - "metrics": { - "accuracy": acc, - # Placeholders for extensibility - "f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0, - "bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0, - "jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0, - }, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "avg_context_tokens": ctx_avg_tokens, - } - return result - finally: - await connector.close() - - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") - parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") - parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") - parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度") - parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型") - args = parser.parse_args() - - result = asyncio.run( - run_memsciqa_eval( - sample_size=args.sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - ) - ) - print(json.dumps(result, ensure_ascii=False, indent=2)) - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py b/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py deleted file mode 100644 index 3023020a..00000000 --- a/api/app/core/memory/evaluation/memsciqa/memsciqa-test.py +++ /dev/null @@ -1,577 +0,0 @@ -import argparse -import asyncio -import json -import os -import re -import time -from datetime import datetime -from typing import Any, Dict, List - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -# 路径与模块导入保持与现有评估脚本一致 -import sys -from pathlib import Path - -_THIS_DIR = Path(__file__).resolve().parent -_PROJECT_ROOT = str(_THIS_DIR.parents[1]) -_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") -for _p in (_SRC_DIR, _PROJECT_ROOT): - if _p not in sys.path: - sys.path.insert(0, _p) - -# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1 -from app.core.memory.evaluation.common.metrics import ( - avg_context_tokens, - exact_match, - latency_stats, -) -from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient -from app.core.memory.utils.config.definitions import ( - PROJECT_ROOT, - SELECTED_EMBEDDING_ID, - SELECTED_GROUP_ID, - SELECTED_LLM_ID, -) -from app.core.memory.utils.llm.llm_utils import MemoryClientFactory -from app.core.models.base import RedBearModelConfig -from app.db import get_db_context -from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.services.memory_config_service import MemoryConfigService - -try: - from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard -except Exception: - # 兜底:简单实现(必要时) - def f1_score(pred: str, ref: str) -> float: - ps = pred.lower().split() - rs = ref.lower().split() - if not ps or not rs: - return 0.0 - tp = len(set(ps) & set(rs)) - if tp == 0: - return 0.0 - precision = tp / len(ps) - recall = tp / len(rs) - if precision + recall == 0: - return 0.0 - return 2 * precision * recall / (precision + recall) - - def bleu1(pred: str, ref: str) -> float: - ps = pred.lower().split() - rs = ref.lower().split() - if not ps or not rs: - return 0.0 - overlap = len([w for w in ps if w in rs]) - return overlap / max(len(ps), 1) - - def jaccard(pred: str, ref: str) -> float: - ps = set(pred.lower().split()) - rs = set(ref.lower().split()) - union = len(ps | rs) - if union == 0: - return 0.0 - return len(ps & rs) / union - - -def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: - """基于问题关键词对上下文进行评分选择,并在预算内拼接文本。 - - 参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。 - """ - if not contexts: - return "" - question_lower = (question or "").lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but' - } - question_words = set(re.findall(r"\b\w+\b", question_lower)) - question_words = {w for w in question_words if w not in stop_words and len(w) > 2} - - scored = [] - for i, ctx in enumerate(contexts): - ctx_lower = (ctx or "").lower() - score = 0 - matches = 0 - for w in question_words: - if w in ctx_lower: - matches += 1 - score += ctx_lower.count(w) * 2 - length = len(ctx) - if 100 < length < 2000: - score += 5 - elif length >= 2000: - score += 2 - if i < 3: - score += 3 - scored.append((score, ctx, matches)) - - scored.sort(key=lambda x: x[0], reverse=True) - - selected: List[str] = [] - total = 0 - for score, ctx, _ in scored: - if total + len(ctx) <= max_chars: - selected.append(ctx) - total += len(ctx) - else: - if score > 10 and total < max_chars - 200: - remaining = max_chars - total - lines = ctx.split('\n') - rel_lines: List[str] = [] - cur = 0 - for line in lines: - l = line.lower() - if any(w in l for w in question_words) and cur < remaining - 50: - rel_lines.append(line) - cur += len(line) - if rel_lines: - truncated = '\n'.join(rel_lines) - if len(truncated) > 50: - selected.append(truncated + "\n[相关内容截断...]") - total += len(truncated) - break - return "\n\n".join(selected) - - -def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]: - """提取问题中的关键词(简单英文分词,去停用词,长度>=3)。""" - ql = (question or "").lower() - stop_words = { - 'what','when','where','who','why','how','did','do','does','is','are','was','were', - 'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this' - } - words = re.findall(r"\b[\w-]+\b", ql) - kws = [w for w in words if w not in stop_words and len(w) >= 3] - # 去重保序 - seen = set() - uniq = [] - for w in kws: - if w not in seen: - uniq.append(w) - seen.add(w) - if len(uniq) >= max_keywords: - break - return uniq - - -def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]: - """对上下文进行简单相关性打分,仅用于控制台可视化。 - - 评分: score = match_count*200 + min(len(text), 100000)/100 - """ - results = [] - for ctx in contexts: - tl = (ctx or "").lower() - match_count = sum(1 for k in keywords if k in tl) - length = len(ctx) - score = match_count * 200 + min(length, 100000) / 100.0 - results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length}) - results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True) - return results[:max(top_n, 0)] - - -# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py - - -def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]: - if not os.path.exists(data_path): - raise FileNotFoundError(f"未找到数据集: {data_path}") - items: List[Dict[str, Any]] = [] - with open(data_path, "r", encoding="utf-8") as f: - for line in f: - line = line.strip() - if not line: - continue - try: - items.append(json.loads(line)) - except Exception: - # 跳过坏行但不中断 - continue - return items - - -async def run_memsciqa_test( - sample_size: int = 3, - end_user_id: str | None = None, - search_limit: int = 8, - context_char_budget: int = 4000, - llm_temperature: float = 0.0, - llm_max_tokens: int = 64, - search_type: str = "embedding", - data_path: str | None = None, - start_index: int = 0, - verbose: bool = True, -) -> Dict[str, Any]: - """memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。 - - - 支持从指定索引开始与评估全部样本(sample_size<=0) - - 支持在摄入前重置组(清空图)与跳过摄入 - - 支持 keyword / embedding / hybrid 三种检索 - """ - - # 默认使用指定的 memsci 组 ID - end_user_id = end_user_id or "group_memsci" - - # 数据路径解析(项目根与当前工作目录兜底) - if not data_path: - proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") - cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl") - if os.path.exists(proj_path): - data_path = proj_path - elif os.path.exists(cwd_path): - data_path = cwd_path - else: - raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。") - - # 加载数据 - all_items = load_dataset_memsciqa(data_path) - if sample_size is None or sample_size <= 0: - items = all_items[start_index:] - else: - items = all_items[start_index:start_index + sample_size] - - # 初始化 LLM(纯测试:不进行摄入) - with get_db_context() as db: - factory = MemoryClientFactory(db) - llm = factory.get_llm_client(SELECTED_LLM_ID) - - # 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test) - connector = Neo4jConnector() - embedder = None - if search_type in ("embedding", "hybrid"): - with get_db_context() as db: - config_service = MemoryConfigService(db) - cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID) - embedder = OpenAIEmbedderClient( - model_config=RedBearModelConfig.model_validate(cfg_dict) - ) - - # 评估循环 - latencies_llm: List[float] = [] - latencies_search: List[float] = [] - # 存储完整上下文文本用于统计 - contexts_used: List[str] = [] - per_query_context_chars: List[int] = [] - per_query_context_counts: List[int] = [] - correct_flags: List[float] = [] - f1s: List[float] = [] - b1s: List[float] = [] - jss: List[float] = [] - samples: List[Dict[str, Any]] = [] - - total_items = len(items) - for idx, item in enumerate(items): - if verbose: - print(f"\n🧪 评估样本: {idx+1}/{total_items}") - question = item.get("self_instruct", {}).get("B", "") or item.get("question", "") - reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "") - - # 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py) - t0 = time.time() - results = None - try: - if search_type in ("embedding", "hybrid"): - # 使用嵌入检索(与 qwen_search_eval 对齐) - results = await search_graph_by_embedding( - connector=connector, - embedder_client=embedder, - query_text=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues - ) - elif search_type == "keyword": - # 关键词检索(直接调用 graph_search) - results = await search_graph( - connector=connector, - q=question, - end_user_id=end_user_id, - limit=search_limit, - include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues - ) - except Exception: - results = None - t1 = time.time() - search_ms = (t1 - t0) * 1000 - latencies_search.append(search_ms) - - # 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py) - contexts_all: List[str] = [] - retrieved_counts: Dict[str, int] = {} - if results: - chunks = results.get("chunks", []) - statements = results.get("statements", []) - entities = results.get("entities", []) - summaries = results.get("summaries", []) - retrieved_counts = { - "chunks": len(chunks), - "statements": len(statements), - "entities": len(entities), - "summaries": len(summaries), - } - # 优先使用 chunks - for c in chunks: - text = str(c.get("content", "")).strip() - if text: - contexts_all.append(text) - # 然后是 statements - for s in statements: - text = str(s.get("statement", "")).strip() - if text: - contexts_all.append(text) - # 然后是 summaries - for sm in summaries: - text = str(sm.get("summary", "")).strip() - if text: - contexts_all.append(text) - # 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py) - scored = [e for e in entities if e.get("score") is not None] - top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] - if top_entities: - summary_lines = [] - for e in top_entities: - name = str(e.get("name", "")).strip() - etype = str(e.get("entity_type", "")).strip() - score = e.get("score") - if name: - meta = [] - if etype: - meta.append(f"type={etype}") - if isinstance(score, (int, float)): - meta.append(f"score={score:.3f}") - summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}") - if summary_lines: - contexts_all.append("\n".join(summary_lines)) - - if verbose: - if retrieved_counts: - print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要") - print(f"📊 有效上下文数量: {len(contexts_all)}") - q_keywords = extract_question_keywords(question, max_keywords=8) - if q_keywords: - print(f"🔍 问题关键词: {set(q_keywords)}") - if contexts_all: - analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5) - if analysis: - print("📊 上下文相关性分析:") - for a in analysis: - print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}") - # 打印检索到的上下文预览,便于定位为何为 Unknown - print("🔎 上下文预览(最多前10条,每条截断展示):") - for i, ctx in enumerate(contexts_all[:10]): - preview = str(ctx).replace("\n", " ") - if len(preview) > 300: - preview = preview[:300] + "..." - print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}") - # 标注参考答案是否出现在任一上下文中 - ref_lower = (str(reference) or "").lower() - if ref_lower: - hits = [] - for i, ctx in enumerate(contexts_all): - if ref_lower in str(ctx).lower(): - hits.append(i+1) - print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else "")) - - context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else "" - if not context_text: - context_text = "No relevant context found." - contexts_used.append(context_text) - per_query_context_chars.append(len(context_text)) - per_query_context_counts.append(len(contexts_all)) - - if verbose: - selected_count = (context_text.count("\n\n") + 1) if context_text else 0 - print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符") - # 展示拼接后的上下文片段,便于核查是否包含答案 - concat_preview = context_text.replace("\n", " ") - if len(concat_preview) > 600: - concat_preview = concat_preview[:600] + "..." - print(f"🧵 拼接上下文预览: {concat_preview}") - - messages = [ - { - "role": "system", - "content": ( - "You are a QA assistant. Answer in English. Follow these guidelines:\n" - "1) If the context contains information to answer the question, provide a concise answer based on the context;\n" - "2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n" - "3) Keep your answer brief and to the point;\n" - "4) Do not add explanations or additional text beyond the answer." - ), - }, - {"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"}, - ] - - t2 = time.time() - try: - # 使用异步调用 - resp = await llm.chat(messages=messages) - # 更健壮的响应解析,处理不同的LLM响应格式 - if hasattr(resp, 'content'): - pred = resp.content.strip() - elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0: - pred = resp["choices"][0]["message"]["content"].strip() - elif isinstance(resp, dict) and "content" in resp: - pred = resp["content"].strip() - elif isinstance(resp, str): - pred = resp.strip() - else: - pred = "Unknown" - print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}") - - # 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案 - if pred.lower() in ["unknown", ""]: - # 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题 - ref_lower = (str(reference) or "").lower() - if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all): - print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词") - except Exception as e: - # 更详细的错误处理 - pred = "Unknown" - print(f"⚠️ LLM调用异常: {e}") - t3 = time.time() - llm_ms = (t3 - t2) * 1000 - latencies_llm.append(llm_ms) - - exact = exact_match(pred, reference) - correct_flags.append(exact) - f1_val = f1_score(str(pred), str(reference)) - b1_val = bleu1(str(pred), str(reference)) - j_val = jaccard(str(pred), str(reference)) - f1s.append(f1_val) - b1s.append(b1_val) - jss.append(j_val) - - if verbose: - print(f"🤖 LLM 回答: {pred}") - print(f"✅ 正确答案: {reference}") - print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}") - print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms") - - # 对齐 locomo/qwen_search_eval.py 的样本输出结构 - samples.append({ - "question": str(question), - "answer": str(reference), - "prediction": str(pred), - "metrics": { - "f1": f1_val, - "b1": b1_val, - "j": j_val - }, - "retrieval": { - "retrieved_documents": len(contexts_all), - "context_length": len(context_text), - "search_limit": search_limit, - "max_chars": context_char_budget - }, - "timing": { - "search_ms": search_ms, - "llm_ms": llm_ms - } - }) - - # 计算总体指标与聚合 - acc = sum(correct_flags) / max(len(correct_flags), 1) - ctx_avg_tokens = avg_context_tokens(contexts_used) - result = { - "dataset": "memsciqa", - "items": len(items), - "metrics": { - "f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0, - "b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0, - "j": (sum(jss) / max(len(jss), 1)) if jss else 0.0, - }, - "context": { - "avg_tokens": ctx_avg_tokens, - "avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0, - "count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0, - "avg_memory_tokens": 0.0 - }, - "latency": { - "search": latency_stats(latencies_search), - "llm": latency_stats(latencies_llm), - }, - "samples": samples, - "params": { - "end_user_id": end_user_id, - "search_limit": search_limit, - "context_char_budget": context_char_budget, - "llm_temperature": llm_temperature, - "llm_max_tokens": llm_max_tokens, - "search_type": search_type, - "start_index": start_index, - "llm_id": SELECTED_LLM_ID, - "retrieval_embedding_id": SELECTED_EMBEDDING_ID - }, - "timestamp": datetime.now().isoformat(), - } - try: - await connector.close() - except Exception: - pass - return result - - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)") - parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)") - parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)") - parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") - parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)") - parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") - parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") - parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") - parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token") - parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)") - parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)") - parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)") - parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)") - parser.add_argument("--quiet", action="store_true", help="关闭过程日志") - args = parser.parse_args() - - sample_size = 0 if args.all else args.sample_size - - verbose_flag = False if args.quiet else args.verbose - result = asyncio.run( - run_memsciqa_test( - sample_size=sample_size, - end_user_id=args.end_user_id, - search_limit=args.search_limit, - context_char_budget=args.context_char_budget, - llm_temperature=args.llm_temperature, - llm_max_tokens=args.llm_max_tokens, - search_type=args.search_type, - data_path=args.data_path, - start_index=args.start_index, - verbose=verbose_flag, - ) - ) - - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 结果保存 - out_path = args.output - if not out_path: - eval_dir = os.path.dirname(os.path.abspath(__file__)) - dataset_results_dir = os.path.join(eval_dir, "results") - ts = datetime.now().strftime("%Y%m%d_%H%M%S") - out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json") - try: - os.makedirs(os.path.dirname(out_path), exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n💾 结果已保存: {out_path}") - except Exception as e: - print(f"⚠️ 结果保存失败: {e}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/evaluation/run_eval.py b/api/app/core/memory/evaluation/run_eval.py deleted file mode 100644 index c5aacb2f..00000000 --- a/api/app/core/memory/evaluation/run_eval.py +++ /dev/null @@ -1,150 +0,0 @@ -import argparse -import asyncio -import json -import os -import sys -from typing import Any, Dict - -# Add src directory to Python path for proper imports when running from evaluation directory -sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src')) - -try: - from dotenv import load_dotenv -except Exception: - def load_dotenv(): - return None - -from app.repositories.neo4j.neo4j_connector import Neo4jConnector -from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT - -from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval -from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test -from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval - - -async def run( - dataset: str, - sample_size: int, - reset_group: bool, - end_user_id: str | None, - judge_model: str | None = None, - search_limit: int | None = None, - context_char_budget: int | None = None, - llm_temperature: float | None = None, - llm_max_tokens: int | None = None, - search_type: str | None = None, - start_index: int | None = None, - max_contexts_per_item: int | None = None, -) -> Dict[str, Any]: - # 恢复原始风格:统一入口做路由,并沿用各数据集既有默认 - end_user_id = end_user_id or SELECTED_GROUP_ID - - if reset_group: - connector = Neo4jConnector() - try: - await connector.delete_group(end_user_id) - finally: - await connector.close() - - if dataset == "locomo": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} - if search_limit is not None: - kwargs["search_limit"] = search_limit - if context_char_budget is not None: - kwargs["context_char_budget"] = context_char_budget - if llm_temperature is not None: - kwargs["llm_temperature"] = llm_temperature - if llm_max_tokens is not None: - kwargs["llm_max_tokens"] = llm_max_tokens - if search_type is not None: - kwargs["search_type"] = search_type - return await run_locomo_eval(**kwargs) - - if dataset == "memsciqa": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} - if search_limit is not None: - kwargs["search_limit"] = search_limit - if context_char_budget is not None: - kwargs["context_char_budget"] = context_char_budget - if llm_temperature is not None: - kwargs["llm_temperature"] = llm_temperature - if llm_max_tokens is not None: - kwargs["llm_max_tokens"] = llm_max_tokens - if search_type is not None: - kwargs["search_type"] = search_type - return await run_memsciqa_eval(**kwargs) - - if dataset == "longmemeval": - kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id} - if search_limit is not None: - kwargs["search_limit"] = search_limit - if context_char_budget is not None: - kwargs["context_char_budget"] = context_char_budget - if llm_temperature is not None: - kwargs["llm_temperature"] = llm_temperature - if llm_max_tokens is not None: - kwargs["llm_max_tokens"] = llm_max_tokens - if search_type is not None: - kwargs["search_type"] = search_type - if start_index is not None: - kwargs["start_index"] = start_index - if max_contexts_per_item is not None: - kwargs["max_contexts_per_item"] = max_contexts_per_item - return await run_longmemeval_test(**kwargs) - raise ValueError(f"未知数据集: {dataset}") - - -def main(): - load_dotenv() - parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo") - parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True) - parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通") - parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据") - parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json") - parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名") - parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)") - parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)") - parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)") - parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)") - parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)") - # 仅透传到 longmemeval;其他数据集忽略 - parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)") - parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)") - parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation//results 目录") - args = parser.parse_args() - - result = asyncio.run(run( - args.dataset, - args.sample_size, - args.reset_group, - args.end_user_id, - args.judge_model, - args.search_limit, - args.context_char_budget, - args.llm_temperature, - args.llm_max_tokens, - args.search_type, - args.start_index, - args.max_contexts_per_item, - )) - print(json.dumps(result, ensure_ascii=False, indent=2)) - - # 结果输出逻辑保持不变 - if args.output: - out_path = args.output - else: - eval_dir = os.path.dirname(os.path.abspath(__file__)) - dataset_results_dir = os.path.join(eval_dir, args.dataset, "results") - out_filename = f"{args.dataset}_{args.sample_size}.json" - out_path = os.path.join(dataset_results_dir, out_filename) - - out_dir = os.path.dirname(out_path) - if out_dir and not os.path.exists(out_dir): - os.makedirs(out_dir, exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - json.dump(result, f, ensure_ascii=False, indent=2) - print(f"\n结果已保存到: {out_path}") - - -if __name__ == "__main__": - main() diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index 8c69c7cf..7b7e854b 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1064,13 +1064,16 @@ class ExtractionOrchestrator: if statement.triplet_extraction_info: triplet_info = statement.triplet_extraction_info - # 创建实体索引到ID的映射 + # 创建实体索引到ID的映射(支持多种索引方式) entity_idx_to_id = {} # 创建实体节点 for entity_idx, entity in enumerate(triplet_info.entities): - # 映射实体索引到实体ID + # 映射实体索引到实体ID(使用多个键以提高容错性) + # 1. 使用实体自己的 entity_idx entity_idx_to_id[entity.entity_idx] = entity.id + # 2. 使用枚举索引(从0开始) + entity_idx_to_id[entity_idx] = entity.id if entity.id not in entity_id_set: entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') @@ -1149,9 +1152,18 @@ class ExtractionOrchestrator: relationship_result ) else: - logger.warning( - f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, " - f"object_id={triplet.object_id}, statement_id={statement.id}" + # 改进的警告信息,包含更多调试信息 + missing_subject = "subject" if not subject_entity_id else "" + missing_object = "object" if not object_entity_id else "" + missing_both = " and " if (not subject_entity_id and not object_entity_id) else "" + + logger.debug( + f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: " + f"subject_id={triplet.subject_id} ({triplet.subject_name}), " + f"object_id={triplet.object_id} ({triplet.object_name}), " + f"predicate={triplet.predicate}, " + f"statement_id={statement.id}, " + f"available_indices={sorted(entity_idx_to_id.keys())}" ) logger.info( diff --git a/api/app/models/agent_app_config_model.py b/api/app/models/agent_app_config_model.py index 0a7a5935..96752c8e 100644 --- a/api/app/models/agent_app_config_model.py +++ b/api/app/models/agent_app_config_model.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import relationship from app.base.type import PydanticType from app.db import Base -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters class AgentConfig(Base): diff --git a/api/app/models/multi_agent_model.py b/api/app/models/multi_agent_model.py index 544ddb27..400c05ad 100644 --- a/api/app/models/multi_agent_model.py +++ b/api/app/models/multi_agent_model.py @@ -10,7 +10,7 @@ from sqlalchemy.orm import relationship from app.base.type import PydanticType from app.db import Base -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters class OrchestrationMode(StrEnum): diff --git a/api/app/schemas/multi_agent_schema.py b/api/app/schemas/multi_agent_schema.py index c0d72cdd..8fba2929 100644 --- a/api/app/schemas/multi_agent_schema.py +++ b/api/app/schemas/multi_agent_schema.py @@ -4,7 +4,7 @@ import datetime from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, ConfigDict, field_serializer -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters # ==================== 子 Agent 配置 ==================== diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index 3971aab7..87fdb22c 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -5,7 +5,7 @@ import uuid from typing import Dict, Any, List, Optional, Tuple from sqlalchemy.orm import Session -from app.schemas import ModelParameters +from app.schemas.app_schema import ModelParameters from app.services.conversation_state_manager import ConversationStateManager from app.models import ModelConfig, AgentConfig from app.core.logging_config import get_business_logger diff --git a/api/app/utils/app_config_utils.py b/api/app/utils/app_config_utils.py index ae41d8bf..06549989 100644 --- a/api/app/utils/app_config_utils.py +++ b/api/app/utils/app_config_utils.py @@ -57,7 +57,7 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]: if data is None: return None - from app.schemas import ModelParameters + from app.schemas.app_schema import ModelParameters if isinstance(data, ModelParameters): return data diff --git a/api/migrations/versions/325b759cd66b_2026011240.py b/api/migrations/versions/325b759cd66b_2026011240.py index 763b0289..3d7443a8 100644 --- a/api/migrations/versions/325b759cd66b_2026011240.py +++ b/api/migrations/versions/325b759cd66b_2026011240.py @@ -31,6 +31,7 @@ def upgrade() -> None: op.execute("UPDATE memory_config SET config_id = apply_id::uuid") op.alter_column('memory_config', 'config_id', nullable=False) op.create_primary_key('memory_config_pkey', 'memory_config', ['config_id']) + op.execute("ALTER TABLE memory_config ALTER COLUMN config_id_old DROP DEFAULT") op.execute("DROP SEQUENCE IF EXISTS data_config_config_id_seq") diff --git a/redbear-mem-benchmark b/redbear-mem-benchmark index d9a00be6..558c023d 160000 --- a/redbear-mem-benchmark +++ b/redbear-mem-benchmark @@ -1 +1 @@ -Subproject commit d9a00be62d974c0ad071c27e86f878b921c675b6 +Subproject commit 558c023dadb5327a05561b22d8fb363c6ee2be29 diff --git a/web/src/components/CustomSelect/index.tsx b/web/src/components/CustomSelect/index.tsx index 1887d635..6153a76d 100644 --- a/web/src/components/CustomSelect/index.tsx +++ b/web/src/components/CustomSelect/index.tsx @@ -15,7 +15,7 @@ interface ApiResponse { interface CustomSelectProps extends Omit { url: string; params?: Record; - valueKey?: string; + valueKey?: string | string[]; labelKey?: string; placeholder?: string; hasAll?: boolean; @@ -66,11 +66,18 @@ const CustomSelect: FC = ({ {...props} > {hasAll && {allTitle || t('common.all')}} - {displayOptions.map((option) => ( - - {String(option[labelKey])} - - ))} + {displayOptions.map((option) => { + const getValue = () => { + if (typeof valueKey === 'string') return option[valueKey]; + return valueKey.find(key => option[key] != null) ? option[valueKey.find(key => option[key] != null)!] : undefined; + }; + const value = getValue(); + return ( + + {String(option[labelKey])} + + ); + })} ); }; diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 77e90440..97a622d1 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -79,7 +79,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[], placeholder={t('common.pleaseSelect')} url={url} hasAll={false} - valueKey='config_id' + valueKey={['config_id_old', 'config_id']} labelKey="config_name" /> @@ -126,12 +126,14 @@ const Agent = forwardRef((_props, ref) => { getApplicationConfig(id as string).then(res => { const response = res as Config let allTools = Array.isArray(response.tools) ? response.tools : [] + const memoryContent = response.memory?.memory_content + const convertedMemoryContent = memoryContent && !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent form.setFieldsValue({ ...response, tools: allTools, memory: { ...response.memory, - memory_content: response.memory?.memory_content ? Number(response.memory?.memory_content) : undefined + memory_content: convertedMemoryContent } }) setData({ diff --git a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx index abf56b18..70b17a11 100644 --- a/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/Knowledge/KnowledgeConfigModal.tsx @@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef { if (values?.retrieve_type) { const fieldsToReset = Object.keys(values).filter(key => - key !== 'kb_id' && key !== 'retrieve_type' + key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k' ) as (keyof KnowledgeConfigForm)[]; form.resetFields(fieldsToReset); } diff --git a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx index 77ca21a2..196ce8e3 100644 --- a/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx +++ b/web/src/views/Workflow/components/Properties/Knowledge/KnowledgeConfigModal.tsx @@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef { if (values?.retrieve_type) { const fieldsToReset = Object.keys(values).filter(key => - key !== 'kb_id' && key !== 'retrieve_type' + key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k' ) as (keyof KnowledgeConfigForm)[]; form.resetFields(fieldsToReset); } @@ -108,6 +108,7 @@ const KnowledgeConfigModal = forwardRef {/* Top K */} @@ -116,13 +117,12 @@ const KnowledgeConfigModal = forwardRef form.setFieldValue('top_k', value)} + // onChange={(value) => form.setFieldValue('top_k', value)} /> {/* 语义相似度阈值 similarity_threshold */} diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index e250e184..f528b9df 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -200,7 +200,7 @@ export const nodeLibrary: NodeLibrary[] = [ config_id: { type: 'customSelect', url: memoryConfigListUrl, - valueKey: 'config_id', + valueKey: ['config_id_old', 'config_id'], labelKey: 'config_name' }, search_switch: { @@ -223,7 +223,7 @@ export const nodeLibrary: NodeLibrary[] = [ config_id: { type: 'customSelect', url: memoryConfigListUrl, - valueKey: 'config_id', + valueKey: ['config_id_old', 'config_id'], labelKey: 'config_name' } } @@ -284,7 +284,7 @@ export const nodeLibrary: NodeLibrary[] = [ config: { input: { type: 'variableList', - filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop'], + filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop', 'parameter-extractor'], filterVariableNames: ['message'] }, parallel: { diff --git a/web/src/views/Workflow/types.ts b/web/src/views/Workflow/types.ts index 909c30e4..31d1f512 100644 --- a/web/src/views/Workflow/types.ts +++ b/web/src/views/Workflow/types.ts @@ -14,7 +14,7 @@ export interface NodeConfig { url?: string; params?: { [key: string]: unknown; } - valueKey?: string; + valueKey?: string | string[]; labelKey?: string; defaultValue?: any;