Merge branch 'develop' of github.com:SuanmoSuanyangTechnology/MemoryBear into develop
This commit is contained in:
224
api/app/core/memory/evaluation/.env.evaluation.example
Normal file
224
api/app/core/memory/evaluation/.env.evaluation.example
Normal file
@@ -0,0 +1,224 @@
|
||||
# ============================================================================
|
||||
# 基准测试统一配置文件示例
|
||||
# ============================================================================
|
||||
# 复制此文件为 .env.evaluation 并根据需要修改
|
||||
# 支持的基准测试:LoCoMo、LongMemEval、MemSciQA
|
||||
# ============================================================================
|
||||
|
||||
# ============================================================================
|
||||
# 通用配置(所有基准测试共用)
|
||||
# ============================================================================
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# Neo4j 配置
|
||||
# ----------------------------------------------------------------------------
|
||||
# 默认 Group ID(建议各基准测试使用独立的 group)
|
||||
EVAL_GROUP_ID=benchmark_default
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 模型配置(必需)
|
||||
# ----------------------------------------------------------------------------
|
||||
# ⚠️ 必填:从数据库 models 表中选择有效的模型 ID
|
||||
#
|
||||
# 如何获取模型 ID:
|
||||
# 1. 查询数据库:SELECT id, model_name FROM models WHERE is_active = true;
|
||||
# 2. 或通过系统管理界面查看
|
||||
# 3. 确保模型可用且配置正确
|
||||
|
||||
# LLM 模型 ID(必填)
|
||||
EVAL_LLM_ID=your_llm_model_id_here
|
||||
|
||||
# Embedding 模型 ID(必填)
|
||||
EVAL_EMBEDDING_ID=your_embedding_model_id_here
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 检索参数
|
||||
# ----------------------------------------------------------------------------
|
||||
# 检索类型: "keyword", "embedding", "hybrid"
|
||||
EVAL_SEARCH_TYPE=hybrid
|
||||
|
||||
# 检索结果数量限制(默认值)
|
||||
EVAL_SEARCH_LIMIT=12
|
||||
|
||||
# 上下文最大字符数(默认值)
|
||||
EVAL_MAX_CONTEXT_CHARS=8000
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# LLM 参数
|
||||
# ----------------------------------------------------------------------------
|
||||
# LLM 温度参数(0.0 = 确定性输出)
|
||||
EVAL_LLM_TEMPERATURE=0.0
|
||||
|
||||
# LLM 最大生成 token 数
|
||||
EVAL_LLM_MAX_TOKENS=32
|
||||
|
||||
# LLM 超时时间(秒)
|
||||
EVAL_LLM_TIMEOUT=10.0
|
||||
|
||||
# LLM 最大重试次数
|
||||
EVAL_LLM_MAX_RETRIES=1
|
||||
|
||||
# ----------------------------------------------------------------------------
|
||||
# 数据处理参数
|
||||
# ----------------------------------------------------------------------------
|
||||
# Chunker 策略
|
||||
EVAL_CHUNKER_STRATEGY=RecursiveChunker
|
||||
|
||||
# 是否在导入前清空现有数据
|
||||
EVAL_RESET_ON_INGEST=true
|
||||
|
||||
# 是否保存详细日志
|
||||
EVAL_SAVE_DETAILED_LOGS=true
|
||||
|
||||
# ============================================================================
|
||||
# LoCoMo 基准测试专用配置
|
||||
# ============================================================================
|
||||
# 数据集:locomo10.json
|
||||
# 运行:python locomo_benchmark.py --sample_size 20
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
# Group ID(LoCoMo 专用)
|
||||
LOCOMO_GROUP_ID=locomo_benchmark
|
||||
|
||||
# 测试样本数量
|
||||
# 建议值:20(快速测试)、100(中等测试)、1986(完整测试)
|
||||
LOCOMO_SAMPLE_SIZE=20
|
||||
|
||||
# 检索结果数量限制
|
||||
LOCOMO_SEARCH_LIMIT=12
|
||||
|
||||
# 上下文最大字符数
|
||||
LOCOMO_CONTEXT_CHAR_BUDGET=8000
|
||||
|
||||
# 导入的对话数量
|
||||
LOCOMO_MAX_DIALOGUES=1
|
||||
|
||||
# 跳过数据摄入(true=跳过,false=摄入)
|
||||
# 首次运行设置为 false,后续运行可设置为 true 以节省时间
|
||||
LOCOMO_SKIP_INGEST=false
|
||||
|
||||
# 结果保存目录
|
||||
LOCOMO_OUTPUT_DIR=locomo/results
|
||||
|
||||
# ============================================================================
|
||||
# LongMemEval 基准测试专用配置
|
||||
# ============================================================================
|
||||
# 数据集:longmemeval_oracle_zh.json
|
||||
# 运行:python longmemeval_benchmark.py --sample_size 3
|
||||
# 特点:支持时间推理问题的增强检索
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
# Group ID(LongMemEval 专用)
|
||||
LONGMEMEVAL_GROUP_ID=longmemeval_zh_bak_3
|
||||
|
||||
# 测试样本数量(<=0 表示全部样本)
|
||||
LONGMEMEVAL_SAMPLE_SIZE=3
|
||||
|
||||
# 起始样本索引
|
||||
LONGMEMEVAL_START_INDEX=0
|
||||
|
||||
# 检索结果数量限制
|
||||
LONGMEMEVAL_SEARCH_LIMIT=8
|
||||
|
||||
# 上下文最大字符数
|
||||
LONGMEMEVAL_CONTEXT_CHAR_BUDGET=4000
|
||||
|
||||
# LLM 最大生成 token 数
|
||||
LONGMEMEVAL_LLM_MAX_TOKENS=16
|
||||
|
||||
# 每条样本最多摄入的上下文段数
|
||||
LONGMEMEVAL_MAX_CONTEXTS_PER_ITEM=2
|
||||
|
||||
# 是否保存分块结果
|
||||
LONGMEMEVAL_SAVE_CHUNK_OUTPUT=true
|
||||
|
||||
# 自定义分块输出路径(留空使用默认)
|
||||
LONGMEMEVAL_SAVE_CHUNK_OUTPUT_PATH=
|
||||
|
||||
# 摄入前是否清空组数据
|
||||
LONGMEMEVAL_RESET_GROUP_BEFORE_INGEST=false
|
||||
|
||||
# 是否跳过摄入,仅检索评估
|
||||
LONGMEMEVAL_SKIP_INGEST=false
|
||||
|
||||
# 结果保存目录
|
||||
LONGMEMEVAL_OUTPUT_DIR=longmemeval/results
|
||||
|
||||
# ============================================================================
|
||||
# MemSciQA 基准测试专用配置
|
||||
# ============================================================================
|
||||
# 数据集:msc_self_instruct.jsonl
|
||||
# 运行:python memsciqa_benchmark.py --sample_size 1
|
||||
# 特点:对话记忆检索评估
|
||||
# ----------------------------------------------------------------------------
|
||||
|
||||
# Group ID(MemSciQA 专用,独立数据集)
|
||||
MEMSCIQA_GROUP_ID=memsciqa_benchmark
|
||||
|
||||
# 测试样本数量
|
||||
MEMSCIQA_SAMPLE_SIZE=1 # 0或者-1标识测试数据集中的所有样本
|
||||
|
||||
# 检索结果数量限制
|
||||
MEMSCIQA_SEARCH_LIMIT=8
|
||||
|
||||
# 上下文最大字符数
|
||||
MEMSCIQA_CONTEXT_CHAR_BUDGET=4000
|
||||
|
||||
# LLM 最大生成 token 数
|
||||
MEMSCIQA_LLM_MAX_TOKENS=64
|
||||
|
||||
# 跳过数据摄入(true=跳过,false=摄入)
|
||||
# 首次运行设置为 false,后续运行可设置为 true 以节省时间
|
||||
MEMSCIQA_SKIP_INGEST=false
|
||||
|
||||
# 结果保存目录(相对于 memsciqa 脚本所在目录)
|
||||
# 使用 "results" 会保存到 api/app/core/memory/evaluation/memsciqa/results/
|
||||
MEMSCIQA_OUTPUT_DIR=results
|
||||
|
||||
# ============================================================================
|
||||
# 高级配置(可选)
|
||||
# ============================================================================
|
||||
|
||||
# BM25 权重(用于混合检索,0.0-1.0)
|
||||
EVAL_RERANK_ALPHA=0.6
|
||||
|
||||
# 是否使用遗忘重排序
|
||||
EVAL_USE_FORGETTING_RERANK=false
|
||||
|
||||
# 是否使用 LLM 重排序
|
||||
EVAL_USE_LLM_RERANK=false
|
||||
|
||||
# 连接重置间隔(每 N 个问题重置一次)
|
||||
EVAL_RESET_INTERVAL=5
|
||||
|
||||
# 性能阈值(低于此值触发重置)
|
||||
EVAL_PERFORMANCE_THRESHOLD=0.6
|
||||
|
||||
# ============================================================================
|
||||
# 快速配置指南
|
||||
# ============================================================================
|
||||
# 1. 复制此文件为 .env.evaluation
|
||||
# 2. 修改 EVAL_LLM_ID 和 EVAL_EMBEDDING_ID 为你的模型 ID
|
||||
# 3. 根据需要修改各基准测试的专用配置
|
||||
# 4. 运行测试:
|
||||
# - LoCoMo: python locomo/locomo_benchmark.py --sample_size 20
|
||||
# - LongMemEval: python longmemeval/longmemeval_benchmark.py --sample_size 3 --all
|
||||
# - MemSciQA: python memsciqa/memsciqa_benchmark.py --sample_size 10
|
||||
# 配置优先级:
|
||||
# 命令行参数 > 特定配置(如 LOCOMO_*)> 通用配置(EVAL_*)> 代码默认值
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# 执行LoCoMo测试
|
||||
# 只摄入前5条消息,评估3个问题(最小测试)
|
||||
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
|
||||
#
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试
|
||||
# python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
|
||||
|
||||
|
||||
# 执行longmemeval测试
|
||||
# python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 10 --max-contexts-per-item 3 --reset-group-before-ingest
|
||||
|
||||
# 执行memsciqa测试
|
||||
# python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 1
|
||||
13
api/app/core/memory/evaluation/.gitignore
vendored
Normal file
13
api/app/core/memory/evaluation/.gitignore
vendored
Normal file
@@ -0,0 +1,13 @@
|
||||
# 忽略实际的评估配置文件(包含敏感信息)
|
||||
.env.evaluation
|
||||
|
||||
# 保留示例文件
|
||||
!.env.evaluation.example
|
||||
|
||||
# 忽略测试结果文件
|
||||
*/results/*.json
|
||||
*/results/*.log
|
||||
|
||||
# 忽略数据集文件(文件过大,不应提交到 Git)
|
||||
dataset/*.json
|
||||
dataset/*.jsonl
|
||||
@@ -1,30 +1,748 @@
|
||||
⏬数据集下载地址:
|
||||
Locomo10.json:https://github.com/snap-research/locomo/tree/main/data
|
||||
LongMemEval_oracle.json:https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
||||
msc_self_instruct.jsonl:https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
||||
上方数据集下载好后全部放入app/core/memory/data文件夹中
|
||||
# 1.数据集下载地址
|
||||
Locomo10.json : https://github.com/snap-research/locomo/tree/main/data
|
||||
LongMemEval_oracle.json : https://huggingface.co/datasets/xiaowu0162/longmemeval-cleaned
|
||||
msc_self_instruct.jsonl : https://huggingface.co/datasets/MemGPT/MSC-Self-Instruct
|
||||
|
||||
全流程基准测试运行:
|
||||
locomo:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset locomo --sample-size 1 --reset-group --group-id yyw1 --search-type hybrid --search-limit 8 --context-char-budget 12000 --llm-max-tokens 32
|
||||
LongMemEval:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset longmemeval --sample-size 10 --start-index 0 --group-id longmemeval_zh_bak_2 --search-limit 8 --context-char-budget 4000 --search-type hybrid --max-contexts-per-item 2 --reset-group
|
||||
memsciqa:
|
||||
python -m app.core.memory.evaluation.run_eval --dataset memsciqa --sample-size 10 --reset-group --group-id group_memsci
|
||||
数据集下载之后保存至api\app\core\memory\evaluation\dataset目录下
|
||||
# 2.配置说明
|
||||
文件api\app\core\memory\evaluation\.env.evaluation.example对三个基准测试所需配置有着详细的说明
|
||||
**实际配置文件**:api\app\core\memory\evaluation\.env.evaluation
|
||||
```python
|
||||
# 当使用不带配置参数的命令行执行基准测试,基准测试所需的配置参数根据.env.evaluation中的参数执行
|
||||
python -m app.core.memory.evaluation.locomo.locomo_benchmark
|
||||
```
|
||||
**检查neo4j指定的grou_id是否摄入数据**
|
||||
```python
|
||||
# 1. 进入交互模式
|
||||
python -m app.core.memory.evaluation.check_enduser_data
|
||||
|
||||
单独检索评估运行命令:
|
||||
python -m app.core.memory.evaluation.locomo.locomo_test
|
||||
python -m app.core.memory.evaluation.longmemeval.test_eval
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa-test
|
||||
需要先在项目中修改需要检测评估的group_id。
|
||||
# 2. 选择 "1" 检查指定 group
|
||||
# 3. 输入 group_id,例如: locomo_benchmark
|
||||
# 4. 选择是否显示详细统计 (y/n)
|
||||
```
|
||||
# 3.locomo
|
||||
|
||||
参数及解释:
|
||||
● --dataset longmemeval - 指定数据集
|
||||
● --sample-size 10 - 评估10个样本
|
||||
● --start-index 0 - 从第0个样本开始
|
||||
● --group-id longmemeval_zh_bak_2 - 使用指定的组ID
|
||||
● --search-limit 8 - 检索限制8条
|
||||
● --context-char-budget 4000 - 上下文字符预算4000
|
||||
● --search-type hybrid - 使用混合检索
|
||||
● --max-contexts-per-item 2 - 每个样本最多摄入2个上下文
|
||||
● --reset-group - 运行前清空组数据
|
||||
### (1)locomo执行命令
|
||||
```python
|
||||
# 首先进入api目录
|
||||
cd api
|
||||
|
||||
# 只摄入前5条消息,评估3个问题(最小测试)
|
||||
python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 3 --max_ingest_messages 5
|
||||
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
|
||||
python -m app.core.memory.evaluation.locomo.locomo_benchmark --sample_size 5 --skip_ingest
|
||||
```
|
||||
### (2)locomo结果说明
|
||||
|
||||
#### 结果示例
|
||||
```json
|
||||
{
|
||||
"dataset": "locomo",
|
||||
"sample_size": 0,
|
||||
"timestamp": "2026-01-26T11:24:28.239156",
|
||||
"params": {
|
||||
"group_id": "locomo_benchmark",
|
||||
"search_type": "hybrid",
|
||||
"search_limit": 12,
|
||||
"context_char_budget": 8000,
|
||||
"llm_id": "2c9b0782-7a85-4740-ba84-4baf77f256c4",
|
||||
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2"
|
||||
},
|
||||
"overall_metrics": {
|
||||
"f1": 0.0,
|
||||
"bleu1": 0.0,
|
||||
"jaccard": 0.0,
|
||||
"locomo_f1": 0.0
|
||||
},
|
||||
"by_category": {},
|
||||
"latency": {
|
||||
"search": {
|
||||
"mean": 0.0,
|
||||
"p50": 0.0,
|
||||
"p95": 0.0,
|
||||
"iqr": 0.0
|
||||
},
|
||||
"llm": {
|
||||
"mean": 0.0,
|
||||
"p50": 0.0,
|
||||
"p95": 0.0,
|
||||
"iqr": 0.0
|
||||
}
|
||||
},
|
||||
"context_stats": {
|
||||
"avg_retrieved_docs": 0.0,
|
||||
"avg_context_chars": 0.0,
|
||||
"avg_context_tokens": 0.0
|
||||
},
|
||||
"samples": []
|
||||
}
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
##### 1. 核心评估指标 (overall_metrics)
|
||||
|
||||
**🎯 关键进步指标:**
|
||||
|
||||
- **`f1`** (F1 Score): 精确率和召回率的调和平均值
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量检索和生成答案的准确性
|
||||
- 这是最重要的综合性能指标
|
||||
- 优秀标准:> 0.85
|
||||
|
||||
- **`bleu1`** (BLEU-1): 单词级别的匹配度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量生成答案与标准答案的词汇重叠度
|
||||
- 关注词汇层面的准确性
|
||||
|
||||
- **`jaccard`** (Jaccard 相似度): 集合相似度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量答案集合的相似性
|
||||
- 计算公式:交集大小 / 并集大小
|
||||
|
||||
- **`locomo_f1`**: Locomo 特定的 F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,针对 Locomo 数据集优化的评估指标
|
||||
- 考虑了长对话记忆的特殊性
|
||||
|
||||
##### 2. 性能指标 (latency)
|
||||
|
||||
**⚡ 关键效率指标:**
|
||||
|
||||
- **`search`**: 检索延迟统计(单位:毫秒)
|
||||
- `mean`: 平均延迟
|
||||
- `p50`: 中位数延迟(50%的请求在此时间内完成)
|
||||
- `p95`: 95分位数延迟(95%的请求在此时间内完成)
|
||||
- `iqr`: 四分位距(Q3-Q1,衡量稳定性)
|
||||
- **越低越好**,衡量记忆检索速度
|
||||
- 优秀标准:p95 < 2000ms
|
||||
|
||||
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
|
||||
- `mean`: 平均推理时间
|
||||
- `p50`: 中位数推理时间
|
||||
- `p95`: 95分位数推理时间
|
||||
- `iqr`: 四分位距(越小越稳定)
|
||||
- **越低越好**,衡量答案生成速度
|
||||
- 优秀标准:p95 < 3000ms
|
||||
|
||||
##### 3. 上下文统计 (context_stats)
|
||||
|
||||
**📊 资源效率指标:**
|
||||
|
||||
- **`avg_retrieved_docs`**: 平均检索文档数
|
||||
- 反映检索策略的广度
|
||||
- 需要平衡:太少可能信息不足,太多增加噪音和延迟
|
||||
- 建议范围:8-15 个文档
|
||||
|
||||
- **`avg_context_chars`**: 平均上下文字符数
|
||||
- 反映检索内容的总量
|
||||
- 应在满足准确性前提下尽量精简
|
||||
- 受 `context_char_budget` 参数限制
|
||||
|
||||
- **`avg_context_tokens`**: 平均上下文 token 数
|
||||
- **越低越好**(在保持准确性前提下)
|
||||
- 直接影响 API 调用成本和推理速度
|
||||
- 成本效益比 = f1 / avg_context_tokens
|
||||
|
||||
##### 4. 分类统计 (by_category)
|
||||
|
||||
- 按问题类型分类的性能指标
|
||||
- 帮助识别系统在不同场景下的强弱项
|
||||
- 可针对性优化特定类型的问题
|
||||
|
||||
#### 系统进步衡量标准
|
||||
|
||||
**一级指标(最重要):**
|
||||
- `f1` 和 `locomo_f1` 提升 → 核心能力提升
|
||||
- 目标:f1 > 0.85
|
||||
|
||||
**二级指标(重要):**
|
||||
- `latency.p95` 降低 → 用户体验提升
|
||||
- 目标:search.p95 < 2000ms, llm.p95 < 3000ms
|
||||
|
||||
**三级指标(辅助):**
|
||||
- `avg_context_tokens` 降低(在保持 f1 前提下)→ 成本优化
|
||||
- `iqr` 降低 → 性能稳定性提升
|
||||
# 4.longmemeval
|
||||
支持时间推理问题的增强检索
|
||||
### (1)执行命令
|
||||
```python
|
||||
# 首先进入api目录
|
||||
cd api
|
||||
|
||||
# 不带参数运行 - 使用环境变量
|
||||
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark
|
||||
|
||||
# 命令行参数覆盖环境变量
|
||||
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --sample-size 2
|
||||
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
|
||||
python -m app.core.memory.evaluation.longmemeval.longmemeval_benchmark --skip_ingest
|
||||
```
|
||||
### (2)结果说明
|
||||
|
||||
#### 结果示例
|
||||
```json
|
||||
{
|
||||
"dataset": "longmemeval",
|
||||
"items": 1,
|
||||
"accuracy_by_type": {
|
||||
"single-session-user": 1.0
|
||||
},
|
||||
"f1_by_type": {
|
||||
"single-session-user": 1.0
|
||||
},
|
||||
"jaccard_by_type": {
|
||||
"single-session-user": 1.0
|
||||
},
|
||||
"samples": [
|
||||
{
|
||||
"question": "What degree did I graduate with?",
|
||||
"prediction": "Business Administration",
|
||||
"answer": "Business Administration",
|
||||
"question_type": "single-session-user",
|
||||
"is_temporal": false,
|
||||
"question_id": "e47becba",
|
||||
"options": [],
|
||||
"context_count": 13,
|
||||
"context_chars": 1268,
|
||||
"retrieved_dialogue_count": 0,
|
||||
"retrieved_statement_count": 12,
|
||||
"metrics": {
|
||||
"exact_match": true,
|
||||
"f1": 1.0,
|
||||
"jaccard": 1.0
|
||||
},
|
||||
"timing": {
|
||||
"search_ms": 1483.100175857544,
|
||||
"llm_ms": 995.8682060241699
|
||||
}
|
||||
}
|
||||
],
|
||||
"latency": {
|
||||
"search": {
|
||||
"mean": 1483.100175857544,
|
||||
"p50": 1483.100175857544,
|
||||
"p95": 1483.100175857544,
|
||||
"iqr": 0.0
|
||||
},
|
||||
"llm": {
|
||||
"mean": 995.8682060241699,
|
||||
"p50": 995.8682060241699,
|
||||
"p95": 995.8682060241699,
|
||||
"iqr": 0.0
|
||||
}
|
||||
},
|
||||
"context": {
|
||||
"avg_tokens": 204.0,
|
||||
"avg_chars": 1268,
|
||||
"count_avg": 13
|
||||
},
|
||||
"params": {
|
||||
"group_id": "longmemeval_zh_bak_3",
|
||||
"search_limit": 8,
|
||||
"context_char_budget": 4000,
|
||||
"search_type": "hybrid",
|
||||
"llm_id": "6dc52e1b-9cec-4194-af66-a74c6307fc3f",
|
||||
"embedding_id": "e2a6392d-ca63-4d59-a523-647420b59cb2",
|
||||
"sample_size": 1,
|
||||
"start_index": 0
|
||||
},
|
||||
"timestamp": "2026-01-24T21:36:10.818308",
|
||||
"metric_summary": {
|
||||
"score_accuracy": 100.0,
|
||||
"latency_median_s": 2.478968381881714,
|
||||
"latency_iqr_s": 0.0,
|
||||
"avg_context_tokens_k": 0.204
|
||||
},
|
||||
"diagnostics": {
|
||||
"duplicate_previews_top": [],
|
||||
"unique_preview_count": 1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
##### 1. 核心评估指标
|
||||
|
||||
**🎯 关键进步指标:**
|
||||
|
||||
- **`accuracy_by_type`**: 按问题类型分类的准确率
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,1.0 表示 100% 准确
|
||||
- 问题类型包括:
|
||||
- `single-session-user`: 单会话用户信息
|
||||
- `single-session-event`: 单会话事件信息
|
||||
- `multi-session-user`: 多会话用户信息
|
||||
- `multi-session-event`: 多会话事件信息
|
||||
- 可以识别系统在不同场景下的强弱项
|
||||
|
||||
- **`f1_by_type`**: 按问题类型的 F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,综合评估精确率和召回率
|
||||
- 比单纯的准确率更全面
|
||||
|
||||
- **`jaccard_by_type`**: 按问题类型的 Jaccard 相似度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量答案集合匹配度
|
||||
- 对于集合类答案特别有用
|
||||
|
||||
##### 2. 样本级指标 (samples)
|
||||
|
||||
**详细诊断指标:**
|
||||
|
||||
- **`metrics.exact_match`**: 精确匹配(布尔值)
|
||||
- **true 越多越好**,最严格的评估标准
|
||||
- 要求预测答案与标准答案完全一致
|
||||
|
||||
- **`metrics.f1`**: 单个样本的 F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量单个问题的回答质量
|
||||
|
||||
- **`is_temporal`**: 是否为时间推理问题
|
||||
- 布尔值,标识问题是否涉及时间推理
|
||||
- 时间推理问题通常更具挑战性
|
||||
|
||||
- **`context_count`**: 检索到的上下文数量
|
||||
- 反映检索策略的有效性
|
||||
- 建议范围:8-15 个上下文片段
|
||||
|
||||
- **`retrieved_dialogue_count`**: 检索到的对话数
|
||||
- **`retrieved_statement_count`**: 检索到的陈述数
|
||||
- 这两个指标帮助理解检索的内容类型分布
|
||||
- 可用于优化检索策略
|
||||
|
||||
- **`timing.search_ms`**: 单个问题的检索延迟(毫秒)
|
||||
- **`timing.llm_ms`**: 单个问题的 LLM 推理延迟(毫秒)
|
||||
- **越低越好**,反映单次查询的响应速度
|
||||
|
||||
##### 3. 汇总指标 (metric_summary)
|
||||
|
||||
**📊 关键 KPI:**
|
||||
|
||||
- **`score_accuracy`**: 总体准确率百分比
|
||||
- 范围:0.0 - 100.0
|
||||
- **越高越好**,最直观的性能指标
|
||||
- 优秀标准:> 90.0
|
||||
|
||||
- **`latency_median_s`**: 中位延迟(秒)
|
||||
- **越低越好**,反映真实响应速度
|
||||
- 优秀标准:< 3.0 秒
|
||||
|
||||
- **`latency_iqr_s`**: 延迟四分位距(秒)
|
||||
- **越低越好**,反映性能稳定性
|
||||
- 越小说明响应时间越稳定
|
||||
|
||||
- **`avg_context_tokens_k`**: 平均上下文 token 数(千)
|
||||
- **越低越好**(在保持准确性前提下)
|
||||
- 直接影响 API 调用成本
|
||||
- 成本效益比 = score_accuracy / (avg_context_tokens_k * 1000)
|
||||
|
||||
##### 4. 上下文统计 (context)
|
||||
|
||||
- **`avg_tokens`**: 平均 token 数
|
||||
- **`avg_chars`**: 平均字符数
|
||||
- **`count_avg`**: 平均上下文片段数
|
||||
- 这些指标反映检索内容的规模
|
||||
- 需要在准确性和效率之间平衡
|
||||
|
||||
##### 5. 性能指标 (latency)
|
||||
|
||||
**⚡ 效率指标:**
|
||||
|
||||
- **`search`**: 检索延迟统计(单位:毫秒)
|
||||
- `mean`: 平均延迟
|
||||
- `p50`: 中位数延迟
|
||||
- `p95`: 95分位数延迟
|
||||
- `iqr`: 四分位距
|
||||
- **越低越好**,衡量记忆检索速度
|
||||
|
||||
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
|
||||
- `mean`: 平均推理时间
|
||||
- `p50`: 中位数推理时间
|
||||
- `p95`: 95分位数推理时间
|
||||
- `iqr`: 四分位距
|
||||
- **越低越好**,衡量答案生成速度
|
||||
|
||||
##### 6. 诊断信息 (diagnostics)
|
||||
|
||||
- **`duplicate_previews_top`**: 重复预览统计
|
||||
- 列出出现频率最高的重复内容
|
||||
- 帮助发现检索冗余问题
|
||||
- 应该尽量减少重复
|
||||
|
||||
- **`unique_preview_count`**: 唯一预览数量
|
||||
- 反映检索多样性
|
||||
- **越高越好**,说明检索到的内容更丰富
|
||||
|
||||
#### 系统进步衡量标准
|
||||
|
||||
**一级指标(最重要):**
|
||||
- `score_accuracy` 提升 → 核心能力提升
|
||||
- 目标:> 90.0%
|
||||
- 各类型的 `accuracy_by_type` 均衡提升 → 全面能力提升
|
||||
|
||||
**二级指标(重要):**
|
||||
- `latency_median_s` 降低 → 用户体验提升
|
||||
- 目标:< 3.0 秒
|
||||
- `exact_match` 比例提升 → 精确度提升
|
||||
|
||||
**三级指标(辅助):**
|
||||
- `avg_context_tokens_k` 降低(在保持准确性前提下)→ 成本优化
|
||||
- `unique_preview_count` 提升 → 检索多样性提升
|
||||
- `latency_iqr_s` 降低 → 性能稳定性提升
|
||||
|
||||
**特殊关注:**
|
||||
- 时间推理问题(`is_temporal: true`)的准确率
|
||||
- 多会话问题的准确率(通常更具挑战性)
|
||||
# 5.memsciqa
|
||||
对话记忆检索评估
|
||||
### (1)执行命令
|
||||
```python
|
||||
# 首先进入api目录
|
||||
cd api
|
||||
|
||||
# 不带参数运行 - 使用环境变量
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark
|
||||
|
||||
# 命令行参数覆盖环境变量
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --sample-size 100
|
||||
|
||||
# 如果数据已经摄入,跳过摄入阶段直接测试(使用skip_ingest参数)
|
||||
python -m app.core.memory.evaluation.memsciqa.memsciqa_benchmark --skip_ingest
|
||||
```
|
||||
### (2)结果说明
|
||||
|
||||
#### 结果示例
|
||||
```json
|
||||
{
|
||||
"dataset": "memsciqa",
|
||||
"items": 1,
|
||||
"metrics": {
|
||||
"accuracy": 0.0,
|
||||
"f1": 0.0,
|
||||
"bleu1": 0.0,
|
||||
"jaccard": 0.0
|
||||
},
|
||||
"latency": {
|
||||
"search": {
|
||||
"mean": 0.0,
|
||||
"p50": 0.0,
|
||||
"p95": 0.0,
|
||||
"iqr": 0.0
|
||||
},
|
||||
"llm": {
|
||||
"mean": 3067.7285194396973,
|
||||
"p50": 3067.7285194396973,
|
||||
"p95": 3067.7285194396973,
|
||||
"iqr": 0.0
|
||||
}
|
||||
},
|
||||
"avg_context_tokens": 4.0
|
||||
}
|
||||
```
|
||||
|
||||
#### 参数详解
|
||||
|
||||
##### 1. 核心评估指标 (metrics)
|
||||
|
||||
**🎯 关键进步指标:**
|
||||
|
||||
- **`accuracy`**: 准确率
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,最直接的性能指标
|
||||
- 衡量系统回答正确的问题比例
|
||||
- 优秀标准:> 0.85
|
||||
|
||||
- **`f1`**: F1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,平衡精确率和召回率
|
||||
- 计算公式:2 * (precision * recall) / (precision + recall)
|
||||
- 比单纯的准确率更全面,特别适合不平衡数据集
|
||||
|
||||
- **`bleu1`**: BLEU-1 分数
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量词汇级别的匹配度
|
||||
- 关注生成答案与标准答案的单词重叠
|
||||
- 源自机器翻译评估,适用于自然语言生成
|
||||
|
||||
- **`jaccard`**: Jaccard 相似度
|
||||
- 范围:0.0 - 1.0
|
||||
- **越高越好**,衡量集合相似性
|
||||
- 计算公式:|A ∩ B| / |A ∪ B|
|
||||
- 对于多答案或集合类问题特别有用
|
||||
|
||||
##### 2. 性能指标 (latency)
|
||||
|
||||
**⚡ 效率指标:**
|
||||
|
||||
- **`search`**: 检索延迟统计(单位:毫秒)
|
||||
- `mean`: 平均检索延迟
|
||||
- `p50`: 中位数延迟(50%的请求在此时间内完成)
|
||||
- `p95`: 95分位数延迟(95%的请求在此时间内完成)
|
||||
- `iqr`: 四分位距(Q3-Q1,衡量稳定性)
|
||||
- **越低越好**,衡量记忆检索效率
|
||||
- 优秀标准:p95 < 2000ms
|
||||
|
||||
- **`llm`**: LLM 推理延迟统计(单位:毫秒)
|
||||
- `mean`: 平均推理时间
|
||||
- `p50`: 中位数推理时间
|
||||
- `p95`: 95分位数推理时间
|
||||
- `iqr`: 四分位距(越小越稳定)
|
||||
- **越低越好**,衡量答案生成速度
|
||||
- 优秀标准:p95 < 3000ms
|
||||
- 注意:LLM 延迟通常占总延迟的大部分
|
||||
|
||||
##### 3. 资源指标
|
||||
|
||||
- **`avg_context_tokens`**: 平均上下文 token 数
|
||||
- **越低越好**(在保持准确性前提下)
|
||||
- 直接影响:
|
||||
- API 调用成本(按 token 计费)
|
||||
- 推理速度(token 越多越慢)
|
||||
- 上下文窗口占用
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
- 建议范围:根据模型上下文窗口和成本预算调整
|
||||
|
||||
##### 4. 数据集特点
|
||||
|
||||
- **`items`**: 评估的问题数量
|
||||
- 样本量越大,评估结果越可靠
|
||||
- 建议至少 100 个样本以获得稳定的评估结果
|
||||
|
||||
- **对话记忆特性**:
|
||||
- MemSciQA 专注于对话历史中的记忆检索
|
||||
- 评估系统从多轮对话中提取和回忆信息的能力
|
||||
- 模拟真实的对话场景
|
||||
|
||||
#### 系统进步衡量标准
|
||||
|
||||
**一级指标(最重要):**
|
||||
- `accuracy` 提升 → 核心能力提升
|
||||
- 目标:> 0.85
|
||||
- `f1` 提升 → 综合性能提升
|
||||
- 目标:> 0.80
|
||||
|
||||
**二级指标(重要):**
|
||||
- `latency.p95` 降低 → 用户体验提升
|
||||
- search.p95 目标:< 2000ms
|
||||
- llm.p95 目标:< 3000ms
|
||||
- `iqr` 降低 → 性能稳定性提升
|
||||
|
||||
**三级指标(辅助):**
|
||||
- `avg_context_tokens` 降低(在保持准确性前提下)→ 成本优化
|
||||
- `bleu1` 和 `jaccard` 提升 → 答案质量提升
|
||||
|
||||
**综合评估:**
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
- 该比值越高,说明系统在相同成本下性能越好
|
||||
- 总延迟 = search.p95 + llm.p95
|
||||
- 应控制在 5 秒以内以保证良好的用户体验
|
||||
|
||||
#### 优化建议
|
||||
|
||||
**提升准确性:**
|
||||
- 优化检索算法(调整 hybrid search 参数)
|
||||
- 改进 embedding 模型质量
|
||||
- 增加检索上下文数量(`search_limit`)
|
||||
- 优化 prompt 工程
|
||||
|
||||
**提升效率:**
|
||||
- 减少不必要的检索文档
|
||||
- 使用更快的 LLM 模型或量化版本
|
||||
- 实施缓存策略(相似问题复用结果)
|
||||
- 优化数据库索引
|
||||
|
||||
**平衡性能:**
|
||||
- 监控 accuracy vs latency 的权衡
|
||||
- 监控 accuracy vs cost (tokens) 的权衡
|
||||
- 根据业务需求调整优先级
|
||||
|
||||
|
||||
---
|
||||
|
||||
# 6. 三个基准测试对比总结
|
||||
|
||||
## 6.1 测试特点对比
|
||||
|
||||
| 基准测试 | 主要评估目标 | 数据集特点 | 适用场景 |
|
||||
|---------|------------|-----------|---------|
|
||||
| **Locomo** | 长对话记忆检索 | 长对话历史,多轮交互 | 评估长期记忆保持和检索能力 |
|
||||
| **LongMemEval** | 时间推理和多会话记忆 | 支持时间推理,多会话场景 | 评估时间感知和跨会话记忆能力 |
|
||||
| **MemSciQA** | 对话记忆问答 | 对话历史问答 | 评估对话上下文理解和记忆提取 |
|
||||
|
||||
## 6.2 核心指标对比
|
||||
|
||||
### 准确性指标
|
||||
|
||||
| 指标 | Locomo | LongMemEval | MemSciQA | 说明 |
|
||||
|-----|--------|-------------|----------|------|
|
||||
| **F1 Score** | ✅ | ✅ | ✅ | 所有测试都使用,最重要的综合指标 |
|
||||
| **Accuracy** | ❌ | ✅ | ✅ | 直观的准确率指标 |
|
||||
| **BLEU-1** | ✅ | ❌ | ✅ | 词汇级别匹配度 |
|
||||
| **Jaccard** | ✅ | ✅ | ✅ | 集合相似度 |
|
||||
| **Exact Match** | ❌ | ✅ | ❌ | 最严格的评估标准 |
|
||||
|
||||
### 性能指标
|
||||
|
||||
所有三个测试都包含:
|
||||
- **检索延迟** (search latency): mean, p50, p95, iqr
|
||||
- **LLM 延迟** (llm latency): mean, p50, p95, iqr
|
||||
- **上下文统计**: token 数、字符数、文档数
|
||||
|
||||
## 6.3 关键进步指标优先级
|
||||
|
||||
### 🥇 一级指标(必须关注)
|
||||
|
||||
1. **准确性指标**
|
||||
- Locomo: `f1`, `locomo_f1`
|
||||
- LongMemEval: `score_accuracy`, `accuracy_by_type`
|
||||
- MemSciQA: `accuracy`, `f1`
|
||||
- **目标**: > 85% 或 > 0.85
|
||||
|
||||
2. **综合性能**
|
||||
- 所有测试的 F1 分数应保持一致性
|
||||
- 不同类型问题的准确率应均衡
|
||||
|
||||
### 🥈 二级指标(重要)
|
||||
|
||||
3. **响应延迟**
|
||||
- `latency.p95` (95分位数延迟)
|
||||
- **目标**:
|
||||
- search.p95 < 2000ms
|
||||
- llm.p95 < 3000ms
|
||||
- 总延迟 < 5000ms
|
||||
|
||||
4. **性能稳定性**
|
||||
- `iqr` (四分位距)
|
||||
- **目标**: 越小越好,说明性能稳定
|
||||
|
||||
### 🥉 三级指标(优化)
|
||||
|
||||
5. **成本效率**
|
||||
- `avg_context_tokens`
|
||||
- **目标**: 在保持准确性前提下最小化
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
|
||||
6. **检索质量**
|
||||
- `avg_retrieved_docs` 的合理性
|
||||
- `unique_preview_count` (LongMemEval)
|
||||
- 检索内容的多样性和相关性
|
||||
|
||||
## 6.4 系统优化路径
|
||||
|
||||
### 阶段一:提升准确性(优先级最高)
|
||||
|
||||
**目标**: 所有测试的准确率 > 85%
|
||||
|
||||
**优化方向**:
|
||||
1. 改进 embedding 模型质量
|
||||
2. 优化检索算法(hybrid search 参数)
|
||||
3. 增加检索上下文数量(`search_limit`)
|
||||
4. 优化 prompt 工程
|
||||
5. 改进记忆存储结构
|
||||
|
||||
**监控指标**:
|
||||
- Locomo: `f1`, `locomo_f1`
|
||||
- LongMemEval: `score_accuracy`, `exact_match` 比例
|
||||
- MemSciQA: `accuracy`, `f1`
|
||||
|
||||
### 阶段二:优化性能(准确性达标后)
|
||||
|
||||
**目标**: p95 延迟 < 5 秒,性能稳定
|
||||
|
||||
**优化方向**:
|
||||
1. 优化数据库索引和查询
|
||||
2. 实施缓存策略
|
||||
3. 使用更快的 LLM 模型
|
||||
4. 并行化检索和推理
|
||||
5. 减少不必要的检索
|
||||
|
||||
**监控指标**:
|
||||
- `latency.p50`, `latency.p95`
|
||||
- `iqr` (稳定性)
|
||||
- 各阶段耗时分布
|
||||
|
||||
### 阶段三:降低成本(性能达标后)
|
||||
|
||||
**目标**: 在保持准确性和性能前提下,最小化成本
|
||||
|
||||
**优化方向**:
|
||||
1. 精简检索上下文
|
||||
2. 优化 context 选择策略
|
||||
3. 使用更小的 LLM 模型
|
||||
4. 实施智能缓存
|
||||
5. 批处理优化
|
||||
|
||||
**监控指标**:
|
||||
- `avg_context_tokens`
|
||||
- 成本效益比 = accuracy / avg_context_tokens
|
||||
- API 调用成本
|
||||
|
||||
## 6.5 评估最佳实践
|
||||
|
||||
### 测试执行建议
|
||||
|
||||
1. **初始测试**: 使用小样本快速验证
|
||||
```bash
|
||||
--sample_size 10
|
||||
```
|
||||
|
||||
2. **完整评估**: 使用足够大的样本量
|
||||
```bash
|
||||
--sample_size 100 # 或更多
|
||||
```
|
||||
|
||||
3. **增量测试**: 数据已摄入时跳过摄入阶段
|
||||
```bash
|
||||
--skip_ingest
|
||||
```
|
||||
|
||||
4. **参数调优**: 系统性地调整参数并记录结果
|
||||
- 调整 `search_limit`: 4, 8, 12, 16
|
||||
- 调整 `context_char_budget`: 2000, 4000, 8000
|
||||
- 尝试不同的 `search_type`: vector, keyword, hybrid
|
||||
|
||||
### 结果分析建议
|
||||
|
||||
1. **横向对比**: 比较三个测试的结果,识别系统的强弱项
|
||||
2. **纵向对比**: 跟踪同一测试在不同版本的表现
|
||||
3. **分类分析**: 关注不同问题类型的性能差异
|
||||
4. **异常诊断**: 分析失败案例,找出根本原因
|
||||
|
||||
### 持续监控
|
||||
|
||||
建议建立监控仪表板,跟踪:
|
||||
- 核心指标趋势(准确率、延迟)
|
||||
- 成本效益比趋势
|
||||
- 不同问题类型的性能分布
|
||||
- 异常样本和失败模式
|
||||
|
||||
## 6.6 性能基准参考
|
||||
|
||||
### 优秀水平(Production Ready)
|
||||
|
||||
- **准确性**: accuracy/f1 > 0.90
|
||||
- **延迟**: p95 < 3 秒
|
||||
- **稳定性**: iqr < 500ms
|
||||
- **成本效益**: accuracy/tokens > 0.0001
|
||||
|
||||
### 良好水平(Acceptable)
|
||||
|
||||
- **准确性**: accuracy/f1 > 0.85
|
||||
- **延迟**: p95 < 5 秒
|
||||
- **稳定性**: iqr < 1000ms
|
||||
- **成本效益**: accuracy/tokens > 0.00005
|
||||
|
||||
### 需要改进(Below Target)
|
||||
|
||||
- **准确性**: accuracy/f1 < 0.85
|
||||
- **延迟**: p95 > 5 秒
|
||||
- **稳定性**: iqr > 1000ms
|
||||
- **成本效益**: accuracy/tokens < 0.00005
|
||||
|
||||
---
|
||||
|
||||
**注**: 以上标准仅供参考,实际目标应根据具体业务需求和资源约束调整。
|
||||
|
||||
371
api/app/core/memory/evaluation/check_enduser_data.py
Normal file
371
api/app/core/memory/evaluation/check_enduser_data.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
交互式 Neo4j End User 数据检查工具
|
||||
|
||||
用于查询指定 end_user_id 在 Neo4j 中是否存在数据,以及数据的详细统计信息。
|
||||
|
||||
使用方法:
|
||||
python check_group_data.py
|
||||
python check_group_data.py --group-id locomo_benchmark
|
||||
python check_group_data.py --group-id memsciqa_benchmark --detailed
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}\n")
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def check_group_exists(end_user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
检查指定 end_user_id 是否存在数据
|
||||
|
||||
Args:
|
||||
end_user_id: 要检查的 end_user ID
|
||||
|
||||
Returns:
|
||||
包含统计信息的字典
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
# 查询该 end_user 的节点总数
|
||||
query_total = """
|
||||
MATCH (n {end_user_id: $end_user_id})
|
||||
RETURN count(n) as total_nodes
|
||||
"""
|
||||
result_total = await connector.execute_query(query_total, end_user_id=end_user_id)
|
||||
total_nodes = result_total[0]["total_nodes"] if result_total else 0
|
||||
|
||||
# 查询各类型节点的数量
|
||||
query_by_type = """
|
||||
MATCH (n {end_user_id: $end_user_id})
|
||||
RETURN labels(n) as labels, count(n) as count
|
||||
ORDER BY count DESC
|
||||
"""
|
||||
result_by_type = await connector.execute_query(query_by_type, end_user_id=end_user_id)
|
||||
|
||||
# 查询关系数量
|
||||
query_relationships = """
|
||||
MATCH (n {end_user_id: $end_user_id})-[r]-()
|
||||
RETURN count(DISTINCT r) as total_relationships
|
||||
"""
|
||||
result_rel = await connector.execute_query(query_relationships, end_user_id=end_user_id)
|
||||
total_relationships = result_rel[0]["total_relationships"] if result_rel else 0
|
||||
|
||||
return {
|
||||
"exists": total_nodes > 0,
|
||||
"total_nodes": total_nodes,
|
||||
"total_relationships": total_relationships,
|
||||
"nodes_by_type": result_by_type
|
||||
}
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def get_detailed_stats(end_user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取详细的统计信息
|
||||
|
||||
Args:
|
||||
end_user_id: 要检查的 end_user ID
|
||||
|
||||
Returns:
|
||||
详细统计信息字典
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
stats = {}
|
||||
|
||||
# Chunk 节点统计
|
||||
query_chunks = """
|
||||
MATCH (c:Chunk {end_user_id: $end_user_id})
|
||||
RETURN count(c) as count,
|
||||
avg(size(c.content)) as avg_content_length
|
||||
"""
|
||||
result_chunks = await connector.execute_query(query_chunks, end_user_id=end_user_id)
|
||||
if result_chunks and result_chunks[0]["count"] > 0:
|
||||
stats["chunks"] = {
|
||||
"count": result_chunks[0]["count"],
|
||||
"avg_content_length": int(result_chunks[0]["avg_content_length"]) if result_chunks[0]["avg_content_length"] else 0
|
||||
}
|
||||
|
||||
# Statement 节点统计
|
||||
query_statements = """
|
||||
MATCH (s:Statement {end_user_id: $end_user_id})
|
||||
RETURN count(s) as count
|
||||
"""
|
||||
result_statements = await connector.execute_query(query_statements, end_user_id=end_user_id)
|
||||
if result_statements and result_statements[0]["count"] > 0:
|
||||
stats["statements"] = {
|
||||
"count": result_statements[0]["count"]
|
||||
}
|
||||
|
||||
# Entity 节点统计
|
||||
query_entities = """
|
||||
MATCH (e:Entity {end_user_id: $end_user_id})
|
||||
RETURN count(e) as count,
|
||||
count(DISTINCT e.entity_type) as unique_types
|
||||
"""
|
||||
result_entities = await connector.execute_query(query_entities, end_user_id=end_user_id)
|
||||
if result_entities and result_entities[0]["count"] > 0:
|
||||
stats["entities"] = {
|
||||
"count": result_entities[0]["count"],
|
||||
"unique_types": result_entities[0]["unique_types"]
|
||||
}
|
||||
|
||||
# Dialogue 节点统计
|
||||
query_dialogues = """
|
||||
MATCH (d:Dialogue {end_user_id: $end_user_id})
|
||||
RETURN count(d) as count
|
||||
"""
|
||||
result_dialogues = await connector.execute_query(query_dialogues, end_user_id=end_user_id)
|
||||
if result_dialogues and result_dialogues[0]["count"] > 0:
|
||||
stats["dialogues"] = {
|
||||
"count": result_dialogues[0]["count"]
|
||||
}
|
||||
|
||||
# Summary 节点统计
|
||||
query_summaries = """
|
||||
MATCH (s:Summary {end_user_id: $end_user_id})
|
||||
RETURN count(s) as count
|
||||
"""
|
||||
result_summaries = await connector.execute_query(query_summaries, end_user_id=end_user_id)
|
||||
if result_summaries and result_summaries[0]["count"] > 0:
|
||||
stats["summaries"] = {
|
||||
"count": result_summaries[0]["count"]
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def list_all_end_users() -> list:
|
||||
"""
|
||||
列出数据库中所有的 end_user_id
|
||||
|
||||
Returns:
|
||||
end_user_id 列表及其节点数量
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE n.end_user_id IS NOT NULL
|
||||
RETURN DISTINCT n.end_user_id as end_user_id, count(n) as node_count
|
||||
ORDER BY node_count DESC
|
||||
"""
|
||||
results = await connector.execute_query(query)
|
||||
return results
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
def print_results(end_user_id: str, stats: Dict[str, Any], detailed_stats: Dict[str, Any] = None):
|
||||
"""
|
||||
打印查询结果
|
||||
|
||||
Args:
|
||||
end_user_id: End User ID
|
||||
stats: 基本统计信息
|
||||
detailed_stats: 详细统计信息(可选)
|
||||
"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📊 End User ID: {end_user_id}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
if not stats["exists"]:
|
||||
print("❌ 该 end_user_id 不存在数据")
|
||||
print("\n💡 提示: 请先运行基准测试以摄入数据")
|
||||
return
|
||||
|
||||
print(f"✅ 该 end_user_id 存在数据\n")
|
||||
print(f"📈 基本统计:")
|
||||
print(f" 总节点数: {stats['total_nodes']}")
|
||||
print(f" 总关系数: {stats['total_relationships']}")
|
||||
|
||||
if stats["nodes_by_type"]:
|
||||
print(f"\n📋 节点类型分布:")
|
||||
for item in stats["nodes_by_type"]:
|
||||
labels = ", ".join(item["labels"])
|
||||
count = item["count"]
|
||||
print(f" {labels}: {count}")
|
||||
|
||||
if detailed_stats:
|
||||
print(f"\n🔍 详细统计:")
|
||||
|
||||
if "chunks" in detailed_stats:
|
||||
print(f" Chunks: {detailed_stats['chunks']['count']} 个")
|
||||
print(f" 平均内容长度: {detailed_stats['chunks']['avg_content_length']} 字符")
|
||||
|
||||
if "statements" in detailed_stats:
|
||||
print(f" Statements: {detailed_stats['statements']['count']} 个")
|
||||
|
||||
if "entities" in detailed_stats:
|
||||
print(f" Entities: {detailed_stats['entities']['count']} 个")
|
||||
print(f" 唯一类型数: {detailed_stats['entities']['unique_types']}")
|
||||
|
||||
if "dialogues" in detailed_stats:
|
||||
print(f" Dialogues: {detailed_stats['dialogues']['count']} 个")
|
||||
|
||||
if "summaries" in detailed_stats:
|
||||
print(f" Summaries: {detailed_stats['summaries']['count']} 个")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
|
||||
async def interactive_mode():
|
||||
"""
|
||||
交互式模式
|
||||
"""
|
||||
print("\n" + "="*60)
|
||||
print("🔍 Neo4j End User 数据检查工具 - 交互模式")
|
||||
print("="*60 + "\n")
|
||||
|
||||
while True:
|
||||
print("\n请选择操作:")
|
||||
print(" 1. 检查指定 end_user_id")
|
||||
print(" 2. 列出所有 end_user_id")
|
||||
print(" 3. 退出")
|
||||
|
||||
choice = input("\n请输入选项 (1-3): ").strip()
|
||||
|
||||
if choice == "1":
|
||||
end_user_id = input("\n请输入 end_user_id: ").strip()
|
||||
if not end_user_id:
|
||||
print("❌ end_user_id 不能为空")
|
||||
continue
|
||||
|
||||
detailed = input("是否显示详细统计? (y/n, 默认 n): ").strip().lower() == 'y'
|
||||
|
||||
print("\n🔄 正在查询...")
|
||||
stats = await check_group_exists(end_user_id)
|
||||
|
||||
detailed_stats = None
|
||||
if detailed and stats["exists"]:
|
||||
detailed_stats = await get_detailed_stats(end_user_id)
|
||||
|
||||
print_results(end_user_id, stats, detailed_stats)
|
||||
|
||||
elif choice == "2":
|
||||
print("\n🔄 正在查询所有 end_user_id...")
|
||||
end_users = await list_all_end_users()
|
||||
|
||||
if not end_users:
|
||||
print("\n❌ 数据库中没有任何 end_user 数据")
|
||||
else:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📋 数据库中的所有 End User ID")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
for idx, end_user in enumerate(end_users, 1):
|
||||
print(f" {idx}. {end_user['end_user_id']}")
|
||||
print(f" 节点数: {end_user['node_count']}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
|
||||
elif choice == "3":
|
||||
print("\n👋 再见!")
|
||||
break
|
||||
|
||||
else:
|
||||
print("\n❌ 无效的选项,请重新选择")
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
主函数
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="检查 Neo4j 中指定 end_user_id 的数据情况",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 交互模式
|
||||
python check_group_data.py
|
||||
|
||||
# 检查指定 end_user
|
||||
python check_group_data.py --end-user-id locomo_benchmark
|
||||
|
||||
# 检查并显示详细统计
|
||||
python check_group_data.py --end-user-id memsciqa_benchmark --detailed
|
||||
|
||||
# 列出所有 end_user
|
||||
python check_group_data.py --list-all
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--end-user-id",
|
||||
type=str,
|
||||
help="要检查的 end_user ID"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--detailed",
|
||||
action="store_true",
|
||||
help="显示详细统计信息"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--list-all",
|
||||
action="store_true",
|
||||
help="列出所有 end_user_id"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 如果没有提供任何参数,进入交互模式
|
||||
if not args.end_user_id and not args.list_all:
|
||||
await interactive_mode()
|
||||
return
|
||||
|
||||
# 列出所有 end_user
|
||||
if args.list_all:
|
||||
print("\n🔄 正在查询所有 end_user_id...")
|
||||
end_users = await list_all_end_users()
|
||||
|
||||
if not end_users:
|
||||
print("\n❌ 数据库中没有任何 end_user 数据")
|
||||
else:
|
||||
print(f"\n{'='*60}")
|
||||
print(f"📋 数据库中的所有 End User ID")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
for idx, end_user in enumerate(end_users, 1):
|
||||
print(f" {idx}. {end_user['end_user_id']}")
|
||||
print(f" 节点数: {end_user['node_count']}")
|
||||
|
||||
print(f"\n{'='*60}\n")
|
||||
return
|
||||
|
||||
# 检查指定 end_user
|
||||
if args.end_user_id:
|
||||
print(f"\n🔄 正在查询 end_user_id: {args.end_user_id}...")
|
||||
stats = await check_group_exists(args.end_user_id)
|
||||
|
||||
detailed_stats = None
|
||||
if args.detailed and stats["exists"]:
|
||||
print("🔄 正在获取详细统计...")
|
||||
detailed_stats = await get_detailed_stats(args.end_user_id)
|
||||
|
||||
print_results(args.end_user_id, stats, detailed_stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -2,7 +2,7 @@ import math
|
||||
import re
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
# 评估指标的实现
|
||||
def _normalize(text: str) -> List[str]:
|
||||
"""Lowercase, strip punctuation, and split into tokens."""
|
||||
text = text.lower().strip()
|
||||
|
||||
@@ -4,15 +4,17 @@ This file contains Cypher queries for searching dialogues, entities, and chunks.
|
||||
Placed in evaluation directory to avoid circular imports with src modules.
|
||||
"""
|
||||
|
||||
# 应该是neo4j browser的cypher语句,需要修改文件名
|
||||
|
||||
# Entity search queries
|
||||
SEARCH_ENTITIES_BY_NAME = """
|
||||
MATCH (e:Entity)
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.name = $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
SEARCH_ENTITIES_BY_NAME_FALLBACK = """
|
||||
MATCH (e:Entity)
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.name CONTAINS $name
|
||||
RETURN e
|
||||
"""
|
||||
|
||||
@@ -1,34 +1,33 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
import re
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import LLMClient
|
||||
from app.core.memory.models.message_models import (
|
||||
ConversationContext,
|
||||
ConversationMessage,
|
||||
DialogData,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import DialogueChunker
|
||||
from app.core.memory.models.message_models import DialogData, ConversationContext, ConversationMessage
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent / "app" / "core" / "memory" / "evaluation" / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
|
||||
# 使用新的模块化架构
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.chunk_extraction import (
|
||||
DialogueChunker,
|
||||
)
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_CHUNKER_STRATEGY,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
|
||||
# Import from database module
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
# Cypher queries for evaluation
|
||||
# Note: Entity, chunk, and dialogue search queries have been moved to evaluation/dialogue_queries.py
|
||||
@@ -41,11 +40,14 @@ async def ingest_contexts_via_full_pipeline(
|
||||
embedding_name: str | None = None,
|
||||
save_chunk_output: bool = False,
|
||||
save_chunk_output_path: str | None = None,
|
||||
reset_group: bool = False,
|
||||
) -> bool:
|
||||
"""DEPRECATED: 此函数使用旧的流水线架构,建议使用新的 ExtractionOrchestrator
|
||||
"""
|
||||
使用新的 ExtractionOrchestrator 运行完整的提取流水线
|
||||
|
||||
Run the full extraction pipeline on provided dialogue contexts and save to Neo4j.
|
||||
This function mirrors the steps in main(), but starts from raw text contexts.
|
||||
This function uses the new ExtractionOrchestrator architecture for better maintainability.
|
||||
|
||||
Args:
|
||||
contexts: List of dialogue texts, each containing lines like "role: message".
|
||||
end_user_id: Group ID to assign to generated DialogData and graph nodes.
|
||||
@@ -53,25 +55,59 @@ async def ingest_contexts_via_full_pipeline(
|
||||
embedding_name: Optional embedding model ID; defaults to SELECTED_EMBEDDING_ID.
|
||||
save_chunk_output: If True, write chunked DialogData list to a JSON file for debugging.
|
||||
save_chunk_output_path: Optional output path; defaults to src/chunker_test_output.txt.
|
||||
reset_group: If True, clear existing data for this group before ingestion.
|
||||
Returns:
|
||||
True if data saved successfully, False otherwise.
|
||||
"""
|
||||
chunker_strategy = chunker_strategy or SELECTED_CHUNKER_STRATEGY
|
||||
embedding_name = embedding_name or SELECTED_EMBEDDING_ID
|
||||
chunker_strategy = chunker_strategy or os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
|
||||
embedding_name = embedding_name or os.getenv("EVAL_EMBEDDING_ID")
|
||||
|
||||
# Check if we should reset from environment variable if not explicitly set
|
||||
if not reset_group:
|
||||
reset_group = os.getenv("EVAL_RESET_ON_INGEST", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
# Step 0: Reset group if requested
|
||||
if reset_group:
|
||||
print(f"[Ingestion] 🗑️ 清空 end_user '{end_user_id}' 的现有数据...")
|
||||
try:
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 删除该 end_user 的所有节点和关系
|
||||
query = """
|
||||
MATCH (n {end_user_id: $end_user_id})
|
||||
DETACH DELETE n
|
||||
"""
|
||||
await connector.execute_query(query, end_user_id=end_user_id)
|
||||
print(f"[Ingestion] ✅ End User '{end_user_id}' 已清空")
|
||||
finally:
|
||||
await connector.close()
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] ⚠️ 清空 end_user 失败: {e}")
|
||||
# 继续执行,不中断摄入流程
|
||||
|
||||
# Initialize llm client with graceful fallback
|
||||
# Step 1: Initialize LLM client
|
||||
llm_client = None
|
||||
llm_available = True
|
||||
try:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
# 使用评估配置中的 LLM ID
|
||||
llm_id = os.getenv("EVAL_LLM_ID")
|
||||
if not llm_id:
|
||||
print("[Ingestion] ❌ EVAL_LLM_ID not set in .env.evaluation")
|
||||
return False
|
||||
|
||||
from app.db import get_db
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
llm_client = get_llm_client(llm_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] LLM client unavailable, will skip LLM-dependent steps: {e}")
|
||||
llm_available = False
|
||||
print(f"[Ingestion] LLM client unavailable: {e}")
|
||||
return False
|
||||
|
||||
# Step A: Build DialogData list from contexts with robust parsing
|
||||
# Step 2: Parse contexts and create DialogData with chunks
|
||||
print(f"[Ingestion] Parsing {len(contexts)} contexts...")
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
dialog_data_list: List[DialogData] = []
|
||||
|
||||
@@ -94,7 +130,7 @@ async def ingest_contexts_via_full_pipeline(
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)$', line)
|
||||
m = re.match(r'^\s*([^::]+)\s*[::]\s*(.+)', line)
|
||||
if m:
|
||||
role = m.group(1).strip()
|
||||
msg = m.group(2).strip()
|
||||
@@ -118,10 +154,12 @@ async def ingest_contexts_via_full_pipeline(
|
||||
dialog_data_list.append(dialog)
|
||||
|
||||
if not dialog_data_list:
|
||||
print("No dialogs to process for ingestion.")
|
||||
print("[Ingestion] No dialogs to process.")
|
||||
return False
|
||||
|
||||
# Optionally save chunking outputs for debugging
|
||||
print(f"[Ingestion] Parsed {len(dialog_data_list)} dialogs with chunks")
|
||||
|
||||
# Step 3: Optionally save chunking outputs for debugging
|
||||
if save_chunk_output:
|
||||
try:
|
||||
def _serialize_datetime(obj):
|
||||
@@ -137,124 +175,185 @@ async def ingest_contexts_via_full_pipeline(
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
with open(out_path, "w", encoding="utf-8") as f:
|
||||
json.dump(combined_output, f, ensure_ascii=False, indent=4, default=_serialize_datetime)
|
||||
print(f"Saved chunking results to: {out_path}")
|
||||
print(f"[Ingestion] Saved chunking results to: {out_path}")
|
||||
except Exception as e:
|
||||
print(f"Failed to save chunking results: {e}")
|
||||
print(f"[Ingestion] Failed to save chunking results: {e}")
|
||||
|
||||
# Step B-G: 使用新的 ExtractionOrchestrator 执行完整的提取流水线
|
||||
if not llm_available:
|
||||
print("[Ingestion] Skipping extraction pipeline (no LLM).")
|
||||
return False
|
||||
|
||||
# 初始化 embedder 客户端
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
# Step 4: Initialize embedder client
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.db import get_db
|
||||
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
embedder_config_dict = MemoryConfigService(db).get_embedder_config(embedding_name or SELECTED_EMBEDDING_ID)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
db = next(get_db())
|
||||
try:
|
||||
embedder_config_dict = get_embedder_config(embedding_name, db)
|
||||
embedder_config = RedBearModelConfig(**embedder_config_dict)
|
||||
embedder_client = OpenAIEmbedderClient(embedder_config)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to initialize embedder client: {e}")
|
||||
print("[Ingestion] Skipping extraction pipeline (embedder initialization failed).")
|
||||
return False
|
||||
|
||||
# Step 5: Initialize Neo4j connector
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化并运行 ExtractionOrchestrator
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
config = get_pipeline_config()
|
||||
# Step 6: 构建 MemoryConfig(从环境变量直接构建,不依赖数据库)
|
||||
print("[Ingestion] 构建 MemoryConfig from environment variables...")
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
try:
|
||||
# 从环境变量获取配置参数
|
||||
llm_id = os.getenv("EVAL_LLM_ID")
|
||||
embedding_id = os.getenv("EVAL_EMBEDDING_ID")
|
||||
chunker_strategy_env = os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker")
|
||||
|
||||
if not llm_id or not embedding_id:
|
||||
print("[Ingestion] ❌ EVAL_LLM_ID or EVAL_EMBEDDING_ID is not set in .env.evaluation")
|
||||
print("[Ingestion] Please set both EVAL_LLM_ID and EVAL_EMBEDDING_ID")
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
# 从数据库获取模型信息(仅用于显示名称)
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
# 获取 LLM 模型信息(从 model_configs 表)
|
||||
llm_result = db.execute(
|
||||
text("SELECT name FROM model_configs WHERE id = :id"),
|
||||
{"id": llm_id}
|
||||
).fetchone()
|
||||
llm_model_name = llm_result[0] if llm_result else "Unknown LLM"
|
||||
|
||||
# 获取 Embedding 模型信息(从 model_configs 表)
|
||||
emb_result = db.execute(
|
||||
text("SELECT name FROM model_configs WHERE id = :id"),
|
||||
{"id": embedding_id}
|
||||
).fetchone()
|
||||
embedding_model_name = emb_result[0] if emb_result else "Unknown Embedding"
|
||||
except Exception as e:
|
||||
# 如果查询失败,使用默认名称
|
||||
print(f"[Ingestion] Warning: Failed to query model names from database: {e}")
|
||||
llm_model_name = f"LLM ({llm_id[:8]}...)"
|
||||
embedding_model_name = f"Embedding ({embedding_id[:8]}...)"
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# 构建 MemoryConfig 对象(使用最小必需配置)
|
||||
from uuid import uuid4
|
||||
memory_config = MemoryConfig(
|
||||
config_id=0, # 评估环境不需要真实的 config_id
|
||||
config_name="evaluation_config",
|
||||
workspace_id=uuid4(), # 临时 workspace_id
|
||||
workspace_name="evaluation_workspace",
|
||||
tenant_id=uuid4(), # 临时 tenant_id
|
||||
llm_model_id=UUID(llm_id),
|
||||
llm_model_name=llm_model_name,
|
||||
embedding_model_id=UUID(embedding_id),
|
||||
embedding_model_name=embedding_model_name,
|
||||
storage_type="neo4j",
|
||||
chunker_strategy=chunker_strategy_env,
|
||||
reflexion_enabled=False,
|
||||
reflexion_iteration_period=3,
|
||||
reflexion_range="partial",
|
||||
reflexion_baseline="TIME",
|
||||
loaded_at=datetime.now(),
|
||||
# 可选字段使用默认值
|
||||
rerank_model_id=None,
|
||||
rerank_model_name=None,
|
||||
llm_params={},
|
||||
embedding_params={},
|
||||
config_version="2.0",
|
||||
)
|
||||
|
||||
print(f"[Ingestion] ✅ 构建 MemoryConfig 成功")
|
||||
print(f"[Ingestion] LLM: {llm_model_name}")
|
||||
print(f"[Ingestion] Embedding: {embedding_model_name}")
|
||||
print(f"[Ingestion] Chunker: {chunker_strategy_env}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] ❌ Failed to build MemoryConfig: {e}")
|
||||
print(f"[Ingestion] Please check:")
|
||||
print(f"[Ingestion] 1. EVAL_LLM_ID and EVAL_EMBEDDING_ID are set in .env.evaluation")
|
||||
print(f"[Ingestion] 2. Model IDs exist in the models table")
|
||||
print(f"[Ingestion] 3. Database connection is working")
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
# Step 7: Initialize and run ExtractionOrchestrator
|
||||
print("[Ingestion] Running extraction pipeline with ExtractionOrchestrator...")
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
config = MemoryConfigService.get_pipeline_config(memory_config)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=connector,
|
||||
config=config,
|
||||
embedding_id=str(memory_config.embedding_model_id), # 传递 embedding_id
|
||||
)
|
||||
|
||||
# 创建一个包装的 orchestrator 来修复时间提取器的输出
|
||||
# 保存原始的 _assign_extracted_data 方法
|
||||
original_assign = orchestrator._assign_extracted_data
|
||||
|
||||
def clean_temporal_value(value):
|
||||
"""清理 temporal_validity 字段的值,将无效值转换为 None"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
# 处理字符串形式的 'null', 'None', 空字符串等
|
||||
if value.lower() in ('null', 'none', '') or value.strip() == '':
|
||||
return None
|
||||
return value
|
||||
|
||||
async def patched_assign_extracted_data(*args, **kwargs):
|
||||
"""包装方法:在赋值后清理 temporal_validity 中的无效字符串"""
|
||||
result = await original_assign(*args, **kwargs)
|
||||
try:
|
||||
# Run the complete extraction pipeline
|
||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
||||
|
||||
# 清理返回的 dialog_data_list 中的 temporal_validity
|
||||
for dialog in result:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
# 清理 valid_at 和 invalid_at
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
return result
|
||||
|
||||
# 替换方法
|
||||
orchestrator._assign_extracted_data = patched_assign_extracted_data
|
||||
|
||||
# 同时包装 _create_nodes_and_edges 方法,在创建节点前再次清理
|
||||
original_create = orchestrator._create_nodes_and_edges
|
||||
|
||||
async def patched_create_nodes_and_edges(dialog_data_list_arg):
|
||||
"""包装方法:在创建节点前再次清理 temporal_validity"""
|
||||
# 最后一次清理,确保万无一失
|
||||
for dialog in dialog_data_list_arg:
|
||||
if hasattr(dialog, 'chunks') and dialog.chunks:
|
||||
for chunk in dialog.chunks:
|
||||
if hasattr(chunk, 'statements') and chunk.statements:
|
||||
for statement in chunk.statements:
|
||||
if hasattr(statement, 'temporal_validity') and statement.temporal_validity:
|
||||
tv = statement.temporal_validity
|
||||
if hasattr(tv, 'valid_at'):
|
||||
tv.valid_at = clean_temporal_value(tv.valid_at)
|
||||
if hasattr(tv, 'invalid_at'):
|
||||
tv.invalid_at = clean_temporal_value(tv.invalid_at)
|
||||
# Handle different return formats:
|
||||
# - Pilot mode: 7 values (without dedup_details)
|
||||
# - Normal mode: 8 values (with dedup_details at the end)
|
||||
if len(result) == 8:
|
||||
# Normal mode: includes dedup_details
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
_, # dedup_details - not needed here
|
||||
) = result
|
||||
elif len(result) == 7:
|
||||
# Pilot mode or older version: no dedup_details
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
) = result
|
||||
else:
|
||||
raise ValueError(f"Unexpected number of return values: {len(result)}")
|
||||
|
||||
return await original_create(dialog_data_list_arg)
|
||||
|
||||
orchestrator._create_nodes_and_edges = patched_create_nodes_and_edges
|
||||
|
||||
# 运行完整的提取流水线
|
||||
# orchestrator.run 返回 7 个元素的元组
|
||||
result = await orchestrator.run(dialog_data_list, is_pilot_run=False)
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
entity_nodes,
|
||||
statement_chunk_edges,
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
) = result
|
||||
|
||||
# statement_chunk_edges 已经由 orchestrator 创建,无需重复创建
|
||||
print(f"[Ingestion] Extraction completed: {len(statement_nodes)} statements, {len(entity_nodes)} entities")
|
||||
|
||||
except ValueError as e:
|
||||
# If unpacking fails, provide helpful error message
|
||||
print(f"[Ingestion] Extraction pipeline result unpacking failed: {e}")
|
||||
print(f"[Ingestion] Result type: {type(result)}, length: {len(result) if hasattr(result, '__len__') else 'N/A'}")
|
||||
if hasattr(result, '__len__') and len(result) > 0:
|
||||
print(f"[Ingestion] First element type: {type(result[0])}")
|
||||
await connector.close()
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Extraction pipeline failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
# Step G: 生成记忆摘要
|
||||
# Step 7: Generate memory summaries
|
||||
print("[Ingestion] Generating memory summaries...")
|
||||
try:
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs=dialog_data_list,
|
||||
@@ -266,7 +365,8 @@ async def ingest_contexts_via_full_pipeline(
|
||||
print(f"[Ingestion] Warning: Failed to generate memory summaries: {e}")
|
||||
summaries = []
|
||||
|
||||
# Step H: Save to Neo4j
|
||||
# Step 8: Save to Neo4j
|
||||
print("[Ingestion] Saving to Neo4j...")
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=dialogue_nodes,
|
||||
@@ -284,18 +384,21 @@ async def ingest_contexts_via_full_pipeline(
|
||||
try:
|
||||
await add_memory_summary_nodes(summaries, connector)
|
||||
await add_memory_summary_statement_edges(summaries, connector)
|
||||
print(f"Successfully saved {len(summaries)} memory summary nodes to Neo4j")
|
||||
print(f"[Ingestion] Saved {len(summaries)} memory summary nodes to Neo4j")
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to save summary nodes: {e}")
|
||||
print(f"[Ingestion] Warning: Failed to save summary nodes: {e}")
|
||||
|
||||
await connector.close()
|
||||
|
||||
if success:
|
||||
print("Successfully saved extracted data to Neo4j!")
|
||||
print("[Ingestion] Successfully saved all data to Neo4j!")
|
||||
else:
|
||||
print("Failed to save data to Neo4j")
|
||||
print("[Ingestion] Failed to save data to Neo4j")
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to save data to Neo4j: {e}")
|
||||
print(f"[Ingestion] Failed to save data to Neo4j: {e}")
|
||||
await connector.close()
|
||||
return False
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,30 +1,29 @@
|
||||
# file name: check_neo4j_connection_fixed.py
|
||||
import asyncio
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 1
|
||||
# 添加项目根目录到路径
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
project_root = str(current_dir.parent)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# 关键:将 src 目录置于最前,确保从当前仓库加载模块
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
if src_dir not in sys.path:
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
# Load main .env
|
||||
load_dotenv()
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
# Get group_id from config
|
||||
group_id = os.getenv("EVAL_GROUP_ID", "locomo_test")
|
||||
print(f"✅ 使用配置的 group_id: {group_id}")
|
||||
|
||||
# 首先定义 _loc_normalize 函数,因为其他函数依赖它
|
||||
def _loc_normalize(text: str) -> str:
|
||||
text = str(text) if text is not None else ""
|
||||
@@ -37,7 +36,7 @@ def _loc_normalize(text: str) -> str:
|
||||
|
||||
# 尝试从 metrics.py 导入基础指标
|
||||
try:
|
||||
from common.metrics import bleu1, f1_score, jaccard
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
print("✅ 从 metrics.py 导入基础指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 metrics.py 导入失败: {e}")
|
||||
@@ -107,23 +106,8 @@ except ImportError as e:
|
||||
|
||||
# 尝试从 qwen_search_eval.py 导入 LoCoMo 特定指标
|
||||
try:
|
||||
# 添加 evaluation 目录路径
|
||||
evaluation_dir = os.path.join(project_root, "evaluation")
|
||||
if evaluation_dir not in sys.path:
|
||||
sys.path.insert(0, evaluation_dir)
|
||||
|
||||
# 尝试从不同位置导入
|
||||
try:
|
||||
from locomo.qwen_search_eval import (
|
||||
_resolve_relative_times,
|
||||
loc_f1_score,
|
||||
loc_multi_f1,
|
||||
)
|
||||
print("✅ 从 locomo.qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError:
|
||||
from qwen_search_eval import _resolve_relative_times, loc_f1_score, loc_multi_f1
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
|
||||
from app.core.memory.evaluation.locomo.qwen_search_eval import loc_f1_score, loc_multi_f1, _resolve_relative_times
|
||||
print("✅ 从 qwen_search_eval 导入 LoCoMo 特定指标成功")
|
||||
except ImportError as e:
|
||||
print(f"❌ 从 qwen_search_eval.py 导入失败: {e}")
|
||||
# 回退到本地实现 LoCoMo 特定函数
|
||||
@@ -429,31 +413,36 @@ def enhanced_context_selection(contexts: List[str], question: str, question_inde
|
||||
|
||||
async def run_enhanced_evaluation():
|
||||
"""使用增强方法进行完整评估 - 解决中间性能衰减问题"""
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
|
||||
# 修正导入路径:使用 app.core.memory.src 前缀
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.graph_search import search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# Get model IDs from config
|
||||
llm_id = os.getenv("EVAL_LLM_ID", "6dc52e1b-9cec-4194-af66-a74c6307fc3f")
|
||||
embedding_id = os.getenv("EVAL_EMBEDDING_ID", "e2a6392d-ca63-4d59-a523-647420b59cb2")
|
||||
|
||||
# 加载数据
|
||||
# 获取项目根目录
|
||||
current_file = os.path.abspath(__file__)
|
||||
evaluation_dir = os.path.dirname(os.path.dirname(current_file)) # evaluation目录
|
||||
memory_dir = os.path.dirname(evaluation_dir) # memory目录
|
||||
data_path = os.path.join(memory_dir, "data", "locomo10.json")
|
||||
# 加载数据 - 使用统一的 dataset 目录
|
||||
data_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "dataset", "locomo10.json")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 locomo10.json 放置在: api/app/core/memory/evaluation/dataset/"
|
||||
)
|
||||
|
||||
print(f"✅ 找到数据文件: {data_path}")
|
||||
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
|
||||
@@ -463,64 +452,109 @@ async def run_enhanced_evaluation():
|
||||
qa_items.extend(entry.get("qa", []))
|
||||
else:
|
||||
qa_items.extend(raw.get("qa", []))
|
||||
|
||||
items = qa_items[:20] # 测试多少个问题
|
||||
|
||||
# 测试多少个问题 - 可通过环境变量设置
|
||||
sample_size = int(os.getenv("LOCOMO_SAMPLE_SIZE", "20"))
|
||||
items = qa_items[:sample_size]
|
||||
print(f"📊 将测试 {len(items)} 个问题(总共 {len(qa_items)} 个可用)")
|
||||
|
||||
# 初始化增强监控器
|
||||
monitor = EnhancedEvaluationMonitor(reset_interval=5, performance_threshold=0.6)
|
||||
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
# 获取数据库会话并初始化 LLM 客户端
|
||||
from app.db import get_db
|
||||
db = next(get_db())
|
||||
|
||||
# 初始化embedder
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 初始化连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 初始化结果字典
|
||||
results = {
|
||||
"questions": [],
|
||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
||||
"category_metrics": {},
|
||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
||||
"performance_trend": "stable",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"enhanced_strategy": True
|
||||
}
|
||||
|
||||
total_f1 = 0.0
|
||||
total_bleu1 = 0.0
|
||||
total_jaccard = 0.0
|
||||
total_loc_f1 = 0.0
|
||||
total_context_length = 0
|
||||
total_retrieved_docs = 0
|
||||
category_stats = {}
|
||||
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
monitor.question_count += 1
|
||||
llm = get_llm_client(llm_id, db)
|
||||
|
||||
# 初始化embedder
|
||||
cfg_dict = get_embedder_config(embedding_id, db)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 🔧 创建 MemoryConfig 对象用于搜索
|
||||
# 方案1:如果有配置ID,从数据库加载
|
||||
config_id = os.getenv("EVAL_CONFIG_ID")
|
||||
if config_id:
|
||||
print(f"📋 从数据库加载配置 ID: {config_id}")
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
memory_config = memory_config_service.load_memory_config(config_id, service_name="locomo_test")
|
||||
else:
|
||||
# 方案2:创建临时配置对象用于测试
|
||||
print(f"📋 创建临时测试配置")
|
||||
from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
# 将字符串 ID 转换为 UUID
|
||||
try:
|
||||
embedding_uuid = UUID(embedding_id)
|
||||
llm_uuid = UUID(llm_id)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"无效的 UUID 格式: {e}")
|
||||
|
||||
memory_config = MemoryConfig(
|
||||
config_id=1, # 临时 ID
|
||||
config_name="locomo_test_config",
|
||||
workspace_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 workspace
|
||||
workspace_name="test_workspace",
|
||||
tenant_id=UUID("00000000-0000-0000-0000-000000000000"), # 临时 tenant
|
||||
embedding_model_id=embedding_uuid,
|
||||
embedding_model_name="test_embedding",
|
||||
llm_model_id=llm_uuid,
|
||||
llm_model_name="test_llm",
|
||||
storage_type="neo4j",
|
||||
chunker_strategy="RecursiveChunker",
|
||||
reflexion_enabled=False,
|
||||
reflexion_iteration_period=3,
|
||||
reflexion_range="partial",
|
||||
reflexion_baseline="Time",
|
||||
loaded_at=datetime.now()
|
||||
)
|
||||
|
||||
print(f"✅ MemoryConfig 已准备: embedding_id={memory_config.embedding_model_id}, llm_id={memory_config.llm_model_id}")
|
||||
|
||||
# 初始化连接器
|
||||
connector = Neo4jConnector()
|
||||
|
||||
# 获取近期性能用于重置判断
|
||||
recent_performance = monitor.get_recent_performance()
|
||||
# 初始化结果字典
|
||||
results = {
|
||||
"questions": [],
|
||||
"overall_metrics": {"f1": 0.0, "b1": 0.0, "j": 0.0, "loc_f1": 0.0},
|
||||
"category_metrics": {},
|
||||
"retrieval_stats": {"total_questions": len(items), "avg_context_length": 0, "avg_retrieved_docs": 0},
|
||||
"performance_trend": "stable",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"enhanced_strategy": True
|
||||
}
|
||||
|
||||
# 增强的重置判断
|
||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
||||
if should_reset and i > 0:
|
||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
||||
await connector.close()
|
||||
connector = Neo4jConnector() # 创建新连接
|
||||
print("✅ 连接重置完成")
|
||||
total_f1 = 0.0
|
||||
total_bleu1 = 0.0
|
||||
total_jaccard = 0.0
|
||||
total_loc_f1 = 0.0
|
||||
total_context_length = 0
|
||||
total_retrieved_docs = 0
|
||||
category_stats = {}
|
||||
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
try:
|
||||
for i, item in enumerate(items):
|
||||
monitor.question_count += 1
|
||||
|
||||
# 获取近期性能用于重置判断
|
||||
recent_performance = monitor.get_recent_performance()
|
||||
|
||||
# 增强的重置判断
|
||||
should_reset = monitor.should_reset_connections(current_f1=recent_performance)
|
||||
if should_reset and i > 0:
|
||||
print(f"🔄 重置Neo4j连接 (问题 {i+1}/{len(items)}, 近期性能: {recent_performance:.3f})...")
|
||||
await connector.close()
|
||||
connector = Neo4jConnector() # 创建新连接
|
||||
print("✅ 连接重置完成")
|
||||
|
||||
q = item.get("question", "")
|
||||
ref = item.get("answer", "")
|
||||
ref_str = str(ref) if ref is not None else ""
|
||||
|
||||
print(f"\n🔍 [{i+1}/{len(items)}] 问题: {q}")
|
||||
print(f"✅ 真实答案: {ref_str}")
|
||||
@@ -548,10 +582,12 @@ async def run_enhanced_evaluation():
|
||||
contexts_all = []
|
||||
|
||||
try:
|
||||
# 使用统一的搜索服务
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
# 使用旧版本的搜索服务(重构前的版本)
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
print("🔀 使用混合搜索服务...")
|
||||
print(f"🔀 使用混合搜索服务(旧版本)...")
|
||||
print(f"📍 检索参数: group_id={group_id}, limit=20, search_type=hybrid")
|
||||
print(f"📍 查询文本: {q}")
|
||||
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
@@ -559,15 +595,27 @@ async def run_enhanced_evaluation():
|
||||
end_user_id="locomo_sk",
|
||||
limit=20,
|
||||
include=["statements", "chunks", "entities", "summaries"],
|
||||
alpha=0.6, # BM25权重
|
||||
embedding_id=SELECTED_EMBEDDING_ID
|
||||
output_path=None,
|
||||
memory_config=memory_config, # 🔧 添加必需的 memory_config 参数
|
||||
rerank_alpha=0.6, # BM25权重
|
||||
use_forgetting_rerank=False,
|
||||
use_llm_rerank=False
|
||||
)
|
||||
|
||||
# 处理搜索结果 - 新的搜索服务返回统一的结构
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
# 处理搜索结果 - 旧版本返回包含 reranked_results 的结构
|
||||
# 对于 hybrid 搜索,使用 reranked_results
|
||||
if "reranked_results" in search_results:
|
||||
reranked = search_results["reranked_results"]
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
# 单一搜索类型的结果
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
print(f"✅ 混合检索成功: {len(chunks)} chunks, {len(statements)} 条陈述, {len(entities)} 个实体, {len(summaries)} 个摘要")
|
||||
|
||||
@@ -609,6 +657,8 @@ async def run_enhanced_evaluation():
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
except Exception as e:
|
||||
print(f"❌ 检索失败: {e}")
|
||||
import traceback
|
||||
print(f"详细错误信息:\n{traceback.format_exc()}")
|
||||
contexts_all = []
|
||||
|
||||
t1 = time.time()
|
||||
@@ -728,14 +778,17 @@ async def run_enhanced_evaluation():
|
||||
|
||||
print("="*60)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 评估过程中发生错误: {e}")
|
||||
# 即使出错,也返回已有的结果
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
print(f"❌ 评估过程中发生错误: {e}")
|
||||
# 即使出错,也返回已有的结果
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
db.close() # 关闭数据库会话
|
||||
|
||||
# 计算总体指标
|
||||
n = len(items)
|
||||
|
||||
@@ -15,8 +15,14 @@ import json
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
|
||||
from app.core.memory.utils.definitions import PROJECT_ROOT
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
|
||||
|
||||
@@ -82,7 +88,7 @@ def load_locomo_data(
|
||||
return qa_items[:sample_size]
|
||||
|
||||
|
||||
def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
||||
def extract_conversations(data_path: str, max_dialogues: int = 1, max_messages_per_dialogue: Optional[int] = None) -> List[str]:
|
||||
"""
|
||||
Extract conversation texts from LoCoMo data for ingestion.
|
||||
|
||||
@@ -93,6 +99,7 @@ def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
||||
Args:
|
||||
data_path: Path to locomo10.json file
|
||||
max_dialogues: Maximum number of dialogues to extract (default: 1)
|
||||
max_messages_per_dialogue: Maximum messages per dialogue (default: None = all messages)
|
||||
|
||||
Returns:
|
||||
List of conversation strings formatted for ingestion.
|
||||
@@ -141,13 +148,21 @@ def extract_conversations(data_path: str, max_dialogues: int = 1) -> List[str]:
|
||||
continue
|
||||
|
||||
lines.append(f"{role}: {text}")
|
||||
|
||||
# Limit messages if specified
|
||||
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
|
||||
break
|
||||
|
||||
# Break outer loop if we've reached the message limit
|
||||
if max_messages_per_dialogue and len(lines) >= max_messages_per_dialogue:
|
||||
break
|
||||
|
||||
if lines:
|
||||
contents.append("\n".join(lines))
|
||||
|
||||
return contents
|
||||
|
||||
|
||||
# 时间解析:将相对时间表达转换为绝对日期
|
||||
def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
||||
"""
|
||||
Resolve relative temporal references to absolute dates.
|
||||
@@ -225,6 +240,8 @@ def resolve_temporal_references(text: str, anchor_date: datetime) -> str:
|
||||
t,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
# 中文支持
|
||||
t = re.sub(
|
||||
r"\bnext\s+week\b",
|
||||
(anchor_date + timedelta(days=7)).date().isoformat(),
|
||||
@@ -345,6 +362,50 @@ def select_and_format_information(
|
||||
|
||||
return "\n\n".join(selected)
|
||||
|
||||
# 记忆系统核心能力:写入与读取
|
||||
async def ingest_conversations_if_needed(
|
||||
conversations: List[str],
|
||||
end_user_id: str,
|
||||
reset: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Wrapper for conversation ingestion using external extraction pipeline.
|
||||
|
||||
This function populates the Neo4j database with processed conversation data
|
||||
(chunks, statements, entities) so that the retrieval system has memory to search.
|
||||
|
||||
The ingestion process:
|
||||
1. Parses conversation text into dialogue messages
|
||||
2. Chunks the dialogues into semantic units
|
||||
3. Extracts statements and entities using LLM
|
||||
4. Generates embeddings for all content
|
||||
5. Stores everything in Neo4j graph database
|
||||
|
||||
Args:
|
||||
conversations: List of raw conversation texts from LoCoMo dataset
|
||||
Example: ["User: I went to Paris. AI: When was that?", ...]
|
||||
end_user_id: Target end_user ID for database storage
|
||||
reset: Whether to clear existing data first (not implemented in wrapper)
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
|
||||
Note:
|
||||
The external function uses "contexts" to mean "conversation texts".
|
||||
This runs the full extraction pipeline: chunking → entity extraction →
|
||||
statement extraction → embedding → Neo4j storage.
|
||||
"""
|
||||
try:
|
||||
success = await ingest_contexts_via_full_pipeline(
|
||||
contexts=conversations,
|
||||
end_user_id=end_user_id,
|
||||
save_chunk_output=True,
|
||||
reset_group=reset
|
||||
)
|
||||
return success
|
||||
except Exception as e:
|
||||
print(f"[Ingestion] Failed to ingest conversations: {e}")
|
||||
return False
|
||||
|
||||
async def retrieve_relevant_information(
|
||||
question: str,
|
||||
@@ -385,7 +446,7 @@ async def retrieve_relevant_information(
|
||||
search_graph,
|
||||
search_graph_by_embedding
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
contexts_all: List[str] = []
|
||||
|
||||
|
||||
@@ -2,43 +2,29 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from typing import List, Dict, Any
|
||||
import statistics
|
||||
import re
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
bleu1,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, bleu1, jaccard, latency_stats, avg_context_tokens
|
||||
|
||||
|
||||
# 参考 evaluation/locomo/evaluation.py 的 F1 计算逻辑(移除外部依赖,内联实现)
|
||||
@@ -265,7 +251,10 @@ async def run_locomo_eval(
|
||||
end_user_id = end_user_id or SELECTED_end_user_id
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "locomo10.json")
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "locomo10.json")
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 locomo10.json 放置在: {dataset_dir}"
|
||||
)
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
raw = json.load(f)
|
||||
# LoCoMo 数据结构:顶层为若干对象,每个对象下有 qa 列表
|
||||
@@ -343,13 +332,9 @@ async def run_locomo_eval(
|
||||
await ingest_contexts_via_full_pipeline(contents, end_user_id, save_chunk_output=True)
|
||||
|
||||
# 使用异步LLM客户端
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
llm_client = get_llm_client(llm_id)
|
||||
# 初始化embedder用于直接调用
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
cfg_dict = get_embedder_config(embedding_id)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
@@ -480,8 +465,8 @@ async def run_locomo_eval(
|
||||
contexts_all.append(f"EntitySummary: {', '.join(entity_names)}")
|
||||
|
||||
else: # hybrid
|
||||
# 🎯 关键修复:混合检索使用更严格的回退机制
|
||||
print("🔀 使用混合检索(带回退机制)...")
|
||||
# 使用旧版本的混合检索(重构前)
|
||||
print("🔀 使用混合检索(旧版本)...")
|
||||
try:
|
||||
search_results = await run_hybrid_search(
|
||||
query_text=q,
|
||||
@@ -490,16 +475,26 @@ async def run_locomo_eval(
|
||||
limit=adjusted_limit,
|
||||
include=["chunks", "statements", "entities", "summaries"],
|
||||
output_path=None,
|
||||
rerank_alpha=0.6,
|
||||
use_forgetting_rerank=False,
|
||||
use_llm_rerank=False
|
||||
)
|
||||
|
||||
# 🎯 关键修复:正确处理混合检索的扁平结构
|
||||
# 新的API返回扁平结构,直接从顶层获取结果
|
||||
# 处理旧版本的返回结构(包含 reranked_results)
|
||||
if search_results and isinstance(search_results, dict):
|
||||
# 新API返回扁平结构:直接从顶层获取
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
# 对于 hybrid 搜索,使用 reranked_results
|
||||
if "reranked_results" in search_results:
|
||||
reranked = search_results["reranked_results"]
|
||||
chunks = reranked.get("chunks", [])
|
||||
statements = reranked.get("statements", [])
|
||||
entities = reranked.get("entities", [])
|
||||
summaries = reranked.get("summaries", [])
|
||||
else:
|
||||
# 单一搜索类型的结果
|
||||
chunks = search_results.get("chunks", [])
|
||||
statements = search_results.get("statements", [])
|
||||
entities = search_results.get("entities", [])
|
||||
summaries = search_results.get("summaries", [])
|
||||
|
||||
# 检查是否有有效结果
|
||||
if chunks or statements or entities or summaries:
|
||||
@@ -799,8 +794,9 @@ async def run_locomo_eval(
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID,
|
||||
"llm_id": llm_id,
|
||||
"retrieval_embedding_id": embedding_id,
|
||||
"chunker_strategy": os.getenv("EVAL_CHUNKER_STRATEGY", "RecursiveChunker"),
|
||||
"skip_ingest_if_exists": skip_ingest_if_exists,
|
||||
"llm_timeout": llm_timeout,
|
||||
"llm_max_retries": llm_max_retries,
|
||||
|
||||
@@ -2,100 +2,67 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 确保可以找到 src 及项目根路径
|
||||
import sys
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_PROJECT_ROOT = str(_THIS_DIR.parents[2])
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
if _p not in sys.path:
|
||||
sys.path.insert(0, _p)
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
try:
|
||||
# 优先从 extraction_utils1 导入
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline, # type: ignore
|
||||
)
|
||||
except Exception:
|
||||
ingest_contexts_via_full_pipeline = None # 在运行时做兜底检查
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
# 兜底:简单的大小写不敏感比较
|
||||
def exact_match(pred: str, ref: str) -> bool:
|
||||
return str(pred).strip().lower() == str(ref).strip().lower()
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
|
||||
|
||||
def load_dataset_any(path: str) -> List[Dict[str, Any]]:
|
||||
"""健壮地加载数据集(兼容 list 或多段 JSON)。"""
|
||||
"""健壮地加载数据集,支持三种格式:
|
||||
1. 标准 JSON 数组: [{...}, {...}]
|
||||
2. 单个 JSON 对象: {...}
|
||||
3. JSONL 格式(每行一个 JSON): {...}\n{...}\n{...}
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
s = f.read().strip()
|
||||
content = f.read().strip()
|
||||
|
||||
# 尝试标准 JSON 解析
|
||||
try:
|
||||
obj = json.loads(s)
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
elif isinstance(obj, dict):
|
||||
return [obj]
|
||||
data = json.loads(content)
|
||||
if isinstance(data, list):
|
||||
return [item for item in data if isinstance(item, dict)]
|
||||
elif isinstance(data, dict):
|
||||
return [data]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
dec = json.JSONDecoder()
|
||||
idx = 0
|
||||
items: List[Dict[str, Any]] = []
|
||||
while idx < len(s):
|
||||
while idx < len(s) and s[idx].isspace():
|
||||
idx += 1
|
||||
if idx >= len(s):
|
||||
break
|
||||
|
||||
# 尝试 JSONL 格式(每行一个 JSON 对象)
|
||||
items = []
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj, end = dec.raw_decode(s, idx)
|
||||
if isinstance(obj, list):
|
||||
for it in obj:
|
||||
if isinstance(it, dict):
|
||||
items.append(it)
|
||||
elif isinstance(obj, dict):
|
||||
obj = json.loads(line)
|
||||
if isinstance(obj, dict):
|
||||
items.append(obj)
|
||||
idx = end
|
||||
elif isinstance(obj, list):
|
||||
items.extend(item for item in obj if isinstance(item, dict))
|
||||
except json.JSONDecodeError:
|
||||
nl = s.find("\n", idx)
|
||||
if nl == -1:
|
||||
break
|
||||
idx = nl + 1
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@@ -624,7 +591,7 @@ def _resolve_relative_times_cn_en(text: str, anchor: datetime) -> str:
|
||||
|
||||
async def run_longmemeval_test(
|
||||
sample_size: int = 3,
|
||||
end_user_id: str = "longmemeval_zh_bak_3",
|
||||
end_user_id: str | None = None,
|
||||
search_limit: int = 8,
|
||||
context_char_budget: int = 4000,
|
||||
llm_temperature: float = 0.0,
|
||||
@@ -639,18 +606,22 @@ async def run_longmemeval_test(
|
||||
skip_ingest: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""LongMemEval 评估测试:增强时间推理能力"""
|
||||
|
||||
# Use environment variable with fallback chain
|
||||
if end_user_id is None:
|
||||
end_user_id = os.getenv("LONGMEMEVAL_END_USER_ID") or os.getenv("EVAL_END_USER_ID", "longmemeval_zh_bak_3")
|
||||
|
||||
# 数据路径
|
||||
if not data_path:
|
||||
# 固定使用中文数据集:data/longmemeval_oracle_zh.json
|
||||
zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json")
|
||||
zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json")
|
||||
if os.path.exists(zh_proj):
|
||||
data_path = zh_proj
|
||||
elif os.path.exists(zh_cwd):
|
||||
data_path = zh_cwd
|
||||
else:
|
||||
raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
||||
# 固定使用中文数据集:dataset/longmemeval_oracle_zh.json
|
||||
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
|
||||
data_path = str(dataset_dir / "longmemeval_oracle_zh.json")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}"
|
||||
)
|
||||
|
||||
qa_list: List[Dict[str, Any]] = load_dataset_any(data_path)
|
||||
# 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾
|
||||
@@ -702,16 +673,19 @@ async def run_longmemeval_test(
|
||||
)
|
||||
|
||||
# 初始化组件(摄入后再初始化连接器)- 使用异步LLM客户端
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
from app.db import get_db
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db)
|
||||
cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"), db)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
connector = Neo4jConnector()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
|
||||
# 指标收集
|
||||
latencies_llm: List[float] = []
|
||||
@@ -768,10 +742,10 @@ async def run_longmemeval_test(
|
||||
if stmt_text:
|
||||
contexts_all.append(stmt_text)
|
||||
|
||||
# for sm in summaries:
|
||||
# summary_text = str(sm.get("summary", "")).strip()
|
||||
# if summary_text:
|
||||
# contexts_all.append(summary_text)
|
||||
for sm in summaries:
|
||||
summary_text = str(sm.get("summary", "")).strip()
|
||||
if summary_text:
|
||||
contexts_all.append(summary_text)
|
||||
|
||||
# 实体摘要(最多3个)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
@@ -1228,8 +1202,8 @@ async def run_longmemeval_test(
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"embedding_id": SELECTED_EMBEDDING_ID,
|
||||
"llm_id": os.getenv("EVAL_LLM_ID"),
|
||||
"embedding_id": os.getenv("EVAL_EMBEDDING_ID"),
|
||||
"sample_size": sample_size,
|
||||
"start_index": start_index,
|
||||
},
|
||||
@@ -1288,7 +1262,7 @@ def main():
|
||||
parser.add_argument("--sample-size", type=int, default=3, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
||||
parser.add_argument("--group-id", type=str, default="longmemeval_zh_bak_3", help="图数据库 Group ID")
|
||||
parser.add_argument("--end-user-id", type=str, default=None, help="图数据库 End User ID,默认使用环境变量")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="检索条数上限")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
@@ -1349,7 +1323,8 @@ def main():
|
||||
|
||||
# 保存结果到文件
|
||||
try:
|
||||
out_dir = os.path.join(PROJECT_ROOT, "evaluation", "longmemeval", "results")
|
||||
# 使用相对路径而不是 PROJECT_ROOT
|
||||
out_dir = Path(__file__).resolve().parent / "results"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
out_path = os.path.join(out_dir, f"longmemeval_{result['params']['search_type']}_{ts}.json")
|
||||
@@ -2,81 +2,67 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import statistics
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
# 与现有评估脚本保持一致的导入方式
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
jaccard,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
except Exception:
|
||||
# 兜底:简单的大小写不敏感比较
|
||||
def exact_match(pred: str, ref: str) -> bool:
|
||||
return str(pred).strip().lower() == str(ref).strip().lower()
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.dialogue_queries import SEARCH_ENTITIES_BY_NAME
|
||||
from app.core.memory.evaluation.common.metrics import f1_score as common_f1, jaccard, latency_stats, avg_context_tokens
|
||||
from app.core.memory.evaluation.common.metrics import exact_match
|
||||
|
||||
|
||||
def load_dataset_any(path: str) -> List[Dict[str, Any]]:
|
||||
"""健壮地加载数据集(兼容 list 或多段 JSON)。"""
|
||||
"""健壮地加载数据集,支持三种格式:
|
||||
1. 标准 JSON 数组: [{...}, {...}]
|
||||
2. 单个 JSON 对象: {...}
|
||||
3. JSONL 格式(每行一个 JSON): {...}\n{...}\n{...}
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
s = f.read().strip()
|
||||
content = f.read().strip()
|
||||
|
||||
# 尝试标准 JSON 解析
|
||||
try:
|
||||
obj = json.loads(s)
|
||||
if isinstance(obj, list):
|
||||
return obj
|
||||
elif isinstance(obj, dict):
|
||||
return [obj]
|
||||
data = json.loads(content)
|
||||
if isinstance(data, list):
|
||||
return [item for item in data if isinstance(item, dict)]
|
||||
elif isinstance(data, dict):
|
||||
return [data]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
dec = json.JSONDecoder()
|
||||
idx = 0
|
||||
items: List[Dict[str, Any]] = []
|
||||
while idx < len(s):
|
||||
while idx < len(s) and s[idx].isspace():
|
||||
idx += 1
|
||||
if idx >= len(s):
|
||||
break
|
||||
|
||||
# 尝试 JSONL 格式(每行一个 JSON 对象)
|
||||
items = []
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj, end = dec.raw_decode(s, idx)
|
||||
if isinstance(obj, list):
|
||||
for it in obj:
|
||||
if isinstance(it, dict):
|
||||
items.append(it)
|
||||
elif isinstance(obj, dict):
|
||||
obj = json.loads(line)
|
||||
if isinstance(obj, dict):
|
||||
items.append(obj)
|
||||
idx = end
|
||||
elif isinstance(obj, list):
|
||||
items.extend(item for item in obj if isinstance(item, dict))
|
||||
except json.JSONDecodeError:
|
||||
nl = s.find("\n", idx)
|
||||
if nl == -1:
|
||||
break
|
||||
idx = nl + 1
|
||||
continue
|
||||
|
||||
return items
|
||||
|
||||
|
||||
@@ -640,15 +626,15 @@ async def run_longmemeval_test(
|
||||
|
||||
# 数据路径
|
||||
if not data_path:
|
||||
# 固定使用中文数据集:data/longmemeval_oracle_zh.json
|
||||
zh_proj = os.path.join(PROJECT_ROOT, "data", "longmemeval_oracle_zh.json")
|
||||
zh_cwd = os.path.join(os.getcwd(), "data", "longmemeval_oracle_zh.json")
|
||||
if os.path.exists(zh_proj):
|
||||
data_path = zh_proj
|
||||
elif os.path.exists(zh_cwd):
|
||||
data_path = zh_cwd
|
||||
else:
|
||||
raise FileNotFoundError("未找到数据集: data/longmemeval_oracle_zh.json,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
||||
# 固定使用中文数据集:dataset/longmemeval_oracle_zh.json
|
||||
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
|
||||
data_path = str(dataset_dir / "longmemeval_oracle_zh.json")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 longmemeval_oracle_zh.json 放置在: {dataset_dir}"
|
||||
)
|
||||
|
||||
qa_list: List[Dict[str, Any]] = load_dataset_any(data_path)
|
||||
# 支持评估全部样本:当 sample_size <= 0 时,取从 start_index 到末尾
|
||||
@@ -658,13 +644,9 @@ async def run_longmemeval_test(
|
||||
items = qa_list[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化组件 - 使用异步LLM客户端
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"))
|
||||
connector = Neo4jConnector()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"))
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
@@ -1203,8 +1185,8 @@ async def run_longmemeval_test(
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"embedding_id": SELECTED_EMBEDDING_ID,
|
||||
"llm_id": os.getenv("EVAL_LLM_ID"),
|
||||
"embedding_id": os.getenv("EVAL_EMBEDDING_ID"),
|
||||
"sample_size": sample_size,
|
||||
"start_index": start_index,
|
||||
},
|
||||
|
||||
@@ -2,81 +2,30 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
# 路径与模块导入保持与现有评估脚本一致
|
||||
import sys
|
||||
from typing import List, Dict, Any
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
_THIS_DIR = Path(__file__).resolve().parent
|
||||
_PROJECT_ROOT = str(_THIS_DIR.parents[1])
|
||||
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
|
||||
for _p in (_SRC_DIR, _PROJECT_ROOT):
|
||||
if _p not in sys.path:
|
||||
sys.path.insert(0, _p)
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
print(f"✅ 加载评估配置: {eval_config_path}")
|
||||
|
||||
# 对齐 locomo_test 的检索逻辑:直接使用 graph_search 与 Neo4jConnector/Embedder1
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.graph_search import search_graph, search_graph_by_embedding
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.core.memory.src.search import run_hybrid_search # 使用与 evaluate_qa.py 相同的检索函数
|
||||
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.memory.utils.config.config_utils import get_embedder_config
|
||||
|
||||
try:
|
||||
from app.core.memory.evaluation.common.metrics import bleu1, f1_score, jaccard
|
||||
except Exception:
|
||||
# 兜底:简单实现(必要时)
|
||||
def f1_score(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
tp = len(set(ps) & set(rs))
|
||||
if tp == 0:
|
||||
return 0.0
|
||||
precision = tp / len(ps)
|
||||
recall = tp / len(rs)
|
||||
if precision + recall == 0:
|
||||
return 0.0
|
||||
return 2 * precision * recall / (precision + recall)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
|
||||
def bleu1(pred: str, ref: str) -> float:
|
||||
ps = pred.lower().split()
|
||||
rs = ref.lower().split()
|
||||
if not ps or not rs:
|
||||
return 0.0
|
||||
overlap = len([w for w in ps if w in rs])
|
||||
return overlap / max(len(ps), 1)
|
||||
|
||||
def jaccard(pred: str, ref: str) -> float:
|
||||
ps = set(pred.lower().split())
|
||||
rs = set(ref.lower().split())
|
||||
union = len(ps | rs)
|
||||
if union == 0:
|
||||
return 0.0
|
||||
return len(ps & rs) / union
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
@@ -219,16 +168,16 @@ async def run_memsciqa_test(
|
||||
# 默认使用指定的 memsci 组 ID
|
||||
end_user_id = end_user_id or "group_memsci"
|
||||
|
||||
# 数据路径解析(项目根与当前工作目录兜底)
|
||||
# 数据路径解析
|
||||
if not data_path:
|
||||
proj_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
cwd_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
if os.path.exists(proj_path):
|
||||
data_path = proj_path
|
||||
elif os.path.exists(cwd_path):
|
||||
data_path = cwd_path
|
||||
else:
|
||||
raise FileNotFoundError("未找到数据集: data/msc_self_instruct.jsonl,请确保其存在于项目根目录或当前工作目录的 data 目录下。")
|
||||
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
|
||||
data_path = str(dataset_dir / "msc_self_instruct.jsonl")
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
|
||||
)
|
||||
|
||||
# 加载数据
|
||||
all_items = load_dataset_memsciqa(data_path)
|
||||
@@ -238,17 +187,13 @@ async def run_memsciqa_test(
|
||||
items = all_items[start_index:start_index + sample_size]
|
||||
|
||||
# 初始化 LLM(纯测试:不进行摄入)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
llm = get_llm_client(os.getenv("EVAL_LLM_ID"))
|
||||
|
||||
# 初始化 Neo4j 连接与向量检索 Embedder(对齐 locomo_test)
|
||||
connector = Neo4jConnector()
|
||||
embedder = None
|
||||
if search_type in ("embedding", "hybrid"):
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
cfg_dict = config_service.get_embedder_config(SELECTED_EMBEDDING_ID)
|
||||
cfg_dict = get_embedder_config(os.getenv("EVAL_EMBEDDING_ID"))
|
||||
embedder = OpenAIEmbedderClient(
|
||||
model_config=RedBearModelConfig.model_validate(cfg_dict)
|
||||
)
|
||||
@@ -273,7 +218,7 @@ async def run_memsciqa_test(
|
||||
question = item.get("self_instruct", {}).get("B", "") or item.get("question", "")
|
||||
reference = item.get("self_instruct", {}).get("A", "") or item.get("answer", "")
|
||||
|
||||
# 三路检索:chunks/statements/entities/summaries(对齐 qwen_search_eval.py)
|
||||
# 检索:使用与 evaluate_qa.py 相同的 run_hybrid_search
|
||||
t0 = time.time()
|
||||
results = None
|
||||
try:
|
||||
@@ -302,57 +247,94 @@ async def run_memsciqa_test(
|
||||
search_ms = (t1 - t0) * 1000
|
||||
latencies_search.append(search_ms)
|
||||
|
||||
# 构建上下文:包含 chunks、陈述、摘要和实体(对齐 qwen_search_eval.py)
|
||||
# 构建上下文:与 evaluate_qa.py 完全一致的逻辑
|
||||
contexts_all: List[str] = []
|
||||
retrieved_counts: Dict[str, int] = {}
|
||||
if results:
|
||||
chunks = results.get("chunks", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
summaries = results.get("summaries", [])
|
||||
# 处理 hybrid 搜索结果
|
||||
if search_type == "hybrid":
|
||||
emb = results.get("embedding_search", {}) if isinstance(results.get("embedding_search"), dict) else {}
|
||||
kw = results.get("keyword_search", {}) if isinstance(results.get("keyword_search"), dict) else {}
|
||||
emb_dialogs = emb.get("dialogues", [])
|
||||
emb_statements = emb.get("statements", [])
|
||||
emb_entities = emb.get("entities", [])
|
||||
kw_dialogs = kw.get("dialogues", [])
|
||||
kw_statements = kw.get("statements", [])
|
||||
kw_entities = kw.get("entities", [])
|
||||
all_dialogs = emb_dialogs + kw_dialogs
|
||||
all_statements = emb_statements + kw_statements
|
||||
all_entities = emb_entities + kw_entities
|
||||
|
||||
# 简单去重
|
||||
seen_dialog = set()
|
||||
dialogues = []
|
||||
for d in all_dialogs:
|
||||
key = (str(d.get("uuid", "")), str(d.get("content", "")))
|
||||
if key not in seen_dialog:
|
||||
dialogues.append(d)
|
||||
seen_dialog.add(key)
|
||||
|
||||
seen_stmt = set()
|
||||
statements = []
|
||||
for s in all_statements:
|
||||
key = str(s.get("statement", ""))
|
||||
if key not in seen_stmt:
|
||||
statements.append(s)
|
||||
seen_stmt.add(key)
|
||||
|
||||
seen_ent = set()
|
||||
entities = []
|
||||
for e in all_entities:
|
||||
key = str(e.get("name", ""))
|
||||
if key not in seen_ent:
|
||||
entities.append(e)
|
||||
seen_ent.add(key)
|
||||
else:
|
||||
# embedding 或 keyword 单独搜索
|
||||
dialogues = results.get("dialogues", [])
|
||||
statements = results.get("statements", [])
|
||||
entities = results.get("entities", [])
|
||||
|
||||
retrieved_counts = {
|
||||
"chunks": len(chunks),
|
||||
"dialogues": len(dialogues),
|
||||
"statements": len(statements),
|
||||
"entities": len(entities),
|
||||
"summaries": len(summaries),
|
||||
}
|
||||
# 优先使用 chunks
|
||||
for c in chunks:
|
||||
text = str(c.get("content", "")).strip()
|
||||
|
||||
# 构建上下文文本
|
||||
for d in dialogues:
|
||||
text = str(d.get("content", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 statements
|
||||
|
||||
for s in statements:
|
||||
text = str(s.get("statement", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 然后是 summaries
|
||||
for sm in summaries:
|
||||
text = str(sm.get("summary", "")).strip()
|
||||
if text:
|
||||
contexts_all.append(text)
|
||||
# 实体摘要:最多加入前3个高分实体(对齐 qwen_search_eval.py)
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
# 实体摘要
|
||||
if entities:
|
||||
scored = [e for e in entities if e.get("score") is not None]
|
||||
top_entities = sorted(scored, key=lambda x: x.get("score", 0), reverse=True)[:3] if scored else entities[:3]
|
||||
if top_entities:
|
||||
summary_lines = []
|
||||
for e in top_entities:
|
||||
name = str(e.get("name", "")).strip()
|
||||
etype = str(e.get("entity_type", "")).strip()
|
||||
score = e.get("score")
|
||||
if name:
|
||||
meta = []
|
||||
if etype:
|
||||
meta.append(f"type={etype}")
|
||||
if isinstance(score, (int, float)):
|
||||
meta.append(f"score={score:.3f}")
|
||||
summary_lines.append(f"EntitySummary: {name}{(' [' + '; '.join(meta) + ']') if meta else ''}")
|
||||
if summary_lines:
|
||||
contexts_all.append("\n".join(summary_lines))
|
||||
|
||||
if verbose:
|
||||
if retrieved_counts:
|
||||
print(f"✅ 检索成功: {retrieved_counts.get('chunks',0)} chunks, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
||||
print(f"✅ 检索成功: {retrieved_counts.get('dialogues',0)} dialogues, {retrieved_counts.get('statements',0)} 条陈述, {retrieved_counts.get('entities',0)} 个实体, {retrieved_counts.get('summaries',0)} 个摘要")
|
||||
print(f"📊 有效上下文数量: {len(contexts_all)}")
|
||||
q_keywords = extract_question_keywords(question, max_keywords=8)
|
||||
if q_keywords:
|
||||
@@ -507,8 +489,8 @@ async def run_memsciqa_test(
|
||||
"llm_max_tokens": llm_max_tokens,
|
||||
"search_type": search_type,
|
||||
"start_index": start_index,
|
||||
"llm_id": SELECTED_LLM_ID,
|
||||
"retrieval_embedding_id": SELECTED_EMBEDDING_ID
|
||||
"llm_id": os.getenv("EVAL_LLM_ID"),
|
||||
"retrieval_embedding_id": os.getenv("EVAL_EMBEDDING_ID")
|
||||
},
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
@@ -522,7 +504,7 @@ async def run_memsciqa_test(
|
||||
def main():
|
||||
load_dotenv()
|
||||
parser = argparse.ArgumentParser(description="memsciqa 测试脚本(三路检索 + 智能上下文选择)")
|
||||
parser.add_argument("--sample-size", type=int, default=30, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--sample-size", type=int, default=10, help="样本数量(<=0 表示全部)")
|
||||
parser.add_argument("--all", action="store_true", help="评估全部样本(覆盖 --sample-size)")
|
||||
parser.add_argument("--start-index", type=int, default=0, help="起始样本索引")
|
||||
parser.add_argument("--group-id", type=str, default="group_memsci", help="图数据库 Group ID(默认 group_memsci)")
|
||||
|
||||
@@ -4,35 +4,20 @@ import json
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List
|
||||
from typing import List, Dict, Any
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent.parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
avg_context_tokens,
|
||||
exact_match,
|
||||
latency_stats,
|
||||
)
|
||||
from app.core.memory.evaluation.extraction_utils import (
|
||||
ingest_contexts_via_full_pipeline,
|
||||
)
|
||||
from app.core.memory.storage_services.search import run_hybrid_search
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
PROJECT_ROOT,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_LLM_ID,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.src.search import run_hybrid_search # 使用旧版本(重构前)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.evaluation.extraction_utils import ingest_contexts_via_full_pipeline
|
||||
from app.core.memory.evaluation.common.metrics import exact_match, latency_stats, avg_context_tokens
|
||||
|
||||
|
||||
def smart_context_selection(contexts: List[str], question: str, max_chars: int = 4000) -> str:
|
||||
@@ -135,24 +120,37 @@ def _combine_dialogues_for_hybrid(results: Dict[str, Any]) -> List[Dict[str, Any
|
||||
return merged
|
||||
|
||||
|
||||
|
||||
async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None, search_limit: int = 8, context_char_budget: int = 4000, llm_temperature: float = 0.0, llm_max_tokens: int = 64, search_type: str = "hybrid", memory_config: "MemoryConfig" = None) -> Dict[str, Any]:
|
||||
end_user_id = end_user_id or SELECTED_GROUP_ID
|
||||
|
||||
# Load data
|
||||
data_path = os.path.join(PROJECT_ROOT, "data", "msc_self_instruct.jsonl")
|
||||
dataset_dir = Path(__file__).resolve().parent.parent / "dataset"
|
||||
data_path = dataset_dir / "msc_self_instruct.jsonl"
|
||||
|
||||
if not os.path.exists(data_path):
|
||||
data_path = os.path.join(os.getcwd(), "data", "msc_self_instruct.jsonl")
|
||||
raise FileNotFoundError(
|
||||
f"数据集文件不存在: {data_path}\n"
|
||||
f"请将 msc_self_instruct.jsonl 放置在: {dataset_dir}"
|
||||
)
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
items: List[Dict[str, Any]] = [json.loads(l) for l in lines[:sample_size]]
|
||||
|
||||
|
||||
# 改为:每条样本仅摄入一个上下文(完整对话转录),避免多上下文摄入
|
||||
# 说明:memsciqa 数据集的每个样本天然只有一个对话,保持按样本一上下文的策略
|
||||
contexts: List[str] = [build_context_from_dialog(item) for item in items]
|
||||
await ingest_contexts_via_full_pipeline(contexts, end_user_id)
|
||||
|
||||
# LLM client (使用异步调用)
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(SELECTED_LLM_ID)
|
||||
from app.db import get_db
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
llm_client = get_llm_client(os.getenv("EVAL_LLM_ID"), db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Evaluate each item
|
||||
connector = Neo4jConnector()
|
||||
@@ -177,7 +175,6 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None
|
||||
limit=search_limit,
|
||||
include=["dialogues", "statements", "entities"],
|
||||
output_path=None,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
except Exception:
|
||||
results = None
|
||||
@@ -261,11 +258,7 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None
|
||||
pred = resp.content.strip() if hasattr(resp, 'content') else (resp["choices"][0]["message"]["content"].strip() if isinstance(resp, dict) else str(resp).strip())
|
||||
# Metrics: F1, BLEU-1, Jaccard; keep exact match for reference
|
||||
correct_flags.append(exact_match(pred, reference))
|
||||
from app.core.memory.evaluation.common.metrics import (
|
||||
bleu1,
|
||||
f1_score,
|
||||
jaccard,
|
||||
)
|
||||
from app.core.memory.evaluation.common.metrics import f1_score, bleu1, jaccard
|
||||
f1s.append(f1_score(str(pred), str(reference)))
|
||||
b1s.append(bleu1(str(pred), str(reference)))
|
||||
jss.append(jaccard(str(pred), str(reference)))
|
||||
@@ -295,15 +288,39 @@ async def run_memsciqa_eval(sample_size: int = 1, end_user_id: str | None = None
|
||||
|
||||
|
||||
def main():
|
||||
# Load environment variables first
|
||||
load_dotenv()
|
||||
|
||||
# Get defaults from environment variables
|
||||
env_sample_size = os.getenv("MEMSCIQA_SAMPLE_SIZE")
|
||||
env_search_limit = os.getenv("MEMSCIQA_SEARCH_LIMIT")
|
||||
env_context_budget = os.getenv("MEMSCIQA_CONTEXT_CHAR_BUDGET")
|
||||
env_llm_max_tokens = os.getenv("MEMSCIQA_LLM_MAX_TOKENS")
|
||||
env_skip_ingest = os.getenv("MEMSCIQA_SKIP_INGEST", "false").lower() in ("true", "1", "yes")
|
||||
env_output_dir = os.getenv("MEMSCIQA_OUTPUT_DIR")
|
||||
|
||||
# Convert to appropriate types with fallback to code defaults
|
||||
default_sample_size = int(env_sample_size) if env_sample_size else 1
|
||||
default_search_limit = int(env_search_limit) if env_search_limit else 8
|
||||
default_context_budget = int(env_context_budget) if env_context_budget else 4000
|
||||
default_llm_max_tokens = int(env_llm_max_tokens) if env_llm_max_tokens else 64
|
||||
default_output_dir = env_output_dir if env_output_dir else None
|
||||
|
||||
parser = argparse.ArgumentParser(description="Evaluate DMR (memsciqa) with graph search and Qwen")
|
||||
|
||||
parser.add_argument("--sample-size", type=int, default=1, help="评测样本数量")
|
||||
parser.add_argument("--group-id", type=str, default=None, help="可选 end_user_id,默认取 runtime.json")
|
||||
parser.add_argument("--end-user-id", type=str, default=None, help="可选 end_user_id,默认使用环境变量")
|
||||
parser.add_argument("--search-limit", type=int, default=8, help="每类检索最大返回数")
|
||||
parser.add_argument("--context-char-budget", type=int, default=4000, help="上下文字符预算")
|
||||
|
||||
parser.add_argument("--llm-temperature", type=float, default=0.0, help="LLM 温度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=64, help="LLM 最大生成长度")
|
||||
parser.add_argument("--llm-max-tokens", type=int, default=default_llm_max_tokens,
|
||||
help=f"LLM 最大生成长度 (env: MEMSCIQA_LLM_MAX_TOKENS={env_llm_max_tokens or 'not set'})")
|
||||
parser.add_argument("--search-type", type=str, choices=["keyword","embedding","hybrid"], default="hybrid", help="检索类型")
|
||||
parser.add_argument("--skip-ingest", action="store_true", default=env_skip_ingest,
|
||||
help=f"跳过数据摄入,使用 Neo4j 中的现有数据 (env: MEMSCIQA_SKIP_INGEST={os.getenv('MEMSCIQA_SKIP_INGEST', 'false')})")
|
||||
parser.add_argument("--output-dir", type=str, default=default_output_dir,
|
||||
help=f"结果保存目录 (env: MEMSCIQA_OUTPUT_DIR={env_output_dir or 'not set'})")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = asyncio.run(
|
||||
@@ -315,9 +332,37 @@ def main():
|
||||
llm_temperature=args.llm_temperature,
|
||||
llm_max_tokens=args.llm_max_tokens,
|
||||
search_type=args.search_type,
|
||||
skip_ingest=args.skip_ingest,
|
||||
)
|
||||
)
|
||||
|
||||
# Print results to console
|
||||
print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
# Save results to file
|
||||
output_dir = args.output_dir
|
||||
if output_dir is None:
|
||||
# Use absolute path to ensure results are saved in the correct location
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
output_dir = script_dir / "results"
|
||||
elif not Path(output_dir).is_absolute():
|
||||
# If relative path, make it relative to this script's directory
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
output_dir = script_dir / output_dir
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
output_path = output_dir / f"memsciqa_{timestamp_str}.json"
|
||||
|
||||
try:
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(result, f, ensure_ascii=False, indent=2)
|
||||
print(f"\n✅ 结果已保存到: {output_path}")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 保存结果失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -2,20 +2,16 @@ import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Add src directory to Python path for proper imports when running from evaluation directory
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'src'))
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except Exception:
|
||||
def load_dotenv():
|
||||
return None
|
||||
# Load evaluation config
|
||||
eval_config_path = Path(__file__).resolve().parent / ".env.evaluation"
|
||||
if eval_config_path.exists():
|
||||
load_dotenv(eval_config_path, override=True)
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.utils.config.definitions import SELECTED_GROUP_ID, PROJECT_ROOT
|
||||
|
||||
from app.core.memory.evaluation.memsciqa.evaluate_qa import run_memsciqa_eval
|
||||
from app.core.memory.evaluation.longmemeval.qwen_search_eval import run_longmemeval_test
|
||||
@@ -36,8 +32,9 @@ async def run(
|
||||
start_index: int | None = None,
|
||||
max_contexts_per_item: int | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
# 恢复原始风格:统一入口做路由,并沿用各数据集既有默认
|
||||
end_user_id = end_user_id or SELECTED_GROUP_ID
|
||||
# Use environment variable with fallback chain if not provided
|
||||
if end_user_id is None:
|
||||
end_user_id = os.getenv("EVAL_END_USER_ID", "benchmark_default")
|
||||
|
||||
if reset_group:
|
||||
connector = Neo4jConnector()
|
||||
|
||||
@@ -1064,13 +1064,16 @@ class ExtractionOrchestrator:
|
||||
if statement.triplet_extraction_info:
|
||||
triplet_info = statement.triplet_extraction_info
|
||||
|
||||
# 创建实体索引到ID的映射
|
||||
# 创建实体索引到ID的映射(支持多种索引方式)
|
||||
entity_idx_to_id = {}
|
||||
|
||||
# 创建实体节点
|
||||
for entity_idx, entity in enumerate(triplet_info.entities):
|
||||
# 映射实体索引到实体ID
|
||||
# 映射实体索引到实体ID(使用多个键以提高容错性)
|
||||
# 1. 使用实体自己的 entity_idx
|
||||
entity_idx_to_id[entity.entity_idx] = entity.id
|
||||
# 2. 使用枚举索引(从0开始)
|
||||
entity_idx_to_id[entity_idx] = entity.id
|
||||
|
||||
if entity.id not in entity_id_set:
|
||||
entity_connect_strength = getattr(entity, 'connect_strength', 'Strong')
|
||||
@@ -1149,9 +1152,18 @@ class ExtractionOrchestrator:
|
||||
relationship_result
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"跳过三元组 - 无法找到实体ID: subject_id={triplet.subject_id}, "
|
||||
f"object_id={triplet.object_id}, statement_id={statement.id}"
|
||||
# 改进的警告信息,包含更多调试信息
|
||||
missing_subject = "subject" if not subject_entity_id else ""
|
||||
missing_object = "object" if not object_entity_id else ""
|
||||
missing_both = " and " if (not subject_entity_id and not object_entity_id) else ""
|
||||
|
||||
logger.debug(
|
||||
f"跳过三元组 - 无法找到{missing_subject}{missing_both}{missing_object}实体ID: "
|
||||
f"subject_id={triplet.subject_id} ({triplet.subject_name}), "
|
||||
f"object_id={triplet.object_id} ({triplet.object_name}), "
|
||||
f"predicate={triplet.predicate}, "
|
||||
f"statement_id={statement.id}, "
|
||||
f"available_indices={sorted(entity_idx_to_id.keys())}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import relationship
|
||||
|
||||
from app.base.type import PydanticType
|
||||
from app.db import Base
|
||||
from app.schemas import ModelParameters
|
||||
from app.schemas.app_schema import ModelParameters
|
||||
|
||||
|
||||
class AgentConfig(Base):
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.orm import relationship
|
||||
|
||||
from app.base.type import PydanticType
|
||||
from app.db import Base
|
||||
from app.schemas import ModelParameters
|
||||
from app.schemas.app_schema import ModelParameters
|
||||
|
||||
|
||||
class OrchestrationMode(StrEnum):
|
||||
|
||||
@@ -4,7 +4,7 @@ import datetime
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer
|
||||
|
||||
from app.schemas import ModelParameters
|
||||
from app.schemas.app_schema import ModelParameters
|
||||
|
||||
|
||||
# ==================== 子 Agent 配置 ====================
|
||||
|
||||
@@ -5,7 +5,7 @@ import uuid
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.schemas import ModelParameters
|
||||
from app.schemas.app_schema import ModelParameters
|
||||
from app.services.conversation_state_manager import ConversationStateManager
|
||||
from app.models import ModelConfig, AgentConfig
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
@@ -188,7 +188,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"config_desc": config.config_desc,
|
||||
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
|
||||
"end_user_id": config.end_user_id,
|
||||
"user_id": config.user_id,
|
||||
"config_id_old": int(config.user_id),
|
||||
"apply_id": config.apply_id,
|
||||
"llm_id": config.llm_id,
|
||||
"embedding_id": config.embedding_id,
|
||||
|
||||
@@ -57,7 +57,7 @@ def dict_to_model_parameters(data: Optional[Dict[str, Any]]) -> Optional[Any]:
|
||||
if data is None:
|
||||
return None
|
||||
|
||||
from app.schemas import ModelParameters
|
||||
from app.schemas.app_schema import ModelParameters
|
||||
|
||||
if isinstance(data, ModelParameters):
|
||||
return data
|
||||
|
||||
@@ -8,6 +8,7 @@ import { type FC, useRef, useEffect } from 'react'
|
||||
import clsx from 'clsx'
|
||||
import Markdown from '@/components/Markdown'
|
||||
import type { ChatContentProps } from './types'
|
||||
import { Spin } from 'antd'
|
||||
|
||||
/**
|
||||
* 聊天内容显示组件
|
||||
@@ -21,7 +22,8 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
empty,
|
||||
labelPosition = 'bottom',
|
||||
labelFormat,
|
||||
errorDesc
|
||||
errorDesc,
|
||||
renderRuntime
|
||||
}) => {
|
||||
// 滚动容器引用,用于控制自动滚动到底部
|
||||
const scrollContainerRef = useRef<(HTMLDivElement | null)>(null)
|
||||
@@ -45,8 +47,8 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
'rb:left-0 rb:text-left': item.role === 'assistant', // 助手消息左对齐
|
||||
})}>
|
||||
{/* 流式加载时且内容为空则不显示 */}
|
||||
{streamLoading && item.content === ''
|
||||
? null
|
||||
{streamLoading && item.content === '' && !renderRuntime
|
||||
? <Spin />
|
||||
: <>
|
||||
{/* 顶部标签(如时间戳、用户名等) */}
|
||||
{labelPosition === 'top' &&
|
||||
@@ -55,16 +57,17 @@ const ChatContent: FC<ChatContentProps> = ({
|
||||
</div>
|
||||
}
|
||||
{/* 消息气泡框 */}
|
||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-[520px] rb:wrap-break-word', contentClassNames, {
|
||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-130 rb:wrap-break-word', contentClassNames, {
|
||||
// 错误消息样式(内容为null且非助手消息)
|
||||
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null,
|
||||
'rb:border-[rgba(255,93,52,0.30)] rb:bg-[rgba(255,93,52,0.08)] rb:text-[#FF5D34]': errorDesc && item.role === 'assistant' && item.content === null && !renderRuntime,
|
||||
// 助手消息样式
|
||||
'rb:bg-[rgba(21,94,239,0.08)] rb:border-[rgba(21,94,239,0.30)]': item.role === 'user',
|
||||
// 用户消息样式
|
||||
'rb:bg-[#FFFFFF] rb:border-[#EBEBEB]': item.role === 'assistant' && (item.content || item.content === ''),
|
||||
'rb:bg-[#FFFFFF] rb:border-[#EBEBEB]': item.role === 'assistant' && (item.content || item.content === '' || typeof renderRuntime === 'function'),
|
||||
})}>
|
||||
{item.subContent && renderRuntime && renderRuntime(item, index)}
|
||||
{/* 使用Markdown组件渲染消息内容 */}
|
||||
<Markdown content={item.content ?? errorDesc ?? ''} />
|
||||
<Markdown content={renderRuntime ? item.content ?? '' : item.content ?? errorDesc ?? ''} />
|
||||
</div>
|
||||
{/* 底部标签(如时间戳、用户名等) */}
|
||||
{labelPosition === 'bottom' &&
|
||||
|
||||
@@ -19,7 +19,9 @@ export interface ChatItem {
|
||||
/** 消息内容 */
|
||||
content?: string | null;
|
||||
/** 创建时间 */
|
||||
created_at?: number | string
|
||||
created_at?: number | string;
|
||||
status?: string;
|
||||
subContent?: Record<string, any>[]
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -81,4 +83,5 @@ export interface ChatContentProps {
|
||||
/** 标签格式化函数 */
|
||||
labelFormat: (item: ChatItem) => any;
|
||||
errorDesc?: string;
|
||||
renderRuntime?: (item: ChatItem, index: number) => ReactNode;
|
||||
}
|
||||
@@ -15,7 +15,7 @@ interface ApiResponse<T> {
|
||||
interface CustomSelectProps extends Omit<SelectProps, 'filterOption'> {
|
||||
url: string;
|
||||
params?: Record<string, unknown>;
|
||||
valueKey?: string;
|
||||
valueKey?: string | string[];
|
||||
labelKey?: string;
|
||||
placeholder?: string;
|
||||
hasAll?: boolean;
|
||||
@@ -66,11 +66,18 @@ const CustomSelect: FC<CustomSelectProps> = ({
|
||||
{...props}
|
||||
>
|
||||
{hasAll && <Select.Option value={null}>{allTitle || t('common.all')}</Select.Option>}
|
||||
{displayOptions.map((option) => (
|
||||
<Select.Option key={option[valueKey]} value={option[valueKey]}>
|
||||
{String(option[labelKey])}
|
||||
</Select.Option>
|
||||
))}
|
||||
{displayOptions.map((option) => {
|
||||
const getValue = () => {
|
||||
if (typeof valueKey === 'string') return option[valueKey];
|
||||
return valueKey.find(key => option[key] != null) ? option[valueKey.find(key => option[key] != null)!] : undefined;
|
||||
};
|
||||
const value = getValue();
|
||||
return (
|
||||
<Select.Option key={value} value={value}>
|
||||
{String(option[labelKey])}
|
||||
</Select.Option>
|
||||
);
|
||||
})}
|
||||
</Select>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,6 +6,9 @@ import CopyBtn from './CopyBtn';
|
||||
|
||||
type ICodeBlockProps = {
|
||||
value: string;
|
||||
needCopy?: boolean;
|
||||
size?: 'small' | 'default';
|
||||
showLineNumbers?: boolean;
|
||||
}
|
||||
|
||||
// enum languageType {
|
||||
@@ -16,6 +19,9 @@ type ICodeBlockProps = {
|
||||
|
||||
const CodeBlock: FC<ICodeBlockProps> = ({
|
||||
value,
|
||||
needCopy = true,
|
||||
size = 'default',
|
||||
showLineNumbers = false
|
||||
}) => {
|
||||
|
||||
return (
|
||||
@@ -23,24 +29,26 @@ const CodeBlock: FC<ICodeBlockProps> = ({
|
||||
<SyntaxHighlighter
|
||||
style={atelierHeathLight}
|
||||
customStyle={{
|
||||
padding: '16px 20px 16px 24px',
|
||||
padding: '8px 12px 8px 12px',
|
||||
backgroundColor: '#F0F3F8',
|
||||
borderRadius: 8,
|
||||
fontSize: size === 'small' ? 12 : 14,
|
||||
wordBreak: 'break-all'
|
||||
}}
|
||||
language="json"
|
||||
showLineNumbers={false}
|
||||
showLineNumbers={showLineNumbers}
|
||||
PreTag="div"
|
||||
>
|
||||
{value}
|
||||
</SyntaxHighlighter>
|
||||
<CopyBtn
|
||||
{needCopy && <CopyBtn
|
||||
value={value}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 20,
|
||||
right: 20,
|
||||
}}
|
||||
/>
|
||||
/>}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1982,6 +1982,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
arrange: 'Arrange',
|
||||
redo: 'Redo',
|
||||
undo: 'Undo',
|
||||
|
||||
input: 'Input',
|
||||
output: 'Output',
|
||||
error: 'Error Message',
|
||||
},
|
||||
emotionEngine: {
|
||||
emotionEngineConfig: 'Emotion Engine Configuration',
|
||||
|
||||
@@ -2076,6 +2076,10 @@ export const zh = {
|
||||
arrange: '整理',
|
||||
redo: '重做',
|
||||
undo: '撤销',
|
||||
|
||||
input: '输入',
|
||||
output: '输出',
|
||||
error: '错误信息',
|
||||
},
|
||||
emotionEngine: {
|
||||
emotionEngineConfig: '情感引擎配置',
|
||||
|
||||
@@ -123,6 +123,20 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe
|
||||
let response = await makeSSERequest(url, data, token || '', config);
|
||||
|
||||
switch (response.status) {
|
||||
case 500:
|
||||
case 502:
|
||||
const errorData = await response.json();
|
||||
errorData.error || i18n.t('common.serviceUpgrading');
|
||||
message.warning(errorData.error || i18n.t('common.serviceUpgrading'));
|
||||
break
|
||||
case 400:
|
||||
const error = await response.json();
|
||||
message.warning(error.error);
|
||||
throw error || 'Bad Request';
|
||||
case 504:
|
||||
const errorJson = await response.json();
|
||||
message.warning(errorJson.error || i18n.t('common.serverError'));
|
||||
break
|
||||
case 401:
|
||||
if (url?.includes('/public')) {
|
||||
return message.warning(i18n.t('common.publicApiCannotRefreshToken'));
|
||||
|
||||
@@ -79,7 +79,7 @@ const SelectWrapper: FC<{ title: string, desc: string, name: string | string[],
|
||||
placeholder={t('common.pleaseSelect')}
|
||||
url={url}
|
||||
hasAll={false}
|
||||
valueKey='config_id'
|
||||
valueKey={['config_id_old', 'config_id']}
|
||||
labelKey="config_name"
|
||||
/>
|
||||
</Form.Item>
|
||||
@@ -126,12 +126,14 @@ const Agent = forwardRef<AgentRef>((_props, ref) => {
|
||||
getApplicationConfig(id as string).then(res => {
|
||||
const response = res as Config
|
||||
let allTools = Array.isArray(response.tools) ? response.tools : []
|
||||
const memoryContent = response.memory?.memory_content
|
||||
const convertedMemoryContent = memoryContent && !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent
|
||||
form.setFieldsValue({
|
||||
...response,
|
||||
tools: allTools,
|
||||
memory: {
|
||||
...response.memory,
|
||||
memory_content: response.memory?.memory_content ? Number(response.memory?.memory_content) : undefined
|
||||
memory_content: convertedMemoryContent
|
||||
}
|
||||
})
|
||||
setData({
|
||||
|
||||
@@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
|
||||
useEffect(() => {
|
||||
if (values?.retrieve_type) {
|
||||
const fieldsToReset = Object.keys(values).filter(key =>
|
||||
key !== 'kb_id' && key !== 'retrieve_type'
|
||||
key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k'
|
||||
) as (keyof KnowledgeConfigForm)[];
|
||||
form.resetFields(fieldsToReset);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { forwardRef, useImperativeHandle, useState, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import clsx from 'clsx'
|
||||
import { Input, Form, App } from 'antd'
|
||||
import { Space, Button } from 'antd'
|
||||
import { Input, Form, App, Space, Button, Collapse } from 'antd'
|
||||
import { CheckCircleFilled, CloseCircleFilled, LoadingOutlined } from '@ant-design/icons'
|
||||
import CodeBlock from '@/components/Markdown/CodeBlock'
|
||||
|
||||
import ChatIcon from '@/assets/images/application/chat.png'
|
||||
import RbDrawer from '@/components/RbDrawer';
|
||||
@@ -13,8 +14,11 @@ import ChatContent from '@/components/Chat/ChatContent'
|
||||
import type { ChatItem } from '@/components/Chat/types'
|
||||
import ChatSendIcon from '@/assets/images/application/chatSend.svg'
|
||||
import dayjs from 'dayjs'
|
||||
import type { ChatRef, VariableConfigModalRef, StartVariableItem, GraphRef } from '../../types'
|
||||
import type { ChatRef, VariableConfigModalRef, GraphRef } from '../../types'
|
||||
import { type SSEMessage } from '@/utils/stream'
|
||||
import type { Variable } from '../Properties/VariableList/types'
|
||||
import styles from './chat.module.css'
|
||||
import Markdown from '@/components/Markdown'
|
||||
|
||||
const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId, graphRef }, ref) => {
|
||||
const { t } = useTranslation()
|
||||
@@ -24,7 +28,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
const [open, setOpen] = useState(false)
|
||||
const [loading, setLoading] = useState(false)
|
||||
const [chatList, setChatList] = useState<ChatItem[]>([])
|
||||
const [variables, setVariables] = useState<StartVariableItem[]>([])
|
||||
const [variables, setVariables] = useState<Variable[]>([])
|
||||
const [streamLoading, setStreamLoading] = useState(false)
|
||||
const [conversationId, setConversationId] = useState<string | null>(null)
|
||||
|
||||
@@ -39,7 +43,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
if (startNodes.length) {
|
||||
const curVariables = startNodes[0].config.variables?.defaultValue
|
||||
|
||||
curVariables.forEach((vo: StartVariableItem) => {
|
||||
curVariables.forEach((vo: Variable) => {
|
||||
if (typeof vo.default !== 'undefined') {
|
||||
vo.value = vo.default
|
||||
}
|
||||
@@ -60,7 +64,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
const handleEditVariables = () => {
|
||||
variableConfigModalRef.current?.handleOpen(variables)
|
||||
}
|
||||
const handleSave = (values: StartVariableItem[]) => {
|
||||
const handleSave = (values: Variable[]) => {
|
||||
setVariables([...values])
|
||||
}
|
||||
const handleSend = () => {
|
||||
@@ -97,13 +101,28 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
created_at: Date.now(),
|
||||
subContent: [],
|
||||
}])
|
||||
|
||||
const handleStreamMessage = (data: SSEMessage[]) => {
|
||||
setStreamLoading(false)
|
||||
|
||||
data.forEach(item => {
|
||||
const { chunk, conversation_id } = item.data as { chunk: string; conversation_id: string | null; };
|
||||
const { chunk, conversation_id, node_id, input, output, error, elapsed_time, status } = item.data as {
|
||||
chunk: string;
|
||||
conversation_id: string | null;
|
||||
node_id: string;
|
||||
node_name?: string;
|
||||
input?: any;
|
||||
output?: any;
|
||||
elapsed_time?: string;
|
||||
error?: any;
|
||||
state: Record<string, any>;
|
||||
status?: 'completed' | 'failed'
|
||||
};
|
||||
|
||||
const node = graphRef.current?.getNodes().find(n => n.id === node_id);
|
||||
const { name, icon } = node?.getData() || {}
|
||||
|
||||
console.log('node', node?.getData())
|
||||
|
||||
switch(item.event) {
|
||||
case 'message':
|
||||
@@ -119,6 +138,66 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
return newList
|
||||
})
|
||||
break
|
||||
case 'node_start':
|
||||
setChatList(prev => {
|
||||
const newList = [...prev]
|
||||
const lastIndex = newList.length - 1
|
||||
if (lastIndex >= 0) {
|
||||
const newSubContent = newList[lastIndex].subContent || []
|
||||
const filterIndex = newSubContent.findIndex(vo => vo.id === node_id)
|
||||
if (filterIndex > -1) {
|
||||
newSubContent[filterIndex] = {
|
||||
...newSubContent[filterIndex],
|
||||
node_id: node_id,
|
||||
node_name: name,
|
||||
icon,
|
||||
content: {},
|
||||
}
|
||||
} else {
|
||||
newSubContent.push({
|
||||
id: node_id,
|
||||
node_id: node_id,
|
||||
node_name: name,
|
||||
icon,
|
||||
content: {},
|
||||
})
|
||||
}
|
||||
newList[lastIndex] = {
|
||||
...newList[lastIndex],
|
||||
subContent: newSubContent
|
||||
}
|
||||
}
|
||||
return newList
|
||||
})
|
||||
break
|
||||
case 'node_end':
|
||||
case 'node_error':
|
||||
setChatList(prev => {
|
||||
const newList = [...prev]
|
||||
const lastIndex = newList.length - 1
|
||||
if (lastIndex >= 0) {
|
||||
const newSubContent = newList[lastIndex].subContent || []
|
||||
const filterIndex = newSubContent.findIndex(vo => vo.node_id === node_id)
|
||||
if (filterIndex > -1 && newSubContent[filterIndex].content) {
|
||||
newSubContent[filterIndex] = {
|
||||
...newSubContent[filterIndex],
|
||||
content: {
|
||||
input,
|
||||
output,
|
||||
error,
|
||||
},
|
||||
status: status || 'completed',
|
||||
elapsed_time
|
||||
}
|
||||
}
|
||||
newList[lastIndex] = {
|
||||
...newList[lastIndex],
|
||||
subContent: newSubContent
|
||||
}
|
||||
}
|
||||
return newList
|
||||
})
|
||||
break
|
||||
case 'workflow_end':
|
||||
setChatList(prev => {
|
||||
const newList = [...prev]
|
||||
@@ -126,6 +205,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
if (lastIndex >= 0) {
|
||||
newList[lastIndex] = {
|
||||
...newList[lastIndex],
|
||||
status,
|
||||
content: newList[lastIndex].content === '' ? null : newList[lastIndex].content
|
||||
}
|
||||
}
|
||||
@@ -142,14 +222,31 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
}
|
||||
|
||||
form.setFieldValue('message', undefined)
|
||||
setStreamLoading(true)
|
||||
draftRun(appId, {
|
||||
message: message,
|
||||
variables: params,
|
||||
stream: true,
|
||||
conversation_id: conversationId
|
||||
}, handleStreamMessage)
|
||||
.catch((error) => {
|
||||
setChatList(prev => {
|
||||
const newList = [...prev]
|
||||
const lastIndex = newList.length - 1
|
||||
if (lastIndex >= 0) {
|
||||
newList[lastIndex] = {
|
||||
...newList[lastIndex],
|
||||
status: 'failed',
|
||||
content: null,
|
||||
subContent: error.error
|
||||
}
|
||||
}
|
||||
return newList
|
||||
})
|
||||
})
|
||||
.finally(() => {
|
||||
setLoading(false)
|
||||
setStreamLoading(false)
|
||||
})
|
||||
}
|
||||
// 暴露给父组件的方法
|
||||
@@ -158,6 +255,11 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
handleClose
|
||||
}));
|
||||
|
||||
const getStatus = (status?: string) => {
|
||||
return status === 'completed' ? 'rb:text-[#369F21]' : status === 'failed' ? 'rb:text-[#FF5D34]' : 'rb:text-[#5B6167]'
|
||||
}
|
||||
|
||||
console.log('chatList', chatList)
|
||||
return (
|
||||
<RbDrawer
|
||||
title={<div className="rb:flex rb:items-center rb:gap-2.5">
|
||||
@@ -173,10 +275,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
onClose={handleClose}
|
||||
>
|
||||
<ChatContent
|
||||
classNames={{
|
||||
'rb:mx-[16px] rb:pt-[24px] rb:h-[calc(100%-76px)]': true,
|
||||
|
||||
}}
|
||||
classNames="rb:mx-[16px] rb:pt-[24px] rb:h-[calc(100%-76px)]"
|
||||
contentClassNames="rb:max-w-[400px]!'"
|
||||
empty={<Empty url={ChatIcon} title={t('application.chatEmpty')} isNeedSubTitle={false} size={[240, 200]} className="rb:h-full" />}
|
||||
data={chatList}
|
||||
@@ -184,6 +283,87 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
labelPosition="bottom"
|
||||
labelFormat={(item) => dayjs(item.created_at).locale('en').format('MMMM D, YYYY [at] h:mm A')}
|
||||
errorDesc={t('application.ReplyException')}
|
||||
renderRuntime={(item, index) => {
|
||||
return (
|
||||
<div key={index} className="rb:w-100 rb:mb-2">
|
||||
<Collapse
|
||||
className={styles[item.status || 'default']}
|
||||
items={[{
|
||||
key: 0,
|
||||
label: <div className={getStatus(item.status)}>
|
||||
{item.status === 'completed' ? <CheckCircleFilled className="rb:mr-1" /> : item.status === 'failed' ? <CloseCircleFilled className="rb:mr-1" /> : <LoadingOutlined className="rb:mr-1" />}
|
||||
{t('application.workflow')}
|
||||
</div>,
|
||||
className: styles.collapseItem,
|
||||
children: (
|
||||
Array.isArray(item.subContent)
|
||||
? <Space size={8} direction="vertical" className="rb:w-full!">
|
||||
{item.subContent?.map(vo => (
|
||||
<Collapse
|
||||
key={vo.node_id}
|
||||
items={[{
|
||||
key: vo.node_id,
|
||||
label: <div className={clsx("rb:flex rb:justify-between rb:items-center", getStatus(vo.status))}>
|
||||
<div className="rb:flex rb:items-center rb:gap-1 rb:flex-1">
|
||||
{vo.icon && <img src={vo.icon} className="rb:size-4" />}
|
||||
<div className="rb:wrap-break-word rb:line-clamp-1">{vo.node_name || vo.node_id}</div>
|
||||
</div>
|
||||
<span>
|
||||
{typeof vo.elapsed_time == 'number' && <>{vo.elapsed_time?.toFixed(3)}ms</>}
|
||||
{vo.status === 'completed' ? <CheckCircleFilled className="rb:ml-1" /> : vo.status === 'failed' ? <CloseCircleFilled className="rb:ml-1" /> : <LoadingOutlined className="rb:ml-1" />}
|
||||
</span>
|
||||
</div>,
|
||||
className: styles.collapseItem,
|
||||
children: (
|
||||
<Space size={8} direction="vertical" className="rb:w-full!">
|
||||
{vo.status === 'failed' &&
|
||||
<div className={clsx("rb:bg-[#F0F3F8] rb:rounded-md", getStatus(vo.status))}>
|
||||
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
|
||||
{t(`workflow.error`)}
|
||||
<Button
|
||||
className="rb:py-0! rb:px-1! rb:text-[12px]!"
|
||||
size="small"
|
||||
>{t('common.copy')}</Button>
|
||||
</div>
|
||||
<div className="rb:pb-2 rb:px-3 rb:max-h-40 rb:overflow-auto">
|
||||
<Markdown content={vo.content?.error || ''} />
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
{['input', 'output'].map(key => (
|
||||
<div key={key} className="rb:bg-[#F0F3F8] rb:rounded-md">
|
||||
<div className="rb:py-2 rb:px-3 rb:flex rb:justify-between rb:items-center rb:text-[12px]">
|
||||
{t(`workflow.${key}`)}
|
||||
<Button
|
||||
className="rb:py-0! rb:px-1! rb:text-[12px]!"
|
||||
size="small"
|
||||
>{t('common.copy')}</Button>
|
||||
</div>
|
||||
<div className="rb:max-h-40 rb:overflow-auto">
|
||||
<CodeBlock
|
||||
size="small"
|
||||
value={typeof vo.content === 'object' && vo.content?.[key] ? JSON.stringify(vo.content[key], null, 2) : '{}'}
|
||||
needCopy={false}
|
||||
showLineNumbers={true}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</Space>
|
||||
)
|
||||
}]}
|
||||
/>
|
||||
))}
|
||||
</Space>
|
||||
: <div className={clsx("rb:bg-[#FBFDFF] rb:rounded-md rb:py-2 rb:px-3 ", getStatus('failed'))}>
|
||||
<Markdown content={item.subContent || ''} />
|
||||
</div>
|
||||
)
|
||||
}]}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}}
|
||||
/>
|
||||
<div className="rb:flex rb:items-center rb:gap-2.5 rb:p-4">
|
||||
<Form form={form} style={{width: 'calc(100% - 54px)'}}>
|
||||
|
||||
45
web/src/views/Workflow/components/Chat/chat.module.css
Normal file
45
web/src/views/Workflow/components/Chat/chat.module.css
Normal file
@@ -0,0 +1,45 @@
|
||||
.completed {
|
||||
background-color: rgba(54, 159, 33, 0.06);
|
||||
border-color: rgba(54, 159, 33, 0.25);
|
||||
border-radius: 8px;
|
||||
}
|
||||
.failed {
|
||||
background-color: rgba(255, 138, 76, 0.08);
|
||||
border-color: rgba(255, 138, 76, 0.20);
|
||||
border-radius: 8px;
|
||||
}
|
||||
.default {
|
||||
background-color: rgba(91, 97, 103, 0.08);
|
||||
border-color: rgba(91, 97, 103, 0.30);
|
||||
border-radius: 8px;
|
||||
}
|
||||
.collapse-item {
|
||||
font-size: 12px;
|
||||
line-height: 16px;
|
||||
}
|
||||
.collapse-item:global(.ant-collapse-item>.ant-collapse-header) {
|
||||
padding: 8px 12px;
|
||||
}
|
||||
.collapse-item:global(.ant-collapse-item>.ant-collapse-header .ant-collapse-expand-icon) {
|
||||
height: 16px;
|
||||
}
|
||||
.completed:global(.ant-collapse .ant-collapse-content),
|
||||
.failed:global(.ant-collapse .ant-collapse-content) {
|
||||
background-color: transparent;
|
||||
border-top: none;
|
||||
}
|
||||
:global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) {
|
||||
padding-top: 0;
|
||||
}
|
||||
.collapse-item :global(.ant-collapse) {
|
||||
/* background-color: #F0F3F8; */
|
||||
background-color: #FBFDFF;
|
||||
border-radius: 6px;
|
||||
}
|
||||
.collapse-item :global(.ant-collapse>.ant-collapse-item:last-child),
|
||||
.collapse-item :global(.ant-collapse>.ant-collapse-item:last-child>.ant-collapse-header) {
|
||||
border-radius: 0 0 6px 6px;
|
||||
}
|
||||
.collapse-item :global(.ant-collapse .ant-collapse-content>.ant-collapse-content-box) {
|
||||
padding: 0 4px 4px 4px;
|
||||
}
|
||||
@@ -66,7 +66,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
|
||||
useEffect(() => {
|
||||
if (values?.retrieve_type) {
|
||||
const fieldsToReset = Object.keys(values).filter(key =>
|
||||
key !== 'kb_id' && key !== 'retrieve_type'
|
||||
key !== 'kb_id' && key !== 'retrieve_type' && key !== 'top_k'
|
||||
) as (keyof KnowledgeConfigForm)[];
|
||||
form.resetFields(fieldsToReset);
|
||||
}
|
||||
@@ -108,6 +108,7 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
|
||||
label: t(`application.${key}`),
|
||||
value: key,
|
||||
}))}
|
||||
// onChange={handleChange}
|
||||
/>
|
||||
</FormItem>
|
||||
{/* Top K */}
|
||||
@@ -116,13 +117,12 @@ const KnowledgeConfigModal = forwardRef<KnowledgeConfigModalRef, KnowledgeConfig
|
||||
label={t('application.top_k')}
|
||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
||||
extra={t('application.top_k_desc')}
|
||||
initialValue={5}
|
||||
>
|
||||
<InputNumber
|
||||
style={{ width: '100%' }}
|
||||
min={1}
|
||||
max={20}
|
||||
onChange={(value) => form.setFieldValue('top_k', value)}
|
||||
// onChange={(value) => form.setFieldValue('top_k', value)}
|
||||
/>
|
||||
</FormItem>
|
||||
{/* 语义相似度阈值 similarity_threshold */}
|
||||
|
||||
@@ -200,7 +200,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config_id: {
|
||||
type: 'customSelect',
|
||||
url: memoryConfigListUrl,
|
||||
valueKey: 'config_id',
|
||||
valueKey: ['config_id_old', 'config_id'],
|
||||
labelKey: 'config_name'
|
||||
},
|
||||
search_switch: {
|
||||
@@ -223,7 +223,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config_id: {
|
||||
type: 'customSelect',
|
||||
url: memoryConfigListUrl,
|
||||
valueKey: 'config_id',
|
||||
valueKey: ['config_id_old', 'config_id'],
|
||||
labelKey: 'config_name'
|
||||
}
|
||||
}
|
||||
@@ -284,7 +284,7 @@ export const nodeLibrary: NodeLibrary[] = [
|
||||
config: {
|
||||
input: {
|
||||
type: 'variableList',
|
||||
filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop'],
|
||||
filterNodeTypes: ['knowledge-retrieval', 'iteration', 'loop', 'parameter-extractor'],
|
||||
filterVariableNames: ['message']
|
||||
},
|
||||
parallel: {
|
||||
|
||||
@@ -14,7 +14,7 @@ export interface NodeConfig {
|
||||
|
||||
url?: string;
|
||||
params?: { [key: string]: unknown; }
|
||||
valueKey?: string;
|
||||
valueKey?: string | string[];
|
||||
labelKey?: string;
|
||||
|
||||
defaultValue?: any;
|
||||
|
||||
Reference in New Issue
Block a user