delete benchmark-test (#204)

* Refactor: Move evaluation folder to redbear-mem-benchmark submodule

* [changes]Restore .gitmodules
This commit is contained in:
乐力齐
2026-01-26 20:30:07 +08:00
committed by GitHub
parent e2c67d0c5b
commit c3ea3b751b
19 changed files with 1 additions and 9110 deletions

View File

@@ -1,224 +0,0 @@
# ============================================================================
# 基准测试统一配置文件示例
# ============================================================================
# 复制此文件为 .env.evaluation 并根据需要修改
# 支持的基准测试LoCoMo、LongMemEval、MemSciQA
# ============================================================================
# ============================================================================
# 通用配置(所有基准测试共用)
# ============================================================================
# ----------------------------------------------------------------------------
# Neo4j 配置
# ----------------------------------------------------------------------------
# 默认 Group ID建议各基准测试使用独立的 group
EVAL_GROUP_ID=benchmark_default
# ----------------------------------------------------------------------------
# 模型配置(必需)
# ----------------------------------------------------------------------------
# ⚠️ 必填:从数据库 models 表中选择有效的模型 ID
#
# 如何获取模型 ID
# 1. 查询数据库SELECT id, model_name FROM models WHERE is_active = true;
# 2. 或通过系统管理界面查看
# 3. 确保模型可用且配置正确
# LLM 模型 ID必填
EVAL_LLM_ID=your_llm_model_id_here
# Embedding 模型 ID必填
EVAL_EMBEDDING_ID=your_embedding_model_id_here
# ----------------------------------------------------------------------------
# 检索参数
# ----------------------------------------------------------------------------
# 检索类型: "keyword", "embedding", "hybrid"
EVAL_SEARCH_TYPE=hybrid
# 检索结果数量限制(默认值)
EVAL_SEARCH_LIMIT=12
# 上下文最大字符数(默认值)
EVAL_MAX_CONTEXT_CHARS=8000
# ----------------------------------------------------------------------------
# LLM 参数
# ----------------------------------------------------------------------------
# LLM 温度参数0.0 = 确定性输出)
EVAL_LLM_TEMPERATURE=0.0
# LLM 最大生成 token 数
EVAL_LLM_MAX_TOKENS=32
# LLM 超时时间(秒)
EVAL_LLM_TIMEOUT=10.0
# LLM 最大重试次数
EVAL_LLM_MAX_RETRIES=1
# ----------------------------------------------------------------------------
# 数据处理参数
# ----------------------------------------------------------------------------
# Chunker 策略
EVAL_CHUNKER_STRATEGY=RecursiveChunker
# 是否在导入前清空现有数据
EVAL_RESET_ON_INGEST=true
# 是否保存详细日志
EVAL_SAVE_DETAILED_LOGS=true
# ============================================================================
# LoCoMo 基准测试专用配置
# ============================================================================
# 数据集locomo10.json
# 运行python locomo_benchmark.py --sample_size 20
# ----------------------------------------------------------------------------
# Group IDLoCoMo 专用)
LOCOMO_GROUP_ID=locomo_benchmark
# 测试样本数量
# 建议值20快速测试、100中等测试、1986完整测试
LOCOMO_SAMPLE_SIZE=20
# 检索结果数量限制
LOCOMO_SEARCH_LIMIT=12
# 上下文最大字符数
LOCOMO_CONTEXT_CHAR_BUDGET=8000
# 导入的对话数量
LOCOMO_MAX_DIALOGUES=1
# 跳过数据摄入true=跳过false=摄入)
# 首次运行设置为 false后续运行可设置为 true 以节省时间
LOCOMO_SKIP_INGEST=false
# 结果保存目录
LOCOMO_OUTPUT_DIR=locomo/results
# ============================================================================
# LongMemEval 基准测试专用配置
# ============================================================================
# 数据集longmemeval_oracle_zh.json
# 运行python longmemeval_benchmark.py --sample_size 3
# 特点:支持时间推理问题的增强检索
# ----------------------------------------------------------------------------
# Group IDLongMemEval 专用)
LONGMEMEVAL_GROUP_ID=longmemeval_zh_bak_3
# 测试样本数量(<=0 表示全部样本)
LONGMEMEVAL_SAMPLE_SIZE=3
# 起始样本索引
LONGMEMEVAL_START_INDEX=0
# 检索结果数量限制
LONGMEMEVAL_SEARCH_LIMIT=8
# 上下文最大字符数
LONGMEMEVAL_CONTEXT_CHAR_BUDGET=4000
# LLM 最大生成 token 数
LONGMEMEVAL_LLM_MAX_TOKENS=16
# 每条样本最多摄入的上下文段数
LONGMEMEVAL_MAX_CONTEXTS_PER_ITEM=2
# 是否保存分块结果
LONGMEMEVAL_SAVE_CHUNK_OUTPUT=true
# 自定义分块输出路径(留空使用默认)
LONGMEMEVAL_SAVE_CHUNK_OUTPUT_PATH=
# 摄入前是否清空组数据
LONGMEMEVAL_RESET_GROUP_BEFORE_INGEST=false
# 是否跳过摄入,仅检索评估
LONGMEMEVAL_SKIP_INGEST=false
# 结果保存目录
LONGMEMEVAL_OUTPUT_DIR=longmemeval/results
# ============================================================================
# MemSciQA 基准测试专用配置
# ============================================================================
# 数据集msc_self_instruct.jsonl
# 运行python memsciqa_benchmark.py --sample_size 1
# 特点:对话记忆检索评估
# ----------------------------------------------------------------------------
# Group IDMemSciQA 专用,独立数据集)
MEMSCIQA_GROUP_ID=memsciqa_benchmark
# 测试样本数量
MEMSCIQA_SAMPLE_SIZE=1 # 0或者-1标识测试数据集中的所有样本
# 检索结果数量限制
MEMSCIQA_SEARCH_LIMIT=8
# 上下文最大字符数
MEMSCIQA_CONTEXT_CHAR_BUDGET=4000
# LLM 最大生成 token 数
MEMSCIQA_LLM_MAX_TOKENS=64
# 跳过数据摄入true=跳过false=摄入)
# 首次运行设置为 false后续运行可设置为 true 以节省时间
MEMSCIQA_SKIP_INGEST=false
# 结果保存目录(相对于 memsciqa 脚本所在目录)
# 使用 "results" 会保存到 api/app/core/memory/evaluation/memsciqa/results/
MEMSCIQA_OUTPUT_DIR=results
# ============================================================================
# 高级配置(可选)
# ============================================================================
# BM25 权重用于混合检索0.0-1.0
EVAL_RERANK_ALPHA=0.6
# 是否使用遗忘重排序
EVAL_USE_FORGETTING_RERANK=false
# 是否使用 LLM 重排序
EVAL_USE_LLM_RERANK=false
# 连接重置间隔(每 N 个问题重置一次)
EVAL_RESET_INTERVAL=5
# 性能阈值(低于此值触发重置)
EVAL_PERFORMANCE_THRESHOLD=0.6
# ============================================================================
# 快速配置指南
# ============================================================================
# 1. 复制此文件为 .env.evaluation
# 2. 修改 EVAL_LLM_ID 和 EVAL_EMBEDDING_ID 为你的模型 ID
# 3. 根据需要修改各基准测试的专用配置
# 4. 运行测试:
# - LoCoMo: python locomo/locomo_benchmark.py --sample_size 20
# - LongMemEval: python longmemeval/longmemeval_benchmark.py --sample_size 3 --all
# - MemSciQA: python memsciqa/memsciqa_benchmark.py --sample_size 10
# 配置优先级:
# 命令行参数 > 特定配置(如 LOCOMO_*> 通用配置EVAL_*> 代码默认值
# ============================================================================
# 执行LoCoMo测试
# 只摄入前5条消息评估3个问题最小测试
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
#
# 如果数据已经摄入,跳过摄入阶段直接测试
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
# 执行longmemeval测试
# python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 10 --max-contexts-per-item 3 --reset-group-before-ingest
# 执行memsciqa测试
# python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 1

View File

@@ -1,13 +0,0 @@
# 忽略实际的评估配置文件(包含敏感信息)
.env.evaluation
# 保留示例文件
!.env.evaluation.example
# 忽略测试结果文件
*/results/*.json
*/results/*.log
# 忽略数据集文件(文件过大,不应提交到 Git
dataset/*.json
dataset/*.jsonl

View File

@@ -1 +0,0 @@
"""Evaluation package with dataset-specific pipelines and a unified runner."""

View File

@@ -1,748 +0,0 @@
# 1.数据集下载地址
Locomo10.json : https://github.com/snap-research/locomo/tree/main/data
LongMemEval_oracle.json : https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
msc_self_instruct.jsonl : https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
数据集下载之后保存至api\app\core\memory\evaluation\dataset目录下
# 2.配置说明
文件api\app\core\memory\evaluation\.env.evaluation.example对三个基准测试所需配置有着详细的说明
**实际配置文件**api\app\core\memory\evaluation\.env.evaluation
```python
# 当使用不带配置参数的命令行执行基准测试,基准测试所需的配置参数根据.env.evaluation中的参数执行
python -m app.core.memory.evaluation.locomo.locomo_benchmark
```
**检查neo4j指定的grou_id是否摄入数据**
```python
# 1. 进入交互模式
python -m app.core.memory.evaluation.check_enduser_data
# 2. 选择 "1" 检查指定 group
# 3. 输入 group_id例如: locomo_benchmark
# 4. 选择是否显示详细统计 (y/n)
```
# 3.locomo
### (1)locomo执行命令
```python
# 首先进入api目录
cd api
# 只摄入前5条消息评估3个问题最小测试
python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
```
### (2)locomo结果说明
#### 结果示例
```json
{
"dataset": "locomo",
"sample_size": 0,
"timestamp": "2026-01-26T11:24:28.239156",
"params": {
"group_id": "locomo_benchmark",
"search_type": "hybrid",
"search_limit": 12,
"context_char_budget": 8000,
"llm_id": "2c9b0782-7a85-4740-ba84-4baf77f256c4",
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2"
},
"overall_metrics": {
"f1": 0.0,
"bleu1": 0.0,
"jaccard": 0.0,
"locomo_f1": 0.0
},
"by_category": {},
"latency": {
"search": {
"mean": 0.0,
"p50": 0.0,
"p95": 0.0,
"iqr": 0.0
},
"llm": {
"mean": 0.0,
"p50": 0.0,
"p95": 0.0,
"iqr": 0.0
}
},
"context_stats": {
"avg_retrieved_docs": 0.0,
"avg_context_chars": 0.0,
"avg_context_tokens": 0.0
},
"samples": []
}
```
#### 参数详解
##### 1. 核心评估指标 (overall_metrics)
**🎯 关键进步指标:**
- **`f1`** (F1 Score): 精确率和召回率的调和平均值
- 范围0.0 - 1.0
- **越高越好**,衡量检索和生成答案的准确性
- 这是最重要的综合性能指标
- 优秀标准:> 0.85
- **`bleu1`** (BLEU-1): 单词级别的匹配度
- 范围0.0 - 1.0
- **越高越好**,衡量生成答案与标准答案的词汇重叠度
- 关注词汇层面的准确性
- **`jaccard`** (Jaccard 相似度): 集合相似度
- 范围0.0 - 1.0
- **越高越好**,衡量答案集合的相似性
- 计算公式:交集大小 / 并集大小
- **`locomo_f1`**: Locomo 特定的 F1 分数
- 范围0.0 - 1.0
- **越高越好**,针对 Locomo 数据集优化的评估指标
- 考虑了长对话记忆的特殊性
##### 2. 性能指标 (latency)
**⚡ 关键效率指标:**
- **`search`**: 检索延迟统计(单位:毫秒)
- `mean`: 平均延迟
- `p50`: 中位数延迟50%的请求在此时间内完成)
- `p95`: 95分位数延迟95%的请求在此时间内完成)
- `iqr`: 四分位距Q3-Q1衡量稳定性
- **越低越好**,衡量记忆检索速度
- 优秀标准p95 < 2000ms
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
- `mean`: 平均推理时间
- `p50`: 中位数推理时间
- `p95`: 95分位数推理时间
- `iqr`: 四分位距(越小越稳定)
- **越低越好**,衡量答案生成速度
- 优秀标准p95 < 3000ms
##### 3. 上下文统计 (context_stats)
**📊 资源效率指标:**
- **`avg_retrieved_docs`**: 平均检索文档数
- 反映检索策略的广度
- 需要平衡:太少可能信息不足,太多增加噪音和延迟
- 建议范围8-15 个文档
- **`avg_context_chars`**: 平均上下文字符数
- 反映检索内容的总量
- 应在满足准确性前提下尽量精简
- 受 `context_char_budget` 参数限制
- **`avg_context_tokens`**: 平均上下文 token 数
- **越低越好**(在保持准确性前提下)
- 直接影响 API 调用成本和推理速度
- 成本效益比 = f1 / avg_context_tokens
##### 4. 分类统计 (by_category)
- 按问题类型分类的性能指标
- 帮助识别系统在不同场景下的强弱项
- 可针对性优化特定类型的问题
#### 系统进步衡量标准
**一级指标(最重要):**
- `f1` 和 `locomo_f1` 提升 → 核心能力提升
- 目标f1 > 0.85
**二级指标(重要):**
- `latency.p95` 降低 → 用户体验提升
- 目标search.p95 < 2000ms, llm.p95 < 3000ms
**三级指标(辅助):**
- `avg_context_tokens` 降低(在保持 f1 前提下)→ 成本优化
- `iqr` 降低 → 性能稳定性提升
# 4.longmemeval
支持时间推理问题的增强检索
### (1)执行命令
```python
# 首先进入api目录
cd api
# 不带参数运行 - 使用环境变量
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark
# 命令行参数覆盖环境变量
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 2
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --skip_ingest
```
### (2)结果说明
#### 结果示例
```json
{
"dataset": "longmemeval",
"items": 1,
"accuracy_by_type": {
"single-session-user": 1.0
},
"f1_by_type": {
"single-session-user": 1.0
},
"jaccard_by_type": {
"single-session-user": 1.0
},
"samples": [
{
"question": "What degree did I graduate with?",
"prediction": "Business Administration",
"answer": "Business Administration",
"question_type": "single-session-user",
"is_temporal": false,
"question_id": "e47becba",
"options": [],
"context_count": 13,
"context_chars": 1268,
"retrieved_dialogue_count": 0,
"retrieved_statement_count": 12,
"metrics": {
"exact_match": true,
"f1": 1.0,
"jaccard": 1.0
},
"timing": {
"search_ms": 1483.100175857544,
"llm_ms": 995.8682060241699
}
}
],
"latency": {
"search": {
"mean": 1483.100175857544,
"p50": 1483.100175857544,
"p95": 1483.100175857544,
"iqr": 0.0
},
"llm": {
"mean": 995.8682060241699,
"p50": 995.8682060241699,
"p95": 995.8682060241699,
"iqr": 0.0
}
},
"context": {
"avg_tokens": 204.0,
"avg_chars": 1268,
"count_avg": 13
},
"params": {
"group_id": "longmemeval_zh_bak_3",
"search_limit": 8,
"context_char_budget": 4000,
"search_type": "hybrid",
"llm_id": "6dc52e1b-9cec-4194-af66-a74c6307fc3f",
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2",
"sample_size": 1,
"start_index": 0
},
"timestamp": "2026-01-24T21:36:10.818308",
"metric_summary": {
"score_accuracy": 100.0,
"latency_median_s": 2.478968381881714,
"latency_iqr_s": 0.0,
"avg_context_tokens_k": 0.204
},
"diagnostics": {
"duplicate_previews_top": [],
"unique_preview_count": 1
}
}
```
#### 参数详解
##### 1. 核心评估指标
**🎯 关键进步指标:**
- **`accuracy_by_type`**: 按问题类型分类的准确率
- 范围0.0 - 1.0
- **越高越好**1.0 表示 100% 准确
- 问题类型包括:
- `single-session-user`: 单会话用户信息
- `single-session-event`: 单会话事件信息
- `multi-session-user`: 多会话用户信息
- `multi-session-event`: 多会话事件信息
- 可以识别系统在不同场景下的强弱项
- **`f1_by_type`**: 按问题类型的 F1 分数
- 范围0.0 - 1.0
- **越高越好**,综合评估精确率和召回率
- 比单纯的准确率更全面
- **`jaccard_by_type`**: 按问题类型的 Jaccard 相似度
- 范围0.0 - 1.0
- **越高越好**,衡量答案集合匹配度
- 对于集合类答案特别有用
##### 2. 样本级指标 (samples)
**详细诊断指标:**
- **`metrics.exact_match`**: 精确匹配(布尔值)
- **true 越多越好**,最严格的评估标准
- 要求预测答案与标准答案完全一致
- **`metrics.f1`**: 单个样本的 F1 分数
- 范围0.0 - 1.0
- **越高越好**,衡量单个问题的回答质量
- **`is_temporal`**: 是否为时间推理问题
- 布尔值,标识问题是否涉及时间推理
- 时间推理问题通常更具挑战性
- **`context_count`**: 检索到的上下文数量
- 反映检索策略的有效性
- 建议范围8-15 个上下文片段
- **`retrieved_dialogue_count`**: 检索到的对话数
- **`retrieved_statement_count`**: 检索到的陈述数
- 这两个指标帮助理解检索的内容类型分布
- 可用于优化检索策略
- **`timing.search_ms`**: 单个问题的检索延迟(毫秒)
- **`timing.llm_ms`**: 单个问题的 LLM 推理延迟(毫秒)
- **越低越好**,反映单次查询的响应速度
##### 3. 汇总指标 (metric_summary)
**📊 关键 KPI**
- **`score_accuracy`**: 总体准确率百分比
- 范围0.0 - 100.0
- **越高越好**,最直观的性能指标
- 优秀标准:> 90.0
- **`latency_median_s`**: 中位延迟(秒)
- **越低越好**,反映真实响应速度
- 优秀标准:< 3.0 秒
- **`latency_iqr_s`**: 延迟四分位距(秒)
- **越低越好**,反映性能稳定性
- 越小说明响应时间越稳定
- **`avg_context_tokens_k`**: 平均上下文 token 数(千)
- **越低越好**(在保持准确性前提下)
- 直接影响 API 调用成本
- 成本效益比 = score_accuracy / (avg_context_tokens_k * 1000)
##### 4. 上下文统计 (context)
- **`avg_tokens`**: 平均 token 数
- **`avg_chars`**: 平均字符数
- **`count_avg`**: 平均上下文片段数
- 这些指标反映检索内容的规模
- 需要在准确性和效率之间平衡
##### 5. 性能指标 (latency)
**⚡ 效率指标:**
- **`search`**: 检索延迟统计(单位:毫秒)
- `mean`: 平均延迟
- `p50`: 中位数延迟
- `p95`: 95分位数延迟
- `iqr`: 四分位距
- **越低越好**,衡量记忆检索速度
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
- `mean`: 平均推理时间
- `p50`: 中位数推理时间
- `p95`: 95分位数推理时间
- `iqr`: 四分位距
- **越低越好**,衡量答案生成速度
##### 6. 诊断信息 (diagnostics)
- **`duplicate_previews_top`**: 重复预览统计
- 列出出现频率最高的重复内容
- 帮助发现检索冗余问题
- 应该尽量减少重复
- **`unique_preview_count`**: 唯一预览数量
- 反映检索多样性
- **越高越好**,说明检索到的内容更丰富
#### 系统进步衡量标准
**一级指标(最重要):**
- `score_accuracy` 提升 → 核心能力提升
- 目标:> 90.0%
- 各类型的 `accuracy_by_type` 均衡提升 → 全面能力提升
**二级指标(重要):**
- `latency_median_s` 降低 → 用户体验提升
- 目标:< 3.0 秒
- `exact_match` 比例提升 → 精确度提升
**三级指标(辅助):**
- `avg_context_tokens_k` 降低(在保持准确性前提下)→ 成本优化
- `unique_preview_count` 提升 → 检索多样性提升
- `latency_iqr_s` 降低 → 性能稳定性提升
**特殊关注:**
- 时间推理问题(`is_temporal: true`)的准确率
- 多会话问题的准确率(通常更具挑战性)
# 5.memsciqa
对话记忆检索评估
### (1)执行命令
```python
# 首先进入api目录
cd api
# 不带参数运行 - 使用环境变量
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark
# 命令行参数覆盖环境变量
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 100
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --skip_ingest
```
### (2)结果说明
#### 结果示例
```json
{
"dataset": "memsciqa",
"items": 1,
"metrics": {
"accuracy": 0.0,
"f1": 0.0,
"bleu1": 0.0,
"jaccard": 0.0
},
"latency": {
"search": {
"mean": 0.0,
"p50": 0.0,
"p95": 0.0,
"iqr": 0.0
},
"llm": {
"mean": 3067.7285194396973,
"p50": 3067.7285194396973,
"p95": 3067.7285194396973,
"iqr": 0.0
}
},
"avg_context_tokens": 4.0
}
```
#### 参数详解
##### 1. 核心评估指标 (metrics)
**🎯 关键进步指标:**
- **`accuracy`**: 准确率
- 范围0.0 - 1.0
- **越高越好**,最直接的性能指标
- 衡量系统回答正确的问题比例
- 优秀标准:> 0.85
- **`f1`**: F1 分数
- 范围0.0 - 1.0
- **越高越好**,平衡精确率和召回率
- 计算公式2 * (precision * recall) / (precision + recall)
- 比单纯的准确率更全面,特别适合不平衡数据集
- **`bleu1`**: BLEU-1 分数
- 范围0.0 - 1.0
- **越高越好**,衡量词汇级别的匹配度
- 关注生成答案与标准答案的单词重叠
- 源自机器翻译评估,适用于自然语言生成
- **`jaccard`**: Jaccard 相似度
- 范围0.0 - 1.0
- **越高越好**,衡量集合相似性
- 计算公式:|A ∩ B| / |A B|
- 对于多答案或集合类问题特别有用
##### 2. 性能指标 (latency)
**⚡ 效率指标:**
- **`search`**: 检索延迟统计(单位:毫秒)
- `mean`: 平均检索延迟
- `p50`: 中位数延迟50%的请求在此时间内完成)
- `p95`: 95分位数延迟95%的请求在此时间内完成)
- `iqr`: 四分位距Q3-Q1衡量稳定性
- **越低越好**,衡量记忆检索效率
- 优秀标准p95 < 2000ms
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
- `mean`: 平均推理时间
- `p50`: 中位数推理时间
- `p95`: 95分位数推理时间
- `iqr`: 四分位距(越小越稳定)
- **越低越好**,衡量答案生成速度
- 优秀标准p95 < 3000ms
- 注意LLM 延迟通常占总延迟的大部分
##### 3. 资源指标
- **`avg_context_tokens`**: 平均上下文 token 数
- **越低越好**(在保持准确性前提下)
- 直接影响:
- API 调用成本(按 token 计费)
- 推理速度token 越多越慢)
- 上下文窗口占用
- 成本效益比 = accuracy / avg_context_tokens
- 建议范围:根据模型上下文窗口和成本预算调整
##### 4. 数据集特点
- **`items`**: 评估的问题数量
- 样本量越大,评估结果越可靠
- 建议至少 100 个样本以获得稳定的评估结果
- **对话记忆特性**
- MemSciQA 专注于对话历史中的记忆检索
- 评估系统从多轮对话中提取和回忆信息的能力
- 模拟真实的对话场景
#### 系统进步衡量标准
**一级指标(最重要):**
- `accuracy` 提升 → 核心能力提升
- 目标:> 0.85
- `f1` 提升 → 综合性能提升
- 目标:> 0.80
**二级指标(重要):**
- `latency.p95` 降低 → 用户体验提升
- search.p95 目标:< 2000ms
- llm.p95 目标:< 3000ms
- `iqr` 降低 → 性能稳定性提升
**三级指标(辅助):**
- `avg_context_tokens` 降低(在保持准确性前提下)→ 成本优化
- `bleu1` 和 `jaccard` 提升 → 答案质量提升
**综合评估:**
- 成本效益比 = accuracy / avg_context_tokens
- 该比值越高,说明系统在相同成本下性能越好
- 总延迟 = search.p95 + llm.p95
- 应控制在 5 秒以内以保证良好的用户体验
#### 优化建议
**提升准确性:**
- 优化检索算法(调整 hybrid search 参数)
- 改进 embedding 模型质量
- 增加检索上下文数量(`search_limit`
- 优化 prompt 工程
**提升效率:**
- 减少不必要的检索文档
- 使用更快的 LLM 模型或量化版本
- 实施缓存策略(相似问题复用结果)
- 优化数据库索引
**平衡性能:**
- 监控 accuracy vs latency 的权衡
- 监控 accuracy vs cost (tokens) 的权衡
- 根据业务需求调整优先级
---
# 6. 三个基准测试对比总结
## 6.1 测试特点对比
| 基准测试 | 主要评估目标 | 数据集特点 | 适用场景 |
|---------|------------|-----------|---------|
| **Locomo** | 长对话记忆检索 | 长对话历史,多轮交互 | 评估长期记忆保持和检索能力 |
| **LongMemEval** | 时间推理和多会话记忆 | 支持时间推理,多会话场景 | 评估时间感知和跨会话记忆能力 |
| **MemSciQA** | 对话记忆问答 | 对话历史问答 | 评估对话上下文理解和记忆提取 |
## 6.2 核心指标对比
### 准确性指标
| 指标 | Locomo | LongMemEval | MemSciQA | 说明 |
|-----|--------|-------------|----------|------|
| **F1 Score** | ✅ | ✅ | ✅ | 所有测试都使用,最重要的综合指标 |
| **Accuracy** | ❌ | ✅ | ✅ | 直观的准确率指标 |
| **BLEU-1** | ✅ | ❌ | ✅ | 词汇级别匹配度 |
| **Jaccard** | ✅ | ✅ | ✅ | 集合相似度 |
| **Exact Match** | ❌ | ✅ | ❌ | 最严格的评估标准 |
### 性能指标
所有三个测试都包含:
- **检索延迟** (search latency): mean, p50, p95, iqr
- **LLM 延迟** (llm latency): mean, p50, p95, iqr
- **上下文统计**: token 数、字符数、文档数
## 6.3 关键进步指标优先级
### 🥇 一级指标(必须关注)
1. **准确性指标**
- Locomo: `f1`, `locomo_f1`
- LongMemEval: `score_accuracy`, `accuracy_by_type`
- MemSciQA: `accuracy`, `f1`
- **目标**: > 85% 或 > 0.85
2. **综合性能**
- 所有测试的 F1 分数应保持一致性
- 不同类型问题的准确率应均衡
### 🥈 二级指标(重要)
3. **响应延迟**
- `latency.p95` (95分位数延迟)
- **目标**:
- search.p95 < 2000ms
- llm.p95 < 3000ms
- 总延迟 < 5000ms
4. **性能稳定性**
- `iqr` (四分位距)
- **目标**: 越小越好,说明性能稳定
### 🥉 三级指标(优化)
5. **成本效率**
- `avg_context_tokens`
- **目标**: 在保持准确性前提下最小化
- 成本效益比 = accuracy / avg_context_tokens
6. **检索质量**
- `avg_retrieved_docs` 的合理性
- `unique_preview_count` (LongMemEval)
- 检索内容的多样性和相关性
## 6.4 系统优化路径
### 阶段一:提升准确性(优先级最高)
**目标**: 所有测试的准确率 > 85%
**优化方向**:
1. 改进 embedding 模型质量
2. 优化检索算法hybrid search 参数)
3. 增加检索上下文数量(`search_limit`
4. 优化 prompt 工程
5. 改进记忆存储结构
**监控指标**:
- Locomo: `f1`, `locomo_f1`
- LongMemEval: `score_accuracy`, `exact_match` 比例
- MemSciQA: `accuracy`, `f1`
### 阶段二:优化性能(准确性达标后)
**目标**: p95 延迟 < 5 秒,性能稳定
**优化方向**:
1. 优化数据库索引和查询
2. 实施缓存策略
3. 使用更快的 LLM 模型
4. 并行化检索和推理
5. 减少不必要的检索
**监控指标**:
- `latency.p50`, `latency.p95`
- `iqr` (稳定性)
- 各阶段耗时分布
### 阶段三:降低成本(性能达标后)
**目标**: 在保持准确性和性能前提下,最小化成本
**优化方向**:
1. 精简检索上下文
2. 优化 context 选择策略
3. 使用更小的 LLM 模型
4. 实施智能缓存
5. 批处理优化
**监控指标**:
- `avg_context_tokens`
- 成本效益比 = accuracy / avg_context_tokens
- API 调用成本
## 6.5 评估最佳实践
### 测试执行建议
1. **初始测试**: 使用小样本快速验证
```bash
--sample_size 10
```
2. **完整评估**: 使用足够大的样本量
```bash
--sample_size 100 # 或更多
```
3. **增量测试**: 数据已摄入时跳过摄入阶段
```bash
--skip_ingest
```
4. **参数调优**: 系统性地调整参数并记录结果
- 调整 `search_limit`: 4, 8, 12, 16
- 调整 `context_char_budget`: 2000, 4000, 8000
- 尝试不同的 `search_type`: vector, keyword, hybrid
### 结果分析建议
1. **横向对比**: 比较三个测试的结果,识别系统的强弱项
2. **纵向对比**: 跟踪同一测试在不同版本的表现
3. **分类分析**: 关注不同问题类型的性能差异
4. **异常诊断**: 分析失败案例,找出根本原因
### 持续监控
建议建立监控仪表板,跟踪:
- 核心指标趋势(准确率、延迟)
- 成本效益比趋势
- 不同问题类型的性能分布
- 异常样本和失败模式
## 6.6 性能基准参考
### 优秀水平Production Ready
- **准确性**: accuracy/f1 > 0.90
- **延迟**: p95 < 3 秒
- **稳定性**: iqr < 500ms
- **成本效益**: accuracy/tokens > 0.0001
### 良好水平Acceptable
- **准确性**: accuracy/f1 > 0.85
- **延迟**: p95 < 5 秒
- **稳定性**: iqr < 1000ms
- **成本效益**: accuracy/tokens > 0.00005
### 需要改进Below Target
- **准确性**: accuracy/f1 < 0.85
- **延迟**: p95 > 5 秒
- **稳定性**: iqr > 1000ms
- **成本效益**: accuracy/tokens < 0.00005
---
**注**: 以上标准仅供参考,实际目标应根据具体业务需求和资源约束调整。

View File

@@ -1,371 +0,0 @@
"""
交互式 Neo4j End User 数据检查工具
用于查询指定 end_user_id 在 Neo4j 中是否存在数据,以及数据的详细统计信息。
使用方法:
python check_group_data.py
python check_group_data.py --group-id locomo_benchmark
python check_group_data.py --group-id memsciqa_benchmark --detailed
"""
import asyncio
import argparse
import os
from pathlib import Path
from typing import Dict, Any
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}\n")
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def check_group_exists(end_user_id: str) -> Dict[str, Any]:
"""
检查指定 end_user_id 是否存在数据
Args:
end_user_id: 要检查的 end_user ID
Returns:
包含统计信息的字典
"""
connector = Neo4jConnector()
try:
# 查询该 end_user 的节点总数
query_total = """
MATCH (n {end_user_id: $end_user_id})
RETURN count(n) as total_nodes
"""
result_total = await connector.execute_query(query_total, end_user_id=end_user_id)
total_nodes = result_total[0]["total_nodes"] if result_total else 0
# 查询各类型节点的数量
query_by_type = """
MATCH (n {end_user_id: $end_user_id})
RETURN labels(n) as labels, count(n) as count
ORDER BY count DESC
"""
result_by_type = await connector.execute_query(query_by_type, end_user_id=end_user_id)
# 查询关系数量
query_relationships = """
MATCH (n {end_user_id: $end_user_id})-[r]-()
RETURN count(DISTINCT r) as total_relationships
"""
result_rel = await connector.execute_query(query_relationships, end_user_id=end_user_id)
total_relationships = result_rel[0]["total_relationships"] if result_rel else 0
return {
"exists": total_nodes > 0,
"total_nodes": total_nodes,
"total_relationships": total_relationships,
"nodes_by_type": result_by_type
}
finally:
await connector.close()
async def get_detailed_stats(end_user_id: str) -> Dict[str, Any]:
"""
获取详细的统计信息
Args:
end_user_id: 要检查的 end_user ID
Returns:
详细统计信息字典
"""
connector = Neo4jConnector()
try:
stats = {}
# Chunk 节点统计
query_chunks = """
MATCH (c:Chunk {end_user_id: $end_user_id})
RETURN count(c) as count,
avg(size(c.content)) as avg_content_length
"""
result_chunks = await connector.execute_query(query_chunks, end_user_id=end_user_id)
if result_chunks and result_chunks[0]["count"] > 0:
stats["chunks"] = {
"count": result_chunks[0]["count"],
"avg_content_length": int(result_chunks[0]["avg_content_length"]) if result_chunks[0]["avg_content_length"] else 0
}
# Statement 节点统计
query_statements = """
MATCH (s:Statement {end_user_id: $end_user_id})
RETURN count(s) as count
"""
result_statements = await connector.execute_query(query_statements, end_user_id=end_user_id)
if result_statements and result_statements[0]["count"] > 0:
stats["statements"] = {
"count": result_statements[0]["count"]
}
# Entity 节点统计
query_entities = """
MATCH (e:Entity {end_user_id: $end_user_id})
RETURN count(e) as count,
count(DISTINCT e.entity_type) as unique_types
"""
result_entities = await connector.execute_query(query_entities, end_user_id=end_user_id)
if result_entities and result_entities[0]["count"] > 0:
stats["entities"] = {
"count": result_entities[0]["count"],
"unique_types": result_entities[0]["unique_types"]
}
# Dialogue 节点统计
query_dialogues = """
MATCH (d:Dialogue {end_user_id: $end_user_id})
RETURN count(d) as count
"""
result_dialogues = await connector.execute_query(query_dialogues, end_user_id=end_user_id)
if result_dialogues and result_dialogues[0]["count"] > 0:
stats["dialogues"] = {
"count": result_dialogues[0]["count"]
}
# Summary 节点统计
query_summaries = """
MATCH (s:Summary {end_user_id: $end_user_id})
RETURN count(s) as count
"""
result_summaries = await connector.execute_query(query_summaries, end_user_id=end_user_id)
if result_summaries and result_summaries[0]["count"] > 0:
stats["summaries"] = {
"count": result_summaries[0]["count"]
}
return stats
finally:
await connector.close()
async def list_all_end_users() -> list:
"""
列出数据库中所有的 end_user_id
Returns:
end_user_id 列表及其节点数量
"""
connector = Neo4jConnector()
try:
query = """
MATCH (n)
WHERE n.end_user_id IS NOT NULL
RETURN DISTINCT n.end_user_id as end_user_id, count(n) as node_count
ORDER BY node_count DESC
"""
results = await connector.execute_query(query)
return results
finally:
await connector.close()
def print_results(end_user_id: str, stats: Dict[str, Any], detailed_stats: Dict[str, Any] = None):
"""
打印查询结果
Args:
end_user_id: End User ID
stats: 基本统计信息
detailed_stats: 详细统计信息(可选)
"""
print(f"\n{'='*60}")
print(f"📊 End User ID: {end_user_id}")
print(f"{'='*60}\n")
if not stats["exists"]:
print("❌ 该 end_user_id 不存在数据")
print("\n💡 提示: 请先运行基准测试以摄入数据")
return
print(f"✅ 该 end_user_id 存在数据\n")
print(f"📈 基本统计:")
print(f" 总节点数: {stats['total_nodes']}")
print(f" 总关系数: {stats['total_relationships']}")
if stats["nodes_by_type"]:
print(f"\n📋 节点类型分布:")
for item in stats["nodes_by_type"]:
labels = ", ".join(item["labels"])
count = item["count"]
print(f" {labels}: {count}")
if detailed_stats:
print(f"\n🔍 详细统计:")
if "chunks" in detailed_stats:
print(f" Chunks: {detailed_stats['chunks']['count']}")
print(f" 平均内容长度: {detailed_stats['chunks']['avg_content_length']} 字符")
if "statements" in detailed_stats:
print(f" Statements: {detailed_stats['statements']['count']}")
if "entities" in detailed_stats:
print(f" Entities: {detailed_stats['entities']['count']}")
print(f" 唯一类型数: {detailed_stats['entities']['unique_types']}")
if "dialogues" in detailed_stats:
print(f" Dialogues: {detailed_stats['dialogues']['count']}")
if "summaries" in detailed_stats:
print(f" Summaries: {detailed_stats['summaries']['count']}")
print(f"\n{'='*60}\n")
async def interactive_mode():
"""
交互式模式
"""
print("\n" + "="*60)
print("🔍 Neo4j End User 数据检查工具 - 交互模式")
print("="*60 + "\n")
while True:
print("\n请选择操作:")
print(" 1. 检查指定 end_user_id")
print(" 2. 列出所有 end_user_id")
print(" 3. 退出")
choice = input("\n请输入选项 (1-3): ").strip()
if choice == "1":
end_user_id = input("\n请输入 end_user_id: ").strip()
if not end_user_id:
print("❌ end_user_id 不能为空")
continue
detailed = input("是否显示详细统计? (y/n, 默认 n): ").strip().lower() == 'y'
print("\n🔄 正在查询...")
stats = await check_group_exists(end_user_id)
detailed_stats = None
if detailed and stats["exists"]:
detailed_stats = await get_detailed_stats(end_user_id)
print_results(end_user_id, stats, detailed_stats)
elif choice == "2":
print("\n🔄 正在查询所有 end_user_id...")
end_users = await list_all_end_users()
if not end_users:
print("\n❌ 数据库中没有任何 end_user 数据")
else:
print(f"\n{'='*60}")
print(f"📋 数据库中的所有 End User ID")
print(f"{'='*60}\n")
for idx, end_user in enumerate(end_users, 1):
print(f" {idx}. {end_user['end_user_id']}")
print(f" 节点数: {end_user['node_count']}")
print(f"\n{'='*60}\n")
elif choice == "3":
print("\n👋 再见!")
break
else:
print("\n❌ 无效的选项,请重新选择")
async def main():
"""
主函数
"""
parser = argparse.ArgumentParser(
description="检查 Neo4j 中指定 end_user_id 的数据情况",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 交互模式
python check_group_data.py
# 检查指定 end_user
python check_group_data.py --end-user-id locomo_benchmark
# 检查并显示详细统计
python check_group_data.py --end-user-id memsciqa_benchmark --detailed
# 列出所有 end_user
python check_group_data.py --list-all
"""
)
parser.add_argument(
"--end-user-id",
type=str,
help="要检查的 end_user ID"
)
parser.add_argument(
"--detailed",
action="store_true",
help="显示详细统计信息"
)
parser.add_argument(
"--list-all",
action="store_true",
help="列出所有 end_user_id"
)
args = parser.parse_args()
# 如果没有提供任何参数,进入交互模式
if not args.end_user_id and not args.list_all:
await interactive_mode()
return
# 列出所有 end_user
if args.list_all:
print("\n🔄 正在查询所有 end_user_id...")
end_users = await list_all_end_users()
if not end_users:
print("\n❌ 数据库中没有任何 end_user 数据")
else:
print(f"\n{'='*60}")
print(f"📋 数据库中的所有 End User ID")
print(f"{'='*60}\n")
for idx, end_user in enumerate(end_users, 1):
print(f" {idx}. {end_user['end_user_id']}")
print(f" 节点数: {end_user['node_count']}")
print(f"\n{'='*60}\n")
return
# 检查指定 end_user
if args.end_user_id:
print(f"\n🔄 正在查询 end_user_id: {args.end_user_id}...")
stats = await check_group_exists(args.end_user_id)
detailed_stats = None
if args.detailed and stats["exists"]:
print("🔄 正在获取详细统计...")
detailed_stats = await get_detailed_stats(args.end_user_id)
print_results(args.end_user_id, stats, detailed_stats)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,100 +0,0 @@
import math
import re
from typing import List, Dict
# 评估指标的实现
def _normalize(text: str) -> List[str]:
"""Lowercase, strip punctuation, and split into tokens."""
text = text.lower().strip()
# Python's re doesn't support \p classes; use a simple non-word filter
text = re.sub(r"[^\w\s]", " ", text)
tokens = [t for t in text.split() if t]
return tokens
def exact_match(pred: str, ref: str) -> float:
return float(_normalize(pred) == _normalize(ref))
def jaccard(pred: str, ref: str) -> float:
p = set(_normalize(pred))
r = set(_normalize(ref))
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
def f1_score(pred: str, ref: str) -> float:
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
p_tokens = _normalize(pred)
r_tokens = _normalize(ref)
if not p_tokens:
return 0.0
# Clipped count
r_counts: Dict[str, int] = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts: Dict[str, int] = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
# Brevity penalty
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def percentile(values: List[float], p: float) -> float:
if not values:
return 0.0
vals = sorted(values)
k = (len(vals) - 1) * p
f = math.floor(k)
c = math.ceil(k)
if f == c:
return vals[int(k)]
return vals[f] + (k - f) * (vals[c] - vals[f])
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
if not latencies_ms:
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
p25 = percentile(latencies_ms, 0.25)
p50 = percentile(latencies_ms, 0.50)
p75 = percentile(latencies_ms, 0.75)
p95 = percentile(latencies_ms, 0.95)
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
def avg_context_tokens(contexts: List[str]) -> float:
if not contexts:
return 0.0
return sum(len(_normalize(c)) for c in contexts) / len(contexts)

View File

@@ -1,62 +0,0 @@
"""
Dialogue search queries for evaluation purposes.
This file contains Cypher queries for searching dialogues, entities, and chunks.
Placed in evaluation directory to avoid circular imports with src modules.
"""
# 应该是neo4j browser的cypher语句需要修改文件名
# Entity search queries
SEARCH_ENTITIES_BY_NAME = """
MATCH (e:ExtractedEntity)
WHERE e.name = $name
RETURN e
"""
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
MATCH (e:ExtractedEntity)
WHERE e.name CONTAINS $name
RETURN e
"""
# Chunk search queries
SEARCH_CHUNKS_BY_CONTENT = """
MATCH (c:Chunk)
WHERE c.content CONTAINS $content
RETURN c
"""
# Dialogue search queries
SEARCH_DIALOGUE_BY_DIALOG_ID = """
MATCH (d:Dialogue)
WHERE d.dialog_id = $dialog_id
RETURN d
"""
SEARCH_DIALOGUES_BY_CONTENT = """
MATCH (d:Dialogue)
WHERE d.content CONTAINS $q
RETURN d
"""
DIALOGUE_EMBEDDING_SEARCH = """
WITH $embedding AS q
MATCH (d:Dialogue)
WHERE d.dialog_embedding IS NOT NULL
AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
WITH d, q, d.dialog_embedding AS v
WITH d,
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
WHERE score > $threshold
RETURN d.id AS dialog_id,
d.end_user_id AS end_user_id,
d.content AS content,
d.created_at AS created_at,
d.expired_at AS expired_at,
score
ORDER BY score DESC
LIMIT $limit
"""

View File

@@ -1,444 +0,0 @@
import os
import asyncio
import json
from typing import List, Dict, Any, Optional
from datetime import datetime
from uuid import UUID
import re
from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent / "app" / "core" / "memory" / "evaluation" / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.llm.llm_utils import get_llm_client
# 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
# Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
# Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
async def ingest_contexts_via_full_pipeline(
contexts: List[str],
end_user_id: str,
chunker_strategy: str | None = None,
embedding_name: str | None = None,
save_chunk_output: bool = False,
save_chunk_output_path: str | None = None,
reset_group: bool = False,
) -> bool:
"""
使用新的 ExtractionOrchestrator 运行完整的提取流水线
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
This function uses the new ExtractionOrchestrator architecture for better maintainability.
Args:
contexts: List of dialogue texts, each containing lines like "role: message".
end_user_id: Group ID to assign to generated DialogData and graph nodes.
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
reset_group: If True, clear existing data for this group before ingestion.
Returns:
True if data saved successfully, False otherwise.
"""
chunker_strategy = chunker_strategy or os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
embedding_name = embedding_name or os.getenv("EVAL_EMBEDDING_ID")
# Check if we should reset from environment variable if not explicitly set
if not reset_group:
reset_group = os.getenv("EVAL_RESET_ON_INGEST", "false").lower() in ("true", "1", "yes")
# Step 0: Reset group if requested
if reset_group:
print(f"[Ingestion] 🗑️ 清空 end_user '{end_user_id}' 的现有数据...")
try:
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
try:
# 删除该 end_user 的所有节点和关系
query = """
MATCH (n {end_user_id: $end_user_id})
DETACH DELETE n
"""
await connector.execute_query(query, end_user_id=end_user_id)
print(f"[Ingestion] ✅ End User '{end_user_id}' 已清空")
finally:
await connector.close()
except Exception as e:
print(f"[Ingestion] ⚠️ 清空 end_user 失败: {e}")
# 继续执行,不中断摄入流程
# Step 1: Initialize LLM client
llm_client = None
try:
# 使用评估配置中的 LLM ID
llm_id = os.getenv("EVAL_LLM_ID")
if not llm_id:
print("[Ingestion] ❌ EVAL_LLM_ID not set in .env.evaluation")
return False
from app.db import get_db
db = next(get_db())
try:
llm_client = get_llm_client(llm_id, db)
finally:
db.close()
except Exception as e:
print(f"[Ingestion] LLM client unavailable: {e}")
return False
# Step 2: Parse contexts and create DialogData with chunks
print(f"[Ingestion] Parsing {len(contexts)} contexts...")
chunker = DialogueChunker(chunker_strategy)
dialog_data_list: List[DialogData] = []
for idx, ctx in enumerate(contexts):
messages: List[ConversationMessage] = []
# Improved parsing: capture multi-line message blocks, normalize roles
pattern = r"^\s*(用户|AI|assistant|user)\s*[:]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[:]|\Z)"
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
if matches:
for m in matches:
raw_role = m.group(1).strip()
content = m.group(2).strip()
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=content))
else:
# Fallback: line-by-line parsing
for raw in ctx.split("\n"):
line = raw.strip()
if not line:
continue
m = re.match(r'^\s*([^:]+)\s*[:]\s*(.+)', line)
if m:
role = m.group(1).strip()
msg = m.group(2).strip()
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
messages.append(ConversationMessage(role=norm_role, msg=msg))
else:
# Final fallback: treat as user message
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
messages.append(ConversationMessage(role=default_role, msg=line))
context_model = ConversationContext(msgs=messages)
dialog = DialogData(
context=context_model,
ref_id=f"pipeline_item_{idx}",
end_user_id=end_user_id,
user_id="default_user",
apply_id="default_application",
)
# Generate chunks
dialog.chunks = await chunker.process_dialogue(dialog)
dialog_data_list.append(dialog)
if not dialog_data_list:
print("[Ingestion] No dialogs to process.")
return False
print(f"[Ingestion] Parsed {len(dialog_data_list)} dialogs with chunks")
# Step 3: Optionally save chunking outputs for debugging
if save_chunk_output:
try:
def _serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
from app.core.config import settings
settings.ensure_memory_output_dir()
default_path = settings.get_memory_output_path("chunker_test_output.txt")
out_path = save_chunk_output_path or default_path
combined_output = [dd.model_dump() for dd in dialog_data_list]
with open(out_path, "w", encoding="utf-8") as f:
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
print(f"[Ingestion] Saved chunking results to: {out_path}")
except Exception as e:
print(f"[Ingestion] Failed to save chunking results: {e}")
# Step 4: Initialize embedder client
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.db import get_db
try:
db = next(get_db())
try:
embedder_config_dict = get_embedder_config(embedding_name, db)
embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
finally:
db.close()
except Exception as e:
print(f"[Ingestion] Failed to initialize embedder client: {e}")
return False
# Step 5: Initialize Neo4j connector
connector = Neo4jConnector()
# Step 6: 构建 MemoryConfig从环境变量直接构建不依赖数据库
print("[Ingestion] 构建 MemoryConfig from environment variables...")
from app.schemas.memory_config_schema import MemoryConfig
try:
# 从环境变量获取配置参数
llm_id = os.getenv("EVAL_LLM_ID")
embedding_id = os.getenv("EVAL_EMBEDDING_ID")
chunker_strategy_env = os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
if not llm_id or not embedding_id:
print("[Ingestion] ❌ EVAL_LLM_ID or EVAL_EMBEDDING_ID is not set in .env.evaluation")
print("[Ingestion] Please set both EVAL_LLM_ID and EVAL_EMBEDDING_ID")
await connector.close()
return False
# 从数据库获取模型信息(仅用于显示名称)
from app.db import get_db
db = next(get_db())
try:
from sqlalchemy import text
# 获取 LLM 模型信息(从 model_configs 表)
llm_result = db.execute(
text("SELECT name FROM model_configs WHERE id = :id"),
{"id": llm_id}
).fetchone()
llm_model_name = llm_result[0] if llm_result else "Unknown LLM"
# 获取 Embedding 模型信息(从 model_configs 表)
emb_result = db.execute(
text("SELECT name FROM model_configs WHERE id = :id"),
{"id": embedding_id}
).fetchone()
embedding_model_name = emb_result[0] if emb_result else "Unknown Embedding"
except Exception as e:
# 如果查询失败,使用默认名称
print(f"[Ingestion] Warning: Failed to query model names from database: {e}")
llm_model_name = f"LLM ({llm_id[:8]}...)"
embedding_model_name = f"Embedding ({embedding_id[:8]}...)"
finally:
db.close()
# 构建 MemoryConfig 对象(使用最小必需配置)
from uuid import uuid4
memory_config = MemoryConfig(
config_id=0, # 评估环境不需要真实的 config_id
config_name="evaluation_config",
workspace_id=uuid4(), # 临时 workspace_id
workspace_name="evaluation_workspace",
tenant_id=uuid4(), # 临时 tenant_id
llm_model_id=UUID(llm_id),
llm_model_name=llm_model_name,
embedding_model_id=UUID(embedding_id),
embedding_model_name=embedding_model_name,
storage_type="neo4j",
chunker_strategy=chunker_strategy_env,
reflexion_enabled=False,
reflexion_iteration_period=3,
reflexion_range="partial",
reflexion_baseline="TIME",
loaded_at=datetime.now(),
# 可选字段使用默认值
rerank_model_id=None,
rerank_model_name=None,
llm_params={},
embedding_params={},
config_version="2.0",
)
print(f"[Ingestion] ✅ 构建 MemoryConfig 成功")
print(f"[Ingestion] LLM: {llm_model_name}")
print(f"[Ingestion] Embedding: {embedding_model_name}")
print(f"[Ingestion] Chunker: {chunker_strategy_env}")
except Exception as e:
print(f"[Ingestion] ❌ Failed to build MemoryConfig: {e}")
print(f"[Ingestion] Please check:")
print(f"[Ingestion] 1. EVAL_LLM_ID and EVAL_EMBEDDING_ID are set in .env.evaluation")
print(f"[Ingestion] 2. Model IDs exist in the models table")
print(f"[Ingestion] 3. Database connection is working")
await connector.close()
return False
# Step 7: Initialize and run ExtractionOrchestrator
print("[Ingestion] Running extraction pipeline with ExtractionOrchestrator...")
from app.services.memory_config_service import MemoryConfigService
config = MemoryConfigService.get_pipeline_config(memory_config)
orchestrator = ExtractionOrchestrator(
llm_client=llm_client,
embedder_client=embedder_client,
connector=connector,
config=config,
embedding_id=str(memory_config.embedding_model_id), # 传递 embedding_id
)
try:
# Run the complete extraction pipeline
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
# Handle different return formats:
# - Pilot mode: 7 values (without dedup_details)
# - Normal mode: 8 values (with dedup_details at the end)
if len(result) == 8:
# Normal mode: includes dedup_details
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
_, # dedup_details - not needed here
) = result
elif len(result) == 7:
# Pilot mode or older version: no dedup_details
(
dialogue_nodes,
chunk_nodes,
statement_nodes,
entity_nodes,
statement_chunk_edges,
statement_entity_edges,
entity_entity_edges,
) = result
else:
raise ValueError(f"Unexpected number of return values: {len(result)}")
print(f"[Ingestion] Extraction completed: {len(statement_nodes)} statements, {len(entity_nodes)} entities")
except ValueError as e:
# If unpacking fails, provide helpful error message
print(f"[Ingestion] Extraction pipeline result unpacking failed: {e}")
print(f"[Ingestion] Result type: {type(result)}, length: {len(result) if hasattr(result, '__len__') else 'N/A'}")
if hasattr(result, '__len__') and len(result) > 0:
print(f"[Ingestion] First element type: {type(result[0])}")
await connector.close()
return False
except Exception as e:
print(f"[Ingestion] Extraction pipeline failed: {e}")
import traceback
traceback.print_exc()
await connector.close()
return False
# Step 7: Generate memory summaries
print("[Ingestion] Generating memory summaries...")
try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation,
)
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list,
llm_client=llm_client,
embedder_client=embedder_client
)
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
except Exception as e:
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
summaries = []
# Step 8: Save to Neo4j
print("[Ingestion] Saving to Neo4j...")
try:
success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes,
chunk_nodes=chunk_nodes,
statement_nodes=statement_nodes,
entity_nodes=entity_nodes,
entity_edges=entity_entity_edges,
statement_chunk_edges=statement_chunk_edges,
statement_entity_edges=statement_entity_edges,
connector=connector
)
# Save memory summaries separately
if summaries:
try:
await add_memory_summary_nodes(summaries, connector)
await add_memory_summary_statement_edges(summaries, connector)
print(f"[Ingestion] Saved {len(summaries)} memory summary nodes to Neo4j")
except Exception as e:
print(f"[Ingestion] Warning: Failed to save summary nodes: {e}")
await connector.close()
if success:
print("[Ingestion] Successfully saved all data to Neo4j!")
else:
print("[Ingestion] Failed to save data to Neo4j")
return success
except Exception as e:
print(f"[Ingestion] Failed to save data to Neo4j: {e}")
await connector.close()
return False
async def handle_context_processing(args):
"""Handle context-based processing from command line arguments."""
contexts = []
if args.contexts:
contexts.extend(args.contexts)
if args.context_file:
try:
with open(args.context_file, 'r', encoding='utf-8') as f:
contexts.extend(line.strip() for line in f if line.strip())
except Exception as e:
print(f"Error reading context file: {e}")
return False
if not contexts:
print("No contexts provided for processing.")
return False
return await main_from_contexts(contexts, args.context_end_user_id)
async def main_from_contexts(contexts: List[str], end_user_id: str):
"""Run the pipeline from provided dialogue contexts instead of test data."""
print("=== Running pipeline from provided contexts ===")
success = await ingest_contexts_via_full_pipeline(
contexts=contexts,
end_user_id=end_user_id,
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
embedding_name=SELECTED_EMBEDDING_ID,
save_chunk_output=True
)
if success:
print("Successfully processed and saved contexts to Neo4j!")
else:
print("Failed to process contexts.")
return success

View File

@@ -1,770 +0,0 @@
"""
LoCoMo Benchmark Script
This module provides the main entry point for running LoCoMo benchmark evaluations.
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
in a clean, maintainable way.
Usage:
python locomo_benchmark.py --sample_size 20 --search_type hybrid
"""
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import List, Dict, Any, Optional
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.common.metrics import (
f1_score,
bleu1,
jaccard,
latency_stats,
avg_context_tokens
)
from app.core.memory.evaluation.locomo.locomo_metrics import (
locomo_f1_score,
locomo_multi_f1,
get_category_name
)
from app.core.memory.evaluation.locomo.locomo_utils import (
load_locomo_data,
extract_conversations,
resolve_temporal_references,
select_and_format_information,
retrieve_relevant_information,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
# Get configuration from environment variables
PROJECT_ROOT = str(Path(__file__).resolve().parents[5]) # api directory
SELECTED_EMBEDDING_ID = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
SELECTED_end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark")
SELECTED_LLM_ID = os.getenv("EVAL_LLM_ID", "2c9b0782-7a85-4740-ba84-4baf77f256c4")
# ============================================================================
# Step 1: Data Loading
# ============================================================================
def step_load_data(data_path: str, sample_size: int) -> List[Dict[str, Any]]:
"""
Load QA pairs from LoCoMo dataset.
Args:
data_path: Path to locomo10.json file
sample_size: Number of QA pairs to load (0 for all)
Returns:
List of QA items from the first conversation
"""
print("📂 Loading LoCoMo data...")
# Load the dataset
qa_items = load_locomo_data(data_path, sample_size)
print(f"✅ Loaded {len(qa_items)} QA pairs from first conversation\n")
return qa_items
# ============================================================================
# Step 2: Data Ingestion
# ============================================================================
async def ingest_conversations_if_needed(
conversations: List[str],
end_user_id: str,
reset: bool = False
) -> bool:
"""
Ingest conversations into Neo4j database.
Args:
conversations: List of conversation strings (already formatted)
end_user_id: Database end_user ID
reset: Whether to reset the group before ingestion
Returns:
True if successful, False otherwise
"""
try:
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
# Conversations are already formatted as strings, use them directly
await ingest_contexts_via_full_pipeline(conversations, end_user_id)
return True
except Exception as e:
print(f"⚠️ Ingestion error: {e}")
import traceback
traceback.print_exc()
return False
async def step_ingest_data(
data_path: str,
end_user_id: str,
skip_ingest: bool,
reset_group: bool,
max_messages: Optional[int] = None
) -> bool:
"""
Ingest conversations into Neo4j database if needed.
Args:
data_path: Path to locomo10.json file
end_user_id: Database end_user ID
skip_ingest: Whether to skip ingestion
reset_group: Whether to reset the group before ingestion
max_messages: Maximum messages per dialogue to ingest (for testing)
Returns:
True if ingestion succeeded or was skipped, False otherwise
"""
if skip_ingest:
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
print(f" End User ID: {end_user_id}\n")
else:
print("💾 Checking database ingestion...")
try:
# Extract conversations with optional message limit
conversations = extract_conversations(
data_path,
max_dialogues=1,
max_messages_per_dialogue=max_messages
)
print(f"📝 Extracted {len(conversations)} conversations")
# Always ingest for now (ingestion check not implemented)
print(f"🔄 Ingesting conversations into end_user '{end_user_id}'...")
success = await ingest_conversations_if_needed(
conversations=conversations,
end_user_id=end_user_id,
reset=reset_group
)
if success:
print("✅ Ingestion completed successfully\n")
else:
print("⚠️ Ingestion may have failed, continuing anyway\n")
except Exception as e:
print(f"❌ Ingestion failed: {e}")
import traceback
traceback.print_exc()
print("⚠️ Continuing with evaluation (database may be empty)\n")
return True
# ============================================================================
# Step 3: Initialize Clients
# ============================================================================
def step_initialize_clients(llm_id: str, embedding_id: str):
"""
Initialize Neo4j connector, LLM client, and embedder.
Args:
llm_id: LLM model ID
embedding_id: Embedding model ID
Returns:
Tuple of (connector, llm_client, embedder)
"""
print("🔧 Initializing clients...")
connector = Neo4jConnector()
# Get database session
from app.db import get_db
db = next(get_db())
try:
llm_client = get_llm_client(llm_id, db)
cfg_dict = get_embedder_config(embedding_id, db)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
finally:
db.close()
print("✅ Clients initialized\n")
return connector, llm_client, embedder
# ============================================================================
# Step 4: Process Questions
# ============================================================================
async def step_process_all_questions(
qa_items: List[Dict[str, Any]],
end_user_id: str,
search_type: str,
search_limit: int,
context_char_budget: int,
connector: Neo4jConnector,
embedder: OpenAIEmbedderClient,
llm_client: Any
) -> List[Dict[str, Any]]:
"""Process all QA items: retrieve, generate, and calculate metrics."""
print(f"🔍 Processing {len(qa_items)} questions...")
print(f"{'='*60}\n")
samples: List[Dict[str, Any]] = []
anchor_date = datetime(2023, 5, 8)
for idx, item in enumerate(qa_items, 1):
question = item.get("question", "")
ground_truth = item.get("answer", "")
category = get_category_name(item)
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
print(f"[{idx}/{len(qa_items)}] Category: {category}")
print(f"❓ Question: {question}")
print(f"✅ Ground Truth: {ground_truth_str}")
# Retrieve
t_search_start = time.time()
try:
retrieved_info = await retrieve_relevant_information(
question=question,
end_user_id=end_user_id,
search_type=search_type,
search_limit=search_limit,
connector=connector,
embedder=embedder
)
search_latency = (time.time() - t_search_start) * 1000
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
except Exception as e:
print(f"❌ Retrieval failed: {e}")
retrieved_info = []
search_latency = 0.0
# Format context
context_text = select_and_format_information(
retrieved_info=retrieved_info,
question=question,
max_chars=context_char_budget
)
context_text = resolve_temporal_references(context_text, anchor_date)
if context_text:
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
else:
context_text = "No relevant context found."
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
# Generate answer
messages = [
{
"role": "system",
"content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)
},
{
"role": "user",
"content": f"Question: {question}\n\nContext:\n{context_text}"
}
]
t_llm_start = time.time()
try:
response = await llm_client.chat(messages=messages)
llm_latency = (time.time() - t_llm_start) * 1000
if hasattr(response, 'content'):
prediction = response.content.strip()
elif isinstance(response, dict):
prediction = response["choices"][0]["message"]["content"].strip()
else:
prediction = "Unknown"
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
except Exception as e:
print(f"❌ LLM failed: {e}")
prediction = "Unknown"
llm_latency = 0.0
# Calculate metrics
f1_val = f1_score(prediction, ground_truth_str)
bleu1_val = bleu1(prediction, ground_truth_str)
jaccard_val = jaccard(prediction, ground_truth_str)
if item.get("category") == 1:
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
else:
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
print()
samples.append({
"question": question,
"ground_truth": ground_truth_str,
"prediction": prediction,
"category": category,
"metrics": {
"f1": f1_val,
"bleu1": bleu1_val,
"jaccard": jaccard_val,
"locomo_f1": locomo_f1_val
},
"retrieval": {
"num_docs": len(retrieved_info),
"context_length": len(context_text)
},
"context_tokens": len(context_text.split()),
"timing": {
"search_ms": search_latency,
"llm_ms": llm_latency
}
})
return samples
# ============================================================================
# Step 5: Aggregate Results
# ============================================================================
def step_aggregate_results(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Aggregate metrics from all samples."""
print(f"\n{'='*60}")
print("📊 Aggregating Results")
print(f"{'='*60}\n")
if not samples:
return {
"overall_metrics": {},
"by_category": {},
"latency": {},
"context_stats": {}
}
# Extract metrics
f1_scores = [s["metrics"]["f1"] for s in samples]
bleu1_scores = [s["metrics"]["bleu1"] for s in samples]
jaccard_scores = [s["metrics"]["jaccard"] for s in samples]
locomo_f1_scores = [s["metrics"]["locomo_f1"] for s in samples]
# Extract timing
latencies_search = [s["timing"]["search_ms"] for s in samples]
latencies_llm = [s["timing"]["llm_ms"] for s in samples]
# Extract context stats
context_counts = [s["retrieval"]["num_docs"] for s in samples]
context_chars = [s["retrieval"]["context_length"] for s in samples]
context_tokens = [s["context_tokens"] for s in samples]
# Overall metrics
overall_metrics = {
"f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
"bleu1": sum(bleu1_scores) / len(bleu1_scores) if bleu1_scores else 0.0,
"jaccard": sum(jaccard_scores) / len(jaccard_scores) if jaccard_scores else 0.0,
"locomo_f1": sum(locomo_f1_scores) / len(locomo_f1_scores) if locomo_f1_scores else 0.0
}
# Per-category metrics
category_data: Dict[str, Dict[str, List[float]]] = {}
for sample in samples:
cat = sample["category"]
if cat not in category_data:
category_data[cat] = {
"f1": [],
"bleu1": [],
"jaccard": [],
"locomo_f1": []
}
category_data[cat]["f1"].append(sample["metrics"]["f1"])
category_data[cat]["bleu1"].append(sample["metrics"]["bleu1"])
category_data[cat]["jaccard"].append(sample["metrics"]["jaccard"])
category_data[cat]["locomo_f1"].append(sample["metrics"]["locomo_f1"])
by_category: Dict[str, Dict[str, Any]] = {}
for cat, metrics_lists in category_data.items():
by_category[cat] = {
"count": len(metrics_lists["f1"]),
"f1": sum(metrics_lists["f1"]) / len(metrics_lists["f1"]),
"bleu1": sum(metrics_lists["bleu1"]) / len(metrics_lists["bleu1"]),
"jaccard": sum(metrics_lists["jaccard"]) / len(metrics_lists["jaccard"]),
"locomo_f1": sum(metrics_lists["locomo_f1"]) / len(metrics_lists["locomo_f1"])
}
# Latency statistics
latency = {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm)
}
# Context statistics
context_stats = {
"avg_retrieved_docs": sum(context_counts) / len(context_counts) if context_counts else 0.0,
"avg_context_chars": sum(context_chars) / len(context_chars) if context_chars else 0.0,
"avg_context_tokens": sum(context_tokens) / len(context_tokens) if context_tokens else 0.0
}
return {
"overall_metrics": overall_metrics,
"by_category": by_category,
"latency": latency,
"context_stats": context_stats
}
# ============================================================================
# Step 6: Result Saving
# ============================================================================
def step_save_results(
result: Dict[str, Any],
output_dir: Optional[str]
) -> str:
"""
Save evaluation results to JSON file.
Args:
result: Complete result dictionary
output_dir: Directory to save results (uses default if None)
Returns:
Path to saved file
"""
if output_dir is None:
# Use absolute path to ensure results are saved in the correct location
script_dir = Path(__file__).resolve().parent
output_dir = script_dir / "results"
else:
# Convert to Path object
output_dir = Path(output_dir)
# If relative path, make it relative to script directory
if not output_dir.is_absolute():
script_dir = Path(__file__).resolve().parent
output_dir = script_dir / output_dir
# Create directory if it doesn't exist
output_dir.mkdir(parents=True, exist_ok=True)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = output_dir / f"locomo_{timestamp_str}.json"
try:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ Results saved to: {output_path}\n")
return str(output_path)
except Exception as e:
print(f"❌ Failed to save results: {e}")
print("📊 Printing results to console instead:\n")
print(json.dumps(result, ensure_ascii=False, indent=2))
return ""
# ============================================================================
# Main Orchestration Function
# ============================================================================
async def run_locomo_benchmark(
sample_size: int = 20,
end_user_id: Optional[str] = None,
search_type: str = "hybrid",
search_limit: int = 12,
context_char_budget: int = 8000,
reset_group: bool = False,
skip_ingest: bool = False,
output_dir: Optional[str] = None,
max_ingest_messages: Optional[int] = None
) -> Dict[str, Any]:
"""
Run LoCoMo benchmark evaluation.
This function orchestrates the complete evaluation pipeline by calling
well-defined step functions:
1. Load LoCoMo dataset (only QA pairs from first conversation)
2. Ingest conversations into database (unless skip_ingest=True)
3. Initialize clients (Neo4j, LLM, Embedder)
4. Process all questions (retrieve, generate, calculate metrics)
5. Aggregate results
6. Save results to file
Note: By default, only the first conversation is ingested into the database,
and only QA pairs from that conversation are evaluated. This ensures that
all questions have corresponding memory in the database for retrieval.
Args:
sample_size: Number of QA pairs to evaluate (from first conversation)
end_user_id: Database end_user ID for retrieval (uses default if None)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max documents to retrieve per query
context_char_budget: Max characters for context
reset_group: Whether to clear and re-ingest data
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
output_dir: Directory to save results (uses default if None)
max_ingest_messages: Max messages per dialogue to ingest (for testing, None = all)
Returns:
Dictionary with evaluation results including metrics, timing, and samples
"""
# Use default end_user_id if not provided
# 优先级:命令行参数 > LOCOMO_END_USER_ID > EVAL_END_USER_ID > 默认值
if end_user_id is None:
end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark")
# Get model IDs from config
llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f")
embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
# Determine data path
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
data_path = dataset_dir / "locomo10.json"
if not os.path.exists(data_path):
raise FileNotFoundError(
f"数据集文件不存在: {data_path}\n"
f"请将 locomo10.json 放置在: {dataset_dir}"
)
# Print configuration
print(f"\n{'='*60}")
print("🚀 Starting LoCoMo Benchmark Evaluation")
print(f"{'='*60}")
print("📊 Configuration:")
print(f" Sample size: {sample_size}")
print(f" End User ID: {end_user_id}")
print(f" Search type: {search_type}")
print(f" Search limit: {search_limit}")
print(f" Context budget: {context_char_budget} chars")
print(f" Data path: {data_path}")
if max_ingest_messages:
print(f" Max ingest messages: {max_ingest_messages} (testing mode)")
print(f"{'='*60}\n")
# Step 1: Load LoCoMo data (加载数据)
try:
qa_items = step_load_data(data_path, sample_size)
except Exception as e:
print(f"❌ Failed to load data: {e}")
return {
"error": f"Data loading failed: {e}",
"timestamp": datetime.now().isoformat()
}
# Step 2: Ingest data if needed数据摄入
await step_ingest_data(data_path, end_user_id, skip_ingest, reset_group, max_ingest_messages)
# Step 3: Initialize clients (初始化客户端)
connector, llm_client, embedder = step_initialize_clients(llm_id, embedding_id)
# Step 4: Process all questions (处理所有问题)
try:
samples = await step_process_all_questions(
qa_items=qa_items,
end_user_id=end_user_id,
search_type=search_type,
search_limit=search_limit,
context_char_budget=context_char_budget,
connector=connector,
embedder=embedder,
llm_client=llm_client
)
finally:
await connector.close()
# Step 5: Aggregate results (聚合答案)
aggregated = step_aggregate_results(samples)
# Build final result dictionary
result = {
"dataset": "locomo",
"sample_size": len(qa_items),
"timestamp": datetime.now().isoformat(),
"params": {
"end_user_id": end_user_id,
"search_type": search_type,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_id": llm_id,
"embedding_id": embedding_id
},
"overall_metrics": aggregated["overall_metrics"],
"by_category": aggregated["by_category"],
"latency": aggregated["latency"],
"context_stats": aggregated["context_stats"],
"samples": samples
}
# Step 6: Save results (保存结果)
step_save_results(result, output_dir)
return result
def main():
"""
Parse command-line arguments and run benchmark.
This function provides a CLI interface for running LoCoMo benchmarks
with configurable parameters.
Configuration priority: Command-line args > Environment variables > Code defaults
"""
# Load environment variables first
load_dotenv()
# Get defaults from environment variables
env_sample_size = os.getenv("LOCOMO_SAMPLE_SIZE")
env_search_limit = os.getenv("LOCOMO_SEARCH_LIMIT")
env_context_budget = os.getenv("LOCOMO_CONTEXT_CHAR_BUDGET")
env_output_dir = os.getenv("LOCOMO_OUTPUT_DIR")
env_skip_ingest = os.getenv("LOCOMO_SKIP_INGEST", "false").lower() in ("true", "1", "yes")
# Convert to appropriate types with fallback to code defaults
default_sample_size = int(env_sample_size) if env_sample_size else 20
default_search_limit = int(env_search_limit) if env_search_limit else 12
default_context_budget = int(env_context_budget) if env_context_budget else 8000
default_output_dir = env_output_dir if env_output_dir else None
parser = argparse.ArgumentParser(
description="Run LoCoMo benchmark evaluation",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--sample_size",
type=int,
default=default_sample_size,
help=f"Number of QA pairs to evaluate (env: LOCOMO_SAMPLE_SIZE={env_sample_size or 'not set'}, 0 for all)"
)
parser.add_argument(
"--end_user_id",
type=str,
default=None,
help="Database end user ID for retrieval (uses LOCOMO_END_USER_ID or EVAL_END_USER_ID if not specified)"
)
parser.add_argument(
"--search_type",
type=str,
default="hybrid",
choices=["keyword", "embedding", "hybrid"],
help="Search strategy to use"
)
parser.add_argument(
"--search_limit",
type=int,
default=default_search_limit,
help=f"Maximum number of documents to retrieve per query (env: LOCOMO_SEARCH_LIMIT={env_search_limit or 'not set'})"
)
parser.add_argument(
"--context_char_budget",
type=int,
default=default_context_budget,
help=f"Maximum characters for context (env: LOCOMO_CONTEXT_CHAR_BUDGET={env_context_budget or 'not set'})"
)
parser.add_argument(
"--reset_group",
action="store_true",
help="Clear and re-ingest data (not implemented)"
)
parser.add_argument(
"--skip_ingest",
action="store_true",
default=env_skip_ingest,
help=f"Skip data ingestion and use existing data in Neo4j (env: LOCOMO_SKIP_INGEST={os.getenv('LOCOMO_SKIP_INGEST', 'false')})"
)
parser.add_argument(
"--output_dir",
type=str,
default=default_output_dir,
help=f"Directory to save results (env: LOCOMO_OUTPUT_DIR={env_output_dir or 'not set'})"
)
parser.add_argument(
"--max_ingest_messages",
type=int,
default=None,
help="Maximum messages per dialogue to ingest (for testing, default: all messages)"
)
args = parser.parse_args()
# Run benchmark
result = asyncio.run(run_locomo_benchmark(
sample_size=args.sample_size,
end_user_id=args.end_user_id,
search_type=args.search_type,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
reset_group=args.reset_group,
skip_ingest=args.skip_ingest,
output_dir=args.output_dir,
max_ingest_messages=args.max_ingest_messages
))
# Print summary
print(f"\n{'='*60}")
# Check if there was an error
if 'error' in result:
print("❌ Benchmark Failed!")
print(f"{'='*60}")
print(f"Error: {result['error']}")
return
print("🎉 Benchmark Complete!")
print(f"{'='*60}")
print("📊 Final Results:")
print(f" Sample size: {result.get('sample_size', 0)}")
print(f" F1: {result['overall_metrics']['f1']:.3f}")
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
if result.get('context_stats'):
print("\n📈 Context Statistics:")
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
if result.get('latency'):
print("\n⏱️ Latency Statistics:")
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
f"P50: {result['latency']['search']['p50']:.1f}ms, "
f"P95: {result['latency']['search']['p95']:.1f}ms")
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
f"P95: {result['latency']['llm']['p95']:.1f}ms")
if result.get('by_category'):
print("\n📂 Results by Category:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" Count: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
print(f" Jaccard: {metrics['jaccard']:.3f}")
print(f"\n{'='*60}\n")
if __name__ == "__main__":
main()

View File

@@ -1,225 +0,0 @@
"""
LoCoMo-specific metric calculations.
This module provides clean, simplified implementations of metrics used for
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
"""
import re
from typing import Dict, Any
def normalize_text(text: str) -> str:
"""
Normalize text for LoCoMo evaluation.
Normalization steps:
- Convert to lowercase
- Remove commas
- Remove stop words (a, an, the, and)
- Remove punctuation
- Normalize whitespace
Args:
text: Input text to normalize
Returns:
Normalized text string with consistent formatting
Examples:
>>> normalize_text("The cat, and the dog")
'cat dog'
>>> normalize_text("Hello, World!")
'hello world'
"""
# Ensure input is a string
text = str(text) if text is not None else ""
# Convert to lowercase
text = text.lower()
# Remove commas
text = re.sub(r"[\,]", " ", text)
# Remove stop words
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
# Remove punctuation (keep only word characters and whitespace)
text = re.sub(r"[^\w\s]", " ", text)
# Normalize whitespace (collapse multiple spaces to single space)
text = " ".join(text.split())
return text
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for single-answer questions.
Uses token-level precision and recall based on normalized text.
Treats tokens as sets (no duplicate counting).
Args:
prediction: Model's predicted answer
ground_truth: Correct answer
Returns:
F1 score between 0.0 and 1.0
Examples:
>>> locomo_f1_score("Paris", "Paris")
1.0
>>> locomo_f1_score("The cat", "cat")
1.0
>>> locomo_f1_score("dog", "cat")
0.0
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Normalize and tokenize
pred_tokens = normalize_text(pred_str).split()
truth_tokens = normalize_text(truth_str).split()
# Handle empty cases
if not pred_tokens or not truth_tokens:
return 0.0
# Convert to sets for comparison
pred_set = set(pred_tokens)
truth_set = set(truth_tokens)
# Calculate true positives (intersection)
true_positives = len(pred_set & truth_set)
# Calculate precision and recall
precision = true_positives / len(pred_set) if pred_set else 0.0
recall = true_positives / len(truth_set) if truth_set else 0.0
# Calculate F1 score
if precision + recall == 0:
return 0.0
f1 = 2 * precision * recall / (precision + recall)
return f1
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
"""
Calculate LoCoMo F1 score for multi-answer questions.
Handles comma-separated answers by:
1. Splitting both prediction and ground truth by commas
2. For each ground truth answer, finding the best matching prediction
3. Averaging the F1 scores across all ground truth answers
Args:
prediction: Model's predicted answer (may contain multiple comma-separated answers)
ground_truth: Correct answer (may contain multiple comma-separated answers)
Returns:
Average F1 score across all ground truth answers (0.0 to 1.0)
Examples:
>>> locomo_multi_f1("Paris, London", "Paris, London")
1.0
>>> locomo_multi_f1("Paris", "Paris, London")
0.5
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
0.5
"""
# Ensure inputs are strings
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
# Split by commas and strip whitespace
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
# Handle empty cases
if not predictions or not ground_truths:
return 0.0
# For each ground truth, find the best matching prediction
f1_scores = []
for gt in ground_truths:
# Calculate F1 with each prediction and take the maximum
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
f1_scores.append(best_f1)
# Return average F1 across all ground truths
return sum(f1_scores) / len(f1_scores)
def get_category_name(item: Dict[str, Any]) -> str:
"""
Extract and normalize category name from QA item.
Handles both numeric categories (1-4) and string categories with various formats.
Supports multiple field names: "cat", "category", "type".
Category mapping:
- 1 or "multi-hop" -> "Multi-Hop"
- 2 or "temporal" -> "Temporal"
- 3 or "open domain" -> "Open Domain"
- 4 or "single-hop" -> "Single-Hop"
Args:
item: QA item dictionary containing category information
Returns:
Standardized category name or "unknown" if not found
Examples:
>>> get_category_name({"category": 1})
'Multi-Hop'
>>> get_category_name({"cat": "temporal"})
'Temporal'
>>> get_category_name({"type": "Single-Hop"})
'Single-Hop'
"""
# Numeric category mapping
CATEGORY_MAP = {
1: "Multi-Hop",
2: "Temporal",
3: "Open Domain",
4: "Single-Hop",
}
# String category aliases (case-insensitive)
TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
# Try "cat" field first (string category)
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return TYPE_ALIASES.get(lower, name)
# Try "category" field (can be int or string)
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP.get(cat_num, "unknown")
elif isinstance(cat_num, str) and cat_num.strip():
lower = cat_num.strip().lower()
return TYPE_ALIASES.get(lower, cat_num.strip())
# Try "type" field as fallback
cat_type = item.get("type")
if isinstance(cat_type, str) and cat_type.strip():
lower = cat_type.strip().lower()
return TYPE_ALIASES.get(lower, cat_type.strip())
return "unknown"

View File

@@ -1,864 +0,0 @@
# file name: check_neo4j_connection_fixed.py
import asyncio
import os
import sys
import json
import time
import math
import re
from datetime import datetime, timedelta
from typing import List, Dict, Any
from pathlib import Path
from dotenv import load_dotenv
# Load main .env
load_dotenv()
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
# Get group_id from config
group_id = os.getenv("EVAL_GROUP_ID", "locomo_test")
print(f"✅ 使用配置的 group_id: {group_id}")
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
def _loc_normalize(text: str) -> str:
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text)
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 尝试从 metrics.py 导入基础指标
try:
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}")
# 回退到本地实现
def f1_score(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens and not r_tokens:
return 1.0
if not p_tokens or not r_tokens:
return 0.0
p_set = set(p_tokens)
r_set = set(r_tokens)
tp = len(p_set & r_set)
precision = tp / len(p_set) if p_set else 0.0
recall = tp / len(r_set) if r_set else 0.0
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p_tokens = _loc_normalize(pred_str).split()
r_tokens = _loc_normalize(ref_str).split()
if not p_tokens:
return 0.0
r_counts = {}
for t in r_tokens:
r_counts[t] = r_counts.get(t, 0) + 1
clipped = 0
p_counts = {}
for t in p_tokens:
p_counts[t] = p_counts.get(t, 0) + 1
for t, c in p_counts.items():
clipped += min(c, r_counts.get(t, 0))
precision = clipped / max(len(p_tokens), 1)
ref_len = len(r_tokens)
pred_len = len(p_tokens)
if pred_len > ref_len or pred_len == 0:
bp = 1.0
else:
bp = math.exp(1 - ref_len / max(pred_len, 1))
return bp * precision
def jaccard(pred: str, ref: str) -> float:
pred_str = str(pred) if pred is not None else ""
ref_str = str(ref) if ref is not None else ""
p = set(_loc_normalize(pred_str).split())
r = set(_loc_normalize(ref_str).split())
if not p and not r:
return 1.0
if not p or not r:
return 0.0
return len(p & r) / len(p | r)
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
try:
from app.core.memory.evaluation.locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e:
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
# 回退到本地实现 LoCoMo 特定函数
def _resolve_relative_times(text: str, anchor: datetime) -> str:
t = str(text) if text is not None else ""
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
p_tokens = _loc_normalize(prediction).split()
g_tokens = _loc_normalize(ground_truth).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
"""根据问题复杂度和进度动态调整检索参数"""
# 分析问题复杂度
word_count = len(question.split())
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
# 根据进度调整 - 后期问题可能需要更精确的检索
progress_factor = question_index / total_questions
base_limit = 12
if has_temporal and has_multi_hop:
base_limit = 20
elif word_count > 8:
base_limit = 16
# 随着测试进行,逐渐收紧检索范围
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
# 动态调整最大字符数
max_chars = 8000 + 4000 * (1 - progress_factor)
return {
"limit": adjusted_limit,
"max_chars": int(max_chars)
}
class EnhancedEvaluationMonitor:
def __init__(self, reset_interval=5, performance_threshold=0.6):
self.question_count = 0
self.reset_interval = reset_interval
self.performance_threshold = performance_threshold
self.consecutive_low_scores = 0
self.performance_history = []
self.recent_f1_scores = []
def should_reset_connections(self, current_f1=None):
"""基于计数和性能双重判断"""
# 定期重置
if self.question_count % self.reset_interval == 0:
return True
# 性能驱动的重置
if current_f1 is not None and current_f1 < self.performance_threshold:
self.consecutive_low_scores += 1
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
print("🚨 连续低分,触发紧急重置")
self.consecutive_low_scores = 0
return True
else:
self.consecutive_low_scores = 0
return False
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
"""记录性能指标,检测衰减"""
self.performance_history.append({
'index': question_index,
'metrics': metrics,
'context_length': context_length,
'retrieved_docs': retrieved_docs,
'timestamp': time.time()
})
# 记录最近的F1分数
self.recent_f1_scores.append(metrics['f1'])
if len(self.recent_f1_scores) > 5:
self.recent_f1_scores.pop(0)
def get_recent_performance(self):
"""获取近期平均性能"""
if not self.recent_f1_scores:
return 0.5
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
def get_performance_trend(self):
"""分析性能趋势"""
if len(self.performance_history) < 2:
return "stable"
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
return "stable"
recent_avg = sum(recent_metrics) / len(recent_metrics)
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
if recent_avg < earlier_avg * 0.8:
return "degrading"
elif recent_avg > earlier_avg * 1.1:
return "improving"
else:
return "stable"
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
"""基于问题复杂度和近期性能动态调整检索参数"""
# 基础参数
base_params = get_dynamic_search_params(question, question_index, total_questions)
# 性能自适应调整
if recent_performance < 0.5: # 近期表现差
# 增加检索范围,尝试获取更多上下文
base_params["limit"] = min(base_params["limit"] + 5, 25)
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
elif recent_performance > 0.8: # 近期表现好
# 收紧检索,提高精度
base_params["limit"] = max(base_params["limit"] - 2, 8)
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
# 中间阶段特殊处理
mid_sequence_factor = abs(question_index / total_questions - 0.5)
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用更精确的检索策略")
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
return base_params
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
"""考虑问题序列位置的智能选择"""
if not contexts:
return ""
# 在序列中间阶段使用更严格的筛选
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
if mid_sequence_factor < 0.2: # 在中间30%的问题
print("🎯 中间阶段:使用严格上下文筛选")
# 提取问题关键词
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
# 只保留高度相关的上下文
filtered_contexts = []
for context in contexts:
context_lower = context.lower()
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
if any(char.isdigit() for char in context):
relevance_score += 2
# 提高阈值:只有得分>=3的上下文才保留
if relevance_score >= 3:
filtered_contexts.append(context)
else:
print(f" - 过滤低分上下文: 得分={relevance_score}")
contexts = filtered_contexts
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
# 使用原有的智能选择逻辑
return smart_context_selection(contexts, question, max_chars)
async def run_enhanced_evaluation():
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
from dotenv import load_dotenv
from uuid import UUID
from datetime import datetime
from dataclasses import dataclass
# 修正导入路径:使用 app.core.memory.src 前缀
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.schemas.memory_config_schema import MemoryConfig
from app.services.memory_config_service import MemoryConfigService
# Get model IDs from config
llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f")
embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
# 加载数据 - 使用统一的 dataset 目录
data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "dataset", "locomo10.json")
if not os.path.exists(data_path):
raise FileNotFoundError(
f"数据集文件不存在: {data_path}\n"
f"请将 locomo10.json 放置在: api/app/core/memory/evaluation/dataset/"
)
print(f"✅ 找到数据文件: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
qa_items = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
# 测试多少个问题 - 可通过环境变量设置
sample_size = int(os.getenv("LOCOMO_SAMPLE_SIZE", "20"))
items = qa_items[:sample_size]
print(f"📊 将测试 {len(items)} 个问题(总共 {len(qa_items)} 个可用)")
# 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
# 获取数据库会话并初始化 LLM 客户端
from app.db import get_db
db = next(get_db())
try:
llm = get_llm_client(llm_id, db)
# 初始化embedder
cfg_dict = get_embedder_config(embedding_id, db)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 🔧 创建 MemoryConfig 对象用于搜索
# 方案1如果有配置ID从数据库加载
config_id = os.getenv("EVAL_CONFIG_ID")
if config_id:
print(f"📋 从数据库加载配置 ID: {config_id}")
memory_config_service = MemoryConfigService(db)
memory_config = memory_config_service.load_memory_config(config_id, service_name="locomo_test")
else:
# 方案2创建临时配置对象用于测试
print(f"📋 创建临时测试配置")
from uuid import UUID
from datetime import datetime
# 将字符串 ID 转换为 UUID
try:
embedding_uuid = UUID(embedding_id)
llm_uuid = UUID(llm_id)
except ValueError as e:
raise ValueError(f"无效的 UUID 格式: {e}")
memory_config = MemoryConfig(
config_id=1, # 临时 ID
config_name="locomo_test_config",
workspace_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 workspace
workspace_name="test_workspace",
tenant_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 tenant
embedding_model_id=embedding_uuid,
embedding_model_name="test_embedding",
llm_model_id=llm_uuid,
llm_model_name="test_llm",
storage_type="neo4j",
chunker_strategy="RecursiveChunker",
reflexion_enabled=False,
reflexion_iteration_period=3,
reflexion_range="partial",
reflexion_baseline="Time",
loaded_at=datetime.now()
)
print(f"✅ MemoryConfig 已准备: embedding_id={memory_config.embedding_model_id}, llm_id={memory_config.llm_model_id}")
# 初始化连接器
connector = Neo4jConnector()
# 初始化结果字典
results = {
"questions": [],
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
"category_metrics": {},
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
"performance_trend": "stable",
"timestamp": datetime.now().isoformat(),
"enhanced_strategy": True
}
total_f1 = 0.0
total_bleu1 = 0.0
total_jaccard = 0.0
total_loc_f1 = 0.0
total_context_length = 0
total_retrieved_docs = 0
category_stats = {}
try:
for i, item in enumerate(items):
monitor.question_count += 1
# 获取近期性能用于重置判断
recent_performance = monitor.get_recent_performance()
# 增强的重置判断
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
if should_reset and i > 0:
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
await connector.close()
connector = Neo4jConnector() # 创建新连接
print("✅ 连接重置完成")
q = item.get("question", "")
ref = item.get("answer", "")
ref_str = str(ref) if ref is not None else ""
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
print(f"✅ 真实答案: {ref_str}")
# 分类别统计
category = "Unknown"
if item.get("category") == 1:
category = "Multi-Hop"
elif item.get("category") == 2:
category = "Temporal"
elif item.get("category") == 3:
category = "Open Domain"
elif item.get("category") == 4:
category = "Single-Hop"
# 增强的检索参数
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
search_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
# 使用项目标准的混合检索方法
t0 = time.time()
contexts_all = []
try:
# 使用旧版本的搜索服务(重构前的版本)
from app.core.memory.src.search import run_hybrid_search
print(f"🔀 使用混合搜索服务(旧版本)...")
print(f"📍 检索参数: group_id={group_id}, limit=20, search_type=hybrid")
print(f"📍 查询文本: {q}")
search_results = await run_hybrid_search(
query_text=q,
search_type="hybrid",
end_user_id="locomo_sk",
limit=20,
include=["statements", "chunks", "entities", "summaries"],
output_path=None,
memory_config=memory_config, # 🔧 添加必需的 memory_config 参数
rerank_alpha=0.6, # BM25权重
use_forgetting_rerank=False,
use_llm_rerank=False
)
# 处理搜索结果 - 旧版本返回包含 reranked_results 的结构
# 对于 hybrid 搜索,使用 reranked_results
if "reranked_results" in search_results:
reranked = search_results["reranked_results"]
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
# 单一搜索类型的结果
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
print(f"📊 有效上下文数量: {len(contexts_all)}")
except Exception as e:
print(f"❌ 检索失败: {e}")
import traceback
print(f"详细错误信息:\n{traceback.format_exc()}")
contexts_all = []
t1 = time.time()
search_time = (t1 - t0) * 1000
# 增强的上下文选择
context_text = ""
if contexts_all:
# 使用增强的上下文选择
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览(不只是第一条)
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
# 🔍 调试:检查答案是否在上下文中
if ref_str and ref_str.strip():
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# LLM 回答
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
except Exception as e:
print(f"❌ LLM 生成失败: {e}")
pred = "Unknown"
t3 = time.time()
llm_time = (t3 - t2) * 1000
# 计算指标 - 使用导入的指标函数
f1_val = f1_score(pred, ref_str)
bleu1_val = bleu1(pred, ref_str)
jaccard_val = jaccard(pred, ref_str)
loc_f1_val = loc_f1_score(pred, ref_str)
print(f"🤖 LLM 回答: {pred}")
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
# 更新统计
total_f1 += f1_val
total_bleu1 += bleu1_val
total_jaccard += jaccard_val
total_loc_f1 += loc_f1_val
total_context_length += len(context_text)
total_retrieved_docs += len(contexts_all)
if category not in category_stats:
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
category_stats[category]["count"] += 1
category_stats[category]["f1_sum"] += f1_val
category_stats[category]["b1_sum"] += bleu1_val
category_stats[category]["j_sum"] += jaccard_val
category_stats[category]["loc_f1_sum"] += loc_f1_val
# 记录性能指标
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
# 保存结果
question_result = {
"question": q,
"ground_truth": ref_str,
"prediction": pred,
"category": category,
"metrics": metrics,
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": max_chars,
"recent_performance": recent_performance
},
"timing": {
"search_ms": search_time,
"llm_ms": llm_time
}
}
results["questions"].append(question_result)
print("="*60)
except Exception as e:
print(f"❌ 评估过程中发生错误: {e}")
# 即使出错,也返回已有的结果
import traceback
traceback.print_exc()
finally:
await connector.close()
finally:
db.close() # 关闭数据库会话
# 计算总体指标
n = len(items)
if n > 0:
results["overall_metrics"] = {
"f1": total_f1 / n,
"b1": total_bleu1 / n,
"j": total_jaccard / n,
"loc_f1": total_loc_f1 / n
}
for category, stats in category_stats.items():
count = stats["count"]
results["category_metrics"][category] = {
"count": count,
"f1": stats["f1_sum"] / count,
"bleu1": stats["b1_sum"] / count,
"jaccard": stats["j_sum"] / count,
"loc_f1": stats["loc_f1_sum"] / count
}
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
# 分析性能趋势
results["performance_trend"] = monitor.get_performance_trend()
results["reset_interval"] = monitor.reset_interval
results["total_questions_processed"] = monitor.question_count
return results
if __name__ == "__main__":
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
print("📋 增强特性:")
print(" - 双重重置策略:定期重置 + 性能驱动重置")
print(" - 动态检索参数:基于近期性能自适应调整")
print(" - 中间阶段严格筛选:提高上下文质量要求")
print(" - 连续性能监控:实时检测性能衰减")
result = asyncio.run(run_enhanced_evaluation())
print("\n📊 最终评估结果:")
print("总体指标:")
print(f" F1: {result['overall_metrics']['f1']:.4f}")
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
print("\n分类别指标:")
for category, metrics in result['category_metrics'].items():
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
print("\n检索统计:")
stats = result['retrieval_stats']
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
print(f"\n性能趋势: {result['performance_trend']}")
print(f"重置间隔: 每{result['reset_interval']}个问题")
print(f"处理问题总数: {result['total_questions_processed']}")
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
# 保存结果到指定目录
# 使用代码文件所在目录的绝对路径
current_file_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_file_dir, "results")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
with open(output_file, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n详细结果已保存到: {output_file}")

View File

@@ -1,687 +0,0 @@
"""
LoCoMo Utilities Module
This module provides helper functions for the LoCoMo benchmark evaluation:
- Data loading from JSON files
- Conversation extraction for ingestion
- Temporal reference resolution
- Context selection and formatting
- Retrieval wrapper functions
- Ingestion wrapper functions
"""
import os
import json
import re
from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
def load_locomo_data(
data_path: str,
sample_size: int,
conversation_index: int = 0
) -> List[Dict[str, Any]]:
"""
Load LoCoMo dataset from JSON file.
The LoCoMo dataset structure is a list of conversation objects, where each
object contains a "qa" list of question-answer pairs.
Args:
data_path: Path to locomo10.json file
sample_size: Number of QA pairs to load (limits total QA items returned)
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
Returns:
List of QA item dictionaries, each containing:
- question: str
- answer: str
- category: int (1-4)
- evidence: List[str]
Raises:
FileNotFoundError: If data_path does not exist
json.JSONDecodeError: If file is not valid JSON
IndexError: If conversation_index is out of range
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo data structure: list of objects, each with a "qa" list
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
# Only load QA pairs from the specified conversation
if conversation_index < len(raw):
entry = raw[conversation_index]
if isinstance(entry, dict) and "qa" in entry:
qa_items.extend(entry.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has {len(raw)} conversations."
)
else:
# Fallback: single object with qa list
if conversation_index == 0:
qa_items.extend(raw.get("qa", []))
else:
raise IndexError(
f"Conversation index {conversation_index} out of range. "
f"Dataset has only 1 conversation."
)
# Return only the requested sample size
return qa_items[:sample_size]
def extract_conversations(data_path: str, max_dialogues: int = 1, max_messages_per_dialogue: Optional[int] = None) -> List[str]:
"""
Extract conversation texts from LoCoMo data for ingestion.
This function extracts the raw conversation dialogues from the LoCoMo dataset
so they can be ingested into the memory system. Each conversation is formatted
as a multi-line string with "role: message" format.
Args:
data_path: Path to locomo10.json file
max_dialogues: Maximum number of dialogues to extract (default: 1)
max_messages_per_dialogue: Maximum messages per dialogue (default: None = all messages)
Returns:
List of conversation strings formatted for ingestion.
Each string contains multiple lines in format "role: message"
Example output:
[
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
]
"""
if not os.path.exists(data_path):
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# Ensure we have a list of entries
entries = raw if isinstance(raw, list) else [raw]
contents: List[str] = []
for i, entry in enumerate(entries[:max_dialogues]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
if not isinstance(conv, dict):
continue
lines: List[str] = []
# Collect all session_* messages
for key, val in sorted(conv.items()):
if isinstance(val, list) and key.startswith("session_"):
for msg in val:
if not isinstance(msg, dict):
continue
role = msg.get("speaker") or "User"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
# Limit messages if specified
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
break
# Break outer loop if we've reached the message limit
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
break
if lines:
contents.append("\n".join(lines))
return contents
# 时间解析:将相对时间表达转换为绝对日期
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
"""
Resolve relative temporal references to absolute dates.
This function converts relative time expressions (like "today", "yesterday",
"3 days ago") into absolute ISO date strings based on an anchor date.
Supported patterns:
- today, yesterday, tomorrow
- X days ago, in X days
- last week, next week
Args:
text: Text containing temporal references
anchor_date: Reference date for resolution (datetime object)
Returns:
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
Example:
>>> anchor = datetime(2023, 5, 8)
>>> resolve_temporal_references("I saw him yesterday", anchor)
"I saw him 2023-05-07"
"""
# Ensure input is a string
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(
r"\btoday\b",
anchor_date.date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\byesterday\b",
(anchor_date - timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\btomorrow\b",
(anchor_date + timedelta(days=1)).date().isoformat(),
t,
flags=re.IGNORECASE
)
# X days ago
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date - timedelta(days=n)).date().isoformat()
# in X days
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor_date + timedelta(days=n)).date().isoformat()
t = re.sub(
r"\b(\d+)\s+days?\s+ago\b",
_ago_repl,
t,
flags=re.IGNORECASE
)
t = re.sub(
r"\bin\s+(\d+)\s+days?\b",
_in_repl,
t,
flags=re.IGNORECASE
)
# last week / next week (approximate as 7 days)
t = re.sub(
r"\blast\s+week\b",
(anchor_date - timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
# 中文支持
t = re.sub(
r"\bnext\s+week\b",
(anchor_date + timedelta(days=7)).date().isoformat(),
t,
flags=re.IGNORECASE
)
return t
def select_and_format_information(
retrieved_info: List[str],
question: str,
max_chars: int = 8000
) -> str:
"""
Intelligently select and format most relevant retrieved information for LLM prompt.
This function scores each piece of retrieved information based on keyword matching
with the question, then selects the highest-scoring pieces up to the character limit.
Scoring criteria:
- Keyword matches (higher weight for multiple occurrences)
- Context length (moderate length preferred)
- Position (earlier contexts get bonus points)
Args:
retrieved_info: List of retrieved information strings (chunks, statements, entities)
question: Question being answered
max_chars: Maximum total characters to include in final prompt
Returns:
Formatted string combining the most relevant information for LLM prompt.
Contexts are separated by double newlines.
Example:
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
>>> question = "Where did Alice go?"
>>> select_and_format_information(contexts, question, max_chars=100)
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
"""
if not retrieved_info:
return ""
# Extract question keywords (filter out stop words and short words)
question_lower = question.lower()
stop_words = {
'what', 'when', 'where', 'who', 'why', 'how',
'did', 'do', 'does', 'is', 'are', 'was', 'were',
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {
word for word in question_words
if word not in stop_words and len(word) > 2
}
# Score each context
scored_contexts = []
for i, context in enumerate(retrieved_info):
context_lower = context.lower()
score = 0
# Keyword matching score
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# Multiple occurrences increase score
score += context_lower.count(word) * 2
# Length score (prefer moderate length)
context_len = len(context)
if 100 < context_len < 2000:
score += 5
elif context_len >= 2000:
score += 2
# Position bonus (earlier contexts often more relevant)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# Sort by score (descending)
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# Select contexts up to character limit
selected = []
total_chars = 0
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
else:
# Try to include high-scoring context by truncating
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# Find lines with keywords
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
truncated = '\n'.join(relevant_lines)
selected.append(truncated + "\n[Content truncated...]")
total_chars += len(truncated)
break
return "\n\n".join(selected)
# 记忆系统核心能力:写入与读取
async def ingest_conversations_if_needed(
conversations: List[str],
end_user_id: str,
reset: bool = False
) -> bool:
"""
Wrapper for conversation ingestion using external extraction pipeline.
This function populates the Neo4j database with processed conversation data
(chunks, statements, entities) so that the retrieval system has memory to search.
The ingestion process:
1. Parses conversation text into dialogue messages
2. Chunks the dialogues into semantic units
3. Extracts statements and entities using LLM
4. Generates embeddings for all content
5. Stores everything in Neo4j graph database
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
end_user_id: Target end_user ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
True if successful, False otherwise
Note:
The external function uses "contexts" to mean "conversation texts".
This runs the full extraction pipeline: chunking → entity extraction →
statement extraction → embedding → Neo4j storage.
"""
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
end_user_id=end_user_id,
save_chunk_output=True,
reset_group=reset
)
return success
except Exception as e:
print(f"[Ingestion] Failed to ingest conversations: {e}")
return False
async def retrieve_relevant_information(
question: str,
end_user_id: str,
search_type: str,
search_limit: int,
connector: Any,
embedder: Any
) -> List[str]:
"""
Retrieve relevant information from memory graph for a question.
This function searches the Neo4j memory graph (populated during ingestion) and
returns relevant chunks, statements, and entity information that might help
answer the question.
The function supports three search types:
- "keyword": Full-text search using Cypher queries
- "embedding": Vector similarity search using embeddings
- "hybrid": Combination of keyword and embedding search with reranking
Args:
question: Question to search for
end_user_id: Database group ID (identifies which conversation memory to search)
search_type: "keyword", "embedding", or "hybrid"
search_limit: Max memory pieces to retrieve
connector: Neo4j connector instance
embedder: Embedder client instance
Returns:
List of text strings (chunks, statements, entity summaries) from memory graph.
Each string represents a piece of retrieved information.
Raises:
Exception: If search fails (caught and returns empty list)
"""
from app.repositories.neo4j.graph_search import (
search_graph,
search_graph_by_embedding
)
from app.core.memory.src.search import run_hybrid_search
contexts_all: List[str] = []
try:
if search_type == "embedding":
# Embedding-based search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context from chunks
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add summaries
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities (limit to 3 to avoid noise)
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# Keyword-based search
search_results = await search_graph(
connector=connector,
q=question,
end_user_id=end_user_id,
limit=search_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
# Build context from dialogues
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
# Add statements
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# Add entity names
if entities:
entity_names = [
str(e.get("name", "")).strip()
for e in entities[:5]
if e.get("name")
]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# Hybrid search with fallback to embedding
try:
search_results = await run_hybrid_search(
query_text=question,
search_type=search_type,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
)
# Handle flat structure (new API format)
if search_results and isinstance(search_results, dict):
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Check if we got results
if not (chunks or statements or entities or summaries):
# Try nested structure (backward compatibility)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
raise ValueError("Hybrid search returned empty results")
else:
raise ValueError("Hybrid search returned empty results")
except Exception as e:
# Fallback to embedding search
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# Build context (same for both hybrid and fallback)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# Add top entities
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = (
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
if scored else entities[:3]
)
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(
f"EntitySummary: {name}"
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
)
if summary_lines:
contexts_all.append("\n".join(summary_lines))
except Exception as e:
# Return empty list on error
contexts_all = []
return contexts_all
async def ingest_conversations_if_needed(
conversations: List[str],
end_user_id: str,
reset: bool = False
) -> bool:
"""
Wrapper for conversation ingestion using external extraction pipeline.
This function populates the Neo4j database with processed conversation data
(chunks, statements, entities) so that the retrieval system has memory to search.
The ingestion process:
1. Parses conversation text into dialogue messages
2. Chunks the dialogues into semantic units
3. Extracts statements and entities using LLM
4. Generates embeddings for all content
5. Stores everything in Neo4j graph database
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
end_user_id: Target group ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
True if successful, False otherwise
Note:
The external function uses "contexts" to mean "conversation texts".
This runs the full extraction pipeline: chunking → entity extraction →
statement extraction → embedding → Neo4j storage.
"""
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
end_user_id=end_user_id,
save_chunk_output=True
)
return success
except Exception as e:
print(f"[Ingestion] Failed to ingest conversations: {e}")
return False

View File

@@ -1,874 +0,0 @@
import argparse
import asyncio
import json
import os
import time
from datetime import datetime, timedelta
from typing import List, Dict, Any
import statistics
import re
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
def _loc_normalize(text: str) -> str:
import re
# 确保输入是字符串
text = str(text) if text is not None else ""
text = text.lower()
text = re.sub(r"[\,]", " ", text) # 去掉逗号
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
text = re.sub(r"[^\w\s]", " ", text)
text = " ".join(text.split())
return text
# 追加相对时间归一化为绝对日期有限支持today/yesterday/tomorrow/X days ago/in X days/last week/next week
def _resolve_relative_times(text: str, anchor: datetime) -> str:
import re
# 确保输入是字符串
t = str(text) if text is not None else ""
# today / yesterday / tomorrow
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
# X days ago / in X days
def _ago_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor - timedelta(days=n)).date().isoformat()
def _in_repl(m: re.Match[str]) -> str:
n = int(m.group(1))
return (anchor + timedelta(days=n)).date().isoformat()
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
# last week / next week以7天近似
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
return t
def loc_f1_score(prediction: str, ground_truth: str) -> float:
# 单答案 F1按词集合计算近似原始实现去除词干依赖
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
p_tokens = _loc_normalize(pred_str).split()
g_tokens = _loc_normalize(truth_str).split()
if not p_tokens or not g_tokens:
return 0.0
p = set(p_tokens)
g = set(g_tokens)
tp = len(p & g)
precision = tp / len(p) if p else 0.0
recall = tp / len(g) if g else 0.0
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
# 多答案 F1prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
# 确保输入是字符串
pred_str = str(prediction) if prediction is not None else ""
truth_str = str(ground_truth) if ground_truth is not None else ""
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
if not predictions or not ground_truths:
return 0.0
def _f1(a: str, b: str) -> float:
return loc_f1_score(a, b)
vals = []
for gt in ground_truths:
vals.append(max(_f1(pred, gt) for pred in predictions))
return sum(vals) / len(vals)
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
CATEGORY_MAP_NUM_TO_NAME = {
4: "Single-Hop",
1: "Multi-Hop",
3: "Open Domain",
2: "Temporal",
}
_TYPE_ALIASES = {
"single-hop": "Single-Hop",
"singlehop": "Single-Hop",
"single hop": "Single-Hop",
"multi-hop": "Multi-Hop",
"multihop": "Multi-Hop",
"multi hop": "Multi-Hop",
"open domain": "Open Domain",
"opendomain": "Open Domain",
"temporal": "Temporal",
}
def get_category_label(item: Dict[str, Any]) -> str:
# 1) 直接用字符串 cat
cat = item.get("cat")
if isinstance(cat, str) and cat.strip():
name = cat.strip()
lower = name.lower()
return _TYPE_ALIASES.get(lower, name)
# 2) 数字 category 转名称
cat_num = item.get("category")
if isinstance(cat_num, int):
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
# 3) 备用 type 字段
t = item.get("type")
if isinstance(t, str) and t.strip():
lower = t.strip().lower()
return _TYPE_ALIASES.get(lower, t.strip())
return "unknown"
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
"""基于问题关键词智能选择上下文"""
if not contexts:
return ""
# 提取问题关键词(只保留有意义的词)
question_lower = question.lower()
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
question_words = set(re.findall(r'\b\w+\b', question_lower))
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
print(f"🔍 问题关键词: {question_words}")
# 给每个上下文打分
scored_contexts = []
for i, context in enumerate(contexts):
context_lower = context.lower()
score = 0
# 关键词匹配得分
keyword_matches = 0
for word in question_words:
if word in context_lower:
keyword_matches += 1
# 关键词出现次数越多,得分越高
score += context_lower.count(word) * 2
# 上下文长度得分(适中的长度更好)
context_len = len(context)
if 100 < context_len < 2000: # 理想长度范围
score += 5
elif context_len >= 2000: # 太长可能包含无关信息
score += 2
# 如果是前几个上下文,给予额外分数(通常相关性更高)
if i < 3:
score += 3
scored_contexts.append((score, context, keyword_matches))
# 按得分排序
scored_contexts.sort(key=lambda x: x[0], reverse=True)
# 选择高得分的上下文,直到达到字符限制
selected = []
total_chars = 0
selected_count = 0
print("📊 上下文相关性分析:")
for score, context, matches in scored_contexts[:5]: # 只显示前5个
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
for score, context, matches in scored_contexts:
if total_chars + len(context) <= max_chars:
selected.append(context)
total_chars += len(context)
selected_count += 1
else:
# 如果这个上下文得分很高但放不下,尝试截取
if score > 10 and total_chars < max_chars - 500:
remaining = max_chars - total_chars
# 找到包含关键词的部分
lines = context.split('\n')
relevant_lines = []
current_chars = 0
for line in lines:
line_lower = line.lower()
line_relevance = any(word in line_lower for word in question_words)
if line_relevance and current_chars < remaining - 100:
relevant_lines.append(line)
current_chars += len(line)
if relevant_lines:
truncated = '\n'.join(relevant_lines)
if len(truncated) > 100: # 确保有足够内容
selected.append(truncated + "\n[相关内容截断...]")
total_chars += len(truncated)
selected_count += 1
break # 不再尝试添加更多上下文
result = "\n\n".join(selected)
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
return result
def get_search_params_by_category(category: str):
"""根据问题类别调整检索参数"""
params_map = {
"Multi-Hop": {"limit": 20, "max_chars": 15000},
"Temporal": {"limit": 16, "max_chars": 10000},
"Open Domain": {"limit": 24, "max_chars": 18000},
"Single-Hop": {"limit": 12, "max_chars": 8000},
}
return params_map.get(category, {"limit": 16, "max_chars": 12000})
async def run_locomo_eval(
sample_size: int = 1,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000, # 保持默认值不变
llm_temperature: float = 0.0,
llm_max_tokens: int = 32,
search_type: str = "hybrid", # 保持默认值不变
output_path: str | None = None,
skip_ingest_if_exists: bool = True,
llm_timeout: float = 10.0,
llm_max_retries: int = 1
) -> Dict[str, Any]:
# 函数内部使用三路检索逻辑,但保持参数签名不变
end_user_id = end_user_id or SELECTED_end_user_id
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path):
raise FileNotFoundError(
f"数据集文件不存在: {data_path}\n"
f"请将 locomo10.json 放置在: {dataset_dir}"
)
with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f)
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
qa_items: List[Dict[str, Any]] = []
if isinstance(raw, list):
for entry in raw:
qa_items.extend(entry.get("qa", []))
else:
qa_items.extend(raw.get("qa", []))
items: List[Dict[str, Any]] = qa_items[:sample_size]
# === 保持原来的数据摄入逻辑 ===
entries = raw if isinstance(raw, list) else [raw]
# 只摄入前1条对话保持原样
max_dialogues_to_ingest = 1
contents: List[str] = []
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest}")
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
if not isinstance(entry, dict):
continue
conv = entry.get("conversation", {})
sample_id = entry.get("sample_id", f"unknown_{i}")
print(f"🔍 处理对话 {i+1}: {sample_id}")
lines: List[str] = []
if isinstance(conv, dict):
# 收集所有 session_* 的消息
session_count = 0
for key, val in conv.items():
if isinstance(val, list) and key.startswith("session_"):
session_count += 1
for msg in val:
role = msg.get("speaker") or "用户"
text = msg.get("text") or ""
text = str(text).strip()
if not text:
continue
lines.append(f"{role}: {text}")
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
if not lines:
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
continue
contents.append("\n".join(lines))
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
# 选择要评测的QA对从所有对话中选取
indexed_items: List[tuple[int, Dict[str, Any]]] = []
if isinstance(raw, list):
for e_idx, entry in enumerate(raw):
for qa in entry.get("qa", []):
indexed_items.append((e_idx, qa))
else:
for qa in raw.get("qa", []):
indexed_items.append((0, qa))
# 这里使用sample_size来限制评测的QA数量
selected = indexed_items[:sample_size]
items: List[Dict[str, Any]] = [qa for _, qa in selected]
print(f"🎯 将评测 {len(items)} 个QA对数据库中只包含 {len(contents)} 个对话")
# === 修改结束 ===
connector = Neo4jConnector()
# 关键修复:强制重新摄入纯净的对话数据
print("🔄 强制重新摄入纯净的对话数据...")
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
# 使用异步LLM客户端
llm_client = get_llm_client(llm_id)
# 初始化embedder用于直接调用
cfg_dict = get_embedder_config(embedding_id)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# connector initialized above
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 上下文诊断收集
per_query_context_counts: List[int] = []
per_query_context_avg_tokens: List[float] = []
per_query_context_chars: List[int] = []
per_query_context_tokens_total: List[int] = []
# 详细样本调试信息
samples: List[Dict[str, Any]] = []
# 通用指标
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
# 参考 LoCoMo 评测的类别专用 F1multi-hop 使用多答案 F1
loc_f1s: List[float] = []
# Per-category aggregation
cat_counts: Dict[str, int] = {}
cat_f1s: Dict[str, List[float]] = {}
cat_b1s: Dict[str, List[float]] = {}
cat_jss: Dict[str, List[float]] = {}
cat_loc_f1s: Dict[str, List[float]] = {}
try:
for item in items:
q = item.get("question", "")
ref = item.get("answer", "")
# 确保答案是字符串
ref_str = str(ref) if ref is not None else ""
cat = get_category_label(item)
print(f"\n=== 处理问题: {q} ===")
# 根据类别调整检索参数
search_params = get_search_params_by_category(cat)
adjusted_limit = search_params["limit"]
max_chars = search_params["max_chars"]
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
# 改进的检索逻辑使用三路检索statements, dialogues, entities
t0 = time.time()
contexts_all: List[str] = []
search_results = None # 保存完整的检索结果
try:
if search_type == "embedding":
# 直接调用嵌入检索,包含三路数据
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
# 构建上下文:优先使用 chunks、statements 和 summaries
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体避免噪声
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
elif search_type == "keyword":
# 直接调用关键词检索
search_results = await search_graph(
connector=connector,
q=q,
end_user_id=end_user_id,
limit=adjusted_limit
)
dialogs = search_results.get("dialogues", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
# 构建上下文
for d in dialogs:
content = str(d.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
# 实体处理(关键词检索的实体可能没有分数)
if entities:
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
if entity_names:
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid
# 使用旧版本的混合检索(重构前)
print("🔀 使用混合检索(旧版本)...")
try:
search_results = await run_hybrid_search(
query_text=q,
search_type=search_type,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
output_path=None,
rerank_alpha=0.6,
use_forgetting_rerank=False,
use_llm_rerank=False
)
# 处理旧版本的返回结构(包含 reranked_results
if search_results and isinstance(search_results, dict):
# 对于 hybrid 搜索,使用 reranked_results
if "reranked_results" in search_results:
reranked = search_results["reranked_results"]
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
# 单一搜索类型的结果
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# 检查是否有有效结果
if chunks or statements or entities or summaries:
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
else:
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
reranked = search_results.get("reranked_results", {})
if reranked and isinstance(reranked, dict):
chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
print(f"✅ 混合检索成功使用旧格式reranked结果: {len(chunks)} chunks, {len(statements)} 陈述")
else:
raise ValueError("混合检索返回空结果")
else:
raise ValueError("混合检索返回空结果")
except Exception as e:
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
search_results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=q,
end_user_id=end_user_id,
limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"],
)
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
# 🎯 统一处理:构建上下文(所有检索类型共用)
for c in chunks:
content = str(c.get("content", "")).strip()
if content:
contexts_all.append(content)
for s in statements:
stmt_text = str(s.get("statement", "")).strip()
if stmt_text:
contexts_all.append(stmt_text)
for sm in summaries:
summary_text = str(sm.get("summary", "")).strip()
if summary_text:
contexts_all.append(summary_text)
# 实体摘要最多加入前3个高分实体
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
# 关键修复:过滤掉包含当前问题答案的上下文
filtered_contexts = []
for context in contexts_all:
content = str(context)
# 排除包含当前问题标准答案的上下文
if ref_str and ref_str.strip() and ref_str.strip() in content:
print("🚫 过滤掉包含标准答案的上下文")
continue
filtered_contexts.append(context)
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
contexts_all = filtered_contexts
# 输出完整的检索结果信息
print("🔍 检索结果详情:")
if search_results:
output_data = {
"statements": [
{
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
"score": s.get("score", 0.0)
}
for s in (statements[:2] if 'statements' in locals() else [])
],
"dialogues": [
{
"uuid": d.get("uuid", ""),
"end_user_id": d.get("end_user_id", ""),
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
"score": d.get("score", 0.0)
}
for d in (dialogs[:2] if 'dialogs' in locals() else [])
],
"entities": [
{
"name": e.get("name", ""),
"entity_type": e.get("entity_type", ""),
"score": e.get("score", 0.0)
}
for e in (entities[:2] if 'entities' in locals() else [])
]
}
print(json.dumps(output_data, ensure_ascii=False, indent=2))
else:
print(" 无检索结果")
except Exception as e:
print(f"{search_type}检索失败: {e}")
contexts_all = []
search_results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 使用智能上下文选择
context_text = ""
if contexts_all:
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
# 如果智能选择后仍然过长,进行最终保护性截断
if len(context_text) > max_chars:
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
# 时间解析
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
context_text = _resolve_relative_times(context_text, anchor_date)
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
print(f"📝 最终上下文长度: {len(context_text)} 字符")
# 显示不同上下文的预览
print("🔍 上下文预览:")
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
preview = context[:150].replace('\n', ' ')
print(f" 上下文{j+1}: {preview}...")
else:
print("❌ 没有检索到有效上下文")
context_text = "No relevant context found."
# 记录上下文诊断信息
per_query_context_counts.append(len(contexts_all))
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
per_query_context_chars.append(len(context_text))
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
# LLM 提示词
messages = [
{"role": "system", "content": (
"You are a precise QA assistant. Answer following these rules:\n"
"1) Extract the EXACT information mentioned in the context\n"
"2) For time questions: calculate actual dates from relative times\n"
"3) Return ONLY the answer text in simplest form\n"
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
"5) If no clear answer found, respond with 'Unknown'"
)},
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
]
t2 = time.time()
# 使用异步调用
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
# 兼容不同的响应格式
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
# 计算指标(确保使用字符串)
f1_val = common_f1(str(pred), ref_str)
b1_val = bleu1(str(pred), ref_str)
j_val = jaccard(str(pred), ref_str)
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
# Accumulate by category
cat_counts[cat] = cat_counts.get(cat, 0) + 1
cat_f1s.setdefault(cat, []).append(f1_val)
cat_b1s.setdefault(cat, []).append(b1_val)
cat_jss.setdefault(cat, []).append(j_val)
# LoCoMo 专用 F1multi-hop(1) 使用多答案 F1其它(2/3/4)使用单答案 F1
if item.get("category") in [2, 3, 4]:
loc_val = loc_f1_score(str(pred), ref_str)
elif item.get("category") in [1]:
loc_val = loc_multi_f1(str(pred), ref_str)
else:
loc_val = loc_f1_score(str(pred), ref_str)
loc_f1s.append(loc_val)
cat_loc_f1s.setdefault(cat, []).append(loc_val)
# 保存完整的检索结果信息
samples.append({
"question": q,
"answer": ref_str,
"category": cat,
"prediction": pred,
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val,
"loc_f1": loc_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": adjusted_limit,
"max_chars": max_chars
},
"timing": {
"search_ms": (t1 - t0) * 1000,
"llm_ms": (t3 - t2) * 1000
}
})
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {ref_str}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
# Compute per-category averages and dispersion (std, iqr)
def _percentile(sorted_vals: List[float], p: float) -> float:
if not sorted_vals:
return 0.0
if len(sorted_vals) == 1:
return sorted_vals[0]
k = (len(sorted_vals) - 1) * p
f = int(k)
c = f + 1 if f + 1 < len(sorted_vals) else f
if f == c:
return sorted_vals[f]
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
by_category: Dict[str, Dict[str, float | int]] = {}
for c in cat_counts:
f_list = cat_f1s.get(c, [])
b_list = cat_b1s.get(c, [])
j_list = cat_jss.get(c, [])
lf_list = cat_loc_f1s.get(c, [])
j_sorted = sorted(j_list)
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
j_q75 = _percentile(j_sorted, 0.75)
j_q25 = _percentile(j_sorted, 0.25)
by_category[c] = {
"count": cat_counts[c],
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
"j_std": j_std,
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
# 参考 LoCoMo 评测的类别专用 F1
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
}
# 累加命中cum accuracy by category与 evaluation_stats.py 输出形式相仿
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
result = {
"dataset": "locomo",
"items": len(items),
"metrics": {
"f1": sum(f1s) / max(len(f1s), 1),
"b1": sum(b1s) / max(len(b1s), 1),
"j": sum(jss) / max(len(jss), 1),
# LoCoMo 类别专用 F1 的总体
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
},
"by_category": by_category,
"category_counts": cat_counts,
"cum_accuracy_by_category": cum_accuracy_by_category,
"context": {
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"search_type": search_type,
"llm_id": llm_id,
"retrieval_embedding_id": embedding_id,
"chunker_strategy": os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker"),
"skip_ingest_if_exists": skip_ingest_if_exists,
"llm_timeout": llm_timeout,
"llm_max_retries": llm_max_retries,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens
},
"timestamp": datetime.now().isoformat()
}
if output_path:
try:
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"✅ 结果已保存到: {output_path}")
except Exception as e:
print(f"❌ 保存结果失败: {e}")
return result
finally:
await connector.close()
def main():
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
args = parser.parse_args()
load_dotenv()
result = asyncio.run(run_locomo_eval(
sample_size=args.sample_size,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
output_path=args.output_path,
skip_ingest_if_exists=args.skip_ingest_if_exists,
llm_timeout=args.llm_timeout,
llm_max_retries=args.llm_max_retries
))
print("\n" + "="*50)
print("📊 最终评测结果:")
print(f" 样本数量: {result['items']}")
print(f" F1: {result['metrics']['f1']:.3f}")
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
print(f" Jaccard: {result['metrics']['j']:.3f}")
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
if result['by_category']:
print("\n📈 按类别细分:")
for cat, metrics in result['by_category'].items():
print(f" {cat}:")
print(f" 样本数: {metrics['count']}")
print(f" F1: {metrics['f1']:.3f}")
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
print(f" Jaccard: {metrics['j']:.3f}{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@@ -1,559 +0,0 @@
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import List, Dict, Any
import re
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.src.search import run_hybrid_search # 使用与 evaluate_qa.py 相同的检索函数
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
"""
if not contexts:
return ""
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3"""
ql = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
}
words = re.findall(r"\b[\w-]+\b", ql)
kws = [w for w in words if w not in stop_words and len(w) >= 3]
# 去重保序
seen = set()
uniq = []
for w in kws:
if w not in seen:
uniq.append(w)
seen.add(w)
if len(uniq) >= max_keywords:
break
return uniq
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
"""对上下文进行简单相关性打分,仅用于控制台可视化。
评分: score = match_count*200 + min(len(text), 100000)/100
"""
results = []
for ctx in contexts:
tl = (ctx or "").lower()
match_count = sum(1 for k in keywords if k in tl)
length = len(ctx)
score = match_count * 200 + min(length, 100000) / 100.0
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
return results[:max(top_n, 0)]
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
if not os.path.exists(data_path):
raise FileNotFoundError(f"未找到数据集: {data_path}")
items: List[Dict[str, Any]] = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
items.append(json.loads(line))
except Exception:
# 跳过坏行但不中断
continue
return items
async def run_memsciqa_test(
sample_size: int = 3,
end_user_id: str | None = None,
search_limit: int = 8,
context_char_budget: int = 4000,
llm_temperature: float = 0.0,
llm_max_tokens: int = 64,
search_type: str = "embedding",
data_path: str | None = None,
start_index: int = 0,
verbose: bool = True,
) -> Dict[str, Any]:
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
- 支持从指定索引开始与评估全部样本sample_size<=0
- 支持在摄入前重置组(清空图)与跳过摄入
- 支持 keyword / embedding / hybrid 三种检索
"""
# 默认使用指定的 memsci 组 ID
end_user_id = end_user_id or "group_memsci"
# 数据路径解析
if not data_path:
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
data_path = str(dataset_dir / "msc_self_instruct.jsonl")
if not os.path.exists(data_path):
raise FileNotFoundError(
f"数据集文件不存在: {data_path}\n"
f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
)
# 加载数据
all_items = load_dataset_memsciqa(data_path)
if sample_size is None or sample_size <= 0:
items = all_items[start_index:]
else:
items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入
llm = get_llm_client(os.getenv("EVAL_LLM_ID"))
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector()
embedder = None
if search_type in ("embedding", "hybrid"):
cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"))
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 评估循环
latencies_llm: List[float] = []
latencies_search: List[float] = []
# 存储完整上下文文本用于统计
contexts_used: List[str] = []
per_query_context_chars: List[int] = []
per_query_context_counts: List[int] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
samples: List[Dict[str, Any]] = []
total_items = len(items)
for idx, item in enumerate(items):
if verbose:
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 检索:使用与 evaluate_qa.py 相同的 run_hybrid_search
t0 = time.time()
results = None
try:
if search_type in ("embedding", "hybrid"):
# 使用嵌入检索(与 qwen_search_eval 对齐)
results = await search_graph_by_embedding(
connector=connector,
embedder_client=embedder,
query_text=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
elif search_type == "keyword":
# 关键词检索(直接调用 graph_search
results = await search_graph(
connector=connector,
q=question,
end_user_id=end_user_id,
limit=search_limit,
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
)
except Exception:
results = None
t1 = time.time()
search_ms = (t1 - t0) * 1000
latencies_search.append(search_ms)
# 构建上下文:与 evaluate_qa.py 完全一致的逻辑
contexts_all: List[str] = []
retrieved_counts: Dict[str, int] = {}
if results:
# 处理 hybrid 搜索结果
if search_type == "hybrid":
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
emb_dialogs = emb.get("dialogues", [])
emb_statements = emb.get("statements", [])
emb_entities = emb.get("entities", [])
kw_dialogs = kw.get("dialogues", [])
kw_statements = kw.get("statements", [])
kw_entities = kw.get("entities", [])
all_dialogs = emb_dialogs + kw_dialogs
all_statements = emb_statements + kw_statements
all_entities = emb_entities + kw_entities
# 简单去重
seen_dialog = set()
dialogues = []
for d in all_dialogs:
key = (str(d.get("uuid", "")), str(d.get("content", "")))
if key not in seen_dialog:
dialogues.append(d)
seen_dialog.add(key)
seen_stmt = set()
statements = []
for s in all_statements:
key = str(s.get("statement", ""))
if key not in seen_stmt:
statements.append(s)
seen_stmt.add(key)
seen_ent = set()
entities = []
for e in all_entities:
key = str(e.get("name", ""))
if key not in seen_ent:
entities.append(e)
seen_ent.add(key)
else:
# embedding 或 keyword 单独搜索
dialogues = results.get("dialogues", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
retrieved_counts = {
"dialogues": len(dialogues),
"statements": len(statements),
"entities": len(entities),
}
# 构建上下文文本
for d in dialogues:
text = str(d.get("content", "")).strip()
if text:
contexts_all.append(text)
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
# 实体摘要
if entities:
scored = [e for e in entities if e.get("score") is not None]
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
if top_entities:
summary_lines = []
for e in top_entities:
name = str(e.get("name", "")).strip()
etype = str(e.get("entity_type", "")).strip()
score = e.get("score")
if name:
meta = []
if etype:
meta.append(f"type={etype}")
if isinstance(score, (int, float)):
meta.append(f"score={score:.3f}")
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
if verbose:
if retrieved_counts:
print(f"✅ 检索成功: {retrieved_counts.get('dialogues',0)} dialogues, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
print(f"📊 有效上下文数量: {len(contexts_all)}")
q_keywords = extract_question_keywords(question, max_keywords=8)
if q_keywords:
print(f"🔍 问题关键词: {set(q_keywords)}")
if contexts_all:
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
if analysis:
print("📊 上下文相关性分析:")
for a in analysis:
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
# 打印检索到的上下文预览,便于定位为何为 Unknown
print("🔎 上下文预览最多前10条每条截断展示:")
for i, ctx in enumerate(contexts_all[:10]):
preview = str(ctx).replace("\n", " ")
if len(preview) > 300:
preview = preview[:300] + "..."
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
# 标注参考答案是否出现在任一上下文中
ref_lower = (str(reference) or "").lower()
if ref_lower:
hits = []
for i, ctx in enumerate(contexts_all):
if ref_lower in str(ctx).lower():
hits.append(i+1)
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text)
per_query_context_chars.append(len(context_text))
per_query_context_counts.append(len(contexts_all))
if verbose:
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
# 展示拼接后的上下文片段,便于核查是否包含答案
concat_preview = context_text.replace("\n", " ")
if len(concat_preview) > 600:
concat_preview = concat_preview[:600] + "..."
print(f"🧵 拼接上下文预览: {concat_preview}")
messages = [
{
"role": "system",
"content": (
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
"3) Keep your answer brief and to the point;\n"
"4) Do not add explanations or additional text beyond the answer."
),
},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
try:
# 使用异步调用
resp = await llm.chat(messages=messages)
# 更健壮的响应解析处理不同的LLM响应格式
if hasattr(resp, 'content'):
pred = resp.content.strip()
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
pred = resp["choices"][0]["message"]["content"].strip()
elif isinstance(resp, dict) and "content" in resp:
pred = resp["content"].strip()
elif isinstance(resp, str):
pred = resp.strip()
else:
pred = "Unknown"
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
if pred.lower() in ["unknown", ""]:
# 如果参考答案在上下文中存在但LLM返回Unknown可能是提示词问题
ref_lower = (str(reference) or "").lower()
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown检查提示词")
except Exception as e:
# 更详细的错误处理
pred = "Unknown"
print(f"⚠️ LLM调用异常: {e}")
t3 = time.time()
llm_ms = (t3 - t2) * 1000
latencies_llm.append(llm_ms)
exact = exact_match(pred, reference)
correct_flags.append(exact)
f1_val = f1_score(str(pred), str(reference))
b1_val = bleu1(str(pred), str(reference))
j_val = jaccard(str(pred), str(reference))
f1s.append(f1_val)
b1s.append(b1_val)
jss.append(j_val)
if verbose:
print(f"🤖 LLM 回答: {pred}")
print(f"✅ 正确答案: {reference}")
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
samples.append({
"question": str(question),
"answer": str(reference),
"prediction": str(pred),
"metrics": {
"f1": f1_val,
"b1": b1_val,
"j": j_val
},
"retrieval": {
"retrieved_documents": len(contexts_all),
"context_length": len(context_text),
"search_limit": search_limit,
"max_chars": context_char_budget
},
"timing": {
"search_ms": search_ms,
"llm_ms": llm_ms
}
})
# 计算总体指标与聚合
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"context": {
"avg_tokens": ctx_avg_tokens,
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
"avg_memory_tokens": 0.0
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"samples": samples,
"params": {
"end_user_id": end_user_id,
"search_limit": search_limit,
"context_char_budget": context_char_budget,
"llm_temperature": llm_temperature,
"llm_max_tokens": llm_max_tokens,
"search_type": search_type,
"start_index": start_index,
"llm_id": os.getenv("EVAL_LLM_ID"),
"retrieval_embedding_id": os.getenv("EVAL_EMBEDDING_ID")
},
"timestamp": datetime.now().isoformat(),
}
try:
await connector.close()
except Exception:
pass
return result
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
parser.add_argument("--sample-size", type=int, default=10, help="样本数量(<=0 表示全部)")
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size")
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID默认 group_memsci")
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型hybrid 等同于 embedding")
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl")
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径JSON")
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
args = parser.parse_args()
sample_size = 0 if args.all else args.sample_size
verbose_flag = False if args.quiet else args.verbose
result = asyncio.run(
run_memsciqa_test(
sample_size=sample_size,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
data_path=args.data_path,
start_index=args.start_index,
verbose=verbose_flag,
)
)
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果保存
out_path = args.output
if not out_path:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, "results")
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
try:
os.makedirs(os.path.dirname(out_path), exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n💾 结果已保存: {out_path}")
except Exception as e:
print(f"⚠️ 结果保存失败: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,369 +0,0 @@
import argparse
import asyncio
import json
import os
import time
from datetime import datetime
from typing import List, Dict, Any
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
if not contexts:
return ""
import re
# 提取问题关键词(移除停用词)
question_lower = (question or "").lower()
stop_words = {
'what','when','where','who','why','how','did','do','does','is','are','was','were',
'the','a','an','and','or','but'
}
question_words = set(re.findall(r"\b\w+\b", question_lower))
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
# 评分
scored = []
for i, ctx in enumerate(contexts):
ctx_lower = (ctx or "").lower()
score = 0
matches = 0
for w in question_words:
if w in ctx_lower:
matches += 1
score += ctx_lower.count(w) * 2
length = len(ctx)
if 100 < length < 2000:
score += 5
elif length >= 2000:
score += 2
if i < 3:
score += 3
scored.append((score, ctx, matches))
scored.sort(key=lambda x: x[0], reverse=True)
# 选择直到达到字符限制,必要时截断包含关键词的段落
selected: List[str] = []
total = 0
for score, ctx, _ in scored:
if total + len(ctx) <= max_chars:
selected.append(ctx)
total += len(ctx)
else:
if score > 10 and total < max_chars - 200:
remaining = max_chars - total
lines = ctx.split('\n')
rel_lines: List[str] = []
cur = 0
for line in lines:
l = line.lower()
if any(w in l for w in question_words) and cur < remaining - 50:
rel_lines.append(line)
cur += len(line)
if rel_lines:
truncated = '\n'.join(rel_lines)
if len(truncated) > 50:
selected.append(truncated + "\n[相关内容截断...]")
total += len(truncated)
break
return "\n\n".join(selected)
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
"""Compose a text context from `dialog` list in msc_self_instruct item."""
parts: List[str] = []
for turn in dialog_obj.get("dialog", []):
speaker = turn.get("speaker", "")
text = turn.get("text", "")
if text:
parts.append(f"{speaker}: {text}")
return "\n".join(parts)
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Combine dialogues from embedding and keyword searches (embedding first)."""
if results is None:
return []
emb = []
kw = []
if isinstance(results.get("embedding_search"), dict):
emb = results.get("embedding_search", {}).get("dialogues", []) or []
elif isinstance(results.get("dialogues"), list):
emb = results.get("dialogues", []) or []
if isinstance(results.get("keyword_search"), dict):
kw = results.get("keyword_search", {}).get("dialogues", []) or []
seen = set()
merged: List[Dict[str, Any]] = []
for d in emb:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
for d in kw:
k = (str(d.get("uuid", "")), str(d.get("content", "")))
if k not in seen:
merged.append(d)
seen.add(k)
return merged
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
end_user_id = end_user_id or SELECTED_GROUP_ID
# Load data
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
data_path = dataset_dir / "msc_self_instruct.jsonl"
if not os.path.exists(data_path):
raise FileNotFoundError(
f"数据集文件不存在: {data_path}\n"
f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
)
with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines()
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
# LLM client (使用异步调用)
from app.db import get_db
db = next(get_db())
try:
llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db)
finally:
db.close()
# Evaluate each item
connector = Neo4jConnector()
latencies_llm: List[float] = []
latencies_search: List[float] = []
contexts_used: List[str] = []
correct_flags: List[float] = []
f1s: List[float] = []
b1s: List[float] = []
jss: List[float] = []
try:
for item in items:
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 检索:对齐 locomo 的三路检索dialogues/statements/entities
t0 = time.time()
try:
results = await run_hybrid_search(
query_text=question,
search_type=search_type,
end_user_id=end_user_id,
limit=search_limit,
include=["dialogues", "statements", "entities"],
output_path=None,
)
except Exception:
results = None
t1 = time.time()
latencies_search.append((t1 - t0) * 1000)
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
contexts_all: List[str] = []
if results:
if search_type == "hybrid":
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
emb_dialogs = emb.get("dialogues", [])
emb_statements = emb.get("statements", [])
emb_entities = emb.get("entities", [])
kw_dialogs = kw.get("dialogues", [])
kw_statements = kw.get("statements", [])
kw_entities = kw.get("entities", [])
all_dialogs = emb_dialogs + kw_dialogs
all_statements = emb_statements + kw_statements
all_entities = emb_entities + kw_entities
# 简单去重与限制
seen_texts = set()
for d in all_dialogs:
text = str(d.get("content", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
for s in all_statements:
text = str(s.get("statement", "")).strip()
if text and text not in seen_texts:
contexts_all.append(text)
seen_texts.add(text)
if len(contexts_all) >= search_limit:
break
# 实体摘要最多3个
names = []
merged_entities = all_entities[:]
for e in merged_entities:
name = str(e.get("name", "")).strip()
if name and name not in names:
names.append(name)
if len(names) >= 3:
break
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
else:
dialogs = results.get("dialogues", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
for d in dialogs:
text = str(d.get("content", "")).strip()
if text:
contexts_all.append(text)
for s in statements:
text = str(s.get("statement", "")).strip()
if text:
contexts_all.append(text)
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
if names:
contexts_all.append("EntitySummary: " + ", ".join(names))
# 智能选择并截断到预算
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
if not context_text:
context_text = "No relevant context found."
contexts_used.append(context_text[:200])
# Call LLM (使用异步调用)
messages = [
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
]
t2 = time.time()
resp = await llm_client.chat(messages=messages)
t3 = time.time()
latencies_llm.append((t3 - t2) * 1000)
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference)))
# Aggregate metrics
acc = sum(correct_flags) / max(len(correct_flags), 1)
ctx_avg_tokens = avg_context_tokens(contexts_used)
result = {
"dataset": "memsciqa",
"items": len(items),
"metrics": {
"accuracy": acc,
# Placeholders for extensibility
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
},
"latency": {
"search": latency_stats(latencies_search),
"llm": latency_stats(latencies_llm),
},
"avg_context_tokens": ctx_avg_tokens,
}
return result
finally:
await connector.close()
def main():
# Load environment variables first
load_dotenv()
# Get defaults from environment variables
env_sample_size = os.getenv("MEMSCIQA_SAMPLE_SIZE")
env_search_limit = os.getenv("MEMSCIQA_SEARCH_LIMIT")
env_context_budget = os.getenv("MEMSCIQA_CONTEXT_CHAR_BUDGET")
env_llm_max_tokens = os.getenv("MEMSCIQA_LLM_MAX_TOKENS")
env_skip_ingest = os.getenv("MEMSCIQA_SKIP_INGEST", "false").lower() in ("true", "1", "yes")
env_output_dir = os.getenv("MEMSCIQA_OUTPUT_DIR")
# Convert to appropriate types with fallback to code defaults
default_sample_size = int(env_sample_size) if env_sample_size else 1
default_search_limit = int(env_search_limit) if env_search_limit else 8
default_context_budget = int(env_context_budget) if env_context_budget else 4000
default_llm_max_tokens = int(env_llm_max_tokens) if env_llm_max_tokens else 64
default_output_dir = env_output_dir if env_output_dir else None
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--end-user-id", type=str, default=None, help="可选 end_user_id默认使用环境变量")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=default_llm_max_tokens,
help=f"LLM 最大生成长度 (env: MEMSCIQA_LLM_MAX_TOKENS={env_llm_max_tokens or 'not set'})")
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
parser.add_argument("--skip-ingest", action="store_true", default=env_skip_ingest,
help=f"跳过数据摄入,使用 Neo4j 中的现有数据 (env: MEMSCIQA_SKIP_INGEST={os.getenv('MEMSCIQA_SKIP_INGEST', 'false')})")
parser.add_argument("--output-dir", type=str, default=default_output_dir,
help=f"结果保存目录 (env: MEMSCIQA_OUTPUT_DIR={env_output_dir or 'not set'})")
args = parser.parse_args()
result = asyncio.run(
run_memsciqa_eval(
sample_size=args.sample_size,
end_user_id=args.end_user_id,
search_limit=args.search_limit,
context_char_budget=args.context_char_budget,
llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type,
skip_ingest=args.skip_ingest,
)
)
# Print results to console
print(json.dumps(result, ensure_ascii=False, indent=2))
# Save results to file
output_dir = args.output_dir
if output_dir is None:
# Use absolute path to ensure results are saved in the correct location
script_dir = Path(__file__).resolve().parent
output_dir = script_dir / "results"
elif not Path(output_dir).is_absolute():
# If relative path, make it relative to this script's directory
script_dir = Path(__file__).resolve().parent
output_dir = script_dir / output_dir
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = output_dir / f"memsciqa_{timestamp_str}.json"
try:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n✅ 结果已保存到: {output_path}")
except Exception as e:
print(f"\n❌ 保存结果失败: {e}")
if __name__ == "__main__":
main()

View File

@@ -1,147 +0,0 @@
import argparse
import asyncio
import json
import os
from typing import Any, Dict
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
async def run(
dataset: str,
sample_size: int,
reset_group: bool,
end_user_id: str | None,
judge_model: str | None = None,
search_limit: int | None = None,
context_char_budget: int | None = None,
llm_temperature: float | None = None,
llm_max_tokens: int | None = None,
search_type: str | None = None,
start_index: int | None = None,
max_contexts_per_item: int | None = None,
) -> Dict[str, Any]:
# Use environment variable with fallback chain if not provided
if end_user_id is None:
end_user_id = os.getenv("EVAL_END_USER_ID", "benchmark_default")
if reset_group:
connector = Neo4jConnector()
try:
await connector.delete_group(end_user_id)
finally:
await connector.close()
if dataset == "locomo":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_locomo_eval(**kwargs)
if dataset == "memsciqa":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
return await run_memsciqa_eval(**kwargs)
if dataset == "longmemeval":
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
if search_limit is not None:
kwargs["search_limit"] = search_limit
if context_char_budget is not None:
kwargs["context_char_budget"] = context_char_budget
if llm_temperature is not None:
kwargs["llm_temperature"] = llm_temperature
if llm_max_tokens is not None:
kwargs["llm_max_tokens"] = llm_max_tokens
if search_type is not None:
kwargs["search_type"] = search_type
if start_index is not None:
kwargs["start_index"] = start_index
if max_contexts_per_item is not None:
kwargs["max_contexts_per_item"] = max_contexts_per_item
return await run_longmemeval_test(**kwargs)
raise ValueError(f"未知数据集: {dataset}")
def main():
load_dotenv()
parser = argparse.ArgumentParser(description="统一评估入口memsciqa / longmemeval / locomo")
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json")
parser.add_argument("--judge-model", type=str, default=None, help="可选longmemeval 判别式评测模型名")
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens不提供则使用各脚本默认")
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
# 仅透传到 longmemeval其他数据集忽略
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval起始样本索引不提供则用脚本默认")
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval每条样本摄入的上下文数量上限不提供则用脚本默认")
parser.add_argument("--output", type=str, default=None, help="可选将评估结果保存到指定文件路径JSON不提供时默认保存到 evaluation/<dataset>/results 目录")
args = parser.parse_args()
result = asyncio.run(run(
args.dataset,
args.sample_size,
args.reset_group,
args.end_user_id,
args.judge_model,
args.search_limit,
args.context_char_budget,
args.llm_temperature,
args.llm_max_tokens,
args.search_type,
args.start_index,
args.max_contexts_per_item,
))
print(json.dumps(result, ensure_ascii=False, indent=2))
# 结果输出逻辑保持不变
if args.output:
out_path = args.output
else:
eval_dir = os.path.dirname(os.path.abspath(__file__))
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
out_filename = f"{args.dataset}_{args.sample_size}.json"
out_path = os.path.join(dataset_results_dir, out_filename)
out_dir = os.path.dirname(out_path)
if out_dir and not os.path.exists(out_dir):
os.makedirs(out_dir, exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n结果已保存到: {out_path}")
if __name__ == "__main__":
main()