Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop

This commit is contained in:
Mark
2026-01-26 19:19:59 +08:00
37 changed files with 3072 additions and 1073 deletions

View File

@@ -0,0 +1,224 @@
# ============================================================================
# 基准测试统一配置文件示例
# ============================================================================
# 复制此文件为 .env.evaluation 并根据需要修改
# 支持的基准测试LoCoMo、LongMemEval、MemSciQA
# ============================================================================
# ============================================================================
# 通用配置(所有基准测试共用)
# ============================================================================
# ----------------------------------------------------------------------------
# Neo4j 配置
# ----------------------------------------------------------------------------
# 默认 Group ID建议各基准测试使用独立的 group
EVAL_GROUP_ID=benchmark_default
# ----------------------------------------------------------------------------
# 模型配置(必需)
# ----------------------------------------------------------------------------
# ⚠️ 必填:从数据库 models 表中选择有效的模型 ID
#
# 如何获取模型 ID
# 1. 查询数据库SELECT id, model_name FROM models WHERE is_active = true;
# 2. 或通过系统管理界面查看
# 3. 确保模型可用且配置正确
# LLM 模型 ID必填
EVAL_LLM_ID=your_llm_model_id_here
# Embedding 模型 ID必填
EVAL_EMBEDDING_ID=your_embedding_model_id_here
# ----------------------------------------------------------------------------
# 检索参数
# ----------------------------------------------------------------------------
# 检索类型: "keyword", "embedding", "hybrid"
EVAL_SEARCH_TYPE=hybrid
# 检索结果数量限制(默认值)
EVAL_SEARCH_LIMIT=12
# 上下文最大字符数(默认值)
EVAL_MAX_CONTEXT_CHARS=8000
# ----------------------------------------------------------------------------
# LLM 参数
# ----------------------------------------------------------------------------
# LLM 温度参数0.0 = 确定性输出)
EVAL_LLM_TEMPERATURE=0.0
# LLM 最大生成 token 数
EVAL_LLM_MAX_TOKENS=32
# LLM 超时时间(秒)
EVAL_LLM_TIMEOUT=10.0
# LLM 最大重试次数
EVAL_LLM_MAX_RETRIES=1
# ----------------------------------------------------------------------------
# 数据处理参数
# ----------------------------------------------------------------------------
# Chunker 策略
EVAL_CHUNKER_STRATEGY=RecursiveChunker
# 是否在导入前清空现有数据
EVAL_RESET_ON_INGEST=true
# 是否保存详细日志
EVAL_SAVE_DETAILED_LOGS=true
# ============================================================================
# LoCoMo 基准测试专用配置
# ============================================================================
# 数据集locomo10.json
# 运行python locomo_benchmark.py --sample_size 20
# ----------------------------------------------------------------------------
# Group IDLoCoMo 专用)
LOCOMO_GROUP_ID=locomo_benchmark
# 测试样本数量
# 建议值20快速测试、100中等测试、1986完整测试
LOCOMO_SAMPLE_SIZE=20
# 检索结果数量限制
LOCOMO_SEARCH_LIMIT=12
# 上下文最大字符数
LOCOMO_CONTEXT_CHAR_BUDGET=8000
# 导入的对话数量
LOCOMO_MAX_DIALOGUES=1
# 跳过数据摄入true=跳过false=摄入)
# 首次运行设置为 false后续运行可设置为 true 以节省时间
LOCOMO_SKIP_INGEST=false
# 结果保存目录
LOCOMO_OUTPUT_DIR=locomo/results
# ============================================================================
# LongMemEval 基准测试专用配置
# ============================================================================
# 数据集longmemeval_oracle_zh.json
# 运行python longmemeval_benchmark.py --sample_size 3
# 特点:支持时间推理问题的增强检索
# ----------------------------------------------------------------------------
# Group IDLongMemEval 专用)
LONGMEMEVAL_GROUP_ID=longmemeval_zh_bak_3
# 测试样本数量(<=0 表示全部样本)
LONGMEMEVAL_SAMPLE_SIZE=3
# 起始样本索引
LONGMEMEVAL_START_INDEX=0
# 检索结果数量限制
LONGMEMEVAL_SEARCH_LIMIT=8
# 上下文最大字符数
LONGMEMEVAL_CONTEXT_CHAR_BUDGET=4000
# LLM 最大生成 token 数
LONGMEMEVAL_LLM_MAX_TOKENS=16
# 每条样本最多摄入的上下文段数
LONGMEMEVAL_MAX_CONTEXTS_PER_ITEM=2
# 是否保存分块结果
LONGMEMEVAL_SAVE_CHUNK_OUTPUT=true
# 自定义分块输出路径(留空使用默认)
LONGMEMEVAL_SAVE_CHUNK_OUTPUT_PATH=
# 摄入前是否清空组数据
LONGMEMEVAL_RESET_GROUP_BEFORE_INGEST=false
# 是否跳过摄入,仅检索评估
LONGMEMEVAL_SKIP_INGEST=false
# 结果保存目录
LONGMEMEVAL_OUTPUT_DIR=longmemeval/results
# ============================================================================
# MemSciQA 基准测试专用配置
# ============================================================================
# 数据集msc_self_instruct.jsonl
# 运行python memsciqa_benchmark.py --sample_size 1
# 特点:对话记忆检索评估
# ----------------------------------------------------------------------------
# Group IDMemSciQA 专用,独立数据集)
MEMSCIQA_GROUP_ID=memsciqa_benchmark
# 测试样本数量
MEMSCIQA_SAMPLE_SIZE=1 # 0或者-1标识测试数据集中的所有样本
# 检索结果数量限制
MEMSCIQA_SEARCH_LIMIT=8
# 上下文最大字符数
MEMSCIQA_CONTEXT_CHAR_BUDGET=4000
# LLM 最大生成 token 数
MEMSCIQA_LLM_MAX_TOKENS=64
# 跳过数据摄入true=跳过false=摄入)
# 首次运行设置为 false后续运行可设置为 true 以节省时间
MEMSCIQA_SKIP_INGEST=false
# 结果保存目录(相对于 memsciqa 脚本所在目录)
# 使用 "results" 会保存到 api/app/core/memory/evaluation/memsciqa/results/
MEMSCIQA_OUTPUT_DIR=results
# ============================================================================
# 高级配置(可选)
# ============================================================================
# BM25 权重用于混合检索0.0-1.0
EVAL_RERANK_ALPHA=0.6
# 是否使用遗忘重排序
EVAL_USE_FORGETTING_RERANK=false
# 是否使用 LLM 重排序
EVAL_USE_LLM_RERANK=false
# 连接重置间隔(每 N 个问题重置一次)
EVAL_RESET_INTERVAL=5
# 性能阈值(低于此值触发重置)
EVAL_PERFORMANCE_THRESHOLD=0.6
# ============================================================================
# 快速配置指南
# ============================================================================
# 1. 复制此文件为 .env.evaluation
# 2. 修改 EVAL_LLM_ID 和 EVAL_EMBEDDING_ID 为你的模型 ID
# 3. 根据需要修改各基准测试的专用配置
# 4. 运行测试:
# - LoCoMo: python locomo/locomo_benchmark.py --sample_size 20
# - LongMemEval: python longmemeval/longmemeval_benchmark.py --sample_size 3 --all
# - MemSciQA: python memsciqa/memsciqa_benchmark.py --sample_size 10
# 配置优先级:
# 命令行参数 > 特定配置(如 LOCOMO_*> 通用配置EVAL_*> 代码默认值
# ============================================================================
# 执行LoCoMo测试
# 只摄入前5条消息评估3个问题最小测试
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
#
# 如果数据已经摄入,跳过摄入阶段直接测试
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
# 执行longmemeval测试
# python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 10 --max-contexts-per-item 3 --reset-group-before-ingest
# 执行memsciqa测试
# python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 1

View File

