Merge branch 'refs/heads/develop' into fix/memory_bug_fix

This commit is contained in:
lixinyue
2026-01-27 10:27:17 +08:00
29 changed files with 49 additions and 7419 deletions

View File

@@ -1 +0,0 @@
"""Evaluation package with dataset-specific pipelines and a unified runner."""

View File

@@ -1,30 +0,0 @@
⏬数据集下载地址:
Locomo10.jsonhttps://github.com/snap-research/locomo/tree/main/data
LongMemEval_oracle.jsonhttps://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
上方数据集下载好后全部放入app/core/memory/data文件夹中
全流程基准测试运行:
locomo
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
LongMemEval
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
memsciqa
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
单独检索评估运行命令:
python -m app.core.memory.evaluation.locomo.locomo_test
python -m app.core.memory.evaluation.longmemeval.test_eval
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
需要先在项目中修改需要检测评估的group_id。
参数及解释:
● --dataset longmemeval - 指定数据集
● --sample-size 10 - 评估10个样本
● --start-index 0 - 从第0个样本开始
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
● --search-limit 8 - 检索限制8条
● --context-char-budget 4000 - 上下文字符预算4000
● --search-type hybrid - 使用混合检索
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
● --reset-group - 运行前清空组数据

View File

@@ -1,100 +0,0 @@
import math
import re
from typing import List, Dict
def _normalize(text: str) -> List[str]:
"""Lowercase, strip punctuation, and split into tokens."""
text = text.lower().strip()
# Python's re doesn't support \p classes; use a simple non-word filter
text = re.sub(r"[^\w\s]", " ", text)
tokens = [t for t in text.split() if t]
return tokens
def exact_match(pred: str, ref: str) -> float:
return float(_normalize(pred) == _normalize(ref))
def jaccard(pred: str, ref: str) -> float:
p = set(_normalize(pred))
r = set(_normalize(ref))
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
def f1_score(pred: str, ref: str) -> float:
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens:
return 0.0
# Clipped count
r_counts: Dict[str, int] = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts: Dict[str, int] = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
# Brevity penalty
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def percentile(values: List[float], p: float) -> float:
if not values:
return 0.0
vals = sorted(values)
k = (len(vals) - 1) * p
f = math.floor(k)
c = math.ceil(k)
if f == c:
return vals[int(k)]
return vals[f] + (k - f) * (vals[c] - vals[f])
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
if not latencies_ms:
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
p25 = percentile(latencies_ms, 0.25)
p50 = percentile(latencies_ms, 0.50)
p75 = percentile(latencies_ms, 0.75)
p95 = percentile(latencies_ms, 0.95)
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
def avg_context_tokens(contexts: List[str]) -> float:
if not contexts:
return 0.0
return sum(len(_normalize(c)) for c in contexts) / len(contexts)

View File

