delete benchmark-test (#204)
* Refactor: Move evaluation folder to redbear-mem-benchmark submodule * [changes]Restore .gitmodules
This commit is contained in:
@@ -1,224 +0,0 @@
|
||||
# ============================================================================
|
||||
# 基准测试统一配置文件示例
|
||||
# ============================================================================
|
||||
# 复制此文件为 .env.evaluation 并根据需要修改
|
||||
# 支持的基准测试:LoCoMo、LongMemEval、MemSciQA
|
||||
# ============================================================================
|
||||
|
||||
# ============================================================================
|
||||
# 通用配置(所有基准测试共用)
|
||||
# ============================================================================
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Neo4j 配置
|
||||
# ----------------------------------------------------------------------------
|
||||
# 默认 Group ID(建议各基准测试使用独立的 group)
|
||||
EVAL_GROUP_ID=benchmark_default
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 模型配置(必需)
|
||||
# ----------------------------------------------------------------------------
|
||||
# ⚠️ 必填:从数据库 models 表中选择有效的模型 ID
|
||||
#
|
||||
# 如何获取模型 ID:
|
||||
# 1. 查询数据库:SELECT id, model_name FROM models WHERE is_active = true;
|
||||
# 2. 或通过系统管理界面查看
|
||||
# 3. 确保模型可用且配置正确
|
||||
|
||||
# LLM 模型 ID(必填)
|
||||
EVAL_LLM_ID=your_llm_model_id_here
|
||||
|
||||
# Embedding 模型 ID(必填)
|
||||
EVAL_EMBEDDING_ID=your_embedding_model_id_here
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 检索参数
|
||||
# ----------------------------------------------------------------------------
|
||||
# 检索类型: "keyword", "embedding", "hybrid"
|
||||
EVAL_SEARCH_TYPE=hybrid
|
||||
|
||||
# 检索结果数量限制(默认值)
|
||||
EVAL_SEARCH_LIMIT=12
|
||||
|
||||
# 上下文最大字符数(默认值)
|
||||
EVAL_MAX_CONTEXT_CHARS=8000
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# LLM 参数
|
||||
# ----------------------------------------------------------------------------
|
||||
# LLM 温度参数(0.0 = 确定性输出)
|
||||
EVAL_LLM_TEMPERATURE=0.0
|
||||
|
||||
# LLM 最大生成 token 数
|
||||
EVAL_LLM_MAX_TOKENS=32
|
||||
|
||||
# LLM 超时时间(秒)
|
||||
EVAL_LLM_TIMEOUT=10.0
|
||||
|
||||
# LLM 最大重试次数
|
||||
EVAL_LLM_MAX_RETRIES=1
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 数据处理参数
|
||||
# ----------------------------------------------------------------------------
|
||||
# Chunker 策略
|
||||
EVAL_CHUNKER_STRATEGY=RecursiveChunker
|
||||
|
||||
# 是否在导入前清空现有数据
|
||||
EVAL_RESET_ON_INGEST=true
|
||||
|
||||
# 是否保存详细日志
|
||||
EVAL_SAVE_DETAILED_LOGS=true
|
||||
|
||||
# ============================================================================
|
||||
# LoCoMo 基准测试专用配置
|
||||
# ============================================================================
|
||||
# 数据集:locomo10.json
|
||||
# 运行:python locomo_benchmark.py --sample_size 20
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
# Group ID(LoCoMo 专用)
|
||||
LOCOMO_GROUP_ID=locomo_benchmark
|
||||
|
||||
# 测试样本数量
|
||||
# 建议值:20(快速测试)、100(中等测试)、1986(完整测试)
|
||||
LOCOMO_SAMPLE_SIZE=20
|
||||
|
||||
# 检索结果数量限制
|
||||
LOCOMO_SEARCH_LIMIT=12
|
||||
|
||||
# 上下文最大字符数
|
||||
LOCOMO_CONTEXT_CHAR_BUDGET=8000
|
||||
|
||||
# 导入的对话数量
|
||||
LOCOMO_MAX_DIALOGUES=1
|
||||
|
||||
# 跳过数据摄入(true=跳过,false=摄入)
|
||||
# 首次运行设置为 false,后续运行可设置为 true 以节省时间
|
||||
LOCOMO_SKIP_INGEST=false
|
||||
|
||||
# 结果保存目录
|
||||
LOCOMO_OUTPUT_DIR=locomo/results
|
||||
|
||||
# ============================================================================
|
||||
# LongMemEval 基准测试专用配置
|
||||
# ============================================================================
|
||||
# 数据集:longmemeval_oracle_zh.json
|
||||
# 运行:python longmemeval_benchmark.py --sample_size 3
|
||||
# 特点:支持时间推理问题的增强检索
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
# Group ID(LongMemEval 专用)
|
||||
LONGMEMEVAL_GROUP_ID=longmemeval_zh_bak_3
|
||||
|
||||
# 测试样本数量(<=0 表示全部样本)
|
||||
LONGMEMEVAL_SAMPLE_SIZE=3
|
||||
|
||||
# 起始样本索引
|
||||
LONGMEMEVAL_START_INDEX=0
|
||||
|
||||
# 检索结果数量限制
|
||||
LONGMEMEVAL_SEARCH_LIMIT=8
|
||||
|
||||
# 上下文最大字符数
|
||||
LONGMEMEVAL_CONTEXT_CHAR_BUDGET=4000
|
||||
|
||||
# LLM 最大生成 token 数
|
||||
LONGMEMEVAL_LLM_MAX_TOKENS=16
|
||||
|
||||
# 每条样本最多摄入的上下文段数
|
||||
LONGMEMEVAL_MAX_CONTEXTS_PER_ITEM=2
|
||||
|
||||
# 是否保存分块结果
|
||||
LONGMEMEVAL_SAVE_CHUNK_OUTPUT=true
|
||||
|
||||
# 自定义分块输出路径(留空使用默认)
|
||||
LONGMEMEVAL_SAVE_CHUNK_OUTPUT_PATH=
|
||||
|
||||
# 摄入前是否清空组数据
|
||||
LONGMEMEVAL_RESET_GROUP_BEFORE_INGEST=false
|
||||
|
||||
# 是否跳过摄入,仅检索评估
|
||||
LONGMEMEVAL_SKIP_INGEST=false
|
||||
|
||||
# 结果保存目录
|
||||
LONGMEMEVAL_OUTPUT_DIR=longmemeval/results
|
||||
|
||||
# ============================================================================
|
||||
# MemSciQA 基准测试专用配置
|
||||
# ============================================================================
|
||||
# 数据集:msc_self_instruct.jsonl
|
||||
# 运行:python memsciqa_benchmark.py --sample_size 1
|
||||
# 特点:对话记忆检索评估
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
# Group ID(MemSciQA 专用,独立数据集)
|
||||
MEMSCIQA_GROUP_ID=memsciqa_benchmark
|
||||
|
||||
# 测试样本数量
|
||||
MEMSCIQA_SAMPLE_SIZE=1 # 0或者-1标识测试数据集中的所有样本
|
||||
|
||||
# 检索结果数量限制
|
||||
MEMSCIQA_SEARCH_LIMIT=8
|
||||
|
||||
# 上下文最大字符数
|
||||
MEMSCIQA_CONTEXT_CHAR_BUDGET=4000
|
||||
|
||||
# LLM 最大生成 token 数
|
||||
MEMSCIQA_LLM_MAX_TOKENS=64
|
||||
|
||||
# 跳过数据摄入(true=跳过,false=摄入)
|
||||
# 首次运行设置为 false,后续运行可设置为 true 以节省时间
|
||||
MEMSCIQA_SKIP_INGEST=false
|
||||
|
||||
# 结果保存目录(相对于 memsciqa 脚本所在目录)
|
||||
# 使用 "results" 会保存到 api/app/core/memory/evaluation/memsciqa/results/
|
||||
MEMSCIQA_OUTPUT_DIR=results
|
||||
|
||||
# ============================================================================
|
||||
# 高级配置(可选)
|
||||
# ============================================================================
|
||||
|
||||
# BM25 权重(用于混合检索,0.0-1.0)
|
||||
EVAL_RERANK_ALPHA=0.6
|
||||
|
||||
# 是否使用遗忘重排序
|
||||
EVAL_USE_FORGETTING_RERANK=false
|
||||
|
||||
# 是否使用 LLM 重排序
|
||||
EVAL_USE_LLM_RERANK=false
|
||||
|
||||
# 连接重置间隔(每 N 个问题重置一次)
|
||||
EVAL_RESET_INTERVAL=5
|
||||
|
||||
# 性能阈值(低于此值触发重置)
|
||||
EVAL_PERFORMANCE_THRESHOLD=0.6
|
||||
|
||||
# ============================================================================
|
||||
# 快速配置指南
|
||||
# ============================================================================
|
||||
# 1. 复制此文件为 .env.evaluation
|
||||
# 2. 修改 EVAL_LLM_ID 和 EVAL_EMBEDDING_ID 为你的模型 ID
|
||||
# 3. 根据需要修改各基准测试的专用配置
|
||||
# 4. 运行测试:
|
||||
# - LoCoMo: python locomo/locomo_benchmark.py --sample_size 20
|
||||
# - LongMemEval: python longmemeval/longmemeval_benchmark.py --sample_size 3 --all
|
||||
# - MemSciQA: python memsciqa/memsciqa_benchmark.py --sample_size 10
|
||||
# 配置优先级:
|
||||
# 命令行参数 > 特定配置(如 LOCOMO_*)> 通用配置(EVAL_*)> 代码默认值
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# 执行LoCoMo测试
|
||||
# 只摄入前5条消息,评估3个问题(最小测试)
|
||||
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
|
||||
#
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试
|
||||
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
|
||||
|
||||
|
||||
# 执行longmemeval测试
|
||||
# python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 10 --max-contexts-per-item 3 --reset-group-before-ingest
|
||||
|
||||
# 执行memsciqa测试
|
||||
# python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 1
|
||||
13
api/app/core/memory/evaluation/.gitignore
vendored
13
api/app/core/memory/evaluation/.gitignore
vendored
@@ -1,13 +0,0 @@
|
||||
# 忽略实际的评估配置文件(包含敏感信息)
|
||||
.env.evaluation
|
||||
|
||||
# 保留示例文件
|
||||
!.env.evaluation.example
|
||||
|
||||
# 忽略测试结果文件
|
||||
*/results/*.json
|
||||
*/results/*.log
|
||||
|
||||
# 忽略数据集文件(文件过大,不应提交到 Git)
|
||||
dataset/*.json
|
||||
dataset/*.jsonl
|
||||
@@ -1 +0,0 @@
|
||||
"""Evaluation package with dataset-specific pipelines and a unified runner."""
|
||||
@@ -1,748 +0,0 @@
|
||||
# 1.数据集下载地址
|
||||
Locomo10.json : https://github.com/snap-research/locomo/tree/main/data
|
||||
LongMemEval_oracle.json : https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
||||
msc_self_instruct.jsonl : https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
||||
|
||||
数据集下载之后保存至api\app\core\memory\evaluation\dataset目录下
|
||||
# 2.配置说明
|
||||
文件api\app\core\memory\evaluation\.env.evaluation.example对三个基准测试所需配置有着详细的说明
|
||||
**实际配置文件**:api\app\core\memory\evaluation\.env.evaluation
|
||||
```python
|
||||
# 当使用不带配置参数的命令行执行基准测试,基准测试所需的配置参数根据.env.evaluation中的参数执行
|
||||
python -m app.core.memory.evaluation.locomo.locomo_benchmark
|
||||
```
|
||||
**检查neo4j指定的grou_id是否摄入数据**
|
||||
```python
|
||||
# 1. 进入交互模式
|
||||
python -m app.core.memory.evaluation.check_enduser_data
|
||||
|
||||
# 2. 选择 "1" 检查指定 group
|
||||
# 3. 输入 group_id,例如: locomo_benchmark
|
||||
# 4. 选择是否显示详细统计 (y/n)
|
||||
```
|
||||
# 3.locomo
|
||||
|
||||
### (1)locomo执行命令
|
||||
```python
|
||||
# 首先进入api目录
|
||||
cd api
|
||||
|
||||
# 只摄入前5条消息,评估3个问题(最小测试)
|
||||
python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
|
||||
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
|
||||
python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
|
||||
```
|
||||
### (2)locomo结果说明
|
||||
|
||||
#### 结果示例
|
||||
```json
|
||||
{
|
||||
"dataset": "locomo",
|
||||
"sample_size": 0,
|
||||
"timestamp": "2026-01-26T11:24:28.239156",
|
||||
"params": {
|
||||
"group_id": "locomo_benchmark",
|
||||
"search_type": "hybrid",
|
||||
"search_limit": 12,
|
||||
"context_char_budget": 8000,
|
||||
"llm_id": "2c9b0782-7a85-4740-ba84-4baf77f256c4",
|
||||
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2"
|
||||
},
|
||||
"overall_metrics": {
|
||||
"f1": 0.0,
|
||||
"bleu1": 0.0,
|
||||
"jaccard": 0.0,
|
||||
"locomo_f1": 0.0
|
||||
},
|
||||
"by_category": {},
|
||||
"latency": {
|
||||
"search": {
|
||||
"mean": 0.0,
|
||||
"p50": 0.0,
|
||||
"p95": 0.0,
|
||||
"iqr": 0.0
|
||||
},
|
||||
"llm": {
|
||||
"mean": 0.0,
|
||||
"p50": 0.0,
|
||||
"p95": 0.0,
|
||||
"iqr": 0.0
|
||||
}
|
||||
},
|
||||
"context_stats": {
|
||||
"avg_retrieved_docs": 0.0,
|
||||
"avg_context_chars": 0.0,
|
||||
"avg_context_tokens": 0.0
|
||||
},
|
||||
"samples": []
|
||||
}
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
##### 1. 核心评估指标 (overall_metrics)
|
||||
|
||||
**🎯 关键进步指标:**
|
||||
|
||||
- **`f1`** (F1 Score): 精确率和召回率的调和平均值
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量检索和生成答案的准确性
|
||||
- 这是最重要的综合性能指标
|
||||
- 优秀标准:> 0.85
|
||||
|
||||
- **`bleu1`** (BLEU-1): 单词级别的匹配度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量生成答案与标准答案的词汇重叠度
|
||||
- 关注词汇层面的准确性
|
||||
|
||||
- **`jaccard`** (Jaccard 相似度): 集合相似度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量答案集合的相似性
|
||||
- 计算公式:交集大小 / 并集大小
|
||||
|
||||
- **`locomo_f1`**: Locomo 特定的 F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,针对 Locomo 数据集优化的评估指标
|
||||
- 考虑了长对话记忆的特殊性
|
||||
|
||||
##### 2. 性能指标 (latency)
|
||||
|
||||
**⚡ 关键效率指标:**
|
||||
|
||||
- **`search`**: 检索延迟统计(单位:毫秒)
|
||||
- `mean`: 平均延迟
|
||||
- `p50`: 中位数延迟(50%的请求在此时间内完成)
|
||||
- `p95`: 95分位数延迟(95%的请求在此时间内完成)
|
||||
- `iqr`: 四分位距(Q3-Q1,衡量稳定性)
|
||||
- **越低越好**,衡量记忆检索速度
|
||||
- 优秀标准:p95 < 2000ms
|
||||
|
||||
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
|
||||
- `mean`: 平均推理时间
|
||||
- `p50`: 中位数推理时间
|
||||
- `p95`: 95分位数推理时间
|
||||
- `iqr`: 四分位距(越小越稳定)
|
||||
- **越低越好**,衡量答案生成速度
|
||||
- 优秀标准:p95 < 3000ms
|
||||
|
||||
##### 3. 上下文统计 (context_stats)
|
||||
|
||||
**📊 资源效率指标:**
|
||||
|
||||
- **`avg_retrieved_docs`**: 平均检索文档数
|
||||
- 反映检索策略的广度
|
||||
- 需要平衡:太少可能信息不足,太多增加噪音和延迟
|
||||
- 建议范围:8-15 个文档
|
||||
|
||||
- **`avg_context_chars`**: 平均上下文字符数
|
||||
- 反映检索内容的总量
|
||||
- 应在满足准确性前提下尽量精简
|
||||
- 受 `context_char_budget` 参数限制
|
||||
|
||||
- **`avg_context_tokens`**: 平均上下文 token 数
|
||||
- **越低越好**(在保持准确性前提下)
|
||||
- 直接影响 API 调用成本和推理速度
|
||||
- 成本效益比 = f1 / avg_context_tokens
|
||||
|
||||
##### 4. 分类统计 (by_category)
|
||||
|
||||
- 按问题类型分类的性能指标
|
||||
- 帮助识别系统在不同场景下的强弱项
|
||||
- 可针对性优化特定类型的问题
|
||||
|
||||
#### 系统进步衡量标准
|
||||
|
||||
**一级指标(最重要):**
|
||||
- `f1` 和 `locomo_f1` 提升 → 核心能力提升
|
||||
- 目标:f1 > 0.85
|
||||
|
||||
**二级指标(重要):**
|
||||
- `latency.p95` 降低 → 用户体验提升
|
||||
- 目标:search.p95 < 2000ms, llm.p95 < 3000ms
|
||||
|
||||
**三级指标(辅助):**
|
||||
- `avg_context_tokens` 降低(在保持 f1 前提下)→ 成本优化
|
||||
- `iqr` 降低 → 性能稳定性提升
|
||||
# 4.longmemeval
|
||||
支持时间推理问题的增强检索
|
||||
### (1)执行命令
|
||||
```python
|
||||
# 首先进入api目录
|
||||
cd api
|
||||
|
||||
# 不带参数运行 - 使用环境变量
|
||||
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark
|
||||
|
||||
# 命令行参数覆盖环境变量
|
||||
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 2
|
||||
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
|
||||
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --skip_ingest
|
||||
```
|
||||
### (2)结果说明
|
||||
|
||||
#### 结果示例
|
||||
```json
|
||||
{
|
||||
"dataset": "longmemeval",
|
||||
"items": 1,
|
||||
"accuracy_by_type": {
|
||||
"single-session-user": 1.0
|
||||
},
|
||||
"f1_by_type": {
|
||||
"single-session-user": 1.0
|
||||
},
|
||||
"jaccard_by_type": {
|
||||
"single-session-user": 1.0
|
||||
},
|
||||
"samples": [
|
||||
{
|
||||
"question": "What degree did I graduate with?",
|
||||
"prediction": "Business Administration",
|
||||
"answer": "Business Administration",
|
||||
"question_type": "single-session-user",
|
||||
"is_temporal": false,
|
||||
"question_id": "e47becba",
|
||||
"options": [],
|
||||
"context_count": 13,
|
||||
"context_chars": 1268,
|
||||
"retrieved_dialogue_count": 0,
|
||||
"retrieved_statement_count": 12,
|
||||
"metrics": {
|
||||
"exact_match": true,
|
||||
"f1": 1.0,
|
||||
"jaccard": 1.0
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": 1483.100175857544,
|
||||
"llm_ms": 995.8682060241699
|
||||
}
|
||||
}
|
||||
],
|
||||
"latency": {
|
||||
"search": {
|
||||
"mean": 1483.100175857544,
|
||||
"p50": 1483.100175857544,
|
||||
"p95": 1483.100175857544,
|
||||
"iqr": 0.0
|
||||
},
|
||||
"llm": {
|
||||
"mean": 995.8682060241699,
|
||||
"p50": 995.8682060241699,
|
||||
"p95": 995.8682060241699,
|
||||
"iqr": 0.0
|
||||
}
|
||||
},
|
||||
"context": {
|
||||
"avg_tokens": 204.0,
|
||||
"avg_chars": 1268,
|
||||
"count_avg": 13
|
||||
},
|
||||
"params": {
|
||||
"group_id": "longmemeval_zh_bak_3",
|
||||
"search_limit": 8,
|
||||
"context_char_budget": 4000,
|
||||
"search_type": "hybrid",
|
||||
"llm_id": "6dc52e1b-9cec-4194-af66-a74c6307fc3f",
|
||||
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2",
|
||||
"sample_size": 1,
|
||||
"start_index": 0
|
||||
},
|
||||
"timestamp": "2026-01-24T21:36:10.818308",
|
||||
"metric_summary": {
|
||||
"score_accuracy": 100.0,
|
||||
"latency_median_s": 2.478968381881714,
|
||||
"latency_iqr_s": 0.0,
|
||||
"avg_context_tokens_k": 0.204
|
||||
},
|
||||
"diagnostics": {
|
||||
"duplicate_previews_top": [],
|
||||
"unique_preview_count": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
##### 1. 核心评估指标
|
||||
|
||||
**🎯 关键进步指标:**
|
||||
|
||||
- **`accuracy_by_type`**: 按问题类型分类的准确率
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,1.0 表示 100% 准确
|
||||
- 问题类型包括:
|
||||
- `single-session-user`: 单会话用户信息
|
||||
- `single-session-event`: 单会话事件信息
|
||||
- `multi-session-user`: 多会话用户信息
|
||||
- `multi-session-event`: 多会话事件信息
|
||||
- 可以识别系统在不同场景下的强弱项
|
||||
|
||||
- **`f1_by_type`**: 按问题类型的 F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,综合评估精确率和召回率
|
||||
- 比单纯的准确率更全面
|
||||
|
||||
- **`jaccard_by_type`**: 按问题类型的 Jaccard 相似度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量答案集合匹配度
|
||||
- 对于集合类答案特别有用
|
||||
|
||||
##### 2. 样本级指标 (samples)
|
||||
|
||||
**详细诊断指标:**
|
||||
|
||||
- **`metrics.exact_match`**: 精确匹配(布尔值)
|
||||
- **true 越多越好**,最严格的评估标准
|
||||
- 要求预测答案与标准答案完全一致
|
||||
|
||||
- **`metrics.f1`**: 单个样本的 F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量单个问题的回答质量
|
||||
|
||||
- **`is_temporal`**: 是否为时间推理问题
|
||||
- 布尔值,标识问题是否涉及时间推理
|
||||
- 时间推理问题通常更具挑战性
|
||||
|
||||
- **`context_count`**: 检索到的上下文数量
|
||||
- 反映检索策略的有效性
|
||||
- 建议范围:8-15 个上下文片段
|
||||
|
||||
- **`retrieved_dialogue_count`**: 检索到的对话数
|
||||
- **`retrieved_statement_count`**: 检索到的陈述数
|
||||
- 这两个指标帮助理解检索的内容类型分布
|
||||
- 可用于优化检索策略
|
||||
|
||||
- **`timing.search_ms`**: 单个问题的检索延迟(毫秒)
|
||||
- **`timing.llm_ms`**: 单个问题的 LLM 推理延迟(毫秒)
|
||||
- **越低越好**,反映单次查询的响应速度
|
||||
|
||||
##### 3. 汇总指标 (metric_summary)
|
||||
|
||||
**📊 关键 KPI:**
|
||||
|
||||
- **`score_accuracy`**: 总体准确率百分比
|
||||
- 范围:0.0 - 100.0
|
||||
- **越高越好**,最直观的性能指标
|
||||
- 优秀标准:> 90.0
|
||||
|
||||
- **`latency_median_s`**: 中位延迟(秒)
|
||||
- **越低越好**,反映真实响应速度
|
||||
- 优秀标准:< 3.0 秒
|
||||
|
||||
- **`latency_iqr_s`**: 延迟四分位距(秒)
|
||||
- **越低越好**,反映性能稳定性
|
||||
- 越小说明响应时间越稳定
|
||||
|
||||
- **`avg_context_tokens_k`**: 平均上下文 token 数(千)
|
||||
- **越低越好**(在保持准确性前提下)
|
||||
- 直接影响 API 调用成本
|
||||
- 成本效益比 = score_accuracy / (avg_context_tokens_k * 1000)
|
||||
|
||||
##### 4. 上下文统计 (context)
|
||||
|
||||
- **`avg_tokens`**: 平均 token 数
|
||||
- **`avg_chars`**: 平均字符数
|
||||
- **`count_avg`**: 平均上下文片段数
|
||||
- 这些指标反映检索内容的规模
|
||||
- 需要在准确性和效率之间平衡
|
||||
|
||||
##### 5. 性能指标 (latency)
|
||||
|
||||
**⚡ 效率指标:**
|
||||
|
||||
- **`search`**: 检索延迟统计(单位:毫秒)
|
||||
- `mean`: 平均延迟
|
||||
- `p50`: 中位数延迟
|
||||
- `p95`: 95分位数延迟
|
||||
- `iqr`: 四分位距
|
||||
- **越低越好**,衡量记忆检索速度
|
||||
|
||||
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
|
||||
- `mean`: 平均推理时间
|
||||
- `p50`: 中位数推理时间
|
||||
- `p95`: 95分位数推理时间
|
||||
- `iqr`: 四分位距
|
||||
- **越低越好**,衡量答案生成速度
|
||||
|
||||
##### 6. 诊断信息 (diagnostics)
|
||||
|
||||
- **`duplicate_previews_top`**: 重复预览统计
|
||||
- 列出出现频率最高的重复内容
|
||||
- 帮助发现检索冗余问题
|
||||
- 应该尽量减少重复
|
||||
|
||||
- **`unique_preview_count`**: 唯一预览数量
|
||||
- 反映检索多样性
|
||||
- **越高越好**,说明检索到的内容更丰富
|
||||
|
||||
#### 系统进步衡量标准
|
||||
|
||||
**一级指标(最重要):**
|
||||
- `score_accuracy` 提升 → 核心能力提升
|
||||
- 目标:> 90.0%
|
||||
- 各类型的 `accuracy_by_type` 均衡提升 → 全面能力提升
|
||||
|
||||
**二级指标(重要):**
|
||||
- `latency_median_s` 降低 → 用户体验提升
|
||||
- 目标:< 3.0 秒
|
||||
- `exact_match` 比例提升 → 精确度提升
|
||||
|
||||
**三级指标(辅助):**
|
||||
- `avg_context_tokens_k` 降低(在保持准确性前提下)→ 成本优化
|
||||
- `unique_preview_count` 提升 → 检索多样性提升
|
||||
- `latency_iqr_s` 降低 → 性能稳定性提升
|
||||
|
||||
**特殊关注:**
|
||||
- 时间推理问题(`is_temporal: true`)的准确率
|
||||
- 多会话问题的准确率(通常更具挑战性)
|
||||
# 5.memsciqa
|
||||
对话记忆检索评估
|
||||
### (1)执行命令
|
||||
```python
|
||||
# 首先进入api目录
|
||||
cd api
|
||||
|
||||
# 不带参数运行 - 使用环境变量
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark
|
||||
|
||||
# 命令行参数覆盖环境变量
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 100
|
||||
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --skip_ingest
|
||||
```
|
||||
### (2)结果说明
|
||||
|
||||
#### 结果示例
|
||||
```json
|
||||
{
|
||||
"dataset": "memsciqa",
|
||||
"items": 1,
|
||||
"metrics": {
|
||||
"accuracy": 0.0,
|
||||
"f1": 0.0,
|
||||
"bleu1": 0.0,
|
||||
"jaccard": 0.0
|
||||
},
|
||||
"latency": {
|
||||
"search": {
|
||||
"mean": 0.0,
|
||||
"p50": 0.0,
|
||||
"p95": 0.0,
|
||||
"iqr": 0.0
|
||||
},
|
||||
"llm": {
|
||||
"mean": 3067.7285194396973,
|
||||
"p50": 3067.7285194396973,
|
||||
"p95": 3067.7285194396973,
|
||||
"iqr": 0.0
|
||||
}
|
||||
},
|
||||
"avg_context_tokens": 4.0
|
||||
}
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
##### 1. 核心评估指标 (metrics)
|
||||
|
||||
**🎯 关键进步指标:**
|
||||
|
||||
- **`accuracy`**: 准确率
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,最直接的性能指标
|
||||
- 衡量系统回答正确的问题比例
|
||||
- 优秀标准:> 0.85
|
||||
|
||||
- **`f1`**: F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,平衡精确率和召回率
|
||||
- 计算公式:2 * (precision * recall) / (precision + recall)
|
||||
- 比单纯的准确率更全面,特别适合不平衡数据集
|
||||
|
||||
- **`bleu1`**: BLEU-1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量词汇级别的匹配度
|
||||
- 关注生成答案与标准答案的单词重叠
|
||||
- 源自机器翻译评估,适用于自然语言生成
|
||||
|
||||
- **`jaccard`**: Jaccard 相似度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量集合相似性
|
||||
- 计算公式:|A ∩ B| / |A ∪ B|
|
||||
- 对于多答案或集合类问题特别有用
|
||||
|
||||
##### 2. 性能指标 (latency)
|
||||
|
||||
**⚡ 效率指标:**
|
||||
|
||||
- **`search`**: 检索延迟统计(单位:毫秒)
|
||||
- `mean`: 平均检索延迟
|
||||
- `p50`: 中位数延迟(50%的请求在此时间内完成)
|
||||
- `p95`: 95分位数延迟(95%的请求在此时间内完成)
|
||||
- `iqr`: 四分位距(Q3-Q1,衡量稳定性)
|
||||
- **越低越好**,衡量记忆检索效率
|
||||
- 优秀标准:p95 < 2000ms
|
||||
|
||||
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
|
||||
- `mean`: 平均推理时间
|
||||
- `p50`: 中位数推理时间
|
||||
- `p95`: 95分位数推理时间
|
||||
- `iqr`: 四分位距(越小越稳定)
|
||||
- **越低越好**,衡量答案生成速度
|
||||
- 优秀标准:p95 < 3000ms
|
||||
- 注意:LLM 延迟通常占总延迟的大部分
|
||||
|
||||
##### 3. 资源指标
|
||||
|
||||
- **`avg_context_tokens`**: 平均上下文 token 数
|
||||
- **越低越好**(在保持准确性前提下)
|
||||
- 直接影响:
|
||||
- API 调用成本(按 token 计费)
|
||||
- 推理速度(token 越多越慢)
|
||||
- 上下文窗口占用
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
- 建议范围:根据模型上下文窗口和成本预算调整
|
||||
|
||||
##### 4. 数据集特点
|
||||
|
||||
- **`items`**: 评估的问题数量
|
||||
- 样本量越大,评估结果越可靠
|
||||
- 建议至少 100 个样本以获得稳定的评估结果
|
||||
|
||||
- **对话记忆特性**:
|
||||
- MemSciQA 专注于对话历史中的记忆检索
|
||||
- 评估系统从多轮对话中提取和回忆信息的能力
|
||||
- 模拟真实的对话场景
|
||||
|
||||
#### 系统进步衡量标准
|
||||
|
||||
**一级指标(最重要):**
|
||||
- `accuracy` 提升 → 核心能力提升
|
||||
- 目标:> 0.85
|
||||
- `f1` 提升 → 综合性能提升
|
||||
- 目标:> 0.80
|
||||
|
||||
**二级指标(重要):**
|
||||
- `latency.p95` 降低 → 用户体验提升
|
||||
- search.p95 目标:< 2000ms
|
||||
- llm.p95 目标:< 3000ms
|
||||
- `iqr` 降低 → 性能稳定性提升
|
||||
|
||||
**三级指标(辅助):**
|
||||
- `avg_context_tokens` 降低(在保持准确性前提下)→ 成本优化
|
||||
- `bleu1` 和 `jaccard` 提升 → 答案质量提升
|
||||
|
||||
**综合评估:**
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
- 该比值越高,说明系统在相同成本下性能越好
|
||||
- 总延迟 = search.p95 + llm.p95
|
||||
- 应控制在 5 秒以内以保证良好的用户体验
|
||||
|
||||
#### 优化建议
|
||||
|
||||
**提升准确性:**
|
||||
- 优化检索算法(调整 hybrid search 参数)
|
||||
- 改进 embedding 模型质量
|
||||
- 增加检索上下文数量(`search_limit`)
|
||||
- 优化 prompt 工程
|
||||
|
||||
**提升效率:**
|
||||
- 减少不必要的检索文档
|
||||
- 使用更快的 LLM 模型或量化版本
|
||||
- 实施缓存策略(相似问题复用结果)
|
||||
- 优化数据库索引
|
||||
|
||||
**平衡性能:**
|
||||
- 监控 accuracy vs latency 的权衡
|
||||
- 监控 accuracy vs cost (tokens) 的权衡
|
||||
- 根据业务需求调整优先级
|
||||
|
||||
|
||||
---
|
||||
|
||||
# 6. 三个基准测试对比总结
|
||||
|
||||
## 6.1 测试特点对比
|
||||
|
||||
| 基准测试 | 主要评估目标 | 数据集特点 | 适用场景 |
|
||||
|---------|------------|-----------|---------|
|
||||
| **Locomo** | 长对话记忆检索 | 长对话历史,多轮交互 | 评估长期记忆保持和检索能力 |
|
||||
| **LongMemEval** | 时间推理和多会话记忆 | 支持时间推理,多会话场景 | 评估时间感知和跨会话记忆能力 |
|
||||
| **MemSciQA** | 对话记忆问答 | 对话历史问答 | 评估对话上下文理解和记忆提取 |
|
||||
|
||||
## 6.2 核心指标对比
|
||||
|
||||
### 准确性指标
|
||||
|
||||
| 指标 | Locomo | LongMemEval | MemSciQA | 说明 |
|
||||
|-----|--------|-------------|----------|------|
|
||||
| **F1 Score** | ✅ | ✅ | ✅ | 所有测试都使用,最重要的综合指标 |
|
||||
| **Accuracy** | ❌ | ✅ | ✅ | 直观的准确率指标 |
|
||||
| **BLEU-1** | ✅ | ❌ | ✅ | 词汇级别匹配度 |
|
||||
| **Jaccard** | ✅ | ✅ | ✅ | 集合相似度 |
|
||||
| **Exact Match** | ❌ | ✅ | ❌ | 最严格的评估标准 |
|
||||
|
||||
### 性能指标
|
||||
|
||||
所有三个测试都包含:
|
||||
- **检索延迟** (search latency): mean, p50, p95, iqr
|
||||
- **LLM 延迟** (llm latency): mean, p50, p95, iqr
|
||||
- **上下文统计**: token 数、字符数、文档数
|
||||
|
||||
## 6.3 关键进步指标优先级
|
||||
|
||||
### 🥇 一级指标(必须关注)
|
||||
|
||||
1. **准确性指标**
|
||||
- Locomo: `f1`, `locomo_f1`
|
||||
- LongMemEval: `score_accuracy`, `accuracy_by_type`
|
||||
- MemSciQA: `accuracy`, `f1`
|
||||
- **目标**: > 85% 或 > 0.85
|
||||
|
||||
2. **综合性能**
|
||||
- 所有测试的 F1 分数应保持一致性
|
||||
- 不同类型问题的准确率应均衡
|
||||
|
||||
### 🥈 二级指标(重要)
|
||||
|
||||
3. **响应延迟**
|
||||
- `latency.p95` (95分位数延迟)
|
||||
- **目标**:
|
||||
- search.p95 < 2000ms
|
||||
- llm.p95 < 3000ms
|
||||
- 总延迟 < 5000ms
|
||||
|
||||
4. **性能稳定性**
|
||||
- `iqr` (四分位距)
|
||||
- **目标**: 越小越好,说明性能稳定
|
||||
|
||||
### 🥉 三级指标(优化)
|
||||
|
||||
5. **成本效率**
|
||||
- `avg_context_tokens`
|
||||
- **目标**: 在保持准确性前提下最小化
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
|
||||
6. **检索质量**
|
||||
- `avg_retrieved_docs` 的合理性
|
||||
- `unique_preview_count` (LongMemEval)
|
||||
- 检索内容的多样性和相关性
|
||||
|
||||
## 6.4 系统优化路径
|
||||
|
||||
### 阶段一:提升准确性(优先级最高)
|
||||
|
||||
**目标**: 所有测试的准确率 > 85%
|
||||
|
||||
**优化方向**:
|
||||
1. 改进 embedding 模型质量
|
||||
2. 优化检索算法(hybrid search 参数)
|
||||
3. 增加检索上下文数量(`search_limit`)
|
||||
4. 优化 prompt 工程
|
||||
5. 改进记忆存储结构
|
||||
|
||||
**监控指标**:
|
||||
- Locomo: `f1`, `locomo_f1`
|
||||
- LongMemEval: `score_accuracy`, `exact_match` 比例
|
||||
- MemSciQA: `accuracy`, `f1`
|
||||
|
||||
### 阶段二:优化性能(准确性达标后)
|
||||
|
||||
**目标**: p95 延迟 < 5 秒,性能稳定
|
||||
|
||||
**优化方向**:
|
||||
1. 优化数据库索引和查询
|
||||
2. 实施缓存策略
|
||||
3. 使用更快的 LLM 模型
|
||||
4. 并行化检索和推理
|
||||
5. 减少不必要的检索
|
||||
|
||||
**监控指标**:
|
||||
- `latency.p50`, `latency.p95`
|
||||
- `iqr` (稳定性)
|
||||
- 各阶段耗时分布
|
||||
|
||||
### 阶段三:降低成本(性能达标后)
|
||||
|
||||
**目标**: 在保持准确性和性能前提下,最小化成本
|
||||
|
||||
**优化方向**:
|
||||
1. 精简检索上下文
|
||||
2. 优化 context 选择策略
|
||||
3. 使用更小的 LLM 模型
|
||||
4. 实施智能缓存
|
||||
5. 批处理优化
|
||||
|
||||
**监控指标**:
|
||||
- `avg_context_tokens`
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
- API 调用成本
|
||||
|
||||
## 6.5 评估最佳实践
|
||||
|
||||
### 测试执行建议
|
||||
|
||||
1. **初始测试**: 使用小样本快速验证
|
||||
```bash
|
||||
--sample_size 10
|
||||
```
|
||||
|
||||
2. **完整评估**: 使用足够大的样本量
|
||||
```bash
|
||||
--sample_size 100 # 或更多
|
||||
```
|
||||
|
||||
3. **增量测试**: 数据已摄入时跳过摄入阶段
|
||||
```bash
|
||||
--skip_ingest
|
||||
```
|
||||
|
||||
4. **参数调优**: 系统性地调整参数并记录结果
|
||||
- 调整 `search_limit`: 4, 8, 12, 16
|
||||
- 调整 `context_char_budget`: 2000, 4000, 8000
|
||||
- 尝试不同的 `search_type`: vector, keyword, hybrid
|
||||
|
||||
### 结果分析建议
|
||||
|
||||
1. **横向对比**: 比较三个测试的结果,识别系统的强弱项
|
||||
2. **纵向对比**: 跟踪同一测试在不同版本的表现
|
||||
3. **分类分析**: 关注不同问题类型的性能差异
|
||||
4. **异常诊断**: 分析失败案例,找出根本原因
|
||||
|
||||
### 持续监控
|
||||
|
||||
建议建立监控仪表板,跟踪:
|
||||
- 核心指标趋势(准确率、延迟)
|
||||
- 成本效益比趋势
|
||||
- 不同问题类型的性能分布
|
||||
- 异常样本和失败模式
|
||||
|
||||
## 6.6 性能基准参考
|
||||
|
||||
### 优秀水平(Production Ready)
|
||||
|
||||
- **准确性**: accuracy/f1 > 0.90
|
||||
- **延迟**: p95 < 3 秒
|
||||
- **稳定性**: iqr < 500ms
|
||||
- **成本效益**: accuracy/tokens > 0.0001
|
||||
|
||||
### 良好水平(Acceptable)
|
||||
|
||||
- **准确性**: accuracy/f1 > 0.85
|
||||
- **延迟**: p95 < 5 秒
|
||||
- **稳定性**: iqr < 1000ms
|
||||
- **成本效益**: accuracy/tokens > 0.00005
|
||||
|
||||
### 需要改进(Below Target)
|
||||
|
||||
- **准确性**: accuracy/f1 < 0.85
|
||||
- **延迟**: p95 > 5 秒
|
||||
- **稳定性**: iqr > 1000ms
|
||||
- **成本效益**: accuracy/tokens < 0.00005
|
||||
|
||||
---
|
||||
|
||||
**注**: 以上标准仅供参考,实际目标应根据具体业务需求和资源约束调整。
|
||||
@@ -1,371 +0,0 @@
|
||||
"""
|
||||
交互式 Neo4j End User 数据检查工具
|
||||
|
||||
用于查询指定 end_user_id 在 Neo4j 中是否存在数据,以及数据的详细统计信息。
|
||||
|
||||
使用方法:
|
||||
python check_group_data.py
|
||||
python check_group_data.py --group-id locomo_benchmark
|
||||
python check_group_data.py --group-id memsciqa_benchmark --detailed
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}\n")
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def check_group_exists(end_user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
检查指定 end_user_id 是否存在数据
|
||||
|
||||
Args:
|
||||
end_user_id: 要检查的 end_user ID
|
||||
|
||||
Returns:
|
||||
包含统计信息的字典
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
# 查询该 end_user 的节点总数
|
||||
query_total = """
|
||||
MATCH (n {end_user_id: $end_user_id})
|
||||
RETURN count(n) as total_nodes
|
||||
"""
|
||||
result_total = await connector.execute_query(query_total, end_user_id=end_user_id)
|
||||
total_nodes = result_total[0]["total_nodes"] if result_total else 0
|
||||
|
||||
# 查询各类型节点的数量
|
||||
query_by_type = """
|
||||
MATCH (n {end_user_id: $end_user_id})
|
||||
RETURN labels(n) as labels, count(n) as count
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
result_by_type = await connector.execute_query(query_by_type, end_user_id=end_user_id)
|
||||
|
||||
# 查询关系数量
|
||||
query_relationships = """
|
||||
MATCH (n {end_user_id: $end_user_id})-[r]-()
|
||||
RETURN count(DISTINCT r) as total_relationships
|
||||
"""
|
||||
result_rel = await connector.execute_query(query_relationships, end_user_id=end_user_id)
|
||||
total_relationships = result_rel[0]["total_relationships"] if result_rel else 0
|
||||
|
||||
return {
|
||||
"exists": total_nodes > 0,
|
||||
"total_nodes": total_nodes,
|
||||
"total_relationships": total_relationships,
|
||||
"nodes_by_type": result_by_type
|
||||
}
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def get_detailed_stats(end_user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取详细的统计信息
|
||||
|
||||
Args:
|
||||
end_user_id: 要检查的 end_user ID
|
||||
|
||||
Returns:
|
||||
详细统计信息字典
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
stats = {}
|
||||
|
||||
# Chunk 节点统计
|
||||
query_chunks = """
|
||||
MATCH (c:Chunk {end_user_id: $end_user_id})
|
||||
RETURN count(c) as count,
|
||||
avg(size(c.content)) as avg_content_length
|
||||
"""
|
||||
result_chunks = await connector.execute_query(query_chunks, end_user_id=end_user_id)
|
||||
if result_chunks and result_chunks[0]["count"] > 0:
|
||||
stats["chunks"] = {
|
||||
"count": result_chunks[0]["count"],
|
||||
"avg_content_length": int(result_chunks[0]["avg_content_length"]) if result_chunks[0]["avg_content_length"] else 0
|
||||
}
|
||||
|
||||
# Statement 节点统计
|
||||
query_statements = """
|
||||
MATCH (s:Statement {end_user_id: $end_user_id})
|
||||
RETURN count(s) as count
|
||||
"""
|
||||
result_statements = await connector.execute_query(query_statements, end_user_id=end_user_id)
|
||||
if result_statements and result_statements[0]["count"] > 0:
|
||||
stats["statements"] = {
|
||||
"count": result_statements[0]["count"]
|
||||
}
|
||||
|
||||
# Entity 节点统计
|
||||
query_entities = """
|
||||
MATCH (e:Entity {end_user_id: $end_user_id})
|
||||
RETURN count(e) as count,
|
||||
count(DISTINCT e.entity_type) as unique_types
|
||||
"""
|
||||
result_entities = await connector.execute_query(query_entities, end_user_id=end_user_id)
|
||||
if result_entities and result_entities[0]["count"] > 0:
|
||||
stats["entities"] = {
|
||||
"count": result_entities[0]["count"],
|
||||
"unique_types": result_entities[0]["unique_types"]
|
||||
}
|
||||
|
||||
# Dialogue 节点统计
|
||||
query_dialogues = """
|
||||
MATCH (d:Dialogue {end_user_id: $end_user_id})
|
||||
RETURN count(d) as count
|
||||
"""
|
||||
result_dialogues = await connector.execute_query(query_dialogues, end_user_id=end_user_id)
|
||||
if result_dialogues and result_dialogues[0]["count"] > 0:
|
||||
stats["dialogues"] = {
|
||||
"count": result_dialogues[0]["count"]
|
||||
}
|
||||
|
||||
# Summary 节点统计
|
||||
query_summaries = """
|
||||
MATCH (s:Summary {end_user_id: $end_user_id})
|
||||
RETURN count(s) as count
|
||||
"""
|
||||
result_summaries = await connector.execute_query(query_summaries, end_user_id=end_user_id)
|
||||
if result_summaries and result_summaries[0]["count"] > 0:
|
||||
stats["summaries"] = {
|
||||
"count": result_summaries[0]["count"]
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def list_all_end_users() -> list:
|
||||
"""
|
||||
列出数据库中所有的 end_user_id
|
||||
|
||||
Returns:
|
||||
end_user_id 列表及其节点数量
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.end_user_id IS NOT NULL
|
||||
RETURN DISTINCT n.end_user_id as end_user_id, count(n) as node_count
|
||||
ORDER BY node_count DESC
|
||||
"""
|
||||
results = await connector.execute_query(query)
|
||||
return results
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def print_results(end_user_id: str, stats: Dict[str, Any], detailed_stats: Dict[str, Any] = None):
|
||||
"""
|
||||
打印查询结果
|
||||
|
||||
Args:
|
||||
end_user_id: End User ID
|
||||
stats: 基本统计信息
|
||||
detailed_stats: 详细统计信息(可选)
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📊 End User ID: {end_user_id}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
if not stats["exists"]:
|
||||
print("❌ 该 end_user_id 不存在数据")
|
||||
print("\n💡 提示: 请先运行基准测试以摄入数据")
|
||||
return
|
||||
|
||||
print(f"✅ 该 end_user_id 存在数据\n")
|
||||
print(f"📈 基本统计:")
|
||||
print(f" 总节点数: {stats['total_nodes']}")
|
||||
print(f" 总关系数: {stats['total_relationships']}")
|
||||
|
||||
if stats["nodes_by_type"]:
|
||||
print(f"\n📋 节点类型分布:")
|
||||
for item in stats["nodes_by_type"]:
|
||||
labels = ", ".join(item["labels"])
|
||||
count = item["count"]
|
||||
print(f" {labels}: {count}")
|
||||
|
||||
if detailed_stats:
|
||||
print(f"\n🔍 详细统计:")
|
||||
|
||||
if "chunks" in detailed_stats:
|
||||
print(f" Chunks: {detailed_stats['chunks']['count']} 个")
|
||||
print(f" 平均内容长度: {detailed_stats['chunks']['avg_content_length']} 字符")
|
||||
|
||||
if "statements" in detailed_stats:
|
||||
print(f" Statements: {detailed_stats['statements']['count']} 个")
|
||||
|
||||
if "entities" in detailed_stats:
|
||||
print(f" Entities: {detailed_stats['entities']['count']} 个")
|
||||
print(f" 唯一类型数: {detailed_stats['entities']['unique_types']}")
|
||||
|
||||
if "dialogues" in detailed_stats:
|
||||
print(f" Dialogues: {detailed_stats['dialogues']['count']} 个")
|
||||
|
||||
if "summaries" in detailed_stats:
|
||||
print(f" Summaries: {detailed_stats['summaries']['count']} 个")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
async def interactive_mode():
|
||||
"""
|
||||
交互式模式
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("🔍 Neo4j End User 数据检查工具 - 交互模式")
|
||||
print("="*60 + "\n")
|
||||
|
||||
while True:
|
||||
print("\n请选择操作:")
|
||||
print(" 1. 检查指定 end_user_id")
|
||||
print(" 2. 列出所有 end_user_id")
|
||||
print(" 3. 退出")
|
||||
|
||||
choice = input("\n请输入选项 (1-3): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
end_user_id = input("\n请输入 end_user_id: ").strip()
|
||||
if not end_user_id:
|
||||
print("❌ end_user_id 不能为空")
|
||||
continue
|
||||
|
||||
detailed = input("是否显示详细统计? (y/n, 默认 n): ").strip().lower() == 'y'
|
||||
|
||||
print("\n🔄 正在查询...")
|
||||
stats = await check_group_exists(end_user_id)
|
||||
|
||||
detailed_stats = None
|
||||
if detailed and stats["exists"]:
|
||||
detailed_stats = await get_detailed_stats(end_user_id)
|
||||
|
||||
print_results(end_user_id, stats, detailed_stats)
|
||||
|
||||
elif choice == "2":
|
||||
print("\n🔄 正在查询所有 end_user_id...")
|
||||
end_users = await list_all_end_users()
|
||||
|
||||
if not end_users:
|
||||
print("\n❌ 数据库中没有任何 end_user 数据")
|
||||
else:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📋 数据库中的所有 End User ID")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
for idx, end_user in enumerate(end_users, 1):
|
||||
print(f" {idx}. {end_user['end_user_id']}")
|
||||
print(f" 节点数: {end_user['node_count']}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
elif choice == "3":
|
||||
print("\n👋 再见!")
|
||||
break
|
||||
|
||||
else:
|
||||
print("\n❌ 无效的选项,请重新选择")
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
主函数
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="检查 Neo4j 中指定 end_user_id 的数据情况",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 交互模式
|
||||
python check_group_data.py
|
||||
|
||||
# 检查指定 end_user
|
||||
python check_group_data.py --end-user-id locomo_benchmark
|
||||
|
||||
# 检查并显示详细统计
|
||||
python check_group_data.py --end-user-id memsciqa_benchmark --detailed
|
||||
|
||||
# 列出所有 end_user
|
||||
python check_group_data.py --list-all
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--end-user-id",
|
||||
type=str,
|
||||
help="要检查的 end_user ID"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--detailed",
|
||||
action="store_true",
|
||||
help="显示详细统计信息"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--list-all",
|
||||
action="store_true",
|
||||
help="列出所有 end_user_id"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 如果没有提供任何参数,进入交互模式
|
||||
if not args.end_user_id and not args.list_all:
|
||||
await interactive_mode()
|
||||
return
|
||||
|
||||
# 列出所有 end_user
|
||||
if args.list_all:
|
||||
print("\n🔄 正在查询所有 end_user_id...")
|
||||
end_users = await list_all_end_users()
|
||||
|
||||
if not end_users:
|
||||
print("\n❌ 数据库中没有任何 end_user 数据")
|
||||
else:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📋 数据库中的所有 End User ID")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
for idx, end_user in enumerate(end_users, 1):
|
||||
print(f" {idx}. {end_user['end_user_id']}")
|
||||
print(f" 节点数: {end_user['node_count']}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
return
|
||||
|
||||
# 检查指定 end_user
|
||||
if args.end_user_id:
|
||||
print(f"\n🔄 正在查询 end_user_id: {args.end_user_id}...")
|
||||
stats = await check_group_exists(args.end_user_id)
|
||||
|
||||
detailed_stats = None
|
||||
if args.detailed and stats["exists"]:
|
||||
print("🔄 正在获取详细统计...")
|
||||
detailed_stats = await get_detailed_stats(args.end_user_id)
|
||||
|
||||
print_results(args.end_user_id, stats, detailed_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,100 +0,0 @@
|
||||
import math
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
# 评估指标的实现
|
||||
def _normalize(text: str) -> List[str]:
|
||||
"""Lowercase, strip punctuation, and split into tokens."""
|
||||
text = text.lower().strip()
|
||||
# Python's re doesn't support \p classes; use a simple non-word filter
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
tokens = [t for t in text.split() if t]
|
||||
return tokens
|
||||
|
||||
|
||||
def exact_match(pred: str, ref: str) -> float:
|
||||
return float(_normalize(pred) == _normalize(ref))
|
||||
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
p = set(_normalize(pred))
|
||||
r = set(_normalize(ref))
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
"""Unigram BLEU (BLEU-1) with clipping and brevity penalty."""
|
||||
p_tokens = _normalize(pred)
|
||||
r_tokens = _normalize(ref)
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
# Clipped count
|
||||
r_counts: Dict[str, int] = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
clipped = 0
|
||||
p_counts: Dict[str, int] = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
# Brevity penalty
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
return bp * precision
|
||||
|
||||
|
||||
def percentile(values: List[float], p: float) -> float:
|
||||
if not values:
|
||||
return 0.0
|
||||
vals = sorted(values)
|
||||
k = (len(vals) - 1) * p
|
||||
f = math.floor(k)
|
||||
c = math.ceil(k)
|
||||
if f == c:
|
||||
return vals[int(k)]
|
||||
return vals[f] + (k - f) * (vals[c] - vals[f])
|
||||
|
||||
|
||||
def latency_stats(latencies_ms: List[float]) -> Dict[str, float]:
|
||||
"""Return basic latency stats: mean, p50, p95, iqr (p75-p25)."""
|
||||
if not latencies_ms:
|
||||
return {"mean": 0.0, "p50": 0.0, "p95": 0.0, "iqr": 0.0}
|
||||
p25 = percentile(latencies_ms, 0.25)
|
||||
p50 = percentile(latencies_ms, 0.50)
|
||||
p75 = percentile(latencies_ms, 0.75)
|
||||
p95 = percentile(latencies_ms, 0.95)
|
||||
mean = sum(latencies_ms) / max(len(latencies_ms), 1)
|
||||
return {"mean": mean, "p50": p50, "p95": p95, "iqr": p75 - p25}
|
||||
|
||||
|
||||
def avg_context_tokens(contexts: List[str]) -> float:
|
||||
if not contexts:
|
||||
return 0.0
|
||||
return sum(len(_normalize(c)) for c in contexts) / len(contexts)
|
||||
@@ -1,62 +0,0 @@
|
||||
"""
|
||||
Dialogue search queries for evaluation purposes.
|
||||
This file contains Cypher queries for searching dialogues, entities, and chunks.
|
||||
Placed in evaluation directory to avoid circular imports with src modules.
|
||||
"""
|
||||
|
||||
# 应该是neo4j browser的cypher语句,需要修改文件名
|
||||
|
||||
# Entity search queries
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.name = $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.name CONTAINS $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
# Chunk search queries
|
||||
SEARCH_CHUNKS_BY_CONTENT = """
|
||||
MATCH (c:Chunk)
|
||||
WHERE c.content CONTAINS $content
|
||||
RETURN c
|
||||
"""
|
||||
|
||||
# Dialogue search queries
|
||||
SEARCH_DIALOGUE_BY_DIALOG_ID = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_id = $dialog_id
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
SEARCH_DIALOGUES_BY_CONTENT = """
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.content CONTAINS $q
|
||||
RETURN d
|
||||
"""
|
||||
|
||||
DIALOGUE_EMBEDDING_SEARCH = """
|
||||
WITH $embedding AS q
|
||||
MATCH (d:Dialogue)
|
||||
WHERE d.dialog_embedding IS NOT NULL
|
||||
AND ($end_user_id IS NULL OR d.end_user_id = $end_user_id)
|
||||
WITH d, q, d.dialog_embedding AS v
|
||||
WITH d,
|
||||
reduce(dot = 0.0, i IN range(0, size(q)-1) | dot + toFloat(q[i]) * toFloat(v[i])) AS dot,
|
||||
sqrt(reduce(qs = 0.0, i IN range(0, size(q)-1) | qs + toFloat(q[i]) * toFloat(q[i]))) AS qnorm,
|
||||
sqrt(reduce(vs = 0.0, i IN range(0, size(v)-1) | vs + toFloat(v[i]) * toFloat(v[i]))) AS vnorm
|
||||
WITH d, CASE WHEN qnorm = 0 OR vnorm = 0 THEN 0.0 ELSE dot / (qnorm * vnorm) END AS score
|
||||
WHERE score > $threshold
|
||||
RETURN d.id AS dialog_id,
|
||||
d.end_user_id AS end_user_id,
|
||||
d.content AS content,
|
||||
d.created_at AS created_at,
|
||||
d.expired_at AS expired_at,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
@@ -1,444 +0,0 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
import re
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent / "app" / "core" / "memory" / "evaluation" / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
|
||||
|
||||
async def ingest_contexts_via_full_pipeline(
|
||||
contexts: List[str],
|
||||
end_user_id: str,
|
||||
chunker_strategy: str | None = None,
|
||||
embedding_name: str | None = None,
|
||||
save_chunk_output: bool = False,
|
||||
save_chunk_output_path: str | None = None,
|
||||
reset_group: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
使用新的 ExtractionOrchestrator 运行完整的提取流水线
|
||||
|
||||
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
|
||||
This function uses the new ExtractionOrchestrator architecture for better maintainability.
|
||||
|
||||
Args:
|
||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||
end_user_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
chunker_strategy: Optional chunker strategy; defaults to SELECTED_CHUNKER_STRATEGY.
|
||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
|
||||
reset_group: If True, clear existing data for this group before ingestion.
|
||||
Returns:
|
||||
True if data saved successfully, False otherwise.
|
||||
"""
|
||||
chunker_strategy = chunker_strategy or os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
|
||||
embedding_name = embedding_name or os.getenv("EVAL_EMBEDDING_ID")
|
||||
|
||||
# Check if we should reset from environment variable if not explicitly set
|
||||
if not reset_group:
|
||||
reset_group = os.getenv("EVAL_RESET_ON_INGEST", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
# Step 0: Reset group if requested
|
||||
if reset_group:
|
||||
print(f"[Ingestion] 🗑️ 清空 end_user '{end_user_id}' 的现有数据...")
|
||||
try:
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 删除该 end_user 的所有节点和关系
|
||||
query = """
|
||||
MATCH (n {end_user_id: $end_user_id})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
await connector.execute_query(query, end_user_id=end_user_id)
|
||||
print(f"[Ingestion] ✅ End User '{end_user_id}' 已清空")
|
||||
finally:
|
||||
await connector.close()
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] ⚠️ 清空 end_user 失败: {e}")
|
||||
# 继续执行,不中断摄入流程
|
||||
|
||||
# Step 1: Initialize LLM client
|
||||
llm_client = None
|
||||
try:
|
||||
# 使用评估配置中的 LLM ID
|
||||
llm_id = os.getenv("EVAL_LLM_ID")
|
||||
if not llm_id:
|
||||
print("[Ingestion] ❌ EVAL_LLM_ID not set in .env.evaluation")
|
||||
return False
|
||||
|
||||
from app.db import get_db
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
llm_client = get_llm_client(llm_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable: {e}")
|
||||
return False
|
||||
|
||||
# Step 2: Parse contexts and create DialogData with chunks
|
||||
print(f"[Ingestion] Parsing {len(contexts)} contexts...")
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
dialog_data_list: List[DialogData] = []
|
||||
|
||||
for idx, ctx in enumerate(contexts):
|
||||
messages: List[ConversationMessage] = []
|
||||
|
||||
# Improved parsing: capture multi-line message blocks, normalize roles
|
||||
pattern = r"^\s*(用户|AI|assistant|user)\s*[::]\s*(.+?)(?=\n\s*(?:用户|AI|assistant|user)\s*[::]|\Z)"
|
||||
matches = list(re.finditer(pattern, ctx, flags=re.MULTILINE | re.DOTALL))
|
||||
|
||||
if matches:
|
||||
for m in matches:
|
||||
raw_role = m.group(1).strip()
|
||||
content = m.group(2).strip()
|
||||
norm_role = "AI" if raw_role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=content))
|
||||
else:
|
||||
# Fallback: line-by-line parsing
|
||||
for raw in ctx.split("\n"):
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)', line)
|
||||
if m:
|
||||
role = m.group(1).strip()
|
||||
msg = m.group(2).strip()
|
||||
norm_role = "AI" if role.lower() in ("ai", "assistant") else "用户"
|
||||
messages.append(ConversationMessage(role=norm_role, msg=msg))
|
||||
else:
|
||||
# Final fallback: treat as user message
|
||||
default_role = "AI" if re.match(r'^\s*(assistant|AI)\b', line, flags=re.IGNORECASE) else "用户"
|
||||
messages.append(ConversationMessage(role=default_role, msg=line))
|
||||
|
||||
context_model = ConversationContext(msgs=messages)
|
||||
dialog = DialogData(
|
||||
context=context_model,
|
||||
ref_id=f"pipeline_item_{idx}",
|
||||
end_user_id=end_user_id,
|
||||
user_id="default_user",
|
||||
apply_id="default_application",
|
||||
)
|
||||
# Generate chunks
|
||||
dialog.chunks = await chunker.process_dialogue(dialog)
|
||||
dialog_data_list.append(dialog)
|
||||
|
||||
if not dialog_data_list:
|
||||
print("[Ingestion] No dialogs to process.")
|
||||
return False
|
||||
|
||||
print(f"[Ingestion] Parsed {len(dialog_data_list)} dialogs with chunks")
|
||||
|
||||
# Step 3: Optionally save chunking outputs for debugging
|
||||
if save_chunk_output:
|
||||
try:
|
||||
def _serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
default_path = settings.get_memory_output_path("chunker_test_output.txt")
|
||||
out_path = save_chunk_output_path or default_path
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
|
||||
print(f"[Ingestion] Saved chunking results to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to save chunking results: {e}")
|
||||
|
||||
# Step 4: Initialize embedder client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.db import get_db
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_name, db)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to initialize embedder client: {e}")
|
||||
return False
|
||||
|
||||
# Step 5: Initialize Neo4j connector
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# Step 6: 构建 MemoryConfig(从环境变量直接构建,不依赖数据库)
|
||||
print("[Ingestion] 构建 MemoryConfig from environment variables...")
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
# 从环境变量获取配置参数
|
||||
llm_id = os.getenv("EVAL_LLM_ID")
|
||||
embedding_id = os.getenv("EVAL_EMBEDDING_ID")
|
||||
chunker_strategy_env = os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
|
||||
|
||||
if not llm_id or not embedding_id:
|
||||
print("[Ingestion] ❌ EVAL_LLM_ID or EVAL_EMBEDDING_ID is not set in .env.evaluation")
|
||||
print("[Ingestion] Please set both EVAL_LLM_ID and EVAL_EMBEDDING_ID")
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
# 从数据库获取模型信息(仅用于显示名称)
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
# 获取 LLM 模型信息(从 model_configs 表)
|
||||
llm_result = db.execute(
|
||||
text("SELECT name FROM model_configs WHERE id = :id"),
|
||||
{"id": llm_id}
|
||||
).fetchone()
|
||||
llm_model_name = llm_result[0] if llm_result else "Unknown LLM"
|
||||
|
||||
# 获取 Embedding 模型信息(从 model_configs 表)
|
||||
emb_result = db.execute(
|
||||
text("SELECT name FROM model_configs WHERE id = :id"),
|
||||
{"id": embedding_id}
|
||||
).fetchone()
|
||||
embedding_model_name = emb_result[0] if emb_result else "Unknown Embedding"
|
||||
except Exception as e:
|
||||
# 如果查询失败,使用默认名称
|
||||
print(f"[Ingestion] Warning: Failed to query model names from database: {e}")
|
||||
llm_model_name = f"LLM ({llm_id[:8]}...)"
|
||||
embedding_model_name = f"Embedding ({embedding_id[:8]}...)"
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# 构建 MemoryConfig 对象(使用最小必需配置)
|
||||
from uuid import uuid4
|
||||
memory_config = MemoryConfig(
|
||||
config_id=0, # 评估环境不需要真实的 config_id
|
||||
config_name="evaluation_config",
|
||||
workspace_id=uuid4(), # 临时 workspace_id
|
||||
workspace_name="evaluation_workspace",
|
||||
tenant_id=uuid4(), # 临时 tenant_id
|
||||
llm_model_id=UUID(llm_id),
|
||||
llm_model_name=llm_model_name,
|
||||
embedding_model_id=UUID(embedding_id),
|
||||
embedding_model_name=embedding_model_name,
|
||||
storage_type="neo4j",
|
||||
chunker_strategy=chunker_strategy_env,
|
||||
reflexion_enabled=False,
|
||||
reflexion_iteration_period=3,
|
||||
reflexion_range="partial",
|
||||
reflexion_baseline="TIME",
|
||||
loaded_at=datetime.now(),
|
||||
# 可选字段使用默认值
|
||||
rerank_model_id=None,
|
||||
rerank_model_name=None,
|
||||
llm_params={},
|
||||
embedding_params={},
|
||||
config_version="2.0",
|
||||
)
|
||||
|
||||
print(f"[Ingestion] ✅ 构建 MemoryConfig 成功")
|
||||
print(f"[Ingestion] LLM: {llm_model_name}")
|
||||
print(f"[Ingestion] Embedding: {embedding_model_name}")
|
||||
print(f"[Ingestion] Chunker: {chunker_strategy_env}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] ❌ Failed to build MemoryConfig: {e}")
|
||||
print(f"[Ingestion] Please check:")
|
||||
print(f"[Ingestion] 1. EVAL_LLM_ID and EVAL_EMBEDDING_ID are set in .env.evaluation")
|
||||
print(f"[Ingestion] 2. Model IDs exist in the models table")
|
||||
print(f"[Ingestion] 3. Database connection is working")
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
# Step 7: Initialize and run ExtractionOrchestrator
|
||||
print("[Ingestion] Running extraction pipeline with ExtractionOrchestrator...")
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
config = MemoryConfigService.get_pipeline_config(memory_config)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=connector,
|
||||
config=config,
|
||||
embedding_id=str(memory_config.embedding_model_id), # 传递 embedding_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Run the complete extraction pipeline
|
||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
||||
|
||||
# Handle different return formats:
|
||||
# - Pilot mode: 7 values (without dedup_details)
|
||||
# - Normal mode: 8 values (with dedup_details at the end)
|
||||
if len(result) == 8:
|
||||
# Normal mode: includes dedup_details
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
_, # dedup_details - not needed here
|
||||
) = result
|
||||
elif len(result) == 7:
|
||||
# Pilot mode or older version: no dedup_details
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
) = result
|
||||
else:
|
||||
raise ValueError(f"Unexpected number of return values: {len(result)}")
|
||||
|
||||
print(f"[Ingestion] Extraction completed: {len(statement_nodes)} statements, {len(entity_nodes)} entities")
|
||||
|
||||
except ValueError as e:
|
||||
# If unpacking fails, provide helpful error message
|
||||
print(f"[Ingestion] Extraction pipeline result unpacking failed: {e}")
|
||||
print(f"[Ingestion] Result type: {type(result)}, length: {len(result) if hasattr(result, '__len__') else 'N/A'}")
|
||||
if hasattr(result, '__len__') and len(result) > 0:
|
||||
print(f"[Ingestion] First element type: {type(result[0])}")
|
||||
await connector.close()
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Extraction pipeline failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
# Step 7: Generate memory summaries
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client
|
||||
)
|
||||
print(f"[Ingestion] Generated {len(summaries)} memory summaries")
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
|
||||
summaries = []
|
||||
|
||||
# Step 8: Save to Neo4j
|
||||
print("[Ingestion] Saving to Neo4j...")
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
chunk_nodes=chunk_nodes,
|
||||
statement_nodes=statement_nodes,
|
||||
entity_nodes=entity_nodes,
|
||||
entity_edges=entity_entity_edges,
|
||||
statement_chunk_edges=statement_chunk_edges,
|
||||
statement_entity_edges=statement_entity_edges,
|
||||
connector=connector
|
||||
)
|
||||
|
||||
# Save memory summaries separately
|
||||
if summaries:
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, connector)
|
||||
await add_memory_summary_statement_edges(summaries, connector)
|
||||
print(f"[Ingestion] Saved {len(summaries)} memory summary nodes to Neo4j")
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Warning: Failed to save summary nodes: {e}")
|
||||
|
||||
await connector.close()
|
||||
|
||||
if success:
|
||||
print("[Ingestion] Successfully saved all data to Neo4j!")
|
||||
else:
|
||||
print("[Ingestion] Failed to save data to Neo4j")
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to save data to Neo4j: {e}")
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
|
||||
async def handle_context_processing(args):
|
||||
"""Handle context-based processing from command line arguments."""
|
||||
contexts = []
|
||||
|
||||
if args.contexts:
|
||||
contexts.extend(args.contexts)
|
||||
|
||||
if args.context_file:
|
||||
try:
|
||||
with open(args.context_file, 'r', encoding='utf-8') as f:
|
||||
contexts.extend(line.strip() for line in f if line.strip())
|
||||
except Exception as e:
|
||||
print(f"Error reading context file: {e}")
|
||||
return False
|
||||
|
||||
if not contexts:
|
||||
print("No contexts provided for processing.")
|
||||
return False
|
||||
|
||||
return await main_from_contexts(contexts, args.context_end_user_id)
|
||||
|
||||
|
||||
async def main_from_contexts(contexts: List[str], end_user_id: str):
|
||||
"""Run the pipeline from provided dialogue contexts instead of test data."""
|
||||
print("=== Running pipeline from provided contexts ===")
|
||||
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=contexts,
|
||||
end_user_id=end_user_id,
|
||||
chunker_strategy=SELECTED_CHUNKER_STRATEGY,
|
||||
embedding_name=SELECTED_EMBEDDING_ID,
|
||||
save_chunk_output=True
|
||||
)
|
||||
|
||||
if success:
|
||||
print("Successfully processed and saved contexts to Neo4j!")
|
||||
else:
|
||||
print("Failed to process contexts.")
|
||||
|
||||
return success
|
||||
@@ -1,770 +0,0 @@
|
||||
"""
|
||||
LoCoMo Benchmark Script
|
||||
|
||||
This module provides the main entry point for running LoCoMo benchmark evaluations.
|
||||
It orchestrates data loading, ingestion, retrieval, LLM inference, and metric calculation
|
||||
in a clean, maintainable way.
|
||||
|
||||
Usage:
|
||||
python locomo_benchmark.py --sample_size 20 --search_type hybrid
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
f1_score,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
avg_context_tokens
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_metrics import (
|
||||
locomo_f1_score,
|
||||
locomo_multi_f1,
|
||||
get_category_name
|
||||
)
|
||||
from app.core.memory.evaluation.locomo.locomo_utils import (
|
||||
load_locomo_data,
|
||||
extract_conversations,
|
||||
resolve_temporal_references,
|
||||
select_and_format_information,
|
||||
retrieve_relevant_information,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# Get configuration from environment variables
|
||||
PROJECT_ROOT = str(Path(__file__).resolve().parents[5]) # api directory
|
||||
SELECTED_EMBEDDING_ID = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
|
||||
SELECTED_end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark")
|
||||
SELECTED_LLM_ID = os.getenv("EVAL_LLM_ID", "2c9b0782-7a85-4740-ba84-4baf77f256c4")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Step 1: Data Loading
|
||||
# ============================================================================
|
||||
|
||||
def step_load_data(data_path: str, sample_size: int) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load QA pairs from LoCoMo dataset.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
sample_size: Number of QA pairs to load (0 for all)
|
||||
|
||||
Returns:
|
||||
List of QA items from the first conversation
|
||||
"""
|
||||
print("📂 Loading LoCoMo data...")
|
||||
|
||||
# Load the dataset
|
||||
qa_items = load_locomo_data(data_path, sample_size)
|
||||
|
||||
print(f"✅ Loaded {len(qa_items)} QA pairs from first conversation\n")
|
||||
return qa_items
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Step 2: Data Ingestion
|
||||
# ============================================================================
|
||||
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
end_user_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Ingest conversations into Neo4j database.
|
||||
|
||||
Args:
|
||||
conversations: List of conversation strings (already formatted)
|
||||
end_user_id: Database end_user ID
|
||||
reset: Whether to reset the group before ingestion
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
|
||||
# Conversations are already formatted as strings, use them directly
|
||||
await ingest_contexts_via_full_pipeline(conversations, end_user_id)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Ingestion error: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def step_ingest_data(
|
||||
data_path: str,
|
||||
end_user_id: str,
|
||||
skip_ingest: bool,
|
||||
reset_group: bool,
|
||||
max_messages: Optional[int] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Ingest conversations into Neo4j database if needed.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
end_user_id: Database end_user ID
|
||||
skip_ingest: Whether to skip ingestion
|
||||
reset_group: Whether to reset the group before ingestion
|
||||
max_messages: Maximum messages per dialogue to ingest (for testing)
|
||||
|
||||
Returns:
|
||||
True if ingestion succeeded or was skipped, False otherwise
|
||||
"""
|
||||
if skip_ingest:
|
||||
print("⏭️ Skipping data ingestion (using existing data in Neo4j)")
|
||||
print(f" End User ID: {end_user_id}\n")
|
||||
else:
|
||||
print("💾 Checking database ingestion...")
|
||||
try:
|
||||
# Extract conversations with optional message limit
|
||||
conversations = extract_conversations(
|
||||
data_path,
|
||||
max_dialogues=1,
|
||||
max_messages_per_dialogue=max_messages
|
||||
)
|
||||
print(f"📝 Extracted {len(conversations)} conversations")
|
||||
|
||||
# Always ingest for now (ingestion check not implemented)
|
||||
print(f"🔄 Ingesting conversations into end_user '{end_user_id}'...")
|
||||
success = await ingest_conversations_if_needed(
|
||||
conversations=conversations,
|
||||
end_user_id=end_user_id,
|
||||
reset=reset_group
|
||||
)
|
||||
|
||||
if success:
|
||||
print("✅ Ingestion completed successfully\n")
|
||||
else:
|
||||
print("⚠️ Ingestion may have failed, continuing anyway\n")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Ingestion failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print("⚠️ Continuing with evaluation (database may be empty)\n")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Step 3: Initialize Clients
|
||||
# ============================================================================
|
||||
|
||||
def step_initialize_clients(llm_id: str, embedding_id: str):
|
||||
"""
|
||||
Initialize Neo4j connector, LLM client, and embedder.
|
||||
|
||||
Args:
|
||||
llm_id: LLM model ID
|
||||
embedding_id: Embedding model ID
|
||||
|
||||
Returns:
|
||||
Tuple of (connector, llm_client, embedder)
|
||||
"""
|
||||
print("🔧 Initializing clients...")
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# Get database session
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
llm_client = get_llm_client(llm_id, db)
|
||||
cfg_dict = get_embedder_config(embedding_id, db)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
print("✅ Clients initialized\n")
|
||||
return connector, llm_client, embedder
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Step 4: Process Questions
|
||||
# ============================================================================
|
||||
|
||||
async def step_process_all_questions(
|
||||
qa_items: List[Dict[str, Any]],
|
||||
end_user_id: str,
|
||||
search_type: str,
|
||||
search_limit: int,
|
||||
context_char_budget: int,
|
||||
connector: Neo4jConnector,
|
||||
embedder: OpenAIEmbedderClient,
|
||||
llm_client: Any
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Process all QA items: retrieve, generate, and calculate metrics."""
|
||||
print(f"🔍 Processing {len(qa_items)} questions...")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
samples: List[Dict[str, Any]] = []
|
||||
anchor_date = datetime(2023, 5, 8)
|
||||
|
||||
for idx, item in enumerate(qa_items, 1):
|
||||
question = item.get("question", "")
|
||||
ground_truth = item.get("answer", "")
|
||||
category = get_category_name(item)
|
||||
ground_truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
print(f"[{idx}/{len(qa_items)}] Category: {category}")
|
||||
print(f"❓ Question: {question}")
|
||||
print(f"✅ Ground Truth: {ground_truth_str}")
|
||||
|
||||
# Retrieve
|
||||
t_search_start = time.time()
|
||||
try:
|
||||
retrieved_info = await retrieve_relevant_information(
|
||||
question=question,
|
||||
end_user_id=end_user_id,
|
||||
search_type=search_type,
|
||||
search_limit=search_limit,
|
||||
connector=connector,
|
||||
embedder=embedder
|
||||
)
|
||||
search_latency = (time.time() - t_search_start) * 1000
|
||||
print(f"🔍 Retrieved {len(retrieved_info)} documents ({search_latency:.1f}ms)")
|
||||
except Exception as e:
|
||||
print(f"❌ Retrieval failed: {e}")
|
||||
retrieved_info = []
|
||||
search_latency = 0.0
|
||||
|
||||
# Format context
|
||||
context_text = select_and_format_information(
|
||||
retrieved_info=retrieved_info,
|
||||
question=question,
|
||||
max_chars=context_char_budget
|
||||
)
|
||||
context_text = resolve_temporal_references(context_text, anchor_date)
|
||||
if context_text:
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n{context_text}"
|
||||
else:
|
||||
context_text = "No relevant context found."
|
||||
|
||||
print(f"📝 Context: {len(context_text)} chars, {len(retrieved_info)} docs")
|
||||
|
||||
# Generate answer
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Question: {question}\n\nContext:\n{context_text}"
|
||||
}
|
||||
]
|
||||
|
||||
t_llm_start = time.time()
|
||||
try:
|
||||
response = await llm_client.chat(messages=messages)
|
||||
llm_latency = (time.time() - t_llm_start) * 1000
|
||||
if hasattr(response, 'content'):
|
||||
prediction = response.content.strip()
|
||||
elif isinstance(response, dict):
|
||||
prediction = response["choices"][0]["message"]["content"].strip()
|
||||
else:
|
||||
prediction = "Unknown"
|
||||
print(f"🤖 Prediction: {prediction} ({llm_latency:.1f}ms)")
|
||||
except Exception as e:
|
||||
print(f"❌ LLM failed: {e}")
|
||||
prediction = "Unknown"
|
||||
llm_latency = 0.0
|
||||
|
||||
# Calculate metrics
|
||||
f1_val = f1_score(prediction, ground_truth_str)
|
||||
bleu1_val = bleu1(prediction, ground_truth_str)
|
||||
jaccard_val = jaccard(prediction, ground_truth_str)
|
||||
if item.get("category") == 1:
|
||||
locomo_f1_val = locomo_multi_f1(prediction, ground_truth_str)
|
||||
else:
|
||||
locomo_f1_val = locomo_f1_score(prediction, ground_truth_str)
|
||||
|
||||
print(f"📊 Metrics - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, "
|
||||
f"Jaccard: {jaccard_val:.3f}, LoCoMo F1: {locomo_f1_val:.3f}")
|
||||
print()
|
||||
|
||||
samples.append({
|
||||
"question": question,
|
||||
"ground_truth": ground_truth_str,
|
||||
"prediction": prediction,
|
||||
"category": category,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"bleu1": bleu1_val,
|
||||
"jaccard": jaccard_val,
|
||||
"locomo_f1": locomo_f1_val
|
||||
},
|
||||
"retrieval": {
|
||||
"num_docs": len(retrieved_info),
|
||||
"context_length": len(context_text)
|
||||
},
|
||||
"context_tokens": len(context_text.split()),
|
||||
"timing": {
|
||||
"search_ms": search_latency,
|
||||
"llm_ms": llm_latency
|
||||
}
|
||||
})
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Step 5: Aggregate Results
|
||||
# ============================================================================
|
||||
|
||||
def step_aggregate_results(samples: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Aggregate metrics from all samples."""
|
||||
print(f"\n{'='*60}")
|
||||
print("📊 Aggregating Results")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
if not samples:
|
||||
return {
|
||||
"overall_metrics": {},
|
||||
"by_category": {},
|
||||
"latency": {},
|
||||
"context_stats": {}
|
||||
}
|
||||
|
||||
# Extract metrics
|
||||
f1_scores = [s["metrics"]["f1"] for s in samples]
|
||||
bleu1_scores = [s["metrics"]["bleu1"] for s in samples]
|
||||
jaccard_scores = [s["metrics"]["jaccard"] for s in samples]
|
||||
locomo_f1_scores = [s["metrics"]["locomo_f1"] for s in samples]
|
||||
|
||||
# Extract timing
|
||||
latencies_search = [s["timing"]["search_ms"] for s in samples]
|
||||
latencies_llm = [s["timing"]["llm_ms"] for s in samples]
|
||||
|
||||
# Extract context stats
|
||||
context_counts = [s["retrieval"]["num_docs"] for s in samples]
|
||||
context_chars = [s["retrieval"]["context_length"] for s in samples]
|
||||
context_tokens = [s["context_tokens"] for s in samples]
|
||||
|
||||
# Overall metrics
|
||||
overall_metrics = {
|
||||
"f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
|
||||
"bleu1": sum(bleu1_scores) / len(bleu1_scores) if bleu1_scores else 0.0,
|
||||
"jaccard": sum(jaccard_scores) / len(jaccard_scores) if jaccard_scores else 0.0,
|
||||
"locomo_f1": sum(locomo_f1_scores) / len(locomo_f1_scores) if locomo_f1_scores else 0.0
|
||||
}
|
||||
|
||||
# Per-category metrics
|
||||
category_data: Dict[str, Dict[str, List[float]]] = {}
|
||||
for sample in samples:
|
||||
cat = sample["category"]
|
||||
if cat not in category_data:
|
||||
category_data[cat] = {
|
||||
"f1": [],
|
||||
"bleu1": [],
|
||||
"jaccard": [],
|
||||
"locomo_f1": []
|
||||
}
|
||||
category_data[cat]["f1"].append(sample["metrics"]["f1"])
|
||||
category_data[cat]["bleu1"].append(sample["metrics"]["bleu1"])
|
||||
category_data[cat]["jaccard"].append(sample["metrics"]["jaccard"])
|
||||
category_data[cat]["locomo_f1"].append(sample["metrics"]["locomo_f1"])
|
||||
|
||||
by_category: Dict[str, Dict[str, Any]] = {}
|
||||
for cat, metrics_lists in category_data.items():
|
||||
by_category[cat] = {
|
||||
"count": len(metrics_lists["f1"]),
|
||||
"f1": sum(metrics_lists["f1"]) / len(metrics_lists["f1"]),
|
||||
"bleu1": sum(metrics_lists["bleu1"]) / len(metrics_lists["bleu1"]),
|
||||
"jaccard": sum(metrics_lists["jaccard"]) / len(metrics_lists["jaccard"]),
|
||||
"locomo_f1": sum(metrics_lists["locomo_f1"]) / len(metrics_lists["locomo_f1"])
|
||||
}
|
||||
|
||||
# Latency statistics
|
||||
latency = {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm)
|
||||
}
|
||||
|
||||
# Context statistics
|
||||
context_stats = {
|
||||
"avg_retrieved_docs": sum(context_counts) / len(context_counts) if context_counts else 0.0,
|
||||
"avg_context_chars": sum(context_chars) / len(context_chars) if context_chars else 0.0,
|
||||
"avg_context_tokens": sum(context_tokens) / len(context_tokens) if context_tokens else 0.0
|
||||
}
|
||||
|
||||
return {
|
||||
"overall_metrics": overall_metrics,
|
||||
"by_category": by_category,
|
||||
"latency": latency,
|
||||
"context_stats": context_stats
|
||||
}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Step 6: Result Saving
|
||||
# ============================================================================
|
||||
|
||||
def step_save_results(
|
||||
result: Dict[str, Any],
|
||||
output_dir: Optional[str]
|
||||
) -> str:
|
||||
"""
|
||||
Save evaluation results to JSON file.
|
||||
|
||||
Args:
|
||||
result: Complete result dictionary
|
||||
output_dir: Directory to save results (uses default if None)
|
||||
|
||||
Returns:
|
||||
Path to saved file
|
||||
"""
|
||||
if output_dir is None:
|
||||
# Use absolute path to ensure results are saved in the correct location
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
output_dir = script_dir / "results"
|
||||
else:
|
||||
# Convert to Path object
|
||||
output_dir = Path(output_dir)
|
||||
# If relative path, make it relative to script directory
|
||||
if not output_dir.is_absolute():
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
output_dir = script_dir / output_dir
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = output_dir / f"locomo_{timestamp_str}.json"
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ Results saved to: {output_path}\n")
|
||||
return str(output_path)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to save results: {e}")
|
||||
print("📊 Printing results to console instead:\n")
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
return ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Orchestration Function
|
||||
# ============================================================================
|
||||
|
||||
|
||||
async def run_locomo_benchmark(
|
||||
sample_size: int = 20,
|
||||
end_user_id: Optional[str] = None,
|
||||
search_type: str = "hybrid",
|
||||
search_limit: int = 12,
|
||||
context_char_budget: int = 8000,
|
||||
reset_group: bool = False,
|
||||
skip_ingest: bool = False,
|
||||
output_dir: Optional[str] = None,
|
||||
max_ingest_messages: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run LoCoMo benchmark evaluation.
|
||||
|
||||
This function orchestrates the complete evaluation pipeline by calling
|
||||
well-defined step functions:
|
||||
1. Load LoCoMo dataset (only QA pairs from first conversation)
|
||||
2. Ingest conversations into database (unless skip_ingest=True)
|
||||
3. Initialize clients (Neo4j, LLM, Embedder)
|
||||
4. Process all questions (retrieve, generate, calculate metrics)
|
||||
5. Aggregate results
|
||||
6. Save results to file
|
||||
|
||||
Note: By default, only the first conversation is ingested into the database,
|
||||
and only QA pairs from that conversation are evaluated. This ensures that
|
||||
all questions have corresponding memory in the database for retrieval.
|
||||
|
||||
Args:
|
||||
sample_size: Number of QA pairs to evaluate (from first conversation)
|
||||
end_user_id: Database end_user ID for retrieval (uses default if None)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max documents to retrieve per query
|
||||
context_char_budget: Max characters for context
|
||||
reset_group: Whether to clear and re-ingest data
|
||||
skip_ingest: If True, skip data ingestion and use existing data in Neo4j
|
||||
output_dir: Directory to save results (uses default if None)
|
||||
max_ingest_messages: Max messages per dialogue to ingest (for testing, None = all)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results including metrics, timing, and samples
|
||||
"""
|
||||
# Use default end_user_id if not provided
|
||||
# 优先级:命令行参数 > LOCOMO_END_USER_ID > EVAL_END_USER_ID > 默认值
|
||||
if end_user_id is None:
|
||||
end_user_id = os.getenv("LOCOMO_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "locomo_benchmark")
|
||||
|
||||
# Get model IDs from config
|
||||
llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f")
|
||||
embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
|
||||
|
||||
# Determine data path
|
||||
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
|
||||
data_path = dataset_dir / "locomo10.json"
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 locomo10.json 放置在: {dataset_dir}"
|
||||
)
|
||||
|
||||
# Print configuration
|
||||
print(f"\n{'='*60}")
|
||||
print("🚀 Starting LoCoMo Benchmark Evaluation")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Configuration:")
|
||||
print(f" Sample size: {sample_size}")
|
||||
print(f" End User ID: {end_user_id}")
|
||||
print(f" Search type: {search_type}")
|
||||
print(f" Search limit: {search_limit}")
|
||||
print(f" Context budget: {context_char_budget} chars")
|
||||
print(f" Data path: {data_path}")
|
||||
if max_ingest_messages:
|
||||
print(f" Max ingest messages: {max_ingest_messages} (testing mode)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Step 1: Load LoCoMo data (加载数据)
|
||||
try:
|
||||
qa_items = step_load_data(data_path, sample_size)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to load data: {e}")
|
||||
return {
|
||||
"error": f"Data loading failed: {e}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Step 2: Ingest data if needed(数据摄入)
|
||||
await step_ingest_data(data_path, end_user_id, skip_ingest, reset_group, max_ingest_messages)
|
||||
|
||||
# Step 3: Initialize clients (初始化客户端)
|
||||
connector, llm_client, embedder = step_initialize_clients(llm_id, embedding_id)
|
||||
|
||||
# Step 4: Process all questions (处理所有问题)
|
||||
try:
|
||||
samples = await step_process_all_questions(
|
||||
qa_items=qa_items,
|
||||
end_user_id=end_user_id,
|
||||
search_type=search_type,
|
||||
search_limit=search_limit,
|
||||
context_char_budget=context_char_budget,
|
||||
connector=connector,
|
||||
embedder=embedder,
|
||||
llm_client=llm_client
|
||||
)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
# Step 5: Aggregate results (聚合答案)
|
||||
aggregated = step_aggregate_results(samples)
|
||||
|
||||
# Build final result dictionary
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"sample_size": len(qa_items),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"params": {
|
||||
"end_user_id": end_user_id,
|
||||
"search_type": search_type,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_id": llm_id,
|
||||
"embedding_id": embedding_id
|
||||
},
|
||||
"overall_metrics": aggregated["overall_metrics"],
|
||||
"by_category": aggregated["by_category"],
|
||||
"latency": aggregated["latency"],
|
||||
"context_stats": aggregated["context_stats"],
|
||||
"samples": samples
|
||||
}
|
||||
|
||||
# Step 6: Save results (保存结果)
|
||||
step_save_results(result, output_dir)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Parse command-line arguments and run benchmark.
|
||||
|
||||
This function provides a CLI interface for running LoCoMo benchmarks
|
||||
with configurable parameters.
|
||||
|
||||
Configuration priority: Command-line args > Environment variables > Code defaults
|
||||
"""
|
||||
# Load environment variables first
|
||||
load_dotenv()
|
||||
|
||||
# Get defaults from environment variables
|
||||
env_sample_size = os.getenv("LOCOMO_SAMPLE_SIZE")
|
||||
env_search_limit = os.getenv("LOCOMO_SEARCH_LIMIT")
|
||||
env_context_budget = os.getenv("LOCOMO_CONTEXT_CHAR_BUDGET")
|
||||
env_output_dir = os.getenv("LOCOMO_OUTPUT_DIR")
|
||||
env_skip_ingest = os.getenv("LOCOMO_SKIP_INGEST", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
# Convert to appropriate types with fallback to code defaults
|
||||
default_sample_size = int(env_sample_size) if env_sample_size else 20
|
||||
default_search_limit = int(env_search_limit) if env_search_limit else 12
|
||||
default_context_budget = int(env_context_budget) if env_context_budget else 8000
|
||||
default_output_dir = env_output_dir if env_output_dir else None
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Run LoCoMo benchmark evaluation",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample_size",
|
||||
type=int,
|
||||
default=default_sample_size,
|
||||
help=f"Number of QA pairs to evaluate (env: LOCOMO_SAMPLE_SIZE={env_sample_size or 'not set'}, 0 for all)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--end_user_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Database end user ID for retrieval (uses LOCOMO_END_USER_ID or EVAL_END_USER_ID if not specified)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_type",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
choices=["keyword", "embedding", "hybrid"],
|
||||
help="Search strategy to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--search_limit",
|
||||
type=int,
|
||||
default=default_search_limit,
|
||||
help=f"Maximum number of documents to retrieve per query (env: LOCOMO_SEARCH_LIMIT={env_search_limit or 'not set'})"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--context_char_budget",
|
||||
type=int,
|
||||
default=default_context_budget,
|
||||
help=f"Maximum characters for context (env: LOCOMO_CONTEXT_CHAR_BUDGET={env_context_budget or 'not set'})"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reset_group",
|
||||
action="store_true",
|
||||
help="Clear and re-ingest data (not implemented)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_ingest",
|
||||
action="store_true",
|
||||
default=env_skip_ingest,
|
||||
help=f"Skip data ingestion and use existing data in Neo4j (env: LOCOMO_SKIP_INGEST={os.getenv('LOCOMO_SKIP_INGEST', 'false')})"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default=default_output_dir,
|
||||
help=f"Directory to save results (env: LOCOMO_OUTPUT_DIR={env_output_dir or 'not set'})"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_ingest_messages",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum messages per dialogue to ingest (for testing, default: all messages)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Run benchmark
|
||||
result = asyncio.run(run_locomo_benchmark(
|
||||
sample_size=args.sample_size,
|
||||
end_user_id=args.end_user_id,
|
||||
search_type=args.search_type,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
reset_group=args.reset_group,
|
||||
skip_ingest=args.skip_ingest,
|
||||
output_dir=args.output_dir,
|
||||
max_ingest_messages=args.max_ingest_messages
|
||||
))
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'='*60}")
|
||||
|
||||
# Check if there was an error
|
||||
if 'error' in result:
|
||||
print("❌ Benchmark Failed!")
|
||||
print(f"{'='*60}")
|
||||
print(f"Error: {result['error']}")
|
||||
return
|
||||
|
||||
print("🎉 Benchmark Complete!")
|
||||
print(f"{'='*60}")
|
||||
print("📊 Final Results:")
|
||||
print(f" Sample size: {result.get('sample_size', 0)}")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['bleu1']:.3f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['jaccard']:.3f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['locomo_f1']:.3f}")
|
||||
|
||||
if result.get('context_stats'):
|
||||
print("\n📈 Context Statistics:")
|
||||
print(f" Avg retrieved docs: {result['context_stats']['avg_retrieved_docs']:.1f}")
|
||||
print(f" Avg context chars: {result['context_stats']['avg_context_chars']:.0f}")
|
||||
print(f" Avg context tokens: {result['context_stats']['avg_context_tokens']:.0f}")
|
||||
|
||||
if result.get('latency'):
|
||||
print("\n⏱️ Latency Statistics:")
|
||||
print(f" Search - Mean: {result['latency']['search']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['search']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['search']['p95']:.1f}ms")
|
||||
print(f" LLM - Mean: {result['latency']['llm']['mean']:.1f}ms, "
|
||||
f"P50: {result['latency']['llm']['p50']:.1f}ms, "
|
||||
f"P95: {result['latency']['llm']['p95']:.1f}ms")
|
||||
|
||||
if result.get('by_category'):
|
||||
print("\n📂 Results by Category:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" Count: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['locomo_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['jaccard']:.3f}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,225 +0,0 @@
|
||||
"""
|
||||
LoCoMo-specific metric calculations.
|
||||
|
||||
This module provides clean, simplified implementations of metrics used for
|
||||
LoCoMo benchmark evaluation, including text normalization and F1 score variants.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
def normalize_text(text: str) -> str:
|
||||
"""
|
||||
Normalize text for LoCoMo evaluation.
|
||||
|
||||
Normalization steps:
|
||||
- Convert to lowercase
|
||||
- Remove commas
|
||||
- Remove stop words (a, an, the, and)
|
||||
- Remove punctuation
|
||||
- Normalize whitespace
|
||||
|
||||
Args:
|
||||
text: Input text to normalize
|
||||
|
||||
Returns:
|
||||
Normalized text string with consistent formatting
|
||||
|
||||
Examples:
|
||||
>>> normalize_text("The cat, and the dog")
|
||||
'cat dog'
|
||||
>>> normalize_text("Hello, World!")
|
||||
'hello world'
|
||||
"""
|
||||
# Ensure input is a string
|
||||
text = str(text) if text is not None else ""
|
||||
|
||||
# Convert to lowercase
|
||||
text = text.lower()
|
||||
|
||||
# Remove commas
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
|
||||
# Remove stop words
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
|
||||
# Remove punctuation (keep only word characters and whitespace)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
|
||||
# Normalize whitespace (collapse multiple spaces to single space)
|
||||
text = " ".join(text.split())
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def locomo_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for single-answer questions.
|
||||
|
||||
Uses token-level precision and recall based on normalized text.
|
||||
Treats tokens as sets (no duplicate counting).
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer
|
||||
ground_truth: Correct answer
|
||||
|
||||
Returns:
|
||||
F1 score between 0.0 and 1.0
|
||||
|
||||
Examples:
|
||||
>>> locomo_f1_score("Paris", "Paris")
|
||||
1.0
|
||||
>>> locomo_f1_score("The cat", "cat")
|
||||
1.0
|
||||
>>> locomo_f1_score("dog", "cat")
|
||||
0.0
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Normalize and tokenize
|
||||
pred_tokens = normalize_text(pred_str).split()
|
||||
truth_tokens = normalize_text(truth_str).split()
|
||||
|
||||
# Handle empty cases
|
||||
if not pred_tokens or not truth_tokens:
|
||||
return 0.0
|
||||
|
||||
# Convert to sets for comparison
|
||||
pred_set = set(pred_tokens)
|
||||
truth_set = set(truth_tokens)
|
||||
|
||||
# Calculate true positives (intersection)
|
||||
true_positives = len(pred_set & truth_set)
|
||||
|
||||
# Calculate precision and recall
|
||||
precision = true_positives / len(pred_set) if pred_set else 0.0
|
||||
recall = true_positives / len(truth_set) if truth_set else 0.0
|
||||
|
||||
# Calculate F1 score
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
|
||||
f1 = 2 * precision * recall / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
def locomo_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
"""
|
||||
Calculate LoCoMo F1 score for multi-answer questions.
|
||||
|
||||
Handles comma-separated answers by:
|
||||
1. Splitting both prediction and ground truth by commas
|
||||
2. For each ground truth answer, finding the best matching prediction
|
||||
3. Averaging the F1 scores across all ground truth answers
|
||||
|
||||
Args:
|
||||
prediction: Model's predicted answer (may contain multiple comma-separated answers)
|
||||
ground_truth: Correct answer (may contain multiple comma-separated answers)
|
||||
|
||||
Returns:
|
||||
Average F1 score across all ground truth answers (0.0 to 1.0)
|
||||
|
||||
Examples:
|
||||
>>> locomo_multi_f1("Paris, London", "Paris, London")
|
||||
1.0
|
||||
>>> locomo_multi_f1("Paris", "Paris, London")
|
||||
0.5
|
||||
>>> locomo_multi_f1("Paris, Berlin", "Paris, London")
|
||||
0.5
|
||||
"""
|
||||
# Ensure inputs are strings
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
# Split by commas and strip whitespace
|
||||
predictions = [p.strip() for p in pred_str.split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in truth_str.split(',') if g.strip()]
|
||||
|
||||
# Handle empty cases
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
|
||||
# For each ground truth, find the best matching prediction
|
||||
f1_scores = []
|
||||
for gt in ground_truths:
|
||||
# Calculate F1 with each prediction and take the maximum
|
||||
best_f1 = max(locomo_f1_score(pred, gt) for pred in predictions)
|
||||
f1_scores.append(best_f1)
|
||||
|
||||
# Return average F1 across all ground truths
|
||||
return sum(f1_scores) / len(f1_scores)
|
||||
|
||||
|
||||
def get_category_name(item: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract and normalize category name from QA item.
|
||||
|
||||
Handles both numeric categories (1-4) and string categories with various formats.
|
||||
Supports multiple field names: "cat", "category", "type".
|
||||
|
||||
Category mapping:
|
||||
- 1 or "multi-hop" -> "Multi-Hop"
|
||||
- 2 or "temporal" -> "Temporal"
|
||||
- 3 or "open domain" -> "Open Domain"
|
||||
- 4 or "single-hop" -> "Single-Hop"
|
||||
|
||||
Args:
|
||||
item: QA item dictionary containing category information
|
||||
|
||||
Returns:
|
||||
Standardized category name or "unknown" if not found
|
||||
|
||||
Examples:
|
||||
>>> get_category_name({"category": 1})
|
||||
'Multi-Hop'
|
||||
>>> get_category_name({"cat": "temporal"})
|
||||
'Temporal'
|
||||
>>> get_category_name({"type": "Single-Hop"})
|
||||
'Single-Hop'
|
||||
"""
|
||||
# Numeric category mapping
|
||||
CATEGORY_MAP = {
|
||||
1: "Multi-Hop",
|
||||
2: "Temporal",
|
||||
3: "Open Domain",
|
||||
4: "Single-Hop",
|
||||
}
|
||||
|
||||
# String category aliases (case-insensitive)
|
||||
TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
# Try "cat" field first (string category)
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return TYPE_ALIASES.get(lower, name)
|
||||
|
||||
# Try "category" field (can be int or string)
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP.get(cat_num, "unknown")
|
||||
elif isinstance(cat_num, str) and cat_num.strip():
|
||||
lower = cat_num.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_num.strip())
|
||||
|
||||
# Try "type" field as fallback
|
||||
cat_type = item.get("type")
|
||||
if isinstance(cat_type, str) and cat_type.strip():
|
||||
lower = cat_type.strip().lower()
|
||||
return TYPE_ALIASES.get(lower, cat_type.strip())
|
||||
|
||||
return "unknown"
|
||||
@@ -1,864 +0,0 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load main .env
|
||||
load_dotenv()
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
# Get group_id from config
|
||||
group_id = os.getenv("EVAL_GROUP_ID", "locomo_test")
|
||||
print(f"✅ 使用配置的 group_id: {group_id}")
|
||||
|
||||
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
|
||||
def _loc_normalize(text: str) -> str:
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text)
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
# 回退到本地实现
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens and not r_tokens:
|
||||
return 1.0
|
||||
if not p_tokens or not r_tokens:
|
||||
return 0.0
|
||||
p_set = set(p_tokens)
|
||||
r_set = set(r_tokens)
|
||||
tp = len(p_set & r_set)
|
||||
precision = tp / len(p_set) if p_set else 0.0
|
||||
recall = tp / len(r_set) if r_set else 0.0
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
r_tokens = _loc_normalize(ref_str).split()
|
||||
if not p_tokens:
|
||||
return 0.0
|
||||
|
||||
r_counts = {}
|
||||
for t in r_tokens:
|
||||
r_counts[t] = r_counts.get(t, 0) + 1
|
||||
|
||||
clipped = 0
|
||||
p_counts = {}
|
||||
for t in p_tokens:
|
||||
p_counts[t] = p_counts.get(t, 0) + 1
|
||||
|
||||
for t, c in p_counts.items():
|
||||
clipped += min(c, r_counts.get(t, 0))
|
||||
|
||||
precision = clipped / max(len(p_tokens), 1)
|
||||
ref_len = len(r_tokens)
|
||||
pred_len = len(p_tokens)
|
||||
|
||||
if pred_len > ref_len or pred_len == 0:
|
||||
bp = 1.0
|
||||
else:
|
||||
bp = math.exp(1 - ref_len / max(pred_len, 1))
|
||||
|
||||
return bp * precision
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
pred_str = str(pred) if pred is not None else ""
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
p = set(_loc_normalize(pred_str).split())
|
||||
r = set(_loc_normalize(ref_str).split())
|
||||
if not p and not r:
|
||||
return 1.0
|
||||
if not p or not r:
|
||||
return 0.0
|
||||
return len(p & r) / len(p | r)
|
||||
|
||||
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
|
||||
try:
|
||||
from app.core.memory.evaluation.locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
|
||||
# 回退到本地实现 LoCoMo 特定函数
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
t = str(text) if text is not None else ""
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
p_tokens = _loc_normalize(prediction).split()
|
||||
g_tokens = _loc_normalize(ground_truth).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
predictions = [p.strip() for p in str(prediction).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(ground_truth).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 8000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_dynamic_search_params(question: str, question_index: int, total_questions: int):
|
||||
"""根据问题复杂度和进度动态调整检索参数"""
|
||||
|
||||
# 分析问题复杂度
|
||||
word_count = len(question.split())
|
||||
has_temporal = any(word in question.lower() for word in ['when', 'date', 'time', 'ago'])
|
||||
has_multi_hop = any(word in question.lower() for word in ['and', 'both', 'also', 'while'])
|
||||
|
||||
# 根据进度调整 - 后期问题可能需要更精确的检索
|
||||
progress_factor = question_index / total_questions
|
||||
|
||||
base_limit = 12
|
||||
if has_temporal and has_multi_hop:
|
||||
base_limit = 20
|
||||
elif word_count > 8:
|
||||
base_limit = 16
|
||||
|
||||
# 随着测试进行,逐渐收紧检索范围
|
||||
adjusted_limit = max(8, int(base_limit * (1 - progress_factor * 0.3)))
|
||||
|
||||
# 动态调整最大字符数
|
||||
max_chars = 8000 + 4000 * (1 - progress_factor)
|
||||
|
||||
return {
|
||||
"limit": adjusted_limit,
|
||||
"max_chars": int(max_chars)
|
||||
}
|
||||
|
||||
|
||||
class EnhancedEvaluationMonitor:
|
||||
def __init__(self, reset_interval=5, performance_threshold=0.6):
|
||||
self.question_count = 0
|
||||
self.reset_interval = reset_interval
|
||||
self.performance_threshold = performance_threshold
|
||||
self.consecutive_low_scores = 0
|
||||
self.performance_history = []
|
||||
self.recent_f1_scores = []
|
||||
|
||||
def should_reset_connections(self, current_f1=None):
|
||||
"""基于计数和性能双重判断"""
|
||||
# 定期重置
|
||||
if self.question_count % self.reset_interval == 0:
|
||||
return True
|
||||
|
||||
# 性能驱动的重置
|
||||
if current_f1 is not None and current_f1 < self.performance_threshold:
|
||||
self.consecutive_low_scores += 1
|
||||
if self.consecutive_low_scores >= 2: # 连续2个低分就重置
|
||||
print("🚨 连续低分,触发紧急重置")
|
||||
self.consecutive_low_scores = 0
|
||||
return True
|
||||
else:
|
||||
self.consecutive_low_scores = 0
|
||||
|
||||
return False
|
||||
|
||||
def record_performance(self, question_index, metrics, context_length, retrieved_docs):
|
||||
"""记录性能指标,检测衰减"""
|
||||
self.performance_history.append({
|
||||
'index': question_index,
|
||||
'metrics': metrics,
|
||||
'context_length': context_length,
|
||||
'retrieved_docs': retrieved_docs,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# 记录最近的F1分数
|
||||
self.recent_f1_scores.append(metrics['f1'])
|
||||
if len(self.recent_f1_scores) > 5:
|
||||
self.recent_f1_scores.pop(0)
|
||||
|
||||
def get_recent_performance(self):
|
||||
"""获取近期平均性能"""
|
||||
if not self.recent_f1_scores:
|
||||
return 0.5
|
||||
return sum(self.recent_f1_scores) / len(self.recent_f1_scores)
|
||||
|
||||
def get_performance_trend(self):
|
||||
"""分析性能趋势"""
|
||||
if len(self.performance_history) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_metrics = [item['metrics']['f1'] for item in self.performance_history[-5:]]
|
||||
earlier_metrics = [item['metrics']['f1'] for item in self.performance_history[-10:-5]]
|
||||
|
||||
if len(recent_metrics) < 2 or len(earlier_metrics) < 2:
|
||||
return "stable"
|
||||
|
||||
recent_avg = sum(recent_metrics) / len(recent_metrics)
|
||||
earlier_avg = sum(earlier_metrics) / len(earlier_metrics)
|
||||
|
||||
if recent_avg < earlier_avg * 0.8:
|
||||
return "degrading"
|
||||
elif recent_avg > earlier_avg * 1.1:
|
||||
return "improving"
|
||||
else:
|
||||
return "stable"
|
||||
|
||||
|
||||
def get_enhanced_search_params(question: str, question_index: int, total_questions: int, recent_performance: float):
|
||||
"""基于问题复杂度和近期性能动态调整检索参数"""
|
||||
|
||||
# 基础参数
|
||||
base_params = get_dynamic_search_params(question, question_index, total_questions)
|
||||
|
||||
# 性能自适应调整
|
||||
if recent_performance < 0.5: # 近期表现差
|
||||
# 增加检索范围,尝试获取更多上下文
|
||||
base_params["limit"] = min(base_params["limit"] + 5, 25)
|
||||
base_params["max_chars"] = min(base_params["max_chars"] + 2000, 12000)
|
||||
print(f"📈 性能自适应:增加检索范围 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
elif recent_performance > 0.8: # 近期表现好
|
||||
# 收紧检索,提高精度
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 8)
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 6000)
|
||||
print(f"🎯 性能自适应:提高检索精度 (limit={base_params['limit']}, max_chars={base_params['max_chars']})")
|
||||
|
||||
# 中间阶段特殊处理
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5)
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用更精确的检索策略")
|
||||
base_params["limit"] = max(base_params["limit"] - 2, 10) # 减少数量,提高质量
|
||||
base_params["max_chars"] = max(base_params["max_chars"] - 1000, 7000)
|
||||
|
||||
return base_params
|
||||
|
||||
|
||||
def enhanced_context_selection(contexts: List[str], question: str, question_index: int, total_questions: int, max_chars: int = 8000) -> str:
|
||||
"""考虑问题序列位置的智能选择"""
|
||||
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 在序列中间阶段使用更严格的筛选
|
||||
mid_sequence_factor = abs(question_index / total_questions - 0.5) # 距离中心的距离
|
||||
|
||||
if mid_sequence_factor < 0.2: # 在中间30%的问题
|
||||
print("🎯 中间阶段:使用严格上下文筛选")
|
||||
|
||||
# 提取问题关键词
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
# 只保留高度相关的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts:
|
||||
context_lower = context.lower()
|
||||
relevance_score = sum(3 if word in context_lower else 0 for word in question_words)
|
||||
|
||||
# 额外加分给包含数字、日期的上下文(对事实性问题更重要)
|
||||
if any(char.isdigit() for char in context):
|
||||
relevance_score += 2
|
||||
|
||||
# 提高阈值:只有得分>=3的上下文才保留
|
||||
if relevance_score >= 3:
|
||||
filtered_contexts.append(context)
|
||||
else:
|
||||
print(f" - 过滤低分上下文: 得分={relevance_score}")
|
||||
|
||||
contexts = filtered_contexts
|
||||
print(f"🔍 严格筛选后保留 {len(contexts)} 个上下文")
|
||||
|
||||
# 使用原有的智能选择逻辑
|
||||
return smart_context_selection(contexts, question, max_chars)
|
||||
|
||||
|
||||
async def run_enhanced_evaluation():
|
||||
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
|
||||
from dotenv import load_dotenv
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# Get model IDs from config
|
||||
llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f")
|
||||
embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
|
||||
|
||||
# 加载数据 - 使用统一的 dataset 目录
|
||||
data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "dataset", "locomo10.json")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 locomo10.json 放置在: api/app/core/memory/evaluation/dataset/"
|
||||
)
|
||||
|
||||
print(f"✅ 找到数据文件: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
qa_items = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
|
||||
# 测试多少个问题 - 可通过环境变量设置
|
||||
sample_size = int(os.getenv("LOCOMO_SAMPLE_SIZE", "20"))
|
||||
items = qa_items[:sample_size]
|
||||
print(f"📊 将测试 {len(items)} 个问题(总共 {len(qa_items)} 个可用)")
|
||||
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
# 获取数据库会话并初始化 LLM 客户端
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
llm = get_llm_client(llm_id, db)
|
||||
|
||||
# 初始化embedder
|
||||
cfg_dict = get_embedder_config(embedding_id, db)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 🔧 创建 MemoryConfig 对象用于搜索
|
||||
# 方案1:如果有配置ID,从数据库加载
|
||||
config_id = os.getenv("EVAL_CONFIG_ID")
|
||||
if config_id:
|
||||
print(f"📋 从数据库加载配置 ID: {config_id}")
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
memory_config = memory_config_service.load_memory_config(config_id, service_name="locomo_test")
|
||||
else:
|
||||
# 方案2:创建临时配置对象用于测试
|
||||
print(f"📋 创建临时测试配置")
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
# 将字符串 ID 转换为 UUID
|
||||
try:
|
||||
embedding_uuid = UUID(embedding_id)
|
||||
llm_uuid = UUID(llm_id)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"无效的 UUID 格式: {e}")
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
config_id=1, # 临时 ID
|
||||
config_name="locomo_test_config",
|
||||
workspace_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 workspace
|
||||
workspace_name="test_workspace",
|
||||
tenant_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 tenant
|
||||
embedding_model_id=embedding_uuid,
|
||||
embedding_model_name="test_embedding",
|
||||
llm_model_id=llm_uuid,
|
||||
llm_model_name="test_llm",
|
||||
storage_type="neo4j",
|
||||
chunker_strategy="RecursiveChunker",
|
||||
reflexion_enabled=False,
|
||||
reflexion_iteration_period=3,
|
||||
reflexion_range="partial",
|
||||
reflexion_baseline="Time",
|
||||
loaded_at=datetime.now()
|
||||
)
|
||||
|
||||
print(f"✅ MemoryConfig 已准备: embedding_id={memory_config.embedding_model_id}, llm_id={memory_config.llm_model_id}")
|
||||
|
||||
# 初始化连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化结果字典
|
||||
results = {
|
||||
"questions": [],
|
||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
||||
"category_metrics": {},
|
||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
||||
"performance_trend": "stable",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"enhanced_strategy": True
|
||||
}
|
||||
|
||||
total_f1 = 0.0
|
||||
total_bleu1 = 0.0
|
||||
total_jaccard = 0.0
|
||||
total_loc_f1 = 0.0
|
||||
total_context_length = 0
|
||||
total_retrieved_docs = 0
|
||||
category_stats = {}
|
||||
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
monitor.question_count += 1
|
||||
|
||||
# 获取近期性能用于重置判断
|
||||
recent_performance = monitor.get_recent_performance()
|
||||
|
||||
# 增强的重置判断
|
||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
||||
if should_reset and i > 0:
|
||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
||||
await connector.close()
|
||||
connector = Neo4jConnector() # 创建新连接
|
||||
print("✅ 连接重置完成")
|
||||
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
|
||||
print(f"✅ 真实答案: {ref_str}")
|
||||
|
||||
# 分类别统计
|
||||
category = "Unknown"
|
||||
if item.get("category") == 1:
|
||||
category = "Multi-Hop"
|
||||
elif item.get("category") == 2:
|
||||
category = "Temporal"
|
||||
elif item.get("category") == 3:
|
||||
category = "Open Domain"
|
||||
elif item.get("category") == 4:
|
||||
category = "Single-Hop"
|
||||
|
||||
# 增强的检索参数
|
||||
search_params = get_enhanced_search_params(q, i, len(items), recent_performance)
|
||||
search_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {category}, 检索参数: limit={search_limit}, max_chars={max_chars}")
|
||||
|
||||
# 使用项目标准的混合检索方法
|
||||
t0 = time.time()
|
||||
contexts_all = []
|
||||
|
||||
try:
|
||||
# 使用旧版本的搜索服务(重构前的版本)
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
print(f"🔀 使用混合搜索服务(旧版本)...")
|
||||
print(f"📍 检索参数: group_id={group_id}, limit=20, search_type=hybrid")
|
||||
print(f"📍 查询文本: {q}")
|
||||
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type="hybrid",
|
||||
end_user_id="locomo_sk",
|
||||
limit=20,
|
||||
include=["statements", "chunks", "entities", "summaries"],
|
||||
output_path=None,
|
||||
memory_config=memory_config, # 🔧 添加必需的 memory_config 参数
|
||||
rerank_alpha=0.6, # BM25权重
|
||||
use_forgetting_rerank=False,
|
||||
use_llm_rerank=False
|
||||
)
|
||||
|
||||
# 处理搜索结果 - 旧版本返回包含 reranked_results 的结构
|
||||
# 对于 hybrid 搜索,使用 reranked_results
|
||||
if "reranked_results" in search_results:
|
||||
reranked = search_results["reranked_results"]
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
# 单一搜索类型的结果
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + ' '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
except Exception as e:
|
||||
print(f"❌ 检索失败: {e}")
|
||||
import traceback
|
||||
print(f"详细错误信息:\n{traceback.format_exc()}")
|
||||
contexts_all = []
|
||||
|
||||
t1 = time.time()
|
||||
search_time = (t1 - t0) * 1000
|
||||
|
||||
# 增强的上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
# 使用增强的上下文选择
|
||||
context_text = enhanced_context_selection(contexts_all, q, i, len(items), max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览(不只是第一条)
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
# 🔍 调试:检查答案是否在上下文中
|
||||
if ref_str and ref_str.strip():
|
||||
answer_found = any(ref_str.lower() in ctx.lower() for ctx in contexts_all)
|
||||
print(f"🔍 调试:答案 '{ref_str}' 是否在检索到的上下文中? {'✅ 是' if answer_found else '❌ 否'}")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# LLM 回答
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
except Exception as e:
|
||||
print(f"❌ LLM 生成失败: {e}")
|
||||
pred = "Unknown"
|
||||
t3 = time.time()
|
||||
llm_time = (t3 - t2) * 1000
|
||||
|
||||
# 计算指标 - 使用导入的指标函数
|
||||
f1_val = f1_score(pred, ref_str)
|
||||
bleu1_val = bleu1(pred, ref_str)
|
||||
jaccard_val = jaccard(pred, ref_str)
|
||||
loc_f1_val = loc_f1_score(pred, ref_str)
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"📈 指标 - F1: {f1_val:.3f}, BLEU-1: {bleu1_val:.3f}, Jaccard: {jaccard_val:.3f}, LoCoMo F1: {loc_f1_val:.3f}")
|
||||
print(f"⏱️ 时间 - 检索: {search_time:.1f}ms, LLM: {llm_time:.1f}ms")
|
||||
|
||||
# 更新统计
|
||||
total_f1 += f1_val
|
||||
total_bleu1 += bleu1_val
|
||||
total_jaccard += jaccard_val
|
||||
total_loc_f1 += loc_f1_val
|
||||
total_context_length += len(context_text)
|
||||
total_retrieved_docs += len(contexts_all)
|
||||
|
||||
if category not in category_stats:
|
||||
category_stats[category] = {"count": 0, "f1_sum": 0.0, "b1_sum": 0.0, "j_sum": 0.0, "loc_f1_sum": 0.0}
|
||||
|
||||
category_stats[category]["count"] += 1
|
||||
category_stats[category]["f1_sum"] += f1_val
|
||||
category_stats[category]["b1_sum"] += bleu1_val
|
||||
category_stats[category]["j_sum"] += jaccard_val
|
||||
category_stats[category]["loc_f1_sum"] += loc_f1_val
|
||||
|
||||
# 记录性能指标
|
||||
metrics = {"f1": f1_val, "bleu1": bleu1_val, "jaccard": jaccard_val, "loc_f1": loc_f1_val}
|
||||
monitor.record_performance(i, metrics, len(context_text), len(contexts_all))
|
||||
|
||||
# 保存结果
|
||||
question_result = {
|
||||
"question": q,
|
||||
"ground_truth": ref_str,
|
||||
"prediction": pred,
|
||||
"category": category,
|
||||
"metrics": metrics,
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": max_chars,
|
||||
"recent_performance": recent_performance
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_time,
|
||||
"llm_ms": llm_time
|
||||
}
|
||||
}
|
||||
|
||||
results["questions"].append(question_result)
|
||||
|
||||
print("="*60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 评估过程中发生错误: {e}")
|
||||
# 即使出错,也返回已有的结果
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
finally:
|
||||
db.close() # 关闭数据库会话
|
||||
|
||||
# 计算总体指标
|
||||
n = len(items)
|
||||
if n > 0:
|
||||
results["overall_metrics"] = {
|
||||
"f1": total_f1 / n,
|
||||
"b1": total_bleu1 / n,
|
||||
"j": total_jaccard / n,
|
||||
"loc_f1": total_loc_f1 / n
|
||||
}
|
||||
|
||||
for category, stats in category_stats.items():
|
||||
count = stats["count"]
|
||||
results["category_metrics"][category] = {
|
||||
"count": count,
|
||||
"f1": stats["f1_sum"] / count,
|
||||
"bleu1": stats["b1_sum"] / count,
|
||||
"jaccard": stats["j_sum"] / count,
|
||||
"loc_f1": stats["loc_f1_sum"] / count
|
||||
}
|
||||
|
||||
results["retrieval_stats"]["avg_context_length"] = total_context_length / n
|
||||
results["retrieval_stats"]["avg_retrieved_docs"] = total_retrieved_docs / n
|
||||
|
||||
# 分析性能趋势
|
||||
results["performance_trend"] = monitor.get_performance_trend()
|
||||
results["reset_interval"] = monitor.reset_interval
|
||||
results["total_questions_processed"] = monitor.question_count
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 运行增强版完整评估(解决中间性能衰减问题)...")
|
||||
print("📋 增强特性:")
|
||||
print(" - 双重重置策略:定期重置 + 性能驱动重置")
|
||||
print(" - 动态检索参数:基于近期性能自适应调整")
|
||||
print(" - 中间阶段严格筛选:提高上下文质量要求")
|
||||
print(" - 连续性能监控:实时检测性能衰减")
|
||||
|
||||
result = asyncio.run(run_enhanced_evaluation())
|
||||
|
||||
print("\n📊 最终评估结果:")
|
||||
print("总体指标:")
|
||||
print(f" F1: {result['overall_metrics']['f1']:.4f}")
|
||||
print(f" BLEU-1: {result['overall_metrics']['b1']:.4f}")
|
||||
print(f" Jaccard: {result['overall_metrics']['j']:.4f}")
|
||||
print(f" LoCoMo F1: {result['overall_metrics']['loc_f1']:.4f}")
|
||||
|
||||
print("\n分类别指标:")
|
||||
for category, metrics in result['category_metrics'].items():
|
||||
print(f" {category}: F1={metrics['f1']:.4f}, BLEU-1={metrics['bleu1']:.4f}, Jaccard={metrics['jaccard']:.4f}, LoCoMo F1={metrics['loc_f1']:.4f} (样本数: {metrics['count']})")
|
||||
|
||||
print("\n检索统计:")
|
||||
stats = result['retrieval_stats']
|
||||
print(f" 平均上下文长度: {stats['avg_context_length']:.0f} 字符")
|
||||
print(f" 平均检索文档数: {stats['avg_retrieved_docs']:.1f}")
|
||||
|
||||
print(f"\n性能趋势: {result['performance_trend']}")
|
||||
print(f"重置间隔: 每{result['reset_interval']}个问题")
|
||||
print(f"处理问题总数: {result['total_questions_processed']}")
|
||||
print(f"增强策略: {'启用' if result.get('enhanced_strategy', False) else '未启用'}")
|
||||
|
||||
|
||||
# 保存结果到指定目录
|
||||
# 使用代码文件所在目录的绝对路径
|
||||
current_file_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_file_dir, "results")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
output_file = os.path.join(output_dir, "enhanced_evaluation_results.json")
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n详细结果已保存到: {output_file}")
|
||||
@@ -1,687 +0,0 @@
|
||||
"""
|
||||
LoCoMo Utilities Module
|
||||
|
||||
This module provides helper functions for the LoCoMo benchmark evaluation:
|
||||
- Data loading from JSON files
|
||||
- Conversation extraction for ingestion
|
||||
- Temporal reference resolution
|
||||
- Context selection and formatting
|
||||
- Retrieval wrapper functions
|
||||
- Ingestion wrapper functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
|
||||
|
||||
def load_locomo_data(
|
||||
data_path: str,
|
||||
sample_size: int,
|
||||
conversation_index: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Load LoCoMo dataset from JSON file.
|
||||
|
||||
The LoCoMo dataset structure is a list of conversation objects, where each
|
||||
object contains a "qa" list of question-answer pairs.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
sample_size: Number of QA pairs to load (limits total QA items returned)
|
||||
conversation_index: Which conversation to load QA pairs from (default: 0 for first)
|
||||
|
||||
Returns:
|
||||
List of QA item dictionaries, each containing:
|
||||
- question: str
|
||||
- answer: str
|
||||
- category: int (1-4)
|
||||
- evidence: List[str]
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If data_path does not exist
|
||||
json.JSONDecodeError: If file is not valid JSON
|
||||
IndexError: If conversation_index is out of range
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# LoCoMo data structure: list of objects, each with a "qa" list
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
|
||||
if isinstance(raw, list):
|
||||
# Only load QA pairs from the specified conversation
|
||||
if conversation_index < len(raw):
|
||||
entry = raw[conversation_index]
|
||||
if isinstance(entry, dict) and "qa" in entry:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has {len(raw)} conversations."
|
||||
)
|
||||
else:
|
||||
# Fallback: single object with qa list
|
||||
if conversation_index == 0:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
else:
|
||||
raise IndexError(
|
||||
f"Conversation index {conversation_index} out of range. "
|
||||
f"Dataset has only 1 conversation."
|
||||
)
|
||||
|
||||
# Return only the requested sample size
|
||||
return qa_items[:sample_size]
|
||||
|
||||
|
||||
def extract_conversations(data_path: str, max_dialogues: int = 1, max_messages_per_dialogue: Optional[int] = None) -> List[str]:
|
||||
"""
|
||||
Extract conversation texts from LoCoMo data for ingestion.
|
||||
|
||||
This function extracts the raw conversation dialogues from the LoCoMo dataset
|
||||
so they can be ingested into the memory system. Each conversation is formatted
|
||||
as a multi-line string with "role: message" format.
|
||||
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
max_dialogues: Maximum number of dialogues to extract (default: 1)
|
||||
max_messages_per_dialogue: Maximum messages per dialogue (default: None = all messages)
|
||||
|
||||
Returns:
|
||||
List of conversation strings formatted for ingestion.
|
||||
Each string contains multiple lines in format "role: message"
|
||||
|
||||
Example output:
|
||||
[
|
||||
"User: I went to the store yesterday.\\nAI: What did you buy?\\n...",
|
||||
"User: I love hiking.\\nAI: Where do you like to hike?\\n..."
|
||||
]
|
||||
"""
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"LoCoMo data file not found: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
# Ensure we have a list of entries
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
contents: List[str] = []
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
|
||||
if not isinstance(conv, dict):
|
||||
continue
|
||||
|
||||
lines: List[str] = []
|
||||
|
||||
# Collect all session_* messages
|
||||
for key, val in sorted(conv.items()):
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
for msg in val:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
role = msg.get("speaker") or "User"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
|
||||
if not text:
|
||||
continue
|
||||
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
# Limit messages if specified
|
||||
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
|
||||
break
|
||||
|
||||
# Break outer loop if we've reached the message limit
|
||||
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
|
||||
break
|
||||
|
||||
if lines:
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
return contents
|
||||
|
||||
# 时间解析:将相对时间表达转换为绝对日期
|
||||
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
||||
"""
|
||||
Resolve relative temporal references to absolute dates.
|
||||
|
||||
This function converts relative time expressions (like "today", "yesterday",
|
||||
"3 days ago") into absolute ISO date strings based on an anchor date.
|
||||
|
||||
Supported patterns:
|
||||
- today, yesterday, tomorrow
|
||||
- X days ago, in X days
|
||||
- last week, next week
|
||||
|
||||
Args:
|
||||
text: Text containing temporal references
|
||||
anchor_date: Reference date for resolution (datetime object)
|
||||
|
||||
Returns:
|
||||
Text with temporal references replaced by ISO dates (YYYY-MM-DD format)
|
||||
|
||||
Example:
|
||||
>>> anchor = datetime(2023, 5, 8)
|
||||
>>> resolve_temporal_references("I saw him yesterday", anchor)
|
||||
"I saw him 2023-05-07"
|
||||
"""
|
||||
# Ensure input is a string
|
||||
t = str(text) if text is not None else ""
|
||||
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(
|
||||
r"\btoday\b",
|
||||
anchor_date.date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\byesterday\b",
|
||||
(anchor_date - timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\btomorrow\b",
|
||||
(anchor_date + timedelta(days=1)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# X days ago
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date - timedelta(days=n)).date().isoformat()
|
||||
|
||||
# in X days
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor_date + timedelta(days=n)).date().isoformat()
|
||||
|
||||
t = re.sub(
|
||||
r"\b(\d+)\s+days?\s+ago\b",
|
||||
_ago_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
t = re.sub(
|
||||
r"\bin\s+(\d+)\s+days?\b",
|
||||
_in_repl,
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# last week / next week (approximate as 7 days)
|
||||
t = re.sub(
|
||||
r"\blast\s+week\b",
|
||||
(anchor_date - timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# 中文支持
|
||||
t = re.sub(
|
||||
r"\bnext\s+week\b",
|
||||
(anchor_date + timedelta(days=7)).date().isoformat(),
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
return t
|
||||
|
||||
|
||||
def select_and_format_information(
|
||||
retrieved_info: List[str],
|
||||
question: str,
|
||||
max_chars: int = 8000
|
||||
) -> str:
|
||||
"""
|
||||
Intelligently select and format most relevant retrieved information for LLM prompt.
|
||||
|
||||
This function scores each piece of retrieved information based on keyword matching
|
||||
with the question, then selects the highest-scoring pieces up to the character limit.
|
||||
|
||||
Scoring criteria:
|
||||
- Keyword matches (higher weight for multiple occurrences)
|
||||
- Context length (moderate length preferred)
|
||||
- Position (earlier contexts get bonus points)
|
||||
|
||||
Args:
|
||||
retrieved_info: List of retrieved information strings (chunks, statements, entities)
|
||||
question: Question being answered
|
||||
max_chars: Maximum total characters to include in final prompt
|
||||
|
||||
Returns:
|
||||
Formatted string combining the most relevant information for LLM prompt.
|
||||
Contexts are separated by double newlines.
|
||||
|
||||
Example:
|
||||
>>> contexts = ["Alice went to Paris", "Bob likes pizza", "Alice visited the Eiffel Tower"]
|
||||
>>> question = "Where did Alice go?"
|
||||
>>> select_and_format_information(contexts, question, max_chars=100)
|
||||
"Alice went to Paris\\n\\nAlice visited the Eiffel Tower"
|
||||
"""
|
||||
if not retrieved_info:
|
||||
return ""
|
||||
|
||||
# Extract question keywords (filter out stop words and short words)
|
||||
question_lower = question.lower()
|
||||
stop_words = {
|
||||
'what', 'when', 'where', 'who', 'why', 'how',
|
||||
'did', 'do', 'does', 'is', 'are', 'was', 'were',
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at'
|
||||
}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {
|
||||
word for word in question_words
|
||||
if word not in stop_words and len(word) > 2
|
||||
}
|
||||
|
||||
# Score each context
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(retrieved_info):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# Keyword matching score
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# Multiple occurrences increase score
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# Length score (prefer moderate length)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000:
|
||||
score += 5
|
||||
elif context_len >= 2000:
|
||||
score += 2
|
||||
|
||||
# Position bonus (earlier contexts often more relevant)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# Sort by score (descending)
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# Select contexts up to character limit
|
||||
selected = []
|
||||
total_chars = 0
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
else:
|
||||
# Try to include high-scoring context by truncating
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# Find lines with keywords
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines and len('\n'.join(relevant_lines)) > 100:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
selected.append(truncated + "\n[Content truncated...]")
|
||||
total_chars += len(truncated)
|
||||
break
|
||||
|
||||
return "\n\n".join(selected)
|
||||
|
||||
# 记忆系统核心能力:写入与读取
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
end_user_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Wrapper for conversation ingestion using external extraction pipeline.
|
||||
|
||||
This function populates the Neo4j database with processed conversation data
|
||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
||||
|
||||
The ingestion process:
|
||||
1. Parses conversation text into dialogue messages
|
||||
2. Chunks the dialogues into semantic units
|
||||
3. Extracts statements and entities using LLM
|
||||
4. Generates embeddings for all content
|
||||
5. Stores everything in Neo4j graph database
|
||||
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
end_user_id: Target end_user ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
Note:
|
||||
The external function uses "contexts" to mean "conversation texts".
|
||||
This runs the full extraction pipeline: chunking → entity extraction →
|
||||
statement extraction → embedding → Neo4j storage.
|
||||
"""
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
end_user_id=end_user_id,
|
||||
save_chunk_output=True,
|
||||
reset_group=reset
|
||||
)
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
||||
return False
|
||||
|
||||
async def retrieve_relevant_information(
|
||||
question: str,
|
||||
end_user_id: str,
|
||||
search_type: str,
|
||||
search_limit: int,
|
||||
connector: Any,
|
||||
embedder: Any
|
||||
) -> List[str]:
|
||||
"""
|
||||
Retrieve relevant information from memory graph for a question.
|
||||
|
||||
This function searches the Neo4j memory graph (populated during ingestion) and
|
||||
returns relevant chunks, statements, and entity information that might help
|
||||
answer the question.
|
||||
|
||||
The function supports three search types:
|
||||
- "keyword": Full-text search using Cypher queries
|
||||
- "embedding": Vector similarity search using embeddings
|
||||
- "hybrid": Combination of keyword and embedding search with reranking
|
||||
|
||||
Args:
|
||||
question: Question to search for
|
||||
end_user_id: Database group ID (identifies which conversation memory to search)
|
||||
search_type: "keyword", "embedding", or "hybrid"
|
||||
search_limit: Max memory pieces to retrieve
|
||||
connector: Neo4j connector instance
|
||||
embedder: Embedder client instance
|
||||
|
||||
Returns:
|
||||
List of text strings (chunks, statements, entity summaries) from memory graph.
|
||||
Each string represents a piece of retrieved information.
|
||||
|
||||
Raises:
|
||||
Exception: If search fails (caught and returns empty list)
|
||||
"""
|
||||
from app.repositories.neo4j.graph_search import (
|
||||
search_graph,
|
||||
search_graph_by_embedding
|
||||
)
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
contexts_all: List[str] = []
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# Embedding-based search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context from chunks
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add summaries
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities (limit to 3 to avoid noise)
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# Keyword-based search
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit
|
||||
)
|
||||
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
|
||||
# Build context from dialogues
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
# Add statements
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# Add entity names
|
||||
if entities:
|
||||
entity_names = [
|
||||
str(e.get("name", "")).strip()
|
||||
for e in entities[:5]
|
||||
if e.get("name")
|
||||
]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# Hybrid search with fallback to embedding
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
)
|
||||
|
||||
# Handle flat structure (new API format)
|
||||
if search_results and isinstance(search_results, dict):
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Check if we got results
|
||||
if not (chunks or statements or entities or summaries):
|
||||
# Try nested structure (backward compatibility)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
else:
|
||||
raise ValueError("Hybrid search returned empty results")
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to embedding search
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# Build context (same for both hybrid and fallback)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# Add top entities
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = (
|
||||
sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3]
|
||||
if scored else entities[:3]
|
||||
)
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(
|
||||
f"EntitySummary: {name}"
|
||||
f"{(' [' + '; '.join(meta) + ']') if meta else ''}"
|
||||
)
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
except Exception as e:
|
||||
# Return empty list on error
|
||||
contexts_all = []
|
||||
|
||||
return contexts_all
|
||||
|
||||
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
end_user_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Wrapper for conversation ingestion using external extraction pipeline.
|
||||
|
||||
This function populates the Neo4j database with processed conversation data
|
||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
||||
|
||||
The ingestion process:
|
||||
1. Parses conversation text into dialogue messages
|
||||
2. Chunks the dialogues into semantic units
|
||||
3. Extracts statements and entities using LLM
|
||||
4. Generates embeddings for all content
|
||||
5. Stores everything in Neo4j graph database
|
||||
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
end_user_id: Target group ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
Note:
|
||||
The external function uses "contexts" to mean "conversation texts".
|
||||
This runs the full extraction pipeline: chunking → entity extraction →
|
||||
statement extraction → embedding → Neo4j storage.
|
||||
"""
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
end_user_id=end_user_id,
|
||||
save_chunk_output=True
|
||||
)
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
||||
return False
|
||||
@@ -1,874 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
import statistics
|
||||
import re
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
def _loc_normalize(text: str) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
text = str(text) if text is not None else ""
|
||||
text = text.lower()
|
||||
text = re.sub(r"[\,]", " ", text) # 去掉逗号
|
||||
text = re.sub(r"\b(a|an|the|and)\b", " ", text)
|
||||
text = re.sub(r"[^\w\s]", " ", text)
|
||||
text = " ".join(text.split())
|
||||
return text
|
||||
|
||||
# 追加:相对时间归一化为绝对日期(有限支持:today/yesterday/tomorrow/X days ago/in X days/last week/next week)
|
||||
def _resolve_relative_times(text: str, anchor: datetime) -> str:
|
||||
import re
|
||||
# 确保输入是字符串
|
||||
t = str(text) if text is not None else ""
|
||||
# today / yesterday / tomorrow
|
||||
t = re.sub(r"\btoday\b", anchor.date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\byesterday\b", (anchor - timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\btomorrow\b", (anchor + timedelta(days=1)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
# X days ago / in X days
|
||||
def _ago_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor - timedelta(days=n)).date().isoformat()
|
||||
def _in_repl(m: re.Match[str]) -> str:
|
||||
n = int(m.group(1))
|
||||
return (anchor + timedelta(days=n)).date().isoformat()
|
||||
t = re.sub(r"\b(\d+)\s+days\s+ago\b", _ago_repl, t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bin\s+(\d+)\s+days\b", _in_repl, t, flags=re.IGNORECASE)
|
||||
# last week / next week(以7天近似)
|
||||
t = re.sub(r"\blast\s+week\b", (anchor - timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
t = re.sub(r"\bnext\s+week\b", (anchor + timedelta(days=7)).date().isoformat(), t, flags=re.IGNORECASE)
|
||||
return t
|
||||
|
||||
def loc_f1_score(prediction: str, ground_truth: str) -> float:
|
||||
# 单答案 F1:按词集合计算(近似原始实现,去除词干依赖)
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
p_tokens = _loc_normalize(pred_str).split()
|
||||
g_tokens = _loc_normalize(truth_str).split()
|
||||
if not p_tokens or not g_tokens:
|
||||
return 0.0
|
||||
p = set(p_tokens)
|
||||
g = set(g_tokens)
|
||||
tp = len(p & g)
|
||||
precision = tp / len(p) if p else 0.0
|
||||
recall = tp / len(g) if g else 0.0
|
||||
return (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
|
||||
|
||||
def loc_multi_f1(prediction: str, ground_truth: str) -> float:
|
||||
# 多答案 F1:prediction 与 ground_truth 以逗号分隔,逐一匹配取最大,再对多个 GT 取平均
|
||||
# 确保输入是字符串
|
||||
pred_str = str(prediction) if prediction is not None else ""
|
||||
truth_str = str(ground_truth) if ground_truth is not None else ""
|
||||
|
||||
predictions = [p.strip() for p in str(pred_str).split(',') if p.strip()]
|
||||
ground_truths = [g.strip() for g in str(truth_str).split(',') if g.strip()]
|
||||
if not predictions or not ground_truths:
|
||||
return 0.0
|
||||
def _f1(a: str, b: str) -> float:
|
||||
return loc_f1_score(a, b)
|
||||
vals = []
|
||||
for gt in ground_truths:
|
||||
vals.append(max(_f1(pred, gt) for pred in predictions))
|
||||
return sum(vals) / len(vals)
|
||||
|
||||
# 标准化 LoCoMo 类别名:支持数字 category 与字符串 cat/type
|
||||
CATEGORY_MAP_NUM_TO_NAME = {
|
||||
4: "Single-Hop",
|
||||
1: "Multi-Hop",
|
||||
3: "Open Domain",
|
||||
2: "Temporal",
|
||||
}
|
||||
|
||||
_TYPE_ALIASES = {
|
||||
"single-hop": "Single-Hop",
|
||||
"singlehop": "Single-Hop",
|
||||
"single hop": "Single-Hop",
|
||||
"multi-hop": "Multi-Hop",
|
||||
"multihop": "Multi-Hop",
|
||||
"multi hop": "Multi-Hop",
|
||||
"open domain": "Open Domain",
|
||||
"opendomain": "Open Domain",
|
||||
"temporal": "Temporal",
|
||||
}
|
||||
|
||||
def get_category_label(item: Dict[str, Any]) -> str:
|
||||
# 1) 直接用字符串 cat
|
||||
cat = item.get("cat")
|
||||
if isinstance(cat, str) and cat.strip():
|
||||
name = cat.strip()
|
||||
lower = name.lower()
|
||||
return _TYPE_ALIASES.get(lower, name)
|
||||
# 2) 数字 category 转名称
|
||||
cat_num = item.get("category")
|
||||
if isinstance(cat_num, int):
|
||||
return CATEGORY_MAP_NUM_TO_NAME.get(cat_num, "unknown")
|
||||
# 3) 备用 type 字段
|
||||
t = item.get("type")
|
||||
if isinstance(t, str) and t.strip():
|
||||
lower = t.strip().lower()
|
||||
return _TYPE_ALIASES.get(lower, t.strip())
|
||||
return "unknown"
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 12000) -> str:
|
||||
"""基于问题关键词智能选择上下文"""
|
||||
if not contexts:
|
||||
return ""
|
||||
|
||||
# 提取问题关键词(只保留有意义的词)
|
||||
question_lower = question.lower()
|
||||
stop_words = {'what', 'when', 'where', 'who', 'why', 'how', 'did', 'do', 'does', 'is', 'are', 'was', 'were', 'the', 'a', 'an', 'and', 'or', 'but'}
|
||||
question_words = set(re.findall(r'\b\w+\b', question_lower))
|
||||
question_words = {word for word in question_words if word not in stop_words and len(word) > 2}
|
||||
|
||||
print(f"🔍 问题关键词: {question_words}")
|
||||
|
||||
# 给每个上下文打分
|
||||
scored_contexts = []
|
||||
for i, context in enumerate(contexts):
|
||||
context_lower = context.lower()
|
||||
score = 0
|
||||
|
||||
# 关键词匹配得分
|
||||
keyword_matches = 0
|
||||
for word in question_words:
|
||||
if word in context_lower:
|
||||
keyword_matches += 1
|
||||
# 关键词出现次数越多,得分越高
|
||||
score += context_lower.count(word) * 2
|
||||
|
||||
# 上下文长度得分(适中的长度更好)
|
||||
context_len = len(context)
|
||||
if 100 < context_len < 2000: # 理想长度范围
|
||||
score += 5
|
||||
elif context_len >= 2000: # 太长可能包含无关信息
|
||||
score += 2
|
||||
|
||||
# 如果是前几个上下文,给予额外分数(通常相关性更高)
|
||||
if i < 3:
|
||||
score += 3
|
||||
|
||||
scored_contexts.append((score, context, keyword_matches))
|
||||
|
||||
# 按得分排序
|
||||
scored_contexts.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择高得分的上下文,直到达到字符限制
|
||||
selected = []
|
||||
total_chars = 0
|
||||
selected_count = 0
|
||||
|
||||
print("📊 上下文相关性分析:")
|
||||
for score, context, matches in scored_contexts[:5]: # 只显示前5个
|
||||
print(f" - 得分: {score}, 关键词匹配: {matches}, 长度: {len(context)}")
|
||||
|
||||
for score, context, matches in scored_contexts:
|
||||
if total_chars + len(context) <= max_chars:
|
||||
selected.append(context)
|
||||
total_chars += len(context)
|
||||
selected_count += 1
|
||||
else:
|
||||
# 如果这个上下文得分很高但放不下,尝试截取
|
||||
if score > 10 and total_chars < max_chars - 500:
|
||||
remaining = max_chars - total_chars
|
||||
# 找到包含关键词的部分
|
||||
lines = context.split('\n')
|
||||
relevant_lines = []
|
||||
current_chars = 0
|
||||
|
||||
for line in lines:
|
||||
line_lower = line.lower()
|
||||
line_relevance = any(word in line_lower for word in question_words)
|
||||
|
||||
if line_relevance and current_chars < remaining - 100:
|
||||
relevant_lines.append(line)
|
||||
current_chars += len(line)
|
||||
|
||||
if relevant_lines:
|
||||
truncated = '\n'.join(relevant_lines)
|
||||
if len(truncated) > 100: # 确保有足够内容
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total_chars += len(truncated)
|
||||
selected_count += 1
|
||||
break # 不再尝试添加更多上下文
|
||||
|
||||
result = "\n\n".join(selected)
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {total_chars}字符")
|
||||
return result
|
||||
|
||||
|
||||
def get_search_params_by_category(category: str):
|
||||
"""根据问题类别调整检索参数"""
|
||||
params_map = {
|
||||
"Multi-Hop": {"limit": 20, "max_chars": 15000},
|
||||
"Temporal": {"limit": 16, "max_chars": 10000},
|
||||
"Open Domain": {"limit": 24, "max_chars": 18000},
|
||||
"Single-Hop": {"limit": 12, "max_chars": 8000},
|
||||
}
|
||||
return params_map.get(category, {"limit": 16, "max_chars": 12000})
|
||||
|
||||
|
||||
async def run_locomo_eval(
|
||||
sample_size: int = 1,
|
||||
end_user_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000, # 保持默认值不变
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 32,
|
||||
search_type: str = "hybrid", # 保持默认值不变
|
||||
output_path: str | None = None,
|
||||
skip_ingest_if_exists: bool = True,
|
||||
llm_timeout: float = 10.0,
|
||||
llm_max_retries: int = 1
|
||||
) -> Dict[str, Any]:
|
||||
|
||||
# 函数内部使用三路检索逻辑,但保持参数签名不变
|
||||
end_user_id = end_user_id or SELECTED_end_user_id
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 locomo10.json 放置在: {dataset_dir}"
|
||||
)
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
|
||||
qa_items: List[Dict[str, Any]] = []
|
||||
if isinstance(raw, list):
|
||||
for entry in raw:
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
items: List[Dict[str, Any]] = qa_items[:sample_size]
|
||||
|
||||
# === 保持原来的数据摄入逻辑 ===
|
||||
entries = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
# 只摄入前1条对话(保持原样)
|
||||
max_dialogues_to_ingest = 1
|
||||
contents: List[str] = []
|
||||
print(f"📊 找到 {len(entries)} 个对话对象,只摄入前 {max_dialogues_to_ingest} 条")
|
||||
|
||||
for i, entry in enumerate(entries[:max_dialogues_to_ingest]):
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
|
||||
conv = entry.get("conversation", {})
|
||||
sample_id = entry.get("sample_id", f"unknown_{i}")
|
||||
|
||||
print(f"🔍 处理对话 {i+1}: {sample_id}")
|
||||
|
||||
lines: List[str] = []
|
||||
if isinstance(conv, dict):
|
||||
# 收集所有 session_* 的消息
|
||||
session_count = 0
|
||||
for key, val in conv.items():
|
||||
if isinstance(val, list) and key.startswith("session_"):
|
||||
session_count += 1
|
||||
for msg in val:
|
||||
role = msg.get("speaker") or "用户"
|
||||
text = msg.get("text") or ""
|
||||
text = str(text).strip()
|
||||
if not text:
|
||||
continue
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
print(f" - 包含 {session_count} 个session, {len(lines)} 条消息")
|
||||
|
||||
if not lines:
|
||||
print(f"⚠️ 警告: 对话 {sample_id} 没有对话内容,跳过摄入")
|
||||
continue
|
||||
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
print(f"📥 总共摄入 {len(contents)} 个对话的conversation内容")
|
||||
|
||||
# 选择要评测的QA对(从所有对话中选取)
|
||||
indexed_items: List[tuple[int, Dict[str, Any]]] = []
|
||||
if isinstance(raw, list):
|
||||
for e_idx, entry in enumerate(raw):
|
||||
for qa in entry.get("qa", []):
|
||||
indexed_items.append((e_idx, qa))
|
||||
else:
|
||||
for qa in raw.get("qa", []):
|
||||
indexed_items.append((0, qa))
|
||||
|
||||
# 这里使用sample_size来限制评测的QA数量
|
||||
selected = indexed_items[:sample_size]
|
||||
items: List[Dict[str, Any]] = [qa for _, qa in selected]
|
||||
|
||||
print(f"🎯 将评测 {len(items)} 个QA对,数据库中只包含 {len(contents)} 个对话")
|
||||
# === 修改结束 ===
|
||||
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 关键修复:强制重新摄入纯净的对话数据
|
||||
print("🔄 强制重新摄入纯净的对话数据...")
|
||||
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
llm_client = get_llm_client(llm_id)
|
||||
# 初始化embedder用于直接调用
|
||||
cfg_dict = get_embedder_config(embedding_id)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# connector initialized above
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 上下文诊断收集
|
||||
per_query_context_counts: List[int] = []
|
||||
per_query_context_avg_tokens: List[float] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_tokens_total: List[int] = []
|
||||
# 详细样本调试信息
|
||||
samples: List[Dict[str, Any]] = []
|
||||
# 通用指标
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
# 参考 LoCoMo 评测的类别专用 F1(multi-hop 使用多答案 F1)
|
||||
loc_f1s: List[float] = []
|
||||
# Per-category aggregation
|
||||
cat_counts: Dict[str, int] = {}
|
||||
cat_f1s: Dict[str, List[float]] = {}
|
||||
cat_b1s: Dict[str, List[float]] = {}
|
||||
cat_jss: Dict[str, List[float]] = {}
|
||||
cat_loc_f1s: Dict[str, List[float]] = {}
|
||||
try:
|
||||
for item in items:
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
# 确保答案是字符串
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
cat = get_category_label(item)
|
||||
|
||||
print(f"\n=== 处理问题: {q} ===")
|
||||
|
||||
# 根据类别调整检索参数
|
||||
search_params = get_search_params_by_category(cat)
|
||||
adjusted_limit = search_params["limit"]
|
||||
max_chars = search_params["max_chars"]
|
||||
|
||||
print(f"🏷️ 类别: {cat}, 检索参数: limit={adjusted_limit}, max_chars={max_chars}")
|
||||
|
||||
# 改进的检索逻辑:使用三路检索(statements, dialogues, entities)
|
||||
t0 = time.time()
|
||||
contexts_all: List[str] = []
|
||||
search_results = None # 保存完整的检索结果
|
||||
|
||||
try:
|
||||
if search_type == "embedding":
|
||||
# 直接调用嵌入检索,包含三路数据
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 修复:使用正确的类型
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 嵌入检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
# 构建上下文:优先使用 chunks、statements 和 summaries
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体,避免噪声
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
elif search_type == "keyword":
|
||||
# 直接调用关键词检索
|
||||
search_results = await search_graph(
|
||||
connector=connector,
|
||||
q=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit
|
||||
)
|
||||
dialogs = search_results.get("dialogues", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
print(f"🔤 关键词检索找到 {len(dialogs)} 条对话, {len(statements)} 条陈述, {len(entities)} 个实体")
|
||||
|
||||
# 构建上下文
|
||||
for d in dialogs:
|
||||
content = str(d.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
# 实体处理(关键词检索的实体可能没有分数)
|
||||
if entities:
|
||||
entity_names = [str(e.get("name", "")).strip() for e in entities[:5] if e.get("name")]
|
||||
if entity_names:
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# 使用旧版本的混合检索(重构前)
|
||||
print("🔀 使用混合检索(旧版本)...")
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
search_type=search_type,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
rerank_alpha=0.6,
|
||||
use_forgetting_rerank=False,
|
||||
use_llm_rerank=False
|
||||
)
|
||||
|
||||
# 处理旧版本的返回结构(包含 reranked_results)
|
||||
if search_results and isinstance(search_results, dict):
|
||||
# 对于 hybrid 搜索,使用 reranked_results
|
||||
if "reranked_results" in search_results:
|
||||
reranked = search_results["reranked_results"]
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
# 单一搜索类型的结果
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# 检查是否有有效结果
|
||||
if chunks or statements or entities or summaries:
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 陈述, {len(entities)} 实体, {len(summaries)} 摘要")
|
||||
else:
|
||||
# 如果顶层没有结果,尝试旧的嵌套结构(向后兼容)
|
||||
reranked = search_results.get("reranked_results", {})
|
||||
if reranked and isinstance(reranked, dict):
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
print(f"✅ 混合检索成功(使用旧格式reranked结果): {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
else:
|
||||
raise ValueError("混合检索返回空结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 混合检索失败: {e},回退到嵌入检索")
|
||||
search_results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=q,
|
||||
end_user_id=end_user_id,
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
)
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
print(f"✅ 回退嵌入检索成功: {len(chunks)} chunks, {len(statements)} 陈述")
|
||||
|
||||
# 🎯 统一处理:构建上下文(所有检索类型共用)
|
||||
for c in chunks:
|
||||
content = str(c.get("content", "")).strip()
|
||||
if content:
|
||||
contexts_all.append(content)
|
||||
|
||||
for s in statements:
|
||||
stmt_text = str(s.get("statement", "")).strip()
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要:最多加入前3个高分实体
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
# 关键修复:过滤掉包含当前问题答案的上下文
|
||||
filtered_contexts = []
|
||||
for context in contexts_all:
|
||||
content = str(context)
|
||||
# 排除包含当前问题标准答案的上下文
|
||||
if ref_str and ref_str.strip() and ref_str.strip() in content:
|
||||
print("🚫 过滤掉包含标准答案的上下文")
|
||||
continue
|
||||
filtered_contexts.append(context)
|
||||
|
||||
print(f"📊 过滤后保留 {len(filtered_contexts)} 个上下文 (原 {len(contexts_all)} 个)")
|
||||
contexts_all = filtered_contexts
|
||||
|
||||
# 输出完整的检索结果信息
|
||||
print("🔍 检索结果详情:")
|
||||
if search_results:
|
||||
output_data = {
|
||||
"statements": [
|
||||
{
|
||||
"statement": s.get("statement", "")[:200] + "..." if len(s.get("statement", "")) > 200 else s.get("statement", ""),
|
||||
"score": s.get("score", 0.0)
|
||||
}
|
||||
for s in (statements[:2] if 'statements' in locals() else [])
|
||||
],
|
||||
"dialogues": [
|
||||
{
|
||||
"uuid": d.get("uuid", ""),
|
||||
"end_user_id": d.get("end_user_id", ""),
|
||||
"content": d.get("content", "")[:200] + "..." if len(d.get("content", "")) > 200 else d.get("content", ""),
|
||||
"score": d.get("score", 0.0)
|
||||
}
|
||||
for d in (dialogs[:2] if 'dialogs' in locals() else [])
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"name": e.get("name", ""),
|
||||
"entity_type": e.get("entity_type", ""),
|
||||
"score": e.get("score", 0.0)
|
||||
}
|
||||
for e in (entities[:2] if 'entities' in locals() else [])
|
||||
]
|
||||
}
|
||||
print(json.dumps(output_data, ensure_ascii=False, indent=2))
|
||||
else:
|
||||
print(" 无检索结果")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {search_type}检索失败: {e}")
|
||||
contexts_all = []
|
||||
search_results = None
|
||||
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 使用智能上下文选择
|
||||
context_text = ""
|
||||
if contexts_all:
|
||||
context_text = smart_context_selection(contexts_all, q, max_chars=max_chars)
|
||||
|
||||
# 如果智能选择后仍然过长,进行最终保护性截断
|
||||
if len(context_text) > max_chars:
|
||||
print(f"⚠️ 智能选择后仍然过长 ({len(context_text)}字符),进行最终截断")
|
||||
context_text = context_text[:max_chars] + "\n\n[最终截断...]"
|
||||
|
||||
# 时间解析
|
||||
anchor_date = datetime(2023, 5, 8) # 使用固定日期确保一致性
|
||||
context_text = _resolve_relative_times(context_text, anchor_date)
|
||||
|
||||
context_text = f"Reference date: {anchor_date.date().isoformat()}\n\n" + context_text
|
||||
|
||||
print(f"📝 最终上下文长度: {len(context_text)} 字符")
|
||||
|
||||
# 显示不同上下文的预览
|
||||
print("🔍 上下文预览:")
|
||||
for j, context in enumerate(contexts_all[:3]): # 显示前3个上下文
|
||||
preview = context[:150].replace('\n', ' ')
|
||||
print(f" 上下文{j+1}: {preview}...")
|
||||
|
||||
else:
|
||||
print("❌ 没有检索到有效上下文")
|
||||
context_text = "No relevant context found."
|
||||
|
||||
# 记录上下文诊断信息
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
per_query_context_avg_tokens.append(avg_context_tokens([context_text]))
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_tokens_total.append(len(_loc_normalize(context_text).split()))
|
||||
|
||||
# LLM 提示词
|
||||
messages = [
|
||||
{"role": "system", "content": (
|
||||
"You are a precise QA assistant. Answer following these rules:\n"
|
||||
"1) Extract the EXACT information mentioned in the context\n"
|
||||
"2) For time questions: calculate actual dates from relative times\n"
|
||||
"3) Return ONLY the answer text in simplest form\n"
|
||||
"4) For dates, use format 'DD Month YYYY' (e.g., '7 May 2023')\n"
|
||||
"5) If no clear answer found, respond with 'Unknown'"
|
||||
)},
|
||||
{"role": "user", "content": f"Question: {q}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
# 使用异步调用
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
|
||||
# 兼容不同的响应格式
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else "Unknown")
|
||||
|
||||
# 计算指标(确保使用字符串)
|
||||
f1_val = common_f1(str(pred), ref_str)
|
||||
b1_val = bleu1(str(pred), ref_str)
|
||||
j_val = jaccard(str(pred), ref_str)
|
||||
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
# Accumulate by category
|
||||
cat_counts[cat] = cat_counts.get(cat, 0) + 1
|
||||
cat_f1s.setdefault(cat, []).append(f1_val)
|
||||
cat_b1s.setdefault(cat, []).append(b1_val)
|
||||
cat_jss.setdefault(cat, []).append(j_val)
|
||||
|
||||
# LoCoMo 专用 F1:multi-hop(1) 使用多答案 F1,其它(2/3/4)使用单答案 F1
|
||||
if item.get("category") in [2, 3, 4]:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
elif item.get("category") in [1]:
|
||||
loc_val = loc_multi_f1(str(pred), ref_str)
|
||||
else:
|
||||
loc_val = loc_f1_score(str(pred), ref_str)
|
||||
loc_f1s.append(loc_val)
|
||||
cat_loc_f1s.setdefault(cat, []).append(loc_val)
|
||||
|
||||
# 保存完整的检索结果信息
|
||||
samples.append({
|
||||
"question": q,
|
||||
"answer": ref_str,
|
||||
"category": cat,
|
||||
"prediction": pred,
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val,
|
||||
"loc_f1": loc_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": adjusted_limit,
|
||||
"max_chars": max_chars
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": (t1 - t0) * 1000,
|
||||
"llm_ms": (t3 - t2) * 1000
|
||||
}
|
||||
})
|
||||
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {ref_str}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}, LoCoMo F1: {loc_val:.3f}")
|
||||
|
||||
# Compute per-category averages and dispersion (std, iqr)
|
||||
def _percentile(sorted_vals: List[float], p: float) -> float:
|
||||
if not sorted_vals:
|
||||
return 0.0
|
||||
if len(sorted_vals) == 1:
|
||||
return sorted_vals[0]
|
||||
k = (len(sorted_vals) - 1) * p
|
||||
f = int(k)
|
||||
c = f + 1 if f + 1 < len(sorted_vals) else f
|
||||
if f == c:
|
||||
return sorted_vals[f]
|
||||
return sorted_vals[f] + (sorted_vals[c] - sorted_vals[f]) * (k - f)
|
||||
|
||||
by_category: Dict[str, Dict[str, float | int]] = {}
|
||||
for c in cat_counts:
|
||||
f_list = cat_f1s.get(c, [])
|
||||
b_list = cat_b1s.get(c, [])
|
||||
j_list = cat_jss.get(c, [])
|
||||
lf_list = cat_loc_f1s.get(c, [])
|
||||
j_sorted = sorted(j_list)
|
||||
j_std = statistics.stdev(j_list) if len(j_list) > 1 else 0.0
|
||||
j_q75 = _percentile(j_sorted, 0.75)
|
||||
j_q25 = _percentile(j_sorted, 0.25)
|
||||
by_category[c] = {
|
||||
"count": cat_counts[c],
|
||||
"f1": (sum(f_list) / max(len(f_list), 1)) if f_list else 0.0,
|
||||
"b1": (sum(b_list) / max(len(b_list), 1)) if b_list else 0.0,
|
||||
"j": (sum(j_list) / max(len(j_list), 1)) if j_list else 0.0,
|
||||
"j_std": j_std,
|
||||
"j_iqr": (j_q75 - j_q25) if j_list else 0.0,
|
||||
# 参考 LoCoMo 评测的类别专用 F1
|
||||
"loc_f1": (sum(lf_list) / max(len(lf_list), 1)) if lf_list else 0.0,
|
||||
}
|
||||
|
||||
# 累加命中(cum accuracy by category):与 evaluation_stats.py 输出形式相仿
|
||||
cum_accuracy_by_category = {c: sum(cat_loc_f1s.get(c, [])) for c in cat_counts}
|
||||
|
||||
result = {
|
||||
"dataset": "locomo",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": sum(f1s) / max(len(f1s), 1),
|
||||
"b1": sum(b1s) / max(len(b1s), 1),
|
||||
"j": sum(jss) / max(len(jss), 1),
|
||||
# LoCoMo 类别专用 F1 的总体
|
||||
"loc_f1": sum(loc_f1s) / max(len(loc_f1s), 1),
|
||||
},
|
||||
"by_category": by_category,
|
||||
"category_counts": cat_counts,
|
||||
"cum_accuracy_by_category": cum_accuracy_by_category,
|
||||
"context": {
|
||||
"avg_tokens": (sum(per_query_context_avg_tokens) / max(len(per_query_context_avg_tokens), 1)) if per_query_context_avg_tokens else 0.0,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": (sum(per_query_context_tokens_total) / max(len(per_query_context_tokens_total), 1)) if per_query_context_tokens_total else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": llm_id,
|
||||
"retrieval_embedding_id": embedding_id,
|
||||
"chunker_strategy": os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker"),
|
||||
"skip_ingest_if_exists": skip_ingest_if_exists,
|
||||
"llm_timeout": llm_timeout,
|
||||
"llm_max_retries": llm_max_retries,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens
|
||||
},
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
if output_path:
|
||||
try:
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"✅ 结果已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 保存结果失败: {e}")
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Run LoCoMo evaluation with Qwen search")
|
||||
parser.add_argument("--sample_size", type=int, default=1, help="Number of samples to evaluate")
|
||||
parser.add_argument("--end_user_id", type=str, default=None, help="Group ID for retrieval")
|
||||
parser.add_argument("--search_limit", type=int, default=8, help="Search limit per query")
|
||||
parser.add_argument("--context_char_budget", type=int, default=12000, help="Max characters for context")
|
||||
parser.add_argument("--llm_temperature", type=float, default=0.0, help="LLM temperature")
|
||||
parser.add_argument("--llm_max_tokens", type=int, default=32, help="LLM max tokens")
|
||||
parser.add_argument("--search_type", type=str, default="embedding", choices=["keyword", "embedding", "hybrid"], help="Search type")
|
||||
parser.add_argument("--output_path", type=str, default=None, help="Output path for results")
|
||||
parser.add_argument("--skip_ingest_if_exists", action="store_true", help="Skip ingest if group exists")
|
||||
parser.add_argument("--llm_timeout", type=float, default=10.0, help="LLM timeout in seconds")
|
||||
parser.add_argument("--llm_max_retries", type=int, default=1, help="LLM max retries")
|
||||
args = parser.parse_args()
|
||||
|
||||
load_dotenv()
|
||||
|
||||
result = asyncio.run(run_locomo_eval(
|
||||
sample_size=args.sample_size,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
output_path=args.output_path,
|
||||
skip_ingest_if_exists=args.skip_ingest_if_exists,
|
||||
llm_timeout=args.llm_timeout,
|
||||
llm_max_retries=args.llm_max_retries
|
||||
))
|
||||
|
||||
print("\n" + "="*50)
|
||||
print("📊 最终评测结果:")
|
||||
print(f" 样本数量: {result['items']}")
|
||||
print(f" F1: {result['metrics']['f1']:.3f}")
|
||||
print(f" BLEU-1: {result['metrics']['b1']:.3f}")
|
||||
print(f" Jaccard: {result['metrics']['j']:.3f}")
|
||||
print(f" LoCoMo F1: {result['metrics']['loc_f1']:.3f}")
|
||||
print(f" 平均上下文长度: {result['context']['avg_chars']:.0f} 字符")
|
||||
print(f" 平均检索延迟: {result['latency']['search']['mean']:.1f}ms")
|
||||
print(f" 平均LLM延迟: {result['latency']['llm']['mean']:.1f}ms")
|
||||
|
||||
if result['by_category']:
|
||||
print("\n📈 按类别细分:")
|
||||
for cat, metrics in result['by_category'].items():
|
||||
print(f" {cat}:")
|
||||
print(f" 样本数: {metrics['count']}")
|
||||
print(f" F1: {metrics['f1']:.3f}")
|
||||
print(f" LoCoMo F1: {metrics['loc_f1']:.3f}")
|
||||
print(f" Jaccard: {metrics['j']:.3f} (±{metrics['j_std']:.3f}, IQR={metrics['j_iqr']:.3f})")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,559 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.src.search import run_hybrid_search # 使用与 evaluate_qa.py 相同的检索函数
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。
|
||||
|
||||
参考 evaluation/memsciqa/evaluate_qa.py 的实现,避免路径导入带来的不稳定。
|
||||
"""
|
||||
if not contexts:
|
||||
return ""
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def extract_question_keywords(question: str, max_keywords: int = 8) -> List[str]:
|
||||
"""提取问题中的关键词(简单英文分词,去停用词,长度>=3)。"""
|
||||
ql = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but','of','to','in','on','for','with','from','that','this'
|
||||
}
|
||||
words = re.findall(r"\b[\w-]+\b", ql)
|
||||
kws = [w for w in words if w not in stop_words and len(w) >= 3]
|
||||
# 去重保序
|
||||
seen = set()
|
||||
uniq = []
|
||||
for w in kws:
|
||||
if w not in seen:
|
||||
uniq.append(w)
|
||||
seen.add(w)
|
||||
if len(uniq) >= max_keywords:
|
||||
break
|
||||
return uniq
|
||||
|
||||
|
||||
def analyze_contexts_simple(contexts: List[str], keywords: List[str], top_n: int = 5) -> List[Dict[str, int | float]]:
|
||||
"""对上下文进行简单相关性打分,仅用于控制台可视化。
|
||||
|
||||
评分: score = match_count*200 + min(len(text), 100000)/100
|
||||
"""
|
||||
results = []
|
||||
for ctx in contexts:
|
||||
tl = (ctx or "").lower()
|
||||
match_count = sum(1 for k in keywords if k in tl)
|
||||
length = len(ctx)
|
||||
score = match_count * 200 + min(length, 100000) / 100.0
|
||||
results.append({"score": float(f"{score:.0f}"), "match": match_count, "length": length})
|
||||
results.sort(key=lambda x: (x["score"], x["match"], x["length"]), reverse=True)
|
||||
return results[:max(top_n, 0)]
|
||||
|
||||
|
||||
# 纯测试脚本不进行摄入;若需摄入请使用 evaluate_qa.py
|
||||
|
||||
|
||||
def load_dataset_memsciqa(data_path: str) -> List[Dict[str, Any]]:
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(f"未找到数据集: {data_path}")
|
||||
items: List[Dict[str, Any]] = []
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
items.append(json.loads(line))
|
||||
except Exception:
|
||||
# 跳过坏行但不中断
|
||||
continue
|
||||
return items
|
||||
|
||||
|
||||
async def run_memsciqa_test(
|
||||
sample_size: int = 3,
|
||||
end_user_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
llm_max_tokens: int = 64,
|
||||
search_type: str = "embedding",
|
||||
data_path: str | None = None,
|
||||
start_index: int = 0,
|
||||
verbose: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""memsciqa 增强测试脚本:结合 evaluate_qa 的三路检索与智能上下文选择。
|
||||
|
||||
- 支持从指定索引开始与评估全部样本(sample_size<=0)
|
||||
- 支持在摄入前重置组(清空图)与跳过摄入
|
||||
- 支持 keyword / embedding / hybrid 三种检索
|
||||
"""
|
||||
|
||||
# 默认使用指定的 memsci 组 ID
|
||||
end_user_id = end_user_id or "group_memsci"
|
||||
|
||||
# 数据路径解析
|
||||
if not data_path:
|
||||
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
|
||||
data_path = str(dataset_dir / "msc_self_instruct.jsonl")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
|
||||
)
|
||||
|
||||
# 加载数据
|
||||
all_items = load_dataset_memsciqa(data_path)
|
||||
if sample_size is None or sample_size <= 0:
|
||||
items = all_items[start_index:]
|
||||
else:
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
llm = get_llm_client(os.getenv("EVAL_LLM_ID"))
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"))
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 评估循环
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
# 存储完整上下文文本用于统计
|
||||
contexts_used: List[str] = []
|
||||
per_query_context_chars: List[int] = []
|
||||
per_query_context_counts: List[int] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
samples: List[Dict[str, Any]] = []
|
||||
|
||||
total_items = len(items)
|
||||
for idx, item in enumerate(items):
|
||||
if verbose:
|
||||
print(f"\n🧪 评估样本: {idx+1}/{total_items}")
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
|
||||
# 检索:使用与 evaluate_qa.py 相同的 run_hybrid_search
|
||||
t0 = time.time()
|
||||
results = None
|
||||
try:
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
# 使用嵌入检索(与 qwen_search_eval 对齐)
|
||||
results = await search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=question,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
elif search_type == "keyword":
|
||||
# 关键词检索(直接调用 graph_search)
|
||||
results = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"], # 使用 chunks 而不是 dialogues
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
search_ms = (t1 - t0) * 1000
|
||||
latencies_search.append(search_ms)
|
||||
|
||||
# 构建上下文:与 evaluate_qa.py 完全一致的逻辑
|
||||
contexts_all: List[str] = []
|
||||
retrieved_counts: Dict[str, int] = {}
|
||||
if results:
|
||||
# 处理 hybrid 搜索结果
|
||||
if search_type == "hybrid":
|
||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
||||
emb_dialogs = emb.get("dialogues", [])
|
||||
emb_statements = emb.get("statements", [])
|
||||
emb_entities = emb.get("entities", [])
|
||||
kw_dialogs = kw.get("dialogues", [])
|
||||
kw_statements = kw.get("statements", [])
|
||||
kw_entities = kw.get("entities", [])
|
||||
all_dialogs = emb_dialogs + kw_dialogs
|
||||
all_statements = emb_statements + kw_statements
|
||||
all_entities = emb_entities + kw_entities
|
||||
|
||||
# 简单去重
|
||||
seen_dialog = set()
|
||||
dialogues = []
|
||||
for d in all_dialogs:
|
||||
key = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if key not in seen_dialog:
|
||||
dialogues.append(d)
|
||||
seen_dialog.add(key)
|
||||
|
||||
seen_stmt = set()
|
||||
statements = []
|
||||
for s in all_statements:
|
||||
key = str(s.get("statement", ""))
|
||||
if key not in seen_stmt:
|
||||
statements.append(s)
|
||||
seen_stmt.add(key)
|
||||
|
||||
seen_ent = set()
|
||||
entities = []
|
||||
for e in all_entities:
|
||||
key = str(e.get("name", ""))
|
||||
if key not in seen_ent:
|
||||
entities.append(e)
|
||||
seen_ent.add(key)
|
||||
else:
|
||||
# embedding 或 keyword 单独搜索
|
||||
dialogues = results.get("dialogues", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
|
||||
retrieved_counts = {
|
||||
"dialogues": len(dialogues),
|
||||
"statements": len(statements),
|
||||
"entities": len(entities),
|
||||
}
|
||||
|
||||
# 构建上下文文本
|
||||
for d in dialogues:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
|
||||
# 实体摘要
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
if verbose:
|
||||
if retrieved_counts:
|
||||
print(f"✅ 检索成功: {retrieved_counts.get('dialogues',0)} dialogues, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
q_keywords = extract_question_keywords(question, max_keywords=8)
|
||||
if q_keywords:
|
||||
print(f"🔍 问题关键词: {set(q_keywords)}")
|
||||
if contexts_all:
|
||||
analysis = analyze_contexts_simple(contexts_all, q_keywords, top_n=5)
|
||||
if analysis:
|
||||
print("📊 上下文相关性分析:")
|
||||
for a in analysis:
|
||||
print(f" - 得分: {int(a['score'])}, 关键词匹配: {a['match']}, 长度: {a['length']}")
|
||||
# 打印检索到的上下文预览,便于定位为何为 Unknown
|
||||
print("🔎 上下文预览(最多前10条,每条截断展示):")
|
||||
for i, ctx in enumerate(contexts_all[:10]):
|
||||
preview = str(ctx).replace("\n", " ")
|
||||
if len(preview) > 300:
|
||||
preview = preview[:300] + "..."
|
||||
print(f" [{i+1}] 长度: {len(ctx)} | 片段: {preview}")
|
||||
# 标注参考答案是否出现在任一上下文中
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower:
|
||||
hits = []
|
||||
for i, ctx in enumerate(contexts_all):
|
||||
if ref_lower in str(ctx).lower():
|
||||
hits.append(i+1)
|
||||
print(f"🔗 参考答案命中上下文条数: {len(hits)}" + (f" | 命中索引: {hits}" if hits else ""))
|
||||
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text)
|
||||
per_query_context_chars.append(len(context_text))
|
||||
per_query_context_counts.append(len(contexts_all))
|
||||
|
||||
if verbose:
|
||||
selected_count = (context_text.count("\n\n") + 1) if context_text else 0
|
||||
print(f"✅ 智能选择: {selected_count}个上下文, 总长度: {len(context_text)}字符")
|
||||
# 展示拼接后的上下文片段,便于核查是否包含答案
|
||||
concat_preview = context_text.replace("\n", " ")
|
||||
if len(concat_preview) > 600:
|
||||
concat_preview = concat_preview[:600] + "..."
|
||||
print(f"🧵 拼接上下文预览: {concat_preview}")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are a QA assistant. Answer in English. Follow these guidelines:\n"
|
||||
"1) If the context contains information to answer the question, provide a concise answer based on the context;\n"
|
||||
"2) If the context does not contain enough information to answer the question, respond with 'Unknown';\n"
|
||||
"3) Keep your answer brief and to the point;\n"
|
||||
"4) Do not add explanations or additional text beyond the answer."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
|
||||
t2 = time.time()
|
||||
try:
|
||||
# 使用异步调用
|
||||
resp = await llm.chat(messages=messages)
|
||||
# 更健壮的响应解析,处理不同的LLM响应格式
|
||||
if hasattr(resp, 'content'):
|
||||
pred = resp.content.strip()
|
||||
elif isinstance(resp, dict) and "choices" in resp and len(resp["choices"]) > 0:
|
||||
pred = resp["choices"][0]["message"]["content"].strip()
|
||||
elif isinstance(resp, dict) and "content" in resp:
|
||||
pred = resp["content"].strip()
|
||||
elif isinstance(resp, str):
|
||||
pred = resp.strip()
|
||||
else:
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM响应格式异常: {type(resp)} - {resp}")
|
||||
|
||||
# 检查预测是否为"Unknown"或空,如果是则检查上下文是否真的没有答案
|
||||
if pred.lower() in ["unknown", ""]:
|
||||
# 如果参考答案在上下文中存在,但LLM返回Unknown,可能是提示词问题
|
||||
ref_lower = (str(reference) or "").lower()
|
||||
if ref_lower and any(ref_lower in ctx.lower() for ctx in contexts_all):
|
||||
print("⚠️ 参考答案在上下文中存在但LLM返回Unknown,检查提示词")
|
||||
except Exception as e:
|
||||
# 更详细的错误处理
|
||||
pred = "Unknown"
|
||||
print(f"⚠️ LLM调用异常: {e}")
|
||||
t3 = time.time()
|
||||
llm_ms = (t3 - t2) * 1000
|
||||
latencies_llm.append(llm_ms)
|
||||
|
||||
exact = exact_match(pred, reference)
|
||||
correct_flags.append(exact)
|
||||
f1_val = f1_score(str(pred), str(reference))
|
||||
b1_val = bleu1(str(pred), str(reference))
|
||||
j_val = jaccard(str(pred), str(reference))
|
||||
f1s.append(f1_val)
|
||||
b1s.append(b1_val)
|
||||
jss.append(j_val)
|
||||
|
||||
if verbose:
|
||||
print(f"🤖 LLM 回答: {pred}")
|
||||
print(f"✅ 正确答案: {reference}")
|
||||
print(f"📈 当前指标 - F1: {f1_val:.3f}, BLEU-1: {b1_val:.3f}, Jaccard: {j_val:.3f}")
|
||||
print(f"⏱️ 延迟 - 检索: {search_ms:.0f}ms, LLM: {llm_ms:.0f}ms")
|
||||
|
||||
# 对齐 locomo/qwen_search_eval.py 的样本输出结构
|
||||
samples.append({
|
||||
"question": str(question),
|
||||
"answer": str(reference),
|
||||
"prediction": str(pred),
|
||||
"metrics": {
|
||||
"f1": f1_val,
|
||||
"b1": b1_val,
|
||||
"j": j_val
|
||||
},
|
||||
"retrieval": {
|
||||
"retrieved_documents": len(contexts_all),
|
||||
"context_length": len(context_text),
|
||||
"search_limit": search_limit,
|
||||
"max_chars": context_char_budget
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": search_ms,
|
||||
"llm_ms": llm_ms
|
||||
}
|
||||
})
|
||||
|
||||
# 计算总体指标与聚合
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"b1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"j": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"context": {
|
||||
"avg_tokens": ctx_avg_tokens,
|
||||
"avg_chars": (sum(per_query_context_chars) / max(len(per_query_context_chars), 1)) if per_query_context_chars else 0.0,
|
||||
"count_avg": (sum(per_query_context_counts) / max(len(per_query_context_counts), 1)) if per_query_context_counts else 0.0,
|
||||
"avg_memory_tokens": 0.0
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"samples": samples,
|
||||
"params": {
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"llm_temperature": llm_temperature,
|
||||
"llm_max_tokens": llm_max_tokens,
|
||||
"search_type": search_type,
|
||||
"start_index": start_index,
|
||||
"llm_id": os.getenv("EVAL_LLM_ID"),
|
||||
"retrieval_embedding_id": os.getenv("EVAL_EMBEDDING_ID")
|
||||
},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
try:
|
||||
await connector.close()
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
|
||||
parser.add_argument("--sample-size", type=int, default=10, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
||||
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大输出 token")
|
||||
parser.add_argument("--search-type", type=str, default="embedding", choices=["embedding","keyword","hybrid"], help="检索类型(hybrid 等同于 embedding)")
|
||||
parser.add_argument("--data-path", type=str, default=None, help="数据集路径(默认 data/msc_self_instruct.jsonl)")
|
||||
parser.add_argument("--output", type=str, default=None, help="将评估结果保存到指定文件路径(JSON)")
|
||||
parser.add_argument("--verbose", action="store_true", default=True, help="打印过程日志(默认开启)")
|
||||
parser.add_argument("--quiet", action="store_true", help="关闭过程日志")
|
||||
args = parser.parse_args()
|
||||
|
||||
sample_size = 0 if args.all else args.sample_size
|
||||
|
||||
verbose_flag = False if args.quiet else args.verbose
|
||||
result = asyncio.run(
|
||||
run_memsciqa_test(
|
||||
sample_size=sample_size,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
data_path=args.data_path,
|
||||
start_index=args.start_index,
|
||||
verbose=verbose_flag,
|
||||
)
|
||||
)
|
||||
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果保存
|
||||
out_path = args.output
|
||||
if not out_path:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, "results")
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = os.path.join(dataset_results_dir, f"memsciqa_{result['params']['search_type']}_{ts}.json")
|
||||
try:
|
||||
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n💾 结果已保存: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 结果保存失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,369 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
"""基于问题关键词对上下文进行评分选择,并在预算内拼接文本。"""
|
||||
if not contexts:
|
||||
return ""
|
||||
import re
|
||||
# 提取问题关键词(移除停用词)
|
||||
question_lower = (question or "").lower()
|
||||
stop_words = {
|
||||
'what','when','where','who','why','how','did','do','does','is','are','was','were',
|
||||
'the','a','an','and','or','but'
|
||||
}
|
||||
question_words = set(re.findall(r"\b\w+\b", question_lower))
|
||||
question_words = {w for w in question_words if w not in stop_words and len(w) > 2}
|
||||
|
||||
# 评分
|
||||
scored = []
|
||||
for i, ctx in enumerate(contexts):
|
||||
ctx_lower = (ctx or "").lower()
|
||||
score = 0
|
||||
matches = 0
|
||||
for w in question_words:
|
||||
if w in ctx_lower:
|
||||
matches += 1
|
||||
score += ctx_lower.count(w) * 2
|
||||
length = len(ctx)
|
||||
if 100 < length < 2000:
|
||||
score += 5
|
||||
elif length >= 2000:
|
||||
score += 2
|
||||
if i < 3:
|
||||
score += 3
|
||||
scored.append((score, ctx, matches))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
# 选择直到达到字符限制,必要时截断包含关键词的段落
|
||||
selected: List[str] = []
|
||||
total = 0
|
||||
for score, ctx, _ in scored:
|
||||
if total + len(ctx) <= max_chars:
|
||||
selected.append(ctx)
|
||||
total += len(ctx)
|
||||
else:
|
||||
if score > 10 and total < max_chars - 200:
|
||||
remaining = max_chars - total
|
||||
lines = ctx.split('\n')
|
||||
rel_lines: List[str] = []
|
||||
cur = 0
|
||||
for line in lines:
|
||||
l = line.lower()
|
||||
if any(w in l for w in question_words) and cur < remaining - 50:
|
||||
rel_lines.append(line)
|
||||
cur += len(line)
|
||||
if rel_lines:
|
||||
truncated = '\n'.join(rel_lines)
|
||||
if len(truncated) > 50:
|
||||
selected.append(truncated + "\n[相关内容截断...]")
|
||||
total += len(truncated)
|
||||
break
|
||||
return "\n\n".join(selected)
|
||||
|
||||
|
||||
def build_context_from_dialog(dialog_obj: Dict[str, Any]) -> str:
|
||||
"""Compose a text context from `dialog` list in msc_self_instruct item."""
|
||||
parts: List[str] = []
|
||||
for turn in dialog_obj.get("dialog", []):
|
||||
speaker = turn.get("speaker", "")
|
||||
text = turn.get("text", "")
|
||||
if text:
|
||||
parts.append(f"{speaker}: {text}")
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""Combine dialogues from embedding and keyword searches (embedding first)."""
|
||||
if results is None:
|
||||
return []
|
||||
emb = []
|
||||
kw = []
|
||||
if isinstance(results.get("embedding_search"), dict):
|
||||
emb = results.get("embedding_search", {}).get("dialogues", []) or []
|
||||
elif isinstance(results.get("dialogues"), list):
|
||||
emb = results.get("dialogues", []) or []
|
||||
if isinstance(results.get("keyword_search"), dict):
|
||||
kw = results.get("keyword_search", {}).get("dialogues", []) or []
|
||||
seen = set()
|
||||
merged: List[Dict[str, Any]] = []
|
||||
for d in emb:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
for d in kw:
|
||||
k = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if k not in seen:
|
||||
merged.append(d)
|
||||
seen.add(k)
|
||||
return merged
|
||||
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
end_user_id = end_user_id or SELECTED_GROUP_ID
|
||||
|
||||
# Load data
|
||||
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
|
||||
data_path = dataset_dir / "msc_self_instruct.jsonl"
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
|
||||
)
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
|
||||
|
||||
|
||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
from app.db import get_db
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
latencies_llm: List[float] = []
|
||||
latencies_search: List[float] = []
|
||||
contexts_used: List[str] = []
|
||||
correct_flags: List[float] = []
|
||||
f1s: List[float] = []
|
||||
b1s: List[float] = []
|
||||
jss: List[float] = []
|
||||
try:
|
||||
for item in items:
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
# 检索:对齐 locomo 的三路检索(dialogues/statements/entities)
|
||||
t0 = time.time()
|
||||
try:
|
||||
results = await run_hybrid_search(
|
||||
query_text=question,
|
||||
search_type=search_type,
|
||||
end_user_id=end_user_id,
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
t1 = time.time()
|
||||
latencies_search.append((t1 - t0) * 1000)
|
||||
|
||||
# 构建上下文:包含对话、陈述和实体摘要,并智能选择
|
||||
contexts_all: List[str] = []
|
||||
if results:
|
||||
if search_type == "hybrid":
|
||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
||||
emb_dialogs = emb.get("dialogues", [])
|
||||
emb_statements = emb.get("statements", [])
|
||||
emb_entities = emb.get("entities", [])
|
||||
kw_dialogs = kw.get("dialogues", [])
|
||||
kw_statements = kw.get("statements", [])
|
||||
kw_entities = kw.get("entities", [])
|
||||
all_dialogs = emb_dialogs + kw_dialogs
|
||||
all_statements = emb_statements + kw_statements
|
||||
all_entities = emb_entities + kw_entities
|
||||
|
||||
# 简单去重与限制
|
||||
seen_texts = set()
|
||||
for d in all_dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
for s in all_statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text and text not in seen_texts:
|
||||
contexts_all.append(text)
|
||||
seen_texts.add(text)
|
||||
if len(contexts_all) >= search_limit:
|
||||
break
|
||||
# 实体摘要(最多3个)
|
||||
names = []
|
||||
merged_entities = all_entities[:]
|
||||
for e in merged_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
if name and name not in names:
|
||||
names.append(name)
|
||||
if len(names) >= 3:
|
||||
break
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
else:
|
||||
dialogs = results.get("dialogues", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
for d in dialogs:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
names = [str(e.get("name", "")).strip() for e in entities[:3] if e.get("name")]
|
||||
if names:
|
||||
contexts_all.append("EntitySummary: " + ", ".join(names))
|
||||
|
||||
# 智能选择并截断到预算
|
||||
context_text = smart_context_selection(contexts_all, question, max_chars=context_char_budget) if contexts_all else ""
|
||||
if not context_text:
|
||||
context_text = "No relevant context found."
|
||||
contexts_used.append(context_text[:200])
|
||||
|
||||
# Call LLM (使用异步调用)
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a QA assistant. Answer in English. Strictly follow: 1) If the context contains the answer, copy the shortest exact span from the context as the answer; 2) If the answer cannot be determined from the context, respond with 'Unknown'; 3) Return ONLY the answer text, no explanations."},
|
||||
{"role": "user", "content": f"Question: {question}\n\nContext:\n{context_text}"},
|
||||
]
|
||||
t2 = time.time()
|
||||
resp = await llm_client.chat(messages=messages)
|
||||
t3 = time.time()
|
||||
latencies_llm.append((t3 - t2) * 1000)
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
|
||||
# Aggregate metrics
|
||||
acc = sum(correct_flags) / max(len(correct_flags), 1)
|
||||
ctx_avg_tokens = avg_context_tokens(contexts_used)
|
||||
result = {
|
||||
"dataset": "memsciqa",
|
||||
"items": len(items),
|
||||
"metrics": {
|
||||
"accuracy": acc,
|
||||
# Placeholders for extensibility
|
||||
"f1": (sum(f1s) / max(len(f1s), 1)) if f1s else 0.0,
|
||||
"bleu1": (sum(b1s) / max(len(b1s), 1)) if b1s else 0.0,
|
||||
"jaccard": (sum(jss) / max(len(jss), 1)) if jss else 0.0,
|
||||
},
|
||||
"latency": {
|
||||
"search": latency_stats(latencies_search),
|
||||
"llm": latency_stats(latencies_llm),
|
||||
},
|
||||
"avg_context_tokens": ctx_avg_tokens,
|
||||
}
|
||||
return result
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def main():
|
||||
# Load environment variables first
|
||||
load_dotenv()
|
||||
|
||||
# Get defaults from environment variables
|
||||
env_sample_size = os.getenv("MEMSCIQA_SAMPLE_SIZE")
|
||||
env_search_limit = os.getenv("MEMSCIQA_SEARCH_LIMIT")
|
||||
env_context_budget = os.getenv("MEMSCIQA_CONTEXT_CHAR_BUDGET")
|
||||
env_llm_max_tokens = os.getenv("MEMSCIQA_LLM_MAX_TOKENS")
|
||||
env_skip_ingest = os.getenv("MEMSCIQA_SKIP_INGEST", "false").lower() in ("true", "1", "yes")
|
||||
env_output_dir = os.getenv("MEMSCIQA_OUTPUT_DIR")
|
||||
|
||||
# Convert to appropriate types with fallback to code defaults
|
||||
default_sample_size = int(env_sample_size) if env_sample_size else 1
|
||||
default_search_limit = int(env_search_limit) if env_search_limit else 8
|
||||
default_context_budget = int(env_context_budget) if env_context_budget else 4000
|
||||
default_llm_max_tokens = int(env_llm_max_tokens) if env_llm_max_tokens else 64
|
||||
default_output_dir = env_output_dir if env_output_dir else None
|
||||
|
||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||
parser.add_argument("--end-user-id", type=str, default=None, help="可选 end_user_id,默认使用环境变量")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=default_llm_max_tokens,
|
||||
help=f"LLM 最大生成长度 (env: MEMSCIQA_LLM_MAX_TOKENS={env_llm_max_tokens or 'not set'})")
|
||||
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
|
||||
parser.add_argument("--skip-ingest", action="store_true", default=env_skip_ingest,
|
||||
help=f"跳过数据摄入,使用 Neo4j 中的现有数据 (env: MEMSCIQA_SKIP_INGEST={os.getenv('MEMSCIQA_SKIP_INGEST', 'false')})")
|
||||
parser.add_argument("--output-dir", type=str, default=default_output_dir,
|
||||
help=f"结果保存目录 (env: MEMSCIQA_OUTPUT_DIR={env_output_dir or 'not set'})")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(
|
||||
run_memsciqa_eval(
|
||||
sample_size=args.sample_size,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
skip_ingest=args.skip_ingest,
|
||||
)
|
||||
)
|
||||
|
||||
# Print results to console
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# Save results to file
|
||||
output_dir = args.output_dir
|
||||
if output_dir is None:
|
||||
# Use absolute path to ensure results are saved in the correct location
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
output_dir = script_dir / "results"
|
||||
elif not Path(output_dir).is_absolute():
|
||||
# If relative path, make it relative to this script's directory
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
output_dir = script_dir / output_dir
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = output_dir / f"memsciqa_{timestamp_str}.json"
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n✅ 结果已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 保存结果失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,147 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||
from app.core.memory.evaluation.locomo.qwen_search_eval import run_locomo_eval
|
||||
|
||||
|
||||
async def run(
|
||||
dataset: str,
|
||||
sample_size: int,
|
||||
reset_group: bool,
|
||||
end_user_id: str | None,
|
||||
judge_model: str | None = None,
|
||||
search_limit: int | None = None,
|
||||
context_char_budget: int | None = None,
|
||||
llm_temperature: float | None = None,
|
||||
llm_max_tokens: int | None = None,
|
||||
search_type: str | None = None,
|
||||
start_index: int | None = None,
|
||||
max_contexts_per_item: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
# Use environment variable with fallback chain if not provided
|
||||
if end_user_id is None:
|
||||
end_user_id = os.getenv("EVAL_END_USER_ID", "benchmark_default")
|
||||
|
||||
if reset_group:
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
await connector.delete_group(end_user_id)
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
if dataset == "locomo":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_locomo_eval(**kwargs)
|
||||
|
||||
if dataset == "memsciqa":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
return await run_memsciqa_eval(**kwargs)
|
||||
|
||||
if dataset == "longmemeval":
|
||||
kwargs: Dict[str, Any] = {"sample_size": sample_size, "end_user_id": end_user_id}
|
||||
if search_limit is not None:
|
||||
kwargs["search_limit"] = search_limit
|
||||
if context_char_budget is not None:
|
||||
kwargs["context_char_budget"] = context_char_budget
|
||||
if llm_temperature is not None:
|
||||
kwargs["llm_temperature"] = llm_temperature
|
||||
if llm_max_tokens is not None:
|
||||
kwargs["llm_max_tokens"] = llm_max_tokens
|
||||
if search_type is not None:
|
||||
kwargs["search_type"] = search_type
|
||||
if start_index is not None:
|
||||
kwargs["start_index"] = start_index
|
||||
if max_contexts_per_item is not None:
|
||||
kwargs["max_contexts_per_item"] = max_contexts_per_item
|
||||
return await run_longmemeval_test(**kwargs)
|
||||
raise ValueError(f"未知数据集: {dataset}")
|
||||
|
||||
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="统一评估入口:memsciqa / longmemeval / locomo")
|
||||
parser.add_argument("--dataset", choices=["memsciqa", "longmemeval", "locomo"], required=True)
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="先用一条数据跑通")
|
||||
parser.add_argument("--reset-group", action="store_true", help="运行前清空当前 end_user_id 的图数据")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json")
|
||||
parser.add_argument("--judge-model", type=str, default=None, help="可选:longmemeval 判别式评测模型名")
|
||||
parser.add_argument("--search-limit", type=int, default=None, help="检索返回的对话节点数量上限(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--context-char-budget", type=int, default=None, help="上下文字符预算(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-temperature", type=float, default=None, help="生成温度(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=None, help="最大生成 tokens(不提供则使用各脚本默认)")
|
||||
parser.add_argument("--search-type", type=str, default=None, choices=["keyword", "embedding", "hybrid"], help="检索类型(不提供则使用各脚本默认)")
|
||||
# 仅透传到 longmemeval;其他数据集忽略
|
||||
parser.add_argument("--start-index", type=int, default=None, help="仅 longmemeval:起始样本索引(不提供则用脚本默认)")
|
||||
parser.add_argument("--max-contexts-per-item", type=int, default=None, help="仅 longmemeval:每条样本摄入的上下文数量上限(不提供则用脚本默认)")
|
||||
parser.add_argument("--output", type=str, default=None, help="可选:将评估结果保存到指定文件路径(JSON);不提供时默认保存到 evaluation/<dataset>/results 目录")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(run(
|
||||
args.dataset,
|
||||
args.sample_size,
|
||||
args.reset_group,
|
||||
args.end_user_id,
|
||||
args.judge_model,
|
||||
args.search_limit,
|
||||
args.context_char_budget,
|
||||
args.llm_temperature,
|
||||
args.llm_max_tokens,
|
||||
args.search_type,
|
||||
args.start_index,
|
||||
args.max_contexts_per_item,
|
||||
))
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# 结果输出逻辑保持不变
|
||||
if args.output:
|
||||
out_path = args.output
|
||||
else:
|
||||
eval_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
dataset_results_dir = os.path.join(eval_dir, args.dataset, "results")
|
||||
out_filename = f"{args.dataset}_{args.sample_size}.json"
|
||||
out_path = os.path.join(dataset_results_dir, out_filename)
|
||||
|
||||
out_dir = os.path.dirname(out_path)
|
||||
if out_dir and not os.path.exists(out_dir):
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n结果已保存到: {out_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Submodule redbear-mem-benchmark updated: d9a00be62d...558c023dad
Reference in New Issue
Block a user