@@ -0,0 +1,13 @@
# 忽略实际的评估配置文件(包含敏感信息)
.env.evaluation
# 保留示例文件
!.env.evaluation.example
# 忽略测试结果文件
*/results/*.json
*/results/*.log
# 忽略数据集文件(文件过大,不应提交到 Git
dataset/*.json
dataset/*.jsonl

View File

@@ -1,30 +1,748 @@
数据集下载地址 # 1.数据集下载地址
Locomo10.jsonhttps://github.com/snap-research/locomo/tree/main/data Locomo10.json : https://github.com/snap-research/locomo/tree/main/data
LongMemEval_oracle.jsonhttps://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned LongMemEval_oracle.json : https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct msc_self_instruct.jsonl : https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
上方数据集下载好后全部放入app/core/memory/data文件夹中
全流程基准测试运行: 数据集下载之后保存至api\app\core\memory\evaluation\dataset目录下
locomo # 2.配置说明
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32 文件api\app\core\memory\evaluation\.env.evaluation.example对三个基准测试所需配置有着详细的说明
LongMemEval **实际配置文件**api\app\core\memory\evaluation\.env.evaluation
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group ```python
memsciqa # 当使用不带配置参数的命令行执行基准测试,基准测试所需的配置参数根据.env.evaluation中的参数执行
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci python -m app.core.memory.evaluation.locomo.locomo_benchmark
```
**检查neo4j指定的grou_id是否摄入数据**
```python
# 1. 进入交互模式
python -m app.core.memory.evaluation.check_enduser_data
单独检索评估运行命令: # 2. 选择 "1" 检查指定 group
python -m app.core.memory.evaluation.locomo.locomo_test # 3. 输入 group_id例如: locomo_benchmark
python -m app.core.memory.evaluation.longmemeval.test_eval # 4. 选择是否显示详细统计 (y/n)
python -m app.core.memory.evaluation.memsciqa.memsciqa-test ```
需要先在项目中修改需要检测评估的group_id。 # 3.locomo
参数及解释: ### (1)locomo执行命令
● --dataset longmemeval - 指定数据集 ```python
● --sample-size 10 - 评估10个样本 # 首先进入api目录
● --start-index 0 - 从第0个样本开始 cd api
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
● --search-limit 8 - 检索限制8条 # 只摄入前5条消息评估3个问题最小测试
● --context-char-budget 4000 - 上下文字符预算4000 python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
● --search-type hybrid - 使用混合检索
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文 # 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
● --reset-group - 运行前清空组数据 python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
```
### (2)locomo结果说明
#### 结果示例
```json
{
"dataset": "locomo",
"sample_size": 0,
"timestamp": "2026-01-26T11:24:28.239156",
"params": {
"group_id": "locomo_benchmark",
"search_type": "hybrid",
"search_limit": 12,
"context_char_budget": 8000,
"llm_id": "2c9b0782-7a85-4740-ba84-4baf77f256c4",
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2"
},
"overall_metrics": {
"f1": 0.0,
"bleu1": 0.0,
"jaccard": 0.0,
"locomo_f1": 0.0
},
"by_category": {},
"latency": {
"search": {
"mean": 0.0,
"p50": 0.0,
"p95": 0.0,
"iqr": 0.0
},
"llm": {
"mean": 0.0,
"p50": 0.0,
"p95": 0.0,
"iqr": 0.0
}
},
"context_stats": {
"avg_retrieved_docs": 0.0,
"avg_context_chars": 0.0,
"avg_context_tokens": 0.0
},
"samples": []
}
```
#### 参数详解
##### 1. 核心评估指标 (overall_metrics)
**🎯 关键进步指标:**
- **`f1`** (F1 Score): 精确率和召回率的调和平均值
- 范围0.0 - 1.0
- **越高越好**,衡量检索和生成答案的准确性
- 这是最重要的综合性能指标
- 优秀标准:> 0.85
- **`bleu1`** (BLEU-1): 单词级别的匹配度
- 范围0.0 - 1.0
- **越高越好**,衡量生成答案与标准答案的词汇重叠度
- 关注词汇层面的准确性
- **`jaccard`** (Jaccard 相似度): 集合相似度
- 范围0.0 - 1.0
- **越高越好**,衡量答案集合的相似性
- 计算公式:交集大小 / 并集大小
- **`locomo_f1`**: Locomo 特定的 F1 分数
- 范围0.0 - 1.0
- **越高越好**,针对 Locomo 数据集优化的评估指标
- 考虑了长对话记忆的特殊性
##### 2. 性能指标 (latency)
**⚡ 关键效率指标:**
- **`search`**: 检索延迟统计(单位:毫秒)
- `mean`: 平均延迟
- `p50`: 中位数延迟50%的请求在此时间内完成)
- `p95`: 95分位数延迟95%的请求在此时间内完成)
- `iqr`: 四分位距Q3-Q1衡量稳定性
- **越低越好**,衡量记忆检索速度
- 优秀标准p95 < 2000ms
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
- `mean`: 平均推理时间
- `p50`: 中位数推理时间
- `p95`: 95分位数推理时间
- `iqr`: 四分位距(越小越稳定)
- **越低越好**,衡量答案生成速度
- 优秀标准p95 < 3000ms
##### 3. 上下文统计 (context_stats)
**📊 资源效率指标:**
- **`avg_retrieved_docs`**: 平均检索文档数
- 反映检索策略的广度
- 需要平衡:太少可能信息不足,太多增加噪音和延迟
- 建议范围8-15 个文档
- **`avg_context_chars`**: 平均上下文字符数
- 反映检索内容的总量
- 应在满足准确性前提下尽量精简
- 受 `context_char_budget` 参数限制
- **`avg_context_tokens`**: 平均上下文 token 数
- **越低越好**(在保持准确性前提下)
- 直接影响 API 调用成本和推理速度
- 成本效益比 = f1 / avg_context_tokens
##### 4. 分类统计 (by_category)
- 按问题类型分类的性能指标
- 帮助识别系统在不同场景下的强弱项
- 可针对性优化特定类型的问题
#### 系统进步衡量标准
**一级指标(最重要):**
- `f1` 和 `locomo_f1` 提升 → 核心能力提升
- 目标f1 > 0.85
**二级指标(重要):**
- `latency.p95` 降低 → 用户体验提升
- 目标search.p95 < 2000ms, llm.p95 < 3000ms
**三级指标(辅助):**
- `avg_context_tokens` 降低(在保持 f1 前提下)→ 成本优化
- `iqr` 降低 → 性能稳定性提升
# 4.longmemeval
支持时间推理问题的增强检索
### (1)执行命令
```python
# 首先进入api目录
cd api
# 不带参数运行 - 使用环境变量
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark
# 命令行参数覆盖环境变量
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 2
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --skip_ingest
```
### (2)结果说明
#### 结果示例
```json
{
"dataset": "longmemeval",
"items": 1,
"accuracy_by_type": {
"single-session-user": 1.0
},
"f1_by_type": {
"single-session-user": 1.0
},
"jaccard_by_type": {
"single-session-user": 1.0
},
"samples": [
{
"question": "What degree did I graduate with?",
"prediction": "Business Administration",
"answer": "Business Administration",
"question_type": "single-session-user",
"is_temporal": false,
"question_id": "e47becba",
"options": [],
"context_count": 13,
"context_chars": 1268,
"retrieved_dialogue_count": 0,
"retrieved_statement_count": 12,
"metrics": {
"exact_match": true,
"f1": 1.0,
"jaccard": 1.0
},
"timing": {
"search_ms": 1483.100175857544,
"llm_ms": 995.8682060241699
}
}
],
"latency": {
"search": {
"mean": 1483.100175857544,
"p50": 1483.100175857544,
"p95": 1483.100175857544,
"iqr": 0.0
},
"llm": {
"mean": 995.8682060241699,
"p50": 995.8682060241699,
"p95": 995.8682060241699,
"iqr": 0.0
}
},
"context": {
"avg_tokens": 204.0,
"avg_chars": 1268,
"count_avg": 13
},
"params": {
"group_id": "longmemeval_zh_bak_3",
"search_limit": 8,
"context_char_budget": 4000,
"search_type": "hybrid",
"llm_id": "6dc52e1b-9cec-4194-af66-a74c6307fc3f",
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2",
"sample_size": 1,
"start_index": 0
},
"timestamp": "2026-01-24T21:36:10.818308",
"metric_summary": {
"score_accuracy": 100.0,
"latency_median_s": 2.478968381881714,
"latency_iqr_s": 0.0,
"avg_context_tokens_k": 0.204
},
"diagnostics": {
"duplicate_previews_top": [],
"unique_preview_count": 1
}
}
```
#### 参数详解
##### 1. 核心评估指标
**🎯 关键进步指标:**
- **`accuracy_by_type`**: 按问题类型分类的准确率
- 范围0.0 - 1.0
- **越高越好**1.0 表示 100% 准确
- 问题类型包括:
- `single-session-user`: 单会话用户信息
- `single-session-event`: 单会话事件信息
- `multi-session-user`: 多会话用户信息
- `multi-session-event`: 多会话事件信息
- 可以识别系统在不同场景下的强弱项
- **`f1_by_type`**: 按问题类型的 F1 分数
- 范围0.0 - 1.0
- **越高越好**,综合评估精确率和召回率
- 比单纯的准确率更全面
- **`jaccard_by_type`**: 按问题类型的 Jaccard 相似度
- 范围0.0 - 1.0
- **越高越好**,衡量答案集合匹配度
- 对于集合类答案特别有用
##### 2. 样本级指标 (samples)
**详细诊断指标:**
- **`metrics.exact_match`**: 精确匹配(布尔值)
- **true 越多越好**,最严格的评估标准
- 要求预测答案与标准答案完全一致
- **`metrics.f1`**: 单个样本的 F1 分数
- 范围0.0 - 1.0
- **越高越好**,衡量单个问题的回答质量
- **`is_temporal`**: 是否为时间推理问题
- 布尔值,标识问题是否涉及时间推理
- 时间推理问题通常更具挑战性
- **`context_count`**: 检索到的上下文数量
- 反映检索策略的有效性
- 建议范围8-15 个上下文片段
- **`retrieved_dialogue_count`**: 检索到的对话数
- **`retrieved_statement_count`**: 检索到的陈述数
- 这两个指标帮助理解检索的内容类型分布
- 可用于优化检索策略
- **`timing.search_ms`**: 单个问题的检索延迟(毫秒)
- **`timing.llm_ms`**: 单个问题的 LLM 推理延迟(毫秒)
- **越低越好**,反映单次查询的响应速度
##### 3. 汇总指标 (metric_summary)
**📊 关键 KPI**
- **`score_accuracy`**: 总体准确率百分比
- 范围0.0 - 100.0
- **越高越好**,最直观的性能指标
- 优秀标准:> 90.0
- **`latency_median_s`**: 中位延迟(秒)
- **越低越好**,反映真实响应速度
- 优秀标准:< 3.0 秒
- **`latency_iqr_s`**: 延迟四分位距(秒)
- **越低越好**,反映性能稳定性
- 越小说明响应时间越稳定
- **`avg_context_tokens_k`**: 平均上下文 token 数(千)
- **越低越好**(在保持准确性前提下)
- 直接影响 API 调用成本
- 成本效益比 = score_accuracy / (avg_context_tokens_k * 1000)
##### 4. 上下文统计 (context)
- **`avg_tokens`**: 平均 token 数
- **`avg_chars`**: 平均字符数
- **`count_avg`**: 平均上下文片段数
- 这些指标反映检索内容的规模
- 需要在准确性和效率之间平衡
##### 5. 性能指标 (latency)
**⚡ 效率指标:**
- **`search`**: 检索延迟统计(单位:毫秒)
- `mean`: 平均延迟
- `p50`: 中位数延迟
- `p95`: 95分位数延迟
- `iqr`: 四分位距
- **越低越好**,衡量记忆检索速度
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
- `mean`: 平均推理时间
- `p50`: 中位数推理时间
- `p95`: 95分位数推理时间
- `iqr`: 四分位距
- **越低越好**,衡量答案生成速度
##### 6. 诊断信息 (diagnostics)
- **`duplicate_previews_top`**: 重复预览统计
- 列出出现频率最高的重复内容
- 帮助发现检索冗余问题
- 应该尽量减少重复
- **`unique_preview_count`**: 唯一预览数量
- 反映检索多样性
- **越高越好**,说明检索到的内容更丰富
#### 系统进步衡量标准
**一级指标(最重要):**
- `score_accuracy` 提升 → 核心能力提升
- 目标:> 90.0%
- 各类型的 `accuracy_by_type` 均衡提升 → 全面能力提升
**二级指标(重要):**
- `latency_median_s` 降低 → 用户体验提升
- 目标:< 3.0 秒
- `exact_match` 比例提升 → 精确度提升
**三级指标(辅助):**
- `avg_context_tokens_k` 降低(在保持准确性前提下)→ 成本优化
- `unique_preview_count` 提升 → 检索多样性提升
- `latency_iqr_s` 降低 → 性能稳定性提升
**特殊关注:**
- 时间推理问题(`is_temporal: true`)的准确率
- 多会话问题的准确率(通常更具挑战性)
# 5.memsciqa
对话记忆检索评估
### (1)执行命令
```python
# 首先进入api目录
cd api
# 不带参数运行 - 使用环境变量
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark
# 命令行参数覆盖环境变量
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 100
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --skip_ingest
```
### (2)结果说明
#### 结果示例
```json
{
"dataset": "memsciqa",
"items": 1,
"metrics": {
"accuracy": 0.0,
"f1": 0.0,
"bleu1": 0.0,
"jaccard": 0.0
},
"latency": {
"search": {
"mean": 0.0,
"p50": 0.0,
"p95": 0.0,
"iqr": 0.0
},
"llm": {
"mean": 3067.7285194396973,
"p50": 3067.7285194396973,
"p95": 3067.7285194396973,
"iqr": 0.0
}
},
"avg_context_tokens": 4.0
}
```
#### 参数详解
##### 1. 核心评估指标 (metrics)
**🎯 关键进步指标:**
- **`accuracy`**: 准确率
- 范围0.0 - 1.0
- **越高越好**,最直接的性能指标
- 衡量系统回答正确的问题比例
- 优秀标准:> 0.85
- **`f1`**: F1 分数
- 范围0.0 - 1.0
- **越高越好**,平衡精确率和召回率
- 计算公式2 * (precision * recall) / (precision + recall)
- 比单纯的准确率更全面,特别适合不平衡数据集
- **`bleu1`**: BLEU-1 分数
- 范围0.0 - 1.0
- **越高越好**,衡量词汇级别的匹配度
- 关注生成答案与标准答案的单词重叠
- 源自机器翻译评估,适用于自然语言生成
- **`jaccard`**: Jaccard 相似度
- 范围0.0 - 1.0
- **越高越好**,衡量集合相似性
- 计算公式:|A ∩ B| / |A B|
- 对于多答案或集合类问题特别有用
##### 2. 性能指标 (latency)
**⚡ 效率指标:**
- **`search`**: 检索延迟统计(单位:毫秒)
- `mean`: 平均检索延迟
- `p50`: 中位数延迟50%的请求在此时间内完成)
- `p95`: 95分位数延迟95%的请求在此时间内完成)
- `iqr`: 四分位距Q3-Q1衡量稳定性
- **越低越好**,衡量记忆检索效率
- 优秀标准p95 < 2000ms
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
- `mean`: 平均推理时间
- `p50`: 中位数推理时间
- `p95`: 95分位数推理时间
- `iqr`: 四分位距(越小越稳定)
- **越低越好**,衡量答案生成速度
- 优秀标准p95 < 3000ms
- 注意LLM 延迟通常占总延迟的大部分
##### 3. 资源指标
- **`avg_context_tokens`**: 平均上下文 token 数
- **越低越好**(在保持准确性前提下)
- 直接影响:
- API 调用成本(按 token 计费)
- 推理速度token 越多越慢)
- 上下文窗口占用
- 成本效益比 = accuracy / avg_context_tokens
- 建议范围:根据模型上下文窗口和成本预算调整
##### 4. 数据集特点
- **`items`**: 评估的问题数量
- 样本量越大,评估结果越可靠
- 建议至少 100 个样本以获得稳定的评估结果
- **对话记忆特性**
- MemSciQA 专注于对话历史中的记忆检索
- 评估系统从多轮对话中提取和回忆信息的能力
- 模拟真实的对话场景
#### 系统进步衡量标准
**一级指标(最重要):**
- `accuracy` 提升 → 核心能力提升
- 目标:> 0.85
- `f1` 提升 → 综合性能提升
- 目标:> 0.80
**二级指标(重要):**
- `latency.p95` 降低 → 用户体验提升
- search.p95 目标:< 2000ms
- llm.p95 目标:< 3000ms
- `iqr` 降低 → 性能稳定性提升
**三级指标(辅助):**
- `avg_context_tokens` 降低(在保持准确性前提下)→ 成本优化
- `bleu1` 和 `jaccard` 提升 → 答案质量提升
**综合评估:**
- 成本效益比 = accuracy / avg_context_tokens
- 该比值越高,说明系统在相同成本下性能越好
- 总延迟 = search.p95 + llm.p95
- 应控制在 5 秒以内以保证良好的用户体验
#### 优化建议
**提升准确性:**
- 优化检索算法(调整 hybrid search 参数)
- 改进 embedding 模型质量
- 增加检索上下文数量(`search_limit`
- 优化 prompt 工程
**提升效率:**
- 减少不必要的检索文档
- 使用更快的 LLM 模型或量化版本
- 实施缓存策略(相似问题复用结果)
- 优化数据库索引
**平衡性能:**
- 监控 accuracy vs latency 的权衡
- 监控 accuracy vs cost (tokens) 的权衡
- 根据业务需求调整优先级
---
# 6. 三个基准测试对比总结
## 6.1 测试特点对比
| 基准测试 | 主要评估目标 | 数据集特点 | 适用场景 |
|---------|------------|-----------|---------|
| **Locomo** | 长对话记忆检索 | 长对话历史,多轮交互 | 评估长期记忆保持和检索能力 |
| **LongMemEval** | 时间推理和多会话记忆 | 支持时间推理,多会话场景 | 评估时间感知和跨会话记忆能力 |
| **MemSciQA** | 对话记忆问答 | 对话历史问答 | 评估对话上下文理解和记忆提取 |
## 6.2 核心指标对比
### 准确性指标
| 指标 | Locomo | LongMemEval | MemSciQA | 说明 |
|-----|--------|-------------|----------|------|
| **F1 Score** | ✅ | ✅ | ✅ | 所有测试都使用,最重要的综合指标 |
| **Accuracy** | ❌ | ✅ | ✅ | 直观的准确率指标 |
| **BLEU-1** | ✅ | ❌ | ✅ | 词汇级别匹配度 |
| **Jaccard** | ✅ | ✅ | ✅ | 集合相似度 |
| **Exact Match** | ❌ | ✅ | ❌ | 最严格的评估标准 |
### 性能指标
所有三个测试都包含:
- **检索延迟** (search latency): mean, p50, p95, iqr
- **LLM 延迟** (llm latency): mean, p50, p95, iqr
- **上下文统计**: token 数、字符数、文档数
## 6.3 关键进步指标优先级
### 🥇 一级指标(必须关注)
1. **准确性指标**
- Locomo: `f1`, `locomo_f1`
- LongMemEval: `score_accuracy`, `accuracy_by_type`
- MemSciQA: `accuracy`, `f1`
- **目标**: > 85% 或 > 0.85
2. **综合性能**
- 所有测试的 F1 分数应保持一致性
- 不同类型问题的准确率应均衡
### 🥈 二级指标(重要)
3. **响应延迟**
- `latency.p95` (95分位数延迟)
- **目标**:
- search.p95 < 2000ms
- llm.p95 < 3000ms
- 总延迟 < 5000ms
4. **性能稳定性**
- `iqr` (四分位距)
- **目标**: 越小越好,说明性能稳定
### 🥉 三级指标(优化)
5. **成本效率**
- `avg_context_tokens`
- **目标**: 在保持准确性前提下最小化
- 成本效益比 = accuracy / avg_context_tokens
6. **检索质量**
- `avg_retrieved_docs` 的合理性
- `unique_preview_count` (LongMemEval)
- 检索内容的多样性和相关性
## 6.4 系统优化路径
### 阶段一:提升准确性(优先级最高)
**目标**: 所有测试的准确率 > 85%
**优化方向**:
1. 改进 embedding 模型质量
2. 优化检索算法hybrid search 参数)
3. 增加检索上下文数量(`search_limit`
4. 优化 prompt 工程
5. 改进记忆存储结构
**监控指标**:
- Locomo: `f1`, `locomo_f1`
- LongMemEval: `score_accuracy`, `exact_match` 比例
- MemSciQA: `accuracy`, `f1`
### 阶段二:优化性能(准确性达标后)
**目标**: p95 延迟 < 5 秒,性能稳定
**优化方向**:
1. 优化数据库索引和查询
2. 实施缓存策略
3. 使用更快的 LLM 模型
4. 并行化检索和推理
5. 减少不必要的检索
**监控指标**:
- `latency.p50`, `latency.p95`
- `iqr` (稳定性)
- 各阶段耗时分布
### 阶段三:降低成本(性能达标后)
**目标**: 在保持准确性和性能前提下,最小化成本
**优化方向**:
1. 精简检索上下文
2. 优化 context 选择策略
3. 使用更小的 LLM 模型
4. 实施智能缓存
5. 批处理优化
**监控指标**:
- `avg_context_tokens`
- 成本效益比 = accuracy / avg_context_tokens
- API 调用成本
## 6.5 评估最佳实践
### 测试执行建议
1. **初始测试**: 使用小样本快速验证
```bash
--sample_size 10
```
2. **完整评估**: 使用足够大的样本量
```bash
--sample_size 100 # 或更多
```
3. **增量测试**: 数据已摄入时跳过摄入阶段
```bash
--skip_ingest
```
4. **参数调优**: 系统性地调整参数并记录结果
- 调整 `search_limit`: 4, 8, 12, 16
- 调整 `context_char_budget`: 2000, 4000, 8000
- 尝试不同的 `search_type`: vector, keyword, hybrid
### 结果分析建议
1. **横向对比**: 比较三个测试的结果,识别系统的强弱项
2. **纵向对比**: 跟踪同一测试在不同版本的表现
3. **分类分析**: 关注不同问题类型的性能差异
4. **异常诊断**: 分析失败案例,找出根本原因
### 持续监控
建议建立监控仪表板,跟踪:
- 核心指标趋势(准确率、延迟)
- 成本效益比趋势
- 不同问题类型的性能分布
- 异常样本和失败模式
## 6.6 性能基准参考
### 优秀水平Production Ready
- **准确性**: accuracy/f1 > 0.90
- **延迟**: p95 < 3 秒
- **稳定性**: iqr < 500ms
- **成本效益**: accuracy/tokens > 0.0001
### 良好水平Acceptable
- **准确性**: accuracy/f1 > 0.85
- **延迟**: p95 < 5 秒
- **稳定性**: iqr < 1000ms
- **成本效益**: accuracy/tokens > 0.00005
### 需要改进Below Target
- **准确性**: accuracy/f1 < 0.85
- **延迟**: p95 > 5 秒
- **稳定性**: iqr > 1000ms
- **成本效益**: accuracy/tokens < 0.00005
---
**注**: 以上标准仅供参考,实际目标应根据具体业务需求和资源约束调整。

View File

@@ -0,0 +1,371 @@
"""
交互式 Neo4j End User 数据检查工具
用于查询指定 end_user_id 在 Neo4j 中是否存在数据,以及数据的详细统计信息。
使用方法:
python check_group_data.py
python check_group_data.py --group-id locomo_benchmark
python check_group_data.py --group-id memsciqa_benchmark --detailed
"""
import asyncio
import argparse
import os
from pathlib import Path
from typing import Dict, Any
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}\n")
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def check_group_exists(end_user_id: str) -> Dict[str, Any]:
"""
检查指定 end_user_id 是否存在数据
Args:
end_user_id: 要检查的 end_user ID
Returns:
包含统计信息的字典
"""
connector = Neo4jConnector()
try:
# 查询该 end_user 的节点总数
query_total = """
MATCH (n {end_user_id: $end_user_id})
RETURN count(n) as total_nodes
"""
result_total = await connector.execute_query(query_total, end_user_id=end_user_id)
total_nodes = result_total[0]["total_nodes"] if result_total else 0
# 查询各类型节点的数量
query_by_type = """
MATCH (n {end_user_id: $end_user_id})
RETURN labels(n) as labels, count(n) as count
ORDER BY count DESC
"""
result_by_type = await connector.execute_query(query_by_type, end_user_id=end_user_id)
# 查询关系数量
query_relationships = """
MATCH (n {end_user_id: $end_user_id})-[r]-()
RETURN count(DISTINCT r) as total_relationships
"""
result_rel = await connector.execute_query(query_relationships, end_user_id=end_user_id)
total_relationships = result_rel[0]["total_relationships"] if result_rel else 0
return {
"exists": total_nodes > 0,
"total_nodes": total_nodes,
"total_relationships": total_relationships,
"nodes_by_type": result_by_type
}
finally:
await connector.close()
async def get_detailed_stats(end_user_id: str) -> Dict[str, Any]:
"""
获取详细的统计信息
Args:
end_user_id: 要检查的 end_user ID
Returns:
详细统计信息字典
"""
connector = Neo4jConnector()
try:
stats = {}
# Chunk 节点统计
query_chunks = """
MATCH (c:Chunk {end_user_id: $end_user_id})
RETURN count(c) as count,
avg(size(c.content)) as avg_content_length
"""
result_chunks = await connector.execute_query(query_chunks, end_user_id=end_user_id)
if result_chunks and result_chunks[0]["count"] > 0:
stats["chunks"] = {
"count": result_chunks[0]["count"],
"avg_content_length": int(result_chunks[0]["avg_content_length"]) if result_chunks[0]["avg_content_length"] else 0
}
# Statement 节点统计
query_statements = """
MATCH (s:Statement {end_user_id: $end_user_id})
RETURN count(s) as count
"""
result_statements = await connector.execute_query(query_statements, end_user_id=end_user_id)
if result_statements and result_statements[0]["count"] > 0:
stats["statements"] = {
"count": result_statements[0]["count"]
}
# Entity 节点统计
query_entities = """
MATCH (e:Entity {end_user_id: $end_user_id})
RETURN count(e) as count,
count(DISTINCT e.entity_type) as unique_types
"""
result_entities = await connector.execute_query(query_entities, end_user_id=end_user_id)
if result_entities and result_entities[0]["count"] > 0:
stats["entities"] = {
"count": result_entities[0]["count"],
"unique_types": result_entities[0]["unique_types"]
}
# Dialogue 节点统计
query_dialogues = """
MATCH (d:Dialogue {end_user_id: $end_user_id})
RETURN count(d) as count
"""
result_dialogues = await connector.execute_query(query_dialogues, end_user_id=end_user_id)
if result_dialogues and result_dialogues[0]["count"] > 0:
stats["dialogues"] = {
"count": result_dialogues[0]["count"]
}
# Summary 节点统计
query_summaries = """
MATCH (s:Summary {end_user_id: $end_user_id})
RETURN count(s) as count
"""
result_summaries = await connector.execute_query(query_summaries, end_user_id=end_user_id)
if result_summaries and result_summaries[0]["count"] > 0:
stats["summaries"] = {
"count": result_summaries[0]["count"]
}
return stats
finally:
await connector.close()
async def list_all_end_users() -> list:
"""
列出数据库中所有的 end_user_id
Returns:
end_user_id 列表及其节点数量
"""
connector = Neo4jConnector()
try:
query = """
MATCH (n)
WHERE n.end_user_id IS NOT NULL
RETURN DISTINCT n.end_user_id as end_user_id, count(n) as node_count
ORDER BY node_count DESC
"""
results = await connector.execute_query(query)
return results
finally:
await connector.close()
def print_results(end_user_id: str, stats: Dict[str, Any], detailed_stats: Dict[str, Any] = None):
"""
打印查询结果
Args:
end_user_id: End User ID
stats: 基本统计信息
detailed_stats: 详细统计信息(可选)
"""
print(f"\n{'='*60}")
print(f"📊 End User ID: {end_user_id}")
print(f"{'='*60}\n")
if not stats["exists"]:
print("❌ 该 end_user_id 不存在数据")
print("\n💡 提示: 请先运行基准测试以摄入数据")
return
print(f"✅ 该 end_user_id 存在数据\n")
print(f"📈 基本统计:")
print(f" 总节点数: {stats['total_nodes']}")
print(f" 总关系数: {stats['total_relationships']}")
if stats["nodes_by_type"]:
print(f"\n📋 节点类型分布:")
for item in stats["nodes_by_type"]:
labels = ", ".join(item["labels"])
count = item["count"]
print(f" {labels}: {count}")
if detailed_stats:
print(f"\n🔍 详细统计:")
if "chunks" in detailed_stats:
print(f" Chunks: {detailed_stats['chunks']['count']}")
print(f" 平均内容长度: {detailed_stats['chunks']['avg_content_length']} 字符")
if "statements" in detailed_stats:
print(f" Statements: {detailed_stats['statements']['count']}")
if "entities" in detailed_stats:
print(f" Entities: {detailed_stats['entities']['count']}")
print(f" 唯一类型数: {detailed_stats['entities']['unique_types']}")
if "dialogues" in detailed_stats:
print(f" Dialogues: {detailed_stats['dialogues']['count']}")
if "summaries" in detailed_stats:
print(f" Summaries: {detailed_stats['summaries']['count']}")
print(f"\n{'='*60}\n")
async def interactive_mode():
"""
交互式模式
"""
print("\n" + "="*60)
print("🔍 Neo4j End User 数据检查工具 - 交互模式")
print("="*60 + "\n")
while True:
print("\n请选择操作:")
print(" 1. 检查指定 end_user_id")
print(" 2. 列出所有 end_user_id")
print(" 3. 退出")
choice = input("\n请输入选项 (1-3): ").strip()
if choice == "1":
end_user_id = input("\n请输入 end_user_id: ").strip()
if not end_user_id:
print("❌ end_user_id 不能为空")
continue
detailed = input("是否显示详细统计? (y/n, 默认 n): ").strip().lower() == 'y'
print("\n🔄 正在查询...")
stats = await check_group_exists(end_user_id)
detailed_stats = None
if detailed and stats["exists"]:
detailed_stats = await get_detailed_stats(end_user_id)
print_results(end_user_id, stats, detailed_stats)
elif choice == "2":
print("\n🔄 正在查询所有 end_user_id...")
end_users = await list_all_end_users()
if not end_users:
print("\n❌ 数据库中没有任何 end_user 数据")
else:
print(f"\n{'='*60}")
print(f"📋 数据库中的所有 End User ID")
print(f"{'='*60}\n")
for idx, end_user in enumerate(end_users, 1):
print(f" {idx}. {end_user['end_user_id']}")
print(f" 节点数: {end_user['node_count']}")
print(f"\n{'='*60}\n")
elif choice == "3":
print("\n👋 再见!")
break
else:
print("\n❌ 无效的选项,请重新选择")
async def main():
"""
主函数
"""
parser = argparse.ArgumentParser(
description="检查 Neo4j 中指定 end_user_id 的数据情况",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 交互模式
python check_group_data.py
# 检查指定 end_user
python check_group_data.py --end-user-id locomo_benchmark
# 检查并显示详细统计
python check_group_data.py --end-user-id memsciqa_benchmark --detailed
# 列出所有 end_user
python check_group_data.py --list-all
"""
)
parser.add_argument(
"--end-user-id",
type=str,
help="要检查的 end_user ID"
)
parser.add_argument(
"--detailed",
action="store_true",
help="显示详细统计信息"
)
parser.add_argument(
"--list-all",
action="store_true",
help="列出所有 end_user_id"
)
args = parser.parse_args()
# 如果没有提供任何参数,进入交互模式
if not args.end_user_id and not args.list_all:
await interactive_mode()
return
# 列出所有 end_user
if args.list_all:
print("\n🔄 正在查询所有 end_user_id...")
end_users = await list_all_end_users()
if not end_users:
print("\n❌ 数据库中没有任何 end_user 数据")
else:
print(f"\n{'='*60}")
print(f"📋 数据库中的所有 End User ID")
print(f"{'='*60}\n")
for idx, end_user in enumerate(end_users, 1):
print(f" {idx}. {end_user['end_user_id']}")
print(f" 节点数: {end_user['node_count']}")
print(f"\n{'='*60}\n")
return
# 检查指定 end_user
if args.end_user_id:
print(f"\n🔄 正在查询 end_user_id: {args.end_user_id}...")
stats = await check_group_exists(args.end_user_id)
detailed_stats = None
if args.detailed and stats["exists"]:
print("🔄 正在获取详细统计...")
detailed_stats = await get_detailed_stats(args.end_user_id)
print_results(args.end_user_id, stats, detailed_stats)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -2,7 +2,7 @@ import math
import re import re
from typing import List, Dict from typing import List, Dict
# 评估指标的实现
def _normalize(text: str) -> List[str]: def _normalize(text: str) -> List[str]:
"""Lowercase, strip punctuation, and split into tokens.""" """Lowercase, strip punctuation, and split into tokens."""
text = text.lower().strip() text = text.lower().strip()

View File

@@ -4,15 +4,17 @@ This file contains Cypher queries for searching dialogues, entities, and chunks.
Placed in evaluation directory to avoid circular imports with src modules. Placed in evaluation directory to avoid circular imports with src modules.
""" """
# 应该是neo4j browser的cypher语句需要修改文件名
# Entity search queries # Entity search queries
SEARCH_ENTITIES_BY_NAME = """ SEARCH_ENTITIES_BY_NAME = """
MATCH (e:Entity) MATCH (e:ExtractedEntity)
WHERE e.name = $name WHERE e.name = $name
RETURN e RETURN e
""" """
SEARCH_ENTITIES_BY_NAME_FALLBACK = """ SEARCH_ENTITIES_BY_NAME_FALLBACK = """
MATCH (e:Entity) MATCH (e:ExtractedEntity)
WHERE e.name CONTAINS $name WHERE e.name CONTAINS $name
RETURN e RETURN e
""" """

View File

@@ -1,34 +1,33 @@
import os
import asyncio import asyncio
import json import json
import os from typing import List, Dict, Any, Optional
import re
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from uuid import UUID
import re
from app.core.memory.llm_tools.openai_client import LLMClient from app.core.memory.llm_tools.openai_client import LLMClient
from app.core.memory.models.message_models import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
ConversationContext, from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
ConversationMessage, import os
DialogData, import sys
) from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent / "app" / "core" / "memory" / "evaluation" / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.llm.llm_utils import get_llm_client
# 使用新的模块化架构 # 使用新的模块化架构
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ( from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
ExtractionOrchestrator,
)
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
DialogueChunker,
)
from app.core.memory.utils.config.definitions import (
SELECTED_CHUNKER_STRATEGY,
SELECTED_EMBEDDING_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
# Import from database module # Import from database module
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# Cypher queries for evaluation # Cypher queries for evaluation
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py # Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
@@ -41,11 +40,14 @@ async def ingest_contexts_via_full_pipeline(
embedding_name: str | None = None, embedding_name: str | None = None,
save_chunk_output: bool = False, save_chunk_output: bool = False,
save_chunk_output_path: str | None = None, save_chunk_output_path: str | None = None,
reset_group: bool = False,
) -> bool: ) -> bool:
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator """
使用新的 ExtractionOrchestrator 运行完整的提取流水线
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j. Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
This function mirrors the steps in main(), but starts from raw text contexts. This function uses the new ExtractionOrchestrator architecture for better maintainability.
Args: Args:
contexts: List of dialogue texts, each containing lines like "role: message". contexts: List of dialogue texts, each containing lines like "role: message".
end_user_id: Group ID to assign to generated DialogData and graph nodes. end_user_id: Group ID to assign to generated DialogData and graph nodes.
@@ -53,25 +55,59 @@ async def ingest_contexts_via_full_pipeline(
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID. embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging. save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt. save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
reset_group: If True, clear existing data for this group before ingestion.
Returns: Returns:
True if data saved successfully, False otherwise. True if data saved successfully, False otherwise.
""" """
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY chunker_strategy = chunker_strategy or os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
embedding_name = embedding_name or SELECTED_EMBEDDING_ID embedding_name = embedding_name or os.getenv("EVAL_EMBEDDING_ID")
# Check if we should reset from environment variable if not explicitly set
if not reset_group:
reset_group = os.getenv("EVAL_RESET_ON_INGEST", "false").lower() in ("true", "1", "yes")
# Step 0: Reset group if requested
if reset_group:
print(f"[Ingestion] 🗑️ 清空 end_user '{end_user_id}' 的现有数据...")
try:
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
try:
# 删除该 end_user 的所有节点和关系
query = """
MATCH (n {end_user_id: $end_user_id})
DETACH DELETE n
"""
await connector.execute_query(query, end_user_id=end_user_id)
print(f"[Ingestion] ✅ End User '{end_user_id}' 已清空")
finally:
await connector.close()
except Exception as e:
print(f"[Ingestion] ⚠️ 清空 end_user 失败: {e}")
# 继续执行,不中断摄入流程
# Initialize llm client with graceful fallback # Step 1: Initialize LLM client
llm_client = None llm_client = None
llm_available = True
try: try:
from app.core.memory.utils.config import definitions as config_defs # 使用评估配置中的 LLM ID
with get_db_context() as db: llm_id = os.getenv("EVAL_LLM_ID")
factory = MemoryClientFactory(db) if not llm_id:
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID) print("[Ingestion] ❌ EVAL_LLM_ID not set in .env.evaluation")
return False
from app.db import get_db
db = next(get_db())
try:
llm_client = get_llm_client(llm_id, db)
finally:
db.close()
except Exception as e: except Exception as e:
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}") print(f"[Ingestion] LLM client unavailable: {e}")
llm_available = False return False
# Step A: Build DialogData list from contexts with robust parsing # Step 2: Parse contexts and create DialogData with chunks
print(f"[Ingestion] Parsing {len(contexts)} contexts...")
chunker = DialogueChunker(chunker_strategy) chunker = DialogueChunker(chunker_strategy)
dialog_data_list: List[DialogData] = [] dialog_data_list: List[DialogData] = []
@@ -94,7 +130,7 @@ async def ingest_contexts_via_full_pipeline(
line = raw.strip() line = raw.strip()
if not line: if not line:
continue continue
m = re.match(r'^\s*([^:]+)\s*[:]\s*(.+)$', line) m = re.match(r'^\s*([^:]+)\s*[:]\s*(.+)', line)
if m: if m:
role = m.group(1).strip() role = m.group(1).strip()
msg = m.group(2).strip() msg = m.group(2).strip()
@@ -118,10 +154,12 @@ async def ingest_contexts_via_full_pipeline(
dialog_data_list.append(dialog) dialog_data_list.append(dialog)
if not dialog_data_list: if not dialog_data_list:
print("No dialogs to process for ingestion.") print("[Ingestion] No dialogs to process.")
return False return False
# Optionally save chunking outputs for debugging print(f"[Ingestion] Parsed {len(dialog_data_list)} dialogs with chunks")
# Step 3: Optionally save chunking outputs for debugging
if save_chunk_output: if save_chunk_output:
try: try:
def _serialize_datetime(obj): def _serialize_datetime(obj):
@@ -137,124 +175,185 @@ async def ingest_contexts_via_full_pipeline(
combined_output = [dd.model_dump() for dd in dialog_data_list] combined_output = [dd.model_dump() for dd in dialog_data_list]
with open(out_path, "w", encoding="utf-8") as f: with open(out_path, "w", encoding="utf-8") as f:
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime) json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
print(f"Saved chunking results to: {out_path}") print(f"[Ingestion] Saved chunking results to: {out_path}")
except Exception as e: except Exception as e:
print(f"Failed to save chunking results: {e}") print(f"[Ingestion] Failed to save chunking results: {e}")
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线 # Step 4: Initialize embedder client
if not llm_available:
print("[Ingestion] Skipping extraction pipeline (no LLM).")
return False
# 初始化 embedder 客户端
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.services.memory_config_service import MemoryConfigService from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.db import get_db
try: try:
with get_db_context() as db: db = next(get_db())
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID) try:
embedder_config = RedBearModelConfig(**embedder_config_dict) embedder_config_dict = get_embedder_config(embedding_name, db)
embedder_client = OpenAIEmbedderClient(embedder_config) embedder_config = RedBearModelConfig(**embedder_config_dict)
embedder_client = OpenAIEmbedderClient(embedder_config)
finally:
db.close()
except Exception as e: except Exception as e:
print(f"[Ingestion] Failed to initialize embedder client: {e}") print(f"[Ingestion] Failed to initialize embedder client: {e}")
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
return False return False
# Step 5: Initialize Neo4j connector
connector = Neo4jConnector() connector = Neo4jConnector()
# 初始化并运行 ExtractionOrchestrator # Step 6: 构建 MemoryConfig从环境变量直接构建不依赖数据库
from app.core.memory.utils.config.config_utils import get_pipeline_config print("[Ingestion] 构建 MemoryConfig from environment variables...")
config = get_pipeline_config() from app.schemas.memory_config_schema import MemoryConfig
try:
# 从环境变量获取配置参数
llm_id = os.getenv("EVAL_LLM_ID")
embedding_id = os.getenv("EVAL_EMBEDDING_ID")
chunker_strategy_env = os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
if not llm_id or not embedding_id:
print("[Ingestion] ❌ EVAL_LLM_ID or EVAL_EMBEDDING_ID is not set in .env.evaluation")
print("[Ingestion] Please set both EVAL_LLM_ID and EVAL_EMBEDDING_ID")
await connector.close()
return False
# 从数据库获取模型信息(仅用于显示名称)
from app.db import get_db
db = next(get_db())
try:
from sqlalchemy import text
# 获取 LLM 模型信息(从 model_configs 表)
llm_result = db.execute(
text("SELECT name FROM model_configs WHERE id = :id"),
{"id": llm_id}
).fetchone()
llm_model_name = llm_result[0] if llm_result else "Unknown LLM"
# 获取 Embedding 模型信息(从 model_configs 表)
emb_result = db.execute(
text("SELECT name FROM model_configs WHERE id = :id"),
{"id": embedding_id}
).fetchone()
embedding_model_name = emb_result[0] if emb_result else "Unknown Embedding"
except Exception as e:
# 如果查询失败,使用默认名称
print(f"[Ingestion] Warning: Failed to query model names from database: {e}")
llm_model_name = f"LLM ({llm_id[:8]}...)"
embedding_model_name = f"Embedding ({embedding_id[:8]}...)"
finally:
db.close()
# 构建 MemoryConfig 对象(使用最小必需配置)
from uuid import uuid4
memory_config = MemoryConfig(
config_id=0, # 评估环境不需要真实的 config_id
config_name="evaluation_config",
workspace_id=uuid4(), # 临时 workspace_id
workspace_name="evaluation_workspace",
tenant_id=uuid4(), # 临时 tenant_id
llm_model_id=UUID(llm_id),
llm_model_name=llm_model_name,
embedding_model_id=UUID(embedding_id),
embedding_model_name=embedding_model_name,
storage_type="neo4j",
chunker_strategy=chunker_strategy_env,
reflexion_enabled=False,
reflexion_iteration_period=3,
reflexion_range="partial",
reflexion_baseline="TIME",
loaded_at=datetime.now(),
# 可选字段使用默认值
rerank_model_id=None,
rerank_model_name=None,
llm_params={},
embedding_params={},
config_version="2.0",
)
print(f"[Ingestion] ✅ 构建 MemoryConfig 成功")
print(f"[Ingestion] LLM: {llm_model_name}")
print(f"[Ingestion] Embedding: {embedding_model_name}")
print(f"[Ingestion] Chunker: {chunker_strategy_env}")
except Exception as e:
print(f"[Ingestion] ❌ Failed to build MemoryConfig: {e}")
print(f"[Ingestion] Please check:")
print(f"[Ingestion] 1. EVAL_LLM_ID and EVAL_EMBEDDING_ID are set in .env.evaluation")
print(f"[Ingestion] 2. Model IDs exist in the models table")
print(f"[Ingestion] 3. Database connection is working")
await connector.close()
return False
# Step 7: Initialize and run ExtractionOrchestrator
print("[Ingestion] Running extraction pipeline with ExtractionOrchestrator...")
from app.services.memory_config_service import MemoryConfigService
config = MemoryConfigService.get_pipeline_config(memory_config)
orchestrator = ExtractionOrchestrator( orchestrator = ExtractionOrchestrator(
llm_client=llm_client, llm_client=llm_client,
embedder_client=embedder_client, embedder_client=embedder_client,
connector=connector, connector=connector,
config=config, config=config,
embedding_id=str(memory_config.embedding_model_id), # 传递 embedding_id
) )
# 创建一个包装的 orchestrator 来修复时间提取器的输出 try:
# 保存原始的 _assign_extracted_data 方法 # Run the complete extraction pipeline
original_assign = orchestrator._assign_extracted_data result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
def clean_temporal_value(value):
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
if value is None:
return None
if isinstance(value, str):
# 处理字符串形式的 'null', 'None', 空字符串等
if value.lower() in ('null', 'none', '') or value.strip() == '':
return None
return value
async def patched_assign_extracted_data(*args, **kwargs):
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
result = await original_assign(*args, **kwargs)
# 清理返回的 dialog_data_list 中的 temporal_validity # Handle different return formats:
for dialog in result: # - Pilot mode: 7 values (without dedup_details)
if hasattr(dialog, 'chunks') and dialog.chunks: # - Normal mode: 8 values (with dedup_details at the end)
for chunk in dialog.chunks: if len(result) == 8:
if hasattr(chunk, 'statements') and chunk.statements: # Normal mode: includes dedup_details
for statement in chunk.statements: (
if hasattr(statement, 'temporal_validity') and statement.temporal_validity: dialogue_nodes,
tv = statement.temporal_validity chunk_nodes,
# 清理 valid_at 和 invalid_at statement_nodes,
if hasattr(tv, 'valid_at'): entity_nodes,
tv.valid_at = clean_temporal_value(tv.valid_at) statement_chunk_edges,
if hasattr(tv, 'invalid_at'): statement_entity_edges,
tv.invalid_at = clean_temporal_value(tv.invalid_at) entity_entity_edges,
return result _, # dedup_details - not needed here
) = result
# 替换方法 elif len(result) == 7:
orchestrator._assign_extracted_data = patched_assign_extracted_data # Pilot mode or older version: no dedup_details
(
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理 dialogue_nodes,
original_create = orchestrator._create_nodes_and_edges chunk_nodes,
statement_nodes,
async def patched_create_nodes_and_edges(dialog_data_list_arg): entity_nodes,
"""包装方法:在创建节点前再次清理 temporal_validity""" statement_chunk_edges,
# 最后一次清理,确保万无一失 statement_entity_edges,
for dialog in dialog_data_list_arg: entity_entity_edges,
if hasattr(dialog, 'chunks') and dialog.chunks: ) = result
for chunk in dialog.chunks: else:
if hasattr(chunk, 'statements') and chunk.statements: raise ValueError(f"Unexpected number of return values: {len(result)}")
for statement in chunk.statements:
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
tv = statement.temporal_validity
if hasattr(tv, 'valid_at'):
tv.valid_at = clean_temporal_value(tv.valid_at)
if hasattr(tv, 'invalid_at'):
tv.invalid_at = clean_temporal_value(tv.invalid_at)
return await original_create(dialog_data_list_arg) print(f"[Ingestion] Extraction completed: {len(statement_nodes)} statements, {len(entity_nodes)} entities")
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges except ValueError as e:
# If unpacking fails, provide helpful error message
# 运行完整的提取流水线 print(f"[Ingestion] Extraction pipeline result unpacking failed: {e}")
# orchestrator.run 返回 7 个元素的元组 print(f"[Ingestion] Result type: {type(result)}, length: {len(result) if hasattr(result, '__len__') else 'N/A'}")
result = await orchestrator.run(dialog_data_list, is_pilot_run=False) if hasattr(result, '__len__') and len(result) > 0:
( print(f"[Ingestion] First element type: {type(result[0])}")
dialogue_nodes, await connector.close()
chunk_nodes, return False
statement_nodes, except Exception as e:
entity_nodes, print(f"[Ingestion] Extraction pipeline failed: {e}")
statement_chunk_edges, import traceback
statement_entity_edges, traceback.print_exc()
entity_entity_edges, await connector.close()
) = result return False
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
# Step G: 生成记忆摘要 # Step 7: Generate memory summaries
print("[Ingestion] Generating memory summaries...") print("[Ingestion] Generating memory summaries...")
try: try:
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import ( from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
memory_summary_generation, memory_summary_generation,
) )
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
summaries = await memory_summary_generation( summaries = await memory_summary_generation(
chunked_dialogs=dialog_data_list, chunked_dialogs=dialog_data_list,
@@ -266,7 +365,8 @@ async def ingest_contexts_via_full_pipeline(
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}") print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
summaries = [] summaries = []
# Step H: Save to Neo4j # Step 8: Save to Neo4j
print("[Ingestion] Saving to Neo4j...")
try: try:
success = await save_dialog_and_statements_to_neo4j( success = await save_dialog_and_statements_to_neo4j(
dialogue_nodes=dialogue_nodes, dialogue_nodes=dialogue_nodes,
@@ -284,18 +384,21 @@ async def ingest_contexts_via_full_pipeline(
try: try:
await add_memory_summary_nodes(summaries, connector) await add_memory_summary_nodes(summaries, connector)
await add_memory_summary_statement_edges(summaries, connector) await add_memory_summary_statement_edges(summaries, connector)
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j") print(f"[Ingestion] Saved {len(summaries)} memory summary nodes to Neo4j")
except Exception as e: except Exception as e:
print(f"Warning: Failed to save summary nodes: {e}") print(f"[Ingestion] Warning: Failed to save summary nodes: {e}")
await connector.close() await connector.close()
if success: if success:
print("Successfully saved extracted data to Neo4j!") print("[Ingestion] Successfully saved all data to Neo4j!")
else: else:
print("Failed to save data to Neo4j") print("[Ingestion] Failed to save data to Neo4j")
return success return success
except Exception as e: except Exception as e:
print(f"Failed to save data to Neo4j: {e}") print(f"[Ingestion] Failed to save data to Neo4j: {e}")
await connector.close()
return False return False

File diff suppressed because it is too large Load Diff

View File

@@ -1,30 +1,29 @@
# file name: check_neo4j_connection_fixed.py # file name: check_neo4j_connection_fixed.py
import asyncio import asyncio
import json
import math
import os import os
import re
import sys import sys
import json
import time import time
import math
import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List from typing import List, Dict, Any
from pathlib import Path from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
# 1 # Load main .env
# 添加项目根目录到路径
current_dir = Path(__file__).resolve().parent
project_root = str(current_dir.parent)
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
src_dir = os.path.join(project_root, "src")
if src_dir not in sys.path:
sys.path.insert(0, src_dir)
load_dotenv() load_dotenv()
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
# Get group_id from config
group_id = os.getenv("EVAL_GROUP_ID", "locomo_test")
print(f"✅ 使用配置的 group_id: {group_id}")
# 首先定义 _loc_normalize 函数,因为其他函数依赖它 # 首先定义 _loc_normalize 函数,因为其他函数依赖它
def _loc_normalize(text: str) -> str: def _loc_normalize(text: str) -> str:
text = str(text) if text is not None else "" text = str(text) if text is not None else ""
@@ -37,7 +36,7 @@ def _loc_normalize(text: str) -> str:
# 尝试从 metrics.py 导入基础指标 # 尝试从 metrics.py 导入基础指标
try: try:
from common.metrics import bleu1, f1_score, jaccard from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
print("✅ 从 metrics.py 导入基础指标成功") print("✅ 从 metrics.py 导入基础指标成功")
except ImportError as e: except ImportError as e:
print(f"❌ 从 metrics.py 导入失败: {e}") print(f"❌ 从 metrics.py 导入失败: {e}")
@@ -107,23 +106,8 @@ except ImportError as e:
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标 # 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
try: try:
# 添加 evaluation 目录路径 from app.core.memory.evaluation.locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
evaluation_dir = os.path.join(project_root, "evaluation") print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
if evaluation_dir not in sys.path:
sys.path.insert(0, evaluation_dir)
# 尝试从不同位置导入
try:
from locomo.qwen_search_eval import (
_resolve_relative_times,
loc_f1_score,
loc_multi_f1,
)
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError:
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
except ImportError as e: except ImportError as e:
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}") print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
# 回退到本地实现 LoCoMo 特定函数 # 回退到本地实现 LoCoMo 特定函数
@@ -429,31 +413,36 @@ def enhanced_context_selection(contexts: List[str], question: str, question_inde
async def run_enhanced_evaluation(): async def run_enhanced_evaluation():
"""使用增强方法进行完整评估 - 解决中间性能衰减问题""" """使用增强方法进行完整评估 - 解决中间性能衰减问题"""
try: from dotenv import load_dotenv
from dotenv import load_dotenv from uuid import UUID
except Exception: from datetime import datetime
def load_dotenv(): from dataclasses import dataclass
return None
# 修正导入路径:使用 app.core.memory.src 前缀 # 修正导入路径:使用 app.core.memory.src 前缀
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.neo4j.graph_search import search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.schemas.memory_config_schema import MemoryConfig
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
# Get model IDs from config
llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f")
embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
# 加载数据 # 加载数据 - 使用统一的 dataset 目录
# 获取项目根目录 data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "dataset", "locomo10.json")
current_file = os.path.abspath(__file__)
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录 if not os.path.exists(data_path):
memory_dir = os.path.dirname(evaluation_dir) # memory目录 raise FileNotFoundError(
data_path = os.path.join(memory_dir, "data", "locomo10.json") f"数据集文件不存在: {data_path}\n"
f"请将 locomo10.json 放置在: api/app/core/memory/evaluation/dataset/"
)
print(f"✅ 找到数据文件: {data_path}")
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f) raw = json.load(f)
@@ -463,64 +452,109 @@ async def run_enhanced_evaluation():
qa_items.extend(entry.get("qa", [])) qa_items.extend(entry.get("qa", []))
else: else:
qa_items.extend(raw.get("qa", [])) qa_items.extend(raw.get("qa", []))
items = qa_items[:20] # 测试多少个问题 # 测试多少个问题 - 可通过环境变量设置
sample_size = int(os.getenv("LOCOMO_SAMPLE_SIZE", "20"))
items = qa_items[:sample_size]
print(f"📊 将测试 {len(items)} 个问题(总共 {len(qa_items)} 个可用)")
# 初始化增强监控器 # 初始化增强监控器
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6) monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
with get_db_context() as db: # 获取数据库会话并初始化 LLM 客户端
factory = MemoryClientFactory(db) from app.db import get_db
llm = factory.get_llm_client(SELECTED_LLM_ID) db = next(get_db())
# 初始化embedder
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 初始化连接器
connector = Neo4jConnector()
# 初始化结果字典
results = {
"questions": [],
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
"category_metrics": {},
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
"performance_trend": "stable",
"timestamp": datetime.now().isoformat(),
"enhanced_strategy": True
}
total_f1 = 0.0
total_bleu1 = 0.0
total_jaccard = 0.0
total_loc_f1 = 0.0
total_context_length = 0
total_retrieved_docs = 0
category_stats = {}
try: try:
for i, item in enumerate(items): llm = get_llm_client(llm_id, db)
monitor.question_count += 1
# 初始化embedder
cfg_dict = get_embedder_config(embedding_id, db)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 🔧 创建 MemoryConfig 对象用于搜索
# 方案1如果有配置ID从数据库加载
config_id = os.getenv("EVAL_CONFIG_ID")
if config_id:
print(f"📋 从数据库加载配置 ID: {config_id}")
memory_config_service = MemoryConfigService(db)
memory_config = memory_config_service.load_memory_config(config_id, service_name="locomo_test")
else:
# 方案2创建临时配置对象用于测试
print(f"📋 创建临时测试配置")
from uuid import UUID
from datetime import datetime
# 将字符串 ID 转换为 UUID
try:
embedding_uuid = UUID(embedding_id)
llm_uuid = UUID(llm_id)
except ValueError as e:
raise ValueError(f"无效的 UUID 格式: {e}")
memory_config = MemoryConfig(
config_id=1, # 临时 ID
config_name="locomo_test_config",
workspace_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 workspace
workspace_name="test_workspace",
tenant_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 tenant
embedding_model_id=embedding_uuid,
embedding_model_name="test_embedding",
llm_model_id=llm_uuid,
llm_model_name="test_llm",
storage_type="neo4j",
chunker_strategy="RecursiveChunker",
reflexion_enabled=False,
reflexion_iteration_period=3,
reflexion_range="partial",
reflexion_baseline="Time",
loaded_at=datetime.now()
)
print(f"✅ MemoryConfig 已准备: embedding_id={memory_config.embedding_model_id}, llm_id={memory_config.llm_model_id}")
# 初始化连接器
connector = Neo4jConnector()
# 获取近期性能用于重置判断 # 初始化结果字典
recent_performance = monitor.get_recent_performance() results = {
"questions": [],
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
"category_metrics": {},
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
"performance_trend": "stable",
"timestamp": datetime.now().isoformat(),
"enhanced_strategy": True
}
# 增强的重置判断 total_f1 = 0.0
should_reset = monitor.should_reset_connections(current_f1=recent_performance) total_bleu1 = 0.0
if should_reset and i > 0: total_jaccard = 0.0
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...") total_loc_f1 = 0.0
await connector.close() total_context_length = 0
connector = Neo4jConnector() # 创建新连接 total_retrieved_docs = 0
print("✅ 连接重置完成") category_stats = {}
q = item.get("question", "") try:
ref = item.get("answer", "") for i, item in enumerate(items):
ref_str = str(ref) if ref is not None else "" monitor.question_count += 1
# 获取近期性能用于重置判断
recent_performance = monitor.get_recent_performance()
# 增强的重置判断
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
if should_reset and i > 0:
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
await connector.close()
connector = Neo4jConnector() # 创建新连接
print("✅ 连接重置完成")
q = item.get("question", "")
ref = item.get("answer", "")
ref_str = str(ref) if ref is not None else ""
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}") print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
print(f"✅ 真实答案: {ref_str}") print(f"✅ 真实答案: {ref_str}")
@@ -548,10 +582,12 @@ async def run_enhanced_evaluation():
contexts_all = [] contexts_all = []
try: try:
# 使用统一的搜索服务 # 使用旧版本的搜索服务(重构前的版本)
from app.core.memory.storage_services.search import run_hybrid_search from app.core.memory.src.search import run_hybrid_search
print("🔀 使用混合搜索服务...") print(f"🔀 使用混合搜索服务(旧版本)...")
print(f"📍 检索参数: group_id={group_id}, limit=20, search_type=hybrid")
print(f"📍 查询文本: {q}")
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=q, query_text=q,
@@ -559,15 +595,27 @@ async def run_enhanced_evaluation():
end_user_id="locomo_sk", end_user_id="locomo_sk",
limit=20, limit=20,
include=["statements", "chunks", "entities", "summaries"], include=["statements", "chunks", "entities", "summaries"],
alpha=0.6, # BM25权重 output_path=None,
embedding_id=SELECTED_EMBEDDING_ID memory_config=memory_config, # 🔧 添加必需的 memory_config 参数
rerank_alpha=0.6, # BM25权重
use_forgetting_rerank=False,
use_llm_rerank=False
) )
# 处理搜索结果 - 新的搜索服务返回统一的结构 # 处理搜索结果 - 旧版本返回包含 reranked_results 的结构
chunks = search_results.get("chunks", []) # 对于 hybrid 搜索,使用 reranked_results
statements = search_results.get("statements", []) if "reranked_results" in search_results:
entities = search_results.get("entities", []) reranked = search_results["reranked_results"]
summaries = search_results.get("summaries", []) chunks = reranked.get("chunks", [])
statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
# 单一搜索类型的结果
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要") print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
@@ -609,6 +657,8 @@ async def run_enhanced_evaluation():
print(f"📊 有效上下文数量: {len(contexts_all)}") print(f"📊 有效上下文数量: {len(contexts_all)}")
except Exception as e: except Exception as e:
print(f"❌ 检索失败: {e}") print(f"❌ 检索失败: {e}")
import traceback
print(f"详细错误信息:\n{traceback.format_exc()}")
contexts_all = [] contexts_all = []
t1 = time.time() t1 = time.time()
@@ -728,14 +778,17 @@ async def run_enhanced_evaluation():
print("="*60) print("="*60)
except Exception as e: except Exception as e:
print(f"❌ 评估过程中发生错误: {e}") print(f"❌ 评估过程中发生错误: {e}")
# 即使出错,也返回已有的结果 # 即使出错,也返回已有的结果
import traceback import traceback
traceback.print_exc() traceback.print_exc()
finally:
await connector.close()
finally: finally:
await connector.close() db.close() # 关闭数据库会话
# 计算总体指标 # 计算总体指标
n = len(items) n = len(items)

View File

@@ -15,8 +15,14 @@ import json
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
from app.core.memory.utils.definitions import PROJECT_ROOT
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
@@ -82,7 +88,7 @@ def load_locomo_data(
return qa_items[:sample_size] return qa_items[:sample_size]
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]: def extract_conversations(data_path: str, max_dialogues: int = 1, max_messages_per_dialogue: Optional[int] = None) -> List[str]:
""" """
Extract conversation texts from LoCoMo data for ingestion. Extract conversation texts from LoCoMo data for ingestion.
@@ -93,6 +99,7 @@ def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
Args: Args:
data_path: Path to locomo10.json file data_path: Path to locomo10.json file
max_dialogues: Maximum number of dialogues to extract (default: 1) max_dialogues: Maximum number of dialogues to extract (default: 1)
max_messages_per_dialogue: Maximum messages per dialogue (default: None = all messages)
Returns: Returns:
List of conversation strings formatted for ingestion. List of conversation strings formatted for ingestion.
@@ -141,13 +148,21 @@ def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
continue continue
lines.append(f"{role}: {text}") lines.append(f"{role}: {text}")
# Limit messages if specified
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
break
# Break outer loop if we've reached the message limit
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
break
if lines: if lines:
contents.append("\n".join(lines)) contents.append("\n".join(lines))
return contents return contents
# 时间解析:将相对时间表达转换为绝对日期
def resolve_temporal_references(text: str, anchor_date: datetime) -> str: def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
""" """
Resolve relative temporal references to absolute dates. Resolve relative temporal references to absolute dates.
@@ -225,6 +240,8 @@ def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
t, t,
flags=re.IGNORECASE flags=re.IGNORECASE
) )
# 中文支持
t = re.sub( t = re.sub(
r"\bnext\s+week\b", r"\bnext\s+week\b",
(anchor_date + timedelta(days=7)).date().isoformat(), (anchor_date + timedelta(days=7)).date().isoformat(),
@@ -345,6 +362,50 @@ def select_and_format_information(
return "\n\n".join(selected) return "\n\n".join(selected)
# 记忆系统核心能力:写入与读取
async def ingest_conversations_if_needed(
conversations: List[str],
end_user_id: str,
reset: bool = False
) -> bool:
"""
Wrapper for conversation ingestion using external extraction pipeline.
This function populates the Neo4j database with processed conversation data
(chunks, statements, entities) so that the retrieval system has memory to search.
The ingestion process:
1. Parses conversation text into dialogue messages
2. Chunks the dialogues into semantic units
3. Extracts statements and entities using LLM
4. Generates embeddings for all content
5. Stores everything in Neo4j graph database
Args:
conversations: List of raw conversation texts from LoCoMo dataset
Example: ["User: I went to Paris. AI: When was that?", ...]
end_user_id: Target end_user ID for database storage
reset: Whether to clear existing data first (not implemented in wrapper)
Returns:
True if successful, False otherwise
Note:
The external function uses "contexts" to mean "conversation texts".
This runs the full extraction pipeline: chunking → entity extraction →
statement extraction → embedding → Neo4j storage.
"""
try:
success = await ingest_contexts_via_full_pipeline(
contexts=conversations,
end_user_id=end_user_id,
save_chunk_output=True,
reset_group=reset
)
return success
except Exception as e:
print(f"[Ingestion] Failed to ingest conversations: {e}")
return False
async def retrieve_relevant_information( async def retrieve_relevant_information(
question: str, question: str,
@@ -385,7 +446,7 @@ async def retrieve_relevant_information(
search_graph, search_graph,
search_graph_by_embedding search_graph_by_embedding
) )
from app.core.memory.storage_services.search import run_hybrid_search from app.core.memory.src.search import run_hybrid_search
contexts_all: List[str] = [] contexts_all: List[str] = []

View File

@@ -2,43 +2,29 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import statistics
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List from typing import List, Dict, Any
import statistics
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
import re import re
from pathlib import Path
from dotenv import load_dotenv
# Load evaluation config
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
bleu1,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现) # 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
@@ -265,7 +251,10 @@ async def run_locomo_eval(
end_user_id = end_user_id or SELECTED_end_user_id end_user_id = end_user_id or SELECTED_end_user_id
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json") data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
if not os.path.exists(data_path): if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "locomo10.json") raise FileNotFoundError(
f"数据集文件不存在: {data_path}\n"
f"请将 locomo10.json 放置在: {dataset_dir}"
)
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
raw = json.load(f) raw = json.load(f)
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表 # LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
@@ -343,13 +332,9 @@ async def run_locomo_eval(
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True) await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
# 使用异步LLM客户端 # 使用异步LLM客户端
with get_db_context() as db: llm_client = get_llm_client(llm_id)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化embedder用于直接调用 # 初始化embedder用于直接调用
with get_db_context() as db: cfg_dict = get_embedder_config(embedding_id)
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )
@@ -480,8 +465,8 @@ async def run_locomo_eval(
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}") contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
else: # hybrid else: # hybrid
# 🎯 关键修复:混合检索使用更严格的回退机制 # 使用旧版本的混合检索(重构前)
print("🔀 使用混合检索(带回退机制...") print("🔀 使用混合检索(旧版本...")
try: try:
search_results = await run_hybrid_search( search_results = await run_hybrid_search(
query_text=q, query_text=q,
@@ -490,16 +475,26 @@ async def run_locomo_eval(
limit=adjusted_limit, limit=adjusted_limit,
include=["chunks", "statements", "entities", "summaries"], include=["chunks", "statements", "entities", "summaries"],
output_path=None, output_path=None,
rerank_alpha=0.6,
use_forgetting_rerank=False,
use_llm_rerank=False
) )
# 🎯 关键修复:正确处理混合检索的扁平结构 # 处理旧版本的返回结构(包含 reranked_results
# 新的API返回扁平结构直接从顶层获取结果
if search_results and isinstance(search_results, dict): if search_results and isinstance(search_results, dict):
# 新API返回扁平结构直接从顶层获取 # 对于 hybrid 搜索,使用 reranked_results
chunks = search_results.get("chunks", []) if "reranked_results" in search_results:
statements = search_results.get("statements", []) reranked = search_results["reranked_results"]
entities = search_results.get("entities", []) chunks = reranked.get("chunks", [])
summaries = search_results.get("summaries", []) statements = reranked.get("statements", [])
entities = reranked.get("entities", [])
summaries = reranked.get("summaries", [])
else:
# 单一搜索类型的结果
chunks = search_results.get("chunks", [])
statements = search_results.get("statements", [])
entities = search_results.get("entities", [])
summaries = search_results.get("summaries", [])
# 检查是否有有效结果 # 检查是否有有效结果
if chunks or statements or entities or summaries: if chunks or statements or entities or summaries:
@@ -799,8 +794,9 @@ async def run_locomo_eval(
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"search_type": search_type, "search_type": search_type,
"llm_id": SELECTED_LLM_ID, "llm_id": llm_id,
"retrieval_embedding_id": SELECTED_EMBEDDING_ID, "retrieval_embedding_id": embedding_id,
"chunker_strategy": os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker"),
"skip_ingest_if_exists": skip_ingest_if_exists, "skip_ingest_if_exists": skip_ingest_if_exists,
"llm_timeout": llm_timeout, "llm_timeout": llm_timeout,
"llm_max_retries": llm_max_retries, "llm_max_retries": llm_max_retries,

View File

@@ -2,100 +2,67 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import time
import re import re
import statistics import statistics
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List from typing import List, Dict, Any
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 确保可以找到 src 及项目根路径
import sys
from pathlib import Path from pathlib import Path
_THIS_DIR = Path(__file__).resolve().parent from dotenv import load_dotenv
_PROJECT_ROOT = str(_THIS_DIR.parents[2])
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") # Load evaluation config
for _p in (_SRC_DIR, _PROJECT_ROOT): eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if _p not in sys.path: if eval_config_path.exists():
sys.path.insert(0, _p) load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
# 与现有评估脚本保持一致的导入方式
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
try:
# 优先从 extraction_utils1 导入
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline, # type: ignore
)
except Exception:
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.services.memory_config_service import MemoryConfigService from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
try: from app.core.memory.utils.config.config_utils import get_embedder_config
from app.core.memory.evaluation.common.metrics import exact_match from app.core.memory.utils.llm.llm_utils import get_llm_client
except Exception: from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
# 兜底:简单的大小写不敏感比较 from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
def exact_match(pred: str, ref: str) -> bool: from app.core.memory.evaluation.common.metrics import exact_match
return str(pred).strip().lower() == str(ref).strip().lower()
def load_dataset_any(path: str) -> List[Dict[str, Any]]: def load_dataset_any(path: str) -> List[Dict[str, Any]]:
"""健壮地加载数据集(兼容 list 或多段 JSON""" """健壮地加载数据集,支持三种格式:
1. 标准 JSON 数组: [{...}, {...}]
2. 单个 JSON 对象: {...}
3. JSONL 格式每行一个 JSON: {...}\n{...}\n{...}
"""
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
s = f.read().strip() content = f.read().strip()
# 尝试标准 JSON 解析
try: try:
obj = json.loads(s) data = json.loads(content)
if isinstance(obj, list): if isinstance(data, list):
return obj return [item for item in data if isinstance(item, dict)]
elif isinstance(obj, dict): elif isinstance(data, dict):
return [obj] return [data]
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
dec = json.JSONDecoder()
idx = 0 # 尝试 JSONL 格式(每行一个 JSON 对象)
items: List[Dict[str, Any]] = [] items = []
while idx < len(s): for line in content.splitlines():
while idx < len(s) and s[idx].isspace(): line = line.strip()
idx += 1 if not line:
if idx >= len(s): continue
break
try: try:
obj, end = dec.raw_decode(s, idx) obj = json.loads(line)
if isinstance(obj, list): if isinstance(obj, dict):
for it in obj:
if isinstance(it, dict):
items.append(it)
elif isinstance(obj, dict):
items.append(obj) items.append(obj)
idx = end elif isinstance(obj, list):
items.extend(item for item in obj if isinstance(item, dict))
except json.JSONDecodeError: except json.JSONDecodeError:
nl = s.find("\n", idx) continue
if nl == -1:
break
idx = nl + 1
return items return items
@@ -624,7 +591,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
async def run_longmemeval_test( async def run_longmemeval_test(
sample_size: int = 3, sample_size: int = 3,
end_user_id: str = "longmemeval_zh_bak_3", end_user_id: str | None = None,
search_limit: int = 8, search_limit: int = 8,
context_char_budget: int = 4000, context_char_budget: int = 4000,
llm_temperature: float = 0.0, llm_temperature: float = 0.0,
@@ -639,18 +606,22 @@ async def run_longmemeval_test(
skip_ingest: bool = False, skip_ingest: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""LongMemEval 评估测试:增强时间推理能力""" """LongMemEval 评估测试:增强时间推理能力"""
# Use environment variable with fallback chain
if end_user_id is None:
end_user_id = os.getenv("LONGMEMEVAL_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "longmemeval_zh_bak_3")
# 数据路径 # 数据路径
if not data_path: if not data_path:
# 固定使用中文数据集data/longmemeval_oracle_zh.json # 固定使用中文数据集dataset/longmemeval_oracle_zh.json
zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json") dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json") data_path = str(dataset_dir / "longmemeval_oracle_zh.json")
if os.path.exists(zh_proj):
data_path = zh_proj if not os.path.exists(data_path):
elif os.path.exists(zh_cwd): raise FileNotFoundError(
data_path = zh_cwd f"数据集文件不存在: {data_path}\n"
else: f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}"
raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json请确保其存在于项目根目录或当前工作目录的 data 目录下。") )
qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) qa_list: List[Dict[str, Any]] = load_dataset_any(data_path)
# 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾
@@ -702,16 +673,19 @@ async def run_longmemeval_test(
) )
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端 # 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
with get_db_context() as db: from app.db import get_db
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID) db = next(get_db())
try:
llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db)
cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"), db)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
finally:
db.close()
connector = Neo4jConnector() connector = Neo4jConnector()
with get_db_context() as db:
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict)
)
# 指标收集 # 指标收集
latencies_llm: List[float] = [] latencies_llm: List[float] = []
@@ -768,10 +742,10 @@ async def run_longmemeval_test(
if stmt_text: if stmt_text:
contexts_all.append(stmt_text) contexts_all.append(stmt_text)
# for sm in summaries: for sm in summaries:
# summary_text = str(sm.get("summary", "")).strip() summary_text = str(sm.get("summary", "")).strip()
# if summary_text: if summary_text:
# contexts_all.append(summary_text) contexts_all.append(summary_text)
# 实体摘要最多3个 # 实体摘要最多3个
scored = [e for e in entities if e.get("score") is not None] scored = [e for e in entities if e.get("score") is not None]
@@ -1228,8 +1202,8 @@ async def run_longmemeval_test(
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"search_type": search_type, "search_type": search_type,
"llm_id": SELECTED_LLM_ID, "llm_id": os.getenv("EVAL_LLM_ID"),
"embedding_id": SELECTED_EMBEDDING_ID, "embedding_id": os.getenv("EVAL_EMBEDDING_ID"),
"sample_size": sample_size, "sample_size": sample_size,
"start_index": start_index, "start_index": start_index,
}, },
@@ -1288,7 +1262,7 @@ def main():
parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)") parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)")
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size") parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size")
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
parser.add_argument("--group-id", type=str, default="longmemeval_zh_bak_3", help="图数据库 Group ID") parser.add_argument("--end-user-id", type=str, default=None, help="图数据库 End User ID默认使用环境变量")
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限") parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
@@ -1349,7 +1323,8 @@ def main():
# 保存结果到文件 # 保存结果到文件
try: try:
out_dir = os.path.join(PROJECT_ROOT, "evaluation", "longmemeval", "results") # 使用相对路径而不是 PROJECT_ROOT
out_dir = Path(__file__).resolve().parent / "results"
os.makedirs(out_dir, exist_ok=True) os.makedirs(out_dir, exist_ok=True)
ts = datetime.now().strftime("%Y%m%d_%H%M%S") ts = datetime.now().strftime("%Y%m%d_%H%M%S")
out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json") out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json")

View File

@@ -2,81 +2,67 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import time
import re import re
import statistics import statistics
import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List from typing import List, Dict, Any
from pathlib import Path
try: from dotenv import load_dotenv
from dotenv import load_dotenv
except Exception: # Load evaluation config
def load_dotenv(): eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
return None if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
# 与现有评估脚本保持一致的导入方式 # 与现有评估脚本保持一致的导入方式
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
jaccard,
latency_stats,
)
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
try: from app.core.models.base import RedBearModelConfig
from app.core.memory.evaluation.common.metrics import exact_match from app.core.memory.utils.config.config_utils import get_embedder_config
except Exception: from app.core.memory.utils.llm.llm_utils import get_llm_client
# 兜底:简单的大小写不敏感比较 from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
def exact_match(pred: str, ref: str) -> bool: from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
return str(pred).strip().lower() == str(ref).strip().lower() from app.core.memory.evaluation.common.metrics import exact_match
def load_dataset_any(path: str) -> List[Dict[str, Any]]: def load_dataset_any(path: str) -> List[Dict[str, Any]]:
"""健壮地加载数据集(兼容 list 或多段 JSON""" """健壮地加载数据集,支持三种格式:
1. 标准 JSON 数组: [{...}, {...}]
2. 单个 JSON 对象: {...}
3. JSONL 格式(每行一个 JSON: {...}\n{...}\n{...}
"""
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
s = f.read().strip() content = f.read().strip()
# 尝试标准 JSON 解析
try: try:
obj = json.loads(s) data = json.loads(content)
if isinstance(obj, list): if isinstance(data, list):
return obj return [item for item in data if isinstance(item, dict)]
elif isinstance(obj, dict): elif isinstance(data, dict):
return [obj] return [data]
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
dec = json.JSONDecoder()
idx = 0 # 尝试 JSONL 格式(每行一个 JSON 对象)
items: List[Dict[str, Any]] = [] items = []
while idx < len(s): for line in content.splitlines():
while idx < len(s) and s[idx].isspace(): line = line.strip()
idx += 1 if not line:
if idx >= len(s): continue
break
try: try:
obj, end = dec.raw_decode(s, idx) obj = json.loads(line)
if isinstance(obj, list): if isinstance(obj, dict):
for it in obj:
if isinstance(it, dict):
items.append(it)
elif isinstance(obj, dict):
items.append(obj) items.append(obj)
idx = end elif isinstance(obj, list):
items.extend(item for item in obj if isinstance(item, dict))
except json.JSONDecodeError: except json.JSONDecodeError:
nl = s.find("\n", idx) continue
if nl == -1:
break
idx = nl + 1
return items return items
@@ -640,15 +626,15 @@ async def run_longmemeval_test(
# 数据路径 # 数据路径
if not data_path: if not data_path:
# 固定使用中文数据集data/longmemeval_oracle_zh.json # 固定使用中文数据集dataset/longmemeval_oracle_zh.json
zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json") dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json") data_path = str(dataset_dir / "longmemeval_oracle_zh.json")
if os.path.exists(zh_proj):
data_path = zh_proj if not os.path.exists(data_path):
elif os.path.exists(zh_cwd): raise FileNotFoundError(
data_path = zh_cwd f"数据集文件不存在: {data_path}\n"
else: f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}"
raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json请确保其存在于项目根目录或当前工作目录的 data 目录下。") )
qa_list: List[Dict[str, Any]] = load_dataset_any(data_path) qa_list: List[Dict[str, Any]] = load_dataset_any(data_path)
# 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾 # 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾
@@ -658,13 +644,9 @@ async def run_longmemeval_test(
items = qa_list[start_index:start_index + sample_size] items = qa_list[start_index:start_index + sample_size]
# 初始化组件 - 使用异步LLM客户端 # 初始化组件 - 使用异步LLM客户端
with get_db_context() as db: llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"))
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
connector = Neo4jConnector() connector = Neo4jConnector()
with get_db_context() as db: cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"))
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )
@@ -1203,8 +1185,8 @@ async def run_longmemeval_test(
"search_limit": search_limit, "search_limit": search_limit,
"context_char_budget": context_char_budget, "context_char_budget": context_char_budget,
"search_type": search_type, "search_type": search_type,
"llm_id": SELECTED_LLM_ID, "llm_id": os.getenv("EVAL_LLM_ID"),
"embedding_id": SELECTED_EMBEDDING_ID, "embedding_id": os.getenv("EVAL_EMBEDDING_ID"),
"sample_size": sample_size, "sample_size": sample_size,
"start_index": start_index, "start_index": start_index,
}, },

View File

@@ -2,81 +2,30 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import re
import time import time
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List from typing import List, Dict, Any
import re
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
# 路径与模块导入保持与现有评估脚本一致
import sys
from pathlib import Path from pathlib import Path
_THIS_DIR = Path(__file__).resolve().parent from dotenv import load_dotenv
_PROJECT_ROOT = str(_THIS_DIR.parents[1])
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") # Load evaluation config
for _p in (_SRC_DIR, _PROJECT_ROOT): eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if _p not in sys.path: if eval_config_path.exists():
sys.path.insert(0, _p) load_dotenv(eval_config_path, override=True)
print(f"✅ 加载评估配置: {eval_config_path}")
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.models.base import RedBearModelConfig
from app.db import get_db_context
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.services.memory_config_service import MemoryConfigService from app.core.memory.src.search import run_hybrid_search # 使用与 evaluate_qa.py 相同的检索函数
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.models.base import RedBearModelConfig
from app.core.memory.utils.config.config_utils import get_embedder_config
try: from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
except Exception:
# 兜底:简单实现(必要时)
def f1_score(pred: str, ref: str) -> float:
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
tp = len(set(ps) & set(rs))
if tp == 0:
return 0.0
precision = tp / len(ps)
recall = tp / len(rs)
if precision + recall == 0:
return 0.0
return 2 * precision * recall / (precision + recall)
def bleu1(pred: str, ref: str) -> float: from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
ps = pred.lower().split()
rs = ref.lower().split()
if not ps or not rs:
return 0.0
overlap = len([w for w in ps if w in rs])
return overlap / max(len(ps), 1)
def jaccard(pred: str, ref: str) -> float:
ps = set(pred.lower().split())
rs = set(ref.lower().split())
union = len(ps | rs)
if union == 0:
return 0.0
return len(ps & rs) / union
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
@@ -219,16 +168,16 @@ async def run_memsciqa_test(
# 默认使用指定的 memsci 组 ID # 默认使用指定的 memsci 组 ID
end_user_id = end_user_id or "group_memsci" end_user_id = end_user_id or "group_memsci"
# 数据路径解析(项目根与当前工作目录兜底) # 数据路径解析
if not data_path: if not data_path:
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl") data_path = str(dataset_dir / "msc_self_instruct.jsonl")
if os.path.exists(proj_path):
data_path = proj_path if not os.path.exists(data_path):
elif os.path.exists(cwd_path): raise FileNotFoundError(
data_path = cwd_path f"数据集文件不存在: {data_path}\n"
else: f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl请确保其存在于项目根目录或当前工作目录的 data 目录下。") )
# 加载数据 # 加载数据
all_items = load_dataset_memsciqa(data_path) all_items = load_dataset_memsciqa(data_path)
@@ -238,17 +187,13 @@ async def run_memsciqa_test(
items = all_items[start_index:start_index + sample_size] items = all_items[start_index:start_index + sample_size]
# 初始化 LLM纯测试不进行摄入 # 初始化 LLM纯测试不进行摄入
with get_db_context() as db: llm = get_llm_client(os.getenv("EVAL_LLM_ID"))
factory = MemoryClientFactory(db)
llm = factory.get_llm_client(SELECTED_LLM_ID)
# 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test # 初始化 Neo4j 连接与向量检索 Embedder对齐 locomo_test
connector = Neo4jConnector() connector = Neo4jConnector()
embedder = None embedder = None
if search_type in ("embedding", "hybrid"): if search_type in ("embedding", "hybrid"):
with get_db_context() as db: cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"))
config_service = MemoryConfigService(db)
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
embedder = OpenAIEmbedderClient( embedder = OpenAIEmbedderClient(
model_config=RedBearModelConfig.model_validate(cfg_dict) model_config=RedBearModelConfig.model_validate(cfg_dict)
) )
@@ -273,7 +218,7 @@ async def run_memsciqa_test(
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "") question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "") reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
# 三路检索:chunks/statements/entities/summaries对齐 qwen_search_eval.py # 检索:使用与 evaluate_qa.py 相同的 run_hybrid_search
t0 = time.time() t0 = time.time()
results = None results = None
try: try:
@@ -302,57 +247,94 @@ async def run_memsciqa_test(
search_ms = (t1 - t0) * 1000 search_ms = (t1 - t0) * 1000
latencies_search.append(search_ms) latencies_search.append(search_ms)
# 构建上下文:包含 chunks、陈述、摘要和实体对齐 qwen_search_eval.py # 构建上下文:与 evaluate_qa.py 完全一致的逻辑
contexts_all: List[str] = [] contexts_all: List[str] = []
retrieved_counts: Dict[str, int] = {} retrieved_counts: Dict[str, int] = {}
if results: if results:
chunks = results.get("chunks", []) # 处理 hybrid 搜索结果
statements = results.get("statements", []) if search_type == "hybrid":
entities = results.get("entities", []) emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
summaries = results.get("summaries", []) kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
emb_dialogs = emb.get("dialogues", [])
emb_statements = emb.get("statements", [])
emb_entities = emb.get("entities", [])
kw_dialogs = kw.get("dialogues", [])
kw_statements = kw.get("statements", [])
kw_entities = kw.get("entities", [])
all_dialogs = emb_dialogs + kw_dialogs
all_statements = emb_statements + kw_statements
all_entities = emb_entities + kw_entities
# 简单去重
seen_dialog = set()
dialogues = []
for d in all_dialogs:
key = (str(d.get("uuid", "")), str(d.get("content", "")))
if key not in seen_dialog:
dialogues.append(d)
seen_dialog.add(key)
seen_stmt = set()
statements = []
for s in all_statements:
key = str(s.get("statement", ""))
if key not in seen_stmt:
statements.append(s)
seen_stmt.add(key)
seen_ent = set()
entities = []
for e in all_entities:
key = str(e.get("name", ""))
if key not in seen_ent:
entities.append(e)
seen_ent.add(key)
else:
# embedding 或 keyword 单独搜索
dialogues = results.get("dialogues", [])
statements = results.get("statements", [])
entities = results.get("entities", [])
retrieved_counts = { retrieved_counts = {
"chunks": len(chunks), "dialogues": len(dialogues),
"statements": len(statements), "statements": len(statements),
"entities": len(entities), "entities": len(entities),
"summaries": len(summaries),
} }
# 优先使用 chunks
for c in chunks: # 构建上下文文本
text = str(c.get("content", "")).strip() for d in dialogues:
text = str(d.get("content", "")).strip()
if text: if text:
contexts_all.append(text) contexts_all.append(text)
# 然后是 statements
for s in statements: for s in statements:
text = str(s.get("statement", "")).strip() text = str(s.get("statement", "")).strip()
if text: if text:
contexts_all.append(text) contexts_all.append(text)
# 然后是 summaries
for sm in summaries: # 实体摘要
text = str(sm.get("summary", "")).strip() if entities:
if text: scored = [e for e in entities if e.get("score") is not None]
contexts_all.append(text) top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
# 实体摘要最多加入前3个高分实体对齐 qwen_search_eval.py if top_entities:
scored = [e for e in entities if e.get("score") is not None] summary_lines = []
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3] for e in top_entities:
if top_entities: name = str(e.get("name", "")).strip()
summary_lines = [] etype = str(e.get("entity_type", "")).strip()
for e in top_entities: score = e.get("score")
name = str(e.get("name", "")).strip() if name:
etype = str(e.get("entity_type", "")).strip() meta = []
score = e.get("score") if etype:
if name: meta.append(f"type={etype}")
meta = [] if isinstance(score, (int, float)):
if etype: meta.append(f"score={score:.3f}")
meta.append(f"type={etype}") summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if isinstance(score, (int, float)): if summary_lines:
meta.append(f"score={score:.3f}") contexts_all.append("\n".join(summary_lines))
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
if summary_lines:
contexts_all.append("\n".join(summary_lines))
if verbose: if verbose:
if retrieved_counts: if retrieved_counts:
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要") print(f"✅ 检索成功: {retrieved_counts.get('dialogues',0)} dialogues, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
print(f"📊 有效上下文数量: {len(contexts_all)}") print(f"📊 有效上下文数量: {len(contexts_all)}")
q_keywords = extract_question_keywords(question, max_keywords=8) q_keywords = extract_question_keywords(question, max_keywords=8)
if q_keywords: if q_keywords:
@@ -507,8 +489,8 @@ async def run_memsciqa_test(
"llm_max_tokens": llm_max_tokens, "llm_max_tokens": llm_max_tokens,
"search_type": search_type, "search_type": search_type,
"start_index": start_index, "start_index": start_index,
"llm_id": SELECTED_LLM_ID, "llm_id": os.getenv("EVAL_LLM_ID"),
"retrieval_embedding_id": SELECTED_EMBEDDING_ID "retrieval_embedding_id": os.getenv("EVAL_EMBEDDING_ID")
}, },
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
} }
@@ -522,7 +504,7 @@ async def run_memsciqa_test(
def main(): def main():
load_dotenv() load_dotenv()
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)") parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)") parser.add_argument("--sample-size", type=int, default=10, help="样本数量(<=0 表示全部)")
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size") parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size")
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引") parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID默认 group_memsci") parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID默认 group_memsci")

View File

@@ -4,35 +4,20 @@ import json
import os import os
import time import time
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List from typing import List, Dict, Any
from pathlib import Path
from dotenv import load_dotenv
if TYPE_CHECKING: # Load evaluation config
from app.schemas.memory_config_schema import MemoryConfig eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
if eval_config_path.exists():
load_dotenv(eval_config_path, override=True)
try:
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.core.memory.evaluation.common.metrics import (
avg_context_tokens,
exact_match,
latency_stats,
)
from app.core.memory.evaluation.extraction_utils import (
ingest_contexts_via_full_pipeline,
)
from app.core.memory.storage_services.search import run_hybrid_search
from app.core.memory.utils.config.definitions import (
PROJECT_ROOT,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_LLM_ID,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str: def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
@@ -135,24 +120,37 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
return merged return merged
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]: async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
end_user_id = end_user_id or SELECTED_GROUP_ID end_user_id = end_user_id or SELECTED_GROUP_ID
# Load data # Load data
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl") dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
data_path = dataset_dir / "msc_self_instruct.jsonl"
if not os.path.exists(data_path): if not os.path.exists(data_path):
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl") raise FileNotFoundError(
f"数据集文件不存在: {data_path}\n"
f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
)
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
lines = f.readlines() lines = f.readlines()
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]] items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入 # 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
# 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略 # 说明memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
contexts: List[str] = [build_context_from_dialog(item) for item in items] contexts: List[str] = [build_context_from_dialog(item) for item in items]
await ingest_contexts_via_full_pipeline(contexts, end_user_id) await ingest_contexts_via_full_pipeline(contexts, end_user_id)
# LLM client (使用异步调用) # LLM client (使用异步调用)
with get_db_context() as db: from app.db import get_db
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(SELECTED_LLM_ID) db = next(get_db())
try:
llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db)
finally:
db.close()
# Evaluate each item # Evaluate each item
connector = Neo4jConnector() connector = Neo4jConnector()
@@ -177,7 +175,6 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None
limit=search_limit, limit=search_limit,
include=["dialogues", "statements", "entities"], include=["dialogues", "statements", "entities"],
output_path=None, output_path=None,
memory_config=memory_config,
) )
except Exception: except Exception:
results = None results = None
@@ -261,11 +258,7 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip()) pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference # Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
correct_flags.append(exact_match(pred, reference)) correct_flags.append(exact_match(pred, reference))
from app.core.memory.evaluation.common.metrics import ( from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
bleu1,
f1_score,
jaccard,
)
f1s.append(f1_score(str(pred), str(reference))) f1s.append(f1_score(str(pred), str(reference)))
b1s.append(bleu1(str(pred), str(reference))) b1s.append(bleu1(str(pred), str(reference)))
jss.append(jaccard(str(pred), str(reference))) jss.append(jaccard(str(pred), str(reference)))
@@ -295,15 +288,39 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None
def main(): def main():
# Load environment variables first
load_dotenv() load_dotenv()
# Get defaults from environment variables
env_sample_size = os.getenv("MEMSCIQA_SAMPLE_SIZE")
env_search_limit = os.getenv("MEMSCIQA_SEARCH_LIMIT")
env_context_budget = os.getenv("MEMSCIQA_CONTEXT_CHAR_BUDGET")
env_llm_max_tokens = os.getenv("MEMSCIQA_LLM_MAX_TOKENS")
env_skip_ingest = os.getenv("MEMSCIQA_SKIP_INGEST", "false").lower() in ("true", "1", "yes")
env_output_dir = os.getenv("MEMSCIQA_OUTPUT_DIR")
# Convert to appropriate types with fallback to code defaults
default_sample_size = int(env_sample_size) if env_sample_size else 1
default_search_limit = int(env_search_limit) if env_search_limit else 8
default_context_budget = int(env_context_budget) if env_context_budget else 4000
default_llm_max_tokens = int(env_llm_max_tokens) if env_llm_max_tokens else 64
default_output_dir = env_output_dir if env_output_dir else None
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen") parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量") parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id默认取 runtime.json") parser.add_argument("--end-user-id", type=str, default=None, help="可选 end_user_id默认使用环境变量")
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数") parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算") parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度") parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度") parser.add_argument("--llm-max-tokens", type=int, default=default_llm_max_tokens,
help=f"LLM 最大生成长度 (env: MEMSCIQA_LLM_MAX_TOKENS={env_llm_max_tokens or 'not set'})")
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型") parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
parser.add_argument("--skip-ingest", action="store_true", default=env_skip_ingest,
help=f"跳过数据摄入,使用 Neo4j 中的现有数据 (env: MEMSCIQA_SKIP_INGEST={os.getenv('MEMSCIQA_SKIP_INGEST', 'false')})")
parser.add_argument("--output-dir", type=str, default=default_output_dir,
help=f"结果保存目录 (env: MEMSCIQA_OUTPUT_DIR={env_output_dir or 'not set'})")
args = parser.parse_args() args = parser.parse_args()
result = asyncio.run( result = asyncio.run(
@@ -315,9 +332,37 @@ def main():
llm_temperature=args.llm_temperature, llm_temperature=args.llm_temperature,
llm_max_tokens=args.llm_max_tokens, llm_max_tokens=args.llm_max_tokens,
search_type=args.search_type, search_type=args.search_type,
skip_ingest=args.skip_ingest,
) )
) )
# Print results to console
print(json.dumps(result, ensure_ascii=False, indent=2)) print(json.dumps(result, ensure_ascii=False, indent=2))
# Save results to file
output_dir = args.output_dir
if output_dir is None:
# Use absolute path to ensure results are saved in the correct location
script_dir = Path(__file__).resolve().parent
output_dir = script_dir / "results"
elif not Path(output_dir).is_absolute():
# If relative path, make it relative to this script's directory
script_dir = Path(__file__).resolve().parent
output_dir = script_dir / output_dir
else:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = output_dir / f"memsciqa_{timestamp_str}.json"
try:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n✅ 结果已保存到: {output_path}")
except Exception as e:
print(f"\n❌ 保存结果失败: {e}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -2,20 +2,16 @@ import argparse
import asyncio import asyncio
import json import json
import os import os
import sys
from typing import Any, Dict from typing import Any, Dict
from pathlib import Path
from dotenv import load_dotenv
# Add src directory to Python path for proper imports when running from evaluation directory # Load evaluation config
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src')) eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
if eval_config_path.exists():
try: load_dotenv(eval_config_path, override=True)
from dotenv import load_dotenv
except Exception:
def load_dotenv():
return None
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
@@ -36,8 +32,9 @@ async def run(
start_index: int | None = None, start_index: int | None = None,
max_contexts_per_item: int | None = None, max_contexts_per_item: int | None = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认 # Use environment variable with fallback chain if not provided
end_user_id = end_user_id or SELECTED_GROUP_ID if end_user_id is None:
end_user_id = os.getenv("EVAL_END_USER_ID", "benchmark_default")
if reset_group: if reset_group:
connector = Neo4jConnector() connector = Neo4jConnector()

View File

@@ -1064,13 +1064,16 @@ class ExtractionOrchestrator:
if statement.triplet_extraction_info: if statement.triplet_extraction_info:
triplet_info = statement.triplet_extraction_info triplet_info = statement.triplet_extraction_info
# 创建实体索引到ID的映射 # 创建实体索引到ID的映射(支持多种索引方式)
entity_idx_to_id = {} entity_idx_to_id = {}
# 创建实体节点 # 创建实体节点
for entity_idx, entity in enumerate(triplet_info.entities): for entity_idx, entity in enumerate(triplet_info.entities):
# 映射实体索引到实体ID # 映射实体索引到实体ID(使用多个键以提高容错性)
# 1. 使用实体自己的 entity_idx
entity_idx_to_id[entity.entity_idx] = entity.id entity_idx_to_id[entity.entity_idx] = entity.id
# 2. 使用枚举索引从0开始
entity_idx_to_id[entity_idx] = entity.id
if entity.id not in entity_id_set: if entity.id not in entity_id_set:
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong') entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
@@ -1149,9 +1152,18 @@ class ExtractionOrchestrator:
relationship_result relationship_result
) )
else: else:
logger.warning( # 改进的警告信息,包含更多调试信息
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, " missing_subject = "subject" if not subject_entity_id else ""
f"object_id={triplet.object_id}, statement_id={statement.id}" missing_object = "object" if not object_entity_id else ""
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
logger.debug(
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
f"object_id={triplet.object_id} ({triplet.object_name}), "
f"predicate={triplet.predicate}, "
f"statement_id={statement.id}, "
f"available_indices={sorted(entity_idx_to_id.keys())}"
) )
logger.info( logger.info(

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import relationship
from app.base.type import PydanticType from app.base.type import PydanticType
from app.db import Base from app.db import Base
from app.schemas import ModelParameters from app.schemas.app_schema import ModelParameters
class AgentConfig(Base): class AgentConfig(Base):

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import relationship
from app.base.type import PydanticType from app.base.type import PydanticType
from app.db import Base from app.db import Base
from app.schemas import ModelParameters from app.schemas.app_schema import ModelParameters
class OrchestrationMode(StrEnum): class OrchestrationMode(StrEnum):

View File

@@ -4,7 +4,7 @@ import datetime
from typing import Optional, List, Dict, Any, Union from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, ConfigDict, field_serializer from pydantic import BaseModel, Field, ConfigDict, field_serializer
from app.schemas import ModelParameters from app.schemas.app_schema import ModelParameters
# ==================== 子 Agent 配置 ==================== # ==================== 子 Agent 配置 ====================

View File

@@ -5,7 +5,7 @@ import uuid
from typing import Dict, Any, List, Optional, Tuple from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.schemas import ModelParameters from app.schemas.app_schema import ModelParameters
from app.services.conversation_state_manager import ConversationStateManager from app.services.conversation_state_manager import ConversationStateManager
from app.models import ModelConfig, AgentConfig from app.models import ModelConfig, AgentConfig
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger

View File

@@ -188,7 +188,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
"config_desc": config.config_desc, "config_desc": config.config_desc,
"workspace_id": str(config.workspace_id) if config.workspace_id else None, "workspace_id": str(config.workspace_id) if config.workspace_id else None,
"end_user_id": config.end_user_id, "end_user_id": config.end_user_id,
"user_id": config.user_id, "config_id_old": int(config.user_id),
"apply_id": config.apply_id, "apply_id": config.apply_id,
"llm_id": config.llm_id, "llm_id": config.llm_id,
"embedding_id": config.embedding_id, "embedding_id": config.embedding_id,

View File

@@ -57,7 +57,7 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]:
if data is None: if data is None:
return None return None
from app.schemas import ModelParameters from app.schemas.app_schema import ModelParameters
if isinstance(data, ModelParameters): if isinstance(data, ModelParameters):
return data return data

View File

@@ -8,6 +8,7 @@ import { type FC, useRef, useEffect } from 'react'
import clsx from 'clsx' import clsx from 'clsx'
import Markdown from '@/components/Markdown' import Markdown from '@/components/Markdown'
import type { ChatContentProps } from './types' import type { ChatContentProps } from './types'
import { Spin } from 'antd'
/** /**
* 聊天内容显示组件 * 聊天内容显示组件
@@ -21,7 +22,8 @@ const ChatContent: FC<ChatContentProps> = ({
empty, empty,
labelPosition = 'bottom', labelPosition = 'bottom',
labelFormat, labelFormat,
errorDesc errorDesc,
renderRuntime
}) => { }) => {
// 滚动容器引用,用于控制自动滚动到底部 // 滚动容器引用,用于控制自动滚动到底部
const scrollContainerRef = useRef<(HTMLDivElement | null)>(null) const scrollContainerRef = useRef<(HTMLDivElement | null)>(null)
@@ -45,8 +47,8 @@ const ChatContent: FC<ChatContentProps> = ({
'rb:left-0 rb:text-left': item.role === 'assistant', // 助手消息左对齐 'rb:left-0 rb:text-left': item.role === 'assistant', // 助手消息左对齐
})}> })}>
{/* 流式加载时且内容为空则不显示 */} {/* 流式加载时且内容为空则不显示 */}
{streamLoading && item.content === '' {streamLoading && item.content === '' && !renderRuntime
? null ? <Spin />
: <> : <>
{/* 顶部标签(如时间戳、用户名等) */} {/* 顶部标签(如时间戳、用户名等) */}
{labelPosition === 'top' && {labelPosition === 'top' &&
@@ -55,16 +57,17 @@ const ChatContent: FC<ChatContentProps> = ({
</div> </div>
} }
{/* 消息气泡框 */} {/* 消息气泡框 */}
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, { <div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-130 rb:wrap-break-word', contentClassNames, {
// 错误消息样式内容为null且非助手消息 // 错误消息样式内容为null且非助手消息
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null, 'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null && !renderRuntime,
// 助手消息样式 // 助手消息样式
'rb:bg-[rgba(21,94,239,0.08)] rb:border-[rgba(21,94,239,0.30)]': item.role === 'user', 'rb:bg-[rgba(21,94,239,0.08)] rb:border-[rgba(21,94,239,0.30)]': item.role === 'user',
// 用户消息样式 // 用户消息样式
'rb:bg-[#FFFFFF] rb:border-[#EBEBEB]': item.role === 'assistant' && (item.content || item.content === ''), 'rb:bg-[#FFFFFF] rb:border-[#EBEBEB]': item.role === 'assistant' && (item.content || item.content === '' || typeof renderRuntime === 'function'),
})}> })}>
{item.subContent && renderRuntime && renderRuntime(item, index)}
{/* 使用Markdown组件渲染消息内容 */} {/* 使用Markdown组件渲染消息内容 */}
<Markdown content={item.content ?? errorDesc ?? ''} /> <Markdown content={renderRuntime ? item.content ?? '' : item.content ?? errorDesc ?? ''} />
</div> </div>
{/* 底部标签(如时间戳、用户名等) */} {/* 底部标签(如时间戳、用户名等) */}
{labelPosition === 'bottom' && {labelPosition === 'bottom' &&

View File

@@ -19,7 +19,9 @@ export interface ChatItem {
/** 消息内容 */ /** 消息内容 */
content?: string | null; content?: string | null;
/** 创建时间 */ /** 创建时间 */
created_at?: number | string created_at?: number | string;
status?: string;
subContent?: Record<string, any>[]
} }
/** /**
@@ -81,4 +83,5 @@ export interface ChatContentProps {
/** 标签格式化函数 */ /** 标签格式化函数 */
labelFormat: (item: ChatItem) => any; labelFormat: (item: ChatItem) => any;
errorDesc?: string; errorDesc?: string;
renderRuntime?: (item: ChatItem, index: number) => ReactNode;
} }