@@ -1,60 +0,0 @@
"""
Dialogue search queries for evaluation purposes.
This file contains Cypher queries for searching dialogues, entities, and chunks.
Placed in evaluation directory to avoid circular imports with src modules.
"""
# Entity search queries
SEARCH_ENTITIES_BY_NAME = """
MATCH (e:Entity)
WHERE e.name = $name
RETURN e
"""
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
MATCH (e:Entity)
WHERE e.name CONTAINS $name
RETURN e
"""
# Chunk search queries
SEARCH_CHUNKS_BY_CONTENT = """
MATCH (c:Chunk)
WHERE c.content CONTAINS $content
RETURN c
"""
# Dialogue search queries
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE d.dialog_id = $dialog_id
RETURN d
"""
SEARCH_DIALOGUES_BY_CONTENT = """
MATCH (d:Dialogue)
WHERE d.content CONTAINS $q
RETURN d
"""
DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q
MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL
AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
WITH d, q, d.dialog_embedding AS v
WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold
RETURN d.id AS dialog_id,
d.end_user_id AS end_user_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -1,341 +0,0 @@
import asyncio
import json
import os
import re
from datetime import datetime
from typing import Any, Dict, List, Optional
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.models.message_models import (
ConversationContext,
ConversationMessage,
DialogData,
)
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
from app.core.memory.utils.config.definitions import (
SELECTED_CHUNKER_STRATEGY,
SELECTED_EMBEDDING_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
# Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
async def ingest_contexts_via_full_pipeline(
contexts: List[str],
end_user_id: str,
chunker_strategy: str | None = None,
embedding_name: str | None = None,
save_chunk_output: bool = False,
save_chunk_output_path: str | None = None,
) -> bool:
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
This function mirrors the steps in main(), but starts from raw text contexts.
Args:
contexts: List of dialogue texts, each containing lines like "role: message".
end_user_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
Returns:
True if data saved successfully, False otherwise.
"""
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
# Initialize llm client with graceful fallback
llm_client = None
llm_available = True
try:
from app.core.memory.utils.config import definitions as config_defs
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
except Exception as e:
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
llm_available = False
# Step A: Build DialogData list from contexts with robust parsing
chunker = DialogueChunker(chunker_strategy)
dialog_data_list: List[DialogData] = []
for idx, ctx in enumerate(contexts):
messages: List[ConversationMessage] = []
# Improved parsing: capture multi-line message blocks, normalize roles
pattern = r"^\s*(用户|AI|assistant|user)\s*[:]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[:]|\Z)"
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
if matches:
for m in matches:
raw_role = m.group(1).strip()
content = m.group(2).strip()
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=content))
else:
# Fallback: line-by-line parsing
for raw in ctx.split("\n"):
line = raw.strip()
if not line:
continue
m = re.match(r'^\s*([^:]+)\s*[:]\s*(.+)$', line)
if m:
role = m.group(1).strip()
msg = m.group(2).strip()
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=msg))
else:
# Final fallback: treat as user message
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
messages.append(ConversationMessage(role=default_role, msg=line))
context_model = ConversationContext(msgs=messages)
dialog = DialogData(
context=context_model,
ref_id=f"pipeline_item_{idx}",
end_user_id=end_user_id,
user_id="default_user",
apply_id="default_application",
)
# Generate chunks
dialog.chunks = await chunker.process_dialogue(dialog)
dialog_data_list.append(dialog)
if not dialog_data_list:
print("No dialogs to process for ingestion.")
return False
# Optionally save chunking outputs for debugging
if save_chunk_output:
try:
def _serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
from app.core.config import settings
settings.ensure_memory_output_dir()
default_path = settings.get_memory_output_path("chunker_test_output.txt")
out_path = save_chunk_output_path or default_path
combined_output = [dd.model_dump() for dd in dialog_data_list]
with open(out_path, "w", encoding="utf-8") as f:
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
print(f"Saved chunking results to: {out_path}")
except Exception as e:
print(f"Failed to save chunking results: {e}")
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
if not llm_available:
print("[Ingestion] Skipping extraction pipeline (no LLM).")
return False
# 初始化 embedder 客户端
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService
try:
with get_db_context() as db:
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
except Exception as e:
print(f"[Ingestion] Failed to initialize embedder client: {e}")
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
return False
connector = Neo4jConnector()
# 初始化并运行 ExtractionOrchestrator
from app.core.memory.utils.config.config_utils import get_pipeline_config
config = get_pipeline_config()
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=connector,
config=config,
)
# 创建一个包装的 orchestrator 来修复时间提取器的输出
# 保存原始的 _assign_extracted_data 方法
original_assign = orchestrator._assign_extracted_data
def clean_temporal_value(value):
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
if value is None:
return None
if isinstance(value, str):
# 处理字符串形式的 'null', 'None', 空字符串等
if value.lower() in ('null', 'none', '') or value.strip() == '':
return None
return value
async def patched_assign_extracted_data(*args, **kwargs):
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
result = await original_assign(*args, **kwargs)
# 清理返回的 dialog_data_list 中的 temporal_validity
for dialog in result:
if hasattr(dialog, 'chunks') and dialog.chunks:
for chunk in dialog.chunks:
if hasattr(chunk, 'statements') and chunk.statements:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
tv = statement.temporal_validity
# 清理 valid_at 和 invalid_at
if hasattr(tv, 'valid_at'):
tv.valid_at = clean_temporal_value(tv.valid_at)
if hasattr(tv, 'invalid_at'):
tv.invalid_at = clean_temporal_value(tv.invalid_at)
return result
# 替换方法
orchestrator._assign_extracted_data = patched_assign_extracted_data
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
original_create = orchestrator._create_nodes_and_edges
async def patched_create_nodes_and_edges(dialog_data_list_arg):
"""包装方法:在创建节点前再次清理 temporal_validity"""
# 最后一次清理,确保万无一失
for dialog in dialog_data_list_arg:
if hasattr(dialog, 'chunks') and dialog.chunks:
for chunk in dialog.chunks:
if hasattr(chunk, 'statements') and chunk.statements:
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
tv = statement.temporal_validity
if hasattr(tv, 'valid_at'):
tv.valid_at = clean_temporal_value(tv.valid_at)
if hasattr(tv, 'invalid_at'):
tv.invalid_at = clean_temporal_value(tv.invalid_at)
return await original_create(dialog_data_list_arg)
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
# 运行完整的提取流水线
# orchestrator.run 返回 7 个元素的元组
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
) = result
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
# Step G: 生成记忆摘要
print("[Ingestion] Generating memory summaries...")
try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list,
llm_client=llm_client,
embedder_client=embedder_client
)
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
except Exception as e:
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
summaries = []
# Step H: Save to Neo4j
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
entity_edges=entity_entity_edges,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
connector=connector
)
# Save memory summaries separately
if summaries:
try:
await add_memory_summary_nodes(summaries, connector)
await add_memory_summary_statement_edges(summaries, connector)
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
except Exception as e:
print(f"Warning: Failed to save summary nodes: {e}")
await connector.close()
if success:
print("Successfully saved extracted data to Neo4j!")
else:
print("Failed to save data to Neo4j")
return success
except Exception as e:
print(f"Failed to save data to Neo4j: {e}")
return False
async def handle_context_processing(args):
"""Handle context-based processing from command line arguments."""
contexts = []
if args.contexts:
contexts.extend(args.contexts)
if args.context_file:
try:
with open(args.context_file, 'r', encoding='utf-8') as f:
contexts.extend(line.strip() for line in f if line.strip())
except Exception as e:
print(f"Error reading context file: {e}")
return False
if not contexts:
print("No contexts provided for processing.")
return False
return await main_from_contexts(contexts, args.context_end_user_id)
async def main_from_contexts(contexts: List[str], end_user_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline(
contexts=contexts,
end_user_id=end_user_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True
)
if success:
print("Successfully processed and saved contexts to Neo4j!")
else:
print("Failed to process contexts.")
return success

View File

@@ -1,575 +0,0 @@
"""
LoCoMo Benchmark Script
This module provides the main entry point for running LoCoMo benchmark evaluations.
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
in a clean, maintainable way.
Usage:
python locomo_benchmark.py --sample_size 20 --search_type hybrid
"""
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
try:
from dotenv import load_dotenv
except ImportError:
def load_dotenv():
pass
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
f1_score,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.locomo.locomo_metrics import (
get_category_name,
locomo_f1_score,
locomo_multi_f1,
)
from app.core.memory.evaluation.locomo.locomo_utils import (
extract_conversations,
ingest_conversations_if_needed,
load_locomo_data,
resolve_temporal_references,
retrieve_relevant_information,
select_and_format_information,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_end_user_id,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
async def run_locomo_benchmark(
sample_size: int = 20,
end_user_id: Optional[str] = None,
search_type: str = "hybrid",
search_limit: int = 12,
context_char_budget: int = 8000,
reset_group: bool = False,
skip_ingest: bool = False,
output_dir: Optional[str] = None
) -> Dict[str, Any]:
"""
Run LoCoMo benchmark evaluation.
This function orchestrates the complete evaluation pipeline:
1. Load LoCoMo dataset (only QA pairs from first conversation)
2. Check/ingest conversations into database (only first conversation, unless skip_ingest=True)
3. For each question:
- Retrieve relevant information
- Generate answer using LLM
- Calculate metrics
4. Aggregate results and save to file
Note: By default, only the first conversation is ingested into the database,
and only QA pairs from that conversation are evaluated. This ensures that
all questions have corresponding memory in the database for retrieval.
Args:
sample_size: Number of QA pairs to evaluate (from first conversation)
end_user_id: Database group ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context
reset_group: Whether to clear and re-ingest data (not implemented)
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
output_dir: Directory to save results (uses default if None)
Returns:
Dictionary with evaluation results including metrics, timing, and samples
"""
# Use default end_user_id if not provided
end_user_id = end_user_id or SELECTED_end_user_id
# Determine data path
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
# Fallback to current directory
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
print(f"\n{'='*60}")
print("🚀 Starting LoCoMo Benchmark Evaluation")
print(f"{'='*60}")
print("📊 Configuration:")
print(f" Sample size: {sample_size}")
print(f" Group ID: {end_user_id}")
print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars")
print(f" Data path: {data_path}")
print(f"{'='*60}\n")
# Step 1: Load LoCoMo data
print("📂 Loading LoCoMo dataset...")
try:
# Only load QA pairs from the first conversation (index 0)
# since we only ingest the first conversation into the database
qa_items = load_locomo_data(data_path, sample_size, conversation_index=0)
print(f"✅ Loaded {len(qa_items)} QA pairs from conversation 0\n")
except Exception as e:
print(f"❌ Failed to load data: {e}")
return {
"error": f"Data loading failed: {e}",
"timestamp": datetime.now().isoformat()
}
# Step 2: Extract conversations and ingest if needed
if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" Group ID: {end_user_id}\n")
else:
print("💾 Checking database ingestion...")
try:
conversations = extract_conversations(data_path, max_dialogues=1)
print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into group '{end_user_id}'...")
success = await ingest_conversations_if_needed(
conversations=conversations,
end_user_id=end_user_id,
reset=reset_group
)
if success:
print("✅ Ingestion completed successfully\n")
else:
print("⚠️ Ingestion may have failed, continuing anyway\n")
except Exception as e:
print(f"❌ Ingestion failed: {e}")
print("⚠️ Continuing with evaluation (database may be empty)\n")
# Step 3: Initialize clients
print("🔧 Initializing clients...")
connector = Neo4jConnector()
# Initialize LLM client with database context
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Initialize embedder
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
print("✅ Clients initialized\n")
# Step 4: Process questions
print(f"🔍 Processing {len(qa_items)} questions...")
print(f"{'='*60}\n")
# Tracking variables
latencies_search: List[float] = []
latencies_llm: List[float] = []
context_counts: List[int] = []
context_chars: List[int] = []
context_tokens: List[int] = []
# Metric lists
f1_scores: List[float] = []
bleu1_scores: List[float] = []
jaccard_scores: List[float] = []
locomo_f1_scores: List[float] = []
# Per-category tracking
category_counts: Dict[str, int] = {}
category_f1: Dict[str, List[float]] = {}
category_bleu1: Dict[str, List[float]] = {}
category_jaccard: Dict[str, List[float]] = {}
category_locomo_f1: Dict[str, List[float]] = {}
# Detailed samples
samples: List[Dict[str, Any]] = []
# Fixed anchor date for temporal resolution
anchor_date = datetime(2023, 5, 8)
try:
for idx, item in enumerate(qa_items, 1):
question = item.get("question", "")
ground_truth = item.get("answer", "")
category = get_category_name(item)
# Ensure ground truth is a string
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
print(f"[{idx}/{len(qa_items)}] Category: {category}")
print(f"❓ Question: {question}")
print(f"✅ Ground Truth: {ground_truth_str}")
# Step 4a: Retrieve relevant information
t_search_start = time.time()
try:
retrieved_info = await retrieve_relevant_information(
question=question,
end_user_id=end_user_id,
search_type=search_type,
search_limit=search_limit,
connector=connector,
embedder=embedder
)
t_search_end = time.time()
search_latency = (t_search_end - t_search_start) * 1000
latencies_search.append(search_latency)
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
except Exception as e:
print(f"❌ Retrieval failed: {e}")
retrieved_info = []
search_latency = 0.0
latencies_search.append(search_latency)
# Step 4b: Select and format context
context_text = select_and_format_information(
retrieved_info=retrieved_info,
question=question,
max_chars=context_char_budget
)
# Resolve temporal references
context_text = resolve_temporal_references(context_text, anchor_date)
# Add reference date to context
if context_text:
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
else:
context_text = "No relevant context found."
# Track context statistics
context_counts.append(len(retrieved_info))
context_chars.append(len(context_text))
context_tokens.append(len(context_text.split()))
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
# Step 4c: Generate answer with LLM
messages = [
{
"role": "system",
"content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)
},
{
"role": "user",
"content": f"Question: {question}\n\nContext:\n{context_text}"
}
]
t_llm_start = time.time()
try:
response = await llm_client.chat(messages=messages)
t_llm_end = time.time()
llm_latency = (t_llm_end - t_llm_start) * 1000
latencies_llm.append(llm_latency)
# Extract prediction from response
if hasattr(response, 'content'):
prediction = response.content.strip()
elif isinstance(response, dict):
prediction = response["choices"][0]["message"]["content"].strip()
else:
prediction = "Unknown"
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
except Exception as e:
print(f"❌ LLM failed: {e}")
prediction = "Unknown"
llm_latency = 0.0
latencies_llm.append(llm_latency)
# Step 4d: Calculate metrics
f1_val = f1_score(prediction, ground_truth_str)
bleu1_val = bleu1(prediction, ground_truth_str)
jaccard_val = jaccard(prediction, ground_truth_str)
# LoCoMo-specific F1: use multi-answer for category 1 (Multi-Hop)
if item.get("category") == 1:
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
else:
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
# Accumulate metrics
f1_scores.append(f1_val)
bleu1_scores.append(bleu1_val)
jaccard_scores.append(jaccard_val)
locomo_f1_scores.append(locomo_f1_val)
# Track by category
category_counts[category] = category_counts.get(category, 0) + 1
category_f1.setdefault(category, []).append(f1_val)
category_bleu1.setdefault(category, []).append(bleu1_val)
category_jaccard.setdefault(category, []).append(jaccard_val)
category_locomo_f1.setdefault(category, []).append(locomo_f1_val)
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
print()
# Save sample details
samples.append({
"question": question,
"ground_truth": ground_truth_str,
"prediction": prediction,
"category": category,
"metrics": {
"f1": f1_val,
"bleu1": bleu1_val,
"jaccard": jaccard_val,
"locomo_f1": locomo_f1_val
},
"retrieval": {
"num_docs": len(retrieved_info),
"context_length": len(context_text)
},
"timing": {
"search_ms": search_latency,
"llm_ms": llm_latency
}
})
finally:
# Close connector
await connector.close()
# Step 5: Aggregate results
print(f"\n{'='*60}")
print("📊 Aggregating Results")
print(f"{'='*60}\n")
# Overall metrics
overall_metrics = {
"f1": sum(f1_scores) / max(len(f1_scores), 1) if f1_scores else 0.0,
"bleu1": sum(bleu1_scores) / max(len(bleu1_scores), 1) if bleu1_scores else 0.0,
"jaccard": sum(jaccard_scores) / max(len(jaccard_scores), 1) if jaccard_scores else 0.0,
"locomo_f1": sum(locomo_f1_scores) / max(len(locomo_f1_scores), 1) if locomo_f1_scores else 0.0
}
# Per-category metrics
by_category: Dict[str, Dict[str, Any]] = {}
for cat in category_counts:
f1_list = category_f1.get(cat, [])
b1_list = category_bleu1.get(cat, [])
j_list = category_jaccard.get(cat, [])
lf_list = category_locomo_f1.get(cat, [])
by_category[cat] = {
"count": category_counts[cat],
"f1": sum(f1_list) / max(len(f1_list), 1) if f1_list else 0.0,
"bleu1": sum(b1_list) / max(len(b1_list), 1) if b1_list else 0.0,
"jaccard": sum(j_list) / max(len(j_list), 1) if j_list else 0.0,
"locomo_f1": sum(lf_list) / max(len(lf_list), 1) if lf_list else 0.0
}
# Latency statistics
latency = {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm)
}
# Context statistics
context_stats = {
"avg_retrieved_docs": sum(context_counts) / max(len(context_counts), 1) if context_counts else 0.0,
"avg_context_chars": sum(context_chars) / max(len(context_chars), 1) if context_chars else 0.0,
"avg_context_tokens": sum(context_tokens) / max(len(context_tokens), 1) if context_tokens else 0.0
}
# Build result dictionary
result = {
"dataset": "locomo",
"sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(),
"params": {
"end_user_id": end_user_id,
"search_type": search_type,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_id": SELECTED_LLM_ID,
"embedding_id": SELECTED_EMBEDDING_ID
},
"overall_metrics": overall_metrics,
"by_category": by_category,
"latency": latency,
"context_stats": context_stats,
"samples": samples
}
# Step 6: Save results
if output_dir is None:
output_dir = os.path.join(
os.path.dirname(__file__),
"results"
)
os.makedirs(output_dir, exist_ok=True)
# Generate timestamped filename
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(output_dir, f"locomo_{timestamp_str}.json")
try:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ Results saved to: {output_path}\n")
except Exception as e:
print(f"❌ Failed to save results: {e}")
print("📊 Printing results to console instead:\n")
print(json.dumps(result, ensure_ascii=False, indent=2))
return result
def main():
"""
Parse command-line arguments and run benchmark.
This function provides a CLI interface for running LoCoMo benchmarks
with configurable parameters.
"""
parser = argparse.ArgumentParser(
description="Run LoCoMo benchmark evaluation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sample_size",
type=int,
default=20,
help="Number of QA pairs to evaluate"
)
parser.add_argument(
"--end_user_id",
type=str,
default=None,
help="Database group ID for retrieval (uses default if not specified)"
)
parser.add_argument(
"--search_type",
type=str,
default="hybrid",
choices=["keyword", "embedding", "hybrid"],
help="Search strategy to use"
)
parser.add_argument(
"--search_limit",
type=int,
default=12,
help="Maximum number of documents to retrieve per query"
)
parser.add_argument(
"--context_char_budget",
type=int,
default=8000,
help="Maximum characters for context"
)
parser.add_argument(
"--reset_group",
action="store_true",
help="Clear and re-ingest data (not implemented)"
)
parser.add_argument(
"--skip_ingest",
action="store_true",
help="Skip data ingestion and use existing data in Neo4j"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Directory to save results (uses default if not specified)"
)
args = parser.parse_args()
# Load environment variables
load_dotenv()
# Run benchmark
result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size,
end_user_id=args.end_user_id,
search_type=args.search_type,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
reset_group=args.reset_group,
skip_ingest=args.skip_ingest,
output_dir=args.output_dir
))
# Print summary
print(f"\n{'='*60}")
# Check if there was an error
if 'error' in result:
print("❌ Benchmark Failed!")
print(f"{'='*60}")
print(f"Error: {result['error']}")
return
print("🎉 Benchmark Complete!")
print(f"{'='*60}")
print("📊 Final Results:")
print(f" Sample size: {result.get('sample_size', 0)}")
print(f" F1: {result['overall_metrics']['f1']:.3f}")
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
if result.get('context_stats'):
print("\n📈 Context Statistics:")
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
if result.get('latency'):
print("\n⏱️ Latency Statistics:")
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
f"P50: {result['latency']['search']['p50']:.1f}ms, "
f"P95: {result['latency']['search']['p95']:.1f}ms")
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
f"P95: {result['latency']['llm']['p95']:.1f}ms")
if result.get('by_category'):
print("\n📂 Results by Category:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" Count: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
print(f" Jaccard: {metrics['jaccard']:.3f}")
print(f"\n{'='*60}\n")
if __name__ == "__main__":
main()

View File

@@ -1,225 +0,0 @@
"""
LoCoMo-specific metric calculations.
This module provides clean, simplified implementations of metrics used for
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
"""
import re
from typing import Dict, Any
def normalize_text(text: str) -> str:
"""
Normalize text for LoCoMo evaluation.
Normalization steps:
- Convert to lowercase
- Remove commas
- Remove stop words (a, an, the, and)
- Remove punctuation
- Normalize whitespace
Args:
text: Input text to normalize
Returns:
Normalized text string with consistent formatting
Examples:
>>> normalize_text("The cat, and the dog")
'cat dog'
>>> normalize_text("Hello, World!")
'hello world'
"""
# Ensure input is a string
text = str(text) if text is not None else ""
# Convert to lowercase
text = text.lower()
# Remove commas
text = re.sub(r"[\,]", " ", text)
# Remove stop words
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
# Remove punctuation (keep only word characters and whitespace)
text = re.sub(r"[^\w\s]", " ", text)
# Normalize whitespace (collapse multiple spaces to single space)
text = " ".join(text.split())
return text
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for single-answer questions.
Uses token-level precision and recall based on normalized text.
Treats tokens as sets (no duplicate counting).
Args:
prediction: Model's predicted answer
ground_truth: Correct answer
Returns:
F1 score between 0.0 and 1.0
Examples:
>>> locomo_f1_score("Paris", "Paris")
1.0
>>> locomo_f1_score("The cat", "cat")
1.0
>>> locomo_f1_score("dog", "cat")
0.0
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Normalize and tokenize
pred_tokens = normalize_text(pred_str).split()
truth_tokens = normalize_text(truth_str).split()
# Handle empty cases
if not pred_tokens or not truth_tokens:
return 0.0
# Convert to sets for comparison
pred_set = set(pred_tokens)
truth_set = set(truth_tokens)
# Calculate true positives (intersection)
true_positives = len(pred_set & truth_set)
# Calculate precision and recall
precision = true_positives / len(pred_set) if pred_set else 0.0
recall = true_positives / len(truth_set) if truth_set else 0.0
# Calculate F1 score
if precision + recall == 0:
return 0.0
f1 = 2 * precision * recall / (precision + recall)
return f1
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for multi-answer questions.
Handles comma-separated answers by:
1. Splitting both prediction and ground truth by commas
2. For each ground truth answer, finding the best matching prediction
3. Averaging the F1 scores across all ground truth answers
Args:
prediction: Model's predicted answer (may contain multiple comma-separated answers)
ground_truth: Correct answer (may contain multiple comma-separated answers)
Returns:
Average F1 score across all ground truth answers (0.0 to 1.0)
Examples:
>>> locomo_multi_f1("Paris, London", "Paris, London")
1.0
>>> locomo_multi_f1("Paris", "Paris, London")
0.5
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
0.5
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Split by commas and strip whitespace
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
# Handle empty cases
if not predictions or not ground_truths:
return 0.0
# For each ground truth, find the best matching prediction
f1_scores = []
for gt in ground_truths:
# Calculate F1 with each prediction and take the maximum
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
f1_scores.append(best_f1)
# Return average F1 across all ground truths
return sum(f1_scores) / len(f1_scores)
def get_category_name(item: Dict[str, Any]) -> str:
"""
Extract and normalize category name from QA item.
Handles both numeric categories (1-4) and string categories with various formats.
Supports multiple field names: "cat", "category", "type".
Category mapping:
- 1 or "multi-hop" -> "Multi-Hop"
- 2 or "temporal" -> "Temporal"
- 3 or "open domain" -> "Open Domain"
- 4 or "single-hop" -> "Single-Hop"
Args:
item: QA item dictionary containing category information
Returns:
Standardized category name or "unknown" if not found
Examples:
>>> get_category_name({"category": 1})
'Multi-Hop'
>>> get_category_name({"cat": "temporal"})
'Temporal'
>>> get_category_name({"type": "Single-Hop"})
'Single-Hop'
"""
# Numeric category mapping
CATEGORY_MAP = {
1: "Multi-Hop",
2: "Temporal",
3: "Open Domain",
4: "Single-Hop",
}
# String category aliases (case-insensitive)
TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
# Try "cat" field first (string category)
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return TYPE_ALIASES.get(lower, name)
# Try "category" field (can be int or string)
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP.get(cat_num, "unknown")
elif isinstance(cat_num, str) and cat_num.strip():
lower = cat_num.strip().lower()
return TYPE_ALIASES.get(lower, cat_num.strip())
# Try "type" field as fallback
cat_type = item.get("type")
if isinstance(cat_type, str) and cat_type.strip():
lower = cat_type.strip().lower()
return TYPE_ALIASES.get(lower, cat_type.strip())
return "unknown"

View File

@@ -1,811 +0,0 @@
# file name: check_neo4j_connection_fixed.py
import asyncio
import json
import math
import os
import re
import sys
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List
from pathlib import Path
from dotenv import load_dotenv
# 1
# 添加项目根目录到路径
current_dir = Path(__file__).resolve().parent
project_root = str(current_dir.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
src_dir = os.path.join(project_root, "src")
if src_dir not in sys.path:
sys.path.insert(0, src_dir)
load_dotenv()
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
def _loc_normalize(text: str) -> str:
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text)
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 尝试从 metrics.py 导入基础指标
try:
from common.metrics import bleu1, f1_score, jaccard
print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}")
# 回退到本地实现
def f1_score(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens:
return 0.0
r_counts = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def jaccard(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p = set(_loc_normalize(pred_str).split())
r = set(_loc_normalize(ref_str).split())
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
try:
# 添加 evaluation 目录路径
evaluation_dir = os.path.join(project_root, "evaluation")
if evaluation_dir not in sys.path:
sys.path.insert(0, evaluation_dir)
# 尝试从不同位置导入
try:
from locomo.qwen_search_eval import (
_resolve_relative_times,
loc_f1_score,
loc_multi_f1,
)
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError:
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e:
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
# 回退到本地实现 LoCoMo 特定函数
def _resolve_relative_times(text: str, anchor: datetime) -> str:
t = str(text) if text is not None else ""
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
p_tokens = _loc_normalize(prediction).split()
g_tokens = _loc_normalize(ground_truth).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
"""根据问题复杂度和进度动态调整检索参数"""
# 分析问题复杂度
word_count = len(question.split())
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
# 根据进度调整 - 后期问题可能需要更精确的检索
progress_factor = question_index / total_questions
base_limit = 12
if has_temporal and has_multi_hop:
base_limit = 20
elif word_count > 8:
base_limit = 16
# 随着测试进行,逐渐收紧检索范围
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
# 动态调整最大字符数
max_chars = 8000 + 4000 * (1 - progress_factor)
return {
"limit": adjusted_limit,
"max_chars": int(max_chars)
}
class EnhancedEvaluationMonitor:
def __init__(self, reset_interval=5, performance_threshold=0.6):
self.question_count = 0
self.reset_interval = reset_interval
self.performance_threshold = performance_threshold
self.consecutive_low_scores = 0
self.performance_history = []
self.recent_f1_scores = []
def should_reset_connections(self, current_f1=None):
"""基于计数和性能双重判断"""
# 定期重置
if self.question_count % self.reset_interval == 0:
return True
# 性能驱动的重置
if current_f1 is not None and current_f1 < self.performance_threshold:
self.consecutive_low_scores += 1
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
print("🚨 连续低分,触发紧急重置")
self.consecutive_low_scores = 0
return True
else:
self.consecutive_low_scores = 0
return False
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
"""记录性能指标,检测衰减"""
self.performance_history.append({
'index': question_index,
'metrics': metrics,
'context_length': context_length,
'retrieved_docs': retrieved_docs,
'timestamp': time.time()
})
# 记录最近的F1分数
self.recent_f1_scores.append(metrics['f1'])
if len(self.recent_f1_scores) > 5:
self.recent_f1_scores.pop(0)
def get_recent_performance(self):
"""获取近期平均性能"""
if not self.recent_f1_scores:
return 0.5
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
def get_performance_trend(self):
"""分析性能趋势"""
if len(self.performance_history) < 2:
return "stable"
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
return "stable"
recent_avg = sum(recent_metrics) / len(recent_metrics)
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
if recent_avg < earlier_avg * 0.8:
return "degrading"
elif recent_avg > earlier_avg * 1.1:
return "improving"
else:
return "stable"
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
"""基于问题复杂度和近期性能动态调整检索参数"""
# 基础参数
base_params = get_dynamic_search_params(question, question_index, total_questions)
# 性能自适应调整
if recent_performance < 0.5: # 近期表现差
# 增加检索范围,尝试获取更多上下文
base_params["limit"] = min(base_params["limit"] + 5, 25)
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
elif recent_performance > 0.8: # 近期表现好
# 收紧检索,提高精度
base_params["limit"] = max(base_params["limit"] - 2, 8)
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
# 中间阶段特殊处理
mid_sequence_factor = abs(question_index / total_questions - 0.5)
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用更精确的检索策略")
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
return base_params
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
"""考虑问题序列位置的智能选择"""
if not contexts:
return ""
# 在序列中间阶段使用更严格的筛选
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用严格上下文筛选")
# 提取问题关键词
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
# 只保留高度相关的上下文
filtered_contexts = []
for context in contexts:
context_lower = context.lower()
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
if any(char.isdigit() for char in context):
relevance_score += 2
# 提高阈值:只有得分>=3的上下文才保留
if relevance_score >= 3:
filtered_contexts.append(context)
else:
print(f" - 过滤低分上下文: 得分={relevance_score}")
contexts = filtered_contexts
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
# 使用原有的智能选择逻辑
return smart_context_selection(contexts, question, max_chars)
async def run_enhanced_evaluation():
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 修正导入路径:使用 app.core.memory.src 前缀
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 加载数据
# 获取项目根目录
current_file = os.path.abspath(__file__)
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
memory_dir = os.path.dirname(evaluation_dir) # memory目录
data_path = os.path.join(memory_dir, "data", "locomo10.json")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
qa_items = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
items = qa_items[:20] # 测试多少个问题
# 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 初始化连接器
connector = Neo4jConnector()
# 初始化结果字典
results = {
"questions": [],
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
"category_metrics": {},
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
"performance_trend": "stable",
"timestamp": datetime.now().isoformat(),
"enhanced_strategy": True
}
total_f1 = 0.0
total_bleu1 = 0.0
total_jaccard = 0.0
total_loc_f1 = 0.0
total_context_length = 0
total_retrieved_docs = 0
category_stats = {}
try:
for i, item in enumerate(items):
monitor.question_count += 1
# 获取近期性能用于重置判断
recent_performance = monitor.get_recent_performance()
# 增强的重置判断
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
if should_reset and i > 0:
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
await connector.close()
connector = Neo4jConnector() # 创建新连接
print("✅ 连接重置完成")
q = item.get("question", "")
ref = item.get("answer", "")
ref_str = str(ref) if ref is not None else ""
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
print(f"✅ 真实答案: {ref_str}")
# 分类别统计
category = "Unknown"
if item.get("category") == 1:
category = "Multi-Hop"
elif item.get("category") == 2:
category = "Temporal"
elif item.get("category") == 3:
category = "Open Domain"
elif item.get("category") == 4:
category = "Single-Hop"
# 增强的检索参数
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
search_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
# 使用项目标准的混合检索方法
t0 = time.time()
contexts_all = []
try:
# 使用统一的搜索服务
from app.core.memory.storage_services.search import run_hybrid_search
print("🔀 使用混合搜索服务...")
search_results = await run_hybrid_search(
query_text=q,
search_type="hybrid",
end_user_id="locomo_sk",
limit=20,
include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重
embedding_id=SELECTED_EMBEDDING_ID
)
# 处理搜索结果 - 新的搜索服务返回统一的结构
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
print(f"📊 有效上下文数量: {len(contexts_all)}")
except Exception as e:
print(f"❌ 检索失败: {e}")
contexts_all = []
t1 = time.time()
search_time = (t1 - t0) * 1000
# 增强的上下文选择
context_text = ""
if contexts_all:
# 使用增强的上下文选择
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览(不只是第一条)
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
# 🔍 调试:检查答案是否在上下文中
if ref_str and ref_str.strip():
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# LLM 回答
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
except Exception as e:
print(f"❌ LLM 生成失败: {e}")
pred = "Unknown"
t3 = time.time()
llm_time = (t3 - t2) * 1000
# 计算指标 - 使用导入的指标函数
f1_val = f1_score(pred, ref_str)
bleu1_val = bleu1(pred, ref_str)
jaccard_val = jaccard(pred, ref_str)
loc_f1_val = loc_f1_score(pred, ref_str)
print(f"🤖 LLM 回答: {pred}")
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
# 更新统计
total_f1 += f1_val
total_bleu1 += bleu1_val
total_jaccard += jaccard_val
total_loc_f1 += loc_f1_val
total_context_length += len(context_text)
total_retrieved_docs += len(contexts_all)
if category not in category_stats:
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
category_stats[category]["count"] += 1
category_stats[category]["f1_sum"] += f1_val
category_stats[category]["b1_sum"] += bleu1_val
category_stats[category]["j_sum"] += jaccard_val
category_stats[category]["loc_f1_sum"] += loc_f1_val
# 记录性能指标
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
# 保存结果
question_result = {
"question": q,
"ground_truth": ref_str,
"prediction": pred,
"category": category,
"metrics": metrics,
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": max_chars,
"recent_performance": recent_performance
},
"timing": {
"search_ms": search_time,
"llm_ms": llm_time
}
}
results["questions"].append(question_result)
print("="*60)
except Exception as e:
print(f"❌ 评估过程中发生错误: {e}")
# 即使出错,也返回已有的结果
import traceback
traceback.print_exc()
finally:
await connector.close()
# 计算总体指标
n = len(items)
if n > 0:
results["overall_metrics"] = {
"f1": total_f1 / n,
"b1": total_bleu1 / n,
"j": total_jaccard / n,
"loc_f1": total_loc_f1 / n
}
for category, stats in category_stats.items():
count = stats["count"]
results["category_metrics"][category] = {
"count": count,
"f1": stats["f1_sum"] / count,
"bleu1": stats["b1_sum"] / count,
"jaccard": stats["j_sum"] / count,
"loc_f1": stats["loc_f1_sum"] / count
}
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
# 分析性能趋势
results["performance_trend"] = monitor.get_performance_trend()
results["reset_interval"] = monitor.reset_interval
results["total_questions_processed"] = monitor.question_count
return results
if __name__ == "__main__":
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
print("📋 增强特性:")
print(" - 双重重置策略:定期重置 + 性能驱动重置")
print(" - 动态检索参数:基于近期性能自适应调整")
print(" - 中间阶段严格筛选:提高上下文质量要求")
print(" - 连续性能监控:实时检测性能衰减")
result = asyncio.run(run_enhanced_evaluation())
print("\n📊 最终评估结果:")
print("总体指标:")
print(f" F1: {result['overall_metrics']['f1']:.4f}")
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
print("\n分类别指标:")
for category, metrics in result['category_metrics'].items():
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
print("\n检索统计:")
stats = result['retrieval_stats']
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
print(f"\n性能趋势: {result['performance_trend']}")
print(f"重置间隔: 每{result['reset_interval']}个问题")
print(f"处理问题总数: {result['total_questions_processed']}")
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
# 保存结果到指定目录
# 使用代码文件所在目录的绝对路径
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_file_dir, "results")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n详细结果已保存到: {output_file}")

View File

@@ -1,626 +0,0 @@
"""
LoCoMo Utilities Module
This module provides helper functions for the LoCoMo benchmark evaluation:
- Data loading from JSON files
- Conversation extraction for ingestion
- Temporal reference resolution
- Context selection and formatting
- Retrieval wrapper functions
- Ingestion wrapper functions
"""
import os
import json
import re
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from app.core.memory.utils.definitions import PROJECT_ROOT
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
def load_locomo_data(
data_path: str,
sample_size: int,
conversation_index: int = 0
) -> List[Dict[str, Any]]:
"""
Load LoCoMo dataset from JSON file.
The LoCoMo dataset structure is a list of conversation objects, where each
object contains a "qa" list of question-answer pairs.
Args:
data_path: Path to locomo10.json file
sample_size: Number of QA pairs to load (limits total QA items returned)
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
Returns:
List of QA item dictionaries, each containing:
- question: str
- answer: str
- category: int (1-4)
- evidence: List[str]
Raises:
FileNotFoundError: If data_path does not exist
json.JSONDecodeError: If file is not valid JSON
IndexError: If conversation_index is out of range
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo data structure: list of objects, each with a "qa" list
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
# Only load QA pairs from the specified conversation
if conversation_index < len(raw):
entry = raw[conversation_index]
if isinstance(entry, dict) and "qa" in entry:
qa_items.extend(entry.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has {len(raw)} conversations."
)
else:
# Fallback: single object with qa list
if conversation_index == 0:
qa_items.extend(raw.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has only 1 conversation."
)
# Return only the requested sample size
return qa_items[:sample_size]
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
"""
Extract conversation texts from LoCoMo data for ingestion.
This function extracts the raw conversation dialogues from the LoCoMo dataset
so they can be ingested into the memory system. Each conversation is formatted
as a multi-line string with "role: message" format.
Args:
data_path: Path to locomo10.json file
max_dialogues: Maximum number of dialogues to extract (default: 1)
Returns:
List of conversation strings formatted for ingestion.
Each string contains multiple lines in format "role: message"
Example output:
[
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
]
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# Ensure we have a list of entries
entries = raw if isinstance(raw, list) else [raw]
contents: List[str] = []
for i, entry in enumerate(entries[:max_dialogues]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
if not isinstance(conv, dict):
continue
lines: List[str] = []
# Collect all session_* messages
for key, val in sorted(conv.items()):
if isinstance(val, list) and key.startswith("session_"):
for msg in val:
if not isinstance(msg, dict):
continue
role = msg.get("speaker") or "User"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
if lines:
contents.append("\n".join(lines))
return contents
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
"""
Resolve relative temporal references to absolute dates.
This function converts relative time expressions (like "today", "yesterday",
"3 days ago") into absolute ISO date strings based on an anchor date.
Supported patterns:
- today, yesterday, tomorrow
- X days ago, in X days
- last week, next week
Args:
text: Text containing temporal references
anchor_date: Reference date for resolution (datetime object)
Returns:
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
Example:
>>> anchor = datetime(2023, 5, 8)
>>> resolve_temporal_references("I saw him yesterday", anchor)
"I saw him 2023-05-07"
"""
# Ensure input is a string
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(
r"\btoday\b",
anchor_date.date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\byesterday\b",
(anchor_date - timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\btomorrow\b",
(anchor_date + timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
# X days ago
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date - timedelta(days=n)).date().isoformat()
# in X days
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date + timedelta(days=n)).date().isoformat()
t = re.sub(
r"\b(\d+)\s+days?\s+ago\b",
_ago_repl,
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\bin\s+(\d+)\s+days?\b",
_in_repl,
t,
flags=re.IGNORECASE
)
# last week / next week (approximate as 7 days)
t = re.sub(
r"\blast\s+week\b",
(anchor_date - timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\bnext\s+week\b",
(anchor_date + timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
return t
def select_and_format_information(
retrieved_info: List[str],
question: str,
max_chars: int = 8000
) -> str:
"""
Intelligently select and format most relevant retrieved information for LLM prompt.
This function scores each piece of retrieved information based on keyword matching
with the question, then selects the highest-scoring pieces up to the character limit.
Scoring criteria:
- Keyword matches (higher weight for multiple occurrences)
- Context length (moderate length preferred)
- Position (earlier contexts get bonus points)
Args:
retrieved_info: List of retrieved information strings (chunks, statements, entities)
question: Question being answered
max_chars: Maximum total characters to include in final prompt
Returns:
Formatted string combining the most relevant information for LLM prompt.
Contexts are separated by double newlines.
Example:
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
>>> question = "Where did Alice go?"
>>> select_and_format_information(contexts, question, max_chars=100)
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
"""
if not retrieved_info:
return ""
# Extract question keywords (filter out stop words and short words)
question_lower = question.lower()
stop_words = {
'what', 'when', 'where', 'who', 'why', 'how',
'did', 'do', 'does', 'is', 'are', 'was', 'were',
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {
word for word in question_words
if word not in stop_words and len(word) > 2
}
# Score each context
scored_contexts = []
for i, context in enumerate(retrieved_info):
context_lower = context.lower()
score = 0
# Keyword matching score
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# Multiple occurrences increase score
score += context_lower.count(word) * 2
# Length score (prefer moderate length)
context_len = len(context)
if 100 < context_len < 2000:
score += 5
elif context_len >= 2000:
score += 2
# Position bonus (earlier contexts often more relevant)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# Sort by score (descending)
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# Select contexts up to character limit
selected = []
total_chars = 0
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
else:
# Try to include high-scoring context by truncating
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# Find lines with keywords
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
truncated = '\n'.join(relevant_lines)
selected.append(truncated + "\n[Content truncated...]")
total_chars += len(truncated)
break
return "\n\n".join(selected)
async def retrieve_relevant_information(
question: str,
end_user_id: str,
search_type: str,
search_limit: int,
connector: Any,
embedder: Any
) -> List[str]:
"""
Retrieve relevant information from memory graph for a question.
This function searches the Neo4j memory graph (populated during ingestion) and
returns relevant chunks, statements, and entity information that might help
answer the question.
The function supports three search types:
- "keyword": Full-text search using Cypher queries
- "embedding": Vector similarity search using embeddings
- "hybrid": Combination of keyword and embedding search with reranking
Args:
question: Question to search for
end_user_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance
embedder: Embedder client instance
Returns:
List of text strings (chunks, statements, entity summaries) from memory graph.
Each string represents a piece of retrieved information.
Raises:
Exception: If search fails (caught and returns empty list)
"""
from app.repositories.neo4j.graph_search import (
search_graph,
search_graph_by_embedding
)
from app.core.memory.storage_services.search import run_hybrid_search
contexts_all: List[str] = []
try:
if search_type == "embedding":
# Embedding-based search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context from chunks
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add summaries
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities (limit to 3 to avoid noise)
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# Keyword-based search
search_results = await search_graph(
connector=connector,
q=question,
end_user_id=end_user_id,
limit=search_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
# Build context from dialogues
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add entity names
if entities:
entity_names = [
str(e.get("name", "")).strip()
for e in entities[:5]
if e.get("name")
]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# Hybrid search with fallback to embedding
try:
search_results = await run_hybrid_search(
query_text=question,
search_type=search_type,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
)
# Handle flat structure (new API format)
if search_results and isinstance(search_results, dict):
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Check if we got results
if not (chunks or statements or entities or summaries):
# Try nested structure (backward compatibility)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
raise ValueError("Hybrid search returned empty results")
else:
raise ValueError("Hybrid search returned empty results")
except Exception as e:
# Fallback to embedding search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context (same for both hybrid and fallback)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
except Exception as e:
# Return empty list on error
contexts_all = []
return contexts_all
async def ingest_conversations_if_needed(
conversations: List[str],
end_user_id: str,
reset: bool = False
) -> bool:
"""
Wrapper for conversation ingestion using external extraction pipeline.
This function populates the Neo4j database with processed conversation data
(chunks, statements, entities) so that the retrieval system has memory to search.
The ingestion process:
1. Parses conversation text into dialogue messages
2. Chunks the dialogues into semantic units
3. Extracts statements and entities using LLM
4. Generates embeddings for all content
5. Stores everything in Neo4j graph database
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
end_user_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
True if successful, False otherwise
Note:
The external function uses "contexts" to mean "conversation texts".
This runs the full extraction pipeline: chunking → entity extraction →
statement extraction → embedding → Neo4j storage.
"""
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
end_user_id=end_user_id,
save_chunk_output=True
)
return success
except Exception as e:
print(f"[Ingestion] Failed to ingest conversations: {e}")
return False

View File

@@ -1,878 +0,0 @@
import argparse
import asyncio
import json
import os
import statistics
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
import re
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
def _loc_normalize(text: str) -> str:
import re
# 确保输入是字符串
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text) # 去掉逗号
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 追加相对时间归一化为绝对日期有限支持today/yesterday/tomorrow/X days ago/in X days/last week/next week
def _resolve_relative_times(text: str, anchor: datetime) -> str:
import re
# 确保输入是字符串
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
# X days ago / in X days
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
# last week / next week以7天近似
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
# 单答案 F1按词集合计算近似原始实现去除词干依赖
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
p_tokens = _loc_normalize(pred_str).split()
g_tokens = _loc_normalize(truth_str).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
# 多答案 F1prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
CATEGORY_MAP_NUM_TO_NAME = {
4: "Single-Hop",
1: "Multi-Hop",
3: "Open Domain",
2: "Temporal",
}
_TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
def get_category_label(item: Dict[str, Any]) -> str:
# 1) 直接用字符串 cat
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return _TYPE_ALIASES.get(lower, name)
# 2) 数字 category 转名称
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
# 3) 备用 type 字段
t = item.get("type")
if isinstance(t, str) and t.strip():
lower = t.strip().lower()
return _TYPE_ALIASES.get(lower, t.strip())
return "unknown"
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_search_params_by_category(category: str):
"""根据问题类别调整检索参数"""
params_map = {
"Multi-Hop": {"limit": 20, "max_chars": 15000},
"Temporal": {"limit": 16, "max_chars": 10000},
"Open Domain": {"limit": 24, "max_chars": 18000},
"Single-Hop": {"limit": 12, "max_chars": 8000},
}
return params_map.get(category, {"limit": 16, "max_chars": 12000})
async def run_locomo_eval(
sample_size: int = 1,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0,
llm_max_tokens: int = 32,
search_type: str = "hybrid", # 保持默认值不变
output_path: str | None = None,
skip_ingest_if_exists: bool = True,
llm_timeout: float = 10.0,
llm_max_retries: int = 1
) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变
end_user_id = end_user_id or SELECTED_end_user_id
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
items: List[Dict[str, Any]] = qa_items[:sample_size]
# === 保持原来的数据摄入逻辑 ===
entries = raw if isinstance(raw, list) else [raw]
# 只摄入前1条对话保持原样
max_dialogues_to_ingest = 1
contents: List[str] = []
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest}")
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
sample_id = entry.get("sample_id", f"unknown_{i}")
print(f"🔍 处理对话 {i+1}: {sample_id}")
lines: List[str] = []
if isinstance(conv, dict):
# 收集所有 session_* 的消息
session_count = 0
for key, val in conv.items():
if isinstance(val, list) and key.startswith("session_"):
session_count += 1
for msg in val:
role = msg.get("speaker") or "用户"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
if not lines:
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
continue
contents.append("\n".join(lines))
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
# 选择要评测的QA对从所有对话中选取
indexed_items: List[tuple[int, Dict[str, Any]]] = []
if isinstance(raw, list):
for e_idx, entry in enumerate(raw):
for qa in entry.get("qa", []):
indexed_items.append((e_idx, qa))
else:
for qa in raw.get("qa", []):
indexed_items.append((0, qa))
# 这里使用sample_size来限制评测的QA数量
selected = indexed_items[:sample_size]
items: List[Dict[str, Any]] = [qa for _, qa in selected]
print(f"🎯 将评测 {len(items)} 个QA对数据库中只包含 {len(contents)} 个对话")
# === 修改结束 ===
connector = Neo4jConnector()
# 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
# 使用异步LLM客户端
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder用于直接调用
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# connector initialized above
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 上下文诊断收集
per_query_context_counts: List[int] = []
per_query_context_avg_tokens: List[float] = []
per_query_context_chars: List[int] = []
per_query_context_tokens_total: List[int] = []
# 详细样本调试信息
samples: List[Dict[str, Any]] = []
# 通用指标
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
# 参考 LoCoMo 评测的类别专用 F1multi-hop 使用多答案 F1
loc_f1s: List[float] = []
# Per-category aggregation
cat_counts: Dict[str, int] = {}
cat_f1s: Dict[str, List[float]] = {}
cat_b1s: Dict[str, List[float]] = {}
cat_jss: Dict[str, List[float]] = {}
cat_loc_f1s: Dict[str, List[float]] = {}
try:
for item in items:
q = item.get("question", "")
ref = item.get("answer", "")
# 确保答案是字符串
ref_str = str(ref) if ref is not None else ""
cat = get_category_label(item)
print(f"\n=== 处理问题: {q} ===")
# 根据类别调整检索参数
search_params = get_search_params_by_category(cat)
adjusted_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
# 改进的检索逻辑使用三路检索statements, dialogues, entities
t0 = time.time()
contexts_all: List[str] = []
search_results = None # 保存完整的检索结果
try:
if search_type == "embedding":
# 直接调用嵌入检索,包含三路数据
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# 直接调用关键词检索
search_results = await search_graph(
connector=connector,
q=q,
end_user_id=end_user_id,
limit=adjusted_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
# 构建上下文
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# 实体处理(关键词检索的实体可能没有分数)
if entities:
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# 🎯 关键修复:混合检索使用更严格的回退机制
print("🔀 使用混合检索(带回退机制)...")
try:
search_results = await run_hybrid_search(
query_text=q,
search_type=search_type,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
)
# 🎯 关键修复:正确处理混合检索的扁平结构
# 新的API返回扁平结构直接从顶层获取结果
if search_results and isinstance(search_results, dict):
# 新API返回扁平结构直接从顶层获取
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# 检查是否有有效结果
if chunks or statements or entities or summaries:
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
else:
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
print(f"✅ 混合检索成功使用旧格式reranked结果: {len(chunks)} chunks, {len(statements)} 陈述")
else:
raise ValueError("混合检索返回空结果")
else:
raise ValueError("混合检索返回空结果")
except Exception as e:
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
# 🎯 统一处理:构建上下文(所有检索类型共用)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
# 关键修复:过滤掉包含当前问题答案的上下文
filtered_contexts = []
for context in contexts_all:
content = str(context)
# 排除包含当前问题标准答案的上下文
if ref_str and ref_str.strip() and ref_str.strip() in content:
print("🚫 过滤掉包含标准答案的上下文")
continue
filtered_contexts.append(context)
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
contexts_all = filtered_contexts
# 输出完整的检索结果信息
print("🔍 检索结果详情:")
if search_results:
output_data = {
"statements": [
{
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
"score": s.get("score", 0.0)
}
for s in (statements[:2] if 'statements' in locals() else [])
],
"dialogues": [
{
"uuid": d.get("uuid", ""),
"end_user_id": d.get("end_user_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0)
}
for d in (dialogs[:2] if 'dialogs' in locals() else [])
],
"entities": [
{
"name": e.get("name", ""),
"entity_type": e.get("entity_type", ""),
"score": e.get("score", 0.0)
}
for e in (entities[:2] if 'entities' in locals() else [])
]
}
print(json.dumps(output_data, ensure_ascii=False, indent=2))
else:
print(" 无检索结果")
except Exception as e:
print(f"{search_type}检索失败: {e}")
contexts_all = []
search_results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 使用智能上下文选择
context_text = ""
if contexts_all:
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# 记录上下文诊断信息
per_query_context_counts.append(len(contexts_all))
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
per_query_context_chars.append(len(context_text))
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
# LLM 提示词
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
# 使用异步调用
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
# 计算指标(确保使用字符串)
f1_val = common_f1(str(pred), ref_str)
b1_val = bleu1(str(pred), ref_str)
j_val = jaccard(str(pred), ref_str)
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
# Accumulate by category
cat_counts[cat] = cat_counts.get(cat, 0) + 1
cat_f1s.setdefault(cat, []).append(f1_val)
cat_b1s.setdefault(cat, []).append(b1_val)
cat_jss.setdefault(cat, []).append(j_val)
# LoCoMo 专用 F1multi-hop(1) 使用多答案 F1其它(2/3/4)使用单答案 F1
if item.get("category") in [2, 3, 4]:
loc_val = loc_f1_score(str(pred), ref_str)
elif item.get("category") in [1]:
loc_val = loc_multi_f1(str(pred), ref_str)
else:
loc_val = loc_f1_score(str(pred), ref_str)
loc_f1s.append(loc_val)
cat_loc_f1s.setdefault(cat, []).append(loc_val)
# 保存完整的检索结果信息
samples.append({
"question": q,
"answer": ref_str,
"category": cat,
"prediction": pred,
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val,
"loc_f1": loc_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": adjusted_limit,
"max_chars": max_chars
},
"timing": {
"search_ms": (t1 - t0) * 1000,
"llm_ms": (t3 - t2) * 1000
}
})
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {ref_str}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
# Compute per-category averages and dispersion (std, iqr)
def _percentile(sorted_vals: List[float], p: float) -> float:
if not sorted_vals:
return 0.0
if len(sorted_vals) == 1:
return sorted_vals[0]
k = (len(sorted_vals) - 1) * p
f = int(k)
c = f + 1 if f + 1 < len(sorted_vals) else f
if f == c:
return sorted_vals[f]
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
by_category: Dict[str, Dict[str, float | int]] = {}
for c in cat_counts:
f_list = cat_f1s.get(c, [])
b_list = cat_b1s.get(c, [])
j_list = cat_jss.get(c, [])
lf_list = cat_loc_f1s.get(c, [])
j_sorted = sorted(j_list)
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
j_q75 = _percentile(j_sorted, 0.75)
j_q25 = _percentile(j_sorted, 0.25)
by_category[c] = {
"count": cat_counts[c],
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
"j_std": j_std,
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
# 参考 LoCoMo 评测的类别专用 F1
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
}
# 累加命中cum accuracy by category与 evaluation_stats.py 输出形式相仿
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
result = {
"dataset": "locomo",
"items": len(items),
"metrics": {
"f1": sum(f1s) / max(len(f1s), 1),
"b1": sum(b1s) / max(len(b1s), 1),
"j": sum(jss) / max(len(jss), 1),
# LoCoMo 类别专用 F1 的总体
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
},
"by_category": by_category,
"category_counts": cat_counts,
"cum_accuracy_by_category": cum_accuracy_by_category,
"context": {
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
"llm_id": SELECTED_LLM_ID,
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
"skip_ingest_if_exists": skip_ingest_if_exists,
"llm_timeout": llm_timeout,
"llm_max_retries": llm_max_retries,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens
},
"timestamp": datetime.now().isoformat()
}
if output_path:
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ 结果已保存到: {output_path}")
except Exception as e:
print(f"❌ 保存结果失败: {e}")
return result
finally:
await connector.close()
def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
args = parser.parse_args()
load_dotenv()
result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
output_path=args.output_path,
skip_ingest_if_exists=args.skip_ingest_if_exists,
llm_timeout=args.llm_timeout,
llm_max_retries=args.llm_max_retries
))
print("\n" + "="*50)
print("📊 最终评测结果:")
print(f" 样本数量: {result['items']}")
print(f" F1: {result['metrics']['f1']:.3f}")
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
print(f" Jaccard: {result['metrics']['j']:.3f}")
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
if result['by_category']:
print("\n📈 按类别细分:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" 样本数: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
print(f" Jaccard: {metrics['j']:.3f}{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,324 +0,0 @@
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
if not contexts:
return ""
import re
# 提取问题关键词(移除停用词)
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
# 评分
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
# 选择直到达到字符限制,必要时截断包含关键词的段落
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
"""Compose a text context from `dialog` list in msc_self_instruct item."""
parts: List[str] = []
for turn in dialog_obj.get("dialog", []):
speaker = turn.get("speaker", "")
text = turn.get("text", "")
if text:
parts.append(f"{speaker}: {text}")
return "\n".join(parts)
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Combine dialogues from embedding and keyword searches (embedding first)."""
if results is None:
return []
emb = []
kw = []
if isinstance(results.get("embedding_search"), dict):
emb = results.get("embedding_search", {}).get("dialogues", []) or []
elif isinstance(results.get("dialogues"), list):
emb = results.get("dialogues", []) or []
if isinstance(results.get("keyword_search"), dict):
kw = results.get("keyword_search", {}).get("dialogues", []) or []
seen = set()
merged: List[Dict[str, Any]] = []
for d in emb:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
for d in kw:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
return merged
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
end_user_id = end_user_id or SELECTED_GROUP_ID
# Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
# LLM client (使用异步调用)
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# Evaluate each item
connector = Neo4jConnector()
latencies_llm: List[float] = []
latencies_search: List[float] = []
contexts_used: List[str] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
try:
for item in items:
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 检索:对齐 locomo 的三路检索dialogues/statements/entities
t0 = time.time()
try:
results = await run_hybrid_search(
query_text=question,
search_type=search_type,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
memory_config=memory_config,
)
except Exception:
results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
contexts_all: List[str] = []
if results:
if search_type == "hybrid":
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
emb_dialogs = emb.get("dialogues", [])
emb_statements = emb.get("statements", [])
emb_entities = emb.get("entities", [])
kw_dialogs = kw.get("dialogues", [])
kw_statements = kw.get("statements", [])
kw_entities = kw.get("entities", [])
all_dialogs = emb_dialogs + kw_dialogs
all_statements = emb_statements + kw_statements
all_entities = emb_entities + kw_entities
# 简单去重与限制
seen_texts = set()
for d in all_dialogs:
text = str(d.get("content", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
for s in all_statements:
text = str(s.get("statement", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
# 实体摘要最多3个
names = []
merged_entities = all_entities[:]
for e in merged_entities:
name = str(e.get("name", "")).strip()
if name and name not in names:
names.append(name)
if len(names) >= 3:
break
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
else:
dialogs = results.get("dialogues", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
for d in dialogs:
text = str(d.get("content", "")).strip()
if text:
contexts_all.append(text)
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
# 智能选择并截断到预算
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text[:200])
# Call LLM (使用异步调用)
messages = [
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import (
bleu1,
f1_score,
jaccard,
)
f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference)))
# Aggregate metrics
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"accuracy": acc,
# Placeholders for extensibility
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"avg_context_tokens": ctx_avg_tokens,
}
return result
finally:
await connector.close()
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
args = parser.parse_args()
result = asyncio.run(
run_memsciqa_eval(
sample_size=args.sample_size,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
)
)
print(json.dumps(result, ensure_ascii=False, indent=2))
if __name__ == "__main__":
main()

View File

@@ -1,577 +0,0 @@
import argparse
import asyncio
import json
import os
import re
import time
from datetime import datetime
from typing import Any, Dict, List
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 路径与模块导入保持与现有评估脚本一致
import sys
from pathlib import Path
_THIS_DIR = Path(__file__).resolve().parent
_PROJECT_ROOT = str(_THIS_DIR.parents[1])
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
for _p in (_SRC_DIR, _PROJECT_ROOT):
if _p not in sys.path:
sys.path.insert(0, _p)
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService
try:
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
except Exception:
# 兜底:简单实现(必要时)
def f1_score(pred: str, ref: str) -> float:
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
tp = len(set(ps) & set(rs))
if tp == 0:
return 0.0
precision = tp / len(ps)
recall = tp / len(rs)
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
overlap = len([w for w in ps if w in rs])
return overlap / max(len(ps), 1)
def jaccard(pred: str, ref: str) -> float:
ps = set(pred.lower().split())
rs = set(ref.lower().split())
union = len(ps | rs)
if union == 0:
return 0.0
return len(ps & rs) / union
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
"""
if not contexts:
return ""
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3"""
ql = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
}
words = re.findall(r"\b[\w-]+\b", ql)
kws = [w for w in words if w not in stop_words and len(w) >= 3]
# 去重保序
seen = set()
uniq = []
for w in kws:
if w not in seen:
uniq.append(w)
seen.add(w)
if len(uniq) >= max_keywords:
break
return uniq
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
"""对上下文进行简单相关性打分,仅用于控制台可视化。
评分: score = match_count*200 + min(len(text), 100000)/100
"""
results = []
for ctx in contexts:
tl = (ctx or "").lower()
match_count = sum(1 for k in keywords if k in tl)
length = len(ctx)
score = match_count * 200 + min(length, 100000) / 100.0
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
return results[:max(top_n, 0)]
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
if not os.path.exists(data_path):
raise FileNotFoundError(f"未找到数据集: {data_path}")
items: List[Dict[str, Any]] = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
items.append(json.loads(line))
except Exception:
# 跳过坏行但不中断
continue
return items
async def run_memsciqa_test(
sample_size: int = 3,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
llm_max_tokens: int = 64,
search_type: str = "embedding",
data_path: str | None = None,
start_index: int = 0,
verbose: bool = True,
) -> Dict[str, Any]:
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
- 支持从指定索引开始与评估全部样本sample_size<=0
- 支持在摄入前重置组(清空图)与跳过摄入
- 支持 keyword / embedding / hybrid 三种检索
"""
# 默认使用指定的 memsci 组 ID
end_user_id = end_user_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底)
if not data_path:
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
if os.path.exists(proj_path):
data_path = proj_path
elif os.path.exists(cwd_path):
data_path = cwd_path
else:
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl请确保其存在于项目根目录或当前工作目录的 data 目录下。")
# 加载数据
all_items = load_dataset_memsciqa(data_path)
if sample_size is None or sample_size <= 0:
items = all_items[start_index:]
else:
items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector()
embedder = None
if search_type in ("embedding", "hybrid"):
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 评估循环
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 存储完整上下文文本用于统计
contexts_used: List[str] = []
per_query_context_chars: List[int] = []
per_query_context_counts: List[int] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
samples: List[Dict[str, Any]] = []
total_items = len(items)
for idx, item in enumerate(items):
if verbose:
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 三路检索chunks/statements/entities/summaries对齐 qwen_search_eval.py
t0 = time.time()
results = None
try:
if search_type in ("embedding", "hybrid"):
# 使用嵌入检索(与 qwen_search_eval 对齐)
results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
elif search_type == "keyword":
# 关键词检索(直接调用 graph_search
results = await search_graph(
connector=connector,
q=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
except Exception:
results = None
t1 = time.time()
search_ms = (t1 - t0) * 1000
latencies_search.append(search_ms)
# 构建上下文:包含 chunks、陈述、摘要和实体对齐 qwen_search_eval.py
contexts_all: List[str] = []
retrieved_counts: Dict[str, int] = {}
if results:
chunks = results.get("chunks", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
summaries = results.get("summaries", [])
retrieved_counts = {
"chunks": len(chunks),
"statements": len(statements),
"entities": len(entities),
"summaries": len(summaries),
}
# 优先使用 chunks
for c in chunks:
text = str(c.get("content", "")).strip()
if text:
contexts_all.append(text)
# 然后是 statements
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
# 然后是 summaries
for sm in summaries:
text = str(sm.get("summary", "")).strip()
if text:
contexts_all.append(text)
# 实体摘要最多加入前3个高分实体对齐 qwen_search_eval.py
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
if verbose:
if retrieved_counts:
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
print(f"📊 有效上下文数量: {len(contexts_all)}")
q_keywords = extract_question_keywords(question, max_keywords=8)
if q_keywords:
print(f"🔍 问题关键词: {set(q_keywords)}")
if contexts_all:
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
if analysis:
print("📊 上下文相关性分析:")
for a in analysis:
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
# 打印检索到的上下文预览,便于定位为何为 Unknown
print("🔎 上下文预览最多前10条每条截断展示:")
for i, ctx in enumerate(contexts_all[:10]):
preview = str(ctx).replace("\n", " ")
if len(preview) > 300:
preview = preview[:300] + "..."
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
# 标注参考答案是否出现在任一上下文中
ref_lower = (str(reference) or "").lower()
if ref_lower:
hits = []
for i, ctx in enumerate(contexts_all):
if ref_lower in str(ctx).lower():
hits.append(i+1)
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text)
per_query_context_chars.append(len(context_text))
per_query_context_counts.append(len(contexts_all))
if verbose:
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
# 展示拼接后的上下文片段,便于核查是否包含答案
concat_preview = context_text.replace("\n", " ")
if len(concat_preview) > 600:
concat_preview = concat_preview[:600] + "..."
print(f"🧵 拼接上下文预览: {concat_preview}")
messages = [
{
"role": "system",
"content": (
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
"3) Keep your answer brief and to the point;\n"
"4) Do not add explanations or additional text beyond the answer."
),
},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 更健壮的响应解析处理不同的LLM响应格式
if hasattr(resp, 'content'):
pred = resp.content.strip()
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
pred = resp["choices"][0]["message"]["content"].strip()
elif isinstance(resp, dict) and "content" in resp:
pred = resp["content"].strip()
elif isinstance(resp, str):
pred = resp.strip()
else:
pred = "Unknown"
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
if pred.lower() in ["unknown", ""]:
# 如果参考答案在上下文中存在但LLM返回Unknown可能是提示词问题
ref_lower = (str(reference) or "").lower()
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown检查提示词")
except Exception as e:
# 更详细的错误处理
pred = "Unknown"
print(f"⚠️ LLM调用异常: {e}")
t3 = time.time()
llm_ms = (t3 - t2) * 1000
latencies_llm.append(llm_ms)
exact = exact_match(pred, reference)
correct_flags.append(exact)
f1_val = f1_score(str(pred), str(reference))
b1_val = bleu1(str(pred), str(reference))
j_val = jaccard(str(pred), str(reference))
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
if verbose:
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {reference}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
samples.append({
"question": str(question),
"answer": str(reference),
"prediction": str(pred),
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": context_char_budget
},
"timing": {
"search_ms": search_ms,
"llm_ms": llm_ms
}
})
# 计算总体指标与聚合
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"context": {
"avg_tokens": ctx_avg_tokens,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": 0.0
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens,
"search_type": search_type,
"start_index": start_index,
"llm_id": SELECTED_LLM_ID,
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
},
"timestamp": datetime.now().isoformat(),
}
try:
await connector.close()
except Exception:
pass
return result
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size")
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID默认 group_memsci")
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型hybrid 等同于 embedding")
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl")
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径JSON")
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
args = parser.parse_args()
sample_size = 0 if args.all else args.sample_size
verbose_flag = False if args.quiet else args.verbose
result = asyncio.run(
run_memsciqa_test(
sample_size=sample_size,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
data_path=args.data_path,
start_index=args.start_index,
verbose=verbose_flag,
)
)
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果保存
out_path = args.output
if not out_path:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, "results")
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
try:
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n💾 结果已保存: {out_path}")
except Exception as e:
print(f"⚠️ 结果保存失败: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,150 +0,0 @@
import argparse
import asyncio
import json
import os
import sys
from typing import Any, Dict
# Add src directory to Python path for proper imports when running from evaluation directory
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
async def run(
dataset: str,
sample_size: int,
reset_group: bool,
end_user_id: str | None,
judge_model: str | None = None,
search_limit: int | None = None,
context_char_budget: int | None = None,
llm_temperature: float | None = None,
llm_max_tokens: int | None = None,
search_type: str | None = None,
start_index: int | None = None,
max_contexts_per_item: int | None = None,
) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
end_user_id = end_user_id or SELECTED_GROUP_ID
if reset_group:
connector = Neo4jConnector()
try:
await connector.delete_group(end_user_id)
finally:
await connector.close()
if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_locomo_eval(**kwargs)
if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
if start_index is not None:
kwargs["start_index"] = start_index
if max_contexts_per_item is not None:
kwargs["max_contexts_per_item"] = max_contexts_per_item
return await run_longmemeval_test(**kwargs)
raise ValueError(f"未知数据集: {dataset}")
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens不提供则使用各脚本默认")
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
# 仅透传到 longmemeval其他数据集忽略
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval起始样本索引不提供则用脚本默认")
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval每条样本摄入的上下文数量上限不提供则用脚本默认")
parser.add_argument("--output", type=str, default=None, help="可选将评估结果保存到指定文件路径JSON不提供时默认保存到 evaluation/<dataset>/results 目录")
args = parser.parse_args()
result = asyncio.run(run(
args.dataset,
args.sample_size,
args.reset_group,
args.end_user_id,
args.judge_model,
args.search_limit,
args.context_char_budget,
args.llm_temperature,
args.llm_max_tokens,
args.search_type,
args.start_index,
args.max_contexts_per_item,
))
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果输出逻辑保持不变
if args.output:
out_path = args.output
else:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
out_filename = f"{args.dataset}_{args.sample_size}.json"
out_path = os.path.join(dataset_results_dir, out_filename)
out_dir = os.path.dirname(out_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n结果已保存到: {out_path}")
if __name__ == "__main__":
main()

View File

@@ -1064,13 +1064,16 @@ class ExtractionOrchestrator:
if statement.triplet_extraction_info:
triplet_info = statement.triplet_extraction_info
# 创建实体索引到ID的映射
# 创建实体索引到ID的映射(支持多种索引方式)
entity_idx_to_id = {}
# 创建实体节点
for entity_idx, entity in enumerate(triplet_info.entities):
# 映射实体索引到实体ID
# 映射实体索引到实体ID(使用多个键以提高容错性)
# 1. 使用实体自己的 entity_idx
entity_idx_to_id[entity.entity_idx] = entity.id
# 2. 使用枚举索引从0开始
entity_idx_to_id[entity_idx] = entity.id
if entity.id not in entity_id_set:
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
@@ -1149,9 +1152,18 @@ class ExtractionOrchestrator:
relationship_result
)
else:
logger.warning(
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
f"object_id={triplet.object_id}, statement_id={statement.id}"
# 改进的警告信息,包含更多调试信息
missing_subject = "subject" if not subject_entity_id else ""
missing_object = "object" if not object_entity_id else ""
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
logger.debug(
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
f"object_id={triplet.object_id} ({triplet.object_name}), "
f"predicate={triplet.predicate}, "
f"statement_id={statement.id}, "
f"available_indices={sorted(entity_idx_to_id.keys())}"
)
logger.info(

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import relationship
from app.base.type import PydanticType
from app.db import Base
from app.schemas import ModelParameters
from app.schemas.app_schema import ModelParameters
class AgentConfig(Base):

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import relationship
from app.base.type import PydanticType
from app.db import Base
from app.schemas import ModelParameters
from app.schemas.app_schema import ModelParameters
class OrchestrationMode(StrEnum):

View File

@@ -4,7 +4,7 @@ import datetime
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, ConfigDict, field_serializer
from app.schemas import ModelParameters
from app.schemas.app_schema import ModelParameters
# ==================== 子 Agent 配置 ====================

View File

@@ -5,7 +5,7 @@ import uuid
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.orm import Session
from app.schemas import ModelParameters
from app.schemas.app_schema import ModelParameters
from app.services.conversation_state_manager import ConversationStateManager
from app.models import ModelConfig, AgentConfig
from app.core.logging_config import get_business_logger

View File

@@ -57,7 +57,7 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]:
if data is None:
return None
from app.schemas import ModelParameters
from app.schemas.app_schema import ModelParameters
if isinstance(data, ModelParameters):
return data

View File

@@ -31,6 +31,7 @@ def upgrade() -> None:
op.execute("UPDATE memory_config SET config_id = apply_id::uuid")
op.alter_column('memory_config', 'config_id', nullable=False)
op.create_primary_key('memory_config_pkey', 'memory_config', ['config_id'])
op.execute("ALTER TABLE memory_config ALTER COLUMN config_id_old DROP DEFAULT")
op.execute("DROP SEQUENCE IF EXISTS data_config_config_id_seq")

View File

@@ -15,7 +15,7 @@ interface ApiResponse<T> {
interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> {
url: string;
params?: Record<string, unknown>;
valueKey?: string;
valueKey?: string | string[];
labelKey?: string;
placeholder?: string;
hasAll?: boolean;
@@ -66,11 +66,18 @@ const CustomSelect: FC<CustomSelectProps> = ({
{...props}
>
{hasAll && <Select.Option value={null}>{allTitle || t('common.all')}</Select.Option>}
{displayOptions.map((option) => (
<Select.Option key={option[valueKey]} value={option[valueKey]}>
{String(option[labelKey])}
</Select.Option>
))}
{displayOptions.map((option) => {
const getValue = () => {
if (typeof valueKey === 'string') return option[valueKey];
return valueKey.find(key => option[key] != null) ? option[valueKey.find(key => option[key] != null)!] : undefined;
};
const value = getValue();
return (
<Select.Option key={value} value={value}>
{String(option[labelKey])}
</Select.Option>
);
})}
</Select>
);
};

View File

@@ -79,7 +79,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[],
placeholder={t('common.pleaseSelect')}
url={url}
hasAll={false}
valueKey='config_id'
valueKey={['config_id_old', 'config_id']}
labelKey="config_name"
/>
</Form.Item>
@@ -126,12 +126,14 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
getApplicationConfig(id as string).then(res => {
const response = res as Config
let allTools = Array.isArray(response.tools) ? response.tools : []
const memoryContent = response.memory?.memory_content
const convertedMemoryContent = memoryContent && !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent
form.setFieldsValue({
...response,
tools: allTools,
memory: {
...response.memory,
memory_content: response.memory?.memory_content ? Number(response.memory?.memory_content) : undefined
memory_content: convertedMemoryContent
}
})
setData({

View File

@@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
useEffect(() => {
if (values?.retrieve_type) {
const fieldsToReset = Object.keys(values).filter(key =>
key !== 'kb_id' && key !== 'retrieve_type'
key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k'
) as (keyof KnowledgeConfigForm)[];
form.resetFields(fieldsToReset);
}

View File

@@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
useEffect(() => {
if (values?.retrieve_type) {
const fieldsToReset = Object.keys(values).filter(key =>
key !== 'kb_id' && key !== 'retrieve_type'
key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k'
) as (keyof KnowledgeConfigForm)[];
form.resetFields(fieldsToReset);
}
@@ -108,6 +108,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
label: t(`application.${key}`),
value: key,
}))}
// onChange={handleChange}
/>
</FormItem>
{/* Top K */}
@@ -116,13 +117,12 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
label={t('application.top_k')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
extra={t('application.top_k_desc')}
initialValue={5}
>
<InputNumber
style={{ width: '100%' }}
min={1}
max={20}
onChange={(value) => form.setFieldValue('top_k', value)}
// onChange={(value) => form.setFieldValue('top_k', value)}
/>
</FormItem>
{/* 语义相似度阈值 similarity_threshold */}

View File

@@ -200,7 +200,7 @@ export const nodeLibrary: NodeLibrary[] = [
config_id: {
type: 'customSelect',
url: memoryConfigListUrl,
valueKey: 'config_id',
valueKey: ['config_id_old', 'config_id'],
labelKey: 'config_name'
},
search_switch: {
@@ -223,7 +223,7 @@ export const nodeLibrary: NodeLibrary[] = [
config_id: {
type: 'customSelect',
url: memoryConfigListUrl,
valueKey: 'config_id',
valueKey: ['config_id_old', 'config_id'],
labelKey: 'config_name'
}
}
@@ -284,7 +284,7 @@ export const nodeLibrary: NodeLibrary[] = [
config: {
input: {
type: 'variableList',
filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop'],
filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop', 'parameter-extractor'],
filterVariableNames: ['message']
},
parallel: {

View File

@@ -14,7 +14,7 @@ export interface NodeConfig {
url?: string;
params?: { [key: string]: unknown; }
valueKey?: string;
valueKey?: string | string[];
labelKey?: string;
defaultValue?: any;