Initial commit
This commit is contained in:
445
app/core/memory/utils/README.md
Normal file
445
app/core/memory/utils/README.md
Normal file
@@ -0,0 +1,445 @@
|
||||
# Memory 模块工具函数文档
|
||||
|
||||
本目录包含 Memory 模块使用的所有工具函数,统一管理以提高代码可维护性和可复用性。
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
app/core/memory/utils/
|
||||
├── __init__.py # 包初始化文件,导出所有公共接口
|
||||
├── README.md # 本文档
|
||||
├── config/ # 配置管理模块
|
||||
│ ├── __init__.py # 配置模块初始化
|
||||
│ ├── config_utils.py # 配置管理工具
|
||||
│ ├── definitions.py # 全局定义和常量
|
||||
│ ├── overrides.py # 运行时配置覆写
|
||||
│ ├── get_data.py # 数据获取工具
|
||||
│ ├── litellm_config.py # LiteLLM 配置和监控
|
||||
│ └── config_optimization.py # 配置优化工具
|
||||
├── log/ # 日志管理模块
|
||||
│ ├── __init__.py # 日志模块初始化
|
||||
│ ├── logging_utils.py # 日志工具
|
||||
│ └── audit_logger.py # 审计日志
|
||||
├── prompt/ # 提示词管理模块
|
||||
│ ├── __init__.py # 提示词模块初始化
|
||||
│ ├── prompt_utils.py # 提示词渲染工具
|
||||
│ ├── template_render.py # 模板渲染工具
|
||||
│ └── prompts/ # Jinja2 提示词模板目录
|
||||
│ ├── entity_dedup.jinja2 # 实体去重提示词
|
||||
│ ├── extract_statement.jinja2 # 陈述句提取提示词
|
||||
│ ├── extract_temporal.jinja2 # 时间信息提取提示词
|
||||
│ ├── extract_triplet.jinja2 # 三元组提取提示词
|
||||
│ ├── memory_summary.jinja2 # 记忆摘要提示词
|
||||
│ ├── evaluate.jinja2 # 评估提示词
|
||||
│ ├── reflexion.jinja2 # 反思提示词
|
||||
│ ├── system.jinja2 # 系统提示词
|
||||
│ └── user.jinja2 # 用户提示词
|
||||
├── llm/ # LLM 工具模块
|
||||
│ ├── __init__.py # LLM 模块初始化
|
||||
│ └── llm_utils.py # LLM 客户端工具
|
||||
├── data/ # 数据处理模块
|
||||
│ ├── __init__.py # 数据模块初始化
|
||||
│ ├── text_utils.py # 文本处理工具
|
||||
│ ├── time_utils.py # 时间处理工具
|
||||
│ └── ontology.py # 本体定义(谓语、标签等)
|
||||
├── paths/ # 路径管理模块
|
||||
│ ├── __init__.py # 路径模块初始化
|
||||
│ └── output_paths.py # 输出路径管理
|
||||
├── visualization/ # 可视化模块
|
||||
│ ├── __init__.py # 可视化模块初始化
|
||||
│ └── forgetting_visualizer.py # 遗忘曲线可视化
|
||||
└── self_reflexion_utils/ # 自我反思工具模块
|
||||
├── __init__.py # 反思模块初始化
|
||||
├── evaluate.py # 冲突评估
|
||||
├── reflexion.py # 反思处理
|
||||
└── self_reflexion.py # 自我反思主逻辑
|
||||
```
|
||||
|
||||
## 模块分类
|
||||
|
||||
### 1. 配置管理(config/)
|
||||
|
||||
配置管理模块包含所有与配置相关的工具函数和定义。
|
||||
|
||||
#### config_utils.py
|
||||
提供配置加载和管理功能:
|
||||
- `get_model_config(model_id)` - 获取 LLM 模型配置
|
||||
- `get_embedder_config(embedding_id)` - 获取嵌入模型配置
|
||||
- `get_neo4j_config()` - 获取 Neo4j 数据库配置
|
||||
- `get_chunker_config(chunker_strategy)` - 获取分块策略配置
|
||||
- `get_pipeline_config()` - 获取流水线配置
|
||||
- `get_pruning_config()` - 获取语义剪枝配置
|
||||
- `get_picture_config()` - 获取图片模型配置
|
||||
- `get_voice_config()` - 获取语音模型配置
|
||||
|
||||
#### definitions.py
|
||||
全局定义和常量:
|
||||
- `CONFIG` - 基础配置(从 config.json 加载)
|
||||
- `RUNTIME_CONFIG` - 运行时配置(从 runtime.json 或数据库加载)
|
||||
- `PROJECT_ROOT` - 项目根目录路径
|
||||
- 各种选择配置常量(LLM、嵌入模型、分块策略等)
|
||||
- `reload_configuration_from_database(config_id)` - 动态重新加载配置
|
||||
|
||||
#### overrides.py
|
||||
运行时配置覆写:
|
||||
- `load_unified_config(project_root)` - 加载统一配置
|
||||
|
||||
#### get_data.py
|
||||
数据获取工具:
|
||||
- `get_data(host_id)` - 从 SQL 数据库获取数据
|
||||
|
||||
#### litellm_config.py
|
||||
LiteLLM 配置和监控:
|
||||
- `LiteLLMConfig` - LiteLLM 配置类
|
||||
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
|
||||
- `get_usage_summary()` - 获取使用统计摘要
|
||||
- `print_usage_summary()` - 打印使用统计
|
||||
- `get_instant_qps(module)` - 获取即时 QPS 数据
|
||||
- `print_instant_qps(module)` - 打印即时 QPS 信息
|
||||
|
||||
#### config_optimization.py
|
||||
配置优化工具:
|
||||
- 配置参数优化相关功能
|
||||
|
||||
### 3. LLM 工具(llm/)
|
||||
|
||||
LLM 工具模块包含所有与 LLM 客户端相关的工具函数。
|
||||
|
||||
#### llm_utils.py
|
||||
LLM 客户端工具:
|
||||
- `get_llm_client(llm_id)` - 获取 LLM 客户端实例
|
||||
- `get_reranker_client(rerank_id)` - 获取重排序客户端实例
|
||||
- `handle_response(response)` - 处理 LLM 响应
|
||||
|
||||
#### litellm_config.py
|
||||
LiteLLM 配置和监控:
|
||||
- `LiteLLMConfig` - LiteLLM 配置类
|
||||
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
|
||||
- `get_usage_summary()` - 获取使用统计摘要
|
||||
- `print_usage_summary()` - 打印使用统计
|
||||
- `get_instant_qps(module)` - 获取即时 QPS 数据
|
||||
- `print_instant_qps(module)` - 打印即时 QPS 信息
|
||||
|
||||
### 4. 提示词管理(prompt/)
|
||||
|
||||
提示词管理模块包含所有提示词渲染和模板管理相关的工具函数。
|
||||
|
||||
#### prompt_utils.py
|
||||
提示词渲染工具(使用 Jinja2 模板):
|
||||
- `get_prompts(message)` - 获取系统和用户提示词
|
||||
- `render_statement_extraction_prompt(...)` - 渲染陈述句提取提示词
|
||||
- `render_temporal_extraction_prompt(...)` - 渲染时间信息提取提示词
|
||||
- `render_entity_dedup_prompt(...)` - 渲染实体去重提示词
|
||||
- `render_triplet_extraction_prompt(...)` - 渲染三元组提取提示词
|
||||
- `render_memory_summary_prompt(...)` - 渲染记忆摘要提示词
|
||||
- `prompt_env` - Jinja2 环境对象
|
||||
|
||||
#### template_render.py
|
||||
模板渲染工具(用于评估和反思):
|
||||
- `render_evaluate_prompt(evaluate_data, schema)` - 渲染评估提示词
|
||||
- `render_reflexion_prompt(data, schema)` - 渲染反思提示词
|
||||
|
||||
#### prompts/
|
||||
Jinja2 模板文件目录,包含所有提示词模板
|
||||
|
||||
### 5. 数据处理(data/)
|
||||
|
||||
数据处理模块包含所有数据处理相关的工具函数。
|
||||
|
||||
#### text_utils.py
|
||||
文本处理工具:
|
||||
- `escape_lucene_query(query)` - 转义 Lucene 查询特殊字符
|
||||
- `extract_plain_query(query_input)` - 从各种输入格式提取纯文本查询
|
||||
|
||||
#### time_utils.py
|
||||
时间处理工具:
|
||||
- `validate_date_format(date_str)` - 验证日期格式(YYYY-MM-DD)
|
||||
- `normalize_date(date_str)` - 标准化日期格式
|
||||
- `normalize_date_safe(date_str, default)` - 安全的日期标准化(带默认值)
|
||||
- `preprocess_date_string(date_str)` - 预处理日期字符串
|
||||
|
||||
#### ontology.py
|
||||
本体定义:
|
||||
- `PREDICATE_DEFINITIONS` - 谓语定义字典
|
||||
- `LABEL_DEFINITIONS` - 标签定义字典
|
||||
- `Predicate` - 谓语枚举
|
||||
- `StatementType` - 陈述句类型枚举
|
||||
- `TemporalInfo` - 时间信息枚举
|
||||
- `RelevenceInfo` - 相关性信息枚举
|
||||
|
||||
### 2. 日志管理(log/)
|
||||
|
||||
日志管理模块包含所有与日志记录相关的工具函数。
|
||||
|
||||
#### logging_utils.py
|
||||
日志工具:
|
||||
- `log_prompt_rendering(role, content)` - 记录提示词渲染
|
||||
- `log_template_rendering(template_name, context)` - 记录模板渲染
|
||||
- `log_time(operation, duration)` - 记录操作耗时
|
||||
- `prompt_logger` - 提示词日志记录器
|
||||
|
||||
#### audit_logger.py
|
||||
审计日志:
|
||||
- `audit_logger` - 审计日志记录器
|
||||
- 记录系统关键操作和安全事件
|
||||
|
||||
### 6. 自我反思工具(self_reflexion_utils/)
|
||||
|
||||
自我反思工具模块包含记忆冲突检测和反思处理功能。
|
||||
|
||||
#### evaluate.py
|
||||
冲突评估:
|
||||
- `conflict(evaluate_data, schema)` - 评估记忆冲突
|
||||
|
||||
#### reflexion.py
|
||||
反思处理:
|
||||
- `reflexion(data, schema)` - 执行反思处理
|
||||
|
||||
#### self_reflexion.py
|
||||
自我反思主逻辑:
|
||||
- `self_reflexion(...)` - 自我反思主函数
|
||||
|
||||
### 7. 数据模型
|
||||
|
||||
#### json_schema.py
|
||||
JSON Schema 数据模型:
|
||||
- `BaseDataSchema` - 基础数据模型
|
||||
- `ConflictResultSchema` - 冲突结果模型
|
||||
- `ConflictSchema` - 冲突模型
|
||||
- `ReflexionSchema` - 反思模型
|
||||
- `ResolvedSchema` - 解决方案模型
|
||||
- `ReflexionResultSchema` - 反思结果模型
|
||||
|
||||
#### messages.py
|
||||
API 消息模型:
|
||||
- `ConfigKey` - 配置键模型
|
||||
- `ChunkerStrategy` - 分块策略枚举
|
||||
- `ConfigParams` - 配置参数模型
|
||||
- `ConfigParamsCreate` - 创建配置参数模型
|
||||
- `ConfigUpdate` - 更新配置模型
|
||||
- `ConfigUpdateExtracted` - 更新萃取引擎配置模型
|
||||
- `ConfigUpdateForget` - 更新遗忘引擎配置模型
|
||||
- `ConfigPilotRun` - 试运行配置模型
|
||||
- `ConfigFilter` - 配置过滤模型
|
||||
- `ApiResponse` - API 响应模型
|
||||
- `ok(msg, data)` - 成功响应构造函数
|
||||
- `fail(msg, error_code, data)` - 失败响应构造函数
|
||||
|
||||
### 8. 可视化(visualization/)
|
||||
|
||||
可视化模块包含所有可视化相关的工具函数。
|
||||
|
||||
#### forgetting_visualizer.py
|
||||
遗忘曲线可视化:
|
||||
- `export_memory_curve_numpy(...)` - 导出记忆曲线为 NumPy 数组
|
||||
- `export_memory_curves_multiple_strengths(...)` - 导出多个强度的记忆曲线
|
||||
- `export_parameter_sweep_numpy(...)` - 导出参数扫描结果
|
||||
- `visualize_forgetting_curve(...)` - 可视化遗忘曲线
|
||||
- `plot_3d_forgetting_surface(...)` - 绘制 3D 遗忘曲线表面
|
||||
- `create_comparison_visualization(...)` - 创建对比可视化
|
||||
- `save_memory_curves_to_file(...)` - 保存记忆曲线到文件
|
||||
|
||||
### 9. 路径管理(paths/)
|
||||
|
||||
路径管理模块包含所有路径管理相关的工具函数。
|
||||
|
||||
#### output_paths.py
|
||||
输出路径管理:
|
||||
- `get_output_dir()` - 获取输出目录
|
||||
- `get_output_path(filename)` - 获取输出文件路径
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 配置管理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.config import get_model_config, get_pipeline_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
|
||||
# 获取模型配置
|
||||
model_config = get_model_config("model_id_123")
|
||||
|
||||
# 获取流水线配置
|
||||
pipeline_config = get_pipeline_config()
|
||||
|
||||
# 使用全局常量
|
||||
llm_id = SELECTED_LLM_ID
|
||||
```
|
||||
|
||||
### 日志管理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.log import log_prompt_rendering, log_time, audit_logger
|
||||
|
||||
# 记录提示词渲染
|
||||
log_prompt_rendering('user', 'Hello, world!')
|
||||
|
||||
# 记录操作耗时
|
||||
log_time('extraction', 1.23)
|
||||
|
||||
# 使用审计日志
|
||||
audit_logger.info('User action performed')
|
||||
```
|
||||
|
||||
### LLM 工具
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.llm import get_llm_client
|
||||
|
||||
# 获取 LLM 客户端
|
||||
llm_client = get_llm_client("llm_id_456")
|
||||
|
||||
# 调用 LLM
|
||||
response = await llm_client.chat([
|
||||
{"role": "user", "content": "Hello"}
|
||||
])
|
||||
```
|
||||
|
||||
### 提示词渲染
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.prompt import render_statement_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
|
||||
|
||||
# 渲染陈述句提取提示词
|
||||
prompt = await render_statement_extraction_prompt(
|
||||
chunk_content="对话内容...",
|
||||
definitions=LABEL_DEFINITIONS,
|
||||
json_schema=schema,
|
||||
granularity=2
|
||||
)
|
||||
```
|
||||
|
||||
### 数据处理
|
||||
|
||||
```python
|
||||
from app.core.memory.utils.data.time_utils import normalize_date
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
# 标准化日期
|
||||
normalized = normalize_date("2025/10/28") # 返回 "2025-10-28"
|
||||
|
||||
# 转义 Lucene 查询
|
||||
escaped = escape_lucene_query("user:admin AND status:active")
|
||||
```
|
||||
|
||||
### 运行时配置覆写
|
||||
|
||||
```python
|
||||
from app.core.memory.utils import apply_runtime_overrides_with_config_id
|
||||
|
||||
# 使用指定 config_id 覆写配置
|
||||
runtime_cfg = {"selections": {}}
|
||||
updated_cfg = apply_runtime_overrides_with_config_id(
|
||||
project_root="/path/to/project",
|
||||
runtime_cfg=runtime_cfg,
|
||||
config_id="config_123"
|
||||
)
|
||||
```
|
||||
|
||||
## 迁移说明
|
||||
|
||||
### 从旧路径迁移
|
||||
|
||||
如果你的代码使用了旧的导入路径,请按以下方式更新:
|
||||
|
||||
**旧路径(2024年11月之前):**
|
||||
```python
|
||||
from app.core.memory.src.utils.config_utils import get_model_config
|
||||
from app.core.memory.src.utils.prompt_utils import render_statement_extraction_prompt
|
||||
from app.core.memory.src.data_config_api.utils.messages import ok, fail
|
||||
```
|
||||
|
||||
**中间路径(2024年11月):**
|
||||
```python
|
||||
from app.core.memory.utils.config_utils import get_model_config
|
||||
from app.core.memory.utils.logging_utils import log_prompt_rendering
|
||||
from app.schemas.memory_storage_schema import ok, fail
|
||||
```
|
||||
|
||||
**新路径(2024年11月27日之后):**
|
||||
```python
|
||||
# 配置相关
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.config import get_model_config # 简化导入
|
||||
|
||||
# 日志相关
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.log import log_prompt_rendering # 简化导入
|
||||
|
||||
# 其他工具
|
||||
from app.core.memory.utils import prompt_utils
|
||||
from app.schemas.memory_storage_schema import ok, fail
|
||||
```
|
||||
|
||||
### 目录结构重组(2024年11月27日)
|
||||
|
||||
utils 目录已按功能进行了完整的重组:
|
||||
|
||||
**重组前的结构:**
|
||||
- 所有文件都在 `app/core/memory/utils/` 根目录下
|
||||
|
||||
**重组后的结构:**
|
||||
- `config/` - 配置管理相关文件
|
||||
- `log/` - 日志管理相关文件
|
||||
- `prompt/` - 提示词管理相关文件
|
||||
- `llm/` - LLM 工具相关文件
|
||||
- `data/` - 数据处理相关文件
|
||||
- `paths/` - 路径管理相关文件
|
||||
- `visualization/` - 可视化相关文件
|
||||
- `self_reflexion_utils/` - 自我反思工具(已存在)
|
||||
|
||||
**导入路径变化:**
|
||||
```python
|
||||
# 旧导入方式
|
||||
from app.core.memory.utils.config_utils import get_model_config
|
||||
from app.core.memory.utils.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt_utils import render_statement_extraction_prompt
|
||||
|
||||
# 新导入方式
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
|
||||
|
||||
# 或使用简化导入
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.log import log_prompt_rendering
|
||||
from app.core.memory.utils.prompt import render_statement_extraction_prompt
|
||||
```
|
||||
|
||||
## 维护指南
|
||||
|
||||
### 添加新工具函数
|
||||
|
||||
1. 在相应的模块文件中添加函数
|
||||
2. 在 `__init__.py` 中导出函数
|
||||
3. 在本 README 中添加文档
|
||||
4. 编写单元测试
|
||||
|
||||
### 删除旧工具函数
|
||||
|
||||
1. 确认没有代码使用该函数
|
||||
2. 从模块文件中删除函数
|
||||
3. 从 `__init__.py` 中删除导出
|
||||
4. 更新本 README
|
||||
|
||||
### 重构工具函数
|
||||
|
||||
1. 保持向后兼容性(使用别名或包装器)
|
||||
2. 更新所有使用该函数的代码
|
||||
3. 更新文档和测试
|
||||
4. 在适当时机删除旧版本
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **向后兼容性**:所有工具函数应保持向后兼容,避免破坏现有代码
|
||||
2. **文档完整性**:每个函数都应有清晰的文档字符串
|
||||
3. **类型注解**:使用类型注解提高代码可读性
|
||||
4. **错误处理**:工具函数应有适当的错误处理
|
||||
5. **测试覆盖**:所有工具函数都应有单元测试
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [Memory 模块架构设计](../.kiro/specs/memory-refactoring/design.md)
|
||||
- [Memory 模块需求文档](../.kiro/specs/memory-refactoring/requirements.md)
|
||||
- [Memory 模块任务列表](../.kiro/specs/memory-refactoring/tasks.md)
|
||||
65
app/core/memory/utils/__init__.py
Normal file
65
app/core/memory/utils/__init__.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Memory 模块工具函数包
|
||||
|
||||
本包包含 Memory 模块使用的所有工具函数,按功能分类管理。
|
||||
|
||||
目录结构:
|
||||
- config/: 配置管理模块(config_utils, definitions, overrides, get_data, litellm_config, config_optimization)
|
||||
- log/: 日志管理模块(logging_utils, audit_logger)
|
||||
- prompt/: 提示词管理模块(prompt_utils, template_render, prompts/)
|
||||
- llm/: LLM 工具模块(llm_utils)
|
||||
- data/: 数据处理模块(text_utils, time_utils, ontology)
|
||||
- paths/: 路径管理模块(output_paths)
|
||||
- visualization/: 可视化模块(forgetting_visualizer)
|
||||
- self_reflexion_utils/: 自我反思工具(evaluate, reflexion, self_reflexion)
|
||||
|
||||
注意:
|
||||
- json_schema 和 messages 已迁移到 app.schemas.memory_storage_schema
|
||||
- 所有工具函数已按功能分类到对应的子目录
|
||||
|
||||
使用示例:
|
||||
# 配置管理
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
|
||||
|
||||
# 日志管理
|
||||
from app.core.memory.utils.log import log_prompt_rendering, audit_logger
|
||||
|
||||
# 提示词管理
|
||||
from app.core.memory.utils.prompt import render_statement_extraction_prompt
|
||||
|
||||
# LLM 工具
|
||||
from app.core.memory.utils.llm import get_llm_client
|
||||
|
||||
# 数据处理
|
||||
from app.core.memory.utils.data import text_utils, time_utils
|
||||
from app.core.memory.utils.data.ontology import Predicate, StatementType
|
||||
|
||||
# 路径管理
|
||||
from app.core.memory.utils.paths import get_output_dir
|
||||
|
||||
# 可视化
|
||||
from app.core.memory.utils.visualization import visualize_forgetting_curve
|
||||
|
||||
# 自我反思
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
"""
|
||||
|
||||
# 不在 __init__.py 中进行模块级别的导入,以避免循环导入
|
||||
# 用户应该直接导入需要的模块,例如:
|
||||
# from app.core.memory.utils.config import config_utils
|
||||
# from app.core.memory.utils.log import logging_utils
|
||||
# from app.core.memory.utils.data import text_utils
|
||||
# from app.core.memory.utils.prompt import prompt_utils
|
||||
|
||||
__all__ = [
|
||||
# 子模块
|
||||
"config",
|
||||
"log",
|
||||
"prompt",
|
||||
"llm",
|
||||
"data",
|
||||
"paths",
|
||||
"visualization",
|
||||
"self_reflexion_utils",
|
||||
]
|
||||
82
app/core/memory/utils/config/__init__.py
Normal file
82
app/core/memory/utils/config/__init__.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
配置管理模块
|
||||
|
||||
包含所有配置相关的工具函数和定义。
|
||||
"""
|
||||
|
||||
# 从子模块导出常用函数和常量,保持向后兼容
|
||||
from .config_utils import (
|
||||
get_model_config,
|
||||
get_embedder_config,
|
||||
get_neo4j_config,
|
||||
get_chunker_config,
|
||||
get_pipeline_config,
|
||||
get_pruning_config,
|
||||
get_picture_config,
|
||||
get_voice_config,
|
||||
)
|
||||
from .definitions import (
|
||||
CONFIG,
|
||||
RUNTIME_CONFIG,
|
||||
PROJECT_ROOT,
|
||||
SELECTED_LLM_ID,
|
||||
SELECTED_EMBEDDING_ID,
|
||||
SELECTED_GROUP_ID,
|
||||
SELECTED_RERANK_ID,
|
||||
SELECTED_LLM_PICTURE_NAME,
|
||||
SELECTED_LLM_VOICE_NAME,
|
||||
REFLEXION_ENABLED,
|
||||
REFLEXION_ITERATION_PERIOD,
|
||||
REFLEXION_RANGE,
|
||||
REFLEXION_BASELINE,
|
||||
reload_configuration_from_database,
|
||||
)
|
||||
from .overrides import load_unified_config
|
||||
from .get_data import get_data
|
||||
# litellm_config 需要时动态导入,避免循环依赖
|
||||
# from .litellm_config import (
|
||||
# LiteLLMConfig,
|
||||
# setup_litellm_enhanced,
|
||||
# get_usage_summary,
|
||||
# print_usage_summary,
|
||||
# get_instant_qps,
|
||||
# print_instant_qps,
|
||||
# )
|
||||
|
||||
__all__ = [
|
||||
# config_utils
|
||||
"get_model_config",
|
||||
"get_embedder_config",
|
||||
"get_neo4j_config",
|
||||
"get_chunker_config",
|
||||
"get_pipeline_config",
|
||||
"get_pruning_config",
|
||||
"get_picture_config",
|
||||
"get_voice_config",
|
||||
# definitions
|
||||
"CONFIG",
|
||||
"RUNTIME_CONFIG",
|
||||
"PROJECT_ROOT",
|
||||
"SELECTED_LLM_ID",
|
||||
"SELECTED_EMBEDDING_ID",
|
||||
"SELECTED_GROUP_ID",
|
||||
"SELECTED_RERANK_ID",
|
||||
"SELECTED_LLM_PICTURE_NAME",
|
||||
"SELECTED_LLM_VOICE_NAME",
|
||||
"REFLEXION_ENABLED",
|
||||
"REFLEXION_ITERATION_PERIOD",
|
||||
"REFLEXION_RANGE",
|
||||
"REFLEXION_BASELINE",
|
||||
"reload_configuration_from_database",
|
||||
# overrides
|
||||
"load_unified_config",
|
||||
# get_data
|
||||
"get_data",
|
||||
# litellm_config - 需要时从 .litellm_config 直接导入
|
||||
# "LiteLLMConfig",
|
||||
# "setup_litellm_enhanced",
|
||||
# "get_usage_summary",
|
||||
# "print_usage_summary",
|
||||
# "get_instant_qps",
|
||||
# "print_instant_qps",
|
||||
]
|
||||
398
app/core/memory/utils/config/config_optimization.py
Normal file
398
app/core/memory/utils/config/config_optimization.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
配置管理优化模块
|
||||
|
||||
提供可选的配置管理优化功能,包括:
|
||||
- LRU 缓存策略
|
||||
- 缓存预热
|
||||
- 缓存监控指标
|
||||
- 动态 TTL 策略
|
||||
- 配置版本控制
|
||||
|
||||
这些优化是可选的,当前的基础实现已经满足大多数需求。
|
||||
"""
|
||||
import logging
|
||||
import statistics
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LRUConfigCache:
|
||||
"""
|
||||
LRU(Least Recently Used)配置缓存
|
||||
|
||||
当缓存达到最大容量时,自动淘汰最少使用的配置
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)):
|
||||
"""
|
||||
初始化 LRU 缓存
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存容量
|
||||
ttl: 缓存过期时间
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
self._timestamps: Dict[str, datetime] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'evictions': 0,
|
||||
'load_times': []
|
||||
}
|
||||
|
||||
def get(self, config_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取配置(如果存在且未过期)
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
|
||||
Returns:
|
||||
配置字典,如果不存在或已过期则返回 None
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id not in self._cache:
|
||||
self._stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
timestamp = self._timestamps.get(config_id)
|
||||
if timestamp and (datetime.now() - timestamp) >= self.ttl:
|
||||
# 过期,移除
|
||||
self._cache.pop(config_id, None)
|
||||
self._timestamps.pop(config_id, None)
|
||||
self._stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# 命中,移动到末尾(标记为最近使用)
|
||||
self._cache.move_to_end(config_id)
|
||||
self._stats['hits'] += 1
|
||||
return self._cache[config_id]
|
||||
|
||||
def put(self, config_id: str, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
添加或更新配置
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
config: 配置字典
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id in self._cache:
|
||||
# 更新现有配置
|
||||
self._cache.move_to_end(config_id)
|
||||
else:
|
||||
# 添加新配置
|
||||
if len(self._cache) >= self.max_size:
|
||||
# 缓存已满,移除最旧的配置
|
||||
oldest_id, _ = self._cache.popitem(last=False)
|
||||
self._timestamps.pop(oldest_id, None)
|
||||
self._stats['evictions'] += 1
|
||||
logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}")
|
||||
|
||||
self._cache[config_id] = config
|
||||
self._timestamps[config_id] = datetime.now()
|
||||
|
||||
def clear(self, config_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
清除缓存
|
||||
|
||||
Args:
|
||||
config_id: 如果指定,只清除该配置;否则清除所有
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id:
|
||||
self._cache.pop(config_id, None)
|
||||
self._timestamps.pop(config_id, None)
|
||||
else:
|
||||
self._cache.clear()
|
||||
self._timestamps.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
total = self._stats['hits'] + self._stats['misses']
|
||||
hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0
|
||||
|
||||
return {
|
||||
'cache_size': len(self._cache),
|
||||
'max_size': self.max_size,
|
||||
'total_requests': total,
|
||||
'cache_hits': self._stats['hits'],
|
||||
'cache_misses': self._stats['misses'],
|
||||
'evictions': self._stats['evictions'],
|
||||
'hit_rate': hit_rate,
|
||||
'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0
|
||||
}
|
||||
|
||||
def record_load_time(self, load_time_ms: float) -> None:
|
||||
"""
|
||||
记录加载时间
|
||||
|
||||
Args:
|
||||
load_time_ms: 加载时间(毫秒)
|
||||
"""
|
||||
with self._lock:
|
||||
self._stats['load_times'].append(load_time_ms)
|
||||
# 只保留最近 1000 次的记录
|
||||
if len(self._stats['load_times']) > 1000:
|
||||
self._stats['load_times'] = self._stats['load_times'][-1000:]
|
||||
|
||||
|
||||
class ConfigCacheWarmer:
|
||||
"""
|
||||
配置缓存预热器
|
||||
|
||||
在系统启动时预加载常用配置,减少首次请求延迟
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def warmup(config_ids: List[str], load_func) -> Dict[str, bool]:
|
||||
"""
|
||||
预热缓存
|
||||
|
||||
Args:
|
||||
config_ids: 要预加载的配置 ID 列表
|
||||
load_func: 配置加载函数
|
||||
|
||||
Returns:
|
||||
每个配置的加载结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置")
|
||||
|
||||
for config_id in config_ids:
|
||||
try:
|
||||
result = load_func(config_id)
|
||||
results[config_id] = result
|
||||
if result:
|
||||
logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}")
|
||||
else:
|
||||
logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}")
|
||||
results[config_id] = False
|
||||
|
||||
success_count = sum(1 for r in results.values() if r)
|
||||
logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class DynamicTTLStrategy:
|
||||
"""
|
||||
动态 TTL 策略
|
||||
|
||||
根据配置类型和更新频率动态调整缓存过期时间
|
||||
"""
|
||||
|
||||
# 预定义的 TTL 策略
|
||||
TTL_STRATEGIES = {
|
||||
'production': timedelta(minutes=30), # 生产配置较稳定
|
||||
'staging': timedelta(minutes=15), # 预发布配置中等稳定
|
||||
'development': timedelta(minutes=5), # 开发配置频繁变化
|
||||
'testing': timedelta(minutes=1), # 测试配置快速过期
|
||||
'default': timedelta(minutes=5) # 默认策略
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta:
|
||||
"""
|
||||
获取配置的 TTL
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
config_type: 配置类型(production/staging/development/testing)
|
||||
|
||||
Returns:
|
||||
TTL 时间间隔
|
||||
"""
|
||||
if config_type and config_type in cls.TTL_STRATEGIES:
|
||||
return cls.TTL_STRATEGIES[config_type]
|
||||
|
||||
# 根据 config_id 推断类型
|
||||
if 'prod' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['production']
|
||||
elif 'stag' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['staging']
|
||||
elif 'dev' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['development']
|
||||
elif 'test' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['testing']
|
||||
|
||||
return cls.TTL_STRATEGIES['default']
|
||||
|
||||
|
||||
class ConfigVersionManager:
|
||||
"""
|
||||
配置版本管理器
|
||||
|
||||
跟踪配置版本,当配置更新时自动失效旧版本缓存
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._versions: Dict[str, str] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get_version(self, config_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取配置版本
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
|
||||
Returns:
|
||||
版本号,如果不存在则返回 None
|
||||
"""
|
||||
with self._lock:
|
||||
return self._versions.get(config_id)
|
||||
|
||||
def set_version(self, config_id: str, version: str) -> None:
|
||||
"""
|
||||
设置配置版本
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
version: 版本号
|
||||
"""
|
||||
with self._lock:
|
||||
old_version = self._versions.get(config_id)
|
||||
self._versions[config_id] = version
|
||||
|
||||
if old_version and old_version != version:
|
||||
logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}")
|
||||
|
||||
def check_version(self, config_id: str, cached_version: Optional[str]) -> bool:
|
||||
"""
|
||||
检查缓存版本是否有效
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
cached_version: 缓存的版本号
|
||||
|
||||
Returns:
|
||||
True 如果版本匹配,False 如果版本不匹配或不存在
|
||||
"""
|
||||
with self._lock:
|
||||
current_version = self._versions.get(config_id)
|
||||
|
||||
if not current_version or not cached_version:
|
||||
return False
|
||||
|
||||
return current_version == cached_version
|
||||
|
||||
def invalidate(self, config_id: str) -> None:
|
||||
"""
|
||||
使配置版本失效
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id in self._versions:
|
||||
# 生成新版本号
|
||||
import uuid
|
||||
new_version = str(uuid.uuid4())
|
||||
self._versions[config_id] = new_version
|
||||
logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}")
|
||||
|
||||
|
||||
class CacheMonitor:
|
||||
"""
|
||||
缓存监控器
|
||||
|
||||
提供缓存性能监控和报告功能
|
||||
"""
|
||||
|
||||
def __init__(self, cache: LRUConfigCache):
|
||||
self.cache = cache
|
||||
|
||||
def get_report(self) -> str:
|
||||
"""
|
||||
生成缓存性能报告
|
||||
|
||||
Returns:
|
||||
格式化的报告字符串
|
||||
"""
|
||||
stats = self.cache.get_stats()
|
||||
|
||||
report = f"""
|
||||
配置缓存性能报告
|
||||
================
|
||||
缓存容量: {stats['cache_size']}/{stats['max_size']}
|
||||
总请求数: {stats['total_requests']}
|
||||
缓存命中: {stats['cache_hits']}
|
||||
缓存未命中: {stats['cache_misses']}
|
||||
缓存命中率: {stats['hit_rate']:.2f}%
|
||||
淘汰次数: {stats['evictions']}
|
||||
平均加载时间: {stats['avg_load_time']:.2f}ms
|
||||
"""
|
||||
return report
|
||||
|
||||
def log_stats(self) -> None:
|
||||
"""记录统计信息到日志"""
|
||||
stats = self.cache.get_stats()
|
||||
logger.info(
|
||||
f"[CacheMonitor] 缓存统计 - "
|
||||
f"容量: {stats['cache_size']}/{stats['max_size']}, "
|
||||
f"命中率: {stats['hit_rate']:.2f}%, "
|
||||
f"淘汰: {stats['evictions']}"
|
||||
)
|
||||
|
||||
|
||||
# 使用示例
|
||||
def example_usage():
|
||||
"""
|
||||
优化功能使用示例
|
||||
"""
|
||||
# 1. 使用 LRU 缓存
|
||||
lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5))
|
||||
|
||||
# 获取配置
|
||||
config = lru_cache.get("config_001")
|
||||
if config is None:
|
||||
# 缓存未命中,从数据库加载
|
||||
config = {"llm_name": "openai/gpt-4"}
|
||||
lru_cache.put("config_001", config)
|
||||
|
||||
# 2. 预热缓存
|
||||
def load_config(config_id):
|
||||
# 实际的配置加载逻辑
|
||||
return True
|
||||
|
||||
warmer = ConfigCacheWarmer()
|
||||
results = warmer.warmup(["config_001", "config_002"], load_config)
|
||||
|
||||
# 3. 动态 TTL
|
||||
ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production")
|
||||
print(f"TTL: {ttl}")
|
||||
|
||||
# 4. 版本管理
|
||||
version_manager = ConfigVersionManager()
|
||||
version_manager.set_version("config_001", "v1.0.0")
|
||||
|
||||
# 检查版本
|
||||
is_valid = version_manager.check_version("config_001", "v1.0.0")
|
||||
|
||||
# 5. 监控
|
||||
monitor = CacheMonitor(lru_cache)
|
||||
print(monitor.get_report())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
267
app/core/memory/utils/config/config_utils.py
Normal file
267
app/core/memory/utils/config/config_utils.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import uuid
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi.exceptions import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
DedupConfig,
|
||||
StatementExtractionConfig,
|
||||
ForgettingEngineConfig,
|
||||
)
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.db import get_db
|
||||
from app.models.models_model import ModelConfig, ModelApiKey
|
||||
from app.services.model_service import ModelConfigService
|
||||
def get_model_config(model_id: str, db: Session | None = None) -> dict:
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
if not config:
|
||||
print(f"模型ID {model_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
|
||||
# 从环境变量读取超时和重试配置
|
||||
from app.core.config import settings
|
||||
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免 LLM 请求超时
|
||||
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取,默认120秒
|
||||
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取,默认2次
|
||||
}
|
||||
# 写入model_config.log文件中
|
||||
with open("logs/model_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"模型ID: {model_id}\n")
|
||||
f.write(f"模型配置信息:\n{model_config}\n")
|
||||
f.write(f"=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
|
||||
if db is None:
|
||||
db_gen = get_db() # get_db 通常是一个生成器
|
||||
db = next(db_gen) # 取到真正的 Session
|
||||
|
||||
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
|
||||
if not config:
|
||||
print(f"嵌入模型ID {embedding_id} 不存在")
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
|
||||
apiConfig: ModelApiKey = config.api_keys[0]
|
||||
model_config = {
|
||||
"model_name": apiConfig.model_name,
|
||||
"provider": apiConfig.provider,
|
||||
"api_key": apiConfig.api_key,
|
||||
"base_url": apiConfig.api_base,
|
||||
"model_config_id":apiConfig.model_config_id,
|
||||
# Ensure required field for RedBearModelConfig validation
|
||||
"type": config.type,
|
||||
# 添加超时和重试配置,避免嵌入服务请求超时
|
||||
"timeout": 120.0, # 嵌入服务超时时间(秒)
|
||||
"max_retries": 5, # 最大重试次数
|
||||
}
|
||||
# 写入embedder_config.log文件中
|
||||
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
|
||||
f.write(f"嵌入模型ID: {embedding_id}\n")
|
||||
f.write(f"嵌入模型配置信息:\n{model_config}\n")
|
||||
f.write(f"=============================\n\n")
|
||||
return model_config
|
||||
|
||||
def get_neo4j_config() -> dict:
|
||||
"""Retrieves the Neo4j configuration from the config file."""
|
||||
return CONFIG.get("neo4j", {})
|
||||
def get_picture_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
for model_config in CONFIG.get("picture_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
raise ValueError(f"Model '{llm_name}' not found in config.json")
|
||||
def get_voice_config(llm_name: str) -> dict:
|
||||
"""Retrieves the configuration for a specific model from the config file."""
|
||||
for model_config in CONFIG.get("voice_recognition", []):
|
||||
if model_config["llm_name"] == llm_name:
|
||||
return model_config
|
||||
raise ValueError(f"Model '{llm_name}' not found in config.json")
|
||||
|
||||
|
||||
def get_chunker_config(chunker_strategy: str) -> dict:
|
||||
"""Retrieves the configuration for a specific chunker strategy.
|
||||
|
||||
Enhancements:
|
||||
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
|
||||
- Falls back to the first available chunker config when the requested one is missing.
|
||||
"""
|
||||
# 1) Try to find exact match in config
|
||||
chunker_list = CONFIG.get("chunker_list", [])
|
||||
for chunker_config in chunker_list:
|
||||
if chunker_config.get("chunker_strategy") == chunker_strategy:
|
||||
return chunker_config
|
||||
|
||||
# 2) Provide sane defaults for newer strategies
|
||||
default_configs = {
|
||||
"LLMChunker": {
|
||||
"chunker_strategy": "LLMChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 1000,
|
||||
"threshold": 0.8,
|
||||
"min_sentences": 2,
|
||||
"language": "zh",
|
||||
"skip_window": 1,
|
||||
"min_characters_per_chunk": 100,
|
||||
},
|
||||
"HybridChunker": {
|
||||
"chunker_strategy": "HybridChunker",
|
||||
"embedding_model": "BAAI/bge-m3",
|
||||
"chunk_size": 512,
|
||||
"threshold": 0.8,
|
||||
"min_sentences": 2,
|
||||
"language": "zh",
|
||||
"skip_window": 1,
|
||||
"min_characters_per_chunk": 100,
|
||||
},
|
||||
}
|
||||
if chunker_strategy in default_configs:
|
||||
return default_configs[chunker_strategy]
|
||||
|
||||
# 3) Fallback: use first available config but tag with requested strategy
|
||||
if chunker_list:
|
||||
fallback = chunker_list[0].copy()
|
||||
fallback["chunker_strategy"] = chunker_strategy
|
||||
# Non-fatal notice for visibility in logs if any
|
||||
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
|
||||
return fallback
|
||||
|
||||
# 4) If no configs available at all
|
||||
raise ValueError(
|
||||
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
|
||||
)
|
||||
|
||||
|
||||
def get_pipeline_config() -> ExtractionPipelineConfig:
|
||||
"""Build ExtractionPipelineConfig using only runtime.json values.
|
||||
|
||||
Behavior:
|
||||
- Read `deduplication` section from runtime.json if present.
|
||||
- Read `statement_extraction` section from runtime.json if present.
|
||||
- Read `forgetting_engine` section from runtime.json if present.
|
||||
- If absent, check legacy top-level `enable_llm_dedup` key.
|
||||
- Do NOT fall back to environment variables.
|
||||
- Unspecified fields use model defaults defined in DedupConfig.
|
||||
"""
|
||||
dedup_rc = RUNTIME_CONFIG.get("deduplication", {}) or {}
|
||||
stmt_rc = RUNTIME_CONFIG.get("statement_extraction", {}) or {}
|
||||
forget_rc = RUNTIME_CONFIG.get("forgetting_engine", {}) or {}
|
||||
|
||||
# Assemble kwargs from runtime.json only
|
||||
kwargs = {}
|
||||
# LLM switch: prefer new key, then legacy top-level, default False
|
||||
if "enable_llm_dedup_blockwise" in dedup_rc:
|
||||
kwargs["enable_llm_dedup_blockwise"] = bool(dedup_rc.get("enable_llm_dedup_blockwise"))
|
||||
else:
|
||||
# Legacy top-level fallback inside runtime.json only
|
||||
legacy = RUNTIME_CONFIG.get("enable_llm_dedup")
|
||||
if legacy is not None:
|
||||
kwargs["enable_llm_dedup_blockwise"] = bool(legacy)
|
||||
else:
|
||||
kwargs["enable_llm_dedup_blockwise"] = False # default reserve
|
||||
# Disambiguation switch: only from runtime.json deduplication section
|
||||
if "enable_llm_disambiguation" in dedup_rc:
|
||||
kwargs["enable_llm_disambiguation"] = bool(dedup_rc.get("enable_llm_disambiguation"))
|
||||
|
||||
# Optional LLM fallback gating
|
||||
if "enable_llm_fallback_only_on_borderline" in dedup_rc:
|
||||
kwargs["enable_llm_fallback_only_on_borderline"] = bool(dedup_rc.get("enable_llm_fallback_only_on_borderline"))
|
||||
|
||||
# Optional fuzzy thresholds: use values if provided; otherwise rely on DedupConfig defaults
|
||||
for key in (
|
||||
"fuzzy_name_threshold_strict",
|
||||
"fuzzy_type_threshold_strict",
|
||||
"fuzzy_overall_threshold",
|
||||
"fuzzy_unknown_type_name_threshold",
|
||||
"fuzzy_unknown_type_type_threshold",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
# Optional weights and bonuses for overall scoring
|
||||
for key in (
|
||||
"name_weight",
|
||||
"desc_weight",
|
||||
"type_weight",
|
||||
"context_bonus",
|
||||
"llm_fallback_floor",
|
||||
"llm_fallback_ceiling",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
# Optional LLM iterative dedup parameters
|
||||
for key in (
|
||||
"llm_block_size",
|
||||
"llm_block_concurrency",
|
||||
"llm_pair_concurrency",
|
||||
"llm_max_rounds",
|
||||
):
|
||||
if key in dedup_rc:
|
||||
kwargs[key] = dedup_rc[key]
|
||||
|
||||
dedup_config = DedupConfig(**kwargs)
|
||||
|
||||
# Build StatementExtractionConfig from runtime.json
|
||||
stmt_kwargs = {}
|
||||
for key in (
|
||||
"statement_granularity",
|
||||
"temperature",
|
||||
"include_dialogue_context",
|
||||
"max_dialogue_context_chars",
|
||||
):
|
||||
if key in stmt_rc:
|
||||
stmt_kwargs[key] = stmt_rc[key]
|
||||
stmt_config = StatementExtractionConfig(**stmt_kwargs)
|
||||
|
||||
# Build ForgettingEngineConfig from runtime.json
|
||||
forget_kwargs = {}
|
||||
for key in ("offset", "lambda_time", "lambda_mem"):
|
||||
if key in forget_rc:
|
||||
forget_kwargs[key] = forget_rc[key]
|
||||
forget_config = ForgettingEngineConfig(**forget_kwargs)
|
||||
|
||||
return ExtractionPipelineConfig(
|
||||
statement_extraction=stmt_config,
|
||||
deduplication=dedup_config,
|
||||
forgetting_engine=forget_config,
|
||||
)
|
||||
|
||||
|
||||
def get_pruning_config() -> dict:
|
||||
"""Retrieve semantic pruning config from runtime.json.
|
||||
|
||||
Returns a dict suitable for PruningConfig.model_validate.
|
||||
|
||||
Structure in runtime.json:
|
||||
{
|
||||
"pruning": {
|
||||
"enabled": true,
|
||||
"scene": "education" | "online_service" | "outbound",
|
||||
"threshold": 0.5
|
||||
}
|
||||
}
|
||||
"""
|
||||
pruning_rc = RUNTIME_CONFIG.get("pruning", {}) or {}
|
||||
|
||||
return {
|
||||
"pruning_switch": bool(pruning_rc.get("enabled", False)),
|
||||
"pruning_scene": pruning_rc.get("scene", "education"),
|
||||
"pruning_threshold": float(pruning_rc.get("threshold", 0.5)),
|
||||
}
|
||||
360
app/core/memory/utils/config/definitions.py
Normal file
360
app/core/memory/utils/config/definitions.py
Normal file
@@ -0,0 +1,360 @@
|
||||
"""
|
||||
配置加载模块 - 三阶段架构(已迁移到统一配置管理)
|
||||
|
||||
本模块现在使用全局配置管理系统 (app/core/config.py)
|
||||
来加载和管理配置,同时保持向后兼容性。
|
||||
|
||||
阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)
|
||||
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Import unified configuration system
|
||||
try:
|
||||
from app.core.config import settings
|
||||
USE_UNIFIED_CONFIG = True
|
||||
except ImportError:
|
||||
USE_UNIFIED_CONFIG = False
|
||||
settings = None
|
||||
|
||||
# PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# __file__ = app/core/memory/utils/config/definitions.py
|
||||
# os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# os.path.dirname(...) = app/core/memory/utils
|
||||
# os.path.dirname(...) = app/core/memory
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# 全局配置锁 - 用于线程安全
|
||||
_config_lock = threading.RLock()
|
||||
|
||||
# 加载基础配置(config.json)- 使用全局配置系统
|
||||
if USE_UNIFIED_CONFIG:
|
||||
CONFIG = settings.load_memory_config()
|
||||
else:
|
||||
# Fallback to legacy loading
|
||||
config_path = os.path.join(PROJECT_ROOT, "config.json")
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
CONFIG = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError):
|
||||
print("Warning: config.json not found or is malformed. Using default settings.")
|
||||
CONFIG = {}
|
||||
|
||||
DEFAULT_VALUES = {
|
||||
"llm_name": "openai/qwen-plus",
|
||||
"embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
"chunker_strategy": "RecursiveChunker",
|
||||
"group_id": "group_123",
|
||||
"user_id": "default_user",
|
||||
"apply_id": "default_apply",
|
||||
"llm_agent_name": "openai/qwen-plus",
|
||||
"llm_verify_name": "openai/qwen-plus",
|
||||
"llm_image_recognition": "openai/qwen-plus",
|
||||
"llm_voice_recognition": "openai/qwen-plus",
|
||||
"prompt_level": "DEBUG",
|
||||
"reflexion_iteration_period": "3",
|
||||
"reflexion_range": "retrieval",
|
||||
"reflexion_baseline": "TIME",
|
||||
}
|
||||
|
||||
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
"""
|
||||
从 runtime.json 文件加载配置(通过统一配置加载器)
|
||||
|
||||
使用 overrides.py 的统一配置加载器,按优先级加载:
|
||||
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id)
|
||||
2. 环境变量配置
|
||||
3. runtime.json 默认配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 运行时配置字典
|
||||
"""
|
||||
try:
|
||||
# 使用 overrides.py 的统一配置加载器
|
||||
from app.core.memory.utils.config.overrides import load_unified_config
|
||||
|
||||
runtime_cfg = load_unified_config(PROJECT_ROOT)
|
||||
return runtime_cfg
|
||||
except Exception as e:
|
||||
# Fallback: 直接读取 runtime.json
|
||||
runtime_config_path = os.path.join(PROJECT_ROOT, "runtime.json")
|
||||
try:
|
||||
with open(runtime_config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e2:
|
||||
pass # print(f"[definitions] ❌ 无法加载 runtime.json: {e2},使用空配置")
|
||||
return {"selections": {}}
|
||||
|
||||
|
||||
# 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# 保留此函数仅为向后兼容
|
||||
def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
从数据库加载配置(基于 dbrun.json 中的 config_id)
|
||||
|
||||
注意:此函数已被统一配置加载器替代,现在直接调用 _load_from_runtime_json
|
||||
即可获得包含数据库配置的完整配置。
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典
|
||||
"""
|
||||
try:
|
||||
# 直接使用统一配置加载器
|
||||
return _load_from_runtime_json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)
|
||||
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
"""
|
||||
将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
这是路径 A(runtime.json)和路径 B(数据库)的汇合点,
|
||||
无论配置来自哪里,都通过这个函数统一暴露为常量。
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
"""
|
||||
global RUNTIME_CONFIG, SELECTIONS, LOGGING_CONFIG
|
||||
global LANGFUSE_ENABLED, AGENTA_ENABLED, PROMPT_LOG_LEVEL_NAME
|
||||
global SELECTED_LLM_NAME, SELECTED_EMBEDDING_NAME, SELECTED_CHUNKER_STRATEGY
|
||||
global SELECTED_GROUP_ID, SELECTED_USER_ID, SELECTED_APPLY_ID, SELECTED_TEST_DATA_INDICES
|
||||
global SELECTED_LLM_AGENT_NAME, SELECTED_LLM_VERIFY_NAME, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
|
||||
global SELECTED_LLM_ID, SELECTED_EMBEDDING_ID, SELECTED_RERANK_ID
|
||||
global REFLEXION_CONFIG, REFLEXION_ENABLED, REFLEXION_ITERATION_PERIOD, REFLEXION_RANGE, REFLEXION_BASELINE
|
||||
|
||||
RUNTIME_CONFIG = runtime_cfg
|
||||
|
||||
# 可观测性配置
|
||||
LANGFUSE_ENABLED = RUNTIME_CONFIG.get("langfuse", {}).get("enabled", False)
|
||||
AGENTA_ENABLED = RUNTIME_CONFIG.get("agenta", {}).get("enabled", False)
|
||||
|
||||
# 日志配置
|
||||
LOGGING_CONFIG = RUNTIME_CONFIG.get("logging", {})
|
||||
PROMPT_LOG_LEVEL_NAME = LOGGING_CONFIG.get("prompt_level", DEFAULT_VALUES["prompt_level"])
|
||||
|
||||
# 选择配置
|
||||
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# 基础模型选择
|
||||
SELECTED_LLM_NAME = SELECTIONS.get("llm_name", DEFAULT_VALUES["llm_name"])
|
||||
SELECTED_EMBEDDING_NAME = SELECTIONS.get("embedding_name", DEFAULT_VALUES["embedding_name"])
|
||||
SELECTED_CHUNKER_STRATEGY = SELECTIONS.get("chunker_strategy", DEFAULT_VALUES["chunker_strategy"])
|
||||
|
||||
# 分组和用户配置
|
||||
SELECTED_GROUP_ID = SELECTIONS.get("group_id", DEFAULT_VALUES["group_id"])
|
||||
SELECTED_USER_ID = SELECTIONS.get("user_id", DEFAULT_VALUES["user_id"])
|
||||
SELECTED_APPLY_ID = SELECTIONS.get("apply_id", DEFAULT_VALUES["apply_id"])
|
||||
SELECTED_TEST_DATA_INDICES = SELECTIONS.get("test_data_indices", None)
|
||||
|
||||
# 专用 LLM 配置
|
||||
SELECTED_LLM_AGENT_NAME = SELECTIONS.get("llm_agent_name", DEFAULT_VALUES["llm_agent_name"])
|
||||
SELECTED_LLM_VERIFY_NAME = SELECTIONS.get("llm_verify_name", DEFAULT_VALUES["llm_verify_name"])
|
||||
SELECTED_LLM_PICTURE_NAME = SELECTIONS.get("llm_image_recognition", DEFAULT_VALUES["llm_image_recognition"])
|
||||
SELECTED_LLM_VOICE_NAME = SELECTIONS.get("llm_voice_recognition", DEFAULT_VALUES["llm_voice_recognition"])
|
||||
|
||||
# 模型 ID 配置
|
||||
SELECTED_LLM_ID = SELECTIONS.get("llm_id", None)
|
||||
SELECTED_EMBEDDING_ID = SELECTIONS.get("embedding_id", None)
|
||||
SELECTED_RERANK_ID = SELECTIONS.get("rerank_id", None)
|
||||
|
||||
# 反思配置
|
||||
REFLEXION_CONFIG = RUNTIME_CONFIG.get("reflexion", {})
|
||||
REFLEXION_ENABLED = REFLEXION_CONFIG.get("enabled", False)
|
||||
REFLEXION_ITERATION_PERIOD = REFLEXION_CONFIG.get("iteration_period", DEFAULT_VALUES["reflexion_iteration_period"])
|
||||
REFLEXION_RANGE = REFLEXION_CONFIG.get("reflexion_range", DEFAULT_VALUES["reflexion_range"])
|
||||
REFLEXION_BASELINE = REFLEXION_CONFIG.get("baseline", DEFAULT_VALUES["reflexion_baseline"])
|
||||
|
||||
|
||||
# 初始化:使用统一配置加载器
|
||||
def _initialize_configuration() -> None:
|
||||
"""
|
||||
初始化配置:使用统一配置加载器
|
||||
|
||||
配置加载优先级(由 overrides.py 统一处理):
|
||||
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id)
|
||||
2. 环境变量配置(.env)
|
||||
3. runtime.json 默认配置
|
||||
"""
|
||||
try:
|
||||
|
||||
# 使用统一配置加载器(已包含所有优先级处理)
|
||||
runtime_config = _load_from_runtime_json()
|
||||
|
||||
# 暴露为全局常量
|
||||
_expose_runtime_constants(runtime_config)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
pass # print(f"[definitions] × 配置初始化失败: {e}")
|
||||
# 使用空配置
|
||||
_expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# 模块加载时自动初始化配置
|
||||
_initialize_configuration()
|
||||
|
||||
|
||||
# 公共 API:动态重新加载配置
|
||||
def reload_configuration_from_database(config_id: int | str, force_reload: bool = False) -> bool:
|
||||
"""
|
||||
动态重新加载配置(从数据库)- 使用统一配置加载器
|
||||
用于运行时切换配置,例如前端传入新的 config_id 时调用。
|
||||
|
||||
注意:此函数仅在内存中覆写配置,不会修改 runtime.json 文件。
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID(整数或字符串,会自动转换)
|
||||
force_reload: 保留参数以保持向后兼容(已移除缓存逻辑)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功重新加载配置
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
with _config_lock:
|
||||
try:
|
||||
from app.core.memory.utils.config.overrides import load_unified_config
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 导入统一配置加载器失败: {e}")
|
||||
|
||||
# 记录配置加载失败
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"error": f"Import failed: {str(e)}"}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"[definitions] 开始重新加载配置,config_id={config_id}")
|
||||
|
||||
# 使用统一配置加载器(指定 config_id)
|
||||
updated_cfg = load_unified_config(PROJECT_ROOT, config_id=config_id)
|
||||
|
||||
# 检查是否成功加载
|
||||
if not updated_cfg or not updated_cfg.get('selections'):
|
||||
logger.error(f"[definitions] 配置加载失败:数据库中未找到 config_id={config_id} 的配置")
|
||||
|
||||
# 记录配置加载失败
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"reason": "config not found in database"}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
# 重新暴露常量
|
||||
_expose_runtime_constants(updated_cfg)
|
||||
|
||||
logger.info(f"[definitions] 配置重新加载成功,已暴露常量")
|
||||
logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, "
|
||||
f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}")
|
||||
|
||||
# 记录成功的配置加载
|
||||
if audit_logger:
|
||||
selections = updated_cfg.get('selections', {})
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
user_id=selections.get('user_id', None),
|
||||
group_id=selections.get('group_id', None),
|
||||
success=True,
|
||||
details={
|
||||
"llm_id": selections.get('llm_id'),
|
||||
"embedding_id": selections.get('embedding_id'),
|
||||
"chunker_strategy": selections.get('chunker_strategy')
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 重新加载配置时发生异常: {e}", exc_info=True)
|
||||
|
||||
# 记录配置加载异常
|
||||
if audit_logger:
|
||||
audit_logger.log_config_load(
|
||||
config_id=config_id,
|
||||
success=False,
|
||||
details={"error": str(e)}
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_current_config_id() -> Optional[str]:
|
||||
"""
|
||||
获取当前使用的 config_id
|
||||
|
||||
Returns:
|
||||
Optional[str]: 当前的 config_id,如果未设置则返回 None
|
||||
"""
|
||||
return SELECTIONS.get("config_id", None)
|
||||
|
||||
|
||||
def ensure_fresh_config(config_id: Optional[int | str] = None) -> bool:
|
||||
"""
|
||||
确保使用最新的配置(每次写入操作前调用)
|
||||
|
||||
如果提供了 config_id,则加载该配置;
|
||||
否则从 dbrun.json 读取并加载最新配置。
|
||||
|
||||
Args:
|
||||
config_id: 可选的配置ID(整数或字符串,会自动转换)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功加载配置
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with _config_lock:
|
||||
try:
|
||||
if config_id:
|
||||
# 使用指定的 config_id
|
||||
logger.debug(f"[definitions] 加载指定配置,config_id={config_id}")
|
||||
return reload_configuration_from_database(config_id)
|
||||
else:
|
||||
# 从数据库重新加载配置
|
||||
logger.debug("[definitions] 从数据库重新加载最新配置")
|
||||
memory_config = _load_from_database()
|
||||
|
||||
if not memory_config or not memory_config.get('selections'):
|
||||
logger.warning("[definitions] 未能从数据库加载配置,使用当前配置")
|
||||
return False
|
||||
|
||||
_expose_memory_constants(memory_config)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[definitions] 加载配置失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
93
app/core/memory/utils/config/get_data.py
Normal file
93
app/core/memory/utils/config/get_data.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.schemas.memory_storage_schema import BaseDataSchema
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _load_(data: List[Any]) -> List[Dict]:
|
||||
target_keys = [
|
||||
"id",
|
||||
"statement",
|
||||
"group_id",
|
||||
"chunk_id",
|
||||
"created_at",
|
||||
"expired_at",
|
||||
"valid_at",
|
||||
"invalid_at",
|
||||
]
|
||||
results = []
|
||||
for row in data or []:
|
||||
s = None
|
||||
if isinstance(row, (tuple, list)) and row:
|
||||
s = row[0]
|
||||
elif hasattr(row, "retrieve_info"):
|
||||
s = getattr(row, "retrieve_info")
|
||||
elif isinstance(row, dict) and "retrieve_info" in row:
|
||||
s = row.get("retrieve_info")
|
||||
elif hasattr(row, "_mapping") and "retrieve_info" in getattr(row, "_mapping"):
|
||||
s = row._mapping["retrieve_info"]
|
||||
else:
|
||||
s = row
|
||||
if s is None:
|
||||
continue
|
||||
if isinstance(s, bytes):
|
||||
try:
|
||||
s = s.decode("utf-8")
|
||||
except Exception:
|
||||
try:
|
||||
s = s.decode()
|
||||
except Exception:
|
||||
continue
|
||||
s = str(s).strip()
|
||||
if not s or s == "[]":
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(s)
|
||||
except Exception:
|
||||
continue
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if "statement" not in item and "statements" in item:
|
||||
item["statement"] = item.get("statements") or ""
|
||||
normalized = {k: item.get(k, "") for k in target_keys}
|
||||
results.append(normalized)
|
||||
return results
|
||||
|
||||
|
||||
async def get_data(host_id: uuid.UUID) -> List[Dict]:
|
||||
"""
|
||||
从数据库中获取数据
|
||||
"""
|
||||
# 从数据库会话中获取会话
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all()
|
||||
|
||||
# print(f"data:\n{data}")
|
||||
# 解析,提取为字典的列表
|
||||
results = await _load_(data)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
# 从数据库中获取数据
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
data = asyncio.run(get_data(host_id))
|
||||
print(type(data))
|
||||
print(data)
|
||||
90
app/core/memory/utils/config/get_example_data.py
Normal file
90
app/core/memory/utils/config/get_example_data.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import random
|
||||
import string
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# 生成包含字母(大小写)和数字的随机字符串
|
||||
def generate_random_string(length=16):
|
||||
characters = string.ascii_letters + string.digits
|
||||
return ''.join(random.choice(characters) for _ in range(length))
|
||||
|
||||
def get_example_data() -> List[Dict[str, Optional[str]]]:
|
||||
"""
|
||||
从句子提取日志中获取数据
|
||||
Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。
|
||||
Created At: 2025-11-28 19:28:38.256421
|
||||
Expired At: None
|
||||
Valid At: None
|
||||
Invalid At: None
|
||||
将数据构造成如下形式:
|
||||
[
|
||||
{
|
||||
"id":id,
|
||||
"group_id":group_id,
|
||||
"statement": Content,
|
||||
"created_at": Created At,
|
||||
"expired_at": Expired At,
|
||||
"valid_at": Valid At,
|
||||
"invalid_at": Invalid At,
|
||||
"chunk_id": "86da9022710c40eaa5f518a294c398d2",
|
||||
"entity_ids": []
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
# 获取日志文件路径
|
||||
log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(log_file_path):
|
||||
return []
|
||||
|
||||
# 读取日志文件
|
||||
with open(log_file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 解析数据
|
||||
results = []
|
||||
|
||||
# 使用正则表达式分割每个 Statement
|
||||
statement_blocks = re.split(r"Statement \d+:", content)
|
||||
|
||||
for block in statement_blocks[1:]: # 跳过第一个空块
|
||||
# 提取各个字段
|
||||
id_match = re.search(r"Id:\s*(.+?)(?=\n)", block)
|
||||
group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block)
|
||||
statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block)
|
||||
created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block)
|
||||
expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block)
|
||||
valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block)
|
||||
invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block)
|
||||
chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block)
|
||||
|
||||
# 构造字典
|
||||
if statement_match:
|
||||
statement_data = {
|
||||
"id": id_match.group(1).strip() if id_match else generate_random_string(),
|
||||
"group_id": group_id_match.group(1).strip() if group_id_match else "group_example",
|
||||
"statement": statement_match.group(1).strip(),
|
||||
"created_at": created_at_match.group(1).strip() if created_at_match else None,
|
||||
"expired_at": expired_at_match.group(1).strip() if expired_at_match else None,
|
||||
"valid_at": valid_at_match.group(1).strip() if valid_at_match else None,
|
||||
"invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None,
|
||||
"chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example",
|
||||
"entity_ids": []
|
||||
}
|
||||
|
||||
# 将 "None" 字符串转换为 None
|
||||
for key in ["created_at", "expired_at", "valid_at", "invalid_at"]:
|
||||
if statement_data[key] == "None":
|
||||
statement_data[key] = None
|
||||
|
||||
results.append(statement_data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"获取数据如下:\n {get_example_data()}")
|
||||
516
app/core/memory/utils/config/litellm_config.py
Normal file
516
app/core/memory/utils/config/litellm_config.py
Normal file
@@ -0,0 +1,516 @@
|
||||
"""
|
||||
LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring
|
||||
"""
|
||||
|
||||
import litellm
|
||||
from typing import Dict, Any, List
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
class LiteLLMConfig:
|
||||
"""Configuration class for LiteLLM with enhanced retry and tracking capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.usage_data = []
|
||||
self.error_data = []
|
||||
self.module_stats = defaultdict(lambda: {
|
||||
'requests': 0,
|
||||
'tokens_in': 0,
|
||||
'tokens_out': 0,
|
||||
'cost': 0.0,
|
||||
'errors': 0,
|
||||
'start_time': None,
|
||||
'last_request_time': None,
|
||||
'request_timestamps': [], # Store precise timestamps
|
||||
'current_qps': 0.0,
|
||||
'max_qps': 0.0,
|
||||
'qps_history': [] # Store QPS measurements over time
|
||||
})
|
||||
self.start_time = datetime.now()
|
||||
self.global_request_timestamps = []
|
||||
self.global_max_qps = 0.0
|
||||
|
||||
# Rate limiting for AWS Bedrock (conservative limits)
|
||||
self.rate_limits = {
|
||||
'bedrock': {
|
||||
'requests_per_minute': 2, # AWS Bedrock default is very low
|
||||
'requests_per_second': 0.033, # 2/60 = 0.033 RPS
|
||||
'last_request_time': 0,
|
||||
'request_queue': Queue(),
|
||||
'lock': threading.Lock()
|
||||
}
|
||||
}
|
||||
self.rate_limiting_enabled = True
|
||||
|
||||
def setup_enhanced_config(self, max_retries: int = 3):
|
||||
"""Configure LiteLLM with retry logic and instant QPS tracking"""
|
||||
|
||||
litellm.num_retries = max_retries
|
||||
litellm.request_timeout = 300
|
||||
|
||||
litellm.retry_policy = {
|
||||
"RateLimitError": {
|
||||
"max_retries": 5,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 1,
|
||||
"max_delay": 60,
|
||||
"jitter": True
|
||||
},
|
||||
"APIConnectionError": {
|
||||
"max_retries": 3,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 2,
|
||||
"max_delay": 30,
|
||||
"jitter": True
|
||||
},
|
||||
"InternalServerError": {
|
||||
"max_retries": 2,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 5,
|
||||
"max_delay": 60,
|
||||
"jitter": True
|
||||
},
|
||||
"BadRequestError": {
|
||||
"max_retries": 1,
|
||||
"exponential_backoff": False,
|
||||
"initial_delay": 1,
|
||||
"max_delay": 5
|
||||
}
|
||||
}
|
||||
|
||||
litellm.success_callback = [self._success_callback]
|
||||
litellm.failure_callback = [self._failure_callback]
|
||||
litellm.completion_cost_tracking = True
|
||||
litellm.set_verbose = False
|
||||
litellm.modify_params = True
|
||||
|
||||
print("✅ LiteLLM configured with instant QPS tracking and rate limiting")
|
||||
|
||||
def _success_callback(self, kwargs, completion_response, start_time, end_time):
|
||||
"""Callback for successful requests with module-specific QPS tracking"""
|
||||
try:
|
||||
# Extract usage information
|
||||
usage = completion_response.get('usage', {})
|
||||
model = kwargs.get('model', 'unknown')
|
||||
|
||||
# Extract module information from metadata or model name
|
||||
module = self._extract_module_name(kwargs, model)
|
||||
|
||||
# Calculate cost
|
||||
cost = 0.0
|
||||
try:
|
||||
cost = litellm.completion_cost(completion_response)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Calculate duration
|
||||
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
|
||||
|
||||
# Record usage data
|
||||
usage_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"module": module,
|
||||
"input_tokens": usage.get('prompt_tokens', 0),
|
||||
"output_tokens": usage.get('completion_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0),
|
||||
"cost": cost,
|
||||
"duration_seconds": duration_seconds,
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
self.usage_data.append(usage_record)
|
||||
|
||||
# Update module-specific stats for QPS tracking
|
||||
self._update_module_stats(module, usage_record, success=True)
|
||||
|
||||
# Print real-time feedback
|
||||
print(f"✓ {model}: {usage_record['input_tokens']}→{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Success callback failed: {e}")
|
||||
|
||||
def _failure_callback(self, kwargs, completion_response, start_time, end_time):
|
||||
"""Callback for failed requests with module-specific error tracking"""
|
||||
try:
|
||||
model = kwargs.get('model', 'unknown')
|
||||
module = self._extract_module_name(kwargs, model)
|
||||
|
||||
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
|
||||
|
||||
# Handle different error response formats
|
||||
error_message = "Unknown error"
|
||||
error_type = "UnknownError"
|
||||
|
||||
# According to LiteLLM docs, completion_response contains the exception for failures
|
||||
if completion_response is not None:
|
||||
error_message = str(completion_response)
|
||||
error_type = type(completion_response).__name__
|
||||
|
||||
# Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events)
|
||||
elif 'exception' in kwargs:
|
||||
exception = kwargs['exception']
|
||||
error_message = str(exception)
|
||||
error_type = type(exception).__name__
|
||||
|
||||
# Check for other error formats in kwargs
|
||||
elif 'error' in kwargs:
|
||||
error = kwargs['error']
|
||||
error_message = str(error)
|
||||
error_type = type(error).__name__
|
||||
|
||||
# Check log_event_type to confirm this is a failure event
|
||||
log_event_type = kwargs.get('log_event_type', '')
|
||||
if log_event_type == 'failed_api_call' and 'exception' in kwargs:
|
||||
exception = kwargs['exception']
|
||||
error_message = str(exception)
|
||||
error_type = type(exception).__name__
|
||||
|
||||
error_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"module": module,
|
||||
"error": error_message,
|
||||
"error_type": error_type,
|
||||
"duration_seconds": duration_seconds,
|
||||
"status": "failed"
|
||||
}
|
||||
|
||||
self.error_data.append(error_record)
|
||||
|
||||
# Update module-specific stats for error tracking
|
||||
self._update_module_stats(module, error_record, success=False)
|
||||
|
||||
# Print error feedback
|
||||
print(f"✗ {model}: {error_type} - {error_message[:100]}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failure callback failed: {e}")
|
||||
# Debug: print the actual parameters to understand the structure
|
||||
print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}")
|
||||
print(f"Debug - completion_response type: {type(completion_response)}")
|
||||
print(f"Debug - completion_response: {completion_response}")
|
||||
|
||||
def _should_rate_limit(self, model: str) -> bool:
|
||||
"""Check if the model should be rate limited"""
|
||||
if not self.rate_limiting_enabled:
|
||||
return False
|
||||
return model.startswith('bedrock/') or 'bedrock' in model.lower()
|
||||
|
||||
def _enforce_rate_limit(self, model: str):
|
||||
"""Enforce rate limiting for AWS Bedrock models"""
|
||||
if not self._should_rate_limit(model):
|
||||
return
|
||||
|
||||
provider = 'bedrock'
|
||||
if provider not in self.rate_limits:
|
||||
return
|
||||
|
||||
rate_config = self.rate_limits[provider]
|
||||
|
||||
with rate_config['lock']:
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - rate_config['last_request_time']
|
||||
min_interval = 1.0 / rate_config['requests_per_second']
|
||||
|
||||
if time_since_last < min_interval:
|
||||
sleep_time = min_interval - time_since_last
|
||||
print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
rate_config['last_request_time'] = time.time()
|
||||
|
||||
def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str:
|
||||
"""Extract module name from request context"""
|
||||
# Try to get module from metadata
|
||||
metadata = kwargs.get('metadata', {})
|
||||
if 'module' in metadata:
|
||||
return metadata['module']
|
||||
|
||||
# Try to infer from model name or other context
|
||||
if 'claude' in model.lower():
|
||||
return 'bedrock_client'
|
||||
elif 'gpt' in model.lower() or 'openai' in model.lower():
|
||||
return 'openai_client'
|
||||
elif 'embed' in model.lower():
|
||||
return 'embedder'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool):
|
||||
"""Update module-specific statistics with instant QPS tracking"""
|
||||
current_timestamp = time.time()
|
||||
current_time = datetime.now()
|
||||
|
||||
# Initialize module stats if first request
|
||||
if self.module_stats[module]['start_time'] is None:
|
||||
self.module_stats[module]['start_time'] = current_time
|
||||
|
||||
# Update counters
|
||||
self.module_stats[module]['requests'] += 1
|
||||
self.module_stats[module]['last_request_time'] = current_time
|
||||
self.module_stats[module]['request_timestamps'].append(current_timestamp)
|
||||
self.global_request_timestamps.append(current_timestamp)
|
||||
|
||||
# Calculate instant QPS for this module
|
||||
self._calculate_instant_qps(module, current_timestamp)
|
||||
|
||||
# Calculate global instant QPS
|
||||
self._calculate_global_instant_qps(current_timestamp)
|
||||
|
||||
if success:
|
||||
self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0)
|
||||
self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0)
|
||||
self.module_stats[module]['cost'] += record.get('cost', 0.0)
|
||||
else:
|
||||
self.module_stats[module]['errors'] += 1
|
||||
|
||||
def _calculate_instant_qps(self, module: str, current_timestamp: float):
|
||||
"""Calculate instant QPS for a specific module using sliding window"""
|
||||
# Keep only timestamps from last 1 second for instant QPS
|
||||
cutoff_time = current_timestamp - 1.0
|
||||
timestamps = self.module_stats[module]['request_timestamps']
|
||||
|
||||
# Remove old timestamps
|
||||
self.module_stats[module]['request_timestamps'] = [
|
||||
ts for ts in timestamps if ts >= cutoff_time
|
||||
]
|
||||
|
||||
# Calculate current QPS (requests in last second)
|
||||
current_qps = len(self.module_stats[module]['request_timestamps'])
|
||||
self.module_stats[module]['current_qps'] = current_qps
|
||||
|
||||
# Update max QPS if current is higher
|
||||
if current_qps > self.module_stats[module]['max_qps']:
|
||||
self.module_stats[module]['max_qps'] = current_qps
|
||||
|
||||
# Store QPS history (keep last 60 measurements)
|
||||
self.module_stats[module]['qps_history'].append(current_qps)
|
||||
if len(self.module_stats[module]['qps_history']) > 60:
|
||||
self.module_stats[module]['qps_history'].pop(0)
|
||||
|
||||
def _calculate_global_instant_qps(self, current_timestamp: float):
|
||||
"""Calculate global instant QPS across all modules"""
|
||||
# Keep only timestamps from last 1 second
|
||||
cutoff_time = current_timestamp - 1.0
|
||||
self.global_request_timestamps = [
|
||||
ts for ts in self.global_request_timestamps if ts >= cutoff_time
|
||||
]
|
||||
|
||||
# Calculate current global QPS
|
||||
current_global_qps = len(self.global_request_timestamps)
|
||||
|
||||
# Update max global QPS
|
||||
if current_global_qps > self.global_max_qps:
|
||||
self.global_max_qps = current_global_qps
|
||||
|
||||
def get_instant_qps(self, module: str = None) -> Dict[str, Any]:
|
||||
"""Get instant QPS data for modules"""
|
||||
if module:
|
||||
if module in self.module_stats:
|
||||
return {
|
||||
'module': module,
|
||||
'current_qps': self.module_stats[module]['current_qps'],
|
||||
'max_qps': self.module_stats[module]['max_qps'],
|
||||
'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0
|
||||
}
|
||||
else:
|
||||
return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0}
|
||||
else:
|
||||
# Return data for all modules plus global
|
||||
result = {
|
||||
'global': {
|
||||
'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]),
|
||||
'max_qps': self.global_max_qps
|
||||
},
|
||||
'modules': {}
|
||||
}
|
||||
|
||||
for mod in self.module_stats.keys():
|
||||
result['modules'][mod] = {
|
||||
'current_qps': self.module_stats[mod]['current_qps'],
|
||||
'max_qps': self.module_stats[mod]['max_qps'],
|
||||
'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_usage_summary(self) -> Dict[str, Any]:
|
||||
"""Get essential usage statistics"""
|
||||
if not self.usage_data:
|
||||
return {
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
"error_rate": 0.0,
|
||||
"message": "No usage data available"
|
||||
}
|
||||
|
||||
total_requests = len(self.usage_data)
|
||||
total_errors = len(self.error_data)
|
||||
total_cost = sum(record['cost'] for record in self.usage_data)
|
||||
total_input_tokens = sum(record['input_tokens'] for record in self.usage_data)
|
||||
total_output_tokens = sum(record['output_tokens'] for record in self.usage_data)
|
||||
|
||||
# Calculate session duration
|
||||
duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60
|
||||
|
||||
# Build module statistics
|
||||
module_stats = {}
|
||||
for module, stats in self.module_stats.items():
|
||||
if stats['requests'] > 0:
|
||||
module_stats[module] = {
|
||||
"requests": stats['requests'],
|
||||
"errors": stats['errors'],
|
||||
"success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0,
|
||||
"tokens_in": stats['tokens_in'],
|
||||
"tokens_out": stats['tokens_out'],
|
||||
"cost": stats['cost'],
|
||||
"current_qps": stats['current_qps'],
|
||||
"max_qps": stats['max_qps']
|
||||
}
|
||||
|
||||
return {
|
||||
"session_duration_minutes": duration_minutes,
|
||||
"total_requests": total_requests,
|
||||
"total_errors": total_errors,
|
||||
"error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_cost": total_cost,
|
||||
"module_stats": module_stats,
|
||||
"global_max_qps": self.global_max_qps
|
||||
}
|
||||
|
||||
def print_usage_summary(self):
|
||||
"""Print essential usage summary"""
|
||||
stats = self.get_usage_summary()
|
||||
|
||||
if stats.get('message'):
|
||||
print(f"📊 {stats['message']}")
|
||||
return
|
||||
|
||||
print(f"\n📊 USAGE SUMMARY")
|
||||
print(f"{'='*50}")
|
||||
print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min")
|
||||
print(f"📈 Requests: {stats['total_requests']}")
|
||||
print(f"❌ Errors: {stats['total_errors']}")
|
||||
print(f"💰 Cost: ${stats['total_cost']:.4f}")
|
||||
print(f"🏆 Global Max QPS: {stats['global_max_qps']}")
|
||||
|
||||
# Module statistics
|
||||
if stats.get('module_stats'):
|
||||
print(f"\n📦 MODULES:")
|
||||
for module, mod_stats in stats['module_stats'].items():
|
||||
print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}")
|
||||
|
||||
print(f"{'='*50}")
|
||||
|
||||
def save_usage_data(self, filename: str = "litellm_usage.json"):
|
||||
"""Save usage data to JSON file"""
|
||||
data = {
|
||||
"summary": self.get_usage_summary(),
|
||||
"detailed_usage": self.usage_data,
|
||||
"errors": self.error_data,
|
||||
"export_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
print(f"📁 Usage data saved to {filename}")
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking data"""
|
||||
self.usage_data = []
|
||||
self.error_data = []
|
||||
self.module_stats = defaultdict(lambda: {
|
||||
'requests': 0,
|
||||
'tokens_in': 0,
|
||||
'tokens_out': 0,
|
||||
'cost': 0.0,
|
||||
'errors': 0,
|
||||
'start_time': None,
|
||||
'last_request_time': None,
|
||||
'request_timestamps': [],
|
||||
'current_qps': 0.0,
|
||||
'max_qps': 0.0,
|
||||
'qps_history': []
|
||||
})
|
||||
self.global_request_timestamps = []
|
||||
self.global_max_qps = 0.0
|
||||
self.start_time = datetime.now()
|
||||
print("🔄 All tracking data reset")
|
||||
|
||||
# Global instance for easy access
|
||||
litellm_config = LiteLLMConfig()
|
||||
|
||||
def setup_litellm_enhanced(max_retries: int = 3):
|
||||
"""
|
||||
Quick setup function for LiteLLM enhanced configuration
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retries for failed requests
|
||||
"""
|
||||
litellm_config.setup_enhanced_config(max_retries)
|
||||
return litellm_config
|
||||
|
||||
def get_usage_summary():
|
||||
"""Get current usage summary"""
|
||||
return litellm_config.get_usage_summary()
|
||||
|
||||
def print_usage_summary():
|
||||
"""Print current usage summary"""
|
||||
litellm_config.print_usage_summary()
|
||||
|
||||
def save_usage_data(filename: str = "litellm_usage.json"):
|
||||
"""Save usage data to file"""
|
||||
litellm_config.save_usage_data(filename)
|
||||
|
||||
def get_instant_qps(module: str = None) -> Dict[str, Any]:
|
||||
"""Get instant QPS data for modules"""
|
||||
return litellm_config.get_instant_qps(module)
|
||||
|
||||
def print_instant_qps(module: str = None):
|
||||
"""Print instant QPS information"""
|
||||
qps_data = get_instant_qps(module)
|
||||
|
||||
print(f"\n⚡ INSTANT QPS MONITOR")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if module:
|
||||
print(f"Module: {qps_data['module']}")
|
||||
print(f" Current QPS: {qps_data['current_qps']}")
|
||||
print(f" Max QPS: {qps_data['max_qps']}")
|
||||
print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}")
|
||||
else:
|
||||
# Global stats
|
||||
global_data = qps_data.get('global', {})
|
||||
print(f"🌍 GLOBAL:")
|
||||
print(f" Current QPS: {global_data.get('current_qps', 0)}")
|
||||
print(f" Max QPS: {global_data.get('max_qps', 0)}")
|
||||
|
||||
# Module stats
|
||||
modules = qps_data.get('modules', {})
|
||||
if modules:
|
||||
print(f"\n📦 MODULES:")
|
||||
for mod, data in modules.items():
|
||||
print(f" {mod}:")
|
||||
print(f" Current: {data['current_qps']} QPS")
|
||||
print(f" Max: {data['max_qps']} QPS")
|
||||
print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS")
|
||||
|
||||
print(f"{'='*60}")
|
||||
|
||||
def reset_tracking():
|
||||
"""Reset all tracking data"""
|
||||
litellm_config.reset_tracking()
|
||||
|
||||
def get_module_stats() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get detailed module statistics"""
|
||||
summary = get_usage_summary()
|
||||
return summary.get('module_stats', {})
|
||||
611
app/core/memory/utils/config/overrides.py
Normal file
611
app/core/memory/utils/config/overrides.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
运行时配置覆写工具 - 统一配置加载器
|
||||
|
||||
本模块作为统一的配置加载器,负责从多个来源加载配置并按优先级覆写。
|
||||
|
||||
配置来源优先级(从高到低):
|
||||
1. 数据库配置(PostgreSQL data_config 表)
|
||||
2. 环境变量配置(.env 文件)
|
||||
3. 默认配置(runtime.json 文件)
|
||||
|
||||
支持的配置加载方式:
|
||||
- 基于 config_id 的配置加载(从 dbrun.json 读取或前端传入)
|
||||
- 基于 group_id 的配置加载(从 dbrun.json 读取)
|
||||
- 环境变量覆写(支持 INTERNAL/EXTERNAL 网络模式)
|
||||
|
||||
主要功能:
|
||||
- 从 PostgreSQL 数据库读取配置
|
||||
- 从环境变量读取配置
|
||||
- 从 runtime.json 读取默认配置
|
||||
- 按优先级覆写配置项(仅在内存中,不修改文件)
|
||||
- 支持多种配置字段:selections、statement_extraction、deduplication、forgetting_engine、pruning、reflexion
|
||||
|
||||
使用场景:
|
||||
- 应用启动时自动加载配置
|
||||
- 前端切换配置时动态重新加载
|
||||
- 多租户场景下的配置隔离
|
||||
- 内外网环境自动切换
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import socket
|
||||
from typing import Optional, Dict, Any, Literal
|
||||
|
||||
NetworkMode = Literal['internal', 'external']
|
||||
|
||||
|
||||
def _set_if_present(target: Dict[str, Any], target_key: str, src: Dict[str, Any], src_key: str, caster):
|
||||
"""安全地设置目标字典的值(如果源字典中存在且不为 None)
|
||||
|
||||
Args:
|
||||
target: 目标字典
|
||||
target_key: 目标字典的键
|
||||
src: 源字典
|
||||
src_key: 源字典的键
|
||||
caster: 类型转换函数
|
||||
"""
|
||||
try:
|
||||
if src_key in src and src.get(src_key) is not None:
|
||||
try:
|
||||
target[target_key] = caster(src.get(src_key))
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _to_bool(val: Any) -> bool:
|
||||
"""将各种类型的值转换为布尔值
|
||||
|
||||
支持的输入:
|
||||
- bool: 直接返回
|
||||
- int/float: 非零为 True
|
||||
- str: "true", "1", "on", "yes" 为 True;"false", "0", "off", "no" 为 False
|
||||
|
||||
Args:
|
||||
val: 要转换的值
|
||||
|
||||
Returns:
|
||||
bool: 转换后的布尔值
|
||||
"""
|
||||
try:
|
||||
if isinstance(val, bool):
|
||||
return val
|
||||
if isinstance(val, (int, float)):
|
||||
return bool(val)
|
||||
if isinstance(val, str):
|
||||
m = val.strip().lower()
|
||||
if m in {"true", "1", "on", "yes"}:
|
||||
return True
|
||||
if m in {"false", "0", "off", "no"}:
|
||||
return False
|
||||
return bool(val)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _make_pgsql_conn() -> Optional[object]:
|
||||
"""创建 PostgreSQL 数据库连接
|
||||
|
||||
使用环境变量配置连接参数:
|
||||
- DB_HOST: 数据库主机地址(默认 localhost)
|
||||
- DB_PORT: 数据库端口(默认 5432)
|
||||
- DB_USER: 数据库用户名
|
||||
- DB_PASSWORD: 数据库密码
|
||||
- DB_NAME: 数据库名称
|
||||
|
||||
Returns:
|
||||
Optional[object]: 数据库连接对象,失败时返回 None
|
||||
"""
|
||||
host = os.getenv("DB_HOST", "localhost")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
dbname = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host,
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=dbname,
|
||||
)
|
||||
conn.autocommit = True
|
||||
return conn
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_db_config_by_group_id(group_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 group_id 从数据库查询配置
|
||||
|
||||
Args:
|
||||
group_id: 组标识符
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
|
||||
"""
|
||||
conn = _make_pgsql_conn()
|
||||
if conn is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
cur = conn.cursor(cursor_factory=RealDictCursor)
|
||||
|
||||
try:
|
||||
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sql = (
|
||||
"SELECT group_id, user_id, apply_id, chunker_strategy, "
|
||||
" enable_llm_dedup_blockwise, enable_llm_disambiguation "
|
||||
"FROM data_config WHERE group_id = %s ORDER BY updated_at DESC LIMIT 1"
|
||||
)
|
||||
cur.execute(sql, (group_id,))
|
||||
row = cur.fetchone()
|
||||
return row if row else None
|
||||
except Exception:
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
cur.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 config_id 从数据库查询配置
|
||||
|
||||
Args:
|
||||
config_id: 配置标识符(整数或字符串,会自动转换为整数)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
|
||||
"""
|
||||
conn = _make_pgsql_conn()
|
||||
if conn is None:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
try:
|
||||
from psycopg2.extras import RealDictCursor # type: ignore
|
||||
cur = conn.cursor(cursor_factory=RealDictCursor)
|
||||
|
||||
try:
|
||||
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# config_id 在数据库中是 Integer 类型,需要转换
|
||||
try:
|
||||
config_id_int = int(config_id)
|
||||
except (ValueError, TypeError) as e:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
sql = (
|
||||
"SELECT config_id, group_id, user_id, apply_id, chunker_strategy, "
|
||||
" enable_llm_dedup_blockwise, enable_llm_disambiguation, "
|
||||
" deep_retrieval, t_type_strict, t_name_strict, t_overall, state, "
|
||||
" statement_granularity, include_dialogue_context, max_context, "
|
||||
" \"offset\" AS offset, lambda_time, lambda_mem, "
|
||||
" pruning_enabled, pruning_scene, pruning_threshold, "
|
||||
" llm_id, embedding_id "
|
||||
"FROM data_config WHERE config_id = %s LIMIT 1"
|
||||
)
|
||||
cur.execute(sql, (config_id_int,))
|
||||
row = cur.fetchone()
|
||||
|
||||
if row:
|
||||
try:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
return row if row else None
|
||||
except Exception as e:
|
||||
pass
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
cur.close()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _load_dbrun_group_id(project_root: str) -> Optional[str]:
|
||||
"""从 dbrun.json 读取 group_id
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: group_id,未找到时返回 None
|
||||
"""
|
||||
try:
|
||||
path = os.path.join(project_root, "dbrun.json")
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if isinstance(data, dict):
|
||||
if "group_id" in data:
|
||||
return str(data.get("group_id"))
|
||||
sel = data.get("selections", {})
|
||||
if isinstance(sel, dict) and "group_id" in sel:
|
||||
return str(sel.get("group_id"))
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _load_dbrun_config_id(project_root: str) -> Optional[str]:
|
||||
"""从 dbrun.json 读取 config_id
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
|
||||
Returns:
|
||||
Optional[str]: config_id,未找到时返回 None
|
||||
"""
|
||||
try:
|
||||
path = os.path.join(project_root, "dbrun.json")
|
||||
if not os.path.isfile(path):
|
||||
return None
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
if isinstance(data, dict):
|
||||
if "config_id" in data:
|
||||
return str(data.get("config_id"))
|
||||
sel = data.get("selections", {})
|
||||
if isinstance(sel, dict) and "config_id" in sel:
|
||||
return str(sel.get("config_id"))
|
||||
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _apply_overrides_from_db_row(
|
||||
runtime_cfg: Dict[str, Any],
|
||||
db_row: Optional[Dict[str, Any]],
|
||||
identifier: str,
|
||||
identifier_type: str = "config_id"
|
||||
) -> Dict[str, Any]:
|
||||
"""从数据库行数据覆写运行时配置(统一处理函数)
|
||||
|
||||
Args:
|
||||
runtime_cfg: 运行时配置字典
|
||||
db_row: 数据库查询结果行
|
||||
identifier: 标识符值(group_id 或 config_id)
|
||||
identifier_type: 标识符类型("group_id" 或 "config_id")
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selections = runtime_cfg.setdefault("selections", {})
|
||||
selections[identifier_type] = identifier
|
||||
|
||||
if not db_row:
|
||||
return runtime_cfg
|
||||
|
||||
# 覆写 selections 字段
|
||||
for tk in ("group_id", "user_id", "apply_id", "chunker_strategy", "state",
|
||||
"t_type_strict", "t_name_strict", "t_overall",
|
||||
"statement_granularity", "include_dialogue_context"):
|
||||
_set_if_present(selections, tk, db_row, tk, str)
|
||||
|
||||
# 特殊处理 UUID 字段,确保转换为字符串格式
|
||||
for uuid_field in ("llm_id", "embedding_id"):
|
||||
if uuid_field in db_row and db_row.get(uuid_field) is not None:
|
||||
try:
|
||||
value = db_row.get(uuid_field)
|
||||
# 如果是 UUID 对象,转换为字符串(带连字符的标准格式)
|
||||
if hasattr(value, 'hex'):
|
||||
selections[uuid_field] = str(value)
|
||||
else:
|
||||
selections[uuid_field] = str(value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 覆写 statement_extraction 字段
|
||||
stmt = runtime_cfg.setdefault("statement_extraction", {})
|
||||
_set_if_present(stmt, "statement_granularity", db_row, "statement_granularity", int)
|
||||
_set_if_present(stmt, "include_dialogue_context", db_row, "include_dialogue_context", _to_bool)
|
||||
_set_if_present(stmt, "max_dialogue_context_chars", db_row, "max_context", int)
|
||||
|
||||
# 覆写 deduplication 字段
|
||||
dedup = runtime_cfg.setdefault("deduplication", {})
|
||||
for tk in ("enable_llm_dedup_blockwise", "enable_llm_disambiguation"):
|
||||
_set_if_present(dedup, tk, db_row, tk, _to_bool)
|
||||
_set_if_present(dedup, "deep_retrieval", db_row, "deep_retrieval", _to_bool)
|
||||
|
||||
# 覆写 forgetting_engine 字段
|
||||
forgetting = runtime_cfg.setdefault("forgetting_engine", {})
|
||||
_set_if_present(forgetting, "offset", db_row, "offset", float)
|
||||
_set_if_present(forgetting, "lambda_time", db_row, "lambda_time", float)
|
||||
_set_if_present(forgetting, "lambda_mem", db_row, "lambda_mem", float)
|
||||
|
||||
# 覆写 pruning 字段
|
||||
pruning = runtime_cfg.setdefault("pruning", {})
|
||||
_set_if_present(pruning, "enabled", db_row, "pruning_enabled", _to_bool)
|
||||
_set_if_present(pruning, "scene", db_row, "pruning_scene", str)
|
||||
|
||||
# 阈值需要转为 float,且限制在 [0.0, 0.9]
|
||||
try:
|
||||
if "pruning_threshold" in db_row and db_row.get("pruning_threshold") is not None:
|
||||
thr = float(db_row.get("pruning_threshold"))
|
||||
thr = max(0.0, min(0.9, thr)) # 限制在 [0.0, 0.9]
|
||||
pruning["threshold"] = thr
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return runtime_cfg
|
||||
except Exception as e:
|
||||
pass
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_by_group(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""基于 group_id 从数据库覆写运行时配置
|
||||
|
||||
工作流程:
|
||||
1. 从 dbrun.json 读取 group_id
|
||||
2. 根据 group_id 查询数据库配置
|
||||
3. 覆写运行时配置(仅在内存中)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selected_gid = _load_dbrun_group_id(project_root)
|
||||
if not selected_gid:
|
||||
return runtime_cfg
|
||||
|
||||
db_row = _fetch_db_config_by_group_id(selected_gid)
|
||||
if not db_row:
|
||||
# 如果数据库中没有配置,仍然设置 group_id
|
||||
runtime_cfg.setdefault("selections", {})["group_id"] = selected_gid
|
||||
return runtime_cfg
|
||||
|
||||
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_gid, "group_id")
|
||||
except Exception:
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_by_config(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""基于 config_id 从数据库覆写运行时配置(从 dbrun.json 读取)
|
||||
|
||||
工作流程:
|
||||
1. 从 dbrun.json 读取 config_id
|
||||
2. 根据 config_id 查询数据库配置
|
||||
3. 覆写运行时配置(仅在内存中)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 覆写后的运行时配置
|
||||
"""
|
||||
try:
|
||||
selected_cid = _load_dbrun_config_id(project_root)
|
||||
if not selected_cid:
|
||||
return runtime_cfg
|
||||
|
||||
db_row = _fetch_db_config_by_config_id(selected_cid)
|
||||
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
|
||||
except Exception:
|
||||
return runtime_cfg
|
||||
|
||||
|
||||
def apply_runtime_overrides_with_config_id(
|
||||
project_root: str,
|
||||
runtime_cfg: Dict[str, Any],
|
||||
config_id: str
|
||||
) -> tuple[Dict[str, Any], bool]:
|
||||
"""使用指定的 config_id 从数据库覆写运行时配置(不读 dbrun.json)
|
||||
|
||||
用于前端动态切换配置的场景。
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
runtime_cfg: 运行时配置字典
|
||||
config_id: 配置标识符
|
||||
|
||||
Returns:
|
||||
tuple[Dict[str, Any], bool]: (覆写后的运行时配置, 是否成功从数据库加载)
|
||||
"""
|
||||
try:
|
||||
selected_cid = str(config_id).strip()
|
||||
if not selected_cid:
|
||||
return runtime_cfg, False
|
||||
|
||||
db_row = _fetch_db_config_by_config_id(selected_cid)
|
||||
if db_row is None:
|
||||
return runtime_cfg, False
|
||||
|
||||
updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
|
||||
return updated_cfg, True
|
||||
except Exception as e:
|
||||
pass
|
||||
return runtime_cfg, False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 以下函数已注释:不再需要网络模式自动检测功能
|
||||
# ============================================================================
|
||||
|
||||
# def get_server_ip() -> str:
|
||||
# """
|
||||
# 获取当前服务器的IP地址
|
||||
#
|
||||
# Returns:
|
||||
# 服务器IP地址字符串
|
||||
# """
|
||||
# try:
|
||||
# # 方式1:从环境变量获取(优先)
|
||||
# server_ip = os.getenv('SERVER_IP')
|
||||
# if server_ip and server_ip not in ['127.0.0.1', 'localhost', '0.0.0.0']:
|
||||
# return server_ip
|
||||
#
|
||||
# # 方式2:通过socket获取
|
||||
# hostname = socket.gethostname()
|
||||
# ip_address = socket.gethostbyname(hostname)
|
||||
#
|
||||
# # 如果是本地回环地址,尝试获取真实IP
|
||||
# if ip_address.startswith('127.'):
|
||||
# # 尝试连接外部地址来获取本机IP
|
||||
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# try:
|
||||
# s.connect(('8.8.8.8', 80))
|
||||
# ip_address = s.getsockname()[0]
|
||||
# finally:
|
||||
# s.close()
|
||||
#
|
||||
# return ip_address
|
||||
# except Exception as e:
|
||||
# print(f"[overrides] 获取服务器IP失败: {e},使用默认值 127.0.0.1")
|
||||
# return '127.0.0.1'
|
||||
|
||||
|
||||
# def auto_detect_network_mode() -> NetworkMode:
|
||||
# """
|
||||
# 自动检测网络模式(基于服务器IP)
|
||||
#
|
||||
# 规则:
|
||||
# - 如果服务器IP在内网IP列表中 → internal(内网)
|
||||
# - 其他IP → external(外网)
|
||||
#
|
||||
# 可以通过环境变量 INTERNAL_SERVER_IPS 自定义内网IP列表(逗号分隔)
|
||||
#
|
||||
# Returns:
|
||||
# 'internal' 或 'external'
|
||||
# """
|
||||
# server_ip = get_server_ip()
|
||||
#
|
||||
# # 从环境变量获取内网IP列表(支持多个IP,逗号分隔)
|
||||
# internal_ips_str = os.getenv('INTERNAL_SERVER_IPS', '119.45.181.55')
|
||||
# internal_ips = [ip.strip() for ip in internal_ips_str.split(',')]
|
||||
#
|
||||
# # 判断当前IP是否在内网IP列表中
|
||||
# if server_ip in internal_ips:
|
||||
# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于内网,使用 INTERNAL 配置")
|
||||
# return 'internal'
|
||||
# else:
|
||||
# print(f"[overrides] 自动检测:服务器IP {server_ip} 属于外网,使用 EXTERNAL 配置")
|
||||
# return 'external'
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 环境变量覆写功能已废弃 - 不再使用
|
||||
# ============================================================================
|
||||
# def _apply_env_var_overrides(runtime_cfg: Dict[str, Any], network_mode: NetworkMode = None, force_override: bool = False) -> Dict[str, Any]:
|
||||
# """
|
||||
# 从环境变量覆写配置(已废弃)
|
||||
# """
|
||||
# return runtime_cfg
|
||||
|
||||
|
||||
def load_unified_config(
|
||||
project_root: str,
|
||||
config_id: Optional[int | str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
network_mode: NetworkMode = None,
|
||||
env_override_models: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统一配置加载器 - 按优先级加载配置
|
||||
|
||||
配置加载优先级:
|
||||
1. PG数据库配置(最高优先级,通过 dbrun.json 中的 config_id 读取)
|
||||
2. runtime.json 默认配置(最低优先级)
|
||||
|
||||
Args:
|
||||
project_root: 项目根目录路径
|
||||
config_id: 配置ID(整数或字符串,可选,优先从 dbrun.json 读取)
|
||||
group_id: 组ID(可选)
|
||||
network_mode: 已废弃,保留参数仅为向后兼容
|
||||
env_override_models: 已废弃,保留参数仅为向后兼容
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 最终的运行时配置
|
||||
"""
|
||||
try:
|
||||
# 步骤 1: 加载 runtime.json 作为基础配置
|
||||
runtime_config_path = os.path.join(project_root, "runtime.json")
|
||||
try:
|
||||
with open(runtime_config_path, "r", encoding="utf-8") as f:
|
||||
runtime_cfg = json.load(f)
|
||||
except (FileNotFoundError, json.JSONDecodeError) as e:
|
||||
runtime_cfg = {"selections": {}}
|
||||
|
||||
# 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级)
|
||||
if config_id:
|
||||
# 优先使用传入的 config_id
|
||||
db_row = _fetch_db_config_by_config_id(config_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, config_id, "config_id")
|
||||
pass
|
||||
elif group_id:
|
||||
# 其次使用 group_id
|
||||
db_row = _fetch_db_config_by_group_id(group_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, group_id, "group_id")
|
||||
pass
|
||||
else:
|
||||
# 尝试从 dbrun.json 读取
|
||||
dbrun_config_id = _load_dbrun_config_id(project_root)
|
||||
if dbrun_config_id:
|
||||
db_row = _fetch_db_config_by_config_id(dbrun_config_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_config_id, "config_id")
|
||||
pass
|
||||
else:
|
||||
dbrun_group_id = _load_dbrun_group_id(project_root)
|
||||
if dbrun_group_id:
|
||||
db_row = _fetch_db_config_by_group_id(dbrun_group_id)
|
||||
if db_row:
|
||||
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_group_id, "group_id")
|
||||
pass
|
||||
return runtime_cfg
|
||||
|
||||
except Exception as e:
|
||||
return {"selections": {}}
|
||||
|
||||
|
||||
# 向后兼容的别名
|
||||
apply_runtime_overrides = apply_runtime_overrides_by_config
|
||||
43
app/core/memory/utils/data/__init__.py
Normal file
43
app/core/memory/utils/data/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
数据处理模块
|
||||
|
||||
包含所有数据处理相关的工具函数,包括文本处理、时间处理和本体定义。
|
||||
"""
|
||||
|
||||
# 从子模块导出常用函数和类,保持向后兼容
|
||||
from .text_utils import (
|
||||
escape_lucene_query,
|
||||
extract_plain_query,
|
||||
)
|
||||
from .time_utils import (
|
||||
validate_date_format,
|
||||
normalize_date,
|
||||
normalize_date_safe,
|
||||
preprocess_date_string,
|
||||
)
|
||||
from .ontology import (
|
||||
PREDICATE_DEFINITIONS,
|
||||
LABEL_DEFINITIONS,
|
||||
Predicate,
|
||||
StatementType,
|
||||
TemporalInfo,
|
||||
RelevenceInfo,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# text_utils
|
||||
"escape_lucene_query",
|
||||
"extract_plain_query",
|
||||
# time_utils
|
||||
"validate_date_format",
|
||||
"normalize_date",
|
||||
"normalize_date_safe",
|
||||
"preprocess_date_string",
|
||||
# ontology
|
||||
"PREDICATE_DEFINITIONS",
|
||||
"LABEL_DEFINITIONS",
|
||||
"Predicate",
|
||||
"StatementType",
|
||||
"TemporalInfo",
|
||||
"RelevenceInfo",
|
||||
]
|
||||
199
app/core/memory/utils/data/ontology.py
Normal file
199
app/core/memory/utils/data/ontology.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
# Use jinja template.render
|
||||
PREDICATE_DEFINITIONS = {
|
||||
"IS_A": "Denotes a class-or-type relationship between two entities (e.g., 'Model Y IS_A electric-SUV'). Includes 'is' and 'was'.",
|
||||
"HAS_A": "Denotes a part-whole relationship between two entities (e.g., 'Model Y HAS_A electric-engine'). Includes 'has' and 'had'.",
|
||||
"LOCATED_IN": "Specifies geographic or organisational containment or proximity (e.g., headquarters LOCATED_IN Berlin).",
|
||||
"HOLDS_ROLE": "Connects a person to a formal office or title within an organisation (CEO, Chair, Director, etc.).",
|
||||
"PRODUCES": "Indicates that an entity manufactures, builds, or creates a product, service, or infrastructure (includes scale-ups and component inclusion).",
|
||||
"SELLS": "Marks a commercial seller-to-customer relationship for a product or service (markets, distributes, sells).",
|
||||
"LAUNCHED": "Captures the official first release, shipment, or public start of a product, service, or initiative.",
|
||||
"DEVELOPED": "Shows design, R&D, or innovation origin of a technology, product, or capability. Includes 'researched' or 'created'.",
|
||||
"ADOPTED_BY": "Indicates that a technology or product has been taken up, deployed, or implemented by another entity.",
|
||||
"INVESTS_IN": "Represents the flow of capital or resources from one entity into another (equity, funding rounds, strategic investment).",
|
||||
"COLLABORATES_WITH": "Generic partnership, alliance, joint venture, or licensing relationship between entities.",
|
||||
"SUPPLIES": "Captures vendor–client supply-chain links or dependencies (provides to, sources from).",
|
||||
"HAS_REVENUE": "Associates an entity with a revenue amount or metric—actual, reported, or projected.",
|
||||
"INCREASED": "Expresses an upward change in a metric (revenue, market share, output) relative to a prior period or baseline.",
|
||||
"DECREASED": "Expresses a downward change in a metric relative to a prior period or baseline.",
|
||||
"RESULTED_IN": "Captures a causal relationship where one event or factor leads to a specific outcome (positive or negative).",
|
||||
"TARGETS": "Denotes a strategic objective, market segment, or customer group that an entity seeks to reach.",
|
||||
"PART_OF": "Expresses hierarchical membership or subset relationships (division, subsidiary, managed by, belongs to).",
|
||||
"DISCONTINUED": "Indicates official end-of-life, shutdown, or termination of a product, service, or relationship.",
|
||||
"SECURED": "Marks the successful acquisition of funding, contracts, assets, or rights by an entity.",
|
||||
"MENTIONS": "Denotes a reference or mention of an entity in a text or document.",
|
||||
|
||||
# 移除了过于宽泛的谓语集合
|
||||
# "MENTIONS": "Denotes a reference or mention of an entity in a text or document." ,
|
||||
# "FEELS" : "Denotes a subjective opinion or feeling about an entity (e.g., 'I feel like X').Includes 'THINKS'.",
|
||||
# "HELPS" :"Express a action that make it easier or possible for (someone) to do something by offering one's services or resources. Includes 'assist', 'aid' and 'support' " ,
|
||||
# "IS_DOING" : "Denotes a subjective action or activity about an entity (e.g., 'I am doing X').Includes 'DOES'.",
|
||||
# "LIKES": "Express enjoy or approve of something or someone (e.g., 'I like roses').Includes 'LIKES'.",
|
||||
# "DISLIKES": "Express dislike or disapprove of something or someone (e.g., 'I dislike roses').Includes 'DISLIKES'.",
|
||||
# "HAS_ATTRIBUTE": "Express that an entity has a certain attribute (e.g., 'X has a red car').Includes 'HAS'.",
|
||||
|
||||
}
|
||||
|
||||
LABEL_DEFINITIONS: dict[str, dict[str, dict[str, str]]] = {
|
||||
"statement_labelling": {
|
||||
"FACT": dict(
|
||||
definition=(
|
||||
"Statements that are objective and can be independently "
|
||||
"verified or falsified through evidence."
|
||||
),
|
||||
date_handling_guidance=(
|
||||
"These statements can be made up of multiple static and "
|
||||
"dynamic temporal events marking for example the start, end, "
|
||||
"and duration of the fact described statement."
|
||||
),
|
||||
date_handling_example=(
|
||||
"'Company A owns Company B in 2022', 'X caused Y to happen', "
|
||||
"or 'John said X at Event' are verifiable facts which currently "
|
||||
"hold true unless we have a contradictory fact."
|
||||
),
|
||||
),
|
||||
"OPINION": dict(
|
||||
definition=(
|
||||
"Statements that contain personal opinions, feelings, values, "
|
||||
"or judgments that are not independently verifiable. It also "
|
||||
"includes hypothetical and speculative statements."
|
||||
),
|
||||
date_handling_guidance=(
|
||||
"This statement is always static. It is a record of the date the "
|
||||
"opinion was made."
|
||||
),
|
||||
date_handling_example=(
|
||||
"'I like Company A's strategy', 'X may have caused Y to happen', "
|
||||
"or 'The event felt like X' are opinions and down to the reporters "
|
||||
"interpretation."
|
||||
),
|
||||
),
|
||||
"PREDICTION": dict(
|
||||
definition=(
|
||||
"Uncertain statements about the future on something that might happen, "
|
||||
"a hypothetical outcome, unverified claims. "
|
||||
"If the tense of the statement changed, the statement "
|
||||
"would then become a fact."
|
||||
),
|
||||
date_handling_guidance=(
|
||||
"This statement is always static. It is a record of the date the "
|
||||
"prediction was made."
|
||||
),
|
||||
date_handling_example=(
|
||||
"'It is rumoured that Dave will resign next month', 'Company A expects "
|
||||
"X to happen', or 'X suggests Y' are all predictions."
|
||||
),
|
||||
),
|
||||
"SUGGESTION": dict(
|
||||
definition=(
|
||||
"A proposal or recommendation for action, often implying a future course of conduct. "
|
||||
" It's not a statement of fact or a prediction, but rather an advised path. "
|
||||
"It's a suggestion for action that is not yet implemented."
|
||||
),
|
||||
date_handling_guidance=(
|
||||
"This statement is always static."
|
||||
),
|
||||
date_handling_example=(
|
||||
"'They should launch the new product next quarter', 'You could try a different approach', "
|
||||
"or 'I would recommend moving the headquarters to Berlin' are all suggestions."
|
||||
),
|
||||
),
|
||||
},
|
||||
"temporal_labelling": {
|
||||
"STATIC": dict(
|
||||
definition=(
|
||||
"Often past tense, think -ed verbs, describing single points-in-time. "
|
||||
"These statements are valid from the day they occurred and are never "
|
||||
"invalid. Refer to single points in time at which an event occurred, "
|
||||
"the fact X occurred on that date will always hold true."
|
||||
),
|
||||
date_handling_guidance=(
|
||||
"The valid_at date is the date the event occurred. The invalid_at date "
|
||||
"is None."
|
||||
),
|
||||
date_handling_example=(
|
||||
"'John was appointed CEO on 4th Jan 2024', 'Company A reported X percent "
|
||||
"growth from last FY', or 'X resulted in Y to happen' are valid the day "
|
||||
"they occurred and are never invalid."
|
||||
),
|
||||
),
|
||||
"DYNAMIC": dict(
|
||||
definition=(
|
||||
"Often present tense, think -ing verbs, describing a period of time. "
|
||||
"These statements are valid for a specific period of time and are usually "
|
||||
"invalidated by a Static fact marking the end of the event or start of a "
|
||||
"contradictory new one. The statement could already be referring to a "
|
||||
"discrete time period (invalid) or may be an ongoing relationship (not yet "
|
||||
"invalid)."
|
||||
),
|
||||
date_handling_guidance=(
|
||||
"The valid_at date is the date the event started. The invalid_at date is "
|
||||
"the date the event or relationship ended, for ongoing events this is None."
|
||||
),
|
||||
date_handling_example=(
|
||||
"'John is the CEO', 'Company A remains a market leader', or 'X is continuously "
|
||||
"causing Y to decrease' are valid from when the event started and are invalidated "
|
||||
"by a new event."
|
||||
),
|
||||
),
|
||||
"ATEMPORAL": dict(
|
||||
definition=(
|
||||
"Statements that will always hold true regardless of time therefore have no "
|
||||
"temporal bounds."
|
||||
),
|
||||
date_handling_guidance=(
|
||||
"These statements are assumed to be atemporal and have no temporal bounds. Both "
|
||||
"their valid_at and invalid_at are None."
|
||||
),
|
||||
date_handling_example=(
|
||||
"'A stock represents a unit of ownership in a company', 'The earth is round', or "
|
||||
"'Europe is a continent'. These statements are true regardless of time."
|
||||
),
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
class Predicate(StrEnum):
|
||||
"""Enumeration of normalised predicates."""
|
||||
|
||||
IS_A = "IS_A"
|
||||
HAS_A = "HAS_A"
|
||||
LOCATED_IN = "LOCATED_IN"
|
||||
HOLDS_ROLE = "HOLDS_ROLE"
|
||||
PRODUCES = "PRODUCES"
|
||||
SELLS = "SELLS"
|
||||
LAUNCHED = "LAUNCHED"
|
||||
DEVELOPED = "DEVELOPED"
|
||||
ADOPTED_BY = "ADOPTED_BY"
|
||||
INVESTS_IN = "INVESTS_IN"
|
||||
COLLABORATES_WITH = "COLLABORATES_WITH"
|
||||
SUPPLIES = "SUPPLIES"
|
||||
HAS_REVENUE = "HAS_REVENUE"
|
||||
INCREASED = "INCREASED"
|
||||
DECREASED = "DECREASED"
|
||||
RESULTED_IN = "RESULTED_IN"
|
||||
TARGETS = "TARGETS"
|
||||
PART_OF = "PART_OF"
|
||||
DISCONTINUED = "DISCONTINUED"
|
||||
SECURED = "SECURED"
|
||||
MENTIONS = "MENTIONS"
|
||||
|
||||
|
||||
class StatementType(StrEnum):
|
||||
FACT = "FACT"
|
||||
OPINION = "OPINION"
|
||||
PREDICTION = "PREDICTION"
|
||||
SUGGESTION = "SUGGESTION"
|
||||
|
||||
class TemporalInfo(StrEnum):
|
||||
ATEMPORAL = "ATEMPORAL"
|
||||
STATIC = "STATIC"
|
||||
DYNAMIC = "DYNAMIC"
|
||||
|
||||
# Relevance labelling for statements
|
||||
class RelevenceInfo(StrEnum):
|
||||
RELEVANT = "RELEVANT"
|
||||
IRRELEVANT = "IRRELEVANT"
|
||||
|
||||
81
app/core/memory/utils/data/text_utils.py
Normal file
81
app/core/memory/utils/data/text_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
|
||||
|
||||
def escape_lucene_query(query: str) -> str:
|
||||
"""Escape Lucene special characters in a free-text query.
|
||||
|
||||
This prevents ParseException when using Neo4j full-text procedures.
|
||||
"""
|
||||
if query is None:
|
||||
return ""
|
||||
|
||||
s = str(query)
|
||||
# Normalize whitespace
|
||||
s = s.replace("\r", " ").replace("\n", " ").strip()
|
||||
|
||||
# Lucene reserved tokens/special characters
|
||||
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':']
|
||||
# Replace longer tokens first to avoid partial double-escaping
|
||||
for token in sorted(specials, key=len, reverse=True):
|
||||
s = s.replace(token, f"\\{token}")
|
||||
|
||||
return s
|
||||
|
||||
def extract_plain_query(query_input: str) -> str:
|
||||
"""Extract clean, plain-text query from various input forms.
|
||||
|
||||
- Strips surrounding quotes and whitespace
|
||||
- If input looks like JSON, prefers the 'original' field
|
||||
- Fallbacks to the raw string when parsing fails
|
||||
"""
|
||||
if query_input is None:
|
||||
return ""
|
||||
|
||||
# Directly handle dict-like input
|
||||
if isinstance(query_input, dict):
|
||||
original = query_input.get("original")
|
||||
if isinstance(original, str) and original.strip():
|
||||
return original.strip()
|
||||
context = query_input.get("context")
|
||||
if isinstance(context, dict):
|
||||
for key, val in context.items():
|
||||
if isinstance(key, str) and key.strip():
|
||||
return key.strip()
|
||||
if isinstance(val, list) and val:
|
||||
first = val[0]
|
||||
if isinstance(first, str) and first.strip():
|
||||
return first.strip()
|
||||
# Fallback to string conversion below
|
||||
|
||||
s = str(query_input).strip()
|
||||
|
||||
# Remove surrounding single/double quotes if present
|
||||
if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
|
||||
s = s[1:-1].strip()
|
||||
|
||||
# Attempt to parse JSON and extract the 'original' field
|
||||
if s.startswith("{") and s.endswith("}"):
|
||||
try:
|
||||
data = json.loads(s)
|
||||
# Prefer 'original' field if available
|
||||
original = data.get("original")
|
||||
if isinstance(original, str) and original.strip():
|
||||
return original.strip()
|
||||
# Fallbacks: try common nested structures
|
||||
context = data.get("context")
|
||||
if isinstance(context, dict):
|
||||
# Take the first key or first string value in context
|
||||
for key, val in context.items():
|
||||
if isinstance(key, str) and key.strip():
|
||||
return key.strip()
|
||||
if isinstance(val, list) and val:
|
||||
first = val[0]
|
||||
if isinstance(first, str) and first.strip():
|
||||
return first.strip()
|
||||
except Exception:
|
||||
# Not valid JSON; keep as-is after best-effort unescape below
|
||||
pass
|
||||
|
||||
# Best-effort unescape common escaped newlines/tabs without altering unicode
|
||||
s = s.replace("\\n", " ").replace("\\t", " ")
|
||||
return s
|
||||
127
app/core/memory/utils/data/time_utils.py
Normal file
127
app/core/memory/utils/data/time_utils.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import re
|
||||
from dateutil import parser
|
||||
from datetime import datetime
|
||||
|
||||
def validate_date_format(date_str: str) -> bool:
|
||||
"""
|
||||
Validate if the date string is in the format YYYY-MM-DD.
|
||||
"""
|
||||
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
||||
return bool(re.match(pattern, date_str))
|
||||
|
||||
|
||||
def normalize_date(date_str: str) -> str:
|
||||
"""
|
||||
更强大的日期标准化函数,支持多种日期格式转换为 Y-M-D 格式
|
||||
|
||||
Args:
|
||||
date_str: 各种格式的日期字符串
|
||||
|
||||
Returns:
|
||||
Y-M-D 格式的标准化日期字符串
|
||||
"""
|
||||
if not date_str or not isinstance(date_str, str):
|
||||
return date_str
|
||||
|
||||
# 移除首尾空格
|
||||
date_str = date_str.strip().replace(' ', '').replace('/', '').replace('.', '').replace('_', '').replace('-', '')
|
||||
|
||||
try:
|
||||
# 预处理:识别并规范化特殊格式
|
||||
preprocessed_str = preprocess_date_string(date_str)
|
||||
|
||||
# 使用 dateutil.parser 进行解析[citation:1][citation:7]
|
||||
dt = parser.parse(preprocessed_str, dayfirst=False, yearfirst=True)
|
||||
|
||||
return dt.strftime('%Y-%m-%d')
|
||||
|
||||
except (ValueError, TypeError, OverflowError):
|
||||
# 如果智能解析失败,尝试格式匹配
|
||||
return fallback_parse(date_str)
|
||||
|
||||
|
||||
def preprocess_date_string(date_str: str) -> str:
|
||||
"""预处理日期字符串,处理特殊格式"""
|
||||
|
||||
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
||||
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
||||
if match:
|
||||
year, month, day = match.groups()
|
||||
# 如果年份超过4位,可能是年份和月份连在一起
|
||||
if len(year) > 4:
|
||||
# 取前4位作为年份,剩余作为月份
|
||||
actual_year = year[:4]
|
||||
actual_month = year[4:] + (month if month else '')
|
||||
# 重新组合
|
||||
if day:
|
||||
return f"{actual_year}-{actual_month.zfill(2)}-{day.zfill(2)}"
|
||||
else:
|
||||
return f"{actual_year}-{actual_month.zfill(2)}"
|
||||
else:
|
||||
return f"{year}-{month.zfill(2)}-{day.zfill(2)}" if day else f"{year}-{month.zfill(2)}"
|
||||
|
||||
# 处理无分隔符的纯数字格式[citation:4]
|
||||
if re.match(r'^\d{6,8}$', date_str):
|
||||
if len(date_str) == 8: # YYYYMMDD
|
||||
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
|
||||
elif len(date_str) == 6: # YYMMDD 或 MMDDYY
|
||||
# 尝试不同解释
|
||||
if 1 <= int(date_str[:2]) <= 12: # 可能是 MMDDYY
|
||||
return f"20{date_str[4:6]}-{date_str[:2]}-{date_str[2:4]}"
|
||||
else: # 可能是 YYMMDD
|
||||
return f"20{date_str[:2]}-{date_str[2:4]}-{date_str[4:6]}"
|
||||
|
||||
# 处理混合分隔符,统一为 -
|
||||
date_str = re.sub(r'[/\._]', '-', date_str)
|
||||
|
||||
return date_str
|
||||
|
||||
|
||||
def fallback_parse(date_str: str) -> str:
|
||||
"""备选解析方案"""
|
||||
|
||||
# 尝试常见的日期格式[citation:4][citation:5]
|
||||
formats_to_try = [
|
||||
'%Y-%m-%d', '%Y/%m/%d', '%Y.%m.%d',
|
||||
'%Y%m%d', '%y%m%d',
|
||||
'%m-%d-%Y', '%m/%d/%Y', '%m.%d.%Y',
|
||||
'%d-%m-%Y', '%d/%m/%Y', '%d.%m.%Y',
|
||||
'%Y-%m', '%Y/%m', '%Y.%m'
|
||||
]
|
||||
|
||||
for fmt in formats_to_try:
|
||||
try:
|
||||
dt = datetime.strptime(date_str, fmt)
|
||||
return dt.strftime('%Y-%m-%d')
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# 所有方法都失败时,返回原字符串或抛出异常
|
||||
return date_str
|
||||
|
||||
|
||||
def normalize_date_safe(date_str: str, default: str = None) -> str:
|
||||
"""
|
||||
安全的日期标准化函数,提供默认值处理
|
||||
|
||||
Args:
|
||||
date_str: 日期字符串
|
||||
default: 解析失败时的默认返回值
|
||||
|
||||
Returns:
|
||||
标准化日期字符串或默认值
|
||||
"""
|
||||
try:
|
||||
result = normalize_date(date_str)
|
||||
# 检查结果是否是有效的日期格式
|
||||
if validate_date_format(result):
|
||||
return result
|
||||
else:
|
||||
return default if default is not None else date_str
|
||||
except:
|
||||
return default if default is not None else date_str
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_dates = ["2025/10/28", "2025.10.28", "2025_10_28", "20251028"]
|
||||
for date in start_dates:
|
||||
print(normalize_date_safe(date))
|
||||
18
app/core/memory/utils/llm/__init__.py
Normal file
18
app/core/memory/utils/llm/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
LLM 工具模块
|
||||
|
||||
包含所有 LLM 客户端相关的工具函数。
|
||||
"""
|
||||
|
||||
# 从子模块导出常用函数,保持向后兼容
|
||||
from .llm_utils import (
|
||||
get_llm_client,
|
||||
get_reranker_client,
|
||||
handle_response,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_llm_client",
|
||||
"get_reranker_client",
|
||||
"handle_response",
|
||||
]
|
||||
77
app/core/memory/utils/llm/llm_utils.py
Normal file
77
app/core/memory/utils/llm/llm_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import os
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config.config_utils import get_model_config
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
async def handle_response(response: type[BaseModel]) -> dict:
|
||||
return response.model_dump()
|
||||
|
||||
|
||||
def get_llm_client(llm_id: str | None = None):
|
||||
llm_id = llm_id or config_defs.SELECTED_LLM_ID
|
||||
|
||||
# Validate LLM ID exists before attempting to get config
|
||||
if not llm_id:
|
||||
raise ValueError("LLM ID is required but was not provided")
|
||||
|
||||
try:
|
||||
model_config = get_model_config(llm_id)
|
||||
except Exception as e:
|
||||
# Re-raise with clear error message about invalid LLM ID
|
||||
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
# 移除调试打印,避免污染终端输出
|
||||
# print(model_config)
|
||||
llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
# print(llm.dict())
|
||||
return llm_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
|
||||
|
||||
|
||||
def get_reranker_client(rerank_id: str | None = None):
|
||||
"""
|
||||
Get an LLM client configured for reranking.
|
||||
|
||||
Args:
|
||||
rerank_id: Optional reranker model ID. If None, uses SELECTED_RERANK_ID.
|
||||
|
||||
Returns:
|
||||
OpenAIClient: Initialized client for the reranker model
|
||||
|
||||
Raises:
|
||||
ValueError: If rerank_id is invalid or client initialization fails
|
||||
"""
|
||||
rerank_id = rerank_id or config_defs.SELECTED_RERANK_ID
|
||||
|
||||
# Validate rerank ID exists before attempting to get config
|
||||
if not rerank_id:
|
||||
raise ValueError("Rerank ID is required but was not provided")
|
||||
|
||||
try:
|
||||
model_config = get_model_config(rerank_id)
|
||||
except Exception as e:
|
||||
# Re-raise with clear error message about invalid rerank ID
|
||||
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
|
||||
|
||||
try:
|
||||
reranker_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url")
|
||||
),type_=model_config.get("type"))
|
||||
return reranker_client
|
||||
except Exception as e:
|
||||
model_name = model_config.get('model_name', 'unknown')
|
||||
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e
|
||||
24
app/core/memory/utils/log/__init__.py
Normal file
24
app/core/memory/utils/log/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
日志管理模块
|
||||
|
||||
包含所有日志相关的工具函数。
|
||||
"""
|
||||
|
||||
# 从子模块导出常用函数,保持向后兼容
|
||||
from .logging_utils import (
|
||||
log_prompt_rendering,
|
||||
log_template_rendering,
|
||||
log_time,
|
||||
prompt_logger,
|
||||
)
|
||||
from .audit_logger import audit_logger
|
||||
|
||||
__all__ = [
|
||||
# logging_utils
|
||||
"log_prompt_rendering",
|
||||
"log_template_rendering",
|
||||
"log_time",
|
||||
"prompt_logger",
|
||||
# audit_logger
|
||||
"audit_logger",
|
||||
]
|
||||
182
app/core/memory/utils/log/audit_logger.py
Normal file
182
app/core/memory/utils/log/audit_logger.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
配置审计日志记录器
|
||||
|
||||
提供专门的审计日志功能,用于追踪配置变更和操作记录。
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
def _format_value(value: Any) -> str:
|
||||
"""
|
||||
格式化值为字符串,特殊处理 UUID 等对象
|
||||
|
||||
Args:
|
||||
value: 要格式化的值
|
||||
|
||||
Returns:
|
||||
str: 格式化后的字符串
|
||||
"""
|
||||
if value is None:
|
||||
return "None"
|
||||
elif isinstance(value, bool):
|
||||
return str(value)
|
||||
elif hasattr(value, 'hex'): # UUID 对象有 hex 属性
|
||||
return str(value) # 使用标准的 UUID 字符串格式(带连字符)
|
||||
else:
|
||||
return str(value)
|
||||
|
||||
|
||||
class ConfigAuditLogger:
|
||||
"""配置审计日志记录器"""
|
||||
|
||||
def __init__(self, log_file: str = "logs/config_audit.log"):
|
||||
"""
|
||||
初始化审计日志记录器
|
||||
|
||||
Args:
|
||||
log_file: 日志文件路径
|
||||
"""
|
||||
self.logger = logging.getLogger("config_audit")
|
||||
self.logger.setLevel(logging.INFO)
|
||||
|
||||
# 避免重复添加处理器
|
||||
if not self.logger.handlers:
|
||||
# 确保日志目录存在
|
||||
log_dir = os.path.dirname(log_file)
|
||||
if log_dir and not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 创建文件处理器
|
||||
handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s [AUDIT] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
|
||||
def log_config_load(
|
||||
self,
|
||||
config_id: str,
|
||||
user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
success: bool = True,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
记录配置加载事件
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
user_id: 用户 ID(可选)
|
||||
group_id: 组 ID(可选)
|
||||
success: 是否成功
|
||||
details: 详细信息(可选)
|
||||
"""
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"CONFIG_LOAD config_id={config_id} "
|
||||
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
|
||||
f"result={result}"
|
||||
)
|
||||
if details:
|
||||
# 格式化详细信息,确保所有值都正确转换为字符串
|
||||
details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items())
|
||||
msg += f" details=[{details_str}]"
|
||||
self.logger.info(msg)
|
||||
|
||||
def log_config_change(
|
||||
self,
|
||||
config_id: str,
|
||||
old_values: Dict[str, Any],
|
||||
new_values: Dict[str, Any],
|
||||
user_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
记录配置变更事件
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
old_values: 旧配置值
|
||||
new_values: 新配置值
|
||||
user_id: 用户 ID(可选)
|
||||
"""
|
||||
changes = []
|
||||
for key in new_values:
|
||||
if key in old_values and old_values[key] != new_values[key]:
|
||||
changes.append(f"{key}: {old_values[key]} -> {new_values[key]}")
|
||||
|
||||
if changes:
|
||||
msg = (
|
||||
f"CONFIG_CHANGE config_id={config_id} "
|
||||
f"user={user_id or 'N/A'} "
|
||||
f"changes=[{', '.join(changes)}]"
|
||||
)
|
||||
self.logger.info(msg)
|
||||
|
||||
def log_operation(
|
||||
self,
|
||||
operation: str,
|
||||
config_id: str,
|
||||
group_id: str,
|
||||
success: bool = True,
|
||||
duration: Optional[float] = None,
|
||||
error: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
记录操作事件
|
||||
|
||||
Args:
|
||||
operation: 操作类型(WRITE, READ 等)
|
||||
config_id: 配置 ID
|
||||
group_id: 组 ID
|
||||
success: 是否成功
|
||||
duration: 操作耗时(秒)
|
||||
error: 错误信息(可选)
|
||||
details: 详细信息(可选)
|
||||
"""
|
||||
result = "SUCCESS" if success else "FAILED"
|
||||
msg = (
|
||||
f"{operation.upper()} config_id={config_id} "
|
||||
f"group={group_id} result={result}"
|
||||
)
|
||||
if duration is not None:
|
||||
msg += f" duration={duration:.2f}s"
|
||||
if error:
|
||||
msg += f" error={error}"
|
||||
if details:
|
||||
# 格式化详细信息,确保所有值都正确转换为字符串
|
||||
details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items())
|
||||
msg += f" details=[{details_str}]"
|
||||
self.logger.info(msg)
|
||||
|
||||
def log_cache_event(
|
||||
self,
|
||||
event_type: str,
|
||||
config_id: Optional[str] = None,
|
||||
details: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
记录缓存事件
|
||||
|
||||
Args:
|
||||
event_type: 事件类型(HIT, MISS, CLEAR, EXPIRE)
|
||||
config_id: 配置 ID(可选)
|
||||
details: 详细信息(可选)
|
||||
"""
|
||||
msg = f"CACHE_{event_type.upper()}"
|
||||
if config_id:
|
||||
msg += f" config_id={config_id}"
|
||||
if details:
|
||||
# 格式化详细信息,确保所有值都正确转换为字符串
|
||||
details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items())
|
||||
msg += f" details=[{details_str}]"
|
||||
self.logger.info(msg)
|
||||
|
||||
|
||||
# 全局审计日志记录器实例
|
||||
audit_logger = ConfigAuditLogger()
|
||||
38
app/core/memory/utils/log/logging_utils.py
Normal file
38
app/core/memory/utils/log/logging_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Logging utilities for prompt rendering and timing.
|
||||
|
||||
This module provides backward-compatible access to memory module logging utilities
|
||||
that have been unified into the centralized logging system (app.core.logging_config).
|
||||
|
||||
All logging functions are now imported from the centralized configuration to ensure
|
||||
consistent behavior, formatting, and configuration across the entire application.
|
||||
|
||||
For new code, consider importing directly from app.core.logging_config:
|
||||
from app.core.logging_config import log_prompt_rendering, log_template_rendering, log_time
|
||||
|
||||
This module maintains backward compatibility for existing code that imports from here.
|
||||
"""
|
||||
|
||||
# Import from centralized logging configuration
|
||||
from app.core.logging_config import (
|
||||
log_prompt_rendering as _log_prompt_rendering,
|
||||
log_template_rendering as _log_template_rendering,
|
||||
log_time as _log_time,
|
||||
get_prompt_logger as _get_prompt_logger,
|
||||
)
|
||||
|
||||
# Re-export functions to maintain backward compatibility
|
||||
log_prompt_rendering = _log_prompt_rendering
|
||||
log_template_rendering = _log_template_rendering
|
||||
log_time = _log_time
|
||||
|
||||
# Re-export prompt_logger for backward compatibility with code that uses it directly
|
||||
# This provides the same logger instance that was previously created in this module
|
||||
prompt_logger = _get_prompt_logger()
|
||||
|
||||
# Expose functions in __all__ for explicit exports
|
||||
__all__ = [
|
||||
'log_prompt_rendering',
|
||||
'log_template_rendering',
|
||||
'log_time',
|
||||
'prompt_logger',
|
||||
]
|
||||
16
app/core/memory/utils/paths/__init__.py
Normal file
16
app/core/memory/utils/paths/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
路径管理模块
|
||||
|
||||
包含所有路径管理相关的工具函数。
|
||||
"""
|
||||
|
||||
# 从子模块导出常用函数,保持向后兼容
|
||||
from .output_paths import (
|
||||
get_output_dir,
|
||||
get_output_path,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_output_dir",
|
||||
"get_output_path",
|
||||
]
|
||||
133
app/core/memory/utils/paths/output_paths.py
Normal file
133
app/core/memory/utils/paths/output_paths.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Output Path Management for Memory Module
|
||||
|
||||
This module provides utilities for managing output file paths in the memory module.
|
||||
All output files are now centralized in the logs/memory-output directory.
|
||||
|
||||
Migration from: app/core/memory/src/pipeline_output/
|
||||
Migration to: logs/memory-output/
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
USE_UNIFIED_CONFIG = True
|
||||
except ImportError:
|
||||
USE_UNIFIED_CONFIG = False
|
||||
settings = None
|
||||
|
||||
|
||||
def get_output_dir() -> str:
|
||||
"""
|
||||
Get the base output directory for memory module files.
|
||||
|
||||
Returns:
|
||||
str: Path to the output directory
|
||||
"""
|
||||
if USE_UNIFIED_CONFIG:
|
||||
return settings.MEMORY_OUTPUT_DIR
|
||||
else:
|
||||
# Fallback to default path
|
||||
return "logs/memory-output"
|
||||
|
||||
|
||||
def get_output_path(filename: str) -> str:
|
||||
"""
|
||||
Get the full path for a memory module output file.
|
||||
|
||||
Args:
|
||||
filename: Name of the output file
|
||||
|
||||
Returns:
|
||||
str: Full path to the output file
|
||||
"""
|
||||
if USE_UNIFIED_CONFIG:
|
||||
return settings.get_memory_output_path(filename)
|
||||
else:
|
||||
# Fallback to default path
|
||||
return os.path.join("logs/memory-output", filename)
|
||||
|
||||
|
||||
def ensure_output_dir() -> None:
|
||||
"""
|
||||
Ensure the output directory exists.
|
||||
Creates the directory if it doesn't exist.
|
||||
"""
|
||||
if USE_UNIFIED_CONFIG:
|
||||
settings.ensure_memory_output_dir()
|
||||
else:
|
||||
# Fallback: create directory manually
|
||||
output_dir = Path("logs/memory-output")
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Standard output file names (for consistency across the module)
|
||||
class OutputFiles:
|
||||
"""Standard output file names for the memory module."""
|
||||
|
||||
# Chunker output
|
||||
CHUNKER_TEST_OUTPUT = "chunker_test_output.txt"
|
||||
|
||||
# Preprocessing output
|
||||
PREPROCESSED_DATA = "preprocessed_data.json"
|
||||
PRUNED_DATA = "pruned_data.json"
|
||||
PRUNED_TERMINAL = "pruned_terminal.json"
|
||||
|
||||
# Extraction output
|
||||
STATEMENT_EXTRACTION = "statement_extraction.txt"
|
||||
RELATIONS_OUTPUT = "relations_output.txt"
|
||||
EXTRACTED_TRIPLETS = "extracted_triplets.txt"
|
||||
EXTRACTED_ENTITIES_EDGES = "extracted_entities_edges.txt"
|
||||
EXTRACTED_TEMPORAL_DATA = "extracted_temporal_data.txt"
|
||||
|
||||
# Deduplication output
|
||||
DEDUP_ENTITY_OUTPUT = "dedup_entity_output.txt"
|
||||
|
||||
# Summary output
|
||||
EXTRACTED_RESULT = "extracted_result.json"
|
||||
EXTRACTED_RESULT_READABLE = "extracted_result_readable.txt"
|
||||
|
||||
# Analytics output
|
||||
USER_DASHBOARD = "User-Dashboard.json"
|
||||
SIGNBOARD = "Signboard.json"
|
||||
|
||||
|
||||
def get_standard_output_path(file_constant: str) -> str:
|
||||
"""
|
||||
Get the full path for a standard output file.
|
||||
|
||||
Args:
|
||||
file_constant: One of the OutputFiles constants
|
||||
|
||||
Returns:
|
||||
str: Full path to the output file
|
||||
"""
|
||||
return get_output_path(file_constant)
|
||||
|
||||
|
||||
# Backward compatibility: Legacy path resolution
|
||||
def resolve_legacy_path(legacy_path: str) -> str:
|
||||
"""
|
||||
Resolve a legacy pipeline_output path to the new unified output path.
|
||||
|
||||
This function helps migrate code that uses hardcoded pipeline_output paths.
|
||||
|
||||
Args:
|
||||
legacy_path: Path containing 'pipeline_output'
|
||||
|
||||
Returns:
|
||||
str: New path using unified output directory
|
||||
"""
|
||||
if "pipeline_output" in legacy_path:
|
||||
# Extract filename from legacy path
|
||||
filename = os.path.basename(legacy_path)
|
||||
return get_output_path(filename)
|
||||
return legacy_path
|
||||
|
||||
|
||||
# Aliases for backward compatibility with test code
|
||||
get_memory_output_dir = get_output_dir
|
||||
get_memory_output_path = get_output_path
|
||||
34
app/core/memory/utils/prompt/__init__.py
Normal file
34
app/core/memory/utils/prompt/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
提示词管理模块
|
||||
|
||||
包含所有提示词渲染和模板管理相关的工具函数。
|
||||
"""
|
||||
|
||||
# 从子模块导出常用函数,保持向后兼容
|
||||
from .prompt_utils import (
|
||||
get_prompts,
|
||||
render_statement_extraction_prompt,
|
||||
render_temporal_extraction_prompt,
|
||||
render_entity_dedup_prompt,
|
||||
render_triplet_extraction_prompt,
|
||||
render_memory_summary_prompt,
|
||||
prompt_env,
|
||||
)
|
||||
from .template_render import (
|
||||
render_evaluate_prompt,
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# prompt_utils
|
||||
"get_prompts",
|
||||
"render_statement_extraction_prompt",
|
||||
"render_temporal_extraction_prompt",
|
||||
"render_entity_dedup_prompt",
|
||||
"render_triplet_extraction_prompt",
|
||||
"render_memory_summary_prompt",
|
||||
"prompt_env",
|
||||
# template_render
|
||||
"render_evaluate_prompt",
|
||||
"render_reflexion_prompt",
|
||||
]
|
||||
240
app/core/memory/utils/prompt/prompt_utils.py
Normal file
240
app/core/memory/utils/prompt/prompt_utils.py
Normal file
@@ -0,0 +1,240 @@
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
|
||||
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
|
||||
|
||||
# Setup Jinja2 environment
|
||||
# Get the directory of this file (app/core/memory/utils/prompt/)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompt_dir = os.path.join(current_dir, "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def get_prompts(message: str) -> list[dict]:
|
||||
"""
|
||||
Renders system and user prompts using Jinja2 templates.
|
||||
"""
|
||||
system_template = prompt_env.get_template("system.jinja2")
|
||||
user_template = prompt_env.get_template("user.jinja2")
|
||||
|
||||
system_prompt = system_template.render()
|
||||
user_prompt = user_template.render(message=message)
|
||||
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('system', system_prompt)
|
||||
log_prompt_rendering('user', user_prompt)
|
||||
# 可选:记录模板渲染信息(仅当 prompt_templates.log 存在时生效)
|
||||
log_template_rendering('system.jinja2', {})
|
||||
log_template_rendering('user.jinja2', {'message': message})
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
async def render_statement_extraction_prompt(
|
||||
chunk_content: str,
|
||||
definitions: dict,
|
||||
json_schema: dict,
|
||||
granularity: int | None = None,
|
||||
include_dialogue_context: bool = False,
|
||||
dialogue_content: str | None = None,
|
||||
max_dialogue_chars: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the statement extraction prompt using the extract_statement.jinja2 template.
|
||||
|
||||
Args:
|
||||
chunk_content: The content of the chunk to process
|
||||
definitions: Label definitions for statement classification
|
||||
json_schema: JSON schema for the expected output format
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_statement.jinja2")
|
||||
# Optional clipping of dialogue context
|
||||
ctx = None
|
||||
if include_dialogue_context and dialogue_content:
|
||||
try:
|
||||
if isinstance(max_dialogue_chars, int) and max_dialogue_chars > 0:
|
||||
ctx = dialogue_content[:max_dialogue_chars]
|
||||
else:
|
||||
ctx = dialogue_content
|
||||
except Exception:
|
||||
ctx = dialogue_content
|
||||
|
||||
rendered_prompt = template.render(
|
||||
inputs={"chunk": chunk_content},
|
||||
definitions=definitions,
|
||||
json_schema=json_schema,
|
||||
granularity=granularity,
|
||||
include_dialogue_context=include_dialogue_context,
|
||||
dialogue_context=ctx,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('statement extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_statement.jinja2', {
|
||||
'inputs': 'chunk',
|
||||
'definitions': 'LABEL_DEFINITIONS',
|
||||
'json_schema': 'StatementExtractionResponse.schema',
|
||||
'granularity': 'int|None',
|
||||
'include_dialogue_context': include_dialogue_context,
|
||||
'dialogue_context_len': (len(ctx) if isinstance(ctx, str) else 0),
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
async def render_temporal_extraction_prompt(
|
||||
ref_dates: dict,
|
||||
statement: dict,
|
||||
temporal_guide: dict,
|
||||
statement_guide: dict,
|
||||
json_schema: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the temporal extraction prompt using the extract_temporal.jinja2 template.
|
||||
|
||||
Args:
|
||||
ref_dates: Reference dates for context.
|
||||
statement: The statement to process.
|
||||
temporal_guide: Guidance on temporal types.
|
||||
statement_guide: Guidance on statement types.
|
||||
json_schema: JSON schema for the expected output format.
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
template = prompt_env.get_template("extract_temporal.jinja2")
|
||||
inputs = ref_dates | statement
|
||||
rendered_prompt = template.render(
|
||||
inputs=inputs,
|
||||
temporal_guide=temporal_guide,
|
||||
statement_guide=statement_guide,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('temporal extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_temporal.jinja2', {
|
||||
'inputs': 'ref_dates|statement',
|
||||
'temporal_guide': 'dict',
|
||||
'statement_guide': 'dict',
|
||||
'json_schema': 'Temporal.schema'
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
def render_entity_dedup_prompt(
|
||||
entity_a: dict,
|
||||
entity_b: dict,
|
||||
context: dict,
|
||||
json_schema: dict,
|
||||
disambiguation_mode: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Render the entity deduplication prompt using the entity_dedup.jinja2 template.
|
||||
|
||||
Args:
|
||||
entity_a: Dict of entity A attributes
|
||||
entity_b: Dict of entity B attributes
|
||||
context: Dict of computed signals (group/type gate, similarities, co-occurrence, relation statements)
|
||||
json_schema: JSON schema for the structured output (EntityDedupDecision)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("entity_dedup.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
entity_a=entity_a,
|
||||
entity_b=entity_b,
|
||||
same_group=context.get("same_group", False),
|
||||
type_ok=context.get("type_ok", False),
|
||||
type_similarity=context.get("type_similarity", 0.0),
|
||||
name_text_sim=context.get("name_text_sim", 0.0),
|
||||
name_embed_sim=context.get("name_embed_sim", 0.0),
|
||||
name_contains=context.get("name_contains", False),
|
||||
co_occurrence=context.get("co_occurrence", False),
|
||||
relation_statements=context.get("relation_statements", []),
|
||||
json_schema=json_schema,
|
||||
disambiguation_mode=disambiguation_mode,
|
||||
)
|
||||
|
||||
# prompt_logger.info("\n=== RENDERED ENTITY DEDUP PROMPT ===")
|
||||
# prompt_logger.info(rendered_prompt)
|
||||
# prompt_logger.info("\n" + "="*50 + "\n")
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
# async def render_entity_dedup_prompt(
|
||||
# entity_a: dict,
|
||||
# entity_b: dict,
|
||||
# context: dict,
|
||||
# json_schema: dict,
|
||||
# ) -> str:
|
||||
# """
|
||||
# Render the entity deduplication prompt using the entity_dedup.jinja2 template.
|
||||
|
||||
# Args:
|
||||
# entity_a: Dict of entity A attributes
|
||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
|
||||
Args:
|
||||
statement: Statement text to process
|
||||
chunk_content: The content of the chunk to process
|
||||
json_schema: JSON schema for the expected output format
|
||||
predicate_instructions: Optional predicate instructions
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_triplet.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
statement=statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=json_schema,
|
||||
predicate_instructions=predicate_instructions
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_triplet.jinja2', {
|
||||
'statement': 'str',
|
||||
'chunk_content': 'str',
|
||||
'json_schema': 'TripletExtractionResponse.schema',
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS'
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
async def render_memory_summary_prompt(
|
||||
chunk_texts: str,
|
||||
json_schema: dict,
|
||||
max_words: int = 200,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
||||
|
||||
Args:
|
||||
chunk_texts: Concatenated text of conversation chunks
|
||||
json_schema: JSON schema for the expected output format
|
||||
max_words: Maximum words for the summary
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string.
|
||||
"""
|
||||
template = prompt_env.get_template("memory_summary.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
chunk_texts=chunk_texts,
|
||||
json_schema=json_schema,
|
||||
max_words=max_words,
|
||||
)
|
||||
log_prompt_rendering('memory summary', rendered_prompt)
|
||||
log_template_rendering('memory_summary.jinja2', {
|
||||
'chunk_texts_len': len(chunk_texts or ""),
|
||||
'max_words': max_words,
|
||||
'json_schema': 'MemorySummaryResponse.schema'
|
||||
})
|
||||
return rendered_prompt
|
||||
60
app/core/memory/utils/prompt/prompts/entity_dedup.jinja2
Normal file
60
app/core/memory/utils/prompt/prompts/entity_dedup.jinja2
Normal file
@@ -0,0 +1,60 @@
|
||||
===任务===
|
||||
你是一个实体去重/消歧判断助手。你将被提供两个实体的详细信息和上下文,请严格根据指引判断它们是否是同一真实世界实体,并在需要时进行类型消歧。
|
||||
|
||||
模式: {{ '消歧模式' if disambiguation_mode else '去重模式' }}
|
||||
|
||||
===输入===
|
||||
实体A:
|
||||
- 名称: "{{ entity_a.name | default('') }}"
|
||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_a.description | default('') }}"
|
||||
- 别名: {{ entity_a.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_a.fact_summary | default('') }}"
|
||||
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
||||
|
||||
实体B:
|
||||
- 名称: "{{ entity_b.name | default('') }}"
|
||||
- 类型: "{{ entity_b.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_b.description | default('') }}"
|
||||
- 别名: {{ entity_b.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_b.fact_summary | default('') }}"
|
||||
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
||||
|
||||
上下文:
|
||||
- 同组: {{ same_group | default(false) }}
|
||||
- 类型一致或未知类型: {{ type_ok | default(false) }}
|
||||
- 类型相似度(0-1): {{ type_similarity | default(0.0) }}
|
||||
- 名称文本相似度(0-1): {{ name_text_sim | default(0.0) }}
|
||||
- 名称向量相似度(0-1): {{ name_embed_sim | default(0.0) }}
|
||||
- 名称包含关系: {{ name_contains | default(false) }}
|
||||
- 上下文同源(同一语句指向两者): {{ co_occurrence | default(false) }}
|
||||
- 两者相关的关系陈述(来自实体-实体边):
|
||||
{% for s in relation_statements %}
|
||||
- {{ s }}
|
||||
{% endfor %}
|
||||
|
||||
===判定指引===
|
||||
{% if disambiguation_mode %}
|
||||
- 这是“同名但类型不同”的消歧场景。请判断两者是否指向同一真实世界实体。
|
||||
- 综合名称文本/向量相似度、别名、描述、摘要与上下文关系(同源与关系陈述)进行判断。
|
||||
- 若无法充分确定,应保守处理:不合并,并建议阻断该对在其他模糊/启发式合并中出现(block_pair=true)。
|
||||
- 若需要合并(should_merge=true),请选择“规范实体”(canonical_idx)并在可能的情况下给出建议统一类型(suggested_type),建议类型需与上下文一致。
|
||||
- 规范实体优先级:连接强度(strong/both)更高者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。
|
||||
{% else %}
|
||||
- 若实体类型相同或任一为UNKNOWN/空,可放行作为候选;若类型明显冲突(如人 vs 物品),除非别名与描述高度一致,否则判定不同实体。
|
||||
- 综合名称文本/向量相似度、别名、描述、摘要以及上下文关系判断是否为同一实体。
|
||||
- 当上下文同源或存在明确的关系陈述支持同一性(例如同一对象反复被提及或别名对应),可以适度降低判定阈值。
|
||||
- 保守决策:当无法充分确定,不要合并(same_entity=false)。
|
||||
- 若需要合并,选择“保留的规范实体”(canonical_idx)为更合适的一个:
|
||||
- 优先保留连接强度更强(strong/both)者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。
|
||||
{% endif %}
|
||||
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.
|
||||
{{ json_schema }}
|
||||
19
app/core/memory/utils/prompt/prompts/evaluate.jinja2
Normal file
19
app/core/memory/utils/prompt/prompts/evaluate.jinja2
Normal file
@@ -0,0 +1,19 @@
|
||||
你将收到一组记忆对象:{{ evaluate_data }}。
|
||||
任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突)
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
{
|
||||
"data": [ ...与输入同结构的记忆对象数组... ],
|
||||
"conflict": true 或 false,
|
||||
"conflict_memory": 若冲突为 true,则填写与其冲突的记忆对象;否则为 null
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
49
app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2
Normal file
49
app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2
Normal file
@@ -0,0 +1,49 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, dialog_text
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
- ids: [string],编号/ID/订单号/申请号/账号等
|
||||
- amounts: [string],金额/费用/价格相关(带单位或货币符号)
|
||||
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等)
|
||||
- addresses: [string],地址/地点相关文本
|
||||
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
|
||||
|
||||
要求:
|
||||
- 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。
|
||||
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- 仅输出上述键;避免多余解释或字段。
|
||||
#}
|
||||
|
||||
{% set scene_instructions = {
|
||||
'education': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'online_service': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'outbound': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。'
|
||||
} %}
|
||||
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in scene_instructions %}
|
||||
{% set scene_key = 'education' %}
|
||||
{% endif %}
|
||||
|
||||
{% set instruction = scene_instructions[scene_key] %}
|
||||
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
场景说明:{{ instruction }}
|
||||
|
||||
对话全文:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
"""
|
||||
|
||||
只输出严格 JSON(键固定、顺序不限):
|
||||
{
|
||||
"is_related": <true 或 false>,
|
||||
"times": [<string>...],
|
||||
"ids": [<string>...],
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
}
|
||||
207
app/core/memory/utils/prompt/prompts/extract_statement.jinja2
Normal file
207
app/core/memory/utils/prompt/prompts/extract_statement.jinja2
Normal file
@@ -0,0 +1,207 @@
|
||||
{% macro tidy(name) -%}
|
||||
{{ name.replace('_', ' ')}}
|
||||
{%- endmacro %}
|
||||
|
||||
|
||||
===Tasks===
|
||||
|
||||
Your task is to identify and extract declarative statements from the provided conversational chunk based on the detailed extraction guidelines.
|
||||
Each statement must be labeled as per the criteria mentioned below.
|
||||
|
||||
===Inputs===
|
||||
{% if inputs %}
|
||||
{% for key, val in inputs.items() %}
|
||||
- {{ key }}: {{val}}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Extraction Instructions===
|
||||
{% if granularity %}
|
||||
{% if granularity == 3 %}
|
||||
Atomic & Clear: Structure statements to clearly show a single subject-predicate-object relationship. It is better to have multiple smaller statements than one complex one.
|
||||
Context-Independent: Statements must be understandable without needing to read the entire conversation.
|
||||
{% elif granularity == 2 %}
|
||||
Extract statements at the sentence level. Each statement should correspond to a single, complete thought (typically a full sentence from the source) but be rephrased for maximum clarity, removing conversational filler (e.g., 'um,' 'like,' interjections).
|
||||
{% elif granularity == 1 %}
|
||||
Extract only essence sentences and summarize the chunk into multiple, standalone statements, each focusing on factual statements, user preferences, relationships, and salient temporal context.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
Context Resolution Requirements:
|
||||
- Resolve demonstrative pronouns ("that," "this," "those","这个", "那个") to their specific referents
|
||||
- If a statement contains vague references that cannot be resolved from the conversation context, either:
|
||||
a) Expand the statement to include the missing context from earlier in the conversation
|
||||
b) Mark the statement as requiring additional context
|
||||
c) Skip extraction if the statement becomes meaningless without context
|
||||
|
||||
Conversational Context & Co-reference Resolution:
|
||||
- Attribute every statement to the participant who uttered it.
|
||||
- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户").
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context.
|
||||
- Identify and resolve abstract references to their specific names if mentioned.
|
||||
- Expand abbreviations and acronyms to their full form.
|
||||
|
||||
{% if include_dialogue_context %}
|
||||
===Full Dialogue Context===
|
||||
The following is the complete dialogue context to help you understand references, pronouns, and conversational flow:
|
||||
|
||||
{{ dialogue_context }}
|
||||
|
||||
===End of Dialogue Context===
|
||||
{% endif %}
|
||||
|
||||
Filtering and Formatting:
|
||||
|
||||
- Extract only declarative statements.
|
||||
DO NOT extract questions, commands, greetings, or conversational filler.
|
||||
Temporal Precision:
|
||||
|
||||
Include any explicit dates, times, or quantitative qualifiers.
|
||||
If a sentence describes both the start of an event (static) and its ongoing nature (dynamic), extract both as separate statements.
|
||||
|
||||
{%- if definitions %}
|
||||
{%- for section_key, section_dict in definitions.items() %}
|
||||
==== {{ tidy(section_key) | upper }} DEFINITIONS & GUIDANCE ====
|
||||
{%- for category, details in section_dict.items() %}
|
||||
{{ loop.index }}. {{ category }}
|
||||
- Definition: {{ details.get("definition", "") }}
|
||||
{% endfor -%}
|
||||
{% endfor -%}
|
||||
{% endif -%}
|
||||
|
||||
===Examples===
|
||||
Example 1: English Conversation
|
||||
Example Chunk: """
|
||||
Date: March 15, 2024
|
||||
Participants:
|
||||
- Sarah Chen (User)
|
||||
- Assistant (AI)
|
||||
|
||||
User: "I've been trying watercolor painting recently and painted some flowers."
|
||||
AI: "Watercolor painting is very interesting! Watercolor paints are typically made from pigments mixed with binders like gum arabic. How do you like it?"
|
||||
User: "I think the color combinations could use some improvement, but I really like roses and lilies."
|
||||
"""
|
||||
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen has been trying watercolor painting recently.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen painted some flowers.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Watercolor paints are typically made from pigments mixed with binders like gum arabic.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "ATEMPORAL",
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen really likes roses and lilies.",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Example 2: Chinese Conversation (中文对话示例)
|
||||
Example Chunk: """
|
||||
日期: 2024年3月15日
|
||||
参与者:
|
||||
- 张曼婷 (用户)
|
||||
- 小助手 (AI助手)
|
||||
|
||||
用户: "我最近在尝试水彩画,画了一些花朵。"
|
||||
AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
|
||||
用户: "我觉得色彩搭配还有提升的空间,不过我很喜欢玫瑰和百合这两种花。"
|
||||
"""
|
||||
|
||||
Example Output: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "ATEMPORAL",
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
}
|
||||
]
|
||||
}
|
||||
===End of Examples===
|
||||
|
||||
===Reflection Process===
|
||||
|
||||
After extracting statements, perform the following self-review steps:
|
||||
|
||||
**Step 1: Attribution Check**
|
||||
- Confirm every statement is properly attributed to the correct speaker
|
||||
- Verify speaker names are used consistently throughout
|
||||
- Check that AI assistant statements are properly attributed
|
||||
|
||||
**Step 2: Completeness Review**
|
||||
- Ensure no important declarative statements were missed
|
||||
- Check that temporal information is preserved
|
||||
|
||||
**Step 3: Classification Validation**
|
||||
- Review statement_type classifications (FACT/OPINION/PREDICTION/SUGGESTION)
|
||||
- Verify temporal_type assignments (STATIC/DYNAMIC/ATEMPORAL)
|
||||
- Ensure classifications align with the provided definitions
|
||||
|
||||
**Step 4: Final Quality Check**
|
||||
- Remove any questions, commands, or conversational filler
|
||||
- Verify JSON format compliance
|
||||
- Confirm output language matches input language
|
||||
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
5. Example of proper escaping: "statement": "John said: \"I really like this book.\""
|
||||
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
- The output language should ALWAYS match the input language
|
||||
- If input is in English, extract statements in English
|
||||
- If input is in Chinese, extract statements in Chinese
|
||||
- Preserve the original language and do not translate
|
||||
|
||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||
{{ json_schema }}
|
||||
81
app/core/memory/utils/prompt/prompts/extract_temporal.jinja2
Normal file
81
app/core/memory/utils/prompt/prompts/extract_temporal.jinja2
Normal file
@@ -0,0 +1,81 @@
|
||||
|
||||
{% macro tidy(name) -%}
|
||||
{{ name.replace('_', ' ')}}
|
||||
{%- endmacro %}
|
||||
{#
|
||||
This prompt (template) is adapted from [getzep/graphiti]
|
||||
Licensed under the Apache License, Version 2.0
|
||||
|
||||
Original work:
|
||||
https://github.com/getzep/graphiti/blob/main/graphiti_core/prompts/extract_edge_dates.py
|
||||
|
||||
Modifications made by Ke Sun on 2025-09-01
|
||||
See the LICENSE file for the full Apache 2.0 license text.
|
||||
#}
|
||||
# Task
|
||||
|
||||
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
|
||||
|
||||
# Input Data
|
||||
{% if inputs %}
|
||||
{% for key, val in inputs.items() %}
|
||||
- {{ key }}: {{val}}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
# Temporal Fields
|
||||
|
||||
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
|
||||
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
|
||||
|
||||
# Extraction Rules
|
||||
|
||||
## Core Principles
|
||||
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
|
||||
2. **Use the reference/publication date as "now"** when interpreting relative times
|
||||
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
|
||||
4. **For point-in-time events**, set only `valid_at`
|
||||
|
||||
## Date Format Requirements
|
||||
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
|
||||
- If no time specified, use `00:00:00` (midnight)
|
||||
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
|
||||
- If only month mentioned, use first or last day of month
|
||||
- Always include timezone (use `Z` for UTC if unspecified)
|
||||
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
|
||||
|
||||
## Statement Type Rules
|
||||
|
||||
{{ inputs.get("statement_type") | upper }} Statement Guidance:
|
||||
{%for key, guide in statement_guide.items() %}
|
||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
||||
{% endfor %}
|
||||
|
||||
**Special Cases:**
|
||||
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
|
||||
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
|
||||
|
||||
## Temporal Type Rules
|
||||
|
||||
{{ inputs.get("temporal_type") | upper }} Temporal Type Guidance:
|
||||
{% for key, guide in temporal_guide.items() %}
|
||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
||||
{% endfor %}
|
||||
|
||||
{% if inputs.get('quarter') and inputs.get('publication_date') %}
|
||||
## Quarter Reference
|
||||
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
|
||||
{% endif %}
|
||||
|
||||
# Output Requirements
|
||||
|
||||
## JSON Formatting (CRITICAL)
|
||||
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
|
||||
2. Escape internal quotes with backslash: `\"`
|
||||
3. No line breaks within JSON string values
|
||||
4. Properly close and comma-separate all fields
|
||||
|
||||
## Language
|
||||
Output language must match input language.
|
||||
|
||||
{{ json_schema }}
|
||||
248
app/core/memory/utils/prompt/prompts/extract_triplet.jinja2
Normal file
248
app/core/memory/utils/prompt/prompts/extract_triplet.jinja2
Normal file
@@ -0,0 +1,248 @@
|
||||
{% macro tidy(name) -%}
|
||||
{{ name.replace('_', ' ')}}
|
||||
{%- endmacro %}
|
||||
|
||||
===Task===
|
||||
Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
|
||||
===Guidelines===
|
||||
|
||||
**Entity Extraction:**
|
||||
- Extract entities with their types and context-independent descriptions
|
||||
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
|
||||
- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||
Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
|
||||
|
||||
**Triplet Extraction:**
|
||||
- Extract (subject, predicate, object) triplets where:
|
||||
- Subject: main entity performing the action or being described
|
||||
- Predicate: relationship between entities (e.g., 'is', 'works at', 'believes')
|
||||
- Object: entity, value, or concept affected by the predicate
|
||||
- Exclude all temporal expressions from every field
|
||||
- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens)
|
||||
- Do NOT translate predicate tokens
|
||||
- Do NOT include `statement_id` field (assigned automatically)
|
||||
|
||||
**When NOT to extract triplets:**
|
||||
- Non-propositional utterances (emotions, fillers, onomatopoeia)
|
||||
- No clear predicate from the given definitions applies
|
||||
- Standalone noun phrases or checklist items (e.g., "三脚架", "备用电池") → extract as entities only
|
||||
- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS")
|
||||
|
||||
**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty.
|
||||
{%- if predicate_instructions -%}
|
||||
|
||||
**Predicate Instructions:**
|
||||
Use ONLY these predicates. If none fits, set triplets to [].
|
||||
{%- for pred, instruction in predicate_instructions.items() %}
|
||||
- {{ pred }}: {{ instruction }}
|
||||
{%- endfor -%}
|
||||
{%- endif -%}
|
||||
|
||||
|
||||
===Examples===
|
||||
|
||||
**Example 1 (English):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "I",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "Paris",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "I",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "Louvre",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "I",
|
||||
"type": "Person",
|
||||
"description": "The user"
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Paris",
|
||||
"type": "Location",
|
||||
"description": "Capital city of France"
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "Louvre",
|
||||
"type": "Location",
|
||||
"description": "World-famous museum located in Paris"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 2 (English):** "John Smith works at Google and is responsible for AI product development."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "John Smith",
|
||||
"subject_id": 0,
|
||||
"predicate": "WORKS_AT",
|
||||
"object_name": "Google",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "John Smith",
|
||||
"subject_id": 0,
|
||||
"predicate": "RESPONSIBLE_FOR",
|
||||
"object_name": "AI product development",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "John Smith",
|
||||
"type": "Person",
|
||||
"description": "Individual person name"
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Google",
|
||||
"type": "Organization",
|
||||
"description": "American technology company"
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI product development",
|
||||
"type": "WorkRole",
|
||||
"description": "Artificial intelligence product development work"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 3 (Chinese):** "我计划下周去巴黎旅行,参观卢浮宫。"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "我",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "巴黎",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "我",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "卢浮宫",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "我",
|
||||
"type": "Person",
|
||||
"description": "用户本人"
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "巴黎",
|
||||
"type": "Location",
|
||||
"description": "法国首都城市"
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "卢浮宫",
|
||||
"type": "Location",
|
||||
"description": "位于巴黎的世界著名博物馆"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (Chinese):** "张明在腾讯工作,负责AI产品开发。"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "张明",
|
||||
"subject_id": 0,
|
||||
"predicate": "WORKS_AT",
|
||||
"object_name": "腾讯",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "张明",
|
||||
"subject_id": 0,
|
||||
"predicate": "RESPONSIBLE_FOR",
|
||||
"object_name": "AI产品开发",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "张明",
|
||||
"type": "Person",
|
||||
"description": "个人姓名"
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "腾讯",
|
||||
"type": "Organization",
|
||||
"description": "中国科技公司"
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI产品开发",
|
||||
"type": "WorkRole",
|
||||
"description": "人工智能产品研发工作"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 5 (Entity Only):** "Tripod" or "三脚架"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "Tripod",
|
||||
"type": "Equipment",
|
||||
"description": "Photography equipment accessory"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
===Output Format===
|
||||
|
||||
**JSON Requirements:**
|
||||
- Use only ASCII double quotes (") for JSON structure
|
||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- The output language should ALWAYS match the input language
|
||||
- If input is in English, extract statements in English
|
||||
- If input is in Chinese, extract statements in Chinese
|
||||
- Preserve the original language and do not translate
|
||||
|
||||
{{ json_schema }}
|
||||
29
app/core/memory/utils/prompt/prompts/memory_summary.jinja2
Normal file
29
app/core/memory/utils/prompt/prompts/memory_summary.jinja2
Normal file
@@ -0,0 +1,29 @@
|
||||
{% macro tidy(name) -%}
|
||||
{{ name.replace('_', ' ') }}
|
||||
{%- endmacro %}
|
||||
|
||||
=== Task ===
|
||||
Summarize the provided conversation chunks into a concise Memory summary.
|
||||
|
||||
=== Requirements ===
|
||||
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
||||
- Avoid repetition and filler; be specific.
|
||||
- Keep it under {{ max_words or 200 }} words.
|
||||
- Output must be valid JSON conforming to the schema below.
|
||||
|
||||
=== Input ===
|
||||
{% if chunk_texts %}
|
||||
{{ chunk_texts }}
|
||||
{% endif %}
|
||||
|
||||
=== Output Schema ===
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
||||
|
||||
The output language should always be the same as the input language.
|
||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||
{{ json_schema }}
|
||||
23
app/core/memory/utils/prompt/prompts/reflexion.jinja2
Normal file
23
app/core/memory/utils/prompt/prompts/reflexion.jinja2
Normal file
@@ -0,0 +1,23 @@
|
||||
你将收到一条冲突判定对象:{{ data }}。
|
||||
任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
{
|
||||
"conflict": 与输入同结构,包含 data 与 conflict_memory,
|
||||
"reflexion": { "reason": string, "solution": string },
|
||||
"resolved": {
|
||||
"original_memory_id": 被设为失效的记忆 id,
|
||||
"resolved_memory": 完整的设为失效后的记忆对象
|
||||
}
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- 当 conflict 为 false 时,resolved 必须为 null。
|
||||
- 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。
|
||||
模式参考:
|
||||
[
|
||||
{{ json_schema }}
|
||||
]
|
||||
2
app/core/memory/utils/prompt/prompts/system.jinja2
Normal file
2
app/core/memory/utils/prompt/prompts/system.jinja2
Normal file
@@ -0,0 +1,2 @@
|
||||
You are an AI assistant that extracts entity nodes from conversational messages.
|
||||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.
|
||||
5
app/core/memory/utils/prompt/prompts/user.jinja2
Normal file
5
app/core/memory/utils/prompt/prompts/user.jinja2
Normal file
@@ -0,0 +1,5 @@
|
||||
You are given a conversation context and a CURRENT MESSAGE.
|
||||
Your task is to extract user name and age mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
||||
Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the reference entities.
|
||||
|
||||
{{ message }}
|
||||
42
app/core/memory/utils/prompt/template_render.py
Normal file
42
app/core/memory/utils/prompt/template_render.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Setup Jinja2 environment
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate.jinja2 template.
|
||||
|
||||
Args:
|
||||
evaluate_data: The data to evaluate
|
||||
schema: The JSON schema to use for the output.
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the extract_temporal.jinja2 template.
|
||||
|
||||
Args:
|
||||
data: The data to reflex on.
|
||||
schema: The JSON schema to use for the output.
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=schema)
|
||||
|
||||
return rendered_prompt
|
||||
16
app/core/memory/utils/self_reflexion_utils/__init__.py
Normal file
16
app/core/memory/utils/self_reflexion_utils/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""自我反思工具模块
|
||||
|
||||
本模块提供自我反思引擎的核心功能,包括:
|
||||
- 记忆冲突判定
|
||||
- 反思执行
|
||||
- 记忆更新
|
||||
|
||||
从 app.core.memory.src.data_config_api 迁移而来。
|
||||
"""
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
|
||||
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
__all__ = ["conflict", "reflexion", "self_reflexion"]
|
||||
49
app/core/memory/utils/self_reflexion_utils/evaluate.py
Normal file
49
app/core/memory/utils/self_reflexion_utils/evaluate.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""记忆冲突判定模块
|
||||
|
||||
本模块提供记忆冲突判定功能,使用LLM判断记忆数据中是否存在冲突。
|
||||
从 app.core.memory.src.data_config_api.evaluate 迁移而来。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
async def conflict(evaluate_data: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Evaluates memory conflict using the evaluate.jinja2 template.
|
||||
|
||||
Args:
|
||||
evaluate_data: 反思数据列表。
|
||||
Returns:
|
||||
冲突记忆列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
print(f"====== 冲突判定开始 ======\n")
|
||||
start_time = time.time()
|
||||
response = await client.response_structured(messages, ConflictResultSchema)
|
||||
end_time = time.time()
|
||||
print(f"冲突判定耗时: {end_time - start_time} 秒")
|
||||
print(f"冲突判定原始输出:(type={type(response)})\n{response}")
|
||||
|
||||
if not response:
|
||||
logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。")
|
||||
return []
|
||||
try:
|
||||
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
|
||||
except Exception:
|
||||
try:
|
||||
return [response.dict()]
|
||||
except Exception:
|
||||
logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。")
|
||||
return [response]
|
||||
51
app/core/memory/utils/self_reflexion_utils/reflexion.py
Normal file
51
app/core/memory/utils/self_reflexion_utils/reflexion.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""反思执行模块
|
||||
|
||||
本模块提供反思执行功能,使用LLM对冲突记忆进行反思和解决。
|
||||
从 app.core.memory.src.data_config_api.reflexion 迁移而来。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Any
|
||||
import time
|
||||
|
||||
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
async def reflexion(ref_data: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Reflexes on the given reference data using the reflexion.jinja2 template.
|
||||
|
||||
Args:
|
||||
ref_data: 反思数据列表。
|
||||
Returns:
|
||||
反思结果列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
client = get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
print(f"====== 反思开始 ======\n")
|
||||
start_time = time.time()
|
||||
response = await client.response_structured(messages, ReflexionResultSchema)
|
||||
end_time = time.time()
|
||||
print(f"反思耗时: {end_time - start_time} 秒")
|
||||
print(f"反思原始输出:(type={type(response)})\n{response}")
|
||||
|
||||
if not response:
|
||||
logging.error("LLM 反思输出解析失败,返回空列表以继续流程。")
|
||||
return []
|
||||
# 统一返回为列表[dict],便于自我反思主流程更新数据库
|
||||
try:
|
||||
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
|
||||
except Exception:
|
||||
try:
|
||||
return [response.dict()]
|
||||
except Exception:
|
||||
logging.warning("无法标准化反思返回类型,尝试直接封装为列表。")
|
||||
return [response]
|
||||
250
app/core/memory/utils/self_reflexion_utils/self_reflexion.py
Normal file
250
app/core/memory/utils/self_reflexion_utils/self_reflexion.py
Normal file
@@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""自我反思主执行模块
|
||||
|
||||
本模块提供自我反思引擎的主流程,包括:
|
||||
- 获取反思数据
|
||||
- 冲突判断
|
||||
- 反思执行
|
||||
- 记忆更新
|
||||
|
||||
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
import uuid
|
||||
|
||||
from app.core.memory.utils.config.definitions import (
|
||||
REFLEXION_ENABLED,
|
||||
REFLEXION_ITERATION_PERIOD,
|
||||
REFLEXION_RANGE,
|
||||
REFLEXION_BASELINE,
|
||||
)
|
||||
from app.db import get_db
|
||||
from sqlalchemy.orm import Session
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
|
||||
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
|
||||
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
# 并发限制(可通过环境变量覆盖)
|
||||
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))
|
||||
|
||||
# 确保 INFO 级别日志输出到终端
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
根据反思范围获取判断的记忆数据。
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
Returns:
|
||||
符合反思范围的记忆数据列表。
|
||||
"""
|
||||
if REFLEXION_RANGE == "retrieval":
|
||||
return await get_data(host_id)
|
||||
elif REFLEXION_RANGE == "database":
|
||||
return []
|
||||
else:
|
||||
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")
|
||||
|
||||
|
||||
async def run_conflict(conflict_data: List[Any]) -> List[Any]:
|
||||
"""
|
||||
判断反思数据中是否存在冲突。
|
||||
|
||||
Args:
|
||||
conflict_data: 冲突数据列表。
|
||||
Returns:
|
||||
如果存在冲突则返回冲突记忆列表,否则返回空列表。
|
||||
"""
|
||||
if not conflict_data:
|
||||
return []
|
||||
|
||||
conflict_data = await conflict(conflict_data)
|
||||
# 仅保留存在冲突的条目(conflict == True)
|
||||
try:
|
||||
return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True]
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
async def run_reflexion(reflexion_data: List[Any]) -> Any:
|
||||
"""
|
||||
执行反思,解决冲突。
|
||||
|
||||
Args:
|
||||
reflexion_data: 反思数据列表。
|
||||
Returns:
|
||||
解决冲突后的反思结果(由 LLM 返回)。
|
||||
"""
|
||||
if not reflexion_data:
|
||||
return []
|
||||
# 并行对每个冲突进行反思,整体缩短等待时间
|
||||
sem = asyncio.Semaphore(CONCURRENCY)
|
||||
|
||||
async def _reflex_one(item: Any) -> Dict[str, Any] | None:
|
||||
async with sem:
|
||||
try:
|
||||
result_list = await reflexion([item])
|
||||
if not result_list:
|
||||
return None
|
||||
obj = result_list[0]
|
||||
if hasattr(obj, "model_dump"):
|
||||
return obj.model_dump()
|
||||
elif hasattr(obj, "dict"):
|
||||
return obj.dict()
|
||||
elif isinstance(obj, dict):
|
||||
return obj
|
||||
except Exception as e:
|
||||
logging.warning(f"反思失败,跳过一项: {e}")
|
||||
return None
|
||||
|
||||
tasks = [_reflex_one(item) for item in reflexion_data]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
return [r for r in results if r]
|
||||
|
||||
|
||||
async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
更新记忆库,将解决冲突后的记忆更新到记忆库中。
|
||||
|
||||
Args:
|
||||
solved_data: 解决冲突后的记忆(由 LLM 返回)。
|
||||
host_id: 主机ID
|
||||
Returns:
|
||||
更新结果(成功或失败)。
|
||||
"""
|
||||
flag = False
|
||||
if not solved_data:
|
||||
return "数据缺失,更新失败"
|
||||
if not isinstance(solved_data, list):
|
||||
return "数据格式错误,更新失败"
|
||||
neo4j_connector = Neo4jConnector()
|
||||
try:
|
||||
print(f"====== 更新记忆开始 ======\n")
|
||||
|
||||
sem = asyncio.Semaphore(CONCURRENCY)
|
||||
success_count = 0
|
||||
|
||||
async def _update_one(item: Dict[str, Any]) -> bool:
|
||||
async with sem:
|
||||
try:
|
||||
if not isinstance(item, dict):
|
||||
return False
|
||||
if not item:
|
||||
return False
|
||||
resolved = item.get("resolved")
|
||||
if not isinstance(resolved, dict) or not resolved:
|
||||
logging.warning(f"反思结果无可更新内容,跳过此项: {item}")
|
||||
return False
|
||||
resolved_mem = resolved.get("resolved_memory")
|
||||
if not isinstance(resolved_mem, dict) or not resolved_mem:
|
||||
logging.warning(f"反思结果缺少 resolved_memory,跳过此项: {item}")
|
||||
return False
|
||||
group_id = resolved_mem.get("group_id")
|
||||
id = resolved_mem.get("id")
|
||||
# 使用 invalid_at 字段作为新的失效时间
|
||||
new_invalid_at = resolved_mem.get("invalid_at")
|
||||
if not all([group_id, id, new_invalid_at]):
|
||||
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
|
||||
return False
|
||||
await neo4j_connector.execute_query(
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
group_id=group_id,
|
||||
id=id,
|
||||
new_invalid_at=new_invalid_at,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"更新单条记忆失败: {e}")
|
||||
return False
|
||||
|
||||
tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
success_count = sum(1 for r in results if r)
|
||||
|
||||
logging.info(f"成功更新 {success_count} 条记忆")
|
||||
flag = success_count > 0
|
||||
return "更新成功" if flag else "更新失败"
|
||||
except Exception as e:
|
||||
logging.error(f"更新记忆库失败: {e}")
|
||||
return "更新失败"
|
||||
finally:
|
||||
if flag: # 删除数据库中的检索数据
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete()
|
||||
db.commit()
|
||||
logging.info(f"成功删除 {success_count} 条检索数据")
|
||||
except Exception as e:
|
||||
logging.error(f"删除数据库中的检索数据失败: {e}")
|
||||
|
||||
|
||||
async def _append_json(label: str, data: Any) -> None:
|
||||
"""记录冲突记忆(后台线程写入,避免阻塞事件循环)"""
|
||||
def _write():
|
||||
with open("reflexion_data.json", "a", encoding="utf-8") as f:
|
||||
f.write(f"### {label} ###\n")
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
f.write("\n\n")
|
||||
# 正确地在协程内等待后台线程执行,避免未等待的协程警告
|
||||
await asyncio.to_thread(_write)
|
||||
|
||||
|
||||
async def self_reflexion(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
自我反思引擎,执行反思流程。
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
|
||||
Returns:
|
||||
反思结果描述字符串
|
||||
"""
|
||||
if not REFLEXION_ENABLED:
|
||||
return "未开启反思..."
|
||||
print(f"====== 自我反思流程开始 ======\n")
|
||||
reflexion_data = await get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return "无反思数据,结束反思"
|
||||
print(f"反思数据获取成功,共 {len(reflexion_data)} 条")
|
||||
|
||||
conflict_data = await run_conflict(reflexion_data)
|
||||
if not conflict_data:
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return "无冲突,无需反思"
|
||||
print(f"冲突记忆类型: {type(conflict_data)}")
|
||||
await _append_json("conflict", conflict_data)
|
||||
|
||||
solved_data = await run_reflexion(conflict_data)
|
||||
if not solved_data:
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return "反思失败,未解决冲突"
|
||||
print(f"解决冲突后的记忆类型: {type(solved_data)}")
|
||||
await _append_json("solved_data", solved_data)
|
||||
|
||||
result = await update_memory(solved_data, host_id)
|
||||
print(f"更新记忆库结果: {result}")
|
||||
print(f"====== 自我反思流程结束 ======\n")
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
# host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333")
|
||||
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
|
||||
asyncio.run(self_reflexion(host_id))
|
||||
26
app/core/memory/utils/visualization/__init__.py
Normal file
26
app/core/memory/utils/visualization/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
可视化模块
|
||||
|
||||
包含所有可视化相关的工具函数,主要用于遗忘曲线的可视化。
|
||||
"""
|
||||
|
||||
# 从子模块导出常用函数,保持向后兼容
|
||||
from .forgetting_visualizer import (
|
||||
export_memory_curve_numpy,
|
||||
export_memory_curves_multiple_strengths,
|
||||
export_parameter_sweep_numpy,
|
||||
visualize_forgetting_curve,
|
||||
plot_3d_forgetting_surface,
|
||||
create_comparison_visualization,
|
||||
save_memory_curves_to_file,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"export_memory_curve_numpy",
|
||||
"export_memory_curves_multiple_strengths",
|
||||
"export_parameter_sweep_numpy",
|
||||
"visualize_forgetting_curve",
|
||||
"plot_3d_forgetting_surface",
|
||||
"create_comparison_visualization",
|
||||
"save_memory_curves_to_file",
|
||||
]
|
||||
386
app/core/memory/utils/visualization/forgetting_visualizer.py
Normal file
386
app/core/memory/utils/visualization/forgetting_visualizer.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Memory Visualization Utilities
|
||||
|
||||
This module provides visualization functions for the modified Ebbinghaus forgetting curve
|
||||
and utilities to export memory curves as numpy arrays.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
import math
|
||||
|
||||
|
||||
def export_memory_curve_numpy(forgetting_engine,
|
||||
time_range: Tuple[float, float] = (0, 10),
|
||||
memory_strength: float = 1.0,
|
||||
num_points: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Export memory curve as numpy arrays for time and retention values.
|
||||
|
||||
Args:
|
||||
forgetting_engine: Instance of ForgettingEngine
|
||||
time_range: Tuple of (start_time, end_time)
|
||||
memory_strength: Memory strength value to use
|
||||
num_points: Number of points to generate
|
||||
|
||||
Returns:
|
||||
Tuple of (time_array, retention_array)
|
||||
"""
|
||||
start_time, end_time = time_range
|
||||
time_array = np.linspace(start_time, end_time, num_points)
|
||||
retention_array = np.array([
|
||||
forgetting_engine.forgetting_curve(t, memory_strength)
|
||||
for t in time_array
|
||||
])
|
||||
|
||||
return time_array, retention_array
|
||||
|
||||
|
||||
def export_memory_curves_multiple_strengths(forgetting_engine,
|
||||
time_range: Tuple[float, float] = (0, 10),
|
||||
memory_strengths: List[float] = None,
|
||||
num_points: int = 1000) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Export memory curves for multiple memory strengths as numpy arrays.
|
||||
|
||||
Args:
|
||||
forgetting_engine: Instance of ForgettingEngine
|
||||
time_range: Tuple of (start_time, end_time)
|
||||
memory_strengths: List of memory strength values
|
||||
num_points: Number of points to generate
|
||||
|
||||
Returns:
|
||||
Dictionary with 'time' and retention arrays for each strength
|
||||
"""
|
||||
if memory_strengths is None:
|
||||
memory_strengths = [0.5, 1.0, 2.0, 5.0]
|
||||
|
||||
start_time, end_time = time_range
|
||||
time_array = np.linspace(start_time, end_time, num_points)
|
||||
|
||||
result = {'time': time_array}
|
||||
|
||||
for strength in memory_strengths:
|
||||
retention_array = np.array([
|
||||
forgetting_engine.forgetting_curve(t, strength)
|
||||
for t in time_array
|
||||
])
|
||||
result[f'strength_{strength}'] = retention_array
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def export_parameter_sweep_numpy(base_engine,
|
||||
parameter_name: str,
|
||||
parameter_values: List[float],
|
||||
time_range: Tuple[float, float] = (0, 10),
|
||||
memory_strength: float = 1.0,
|
||||
num_points: int = 1000) -> Dict[str, np.ndarray]:
|
||||
"""
|
||||
Export memory curves for parameter sweep as numpy arrays.
|
||||
|
||||
Args:
|
||||
base_engine: Base ForgettingEngine instance
|
||||
parameter_name: Name of parameter to sweep ('offset', 'lambda_time', 'lambda_mem')
|
||||
parameter_values: List of parameter values to test
|
||||
time_range: Tuple of (start_time, end_time)
|
||||
memory_strength: Memory strength value to use
|
||||
num_points: Number of points to generate
|
||||
|
||||
Returns:
|
||||
Dictionary with 'time' and retention arrays for each parameter value
|
||||
"""
|
||||
from app.core.memory.storage_services.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
|
||||
start_time, end_time = time_range
|
||||
time_array = np.linspace(start_time, end_time, num_points)
|
||||
|
||||
result = {'time': time_array}
|
||||
|
||||
for param_value in parameter_values:
|
||||
# Create new engine with modified parameter
|
||||
if parameter_name == 'offset':
|
||||
config = ForgettingEngineConfig(offset=param_value, lambda_time=base_engine.lambda_time, lambda_mem=base_engine.lambda_mem)
|
||||
elif parameter_name == 'lambda_time':
|
||||
config = ForgettingEngineConfig(offset=base_engine.offset, lambda_time=param_value, lambda_mem=base_engine.lambda_mem)
|
||||
elif parameter_name == 'lambda_mem':
|
||||
config = ForgettingEngineConfig(offset=base_engine.offset, lambda_time=base_engine.lambda_time, lambda_mem=param_value)
|
||||
else:
|
||||
raise ValueError(f"Unknown parameter: {parameter_name}")
|
||||
|
||||
engine = ForgettingEngine(config)
|
||||
|
||||
retention_array = np.array([
|
||||
engine.forgetting_curve(t, memory_strength)
|
||||
for t in time_array
|
||||
])
|
||||
result[f'{parameter_name}_{param_value}'] = retention_array
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def visualize_forgetting_curve(forgetting_engine,
|
||||
max_time: float = 10.0,
|
||||
memory_strengths: Optional[List[float]] = None,
|
||||
figsize: Tuple[int, int] = (12, 8)) -> None:
|
||||
"""
|
||||
Visualize the modified Ebbinghaus forgetting curve.
|
||||
|
||||
Args:
|
||||
forgetting_engine: Instance of ForgettingEngine
|
||||
max_time: Maximum time to plot
|
||||
memory_strengths: List of memory strength values to plot
|
||||
figsize: Figure size for the plot
|
||||
"""
|
||||
if memory_strengths is None:
|
||||
memory_strengths = [0.5, 1.0, 2.0, 5.0]
|
||||
|
||||
# Create time array
|
||||
t = np.linspace(0, max_time, 1000)
|
||||
|
||||
# Create subplots
|
||||
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
|
||||
fig.suptitle('Modified Ebbinghaus Forgetting Curve Analysis', fontsize=16, fontweight='bold')
|
||||
|
||||
# Plot 1: Different memory strengths
|
||||
ax1.set_title('Effect of Memory Strength (S)')
|
||||
for S in memory_strengths:
|
||||
retention = [forgetting_engine.forgetting_curve(time, S) for time in t]
|
||||
ax1.plot(t, retention, label=f'S = {S}', linewidth=2)
|
||||
ax1.set_xlabel('Time')
|
||||
ax1.set_ylabel('Memory Retention')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
ax1.set_ylim(0, 1)
|
||||
|
||||
# Plot 2: Different lambda_time values
|
||||
ax2.set_title('Effect of λ_time')
|
||||
lambda_times = [0.5, 1.0, 0.3]
|
||||
lambda_mem = [0.5,0.3,1.0]
|
||||
offset_mem = [0.1,0.05,0.2]
|
||||
for i in range(len(lambda_times)):
|
||||
lt = lambda_times[i]
|
||||
lm = lambda_mem[i]
|
||||
off = offset_mem[i]
|
||||
from app.core.memory.storage_services.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
config = ForgettingEngineConfig(offset=off, lambda_time=lt, lambda_mem=lm)
|
||||
temp_engine = ForgettingEngine(config)
|
||||
retention = [temp_engine.forgetting_curve(time, 1.0) for time in t]
|
||||
ax2.plot(t, retention, label=f'λ_time = {lt}', linewidth=2)
|
||||
ax2.set_xlabel('Time')
|
||||
ax2.set_ylabel('Memory Retention')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
ax2.set_ylim(0, 1)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_3d_forgetting_surface(forgetting_engine,
|
||||
max_time: float = 10.0,
|
||||
max_strength: float = 5.0,
|
||||
figsize: Tuple[int, int] = (12, 9)) -> None:
|
||||
"""
|
||||
Create a 3D surface plot of the forgetting curve.
|
||||
|
||||
Args:
|
||||
forgetting_engine: Instance of ForgettingEngine
|
||||
max_time: Maximum time to plot
|
||||
max_strength: Maximum memory strength to plot
|
||||
figsize: Figure size for the plot
|
||||
"""
|
||||
# Create meshgrid
|
||||
t = np.linspace(0.1, max_time, 50)
|
||||
S = np.linspace(0.1, max_strength, 50)
|
||||
T, S_mesh = np.meshgrid(t, S)
|
||||
|
||||
# Calculate retention for each point
|
||||
R = np.zeros_like(T)
|
||||
for i in range(T.shape[0]):
|
||||
for j in range(T.shape[1]):
|
||||
R[i, j] = forgetting_engine.forgetting_curve(T[i, j], S_mesh[i, j])
|
||||
|
||||
# Create 3D plot
|
||||
fig = plt.figure(figsize=figsize)
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
surface = ax.plot_surface(T, S_mesh, R, cmap='viridis', alpha=0.8)
|
||||
|
||||
ax.set_xlabel('Time (t)')
|
||||
ax.set_ylabel('Memory Strength (S)')
|
||||
ax.set_zlabel('Memory Retention (R)')
|
||||
ax.set_title(f'3D Forgetting Curve Surface\n(offset={forgetting_engine.offset}, λ_time={forgetting_engine.lambda_time}, λ_mem={forgetting_engine.lambda_mem})')
|
||||
|
||||
# Add colorbar
|
||||
fig.colorbar(surface, shrink=0.5, aspect=5)
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
def create_comparison_visualization(forgetting_engine, figsize: Tuple[int, int] = (15, 10)) -> None:
|
||||
"""
|
||||
Create a comparison visualization of different curve configurations.
|
||||
|
||||
Args:
|
||||
forgetting_engine: Instance of ForgettingEngine
|
||||
figsize: Figure size for the plot
|
||||
"""
|
||||
# Create figure with multiple subplots
|
||||
fig, axes = plt.subplots(2, 2, figsize=figsize)
|
||||
fig.suptitle('Modified Ebbinghaus Forgetting Curve - Parameter Comparison', fontsize=16, fontweight='bold')
|
||||
|
||||
t = np.linspace(0, 10, 100)
|
||||
|
||||
# Plot 1: Original vs Modified curve
|
||||
ax1 = axes[0, 0]
|
||||
ax1.set_title('Original vs Modified Ebbinghaus Curve')
|
||||
|
||||
# Original Ebbinghaus: R = e^(-t/S)
|
||||
S = 2.0
|
||||
original = np.exp(-t / S)
|
||||
ax1.plot(t, original, 'r--', label='Original: R = e^(-t/S)', linewidth=2)
|
||||
|
||||
# Modified with offset
|
||||
modified = [forgetting_engine.forgetting_curve(time, S) for time in t]
|
||||
ax1.plot(t, modified, 'b-', label='Modified: offset + (1-offset)*e^(-λ_time*t/λ_mem*S)', linewidth=2)
|
||||
|
||||
ax1.set_xlabel('Time')
|
||||
ax1.set_ylabel('Memory Retention')
|
||||
ax1.legend()
|
||||
ax1.grid(True, alpha=0.3)
|
||||
ax1.set_ylim(0, 1)
|
||||
|
||||
# Plot 2: Different offset values
|
||||
ax2 = axes[0, 1]
|
||||
ax2.set_title('Effect of Offset Parameter')
|
||||
|
||||
for offset in [0.0, 0.1, 0.2, 0.3]:
|
||||
from forgetting.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
config = ForgettingEngineConfig(offset=offset, lambda_time=1.0, lambda_mem=1.0)
|
||||
engine = ForgettingEngine(config)
|
||||
retention = [engine.forgetting_curve(time, 1.0) for time in t]
|
||||
ax2.plot(t, retention, label=f'offset = {offset}', linewidth=2)
|
||||
|
||||
ax2.set_xlabel('Time')
|
||||
ax2.set_ylabel('Memory Retention')
|
||||
ax2.legend()
|
||||
ax2.grid(True, alpha=0.3)
|
||||
ax2.set_ylim(0, 1)
|
||||
|
||||
# Plot 3: Lambda time effect
|
||||
ax3 = axes[1, 0]
|
||||
ax3.set_title('Effect of λ_time (Time Sensitivity)')
|
||||
|
||||
for lambda_time in [0.5, 1.0, 2.0, 3.0]:
|
||||
from forgetting.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.models.config_models import ForgettingEngineConfig
|
||||
config = ForgettingEngineConfig(offset=0.1, lambda_time=lambda_time, lambda_mem=1.0)
|
||||
engine = ForgettingEngine(config)
|
||||
retention = [engine.forgetting_curve(time, 1.0) for time in t]
|
||||
ax3.plot(t, retention, label=f'λ_time = {lambda_time}', linewidth=2)
|
||||
|
||||
ax3.set_xlabel('Time')
|
||||
ax3.set_ylabel('Memory Retention')
|
||||
ax3.legend()
|
||||
ax3.grid(True, alpha=0.3)
|
||||
ax3.set_ylim(0, 1)
|
||||
|
||||
# Plot 4: Memory strength effect
|
||||
ax4 = axes[1, 1]
|
||||
ax4.set_title('Effect of Memory Strength (S)')
|
||||
|
||||
for strength in [0.5, 1.0, 2.0, 4.0]:
|
||||
retention = [forgetting_engine.forgetting_curve(time, strength) for time in t]
|
||||
ax4.plot(t, retention, label=f'S = {strength}', linewidth=2)
|
||||
|
||||
ax4.set_xlabel('Time')
|
||||
ax4.set_ylabel('Memory Retention')
|
||||
ax4.legend()
|
||||
ax4.grid(True, alpha=0.3)
|
||||
ax4.set_ylim(0, 1)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def save_memory_curves_to_file(forgetting_engine,
|
||||
filename: str,
|
||||
time_range: Tuple[float, float] = (0, 10),
|
||||
memory_strengths: List[float] = None,
|
||||
num_points: int = 1000,
|
||||
format: str = 'npz') -> None:
|
||||
"""
|
||||
Save memory curves to file in various formats.
|
||||
|
||||
Args:
|
||||
forgetting_engine: Instance of ForgettingEngine
|
||||
filename: Output filename (without extension)
|
||||
time_range: Tuple of (start_time, end_time)
|
||||
memory_strengths: List of memory strength values
|
||||
num_points: Number of points to generate
|
||||
format: Output format ('npz', 'csv', 'json')
|
||||
"""
|
||||
if memory_strengths is None:
|
||||
memory_strengths = [0.5, 1.0, 2.0, 5.0]
|
||||
|
||||
curves_data = export_memory_curves_multiple_strengths(
|
||||
forgetting_engine, time_range, memory_strengths, num_points
|
||||
)
|
||||
|
||||
if format == 'npz':
|
||||
np.savez(f"{filename}.npz", **curves_data)
|
||||
elif format == 'csv':
|
||||
import pandas as pd
|
||||
df = pd.DataFrame(curves_data)
|
||||
df.to_csv(f"{filename}.csv", index=False)
|
||||
elif format == 'json':
|
||||
import json
|
||||
# Convert numpy arrays to lists for JSON serialization
|
||||
json_data = {k: v.tolist() if isinstance(v, np.ndarray) else v
|
||||
for k, v in curves_data.items()}
|
||||
with open(f"{filename}.json", 'w') as f:
|
||||
json.dump(json_data, f, indent=2)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: {format}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage
|
||||
from app.core.memory.storage_services.forgetting_engine import ForgettingEngine
|
||||
|
||||
print("Memory Visualization Utilities Demo")
|
||||
print("=" * 40)
|
||||
|
||||
# Create engine
|
||||
from app.core.memory.models.variate_config import ForgettingEngineConfig
|
||||
config = ForgettingEngineConfig(offset=0.1, lambda_time=0.5, lambda_mem=0.5)
|
||||
engine = ForgettingEngine(config)
|
||||
|
||||
# # Export single curve as numpy
|
||||
# time_arr, retention_arr = export_memory_curve_numpy(engine, (0, 10), 1.0, 100)
|
||||
# print(f"Exported single curve: {len(time_arr)} points")
|
||||
# print(f"Time range: {time_arr[0]:.2f} to {time_arr[-1]:.2f}")
|
||||
# print(f"Retention range: {retention_arr.min():.4f} to {retention_arr.max():.4f}")
|
||||
|
||||
# # Export multiple curves
|
||||
# curves = export_memory_curves_multiple_strengths(engine, (0, 10), [0.5, 1.0, 2.0])
|
||||
# print(f"\nExported multiple curves: {list(curves.keys())}")
|
||||
|
||||
# # Parameter sweep
|
||||
# param_sweep = export_parameter_sweep_numpy(engine, 'offset', [0.0, 0.1, 0.2, 0.3])
|
||||
# print(f"Parameter sweep results: {list(param_sweep.keys())}")
|
||||
|
||||
# print("\nVisualization functions are ready to use!")
|
||||
visualize_forgetting_curve(engine)
|
||||
create_comparison_visualization(engine)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user