View File

@@ -15,7 +15,7 @@ interface ApiResponse<T> {
interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> { interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> {
url: string; url: string;
params?: Record<string, unknown>; params?: Record<string, unknown>;
valueKey?: string; valueKey?: string | string[];
labelKey?: string; labelKey?: string;
placeholder?: string; placeholder?: string;
hasAll?: boolean; hasAll?: boolean;
@@ -66,11 +66,18 @@ const CustomSelect: FC<CustomSelectProps> = ({
{...props} {...props}
> >
{hasAll && <Select.Option value={null}>{allTitle || t('common.all')}</Select.Option>} {hasAll && <Select.Option value={null}>{allTitle || t('common.all')}</Select.Option>}
{displayOptions.map((option) => ( {displayOptions.map((option) => {
<Select.Option key={option[valueKey]} value={option[valueKey]}> const getValue = () => {
{String(option[labelKey])} if (typeof valueKey === 'string') return option[valueKey];
</Select.Option> return valueKey.find(key => option[key] != null) ? option[valueKey.find(key => option[key] != null)!] : undefined;
))} };
const value = getValue();
return (
<Select.Option key={value} value={value}>
{String(option[labelKey])}
</Select.Option>
);
})}
</Select> </Select>
); );
}; };

View File

@@ -6,6 +6,9 @@ import CopyBtn from './CopyBtn';
type ICodeBlockProps = { type ICodeBlockProps = {
value: string; value: string;
needCopy?: boolean;
size?: 'small' | 'default';
showLineNumbers?: boolean;
} }
// enum languageType { // enum languageType {
@@ -16,6 +19,9 @@ type ICodeBlockProps = {
const CodeBlock: FC<ICodeBlockProps> = ({ const CodeBlock: FC<ICodeBlockProps> = ({
value, value,
needCopy = true,
size = 'default',
showLineNumbers = false
}) => { }) => {
return ( return (
@@ -23,24 +29,26 @@ const CodeBlock: FC<ICodeBlockProps> = ({
<SyntaxHighlighter <SyntaxHighlighter
style={atelierHeathLight} style={atelierHeathLight}
customStyle={{ customStyle={{
padding: '16px 20px 16px 24px', padding: '8px 12px 8px 12px',
backgroundColor: '#F0F3F8', backgroundColor: '#F0F3F8',
borderRadius: 8, borderRadius: 8,
fontSize: size === 'small' ? 12 : 14,
wordBreak: 'break-all'
}} }}
language="json" language="json"
showLineNumbers={false} showLineNumbers={showLineNumbers}
PreTag="div" PreTag="div"
> >
{value} {value}
</SyntaxHighlighter> </SyntaxHighlighter>
<CopyBtn {needCopy && <CopyBtn
value={value} value={value}
style={{ style={{
position: 'absolute', position: 'absolute',
top: 20, top: 20,
right: 20, right: 20,
}} }}
/> />}
</div> </div>
) )
} }

View File

@@ -1982,6 +1982,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
arrange: 'Arrange', arrange: 'Arrange',
redo: 'Redo', redo: 'Redo',
undo: 'Undo', undo: 'Undo',
input: 'Input',
output: 'Output',
error: 'Error Message',
}, },
emotionEngine: { emotionEngine: {
emotionEngineConfig: 'Emotion Engine Configuration', emotionEngineConfig: 'Emotion Engine Configuration',

View File

@@ -2076,6 +2076,10 @@ export const zh = {
arrange: '整理', arrange: '整理',
redo: '重做', redo: '重做',
undo: '撤销', undo: '撤销',
input: '输入',
output: '输出',
error: '错误信息',
}, },
emotionEngine: { emotionEngine: {
emotionEngineConfig: '情感引擎配置', emotionEngineConfig: '情感引擎配置',

View File

@@ -123,6 +123,20 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
let response = await makeSSERequest(url, data, token || '', config); let response = await makeSSERequest(url, data, token || '', config);
switch (response.status) { switch (response.status) {
case 500:
case 502:
const errorData = await response.json();
errorData.error || i18n.t('common.serviceUpgrading');
message.warning(errorData.error || i18n.t('common.serviceUpgrading'));
break
case 400:
const error = await response.json();
message.warning(error.error);
throw error || 'Bad Request';
case 504:
const errorJson = await response.json();
message.warning(errorJson.error || i18n.t('common.serverError'));
break
case 401: case 401:
if (url?.includes('/public')) { if (url?.includes('/public')) {
return message.warning(i18n.t('common.publicApiCannotRefreshToken')); return message.warning(i18n.t('common.publicApiCannotRefreshToken'));

View File

@@ -79,7 +79,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[],
placeholder={t('common.pleaseSelect')} placeholder={t('common.pleaseSelect')}
url={url} url={url}
hasAll={false} hasAll={false}
valueKey='config_id' valueKey={['config_id_old', 'config_id']}
labelKey="config_name" labelKey="config_name"
/> />
</Form.Item> </Form.Item>
@@ -126,12 +126,14 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
getApplicationConfig(id as string).then(res => { getApplicationConfig(id as string).then(res => {
const response = res as Config const response = res as Config
let allTools = Array.isArray(response.tools) ? response.tools : [] let allTools = Array.isArray(response.tools) ? response.tools : []
const memoryContent = response.memory?.memory_content
const convertedMemoryContent = memoryContent && !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent
form.setFieldsValue({ form.setFieldsValue({
...response, ...response,
tools: allTools, tools: allTools,
memory: { memory: {
...response.memory, ...response.memory,
memory_content: response.memory?.memory_content ? Number(response.memory?.memory_content) : undefined memory_content: convertedMemoryContent
} }
}) })
setData({ setData({

View File

@@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
useEffect(() => { useEffect(() => {
if (values?.retrieve_type) { if (values?.retrieve_type) {
const fieldsToReset = Object.keys(values).filter(key => const fieldsToReset = Object.keys(values).filter(key =>
key !== 'kb_id' && key !== 'retrieve_type' key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k'
) as (keyof KnowledgeConfigForm)[]; ) as (keyof KnowledgeConfigForm)[];
form.resetFields(fieldsToReset); form.resetFields(fieldsToReset);
} }

View File

@@ -1,8 +1,9 @@
import { forwardRef, useImperativeHandle, useState, useRef } from 'react' import { forwardRef, useImperativeHandle, useState, useRef } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import clsx from 'clsx' import clsx from 'clsx'
import { Input, Form, App } from 'antd' import { Input, Form, App, Space, Button, Collapse } from 'antd'
import { Space, Button } from 'antd' import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons'
import CodeBlock from '@/components/Markdown/CodeBlock'
import ChatIcon from '@/assets/images/application/chat.png' import ChatIcon from '@/assets/images/application/chat.png'
import RbDrawer from '@/components/RbDrawer'; import RbDrawer from '@/components/RbDrawer';
@@ -13,8 +14,11 @@ import ChatContent from '@/components/Chat/ChatContent'
import type { ChatItem } from '@/components/Chat/types' import type { ChatItem } from '@/components/Chat/types'
import ChatSendIcon from '@/assets/images/application/chatSend.svg' import ChatSendIcon from '@/assets/images/application/chatSend.svg'
import dayjs from 'dayjs' import dayjs from 'dayjs'
import type { ChatRef, VariableConfigModalRef, StartVariableItem, GraphRef } from '../../types' import type { ChatRef, VariableConfigModalRef, GraphRef } from '../../types'
import { type SSEMessage } from '@/utils/stream' import { type SSEMessage } from '@/utils/stream'
import type { Variable } from '../Properties/VariableList/types'
import styles from './chat.module.css'
import Markdown from '@/components/Markdown'
const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId, graphRef }, ref) => { const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId, graphRef }, ref) => {
const { t } = useTranslation() const { t } = useTranslation()
@@ -24,7 +28,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
const [open, setOpen] = useState(false) const [open, setOpen] = useState(false)
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const [chatList, setChatList] = useState<ChatItem[]>([]) const [chatList, setChatList] = useState<ChatItem[]>([])
const [variables, setVariables] = useState<StartVariableItem[]>([]) const [variables, setVariables] = useState<Variable[]>([])
const [streamLoading, setStreamLoading] = useState(false) const [streamLoading, setStreamLoading] = useState(false)
const [conversationId, setConversationId] = useState<string | null>(null) const [conversationId, setConversationId] = useState<string | null>(null)
@@ -39,7 +43,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
if (startNodes.length) { if (startNodes.length) {
const curVariables = startNodes[0].config.variables?.defaultValue const curVariables = startNodes[0].config.variables?.defaultValue
curVariables.forEach((vo: StartVariableItem) => { curVariables.forEach((vo: Variable) => {
if (typeof vo.default !== 'undefined') { if (typeof vo.default !== 'undefined') {
vo.value = vo.default vo.value = vo.default
} }
@@ -60,7 +64,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
const handleEditVariables = () => { const handleEditVariables = () => {
variableConfigModalRef.current?.handleOpen(variables) variableConfigModalRef.current?.handleOpen(variables)
} }
const handleSave = (values: StartVariableItem[]) => { const handleSave = (values: Variable[]) => {
setVariables([...values]) setVariables([...values])
} }
const handleSend = () => { const handleSend = () => {
@@ -97,13 +101,28 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
role: 'assistant', role: 'assistant',
content: '', content: '',
created_at: Date.now(), created_at: Date.now(),
subContent: [],
}]) }])
const handleStreamMessage = (data: SSEMessage[]) => { const handleStreamMessage = (data: SSEMessage[]) => {
setStreamLoading(false)
data.forEach(item => { data.forEach(item => {
const { chunk, conversation_id } = item.data as { chunk: string; conversation_id: string | null; }; const { chunk, conversation_id, node_id, input, output, error, elapsed_time, status } = item.data as {
chunk: string;
conversation_id: string | null;
node_id: string;
node_name?: string;
input?: any;
output?: any;
elapsed_time?: string;
error?: any;
state: Record<string, any>;
status?: 'completed' | 'failed'
};
const node = graphRef.current?.getNodes().find(n => n.id === node_id);
const { name, icon } = node?.getData() || {}
console.log('node', node?.getData())
switch(item.event) { switch(item.event) {
case 'message': case 'message':
@@ -119,6 +138,66 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
return newList return newList
}) })
break break
case 'node_start':
setChatList(prev => {
const newList = [...prev]
const lastIndex = newList.length - 1
if (lastIndex >= 0) {
const newSubContent = newList[lastIndex].subContent || []
const filterIndex = newSubContent.findIndex(vo => vo.id === node_id)
if (filterIndex > -1) {
newSubContent[filterIndex] = {
...newSubContent[filterIndex],
node_id: node_id,
node_name: name,
icon,
content: {},
}
} else {
newSubContent.push({
id: node_id,
node_id: node_id,
node_name: name,
icon,
content: {},
})
}
newList[lastIndex] = {
...newList[lastIndex],
subContent: newSubContent
}
}
return newList
})
break
case 'node_end':
case 'node_error':
setChatList(prev => {
const newList = [...prev]
const lastIndex = newList.length - 1
if (lastIndex >= 0) {
const newSubContent = newList[lastIndex].subContent || []
const filterIndex = newSubContent.findIndex(vo => vo.node_id === node_id)
if (filterIndex > -1 && newSubContent[filterIndex].content) {
newSubContent[filterIndex] = {
...newSubContent[filterIndex],
content: {
input,
output,
error,
},
status: status || 'completed',
elapsed_time
}
}
newList[lastIndex] = {
...newList[lastIndex],
subContent: newSubContent
}
}
return newList
})
break
case 'workflow_end': case 'workflow_end':
setChatList(prev => { setChatList(prev => {
const newList = [...prev] const newList = [...prev]
@@ -126,6 +205,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
if (lastIndex >= 0) { if (lastIndex >= 0) {
newList[lastIndex] = { newList[lastIndex] = {
...newList[lastIndex], ...newList[lastIndex],
status,
content: newList[lastIndex].content === '' ? null : newList[lastIndex].content content: newList[lastIndex].content === '' ? null : newList[lastIndex].content
} }
} }
@@ -142,14 +222,31 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
} }
form.setFieldValue('message', undefined) form.setFieldValue('message', undefined)
setStreamLoading(true)
draftRun(appId, { draftRun(appId, {
message: message, message: message,
variables: params, variables: params,
stream: true, stream: true,
conversation_id: conversationId conversation_id: conversationId
}, handleStreamMessage) }, handleStreamMessage)
.catch((error) => {
setChatList(prev => {
const newList = [...prev]
const lastIndex = newList.length - 1
if (lastIndex >= 0) {
newList[lastIndex] = {
...newList[lastIndex],
status: 'failed',
content: null,
subContent: error.error
}
}
return newList
})
})
.finally(() => { .finally(() => {
setLoading(false) setLoading(false)
setStreamLoading(false)
}) })
} }
// 暴露给父组件的方法 // 暴露给父组件的方法
@@ -158,6 +255,11 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
handleClose handleClose
})); }));
const getStatus = (status?: string) => {
return status === 'completed' ? 'rb:text-[#369F21]' : status === 'failed' ? 'rb:text-[#FF5D34]' : 'rb:text-[#5B6167]'
}
console.log('chatList', chatList)
return ( return (
<RbDrawer <RbDrawer
title={<div className="rb:flex rb:items-center rb:gap-2.5"> title={<div className="rb:flex rb:items-center rb:gap-2.5">
@@ -173,10 +275,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
onClose={handleClose} onClose={handleClose}
> >
<ChatContent <ChatContent
classNames={{ classNames="rb:mx-[16px] rb:pt-[24px] rb:h-[calc(100%-76px)]"
'rb:mx-[16px] rb:pt-[24px] rb:h-[calc(100%-76px)]': true,
}}
contentClassNames="rb:max-w-[400px]!'" contentClassNames="rb:max-w-[400px]!'"
empty={<Empty url={ChatIcon} title={t('application.chatEmpty')} isNeedSubTitle={false} size={[240, 200]} className="rb:h-full" />} empty={<Empty url={ChatIcon} title={t('application.chatEmpty')} isNeedSubTitle={false} size={[240, 200]} className="rb:h-full" />}
data={chatList} data={chatList}
@@ -184,6 +283,87 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
labelPosition="bottom" labelPosition="bottom"
labelFormat={(item) => dayjs(item.created_at).locale('en').format('MMMM D, YYYY [at] h:mm A')} labelFormat={(item) => dayjs(item.created_at).locale('en').format('MMMM D, YYYY [at] h:mm A')}
errorDesc={t('application.ReplyException')} errorDesc={t('application.ReplyException')}
renderRuntime={(item, index) => {
return (
<div key={index} className="rb:w-100 rb:mb-2">
<Collapse
className={styles[item.status || 'default']}
items={[{
key: 0,
label: <div className={getStatus(item.status)}>
{item.status === 'completed' ? <CheckCircleFilled className="rb:mr-1" /> : item.status === 'failed' ? <CloseCircleFilled className="rb:mr-1" /> : <LoadingOutlined className="rb:mr-1" />}
{t('application.workflow')}
</div>,
className: styles.collapseItem,
children: (
Array.isArray(item.subContent)
? <Space size={8} direction="vertical" className="rb:w-full!">
{item.subContent?.map(vo => (
<Collapse
key={vo.node_id}
items={[{
key: vo.node_id,
label: <div className={clsx("rb:flex rb:justify-between rb:items-center", getStatus(vo.status))}>
<div className="rb:flex rb:items-center rb:gap-1 rb:flex-1">
{vo.icon && <img src={vo.icon} className="rb:size-4" />}
<div className="rb:wrap-break-word rb:line-clamp-1">{vo.node_name || vo.node_id}</div>
</div>
<span>
{typeof vo.elapsed_time == 'number' && <>{vo.elapsed_time?.toFixed(3)}ms</>}
{vo.status === 'completed' ? <CheckCircleFilled className="rb:ml-1" /> : vo.status === 'failed' ? <CloseCircleFilled className="rb:ml-1" /> : <LoadingOutlined className="rb:ml-1" />}
</span>
</div>,
className: styles.collapseItem,
children: (
<Space size={8} direction="vertical" className="rb:w-full!">
{vo.status === 'failed' &&
<div className={clsx("rb:bg-[#F0F3F8] rb:rounded-md", getStatus(vo.status))}>
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
{t(`workflow.error`)}
<Button
className="rb:py-0! rb:px-1! rb:text-[12px]!"
size="small"
>{t('common.copy')}</Button>
</div>
<div className="rb:pb-2 rb:px-3 rb:max-h-40 rb:overflow-auto">
<Markdown content={vo.content?.error || ''} />
</div>
</div>
}
{['input', 'output'].map(key => (
<div key={key} className="rb:bg-[#F0F3F8] rb:rounded-md">
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
{t(`workflow.${key}`)}
<Button
className="rb:py-0! rb:px-1! rb:text-[12px]!"
size="small"
>{t('common.copy')}</Button>
</div>
<div className="rb:max-h-40 rb:overflow-auto">
<CodeBlock
size="small"
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
needCopy={false}
showLineNumbers={true}
/>
</div>
</div>
))}
</Space>
)
}]}
/>
))}
</Space>
: <div className={clsx("rb:bg-[#FBFDFF] rb:rounded-md rb:py-2 rb:px-3 ", getStatus('failed'))}>
<Markdown content={item.subContent || ''} />
</div>
)
}]}
/>
</div>
)
}}
/> />
<div className="rb:flex rb:items-center rb:gap-2.5 rb:p-4"> <div className="rb:flex rb:items-center rb:gap-2.5 rb:p-4">
<Form form={form} style={{width: 'calc(100% - 54px)'}}> <Form form={form} style={{width: 'calc(100% - 54px)'}}>

View File

@@ -0,0 +1,45 @@
.completed {
background-color: rgba(54, 159, 33, 0.06);
border-color: rgba(54, 159, 33, 0.25);
border-radius: 8px;
}
.failed {
background-color: rgba(255, 138, 76, 0.08);
border-color: rgba(255, 138, 76, 0.20);
border-radius: 8px;
}
.default {
background-color: rgba(91, 97, 103, 0.08);
border-color: rgba(91, 97, 103, 0.30);
border-radius: 8px;
}
.collapse-item {
font-size: 12px;
line-height: 16px;
}
.collapse-item:global(.ant-collapse-item>.ant-collapse-header) {
padding: 8px 12px;
}
.collapse-item:global(.ant-collapse-item>.ant-collapse-header .ant-collapse-expand-icon) {
height: 16px;
}
.completed:global(.ant-collapse .ant-collapse-content),
.failed:global(.ant-collapse .ant-collapse-content) {
background-color: transparent;
border-top: none;
}
:global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) {
padding-top: 0;
}
.collapse-item :global(.ant-collapse) {
/* background-color: #F0F3F8; */
background-color: #FBFDFF;
border-radius: 6px;
}
.collapse-item :global(.ant-collapse>.ant-collapse-item:last-child),
.collapse-item :global(.ant-collapse>.ant-collapse-item:last-child>.ant-collapse-header) {
border-radius: 0 0 6px 6px;
}
.collapse-item :global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) {
padding: 0 4px 4px 4px;
}

View File

@@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
useEffect(() => { useEffect(() => {
if (values?.retrieve_type) { if (values?.retrieve_type) {
const fieldsToReset = Object.keys(values).filter(key => const fieldsToReset = Object.keys(values).filter(key =>
key !== 'kb_id' && key !== 'retrieve_type' key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k'
) as (keyof KnowledgeConfigForm)[]; ) as (keyof KnowledgeConfigForm)[];
form.resetFields(fieldsToReset); form.resetFields(fieldsToReset);
} }
@@ -108,6 +108,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
label: t(`application.${key}`), label: t(`application.${key}`),
value: key, value: key,
}))} }))}
// onChange={handleChange}
/> />
</FormItem> </FormItem>
{/* Top K */} {/* Top K */}
@@ -116,13 +117,12 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
label={t('application.top_k')} label={t('application.top_k')}
rules={[{ required: true, message: t('common.pleaseEnter') }]} rules={[{ required: true, message: t('common.pleaseEnter') }]}
extra={t('application.top_k_desc')} extra={t('application.top_k_desc')}
initialValue={5}
> >
<InputNumber <InputNumber
style={{ width: '100%' }} style={{ width: '100%' }}
min={1} min={1}
max={20} max={20}
onChange={(value) => form.setFieldValue('top_k', value)} // onChange={(value) => form.setFieldValue('top_k', value)}
/> />
</FormItem> </FormItem>
{/* 语义相似度阈值 similarity_threshold */} {/* 语义相似度阈值 similarity_threshold */}

View File

@@ -200,7 +200,7 @@ export const nodeLibrary: NodeLibrary[] = [
config_id: { config_id: {
type: 'customSelect', type: 'customSelect',
url: memoryConfigListUrl, url: memoryConfigListUrl,
valueKey: 'config_id', valueKey: ['config_id_old', 'config_id'],
labelKey: 'config_name' labelKey: 'config_name'
}, },
search_switch: { search_switch: {
@@ -223,7 +223,7 @@ export const nodeLibrary: NodeLibrary[] = [
config_id: { config_id: {
type: 'customSelect', type: 'customSelect',
url: memoryConfigListUrl, url: memoryConfigListUrl,
valueKey: 'config_id', valueKey: ['config_id_old', 'config_id'],
labelKey: 'config_name' labelKey: 'config_name'
} }
} }
@@ -284,7 +284,7 @@ export const nodeLibrary: NodeLibrary[] = [
config: { config: {
input: { input: {
type: 'variableList', type: 'variableList',
filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop'], filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop', 'parameter-extractor'],
filterVariableNames: ['message'] filterVariableNames: ['message']
}, },
parallel: { parallel: {

View File

@@ -14,7 +14,7 @@ export interface NodeConfig {
url?: string; url?: string;
params?: { [key: string]: unknown; } params?: { [key: string]: unknown; }
valueKey?: string; valueKey?: string | string[];
labelKey?: string; labelKey?: string;
defaultValue?: any; defaultValue?: any;