Initial commit

This commit is contained in:
Ke Sun
2025-11-30 18:22:17 +08:00
commit aea2fe391e
449 changed files with 83030 additions and 0 deletions

View File

@@ -0,0 +1,445 @@
# Memory 模块工具函数文档
本目录包含 Memory 模块使用的所有工具函数,统一管理以提高代码可维护性和可复用性。
## 目录结构
```
app/core/memory/utils/
├── __init__.py # 包初始化文件,导出所有公共接口
├── README.md # 本文档
├── config/ # 配置管理模块
│ ├── __init__.py # 配置模块初始化
│ ├── config_utils.py # 配置管理工具
│ ├── definitions.py # 全局定义和常量
│ ├── overrides.py # 运行时配置覆写
│ ├── get_data.py # 数据获取工具
│ ├── litellm_config.py # LiteLLM 配置和监控
│ └── config_optimization.py # 配置优化工具
├── log/ # 日志管理模块
│ ├── __init__.py # 日志模块初始化
│ ├── logging_utils.py # 日志工具
│ └── audit_logger.py # 审计日志
├── prompt/ # 提示词管理模块
│ ├── __init__.py # 提示词模块初始化
│ ├── prompt_utils.py # 提示词渲染工具
│ ├── template_render.py # 模板渲染工具
│ └── prompts/ # Jinja2 提示词模板目录
│ ├── entity_dedup.jinja2 # 实体去重提示词
│ ├── extract_statement.jinja2 # 陈述句提取提示词
│ ├── extract_temporal.jinja2 # 时间信息提取提示词
│ ├── extract_triplet.jinja2 # 三元组提取提示词
│ ├── memory_summary.jinja2 # 记忆摘要提示词
│ ├── evaluate.jinja2 # 评估提示词
│ ├── reflexion.jinja2 # 反思提示词
│ ├── system.jinja2 # 系统提示词
│ └── user.jinja2 # 用户提示词
├── llm/ # LLM 工具模块
│ ├── __init__.py # LLM 模块初始化
│ └── llm_utils.py # LLM 客户端工具
├── data/ # 数据处理模块
│ ├── __init__.py # 数据模块初始化
│ ├── text_utils.py # 文本处理工具
│ ├── time_utils.py # 时间处理工具
│ └── ontology.py # 本体定义(谓语、标签等)
├── paths/ # 路径管理模块
│ ├── __init__.py # 路径模块初始化
│ └── output_paths.py # 输出路径管理
├── visualization/ # 可视化模块
│ ├── __init__.py # 可视化模块初始化
│ └── forgetting_visualizer.py # 遗忘曲线可视化
└── self_reflexion_utils/ # 自我反思工具模块
├── __init__.py # 反思模块初始化
├── evaluate.py # 冲突评估
├── reflexion.py # 反思处理
└── self_reflexion.py # 自我反思主逻辑
```
## 模块分类
### 1. 配置管理config/
配置管理模块包含所有与配置相关的工具函数和定义。
#### config_utils.py
提供配置加载和管理功能:
- `get_model_config(model_id)` - 获取 LLM 模型配置
- `get_embedder_config(embedding_id)` - 获取嵌入模型配置
- `get_neo4j_config()` - 获取 Neo4j 数据库配置
- `get_chunker_config(chunker_strategy)` - 获取分块策略配置
- `get_pipeline_config()` - 获取流水线配置
- `get_pruning_config()` - 获取语义剪枝配置
- `get_picture_config()` - 获取图片模型配置
- `get_voice_config()` - 获取语音模型配置
#### definitions.py
全局定义和常量:
- `CONFIG` - 基础配置(从 config.json 加载)
- `RUNTIME_CONFIG` - 运行时配置(从 runtime.json 或数据库加载)
- `PROJECT_ROOT` - 项目根目录路径
- 各种选择配置常量LLM、嵌入模型、分块策略等
- `reload_configuration_from_database(config_id)` - 动态重新加载配置
#### overrides.py
运行时配置覆写:
- `load_unified_config(project_root)` - 加载统一配置
#### get_data.py
数据获取工具:
- `get_data(host_id)` - 从 SQL 数据库获取数据
#### litellm_config.py
LiteLLM 配置和监控:
- `LiteLLMConfig` - LiteLLM 配置类
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
- `get_usage_summary()` - 获取使用统计摘要
- `print_usage_summary()` - 打印使用统计
- `get_instant_qps(module)` - 获取即时 QPS 数据
- `print_instant_qps(module)` - 打印即时 QPS 信息
#### config_optimization.py
配置优化工具:
- 配置参数优化相关功能
### 3. LLM 工具llm/
LLM 工具模块包含所有与 LLM 客户端相关的工具函数。
#### llm_utils.py
LLM 客户端工具:
- `get_llm_client(llm_id)` - 获取 LLM 客户端实例
- `get_reranker_client(rerank_id)` - 获取重排序客户端实例
- `handle_response(response)` - 处理 LLM 响应
#### litellm_config.py
LiteLLM 配置和监控:
- `LiteLLMConfig` - LiteLLM 配置类
- `setup_litellm_enhanced(max_retries)` - 设置增强的 LiteLLM 配置
- `get_usage_summary()` - 获取使用统计摘要
- `print_usage_summary()` - 打印使用统计
- `get_instant_qps(module)` - 获取即时 QPS 数据
- `print_instant_qps(module)` - 打印即时 QPS 信息
### 4. 提示词管理prompt/
提示词管理模块包含所有提示词渲染和模板管理相关的工具函数。
#### prompt_utils.py
提示词渲染工具(使用 Jinja2 模板):
- `get_prompts(message)` - 获取系统和用户提示词
- `render_statement_extraction_prompt(...)` - 渲染陈述句提取提示词
- `render_temporal_extraction_prompt(...)` - 渲染时间信息提取提示词
- `render_entity_dedup_prompt(...)` - 渲染实体去重提示词
- `render_triplet_extraction_prompt(...)` - 渲染三元组提取提示词
- `render_memory_summary_prompt(...)` - 渲染记忆摘要提示词
- `prompt_env` - Jinja2 环境对象
#### template_render.py
模板渲染工具(用于评估和反思):
- `render_evaluate_prompt(evaluate_data, schema)` - 渲染评估提示词
- `render_reflexion_prompt(data, schema)` - 渲染反思提示词
#### prompts/
Jinja2 模板文件目录,包含所有提示词模板
### 5. 数据处理data/
数据处理模块包含所有数据处理相关的工具函数。
#### text_utils.py
文本处理工具:
- `escape_lucene_query(query)` - 转义 Lucene 查询特殊字符
- `extract_plain_query(query_input)` - 从各种输入格式提取纯文本查询
#### time_utils.py
时间处理工具:
- `validate_date_format(date_str)` - 验证日期格式YYYY-MM-DD
- `normalize_date(date_str)` - 标准化日期格式
- `normalize_date_safe(date_str, default)` - 安全的日期标准化(带默认值)
- `preprocess_date_string(date_str)` - 预处理日期字符串
#### ontology.py
本体定义:
- `PREDICATE_DEFINITIONS` - 谓语定义字典
- `LABEL_DEFINITIONS` - 标签定义字典
- `Predicate` - 谓语枚举
- `StatementType` - 陈述句类型枚举
- `TemporalInfo` - 时间信息枚举
- `RelevenceInfo` - 相关性信息枚举
### 2. 日志管理log/
日志管理模块包含所有与日志记录相关的工具函数。
#### logging_utils.py
日志工具:
- `log_prompt_rendering(role, content)` - 记录提示词渲染
- `log_template_rendering(template_name, context)` - 记录模板渲染
- `log_time(operation, duration)` - 记录操作耗时
- `prompt_logger` - 提示词日志记录器
#### audit_logger.py
审计日志:
- `audit_logger` - 审计日志记录器
- 记录系统关键操作和安全事件
### 6. 自我反思工具self_reflexion_utils/
自我反思工具模块包含记忆冲突检测和反思处理功能。
#### evaluate.py
冲突评估:
- `conflict(evaluate_data, schema)` - 评估记忆冲突
#### reflexion.py
反思处理:
- `reflexion(data, schema)` - 执行反思处理
#### self_reflexion.py
自我反思主逻辑:
- `self_reflexion(...)` - 自我反思主函数
### 7. 数据模型
#### json_schema.py
JSON Schema 数据模型:
- `BaseDataSchema` - 基础数据模型
- `ConflictResultSchema` - 冲突结果模型
- `ConflictSchema` - 冲突模型
- `ReflexionSchema` - 反思模型
- `ResolvedSchema` - 解决方案模型
- `ReflexionResultSchema` - 反思结果模型
#### messages.py
API 消息模型:
- `ConfigKey` - 配置键模型
- `ChunkerStrategy` - 分块策略枚举
- `ConfigParams` - 配置参数模型
- `ConfigParamsCreate` - 创建配置参数模型
- `ConfigUpdate` - 更新配置模型
- `ConfigUpdateExtracted` - 更新萃取引擎配置模型
- `ConfigUpdateForget` - 更新遗忘引擎配置模型
- `ConfigPilotRun` - 试运行配置模型
- `ConfigFilter` - 配置过滤模型
- `ApiResponse` - API 响应模型
- `ok(msg, data)` - 成功响应构造函数
- `fail(msg, error_code, data)` - 失败响应构造函数
### 8. 可视化visualization/
可视化模块包含所有可视化相关的工具函数。
#### forgetting_visualizer.py
遗忘曲线可视化:
- `export_memory_curve_numpy(...)` - 导出记忆曲线为 NumPy 数组
- `export_memory_curves_multiple_strengths(...)` - 导出多个强度的记忆曲线
- `export_parameter_sweep_numpy(...)` - 导出参数扫描结果
- `visualize_forgetting_curve(...)` - 可视化遗忘曲线
- `plot_3d_forgetting_surface(...)` - 绘制 3D 遗忘曲线表面
- `create_comparison_visualization(...)` - 创建对比可视化
- `save_memory_curves_to_file(...)` - 保存记忆曲线到文件
### 9. 路径管理paths/
路径管理模块包含所有路径管理相关的工具函数。
#### output_paths.py
输出路径管理:
- `get_output_dir()` - 获取输出目录
- `get_output_path(filename)` - 获取输出文件路径
## 使用示例
### 配置管理
```python
from app.core.memory.utils.config import get_model_config, get_pipeline_config
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
# 获取模型配置
model_config = get_model_config("model_id_123")
# 获取流水线配置
pipeline_config = get_pipeline_config()
# 使用全局常量
llm_id = SELECTED_LLM_ID
```
### 日志管理
```python
from app.core.memory.utils.log import log_prompt_rendering, log_time, audit_logger
# 记录提示词渲染
log_prompt_rendering('user', 'Hello, world!')
# 记录操作耗时
log_time('extraction', 1.23)
# 使用审计日志
audit_logger.info('User action performed')
```
### LLM 工具
```python
from app.core.memory.utils.llm import get_llm_client
# 获取 LLM 客户端
llm_client = get_llm_client("llm_id_456")
# 调用 LLM
response = await llm_client.chat([
{"role": "user", "content": "Hello"}
])
```
### 提示词渲染
```python
from app.core.memory.utils.prompt import render_statement_extraction_prompt
from app.core.memory.utils.data.ontology import LABEL_DEFINITIONS
# 渲染陈述句提取提示词
prompt = await render_statement_extraction_prompt(
chunk_content="对话内容...",
definitions=LABEL_DEFINITIONS,
json_schema=schema,
granularity=2
)
```
### 数据处理
```python
from app.core.memory.utils.data.time_utils import normalize_date
from app.core.memory.utils.data.text_utils import escape_lucene_query
# 标准化日期
normalized = normalize_date("2025/10/28") # 返回 "2025-10-28"
# 转义 Lucene 查询
escaped = escape_lucene_query("user:admin AND status:active")
```
### 运行时配置覆写
```python
from app.core.memory.utils import apply_runtime_overrides_with_config_id
# 使用指定 config_id 覆写配置
runtime_cfg = {"selections": {}}
updated_cfg = apply_runtime_overrides_with_config_id(
project_root="/path/to/project",
runtime_cfg=runtime_cfg,
config_id="config_123"
)
```
## 迁移说明
### 从旧路径迁移
如果你的代码使用了旧的导入路径,请按以下方式更新:
**旧路径2024年11月之前**
```python
from app.core.memory.src.utils.config_utils import get_model_config
from app.core.memory.src.utils.prompt_utils import render_statement_extraction_prompt
from app.core.memory.src.data_config_api.utils.messages import ok, fail
```
**中间路径2024年11月**
```python
from app.core.memory.utils.config_utils import get_model_config
from app.core.memory.utils.logging_utils import log_prompt_rendering
from app.schemas.memory_storage_schema import ok, fail
```
**新路径2024年11月27日之后**
```python
# 配置相关
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.config import get_model_config # 简化导入
# 日志相关
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
from app.core.memory.utils.log import log_prompt_rendering # 简化导入
# 其他工具
from app.core.memory.utils import prompt_utils
from app.schemas.memory_storage_schema import ok, fail
```
### 目录结构重组2024年11月27日
utils 目录已按功能进行了完整的重组:
**重组前的结构:**
- 所有文件都在 `app/core/memory/utils/` 根目录下
**重组后的结构:**
- `config/` - 配置管理相关文件
- `log/` - 日志管理相关文件
- `prompt/` - 提示词管理相关文件
- `llm/` - LLM 工具相关文件
- `data/` - 数据处理相关文件
- `paths/` - 路径管理相关文件
- `visualization/` - 可视化相关文件
- `self_reflexion_utils/` - 自我反思工具(已存在)
**导入路径变化:**
```python
# 旧导入方式
from app.core.memory.utils.config_utils import get_model_config
from app.core.memory.utils.logging_utils import log_prompt_rendering
from app.core.memory.utils.prompt_utils import render_statement_extraction_prompt
# 新导入方式
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.log.logging_utils import log_prompt_rendering
from app.core.memory.utils.prompt.prompt_utils import render_statement_extraction_prompt
# 或使用简化导入
from app.core.memory.utils.config import get_model_config
from app.core.memory.utils.log import log_prompt_rendering
from app.core.memory.utils.prompt import render_statement_extraction_prompt
```
## 维护指南
### 添加新工具函数
1. 在相应的模块文件中添加函数
2.`__init__.py` 中导出函数
3. 在本 README 中添加文档
4. 编写单元测试
### 删除旧工具函数
1. 确认没有代码使用该函数
2. 从模块文件中删除函数
3.`__init__.py` 中删除导出
4. 更新本 README
### 重构工具函数
1. 保持向后兼容性(使用别名或包装器)
2. 更新所有使用该函数的代码
3. 更新文档和测试
4. 在适当时机删除旧版本
## 注意事项
1. **向后兼容性**:所有工具函数应保持向后兼容,避免破坏现有代码
2. **文档完整性**:每个函数都应有清晰的文档字符串
3. **类型注解**:使用类型注解提高代码可读性
4. **错误处理**:工具函数应有适当的错误处理
5. **测试覆盖**:所有工具函数都应有单元测试
## 相关文档
- [Memory 模块架构设计](../.kiro/specs/memory-refactoring/design.md)
- [Memory 模块需求文档](../.kiro/specs/memory-refactoring/requirements.md)
- [Memory 模块任务列表](../.kiro/specs/memory-refactoring/tasks.md)

View File

@@ -0,0 +1,65 @@
"""
Memory 模块工具函数包
本包包含 Memory 模块使用的所有工具函数,按功能分类管理。
目录结构:
- config/: 配置管理模块config_utils, definitions, overrides, get_data, litellm_config, config_optimization
- log/: 日志管理模块logging_utils, audit_logger
- prompt/: 提示词管理模块prompt_utils, template_render, prompts/
- llm/: LLM 工具模块llm_utils
- data/: 数据处理模块text_utils, time_utils, ontology
- paths/: 路径管理模块output_paths
- visualization/: 可视化模块forgetting_visualizer
- self_reflexion_utils/: 自我反思工具evaluate, reflexion, self_reflexion
注意:
- json_schema 和 messages 已迁移到 app.schemas.memory_storage_schema
- 所有工具函数已按功能分类到对应的子目录
使用示例:
# 配置管理
from app.core.memory.utils.config import get_model_config
from app.core.memory.utils.config.definitions import SELECTED_LLM_ID
# 日志管理
from app.core.memory.utils.log import log_prompt_rendering, audit_logger
# 提示词管理
from app.core.memory.utils.prompt import render_statement_extraction_prompt
# LLM 工具
from app.core.memory.utils.llm import get_llm_client
# 数据处理
from app.core.memory.utils.data import text_utils, time_utils
from app.core.memory.utils.data.ontology import Predicate, StatementType
# 路径管理
from app.core.memory.utils.paths import get_output_dir
# 可视化
from app.core.memory.utils.visualization import visualize_forgetting_curve
# 自我反思
from app.core.memory.utils.self_reflexion_utils import self_reflexion
"""
# 不在 __init__.py 中进行模块级别的导入,以避免循环导入
# 用户应该直接导入需要的模块,例如:
# from app.core.memory.utils.config import config_utils
# from app.core.memory.utils.log import logging_utils
# from app.core.memory.utils.data import text_utils
# from app.core.memory.utils.prompt import prompt_utils
__all__ = [
# 子模块
"config",
"log",
"prompt",
"llm",
"data",
"paths",
"visualization",
"self_reflexion_utils",
]

View File

@@ -0,0 +1,82 @@
"""
配置管理模块
包含所有配置相关的工具函数和定义。
"""
# 从子模块导出常用函数和常量,保持向后兼容
from .config_utils import (
get_model_config,
get_embedder_config,
get_neo4j_config,
get_chunker_config,
get_pipeline_config,
get_pruning_config,
get_picture_config,
get_voice_config,
)
from .definitions import (
CONFIG,
RUNTIME_CONFIG,
PROJECT_ROOT,
SELECTED_LLM_ID,
SELECTED_EMBEDDING_ID,
SELECTED_GROUP_ID,
SELECTED_RERANK_ID,
SELECTED_LLM_PICTURE_NAME,
SELECTED_LLM_VOICE_NAME,
REFLEXION_ENABLED,
REFLEXION_ITERATION_PERIOD,
REFLEXION_RANGE,
REFLEXION_BASELINE,
reload_configuration_from_database,
)
from .overrides import load_unified_config
from .get_data import get_data
# litellm_config 需要时动态导入,避免循环依赖
# from .litellm_config import (
# LiteLLMConfig,
# setup_litellm_enhanced,
# get_usage_summary,
# print_usage_summary,
# get_instant_qps,
# print_instant_qps,
# )
__all__ = [
# config_utils
"get_model_config",
"get_embedder_config",
"get_neo4j_config",
"get_chunker_config",
"get_pipeline_config",
"get_pruning_config",
"get_picture_config",
"get_voice_config",
# definitions
"CONFIG",
"RUNTIME_CONFIG",
"PROJECT_ROOT",
"SELECTED_LLM_ID",
"SELECTED_EMBEDDING_ID",
"SELECTED_GROUP_ID",
"SELECTED_RERANK_ID",
"SELECTED_LLM_PICTURE_NAME",
"SELECTED_LLM_VOICE_NAME",
"REFLEXION_ENABLED",
"REFLEXION_ITERATION_PERIOD",
"REFLEXION_RANGE",
"REFLEXION_BASELINE",
"reload_configuration_from_database",
# overrides
"load_unified_config",
# get_data
"get_data",
# litellm_config - 需要时从 .litellm_config 直接导入
# "LiteLLMConfig",
# "setup_litellm_enhanced",
# "get_usage_summary",
# "print_usage_summary",
# "get_instant_qps",
# "print_instant_qps",
]

View File

@@ -0,0 +1,398 @@
"""
配置管理优化模块
提供可选的配置管理优化功能,包括:
- LRU 缓存策略
- 缓存预热
- 缓存监控指标
- 动态 TTL 策略
- 配置版本控制
这些优化是可选的,当前的基础实现已经满足大多数需求。
"""
import logging
import statistics
import threading
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Dict, Any, List, Optional, Tuple
logger = logging.getLogger(__name__)
class LRUConfigCache:
"""
LRULeast Recently Used配置缓存
当缓存达到最大容量时,自动淘汰最少使用的配置
"""
def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)):
"""
初始化 LRU 缓存
Args:
max_size: 最大缓存容量
ttl: 缓存过期时间
"""
self.max_size = max_size
self.ttl = ttl
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
self._timestamps: Dict[str, datetime] = {}
self._lock = threading.RLock()
# 统计信息
self._stats = {
'hits': 0,
'misses': 0,
'evictions': 0,
'load_times': []
}
def get(self, config_id: str) -> Optional[Dict[str, Any]]:
"""
获取配置(如果存在且未过期)
Args:
config_id: 配置 ID
Returns:
配置字典,如果不存在或已过期则返回 None
"""
with self._lock:
if config_id not in self._cache:
self._stats['misses'] += 1
return None
# 检查是否过期
timestamp = self._timestamps.get(config_id)
if timestamp and (datetime.now() - timestamp) >= self.ttl:
# 过期,移除
self._cache.pop(config_id, None)
self._timestamps.pop(config_id, None)
self._stats['misses'] += 1
return None
# 命中,移动到末尾(标记为最近使用)
self._cache.move_to_end(config_id)
self._stats['hits'] += 1
return self._cache[config_id]
def put(self, config_id: str, config: Dict[str, Any]) -> None:
"""
添加或更新配置
Args:
config_id: 配置 ID
config: 配置字典
"""
with self._lock:
if config_id in self._cache:
# 更新现有配置
self._cache.move_to_end(config_id)
else:
# 添加新配置
if len(self._cache) >= self.max_size:
# 缓存已满,移除最旧的配置
oldest_id, _ = self._cache.popitem(last=False)
self._timestamps.pop(oldest_id, None)
self._stats['evictions'] += 1
logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}")
self._cache[config_id] = config
self._timestamps[config_id] = datetime.now()
def clear(self, config_id: Optional[str] = None) -> None:
"""
清除缓存
Args:
config_id: 如果指定,只清除该配置;否则清除所有
"""
with self._lock:
if config_id:
self._cache.pop(config_id, None)
self._timestamps.pop(config_id, None)
else:
self._cache.clear()
self._timestamps.clear()
def get_stats(self) -> Dict[str, Any]:
"""
获取缓存统计信息
Returns:
统计信息字典
"""
with self._lock:
total = self._stats['hits'] + self._stats['misses']
hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0
return {
'cache_size': len(self._cache),
'max_size': self.max_size,
'total_requests': total,
'cache_hits': self._stats['hits'],
'cache_misses': self._stats['misses'],
'evictions': self._stats['evictions'],
'hit_rate': hit_rate,
'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0
}
def record_load_time(self, load_time_ms: float) -> None:
"""
记录加载时间
Args:
load_time_ms: 加载时间(毫秒)
"""
with self._lock:
self._stats['load_times'].append(load_time_ms)
# 只保留最近 1000 次的记录
if len(self._stats['load_times']) > 1000:
self._stats['load_times'] = self._stats['load_times'][-1000:]
class ConfigCacheWarmer:
"""
配置缓存预热器
在系统启动时预加载常用配置,减少首次请求延迟
"""
@staticmethod
def warmup(config_ids: List[str], load_func) -> Dict[str, bool]:
"""
预热缓存
Args:
config_ids: 要预加载的配置 ID 列表
load_func: 配置加载函数
Returns:
每个配置的加载结果
"""
results = {}
logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置")
for config_id in config_ids:
try:
result = load_func(config_id)
results[config_id] = result
if result:
logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}")
else:
logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}")
except Exception as e:
logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}")
results[config_id] = False
success_count = sum(1 for r in results.values() if r)
logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功")
return results
class DynamicTTLStrategy:
"""
动态 TTL 策略
根据配置类型和更新频率动态调整缓存过期时间
"""
# 预定义的 TTL 策略
TTL_STRATEGIES = {
'production': timedelta(minutes=30), # 生产配置较稳定
'staging': timedelta(minutes=15), # 预发布配置中等稳定
'development': timedelta(minutes=5), # 开发配置频繁变化
'testing': timedelta(minutes=1), # 测试配置快速过期
'default': timedelta(minutes=5) # 默认策略
}
@classmethod
def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta:
"""
获取配置的 TTL
Args:
config_id: 配置 ID
config_type: 配置类型production/staging/development/testing
Returns:
TTL 时间间隔
"""
if config_type and config_type in cls.TTL_STRATEGIES:
return cls.TTL_STRATEGIES[config_type]
# 根据 config_id 推断类型
if 'prod' in config_id.lower():
return cls.TTL_STRATEGIES['production']
elif 'stag' in config_id.lower():
return cls.TTL_STRATEGIES['staging']
elif 'dev' in config_id.lower():
return cls.TTL_STRATEGIES['development']
elif 'test' in config_id.lower():
return cls.TTL_STRATEGIES['testing']
return cls.TTL_STRATEGIES['default']
class ConfigVersionManager:
"""
配置版本管理器
跟踪配置版本,当配置更新时自动失效旧版本缓存
"""
def __init__(self):
self._versions: Dict[str, str] = {}
self._lock = threading.RLock()
def get_version(self, config_id: str) -> Optional[str]:
"""
获取配置版本
Args:
config_id: 配置 ID
Returns:
版本号,如果不存在则返回 None
"""
with self._lock:
return self._versions.get(config_id)
def set_version(self, config_id: str, version: str) -> None:
"""
设置配置版本
Args:
config_id: 配置 ID
version: 版本号
"""
with self._lock:
old_version = self._versions.get(config_id)
self._versions[config_id] = version
if old_version and old_version != version:
logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}")
def check_version(self, config_id: str, cached_version: Optional[str]) -> bool:
"""
检查缓存版本是否有效
Args:
config_id: 配置 ID
cached_version: 缓存的版本号
Returns:
True 如果版本匹配False 如果版本不匹配或不存在
"""
with self._lock:
current_version = self._versions.get(config_id)
if not current_version or not cached_version:
return False
return current_version == cached_version
def invalidate(self, config_id: str) -> None:
"""
使配置版本失效
Args:
config_id: 配置 ID
"""
with self._lock:
if config_id in self._versions:
# 生成新版本号
import uuid
new_version = str(uuid.uuid4())
self._versions[config_id] = new_version
logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}")
class CacheMonitor:
"""
缓存监控器
提供缓存性能监控和报告功能
"""
def __init__(self, cache: LRUConfigCache):
self.cache = cache
def get_report(self) -> str:
"""
生成缓存性能报告
Returns:
格式化的报告字符串
"""
stats = self.cache.get_stats()
report = f"""
配置缓存性能报告
================
缓存容量: {stats['cache_size']}/{stats['max_size']}
总请求数: {stats['total_requests']}
缓存命中: {stats['cache_hits']}
缓存未命中: {stats['cache_misses']}
缓存命中率: {stats['hit_rate']:.2f}%
淘汰次数: {stats['evictions']}
平均加载时间: {stats['avg_load_time']:.2f}ms
"""
return report
def log_stats(self) -> None:
"""记录统计信息到日志"""
stats = self.cache.get_stats()
logger.info(
f"[CacheMonitor] 缓存统计 - "
f"容量: {stats['cache_size']}/{stats['max_size']}, "
f"命中率: {stats['hit_rate']:.2f}%, "
f"淘汰: {stats['evictions']}"
)
# 使用示例
def example_usage():
"""
优化功能使用示例
"""
# 1. 使用 LRU 缓存
lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5))
# 获取配置
config = lru_cache.get("config_001")
if config is None:
# 缓存未命中,从数据库加载
config = {"llm_name": "openai/gpt-4"}
lru_cache.put("config_001", config)
# 2. 预热缓存
def load_config(config_id):
# 实际的配置加载逻辑
return True
warmer = ConfigCacheWarmer()
results = warmer.warmup(["config_001", "config_002"], load_config)
# 3. 动态 TTL
ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production")
print(f"TTL: {ttl}")
# 4. 版本管理
version_manager = ConfigVersionManager()
version_manager.set_version("config_001", "v1.0.0")
# 检查版本
is_valid = version_manager.check_version("config_001", "v1.0.0")
# 5. 监控
monitor = CacheMonitor(lru_cache)
print(monitor.get_report())
if __name__ == "__main__":
example_usage()

View File

@@ -0,0 +1,267 @@
import uuid
import json
from typing import Optional
from sqlalchemy.orm import Session
from fastapi.exceptions import HTTPException
from fastapi import status
from app.core.memory.utils.config.definitions import CONFIG, RUNTIME_CONFIG
from app.core.memory.models.variate_config import (
ExtractionPipelineConfig,
DedupConfig,
StatementExtractionConfig,
ForgettingEngineConfig,
)
from app.core.memory.models.config_models import PruningConfig
from app.db import get_db
from app.models.models_model import ModelConfig, ModelApiKey
from app.services.model_service import ModelConfigService
def get_model_config(model_id: str, db: Session | None = None) -> dict:
if db is None:
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen) # 取到真正的 Session
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
if not config:
print(f"模型ID {model_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
# 从环境变量读取超时和重试配置
from app.core.config import settings
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
"type": config.type,
# 添加超时和重试配置,避免 LLM 请求超时
"timeout": settings.LLM_TIMEOUT, # 从环境变量读取默认120秒
"max_retries": settings.LLM_MAX_RETRIES, # 从环境变量读取默认2次
}
# 写入model_config.log文件中
with open("logs/model_config.log", "a", encoding="utf-8") as f:
f.write(f"模型ID: {model_id}\n")
f.write(f"模型配置信息:\n{model_config}\n")
f.write(f"=============================\n\n")
return model_config
def get_embedder_config(embedding_id: str, db: Session | None = None) -> dict:
if db is None:
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen) # 取到真正的 Session
config = ModelConfigService.get_model_by_id(db=db, model_id=embedding_id)
if not config:
print(f"嵌入模型ID {embedding_id} 不存在")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="嵌入模型ID不存在")
apiConfig: ModelApiKey = config.api_keys[0]
model_config = {
"model_name": apiConfig.model_name,
"provider": apiConfig.provider,
"api_key": apiConfig.api_key,
"base_url": apiConfig.api_base,
"model_config_id":apiConfig.model_config_id,
# Ensure required field for RedBearModelConfig validation
"type": config.type,
# 添加超时和重试配置,避免嵌入服务请求超时
"timeout": 120.0, # 嵌入服务超时时间(秒)
"max_retries": 5, # 最大重试次数
}
# 写入embedder_config.log文件中
with open("logs/embedder_config.log", "a", encoding="utf-8") as f:
f.write(f"嵌入模型ID: {embedding_id}\n")
f.write(f"嵌入模型配置信息:\n{model_config}\n")
f.write(f"=============================\n\n")
return model_config
def get_neo4j_config() -> dict:
"""Retrieves the Neo4j configuration from the config file."""
return CONFIG.get("neo4j", {})
def get_picture_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file."""
for model_config in CONFIG.get("picture_recognition", []):
if model_config["llm_name"] == llm_name:
return model_config
raise ValueError(f"Model '{llm_name}' not found in config.json")
def get_voice_config(llm_name: str) -> dict:
"""Retrieves the configuration for a specific model from the config file."""
for model_config in CONFIG.get("voice_recognition", []):
if model_config["llm_name"] == llm_name:
return model_config
raise ValueError(f"Model '{llm_name}' not found in config.json")
def get_chunker_config(chunker_strategy: str) -> dict:
"""Retrieves the configuration for a specific chunker strategy.
Enhancements:
- Supports default configs for `LLMChunker` and `HybridChunker` if not present.
- Falls back to the first available chunker config when the requested one is missing.
"""
# 1) Try to find exact match in config
chunker_list = CONFIG.get("chunker_list", [])
for chunker_config in chunker_list:
if chunker_config.get("chunker_strategy") == chunker_strategy:
return chunker_config
# 2) Provide sane defaults for newer strategies
default_configs = {
"LLMChunker": {
"chunker_strategy": "LLMChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 1000,
"threshold": 0.8,
"min_sentences": 2,
"language": "zh",
"skip_window": 1,
"min_characters_per_chunk": 100,
},
"HybridChunker": {
"chunker_strategy": "HybridChunker",
"embedding_model": "BAAI/bge-m3",
"chunk_size": 512,
"threshold": 0.8,
"min_sentences": 2,
"language": "zh",
"skip_window": 1,
"min_characters_per_chunk": 100,
},
}
if chunker_strategy in default_configs:
return default_configs[chunker_strategy]
# 3) Fallback: use first available config but tag with requested strategy
if chunker_list:
fallback = chunker_list[0].copy()
fallback["chunker_strategy"] = chunker_strategy
# Non-fatal notice for visibility in logs if any
print(f"Warning: Using first available chunker config as fallback for '{chunker_strategy}'")
return fallback
# 4) If no configs available at all
raise ValueError(
f"Chunker '{chunker_strategy}' not found in config.json and no default or fallback available"
)
def get_pipeline_config() -> ExtractionPipelineConfig:
"""Build ExtractionPipelineConfig using only runtime.json values.
Behavior:
- Read `deduplication` section from runtime.json if present.
- Read `statement_extraction` section from runtime.json if present.
- Read `forgetting_engine` section from runtime.json if present.
- If absent, check legacy top-level `enable_llm_dedup` key.
- Do NOT fall back to environment variables.
- Unspecified fields use model defaults defined in DedupConfig.
"""
dedup_rc = RUNTIME_CONFIG.get("deduplication", {}) or {}
stmt_rc = RUNTIME_CONFIG.get("statement_extraction", {}) or {}
forget_rc = RUNTIME_CONFIG.get("forgetting_engine", {}) or {}
# Assemble kwargs from runtime.json only
kwargs = {}
# LLM switch: prefer new key, then legacy top-level, default False
if "enable_llm_dedup_blockwise" in dedup_rc:
kwargs["enable_llm_dedup_blockwise"] = bool(dedup_rc.get("enable_llm_dedup_blockwise"))
else:
# Legacy top-level fallback inside runtime.json only
legacy = RUNTIME_CONFIG.get("enable_llm_dedup")
if legacy is not None:
kwargs["enable_llm_dedup_blockwise"] = bool(legacy)
else:
kwargs["enable_llm_dedup_blockwise"] = False # default reserve
# Disambiguation switch: only from runtime.json deduplication section
if "enable_llm_disambiguation" in dedup_rc:
kwargs["enable_llm_disambiguation"] = bool(dedup_rc.get("enable_llm_disambiguation"))
# Optional LLM fallback gating
if "enable_llm_fallback_only_on_borderline" in dedup_rc:
kwargs["enable_llm_fallback_only_on_borderline"] = bool(dedup_rc.get("enable_llm_fallback_only_on_borderline"))
# Optional fuzzy thresholds: use values if provided; otherwise rely on DedupConfig defaults
for key in (
"fuzzy_name_threshold_strict",
"fuzzy_type_threshold_strict",
"fuzzy_overall_threshold",
"fuzzy_unknown_type_name_threshold",
"fuzzy_unknown_type_type_threshold",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
# Optional weights and bonuses for overall scoring
for key in (
"name_weight",
"desc_weight",
"type_weight",
"context_bonus",
"llm_fallback_floor",
"llm_fallback_ceiling",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
# Optional LLM iterative dedup parameters
for key in (
"llm_block_size",
"llm_block_concurrency",
"llm_pair_concurrency",
"llm_max_rounds",
):
if key in dedup_rc:
kwargs[key] = dedup_rc[key]
dedup_config = DedupConfig(**kwargs)
# Build StatementExtractionConfig from runtime.json
stmt_kwargs = {}
for key in (
"statement_granularity",
"temperature",
"include_dialogue_context",
"max_dialogue_context_chars",
):
if key in stmt_rc:
stmt_kwargs[key] = stmt_rc[key]
stmt_config = StatementExtractionConfig(**stmt_kwargs)
# Build ForgettingEngineConfig from runtime.json
forget_kwargs = {}
for key in ("offset", "lambda_time", "lambda_mem"):
if key in forget_rc:
forget_kwargs[key] = forget_rc[key]
forget_config = ForgettingEngineConfig(**forget_kwargs)
return ExtractionPipelineConfig(
statement_extraction=stmt_config,
deduplication=dedup_config,
forgetting_engine=forget_config,
)
def get_pruning_config() -> dict:
"""Retrieve semantic pruning config from runtime.json.
Returns a dict suitable for PruningConfig.model_validate.
Structure in runtime.json:
{
"pruning": {
"enabled": true,
"scene": "education" | "online_service" | "outbound",
"threshold": 0.5
}
}
"""
pruning_rc = RUNTIME_CONFIG.get("pruning", {}) or {}
return {
"pruning_switch": bool(pruning_rc.get("enabled", False)),
"pruning_scene": pruning_rc.get("scene", "education"),
"pruning_threshold": float(pruning_rc.get("threshold", 0.5)),
}

View File

@@ -0,0 +1,360 @@
"""
配置加载模块 - 三阶段架构(已迁移到统一配置管理)
本模块现在使用全局配置管理系统 (app/core/config.py)
来加载和管理配置,同时保持向后兼容性。
阶段 1: 从 runtime.json 加载配置(路径 A
阶段 2: 从数据库加载配置(路径 B基于 dbrun.json 中的 config_id
阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)
"""
import os
import json
import threading
from typing import Any, Dict, Optional
from datetime import datetime, timedelta
try:
from dotenv import load_dotenv
load_dotenv()
except Exception:
pass
# Import unified configuration system
try:
from app.core.config import settings
USE_UNIFIED_CONFIG = True
except ImportError:
USE_UNIFIED_CONFIG = False
settings = None
# PROJECT_ROOT 应该指向 app/core/memory/ 目录
# __file__ = app/core/memory/utils/config/definitions.py
# os.path.dirname(__file__) = app/core/memory/utils/config
# os.path.dirname(...) = app/core/memory/utils
# os.path.dirname(...) = app/core/memory
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# 全局配置锁 - 用于线程安全
_config_lock = threading.RLock()
# 加载基础配置config.json- 使用全局配置系统
if USE_UNIFIED_CONFIG:
CONFIG = settings.load_memory_config()
else:
# Fallback to legacy loading
config_path = os.path.join(PROJECT_ROOT, "config.json")
try:
with open(config_path, "r") as f:
CONFIG = json.load(f)
except (FileNotFoundError, json.JSONDecodeError):
print("Warning: config.json not found or is malformed. Using default settings.")
CONFIG = {}
DEFAULT_VALUES = {
"llm_name": "openai/qwen-plus",
"embedding_name": "openai/nomic-embed-text:v1.5",
"chunker_strategy": "RecursiveChunker",
"group_id": "group_123",
"user_id": "default_user",
"apply_id": "default_apply",
"llm_agent_name": "openai/qwen-plus",
"llm_verify_name": "openai/qwen-plus",
"llm_image_recognition": "openai/qwen-plus",
"llm_voice_recognition": "openai/qwen-plus",
"prompt_level": "DEBUG",
"reflexion_iteration_period": "3",
"reflexion_range": "retrieval",
"reflexion_baseline": "TIME",
}
# 阶段 1: 从 runtime.json 加载配置(路径 A
def _load_from_runtime_json() -> Dict[str, Any]:
"""
从 runtime.json 文件加载配置(通过统一配置加载器)
使用 overrides.py 的统一配置加载器,按优先级加载:
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id
2. 环境变量配置
3. runtime.json 默认配置
Returns:
Dict[str, Any]: 运行时配置字典
"""
try:
# 使用 overrides.py 的统一配置加载器
from app.core.memory.utils.config.overrides import load_unified_config
runtime_cfg = load_unified_config(PROJECT_ROOT)
return runtime_cfg
except Exception as e:
# Fallback: 直接读取 runtime.json
runtime_config_path = os.path.join(PROJECT_ROOT, "runtime.json")
try:
with open(runtime_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e2:
pass # print(f"[definitions] ❌ 无法加载 runtime.json: {e2},使用空配置")
return {"selections": {}}
# 阶段 2: 从数据库加载配置(路径 B- 已整合到统一加载器
# 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
# 保留此函数仅为向后兼容
def _load_from_database() -> Optional[Dict[str, Any]]:
"""
从数据库加载配置(基于 dbrun.json 中的 config_id
注意:此函数已被统一配置加载器替代,现在直接调用 _load_from_runtime_json
即可获得包含数据库配置的完整配置。
Returns:
Optional[Dict[str, Any]]: 配置字典
"""
try:
# 直接使用统一配置加载器
return _load_from_runtime_json()
except Exception:
return None
# 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)
def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
"""
将运行时配置暴露为全局常量供项目使用
这是路径 Aruntime.json和路径 B数据库的汇合点
无论配置来自哪里,都通过这个函数统一暴露为常量。
Args:
runtime_cfg: 运行时配置字典
"""
global RUNTIME_CONFIG, SELECTIONS, LOGGING_CONFIG
global LANGFUSE_ENABLED, AGENTA_ENABLED, PROMPT_LOG_LEVEL_NAME
global SELECTED_LLM_NAME, SELECTED_EMBEDDING_NAME, SELECTED_CHUNKER_STRATEGY
global SELECTED_GROUP_ID, SELECTED_USER_ID, SELECTED_APPLY_ID, SELECTED_TEST_DATA_INDICES
global SELECTED_LLM_AGENT_NAME, SELECTED_LLM_VERIFY_NAME, SELECTED_LLM_PICTURE_NAME, SELECTED_LLM_VOICE_NAME
global SELECTED_LLM_ID, SELECTED_EMBEDDING_ID, SELECTED_RERANK_ID
global REFLEXION_CONFIG, REFLEXION_ENABLED, REFLEXION_ITERATION_PERIOD, REFLEXION_RANGE, REFLEXION_BASELINE
RUNTIME_CONFIG = runtime_cfg
# 可观测性配置
LANGFUSE_ENABLED = RUNTIME_CONFIG.get("langfuse", {}).get("enabled", False)
AGENTA_ENABLED = RUNTIME_CONFIG.get("agenta", {}).get("enabled", False)
# 日志配置
LOGGING_CONFIG = RUNTIME_CONFIG.get("logging", {})
PROMPT_LOG_LEVEL_NAME = LOGGING_CONFIG.get("prompt_level", DEFAULT_VALUES["prompt_level"])
# 选择配置
SELECTIONS = RUNTIME_CONFIG.get("selections", {})
# 基础模型选择
SELECTED_LLM_NAME = SELECTIONS.get("llm_name", DEFAULT_VALUES["llm_name"])
SELECTED_EMBEDDING_NAME = SELECTIONS.get("embedding_name", DEFAULT_VALUES["embedding_name"])
SELECTED_CHUNKER_STRATEGY = SELECTIONS.get("chunker_strategy", DEFAULT_VALUES["chunker_strategy"])
# 分组和用户配置
SELECTED_GROUP_ID = SELECTIONS.get("group_id", DEFAULT_VALUES["group_id"])
SELECTED_USER_ID = SELECTIONS.get("user_id", DEFAULT_VALUES["user_id"])
SELECTED_APPLY_ID = SELECTIONS.get("apply_id", DEFAULT_VALUES["apply_id"])
SELECTED_TEST_DATA_INDICES = SELECTIONS.get("test_data_indices", None)
# 专用 LLM 配置
SELECTED_LLM_AGENT_NAME = SELECTIONS.get("llm_agent_name", DEFAULT_VALUES["llm_agent_name"])
SELECTED_LLM_VERIFY_NAME = SELECTIONS.get("llm_verify_name", DEFAULT_VALUES["llm_verify_name"])
SELECTED_LLM_PICTURE_NAME = SELECTIONS.get("llm_image_recognition", DEFAULT_VALUES["llm_image_recognition"])
SELECTED_LLM_VOICE_NAME = SELECTIONS.get("llm_voice_recognition", DEFAULT_VALUES["llm_voice_recognition"])
# 模型 ID 配置
SELECTED_LLM_ID = SELECTIONS.get("llm_id", None)
SELECTED_EMBEDDING_ID = SELECTIONS.get("embedding_id", None)
SELECTED_RERANK_ID = SELECTIONS.get("rerank_id", None)
# 反思配置
REFLEXION_CONFIG = RUNTIME_CONFIG.get("reflexion", {})
REFLEXION_ENABLED = REFLEXION_CONFIG.get("enabled", False)
REFLEXION_ITERATION_PERIOD = REFLEXION_CONFIG.get("iteration_period", DEFAULT_VALUES["reflexion_iteration_period"])
REFLEXION_RANGE = REFLEXION_CONFIG.get("reflexion_range", DEFAULT_VALUES["reflexion_range"])
REFLEXION_BASELINE = REFLEXION_CONFIG.get("baseline", DEFAULT_VALUES["reflexion_baseline"])
# 初始化:使用统一配置加载器
def _initialize_configuration() -> None:
"""
初始化配置:使用统一配置加载器
配置加载优先级(由 overrides.py 统一处理):
1. 数据库配置(如果 dbrun.json 中有 config_id/group_id
2. 环境变量配置(.env
3. runtime.json 默认配置
"""
try:
# 使用统一配置加载器(已包含所有优先级处理)
runtime_config = _load_from_runtime_json()
# 暴露为全局常量
_expose_runtime_constants(runtime_config)
except Exception as e:
pass # print(f"[definitions] × 配置初始化失败: {e}")
# 使用空配置
_expose_runtime_constants({"selections": {}})
# 模块加载时自动初始化配置
_initialize_configuration()
# 公共 API动态重新加载配置
def reload_configuration_from_database(config_id: int | str, force_reload: bool = False) -> bool:
"""
动态重新加载配置(从数据库)- 使用统一配置加载器
用于运行时切换配置,例如前端传入新的 config_id 时调用。
注意:此函数仅在内存中覆写配置,不会修改 runtime.json 文件。
Args:
config_id: 配置 ID整数或字符串会自动转换
force_reload: 保留参数以保持向后兼容(已移除缓存逻辑)
Returns:
bool: 是否成功重新加载配置
"""
import logging
logger = logging.getLogger(__name__)
# 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
with _config_lock:
try:
from app.core.memory.utils.config.overrides import load_unified_config
except Exception as e:
logger.error(f"[definitions] 导入统一配置加载器失败: {e}")
# 记录配置加载失败
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"error": f"Import failed: {str(e)}"}
)
return False
try:
logger.info(f"[definitions] 开始重新加载配置config_id={config_id}")
# 使用统一配置加载器(指定 config_id
updated_cfg = load_unified_config(PROJECT_ROOT, config_id=config_id)
# 检查是否成功加载
if not updated_cfg or not updated_cfg.get('selections'):
logger.error(f"[definitions] 配置加载失败:数据库中未找到 config_id={config_id} 的配置")
# 记录配置加载失败
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"reason": "config not found in database"}
)
return False
# 重新暴露常量
_expose_runtime_constants(updated_cfg)
logger.info(f"[definitions] 配置重新加载成功,已暴露常量")
logger.debug(f"[definitions] 配置详情: LLM_ID={updated_cfg.get('selections', {}).get('llm_id')}, "
f"EMBEDDING_ID={updated_cfg.get('selections', {}).get('embedding_id')}")
# 记录成功的配置加载
if audit_logger:
selections = updated_cfg.get('selections', {})
audit_logger.log_config_load(
config_id=config_id,
user_id=selections.get('user_id', None),
group_id=selections.get('group_id', None),
success=True,
details={
"llm_id": selections.get('llm_id'),
"embedding_id": selections.get('embedding_id'),
"chunker_strategy": selections.get('chunker_strategy')
}
)
return True
except Exception as e:
logger.error(f"[definitions] 重新加载配置时发生异常: {e}", exc_info=True)
# 记录配置加载异常
if audit_logger:
audit_logger.log_config_load(
config_id=config_id,
success=False,
details={"error": str(e)}
)
return False
def get_current_config_id() -> Optional[str]:
"""
获取当前使用的 config_id
Returns:
Optional[str]: 当前的 config_id如果未设置则返回 None
"""
return SELECTIONS.get("config_id", None)
def ensure_fresh_config(config_id: Optional[int | str] = None) -> bool:
"""
确保使用最新的配置(每次写入操作前调用)
如果提供了 config_id则加载该配置
否则从 dbrun.json 读取并加载最新配置。
Args:
config_id: 可选的配置ID整数或字符串会自动转换
Returns:
bool: 是否成功加载配置
"""
import logging
logger = logging.getLogger(__name__)
with _config_lock:
try:
if config_id:
# 使用指定的 config_id
logger.debug(f"[definitions] 加载指定配置config_id={config_id}")
return reload_configuration_from_database(config_id)
else:
# 从数据库重新加载配置
logger.debug("[definitions] 从数据库重新加载最新配置")
memory_config = _load_from_database()
if not memory_config or not memory_config.get('selections'):
logger.warning("[definitions] 未能从数据库加载配置,使用当前配置")
return False
_expose_memory_constants(memory_config)
return True
except Exception as e:
logger.error(f"[definitions] 加载配置失败: {e}", exc_info=True)
return False

View File

@@ -0,0 +1,93 @@
import json
import os
import uuid
from typing import List, Dict, Any, Optional
from sqlalchemy.orm import Session
from app.db import get_db
from app.models.retrieval_info import RetrievalInfo
from app.schemas.memory_storage_schema import BaseDataSchema
import logging
logger = logging.getLogger(__name__)
async def _load_(data: List[Any]) -> List[Dict]:
target_keys = [
"id",
"statement",
"group_id",
"chunk_id",
"created_at",
"expired_at",
"valid_at",
"invalid_at",
]
results = []
for row in data or []:
s = None
if isinstance(row, (tuple, list)) and row:
s = row[0]
elif hasattr(row, "retrieve_info"):
s = getattr(row, "retrieve_info")
elif isinstance(row, dict) and "retrieve_info" in row:
s = row.get("retrieve_info")
elif hasattr(row, "_mapping") and "retrieve_info" in getattr(row, "_mapping"):
s = row._mapping["retrieve_info"]
else:
s = row
if s is None:
continue
if isinstance(s, bytes):
try:
s = s.decode("utf-8")
except Exception:
try:
s = s.decode()
except Exception:
continue
s = str(s).strip()
if not s or s == "[]":
continue
try:
parsed = json.loads(s)
except Exception:
continue
items = parsed if isinstance(parsed, list) else [parsed]
for item in items:
if "statement" not in item and "statements" in item:
item["statement"] = item.get("statements") or ""
normalized = {k: item.get(k, "") for k in target_keys}
results.append(normalized)
return results
async def get_data(host_id: uuid.UUID) -> List[Dict]:
"""
从数据库中获取数据
"""
# 从数据库会话中获取会话
db: Session = next(get_db())
try:
data = db.query(RetrievalInfo.retrieve_info).filter(RetrievalInfo.host_id == host_id).all()
# print(f"data:\n{data}")
# 解析,提取为字典的列表
results = await _load_(data)
return results
except Exception as e:
logger.error(f"failed to get data from database, host_id: {host_id}, error: {e}")
raise e
finally:
try:
db.close()
except Exception:
pass
if __name__ == "__main__":
import asyncio
# 从数据库中获取数据
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
data = asyncio.run(get_data(host_id))
print(type(data))
print(data)

View File

@@ -0,0 +1,90 @@
import os
import re
import uuid
import random
import string
from typing import List, Dict, Optional
# 生成包含字母(大小写)和数字的随机字符串
def generate_random_string(length=16):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
def get_example_data() -> List[Dict[str, Optional[str]]]:
"""
从句子提取日志中获取数据
Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。
Created At: 2025-11-28 19:28:38.256421
Expired At: None
Valid At: None
Invalid At: None
将数据构造成如下形式:
[
{
"id":id,
"group_id":group_id,
"statement": Content,
"created_at": Created At,
"expired_at": Expired At,
"valid_at": Valid At,
"invalid_at": Invalid At,
"chunk_id": "86da9022710c40eaa5f518a294c398d2",
"entity_ids": []
},
...
]
"""
# 获取日志文件路径
log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt")
# 检查文件是否存在
if not os.path.exists(log_file_path):
return []
# 读取日志文件
with open(log_file_path, "r", encoding="utf-8") as f:
content = f.read()
# 解析数据
results = []
# 使用正则表达式分割每个 Statement
statement_blocks = re.split(r"Statement \d+:", content)
for block in statement_blocks[1:]: # 跳过第一个空块
# 提取各个字段
id_match = re.search(r"Id:\s*(.+?)(?=\n)", block)
group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block)
statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block)
created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block)
expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block)
valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block)
invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block)
chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block)
# 构造字典
if statement_match:
statement_data = {
"id": id_match.group(1).strip() if id_match else generate_random_string(),
"group_id": group_id_match.group(1).strip() if group_id_match else "group_example",
"statement": statement_match.group(1).strip(),
"created_at": created_at_match.group(1).strip() if created_at_match else None,
"expired_at": expired_at_match.group(1).strip() if expired_at_match else None,
"valid_at": valid_at_match.group(1).strip() if valid_at_match else None,
"invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None,
"chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example",
"entity_ids": []
}
# 将 "None" 字符串转换为 None
for key in ["created_at", "expired_at", "valid_at", "invalid_at"]:
if statement_data[key] == "None":
statement_data[key] = None
results.append(statement_data)
return results
if __name__ == "__main__":
print(f"获取数据如下:\n {get_example_data()}")

View File

@@ -0,0 +1,516 @@
"""
LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring
"""
import litellm
from typing import Dict, Any, List
import json
from datetime import datetime, timedelta
import os
import time
from collections import defaultdict
import threading
from queue import Queue
class LiteLLMConfig:
"""Configuration class for LiteLLM with enhanced retry and tracking capabilities"""
def __init__(self):
self.usage_data = []
self.error_data = []
self.module_stats = defaultdict(lambda: {
'requests': 0,
'tokens_in': 0,
'tokens_out': 0,
'cost': 0.0,
'errors': 0,
'start_time': None,
'last_request_time': None,
'request_timestamps': [], # Store precise timestamps
'current_qps': 0.0,
'max_qps': 0.0,
'qps_history': [] # Store QPS measurements over time
})
self.start_time = datetime.now()
self.global_request_timestamps = []
self.global_max_qps = 0.0
# Rate limiting for AWS Bedrock (conservative limits)
self.rate_limits = {
'bedrock': {
'requests_per_minute': 2, # AWS Bedrock default is very low
'requests_per_second': 0.033, # 2/60 = 0.033 RPS
'last_request_time': 0,
'request_queue': Queue(),
'lock': threading.Lock()
}
}
self.rate_limiting_enabled = True
def setup_enhanced_config(self, max_retries: int = 3):
"""Configure LiteLLM with retry logic and instant QPS tracking"""
litellm.num_retries = max_retries
litellm.request_timeout = 300
litellm.retry_policy = {
"RateLimitError": {
"max_retries": 5,
"exponential_backoff": True,
"initial_delay": 1,
"max_delay": 60,
"jitter": True
},
"APIConnectionError": {
"max_retries": 3,
"exponential_backoff": True,
"initial_delay": 2,
"max_delay": 30,
"jitter": True
},
"InternalServerError": {
"max_retries": 2,
"exponential_backoff": True,
"initial_delay": 5,
"max_delay": 60,
"jitter": True
},
"BadRequestError": {
"max_retries": 1,
"exponential_backoff": False,
"initial_delay": 1,
"max_delay": 5
}
}
litellm.success_callback = [self._success_callback]
litellm.failure_callback = [self._failure_callback]
litellm.completion_cost_tracking = True
litellm.set_verbose = False
litellm.modify_params = True
print("✅ LiteLLM configured with instant QPS tracking and rate limiting")
def _success_callback(self, kwargs, completion_response, start_time, end_time):
"""Callback for successful requests with module-specific QPS tracking"""
try:
# Extract usage information
usage = completion_response.get('usage', {})
model = kwargs.get('model', 'unknown')
# Extract module information from metadata or model name
module = self._extract_module_name(kwargs, model)
# Calculate cost
cost = 0.0
try:
cost = litellm.completion_cost(completion_response)
except:
pass
# Calculate duration
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
# Record usage data
usage_record = {
"timestamp": datetime.now().isoformat(),
"model": model,
"module": module,
"input_tokens": usage.get('prompt_tokens', 0),
"output_tokens": usage.get('completion_tokens', 0),
"total_tokens": usage.get('total_tokens', 0),
"cost": cost,
"duration_seconds": duration_seconds,
"status": "success"
}
self.usage_data.append(usage_record)
# Update module-specific stats for QPS tracking
self._update_module_stats(module, usage_record, success=True)
# Print real-time feedback
print(f"{model}: {usage_record['input_tokens']}{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s")
except Exception as e:
print(f"Warning: Success callback failed: {e}")
def _failure_callback(self, kwargs, completion_response, start_time, end_time):
"""Callback for failed requests with module-specific error tracking"""
try:
model = kwargs.get('model', 'unknown')
module = self._extract_module_name(kwargs, model)
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
# Handle different error response formats
error_message = "Unknown error"
error_type = "UnknownError"
# According to LiteLLM docs, completion_response contains the exception for failures
if completion_response is not None:
error_message = str(completion_response)
error_type = type(completion_response).__name__
# Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events)
elif 'exception' in kwargs:
exception = kwargs['exception']
error_message = str(exception)
error_type = type(exception).__name__
# Check for other error formats in kwargs
elif 'error' in kwargs:
error = kwargs['error']
error_message = str(error)
error_type = type(error).__name__
# Check log_event_type to confirm this is a failure event
log_event_type = kwargs.get('log_event_type', '')
if log_event_type == 'failed_api_call' and 'exception' in kwargs:
exception = kwargs['exception']
error_message = str(exception)
error_type = type(exception).__name__
error_record = {
"timestamp": datetime.now().isoformat(),
"model": model,
"module": module,
"error": error_message,
"error_type": error_type,
"duration_seconds": duration_seconds,
"status": "failed"
}
self.error_data.append(error_record)
# Update module-specific stats for error tracking
self._update_module_stats(module, error_record, success=False)
# Print error feedback
print(f"{model}: {error_type} - {error_message[:100]}")
except Exception as e:
print(f"Warning: Failure callback failed: {e}")
# Debug: print the actual parameters to understand the structure
print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}")
print(f"Debug - completion_response type: {type(completion_response)}")
print(f"Debug - completion_response: {completion_response}")
def _should_rate_limit(self, model: str) -> bool:
"""Check if the model should be rate limited"""
if not self.rate_limiting_enabled:
return False
return model.startswith('bedrock/') or 'bedrock' in model.lower()
def _enforce_rate_limit(self, model: str):
"""Enforce rate limiting for AWS Bedrock models"""
if not self._should_rate_limit(model):
return
provider = 'bedrock'
if provider not in self.rate_limits:
return
rate_config = self.rate_limits[provider]
with rate_config['lock']:
current_time = time.time()
time_since_last = current_time - rate_config['last_request_time']
min_interval = 1.0 / rate_config['requests_per_second']
if time_since_last < min_interval:
sleep_time = min_interval - time_since_last
print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}")
time.sleep(sleep_time)
rate_config['last_request_time'] = time.time()
def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str:
"""Extract module name from request context"""
# Try to get module from metadata
metadata = kwargs.get('metadata', {})
if 'module' in metadata:
return metadata['module']
# Try to infer from model name or other context
if 'claude' in model.lower():
return 'bedrock_client'
elif 'gpt' in model.lower() or 'openai' in model.lower():
return 'openai_client'
elif 'embed' in model.lower():
return 'embedder'
else:
return 'unknown'
def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool):
"""Update module-specific statistics with instant QPS tracking"""
current_timestamp = time.time()
current_time = datetime.now()
# Initialize module stats if first request
if self.module_stats[module]['start_time'] is None:
self.module_stats[module]['start_time'] = current_time
# Update counters
self.module_stats[module]['requests'] += 1
self.module_stats[module]['last_request_time'] = current_time
self.module_stats[module]['request_timestamps'].append(current_timestamp)
self.global_request_timestamps.append(current_timestamp)
# Calculate instant QPS for this module
self._calculate_instant_qps(module, current_timestamp)
# Calculate global instant QPS
self._calculate_global_instant_qps(current_timestamp)
if success:
self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0)
self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0)
self.module_stats[module]['cost'] += record.get('cost', 0.0)
else:
self.module_stats[module]['errors'] += 1
def _calculate_instant_qps(self, module: str, current_timestamp: float):
"""Calculate instant QPS for a specific module using sliding window"""
# Keep only timestamps from last 1 second for instant QPS
cutoff_time = current_timestamp - 1.0
timestamps = self.module_stats[module]['request_timestamps']
# Remove old timestamps
self.module_stats[module]['request_timestamps'] = [
ts for ts in timestamps if ts >= cutoff_time
]
# Calculate current QPS (requests in last second)
current_qps = len(self.module_stats[module]['request_timestamps'])
self.module_stats[module]['current_qps'] = current_qps
# Update max QPS if current is higher
if current_qps > self.module_stats[module]['max_qps']:
self.module_stats[module]['max_qps'] = current_qps
# Store QPS history (keep last 60 measurements)
self.module_stats[module]['qps_history'].append(current_qps)
if len(self.module_stats[module]['qps_history']) > 60:
self.module_stats[module]['qps_history'].pop(0)
def _calculate_global_instant_qps(self, current_timestamp: float):
"""Calculate global instant QPS across all modules"""
# Keep only timestamps from last 1 second
cutoff_time = current_timestamp - 1.0
self.global_request_timestamps = [
ts for ts in self.global_request_timestamps if ts >= cutoff_time
]
# Calculate current global QPS
current_global_qps = len(self.global_request_timestamps)
# Update max global QPS
if current_global_qps > self.global_max_qps:
self.global_max_qps = current_global_qps
def get_instant_qps(self, module: str = None) -> Dict[str, Any]:
"""Get instant QPS data for modules"""
if module:
if module in self.module_stats:
return {
'module': module,
'current_qps': self.module_stats[module]['current_qps'],
'max_qps': self.module_stats[module]['max_qps'],
'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0
}
else:
return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0}
else:
# Return data for all modules plus global
result = {
'global': {
'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]),
'max_qps': self.global_max_qps
},
'modules': {}
}
for mod in self.module_stats.keys():
result['modules'][mod] = {
'current_qps': self.module_stats[mod]['current_qps'],
'max_qps': self.module_stats[mod]['max_qps'],
'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0
}
return result
def get_usage_summary(self) -> Dict[str, Any]:
"""Get essential usage statistics"""
if not self.usage_data:
return {
"total_requests": 0,
"total_cost": 0.0,
"error_rate": 0.0,
"message": "No usage data available"
}
total_requests = len(self.usage_data)
total_errors = len(self.error_data)
total_cost = sum(record['cost'] for record in self.usage_data)
total_input_tokens = sum(record['input_tokens'] for record in self.usage_data)
total_output_tokens = sum(record['output_tokens'] for record in self.usage_data)
# Calculate session duration
duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60
# Build module statistics
module_stats = {}
for module, stats in self.module_stats.items():
if stats['requests'] > 0:
module_stats[module] = {
"requests": stats['requests'],
"errors": stats['errors'],
"success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0,
"tokens_in": stats['tokens_in'],
"tokens_out": stats['tokens_out'],
"cost": stats['cost'],
"current_qps": stats['current_qps'],
"max_qps": stats['max_qps']
}
return {
"session_duration_minutes": duration_minutes,
"total_requests": total_requests,
"total_errors": total_errors,
"error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_cost": total_cost,
"module_stats": module_stats,
"global_max_qps": self.global_max_qps
}
def print_usage_summary(self):
"""Print essential usage summary"""
stats = self.get_usage_summary()
if stats.get('message'):
print(f"📊 {stats['message']}")
return
print(f"\n📊 USAGE SUMMARY")
print(f"{'='*50}")
print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min")
print(f"📈 Requests: {stats['total_requests']}")
print(f"❌ Errors: {stats['total_errors']}")
print(f"💰 Cost: ${stats['total_cost']:.4f}")
print(f"🏆 Global Max QPS: {stats['global_max_qps']}")
# Module statistics
if stats.get('module_stats'):
print(f"\n📦 MODULES:")
for module, mod_stats in stats['module_stats'].items():
print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}")
print(f"{'='*50}")
def save_usage_data(self, filename: str = "litellm_usage.json"):
"""Save usage data to JSON file"""
data = {
"summary": self.get_usage_summary(),
"detailed_usage": self.usage_data,
"errors": self.error_data,
"export_timestamp": datetime.now().isoformat()
}
with open(filename, 'w') as f:
json.dump(data, f, indent=2)
print(f"📁 Usage data saved to {filename}")
def reset_tracking(self):
"""Reset all tracking data"""
self.usage_data = []
self.error_data = []
self.module_stats = defaultdict(lambda: {
'requests': 0,
'tokens_in': 0,
'tokens_out': 0,
'cost': 0.0,
'errors': 0,
'start_time': None,
'last_request_time': None,
'request_timestamps': [],
'current_qps': 0.0,
'max_qps': 0.0,
'qps_history': []
})
self.global_request_timestamps = []
self.global_max_qps = 0.0
self.start_time = datetime.now()
print("🔄 All tracking data reset")
# Global instance for easy access
litellm_config = LiteLLMConfig()
def setup_litellm_enhanced(max_retries: int = 3):
"""
Quick setup function for LiteLLM enhanced configuration
Args:
max_retries: Maximum number of retries for failed requests
"""
litellm_config.setup_enhanced_config(max_retries)
return litellm_config
def get_usage_summary():
"""Get current usage summary"""
return litellm_config.get_usage_summary()
def print_usage_summary():
"""Print current usage summary"""
litellm_config.print_usage_summary()
def save_usage_data(filename: str = "litellm_usage.json"):
"""Save usage data to file"""
litellm_config.save_usage_data(filename)
def get_instant_qps(module: str = None) -> Dict[str, Any]:
"""Get instant QPS data for modules"""
return litellm_config.get_instant_qps(module)
def print_instant_qps(module: str = None):
"""Print instant QPS information"""
qps_data = get_instant_qps(module)
print(f"\n⚡ INSTANT QPS MONITOR")
print(f"{'='*60}")
if module:
print(f"Module: {qps_data['module']}")
print(f" Current QPS: {qps_data['current_qps']}")
print(f" Max QPS: {qps_data['max_qps']}")
print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}")
else:
# Global stats
global_data = qps_data.get('global', {})
print(f"🌍 GLOBAL:")
print(f" Current QPS: {global_data.get('current_qps', 0)}")
print(f" Max QPS: {global_data.get('max_qps', 0)}")
# Module stats
modules = qps_data.get('modules', {})
if modules:
print(f"\n📦 MODULES:")
for mod, data in modules.items():
print(f" {mod}:")
print(f" Current: {data['current_qps']} QPS")
print(f" Max: {data['max_qps']} QPS")
print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS")
print(f"{'='*60}")
def reset_tracking():
"""Reset all tracking data"""
litellm_config.reset_tracking()
def get_module_stats() -> Dict[str, Dict[str, Any]]:
"""Get detailed module statistics"""
summary = get_usage_summary()
return summary.get('module_stats', {})

View File

@@ -0,0 +1,611 @@
"""
运行时配置覆写工具 - 统一配置加载器
本模块作为统一的配置加载器,负责从多个来源加载配置并按优先级覆写。
配置来源优先级(从高到低):
1. 数据库配置PostgreSQL data_config 表)
2. 环境变量配置(.env 文件)
3. 默认配置runtime.json 文件)
支持的配置加载方式:
- 基于 config_id 的配置加载(从 dbrun.json 读取或前端传入)
- 基于 group_id 的配置加载(从 dbrun.json 读取)
- 环境变量覆写(支持 INTERNAL/EXTERNAL 网络模式)
主要功能:
- 从 PostgreSQL 数据库读取配置
- 从环境变量读取配置
- 从 runtime.json 读取默认配置
- 按优先级覆写配置项(仅在内存中,不修改文件)
- 支持多种配置字段selections、statement_extraction、deduplication、forgetting_engine、pruning、reflexion
使用场景:
- 应用启动时自动加载配置
- 前端切换配置时动态重新加载
- 多租户场景下的配置隔离
- 内外网环境自动切换
"""
import os
import json
import socket
from typing import Optional, Dict, Any, Literal
NetworkMode = Literal['internal', 'external']
def _set_if_present(target: Dict[str, Any], target_key: str, src: Dict[str, Any], src_key: str, caster):
"""安全地设置目标字典的值(如果源字典中存在且不为 None
Args:
target: 目标字典
target_key: 目标字典的键
src: 源字典
src_key: 源字典的键
caster: 类型转换函数
"""
try:
if src_key in src and src.get(src_key) is not None:
try:
target[target_key] = caster(src.get(src_key))
except Exception:
pass
except Exception:
pass
def _to_bool(val: Any) -> bool:
"""将各种类型的值转换为布尔值
支持的输入:
- bool: 直接返回
- int/float: 非零为 True
- str: "true", "1", "on", "yes" 为 True"false", "0", "off", "no" 为 False
Args:
val: 要转换的值
Returns:
bool: 转换后的布尔值
"""
try:
if isinstance(val, bool):
return val
if isinstance(val, (int, float)):
return bool(val)
if isinstance(val, str):
m = val.strip().lower()
if m in {"true", "1", "on", "yes"}:
return True
if m in {"false", "0", "off", "no"}:
return False
return bool(val)
except Exception:
return False
def _make_pgsql_conn() -> Optional[object]:
"""创建 PostgreSQL 数据库连接
使用环境变量配置连接参数:
- DB_HOST: 数据库主机地址(默认 localhost
- DB_PORT: 数据库端口(默认 5432
- DB_USER: 数据库用户名
- DB_PASSWORD: 数据库密码
- DB_NAME: 数据库名称
Returns:
Optional[object]: 数据库连接对象,失败时返回 None
"""
host = os.getenv("DB_HOST", "localhost")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
dbname = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
from psycopg2.extras import RealDictCursor # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host,
port=port,
user=user,
password=password,
dbname=dbname,
)
conn.autocommit = True
return conn
except Exception:
return None
def _fetch_db_config_by_group_id(group_id: str) -> Optional[Dict[str, Any]]:
"""根据 group_id 从数据库查询配置
Args:
group_id: 组标识符
Returns:
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
"""
conn = _make_pgsql_conn()
if conn is None:
return None
try:
from psycopg2.extras import RealDictCursor # type: ignore
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
except Exception:
pass
sql = (
"SELECT group_id, user_id, apply_id, chunker_strategy, "
" enable_llm_dedup_blockwise, enable_llm_disambiguation "
"FROM data_config WHERE group_id = %s ORDER BY updated_at DESC LIMIT 1"
)
cur.execute(sql, (group_id,))
row = cur.fetchone()
return row if row else None
except Exception:
return None
finally:
try:
cur.close()
except Exception:
pass
try:
conn.close()
except Exception:
pass
def _fetch_db_config_by_config_id(config_id: int | str) -> Optional[Dict[str, Any]]:
"""根据 config_id 从数据库查询配置
Args:
config_id: 配置标识符(整数或字符串,会自动转换为整数)
Returns:
Optional[Dict[str, Any]]: 配置字典,未找到时返回 None
"""
conn = _make_pgsql_conn()
if conn is None:
try:
pass
except Exception:
pass
return None
try:
from psycopg2.extras import RealDictCursor # type: ignore
cur = conn.cursor(cursor_factory=RealDictCursor)
try:
cur.execute("SET TIME ZONE %s", ("Asia/Shanghai",))
except Exception:
pass
# config_id 在数据库中是 Integer 类型,需要转换
try:
config_id_int = int(config_id)
except (ValueError, TypeError) as e:
try:
pass
except Exception:
pass
return None
sql = (
"SELECT config_id, group_id, user_id, apply_id, chunker_strategy, "
" enable_llm_dedup_blockwise, enable_llm_disambiguation, "
" deep_retrieval, t_type_strict, t_name_strict, t_overall, state, "
" statement_granularity, include_dialogue_context, max_context, "
" \"offset\" AS offset, lambda_time, lambda_mem, "
" pruning_enabled, pruning_scene, pruning_threshold, "
" llm_id, embedding_id "
"FROM data_config WHERE config_id = %s LIMIT 1"
)
cur.execute(sql, (config_id_int,))
row = cur.fetchone()
if row:
try:
pass
except Exception:
pass
else:
pass
return row if row else None
except Exception as e:
pass
return None
finally:
try:
cur.close()
except Exception:
pass
try:
conn.close()
except Exception:
pass
def _load_dbrun_group_id(project_root: str) -> Optional[str]:
"""从 dbrun.json 读取 group_id
Args:
project_root: 项目根目录路径
Returns:
Optional[str]: group_id未找到时返回 None
"""
try:
path = os.path.join(project_root, "dbrun.json")
if not os.path.isfile(path):
return None
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
if "group_id" in data:
return str(data.get("group_id"))
sel = data.get("selections", {})
if isinstance(sel, dict) and "group_id" in sel:
return str(sel.get("group_id"))
return None
except Exception:
return None
def _load_dbrun_config_id(project_root: str) -> Optional[str]:
"""从 dbrun.json 读取 config_id
Args:
project_root: 项目根目录路径
Returns:
Optional[str]: config_id未找到时返回 None
"""
try:
path = os.path.join(project_root, "dbrun.json")
if not os.path.isfile(path):
return None
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict):
if "config_id" in data:
return str(data.get("config_id"))
sel = data.get("selections", {})
if isinstance(sel, dict) and "config_id" in sel:
return str(sel.get("config_id"))
return None
except Exception:
return None
def _apply_overrides_from_db_row(
runtime_cfg: Dict[str, Any],
db_row: Optional[Dict[str, Any]],
identifier: str,
identifier_type: str = "config_id"
) -> Dict[str, Any]:
"""从数据库行数据覆写运行时配置(统一处理函数)
Args:
runtime_cfg: 运行时配置字典
db_row: 数据库查询结果行
identifier: 标识符值group_id 或 config_id
identifier_type: 标识符类型("group_id""config_id"
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selections = runtime_cfg.setdefault("selections", {})
selections[identifier_type] = identifier
if not db_row:
return runtime_cfg
# 覆写 selections 字段
for tk in ("group_id", "user_id", "apply_id", "chunker_strategy", "state",
"t_type_strict", "t_name_strict", "t_overall",
"statement_granularity", "include_dialogue_context"):
_set_if_present(selections, tk, db_row, tk, str)
# 特殊处理 UUID 字段,确保转换为字符串格式
for uuid_field in ("llm_id", "embedding_id"):
if uuid_field in db_row and db_row.get(uuid_field) is not None:
try:
value = db_row.get(uuid_field)
# 如果是 UUID 对象,转换为字符串(带连字符的标准格式)
if hasattr(value, 'hex'):
selections[uuid_field] = str(value)
else:
selections[uuid_field] = str(value)
except Exception:
pass
# 覆写 statement_extraction 字段
stmt = runtime_cfg.setdefault("statement_extraction", {})
_set_if_present(stmt, "statement_granularity", db_row, "statement_granularity", int)
_set_if_present(stmt, "include_dialogue_context", db_row, "include_dialogue_context", _to_bool)
_set_if_present(stmt, "max_dialogue_context_chars", db_row, "max_context", int)
# 覆写 deduplication 字段
dedup = runtime_cfg.setdefault("deduplication", {})
for tk in ("enable_llm_dedup_blockwise", "enable_llm_disambiguation"):
_set_if_present(dedup, tk, db_row, tk, _to_bool)
_set_if_present(dedup, "deep_retrieval", db_row, "deep_retrieval", _to_bool)
# 覆写 forgetting_engine 字段
forgetting = runtime_cfg.setdefault("forgetting_engine", {})
_set_if_present(forgetting, "offset", db_row, "offset", float)
_set_if_present(forgetting, "lambda_time", db_row, "lambda_time", float)
_set_if_present(forgetting, "lambda_mem", db_row, "lambda_mem", float)
# 覆写 pruning 字段
pruning = runtime_cfg.setdefault("pruning", {})
_set_if_present(pruning, "enabled", db_row, "pruning_enabled", _to_bool)
_set_if_present(pruning, "scene", db_row, "pruning_scene", str)
# 阈值需要转为 float且限制在 [0.0, 0.9]
try:
if "pruning_threshold" in db_row and db_row.get("pruning_threshold") is not None:
thr = float(db_row.get("pruning_threshold"))
thr = max(0.0, min(0.9, thr)) # 限制在 [0.0, 0.9]
pruning["threshold"] = thr
except Exception:
pass
return runtime_cfg
except Exception as e:
pass
return runtime_cfg
def apply_runtime_overrides_by_group(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""基于 group_id 从数据库覆写运行时配置
工作流程:
1. 从 dbrun.json 读取 group_id
2. 根据 group_id 查询数据库配置
3. 覆写运行时配置(仅在内存中)
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selected_gid = _load_dbrun_group_id(project_root)
if not selected_gid:
return runtime_cfg
db_row = _fetch_db_config_by_group_id(selected_gid)
if not db_row:
# 如果数据库中没有配置,仍然设置 group_id
runtime_cfg.setdefault("selections", {})["group_id"] = selected_gid
return runtime_cfg
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_gid, "group_id")
except Exception:
return runtime_cfg
def apply_runtime_overrides_by_config(project_root: str, runtime_cfg: Dict[str, Any]) -> Dict[str, Any]:
"""基于 config_id 从数据库覆写运行时配置(从 dbrun.json 读取)
工作流程:
1. 从 dbrun.json 读取 config_id
2. 根据 config_id 查询数据库配置
3. 覆写运行时配置(仅在内存中)
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
Returns:
Dict[str, Any]: 覆写后的运行时配置
"""
try:
selected_cid = _load_dbrun_config_id(project_root)
if not selected_cid:
return runtime_cfg
db_row = _fetch_db_config_by_config_id(selected_cid)
return _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
except Exception:
return runtime_cfg
def apply_runtime_overrides_with_config_id(
project_root: str,
runtime_cfg: Dict[str, Any],
config_id: str
) -> tuple[Dict[str, Any], bool]:
"""使用指定的 config_id 从数据库覆写运行时配置(不读 dbrun.json
用于前端动态切换配置的场景。
Args:
project_root: 项目根目录路径
runtime_cfg: 运行时配置字典
config_id: 配置标识符
Returns:
tuple[Dict[str, Any], bool]: (覆写后的运行时配置, 是否成功从数据库加载)
"""
try:
selected_cid = str(config_id).strip()
if not selected_cid:
return runtime_cfg, False
db_row = _fetch_db_config_by_config_id(selected_cid)
if db_row is None:
return runtime_cfg, False
updated_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, selected_cid, "config_id")
return updated_cfg, True
except Exception as e:
pass
return runtime_cfg, False
# ============================================================================
# 以下函数已注释:不再需要网络模式自动检测功能
# ============================================================================
# def get_server_ip() -> str:
# """
# 获取当前服务器的IP地址
#
# Returns:
# 服务器IP地址字符串
# """
# try:
# # 方式1从环境变量获取优先
# server_ip = os.getenv('SERVER_IP')
# if server_ip and server_ip not in ['127.0.0.1', 'localhost', '0.0.0.0']:
# return server_ip
#
# # 方式2通过socket获取
# hostname = socket.gethostname()
# ip_address = socket.gethostbyname(hostname)
#
# # 如果是本地回环地址尝试获取真实IP
# if ip_address.startswith('127.'):
# # 尝试连接外部地址来获取本机IP
# s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# try:
# s.connect(('8.8.8.8', 80))
# ip_address = s.getsockname()[0]
# finally:
# s.close()
#
# return ip_address
# except Exception as e:
# print(f"[overrides] 获取服务器IP失败: {e},使用默认值 127.0.0.1")
# return '127.0.0.1'
# def auto_detect_network_mode() -> NetworkMode:
# """
# 自动检测网络模式基于服务器IP
#
# 规则:
# - 如果服务器IP在内网IP列表中 → internal内网
# - 其他IP → external外网
#
# 可以通过环境变量 INTERNAL_SERVER_IPS 自定义内网IP列表逗号分隔
#
# Returns:
# 'internal' 或 'external'
# """
# server_ip = get_server_ip()
#
# # 从环境变量获取内网IP列表支持多个IP逗号分隔
# internal_ips_str = os.getenv('INTERNAL_SERVER_IPS', '119.45.181.55')
# internal_ips = [ip.strip() for ip in internal_ips_str.split(',')]
#
# # 判断当前IP是否在内网IP列表中
# if server_ip in internal_ips:
# print(f"[overrides] 自动检测服务器IP {server_ip} 属于内网,使用 INTERNAL 配置")
# return 'internal'
# else:
# print(f"[overrides] 自动检测服务器IP {server_ip} 属于外网,使用 EXTERNAL 配置")
# return 'external'
# ============================================================================
# 环境变量覆写功能已废弃 - 不再使用
# ============================================================================
# def _apply_env_var_overrides(runtime_cfg: Dict[str, Any], network_mode: NetworkMode = None, force_override: bool = False) -> Dict[str, Any]:
# """
# 从环境变量覆写配置(已废弃)
# """
# return runtime_cfg
def load_unified_config(
project_root: str,
config_id: Optional[int | str] = None,
group_id: Optional[str] = None,
network_mode: NetworkMode = None,
env_override_models: bool = True
) -> Dict[str, Any]:
"""
统一配置加载器 - 按优先级加载配置
配置加载优先级:
1. PG数据库配置最高优先级通过 dbrun.json 中的 config_id 读取)
2. runtime.json 默认配置(最低优先级)
Args:
project_root: 项目根目录路径
config_id: 配置ID整数或字符串可选优先从 dbrun.json 读取)
group_id: 组ID可选
network_mode: 已废弃,保留参数仅为向后兼容
env_override_models: 已废弃,保留参数仅为向后兼容
Returns:
Dict[str, Any]: 最终的运行时配置
"""
try:
# 步骤 1: 加载 runtime.json 作为基础配置
runtime_config_path = os.path.join(project_root, "runtime.json")
try:
with open(runtime_config_path, "r", encoding="utf-8") as f:
runtime_cfg = json.load(f)
except (FileNotFoundError, json.JSONDecodeError) as e:
runtime_cfg = {"selections": {}}
# 步骤 2: 尝试从 dbrun.json 读取 config_id 并应用数据库配置(最高优先级)
if config_id:
# 优先使用传入的 config_id
db_row = _fetch_db_config_by_config_id(config_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, config_id, "config_id")
pass
elif group_id:
# 其次使用 group_id
db_row = _fetch_db_config_by_group_id(group_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, group_id, "group_id")
pass
else:
# 尝试从 dbrun.json 读取
dbrun_config_id = _load_dbrun_config_id(project_root)
if dbrun_config_id:
db_row = _fetch_db_config_by_config_id(dbrun_config_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_config_id, "config_id")
pass
else:
dbrun_group_id = _load_dbrun_group_id(project_root)
if dbrun_group_id:
db_row = _fetch_db_config_by_group_id(dbrun_group_id)
if db_row:
runtime_cfg = _apply_overrides_from_db_row(runtime_cfg, db_row, dbrun_group_id, "group_id")
pass
return runtime_cfg
except Exception as e:
return {"selections": {}}
# 向后兼容的别名
apply_runtime_overrides = apply_runtime_overrides_by_config

View File

@@ -0,0 +1,43 @@
"""
数据处理模块
包含所有数据处理相关的工具函数,包括文本处理、时间处理和本体定义。
"""
# 从子模块导出常用函数和类,保持向后兼容
from .text_utils import (
escape_lucene_query,
extract_plain_query,
)
from .time_utils import (
validate_date_format,
normalize_date,
normalize_date_safe,
preprocess_date_string,
)
from .ontology import (
PREDICATE_DEFINITIONS,
LABEL_DEFINITIONS,
Predicate,
StatementType,
TemporalInfo,
RelevenceInfo,
)
__all__ = [
# text_utils
"escape_lucene_query",
"extract_plain_query",
# time_utils
"validate_date_format",
"normalize_date",
"normalize_date_safe",
"preprocess_date_string",
# ontology
"PREDICATE_DEFINITIONS",
"LABEL_DEFINITIONS",
"Predicate",
"StatementType",
"TemporalInfo",
"RelevenceInfo",
]

View File

@@ -0,0 +1,199 @@
from enum import StrEnum
# Use jinja template.render
PREDICATE_DEFINITIONS = {
"IS_A": "Denotes a class-or-type relationship between two entities (e.g., 'Model Y IS_A electric-SUV'). Includes 'is' and 'was'.",
"HAS_A": "Denotes a part-whole relationship between two entities (e.g., 'Model Y HAS_A electric-engine'). Includes 'has' and 'had'.",
"LOCATED_IN": "Specifies geographic or organisational containment or proximity (e.g., headquarters LOCATED_IN Berlin).",
"HOLDS_ROLE": "Connects a person to a formal office or title within an organisation (CEO, Chair, Director, etc.).",
"PRODUCES": "Indicates that an entity manufactures, builds, or creates a product, service, or infrastructure (includes scale-ups and component inclusion).",
"SELLS": "Marks a commercial seller-to-customer relationship for a product or service (markets, distributes, sells).",
"LAUNCHED": "Captures the official first release, shipment, or public start of a product, service, or initiative.",
"DEVELOPED": "Shows design, R&D, or innovation origin of a technology, product, or capability. Includes 'researched' or 'created'.",
"ADOPTED_BY": "Indicates that a technology or product has been taken up, deployed, or implemented by another entity.",
"INVESTS_IN": "Represents the flow of capital or resources from one entity into another (equity, funding rounds, strategic investment).",
"COLLABORATES_WITH": "Generic partnership, alliance, joint venture, or licensing relationship between entities.",
"SUPPLIES": "Captures vendorclient supply-chain links or dependencies (provides to, sources from).",
"HAS_REVENUE": "Associates an entity with a revenue amount or metric—actual, reported, or projected.",
"INCREASED": "Expresses an upward change in a metric (revenue, market share, output) relative to a prior period or baseline.",
"DECREASED": "Expresses a downward change in a metric relative to a prior period or baseline.",
"RESULTED_IN": "Captures a causal relationship where one event or factor leads to a specific outcome (positive or negative).",
"TARGETS": "Denotes a strategic objective, market segment, or customer group that an entity seeks to reach.",
"PART_OF": "Expresses hierarchical membership or subset relationships (division, subsidiary, managed by, belongs to).",
"DISCONTINUED": "Indicates official end-of-life, shutdown, or termination of a product, service, or relationship.",
"SECURED": "Marks the successful acquisition of funding, contracts, assets, or rights by an entity.",
"MENTIONS": "Denotes a reference or mention of an entity in a text or document.",
# 移除了过于宽泛的谓语集合
# "MENTIONS": "Denotes a reference or mention of an entity in a text or document." ,
# "FEELS" : "Denotes a subjective opinion or feeling about an entity (e.g., 'I feel like X').Includes 'THINKS'.",
# "HELPS" :"Express a action that make it easier or possible for (someone) to do something by offering one's services or resources. Includes 'assist', 'aid' and 'support' " ,
# "IS_DOING" : "Denotes a subjective action or activity about an entity (e.g., 'I am doing X').Includes 'DOES'.",
# "LIKES": "Express enjoy or approve of something or someone (e.g., 'I like roses').Includes 'LIKES'.",
# "DISLIKES": "Express dislike or disapprove of something or someone (e.g., 'I dislike roses').Includes 'DISLIKES'.",
# "HAS_ATTRIBUTE": "Express that an entity has a certain attribute (e.g., 'X has a red car').Includes 'HAS'.",
}
LABEL_DEFINITIONS: dict[str, dict[str, dict[str, str]]] = {
"statement_labelling": {
"FACT": dict(
definition=(
"Statements that are objective and can be independently "
"verified or falsified through evidence."
),
date_handling_guidance=(
"These statements can be made up of multiple static and "
"dynamic temporal events marking for example the start, end, "
"and duration of the fact described statement."
),
date_handling_example=(
"'Company A owns Company B in 2022', 'X caused Y to happen', "
"or 'John said X at Event' are verifiable facts which currently "
"hold true unless we have a contradictory fact."
),
),
"OPINION": dict(
definition=(
"Statements that contain personal opinions, feelings, values, "
"or judgments that are not independently verifiable. It also "
"includes hypothetical and speculative statements."
),
date_handling_guidance=(
"This statement is always static. It is a record of the date the "
"opinion was made."
),
date_handling_example=(
"'I like Company A's strategy', 'X may have caused Y to happen', "
"or 'The event felt like X' are opinions and down to the reporters "
"interpretation."
),
),
"PREDICTION": dict(
definition=(
"Uncertain statements about the future on something that might happen, "
"a hypothetical outcome, unverified claims. "
"If the tense of the statement changed, the statement "
"would then become a fact."
),
date_handling_guidance=(
"This statement is always static. It is a record of the date the "
"prediction was made."
),
date_handling_example=(
"'It is rumoured that Dave will resign next month', 'Company A expects "
"X to happen', or 'X suggests Y' are all predictions."
),
),
"SUGGESTION": dict(
definition=(
"A proposal or recommendation for action, often implying a future course of conduct. "
" It's not a statement of fact or a prediction, but rather an advised path. "
"It's a suggestion for action that is not yet implemented."
),
date_handling_guidance=(
"This statement is always static."
),
date_handling_example=(
"'They should launch the new product next quarter', 'You could try a different approach', "
"or 'I would recommend moving the headquarters to Berlin' are all suggestions."
),
),
},
"temporal_labelling": {
"STATIC": dict(
definition=(
"Often past tense, think -ed verbs, describing single points-in-time. "
"These statements are valid from the day they occurred and are never "
"invalid. Refer to single points in time at which an event occurred, "
"the fact X occurred on that date will always hold true."
),
date_handling_guidance=(
"The valid_at date is the date the event occurred. The invalid_at date "
"is None."
),
date_handling_example=(
"'John was appointed CEO on 4th Jan 2024', 'Company A reported X percent "
"growth from last FY', or 'X resulted in Y to happen' are valid the day "
"they occurred and are never invalid."
),
),
"DYNAMIC": dict(
definition=(
"Often present tense, think -ing verbs, describing a period of time. "
"These statements are valid for a specific period of time and are usually "
"invalidated by a Static fact marking the end of the event or start of a "
"contradictory new one. The statement could already be referring to a "
"discrete time period (invalid) or may be an ongoing relationship (not yet "
"invalid)."
),
date_handling_guidance=(
"The valid_at date is the date the event started. The invalid_at date is "
"the date the event or relationship ended, for ongoing events this is None."
),
date_handling_example=(
"'John is the CEO', 'Company A remains a market leader', or 'X is continuously "
"causing Y to decrease' are valid from when the event started and are invalidated "
"by a new event."
),
),
"ATEMPORAL": dict(
definition=(
"Statements that will always hold true regardless of time therefore have no "
"temporal bounds."
),
date_handling_guidance=(
"These statements are assumed to be atemporal and have no temporal bounds. Both "
"their valid_at and invalid_at are None."
),
date_handling_example=(
"'A stock represents a unit of ownership in a company', 'The earth is round', or "
"'Europe is a continent'. These statements are true regardless of time."
),
),
},
}
class Predicate(StrEnum):
"""Enumeration of normalised predicates."""
IS_A = "IS_A"
HAS_A = "HAS_A"
LOCATED_IN = "LOCATED_IN"
HOLDS_ROLE = "HOLDS_ROLE"
PRODUCES = "PRODUCES"
SELLS = "SELLS"
LAUNCHED = "LAUNCHED"
DEVELOPED = "DEVELOPED"
ADOPTED_BY = "ADOPTED_BY"
INVESTS_IN = "INVESTS_IN"
COLLABORATES_WITH = "COLLABORATES_WITH"
SUPPLIES = "SUPPLIES"
HAS_REVENUE = "HAS_REVENUE"
INCREASED = "INCREASED"
DECREASED = "DECREASED"
RESULTED_IN = "RESULTED_IN"
TARGETS = "TARGETS"
PART_OF = "PART_OF"
DISCONTINUED = "DISCONTINUED"
SECURED = "SECURED"
MENTIONS = "MENTIONS"
class StatementType(StrEnum):
FACT = "FACT"
OPINION = "OPINION"
PREDICTION = "PREDICTION"
SUGGESTION = "SUGGESTION"
class TemporalInfo(StrEnum):
ATEMPORAL = "ATEMPORAL"
STATIC = "STATIC"
DYNAMIC = "DYNAMIC"
# Relevance labelling for statements
class RelevenceInfo(StrEnum):
RELEVANT = "RELEVANT"
IRRELEVANT = "IRRELEVANT"

View File

@@ -0,0 +1,81 @@
import json
def escape_lucene_query(query: str) -> str:
"""Escape Lucene special characters in a free-text query.
This prevents ParseException when using Neo4j full-text procedures.
"""
if query is None:
return ""
s = str(query)
# Normalize whitespace
s = s.replace("\r", " ").replace("\n", " ").strip()
# Lucene reserved tokens/special characters
specials = ['&&', '||', '\\', '+', '-', '!', '(', ')', '{', '}', '[', ']', '^', '"', '~', '*', '?', ':']
# Replace longer tokens first to avoid partial double-escaping
for token in sorted(specials, key=len, reverse=True):
s = s.replace(token, f"\\{token}")
return s
def extract_plain_query(query_input: str) -> str:
"""Extract clean, plain-text query from various input forms.
- Strips surrounding quotes and whitespace
- If input looks like JSON, prefers the 'original' field
- Fallbacks to the raw string when parsing fails
"""
if query_input is None:
return ""
# Directly handle dict-like input
if isinstance(query_input, dict):
original = query_input.get("original")
if isinstance(original, str) and original.strip():
return original.strip()
context = query_input.get("context")
if isinstance(context, dict):
for key, val in context.items():
if isinstance(key, str) and key.strip():
return key.strip()
if isinstance(val, list) and val:
first = val[0]
if isinstance(first, str) and first.strip():
return first.strip()
# Fallback to string conversion below
s = str(query_input).strip()
# Remove surrounding single/double quotes if present
if (s.startswith("'") and s.endswith("'")) or (s.startswith('"') and s.endswith('"')):
s = s[1:-1].strip()
# Attempt to parse JSON and extract the 'original' field
if s.startswith("{") and s.endswith("}"):
try:
data = json.loads(s)
# Prefer 'original' field if available
original = data.get("original")
if isinstance(original, str) and original.strip():
return original.strip()
# Fallbacks: try common nested structures
context = data.get("context")
if isinstance(context, dict):
# Take the first key or first string value in context
for key, val in context.items():
if isinstance(key, str) and key.strip():
return key.strip()
if isinstance(val, list) and val:
first = val[0]
if isinstance(first, str) and first.strip():
return first.strip()
except Exception:
# Not valid JSON; keep as-is after best-effort unescape below
pass
# Best-effort unescape common escaped newlines/tabs without altering unicode
s = s.replace("\\n", " ").replace("\\t", " ")
return s

View File

@@ -0,0 +1,127 @@
import re
from dateutil import parser
from datetime import datetime
def validate_date_format(date_str: str) -> bool:
"""
Validate if the date string is in the format YYYY-MM-DD.
"""
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
return bool(re.match(pattern, date_str))
def normalize_date(date_str: str) -> str:
"""
更强大的日期标准化函数,支持多种日期格式转换为 Y-M-D 格式
Args:
date_str: 各种格式的日期字符串
Returns:
Y-M-D 格式的标准化日期字符串
"""
if not date_str or not isinstance(date_str, str):
return date_str
# 移除首尾空格
date_str = date_str.strip().replace(' ', '').replace('/', '').replace('.', '').replace('_', '').replace('-', '')
try:
# 预处理:识别并规范化特殊格式
preprocessed_str = preprocess_date_string(date_str)
# 使用 dateutil.parser 进行解析[citation:1][citation:7]
dt = parser.parse(preprocessed_str, dayfirst=False, yearfirst=True)
return dt.strftime('%Y-%m-%d')
except (ValueError, TypeError, OverflowError):
# 如果智能解析失败,尝试格式匹配
return fallback_parse(date_str)
def preprocess_date_string(date_str: str) -> str:
"""预处理日期字符串,处理特殊格式"""
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
if match:
year, month, day = match.groups()
# 如果年份超过4位可能是年份和月份连在一起
if len(year) > 4:
# 取前4位作为年份剩余作为月份
actual_year = year[:4]
actual_month = year[4:] + (month if month else '')
# 重新组合
if day:
return f"{actual_year}-{actual_month.zfill(2)}-{day.zfill(2)}"
else:
return f"{actual_year}-{actual_month.zfill(2)}"
else:
return f"{year}-{month.zfill(2)}-{day.zfill(2)}" if day else f"{year}-{month.zfill(2)}"
# 处理无分隔符的纯数字格式[citation:4]
if re.match(r'^\d{6,8}$', date_str):
if len(date_str) == 8: # YYYYMMDD
return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
elif len(date_str) == 6: # YYMMDD 或 MMDDYY
# 尝试不同解释
if 1 <= int(date_str[:2]) <= 12: # 可能是 MMDDYY
return f"20{date_str[4:6]}-{date_str[:2]}-{date_str[2:4]}"
else: # 可能是 YYMMDD
return f"20{date_str[:2]}-{date_str[2:4]}-{date_str[4:6]}"
# 处理混合分隔符,统一为 -
date_str = re.sub(r'[/\._]', '-', date_str)
return date_str
def fallback_parse(date_str: str) -> str:
"""备选解析方案"""
# 尝试常见的日期格式[citation:4][citation:5]
formats_to_try = [
'%Y-%m-%d', '%Y/%m/%d', '%Y.%m.%d',
'%Y%m%d', '%y%m%d',
'%m-%d-%Y', '%m/%d/%Y', '%m.%d.%Y',
'%d-%m-%Y', '%d/%m/%Y', '%d.%m.%Y',
'%Y-%m', '%Y/%m', '%Y.%m'
]
for fmt in formats_to_try:
try:
dt = datetime.strptime(date_str, fmt)
return dt.strftime('%Y-%m-%d')
except ValueError:
continue
# 所有方法都失败时,返回原字符串或抛出异常
return date_str
def normalize_date_safe(date_str: str, default: str = None) -> str:
"""
安全的日期标准化函数,提供默认值处理
Args:
date_str: 日期字符串
default: 解析失败时的默认返回值
Returns:
标准化日期字符串或默认值
"""
try:
result = normalize_date(date_str)
# 检查结果是否是有效的日期格式
if validate_date_format(result):
return result
else:
return default if default is not None else date_str
except:
return default if default is not None else date_str
if __name__ == "__main__":
start_dates = ["2025/10/28", "2025.10.28", "2025_10_28", "20251028"]
for date in start_dates:
print(normalize_date_safe(date))

View File

@@ -0,0 +1,18 @@
"""
LLM 工具模块
包含所有 LLM 客户端相关的工具函数。
"""
# 从子模块导出常用函数,保持向后兼容
from .llm_utils import (
get_llm_client,
get_reranker_client,
handle_response,
)
__all__ = [
"get_llm_client",
"get_reranker_client",
"handle_response",
]

View File

@@ -0,0 +1,77 @@
import os
from pydantic import BaseModel
from app.core.memory.src.llm_tools.openai_client import OpenAIClient
from app.core.memory.utils.config.config_utils import get_model_config
from app.core.memory.utils.config import definitions as config_defs
from app.core.models.base import RedBearModelConfig
async def handle_response(response: type[BaseModel]) -> dict:
return response.model_dump()
def get_llm_client(llm_id: str | None = None):
llm_id = llm_id or config_defs.SELECTED_LLM_ID
# Validate LLM ID exists before attempting to get config
if not llm_id:
raise ValueError("LLM ID is required but was not provided")
try:
model_config = get_model_config(llm_id)
except Exception as e:
# Re-raise with clear error message about invalid LLM ID
raise ValueError(f"Invalid LLM ID '{llm_id}': {str(e)}") from e
try:
# 移除调试打印,避免污染终端输出
# print(model_config)
llm_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
# print(llm.dict())
return llm_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize LLM client for model '{model_name}': {str(e)}") from e
def get_reranker_client(rerank_id: str | None = None):
"""
Get an LLM client configured for reranking.
Args:
rerank_id: Optional reranker model ID. If None, uses SELECTED_RERANK_ID.
Returns:
OpenAIClient: Initialized client for the reranker model
Raises:
ValueError: If rerank_id is invalid or client initialization fails
"""
rerank_id = rerank_id or config_defs.SELECTED_RERANK_ID
# Validate rerank ID exists before attempting to get config
if not rerank_id:
raise ValueError("Rerank ID is required but was not provided")
try:
model_config = get_model_config(rerank_id)
except Exception as e:
# Re-raise with clear error message about invalid rerank ID
raise ValueError(f"Invalid rerank ID '{rerank_id}': {str(e)}") from e
try:
reranker_client = OpenAIClient(RedBearModelConfig(
model_name=model_config.get("model_name"),
provider=model_config.get("provider"),
api_key=model_config.get("api_key"),
base_url=model_config.get("base_url")
),type_=model_config.get("type"))
return reranker_client
except Exception as e:
model_name = model_config.get('model_name', 'unknown')
raise ValueError(f"Failed to initialize reranker client for model '{model_name}': {str(e)}") from e

View File

@@ -0,0 +1,24 @@
"""
日志管理模块
包含所有日志相关的工具函数。
"""
# 从子模块导出常用函数,保持向后兼容
from .logging_utils import (
log_prompt_rendering,
log_template_rendering,
log_time,
prompt_logger,
)
from .audit_logger import audit_logger
__all__ = [
# logging_utils
"log_prompt_rendering",
"log_template_rendering",
"log_time",
"prompt_logger",
# audit_logger
"audit_logger",
]

View File

@@ -0,0 +1,182 @@
"""
配置审计日志记录器
提供专门的审计日志功能,用于追踪配置变更和操作记录。
"""
import logging
import os
from datetime import datetime
from typing import Optional, Dict, Any
def _format_value(value: Any) -> str:
"""
格式化值为字符串,特殊处理 UUID 等对象
Args:
value: 要格式化的值
Returns:
str: 格式化后的字符串
"""
if value is None:
return "None"
elif isinstance(value, bool):
return str(value)
elif hasattr(value, 'hex'): # UUID 对象有 hex 属性
return str(value) # 使用标准的 UUID 字符串格式(带连字符)
else:
return str(value)
class ConfigAuditLogger:
"""配置审计日志记录器"""
def __init__(self, log_file: str = "logs/config_audit.log"):
"""
初始化审计日志记录器
Args:
log_file: 日志文件路径
"""
self.logger = logging.getLogger("config_audit")
self.logger.setLevel(logging.INFO)
# 避免重复添加处理器
if not self.logger.handlers:
# 确保日志目录存在
log_dir = os.path.dirname(log_file)
if log_dir and not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
# 创建文件处理器
handler = logging.FileHandler(log_file, encoding='utf-8')
formatter = logging.Formatter(
'%(asctime)s [AUDIT] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
def log_config_load(
self,
config_id: str,
user_id: Optional[str] = None,
group_id: Optional[str] = None,
success: bool = True,
details: Optional[Dict[str, Any]] = None
):
"""
记录配置加载事件
Args:
config_id: 配置 ID
user_id: 用户 ID可选
group_id: 组 ID可选
success: 是否成功
details: 详细信息(可选)
"""
result = "SUCCESS" if success else "FAILED"
msg = (
f"CONFIG_LOAD config_id={config_id} "
f"user={user_id or 'N/A'} group={group_id or 'N/A'} "
f"result={result}"
)
if details:
# 格式化详细信息,确保所有值都正确转换为字符串
details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items())
msg += f" details=[{details_str}]"
self.logger.info(msg)
def log_config_change(
self,
config_id: str,
old_values: Dict[str, Any],
new_values: Dict[str, Any],
user_id: Optional[str] = None
):
"""
记录配置变更事件
Args:
config_id: 配置 ID
old_values: 旧配置值
new_values: 新配置值
user_id: 用户 ID可选
"""
changes = []
for key in new_values:
if key in old_values and old_values[key] != new_values[key]:
changes.append(f"{key}: {old_values[key]} -> {new_values[key]}")
if changes:
msg = (
f"CONFIG_CHANGE config_id={config_id} "
f"user={user_id or 'N/A'} "
f"changes=[{', '.join(changes)}]"
)
self.logger.info(msg)
def log_operation(
self,
operation: str,
config_id: str,
group_id: str,
success: bool = True,
duration: Optional[float] = None,
error: Optional[str] = None,
details: Optional[Dict[str, Any]] = None
):
"""
记录操作事件
Args:
operation: 操作类型WRITE, READ 等)
config_id: 配置 ID
group_id: 组 ID
success: 是否成功
duration: 操作耗时(秒)
error: 错误信息(可选)
details: 详细信息(可选)
"""
result = "SUCCESS" if success else "FAILED"
msg = (
f"{operation.upper()} config_id={config_id} "
f"group={group_id} result={result}"
)
if duration is not None:
msg += f" duration={duration:.2f}s"
if error:
msg += f" error={error}"
if details:
# 格式化详细信息,确保所有值都正确转换为字符串
details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items())
msg += f" details=[{details_str}]"
self.logger.info(msg)
def log_cache_event(
self,
event_type: str,
config_id: Optional[str] = None,
details: Optional[Dict[str, Any]] = None
):
"""
记录缓存事件
Args:
event_type: 事件类型HIT, MISS, CLEAR, EXPIRE
config_id: 配置 ID可选
details: 详细信息(可选)
"""
msg = f"CACHE_{event_type.upper()}"
if config_id:
msg += f" config_id={config_id}"
if details:
# 格式化详细信息,确保所有值都正确转换为字符串
details_str = ", ".join(f"{k}={_format_value(v)}" for k, v in details.items())
msg += f" details=[{details_str}]"
self.logger.info(msg)
# 全局审计日志记录器实例
audit_logger = ConfigAuditLogger()

View File

@@ -0,0 +1,38 @@
"""Logging utilities for prompt rendering and timing.
This module provides backward-compatible access to memory module logging utilities
that have been unified into the centralized logging system (app.core.logging_config).
All logging functions are now imported from the centralized configuration to ensure
consistent behavior, formatting, and configuration across the entire application.
For new code, consider importing directly from app.core.logging_config:
from app.core.logging_config import log_prompt_rendering, log_template_rendering, log_time
This module maintains backward compatibility for existing code that imports from here.
"""
# Import from centralized logging configuration
from app.core.logging_config import (
log_prompt_rendering as _log_prompt_rendering,
log_template_rendering as _log_template_rendering,
log_time as _log_time,
get_prompt_logger as _get_prompt_logger,
)
# Re-export functions to maintain backward compatibility
log_prompt_rendering = _log_prompt_rendering
log_template_rendering = _log_template_rendering
log_time = _log_time
# Re-export prompt_logger for backward compatibility with code that uses it directly
# This provides the same logger instance that was previously created in this module
prompt_logger = _get_prompt_logger()
# Expose functions in __all__ for explicit exports
__all__ = [
'log_prompt_rendering',
'log_template_rendering',
'log_time',
'prompt_logger',
]

View File

@@ -0,0 +1,16 @@
"""
路径管理模块
包含所有路径管理相关的工具函数。
"""
# 从子模块导出常用函数,保持向后兼容
from .output_paths import (
get_output_dir,
get_output_path,
)
__all__ = [
"get_output_dir",
"get_output_path",
]

View File

@@ -0,0 +1,133 @@
"""
Output Path Management for Memory Module
This module provides utilities for managing output file paths in the memory module.
All output files are now centralized in the logs/memory-output directory.
Migration from: app/core/memory/src/pipeline_output/
Migration to: logs/memory-output/
"""
import os
from pathlib import Path
from typing import Optional
try:
from app.core.config import settings
USE_UNIFIED_CONFIG = True
except ImportError:
USE_UNIFIED_CONFIG = False
settings = None
def get_output_dir() -> str:
"""
Get the base output directory for memory module files.
Returns:
str: Path to the output directory
"""
if USE_UNIFIED_CONFIG:
return settings.MEMORY_OUTPUT_DIR
else:
# Fallback to default path
return "logs/memory-output"
def get_output_path(filename: str) -> str:
"""
Get the full path for a memory module output file.
Args:
filename: Name of the output file
Returns:
str: Full path to the output file
"""
if USE_UNIFIED_CONFIG:
return settings.get_memory_output_path(filename)
else:
# Fallback to default path
return os.path.join("logs/memory-output", filename)
def ensure_output_dir() -> None:
"""
Ensure the output directory exists.
Creates the directory if it doesn't exist.
"""
if USE_UNIFIED_CONFIG:
settings.ensure_memory_output_dir()
else:
# Fallback: create directory manually
output_dir = Path("logs/memory-output")
output_dir.mkdir(parents=True, exist_ok=True)
# Standard output file names (for consistency across the module)
class OutputFiles:
"""Standard output file names for the memory module."""
# Chunker output
CHUNKER_TEST_OUTPUT = "chunker_test_output.txt"
# Preprocessing output
PREPROCESSED_DATA = "preprocessed_data.json"
PRUNED_DATA = "pruned_data.json"
PRUNED_TERMINAL = "pruned_terminal.json"
# Extraction output
STATEMENT_EXTRACTION = "statement_extraction.txt"
RELATIONS_OUTPUT = "relations_output.txt"
EXTRACTED_TRIPLETS = "extracted_triplets.txt"
EXTRACTED_ENTITIES_EDGES = "extracted_entities_edges.txt"
EXTRACTED_TEMPORAL_DATA = "extracted_temporal_data.txt"
# Deduplication output
DEDUP_ENTITY_OUTPUT = "dedup_entity_output.txt"
# Summary output
EXTRACTED_RESULT = "extracted_result.json"
EXTRACTED_RESULT_READABLE = "extracted_result_readable.txt"
# Analytics output
USER_DASHBOARD = "User-Dashboard.json"
SIGNBOARD = "Signboard.json"
def get_standard_output_path(file_constant: str) -> str:
"""
Get the full path for a standard output file.
Args:
file_constant: One of the OutputFiles constants
Returns:
str: Full path to the output file
"""
return get_output_path(file_constant)
# Backward compatibility: Legacy path resolution
def resolve_legacy_path(legacy_path: str) -> str:
"""
Resolve a legacy pipeline_output path to the new unified output path.
This function helps migrate code that uses hardcoded pipeline_output paths.
Args:
legacy_path: Path containing 'pipeline_output'
Returns:
str: New path using unified output directory
"""
if "pipeline_output" in legacy_path:
# Extract filename from legacy path
filename = os.path.basename(legacy_path)
return get_output_path(filename)
return legacy_path
# Aliases for backward compatibility with test code
get_memory_output_dir = get_output_dir
get_memory_output_path = get_output_path

View File

@@ -0,0 +1,34 @@
"""
提示词管理模块
包含所有提示词渲染和模板管理相关的工具函数。
"""
# 从子模块导出常用函数,保持向后兼容
from .prompt_utils import (
get_prompts,
render_statement_extraction_prompt,
render_temporal_extraction_prompt,
render_entity_dedup_prompt,
render_triplet_extraction_prompt,
render_memory_summary_prompt,
prompt_env,
)
from .template_render import (
render_evaluate_prompt,
render_reflexion_prompt,
)
__all__ = [
# prompt_utils
"get_prompts",
"render_statement_extraction_prompt",
"render_temporal_extraction_prompt",
"render_entity_dedup_prompt",
"render_triplet_extraction_prompt",
"render_memory_summary_prompt",
"prompt_env",
# template_render
"render_evaluate_prompt",
"render_reflexion_prompt",
]

View File

@@ -0,0 +1,240 @@
import os
from jinja2 import Environment, FileSystemLoader
from app.core.memory.utils.log.logging_utils import log_prompt_rendering, log_template_rendering
# Setup Jinja2 environment
# Get the directory of this file (app/core/memory/utils/prompt/)
current_dir = os.path.dirname(os.path.abspath(__file__))
prompt_dir = os.path.join(current_dir, "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def get_prompts(message: str) -> list[dict]:
"""
Renders system and user prompts using Jinja2 templates.
"""
system_template = prompt_env.get_template("system.jinja2")
user_template = prompt_env.get_template("user.jinja2")
system_prompt = system_template.render()
user_prompt = user_template.render(message=message)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('system', system_prompt)
log_prompt_rendering('user', user_prompt)
# 可选:记录模板渲染信息(仅当 prompt_templates.log 存在时生效)
log_template_rendering('system.jinja2', {})
log_template_rendering('user.jinja2', {'message': message})
return [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
async def render_statement_extraction_prompt(
chunk_content: str,
definitions: dict,
json_schema: dict,
granularity: int | None = None,
include_dialogue_context: bool = False,
dialogue_content: str | None = None,
max_dialogue_chars: int | None = None,
) -> str:
"""
Renders the statement extraction prompt using the extract_statement.jinja2 template.
Args:
chunk_content: The content of the chunk to process
definitions: Label definitions for statement classification
json_schema: JSON schema for the expected output format
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_statement.jinja2")
# Optional clipping of dialogue context
ctx = None
if include_dialogue_context and dialogue_content:
try:
if isinstance(max_dialogue_chars, int) and max_dialogue_chars > 0:
ctx = dialogue_content[:max_dialogue_chars]
else:
ctx = dialogue_content
except Exception:
ctx = dialogue_content
rendered_prompt = template.render(
inputs={"chunk": chunk_content},
definitions=definitions,
json_schema=json_schema,
granularity=granularity,
include_dialogue_context=include_dialogue_context,
dialogue_context=ctx,
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('statement extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_statement.jinja2', {
'inputs': 'chunk',
'definitions': 'LABEL_DEFINITIONS',
'json_schema': 'StatementExtractionResponse.schema',
'granularity': 'int|None',
'include_dialogue_context': include_dialogue_context,
'dialogue_context_len': (len(ctx) if isinstance(ctx, str) else 0),
})
return rendered_prompt
async def render_temporal_extraction_prompt(
ref_dates: dict,
statement: dict,
temporal_guide: dict,
statement_guide: dict,
json_schema: dict,
) -> str:
"""
Renders the temporal extraction prompt using the extract_temporal.jinja2 template.
Args:
ref_dates: Reference dates for context.
statement: The statement to process.
temporal_guide: Guidance on temporal types.
statement_guide: Guidance on statement types.
json_schema: JSON schema for the expected output format.
Returns:
Rendered prompt content as a string.
"""
template = prompt_env.get_template("extract_temporal.jinja2")
inputs = ref_dates | statement
rendered_prompt = template.render(
inputs=inputs,
temporal_guide=temporal_guide,
statement_guide=statement_guide,
json_schema=json_schema,
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('temporal extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_temporal.jinja2', {
'inputs': 'ref_dates|statement',
'temporal_guide': 'dict',
'statement_guide': 'dict',
'json_schema': 'Temporal.schema'
})
return rendered_prompt
def render_entity_dedup_prompt(
entity_a: dict,
entity_b: dict,
context: dict,
json_schema: dict,
disambiguation_mode: bool = False,
) -> str:
"""
Render the entity deduplication prompt using the entity_dedup.jinja2 template.
Args:
entity_a: Dict of entity A attributes
entity_b: Dict of entity B attributes
context: Dict of computed signals (group/type gate, similarities, co-occurrence, relation statements)
json_schema: JSON schema for the structured output (EntityDedupDecision)
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("entity_dedup.jinja2")
rendered_prompt = template.render(
entity_a=entity_a,
entity_b=entity_b,
same_group=context.get("same_group", False),
type_ok=context.get("type_ok", False),
type_similarity=context.get("type_similarity", 0.0),
name_text_sim=context.get("name_text_sim", 0.0),
name_embed_sim=context.get("name_embed_sim", 0.0),
name_contains=context.get("name_contains", False),
co_occurrence=context.get("co_occurrence", False),
relation_statements=context.get("relation_statements", []),
json_schema=json_schema,
disambiguation_mode=disambiguation_mode,
)
# prompt_logger.info("\n=== RENDERED ENTITY DEDUP PROMPT ===")
# prompt_logger.info(rendered_prompt)
# prompt_logger.info("\n" + "="*50 + "\n")
return rendered_prompt
# async def render_entity_dedup_prompt(
# entity_a: dict,
# entity_b: dict,
# context: dict,
# json_schema: dict,
# ) -> str:
# """
# Render the entity deduplication prompt using the entity_dedup.jinja2 template.
# Args:
# entity_a: Dict of entity A attributes
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
"""
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
Args:
statement: Statement text to process
chunk_content: The content of the chunk to process
json_schema: JSON schema for the expected output format
predicate_instructions: Optional predicate instructions
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("extract_triplet.jinja2")
rendered_prompt = template.render(
statement=statement,
chunk_content=chunk_content,
json_schema=json_schema,
predicate_instructions=predicate_instructions
)
# 记录渲染结果到提示日志(与示例日志结构一致)
log_prompt_rendering('triplet extraction', rendered_prompt)
# 可选:记录模板渲染信息
log_template_rendering('extract_triplet.jinja2', {
'statement': 'str',
'chunk_content': 'str',
'json_schema': 'TripletExtractionResponse.schema',
'predicate_instructions': 'PREDICATE_DEFINITIONS'
})
return rendered_prompt
async def render_memory_summary_prompt(
chunk_texts: str,
json_schema: dict,
max_words: int = 200,
) -> str:
"""
Renders the memory summary prompt using the memory_summary.jinja2 template.
Args:
chunk_texts: Concatenated text of conversation chunks
json_schema: JSON schema for the expected output format
max_words: Maximum words for the summary
Returns:
Rendered prompt content as string.
"""
template = prompt_env.get_template("memory_summary.jinja2")
rendered_prompt = template.render(
chunk_texts=chunk_texts,
json_schema=json_schema,
max_words=max_words,
)
log_prompt_rendering('memory summary', rendered_prompt)
log_template_rendering('memory_summary.jinja2', {
'chunk_texts_len': len(chunk_texts or ""),
'max_words': max_words,
'json_schema': 'MemorySummaryResponse.schema'
})
return rendered_prompt

View File

@@ -0,0 +1,60 @@
===任务===
你是一个实体去重/消歧判断助手。你将被提供两个实体的详细信息和上下文,请严格根据指引判断它们是否是同一真实世界实体,并在需要时进行类型消歧。
模式: {{ '消歧模式' if disambiguation_mode else '去重模式' }}
===输入===
实体A:
- 名称: "{{ entity_a.name | default('') }}"
- 类型: "{{ entity_a.entity_type | default('') }}"
- 描述: "{{ entity_a.description | default('') }}"
- 别名: {{ entity_a.aliases | default([]) }}
- 摘要: "{{ entity_a.fact_summary | default('') }}"
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
实体B:
- 名称: "{{ entity_b.name | default('') }}"
- 类型: "{{ entity_b.entity_type | default('') }}"
- 描述: "{{ entity_b.description | default('') }}"
- 别名: {{ entity_b.aliases | default([]) }}
- 摘要: "{{ entity_b.fact_summary | default('') }}"
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
上下文:
- 同组: {{ same_group | default(false) }}
- 类型一致或未知类型: {{ type_ok | default(false) }}
- 类型相似度(0-1): {{ type_similarity | default(0.0) }}
- 名称文本相似度(0-1): {{ name_text_sim | default(0.0) }}
- 名称向量相似度(0-1): {{ name_embed_sim | default(0.0) }}
- 名称包含关系: {{ name_contains | default(false) }}
- 上下文同源(同一语句指向两者): {{ co_occurrence | default(false) }}
- 两者相关的关系陈述(来自实体-实体边):
{% for s in relation_statements %}
- {{ s }}
{% endfor %}
===判定指引===
{% if disambiguation_mode %}
- 这是“同名但类型不同”的消歧场景。请判断两者是否指向同一真实世界实体。
- 综合名称文本/向量相似度、别名、描述、摘要与上下文关系(同源与关系陈述)进行判断。
- 若无法充分确定,应保守处理:不合并,并建议阻断该对在其他模糊/启发式合并中出现block_pair=true
- 若需要合并should_merge=true请选择“规范实体”(canonical_idx)并在可能的情况下给出建议统一类型suggested_type建议类型需与上下文一致。
- 规范实体优先级连接强度strong/both更高者其余相同则保留描述/摘要更丰富者再相同时保留实体Acanonical_idx=0
{% else %}
- 若实体类型相同或任一为UNKNOWN/空,可放行作为候选;若类型明显冲突(如人 vs 物品),除非别名与描述高度一致,否则判定不同实体。
- 综合名称文本/向量相似度、别名、描述、摘要以及上下文关系判断是否为同一实体。
- 当上下文同源或存在明确的关系陈述支持同一性(例如同一对象反复被提及或别名对应),可以适度降低判定阈值。
- 保守决策当无法充分确定不要合并same_entity=false
- 若需要合并,选择“保留的规范实体”(canonical_idx)为更合适的一个:
- 优先保留连接强度更强(strong/both)者;其余相同则保留描述/摘要更丰富者再相同时保留实体Acanonical_idx=0
{% endif %}
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
The output language should always be the same as the input language.
{{ json_schema }}

View File

@@ -0,0 +1,19 @@
你将收到一组记忆对象:{{ evaluate_data }}。
任务:多维度判断这些记忆是否与已有记忆存在冲突,并给出冲突的对应记忆。(冗余不算冲突)
仅输出一个合法 JSON 对象,严格遵循下述结构:
{
"data": [ ...与输入同结构的记忆对象数组... ],
"conflict": true 或 false,
"conflict_memory": 若冲突为 true则填写与其冲突的记忆对象否则为 null
}
必须遵守:
- 只输出 JSON不要添加解释或多余文本。
- 使用标准双引号,必要时对内部引号进行转义。
- 字段名与结构必须与给定模式一致。
模式参考:
[
{{ json_schema }}
]

View File

@@ -0,0 +1,49 @@
{#
对话级抽取与相关性判定模板(用于剪枝加速)
输入pruning_scene, dialog_text
输出:严格 JSON不要包含任何多余文本字段
- is_related: bool是否与所选场景相关
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
- ids: [string],编号/ID/订单号/申请号/账号等
- amounts: [string],金额/费用/价格相关(带单位或货币符号)
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等
- addresses: [string],地址/地点相关文本
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
要求:
- 必须只输出上述 JSON且键名一致不得输出解释、前后缀不得包含注释。
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
- 仅输出上述键;避免多余解释或字段。
#}
{% set scene_instructions = {
'education': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
'online_service': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
'outbound': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。'
} %}
{% set scene_key = pruning_scene %}
{% if scene_key not in scene_instructions %}
{% set scene_key = 'education' %}
{% endif %}
{% set instruction = scene_instructions[scene_key] %}
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
场景说明:{{ instruction }}
对话全文:
"""
{{ dialog_text }}
"""
只输出严格 JSON键固定、顺序不限
{
"is_related": <true 或 false>,
"times": [<string>...],
"ids": [<string>...],
"amounts": [<string>...],
"contacts": [<string>...],
"addresses": [<string>...],
"keywords": [<string>...]
}

View File

@@ -0,0 +1,207 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ')}}
{%- endmacro %}
===Tasks===
Your task is to identify and extract declarative statements from the provided conversational chunk based on the detailed extraction guidelines.
Each statement must be labeled as per the criteria mentioned below.
===Inputs===
{% if inputs %}
{% for key, val in inputs.items() %}
- {{ key }}: {{val}}
{% endfor %}
{% endif %}
===Extraction Instructions===
{% if granularity %}
{% if granularity == 3 %}
Atomic & Clear: Structure statements to clearly show a single subject-predicate-object relationship. It is better to have multiple smaller statements than one complex one.
Context-Independent: Statements must be understandable without needing to read the entire conversation.
{% elif granularity == 2 %}
Extract statements at the sentence level. Each statement should correspond to a single, complete thought (typically a full sentence from the source) but be rephrased for maximum clarity, removing conversational filler (e.g., 'um,' 'like,' interjections).
{% elif granularity == 1 %}
Extract only essence sentences and summarize the chunk into multiple, standalone statements, each focusing on factual statements, user preferences, relationships, and salient temporal context.
{% endif %}
{% endif %}
Context Resolution Requirements:
- Resolve demonstrative pronouns ("that," "this," "those","这个", "那个") to their specific referents
- If a statement contains vague references that cannot be resolved from the conversation context, either:
a) Expand the statement to include the missing context from earlier in the conversation
b) Mark the statement as requiring additional context
c) Skip extraction if the statement becomes meaningless without context
Conversational Context & Co-reference Resolution:
- Attribute every statement to the participant who uttered it.
- If the participant list provides a name for a speaker (e.g., "李雪 (用户)"), use the specific name ("李雪") in the extracted statement, not the generic role ("用户").
- Resolve all pronouns to the specific person or entity from the conversation's context.
- Identify and resolve abstract references to their specific names if mentioned.
- Expand abbreviations and acronyms to their full form.
{% if include_dialogue_context %}
===Full Dialogue Context===
The following is the complete dialogue context to help you understand references, pronouns, and conversational flow:
{{ dialogue_context }}
===End of Dialogue Context===
{% endif %}
Filtering and Formatting:
- Extract only declarative statements.
DO NOT extract questions, commands, greetings, or conversational filler.
Temporal Precision:
Include any explicit dates, times, or quantitative qualifiers.
If a sentence describes both the start of an event (static) and its ongoing nature (dynamic), extract both as separate statements.
{%- if definitions %}
{%- for section_key, section_dict in definitions.items() %}
==== {{ tidy(section_key) | upper }} DEFINITIONS & GUIDANCE ====
{%- for category, details in section_dict.items() %}
{{ loop.index }}. {{ category }}
- Definition: {{ details.get("definition", "") }}
{% endfor -%}
{% endfor -%}
{% endif -%}
===Examples===
Example 1: English Conversation
Example Chunk: """
Date: March 15, 2024
Participants:
- Sarah Chen (User)
- Assistant (AI)
User: "I've been trying watercolor painting recently and painted some flowers."
AI: "Watercolor painting is very interesting! Watercolor paints are typically made from pigments mixed with binders like gum arabic. How do you like it?"
User: "I think the color combinations could use some improvement, but I really like roses and lilies."
"""
Example Output: {
"statements": [
{
"statement": "Sarah Chen has been trying watercolor painting recently.",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "Sarah Chen painted some flowers.",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "Watercolor paints are typically made from pigments mixed with binders like gum arabic.",
"statement_type": "FACT",
"temporal_type": "ATEMPORAL",
"relevance": "IRRELEVANT"
},
{
"statement": "Sarah Chen thinks the color combinations in her watercolor paintings could use some improvement.",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "Sarah Chen really likes roses and lilies.",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
}
]
}
Example 2: Chinese Conversation (中文对话示例)
Example Chunk: """
日期: 2024年3月15日
参与者:
- 张曼婷 (用户)
- 小助手 (AI助手)
用户: "我最近在尝试水彩画,画了一些花朵。"
AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
用户: "我觉得色彩搭配还有提升的空间,不过我很喜欢玫瑰和百合这两种花。"
"""
Example Output: {
"statements": [
{
"statement": "张曼婷最近在尝试水彩画。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "张曼婷画了一些花朵。",
"statement_type": "FACT",
"temporal_type": "DYNAMIC",
"relevance": "RELEVANT"
},
{
"statement": "水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。",
"statement_type": "FACT",
"temporal_type": "ATEMPORAL",
"relevance": "IRRELEVANT"
},
{
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
"statement_type": "OPINION",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
},
{
"statement": "张曼婷很喜欢玫瑰和百合。",
"statement_type": "FACT",
"temporal_type": "STATIC",
"relevance": "RELEVANT"
}
]
}
===End of Examples===
===Reflection Process===
After extracting statements, perform the following self-review steps:
**Step 1: Attribution Check**
- Confirm every statement is properly attributed to the correct speaker
- Verify speaker names are used consistently throughout
- Check that AI assistant statements are properly attributed
**Step 2: Completeness Review**
- Ensure no important declarative statements were missed
- Check that temporal information is preserved
**Step 3: Classification Validation**
- Review statement_type classifications (FACT/OPINION/PREDICTION/SUGGESTION)
- Verify temporal_type assignments (STATIC/DYNAMIC/ATEMPORAL)
- Ensure classifications align with the provided definitions
**Step 4: Final Quality Check**
- Remove any questions, commands, or conversational filler
- Verify JSON format compliance
- Confirm output language matches input language
**Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
5. Example of proper escaping: "statement": "John said: \"I really like this book.\""
**LANGUAGE REQUIREMENT:**
- The output language should ALWAYS match the input language
- If input is in English, extract statements in English
- If input is in Chinese, extract statements in Chinese
- Preserve the original language and do not translate
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
{{ json_schema }}

View File

@@ -0,0 +1,81 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ')}}
{%- endmacro %}
{#
This prompt (template) is adapted from [getzep/graphiti]
Licensed under the Apache License, Version 2.0
Original work:
https://github.com/getzep/graphiti/blob/main/graphiti_core/prompts/extract_edge_dates.py
Modifications made by Ke Sun on 2025-09-01
See the LICENSE file for the full Apache 2.0 license text.
#}
# Task
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
# Input Data
{% if inputs %}
{% for key, val in inputs.items() %}
- {{ key }}: {{val}}
{% endfor %}
{% endif %}
# Temporal Fields
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
# Extraction Rules
## Core Principles
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
2. **Use the reference/publication date as "now"** when interpreting relative times
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
4. **For point-in-time events**, set only `valid_at`
## Date Format Requirements
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
- If no time specified, use `00:00:00` (midnight)
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
- If only month mentioned, use first or last day of month
- Always include timezone (use `Z` for UTC if unspecified)
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
## Statement Type Rules
{{ inputs.get("statement_type") | upper }} Statement Guidance:
{%for key, guide in statement_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
**Special Cases:**
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
## Temporal Type Rules
{{ inputs.get("temporal_type") | upper }} Temporal Type Guidance:
{% for key, guide in temporal_guide.items() %}
- {{ tidy(key) | capitalize }}: {{ guide }}
{% endfor %}
{% if inputs.get('quarter') and inputs.get('publication_date') %}
## Quarter Reference
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
{% endif %}
# Output Requirements
## JSON Formatting (CRITICAL)
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
2. Escape internal quotes with backslash: `\"`
3. No line breaks within JSON string values
4. Properly close and comma-separate all fields
## Language
Output language must match input language.
{{ json_schema }}

View File

@@ -0,0 +1,248 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ')}}
{%- endmacro %}
===Task===
Extract entities and knowledge triplets from the given statement.
===Inputs===
**Chunk Content:** "{{ chunk_content }}"
**Statement:** "{{ statement }}"
===Guidelines===
**Entity Extraction:**
- Extract entities with their types and context-independent descriptions
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
**Triplet Extraction:**
- Extract (subject, predicate, object) triplets where:
- Subject: main entity performing the action or being described
- Predicate: relationship between entities (e.g., 'is', 'works at', 'believes')
- Object: entity, value, or concept affected by the predicate
- Exclude all temporal expressions from every field
- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens)
- Do NOT translate predicate tokens
- Do NOT include `statement_id` field (assigned automatically)
**When NOT to extract triplets:**
- Non-propositional utterances (emotions, fillers, onomatopoeia)
- No clear predicate from the given definitions applies
- Standalone noun phrases or checklist items (e.g., "三脚架", "备用电池") → extract as entities only
- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS")
**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty.
{%- if predicate_instructions -%}
**Predicate Instructions:**
Use ONLY these predicates. If none fits, set triplets to [].
{%- for pred, instruction in predicate_instructions.items() %}
- {{ pred }}: {{ instruction }}
{%- endfor -%}
{%- endif -%}
===Examples===
**Example 1 (English):** "I plan to travel to Paris next week and visit the Louvre."
Output:
{
"triplets": [
{
"subject_name": "I",
"subject_id": 0,
"predicate": "PLANS_TO_VISIT",
"object_name": "Paris",
"object_id": 1,
"value": null
},
{
"subject_name": "I",
"subject_id": 0,
"predicate": "PLANS_TO_VISIT",
"object_name": "Louvre",
"object_id": 2,
"value": null
}
],
"entities": [
{
"entity_idx": 0,
"name": "I",
"type": "Person",
"description": "The user"
},
{
"entity_idx": 1,
"name": "Paris",
"type": "Location",
"description": "Capital city of France"
},
{
"entity_idx": 2,
"name": "Louvre",
"type": "Location",
"description": "World-famous museum located in Paris"
}
]
}
**Example 2 (English):** "John Smith works at Google and is responsible for AI product development."
Output:
{
"triplets": [
{
"subject_name": "John Smith",
"subject_id": 0,
"predicate": "WORKS_AT",
"object_name": "Google",
"object_id": 1,
"value": null
},
{
"subject_name": "John Smith",
"subject_id": 0,
"predicate": "RESPONSIBLE_FOR",
"object_name": "AI product development",
"object_id": 2,
"value": null
}
],
"entities": [
{
"entity_idx": 0,
"name": "John Smith",
"type": "Person",
"description": "Individual person name"
},
{
"entity_idx": 1,
"name": "Google",
"type": "Organization",
"description": "American technology company"
},
{
"entity_idx": 2,
"name": "AI product development",
"type": "WorkRole",
"description": "Artificial intelligence product development work"
}
]
}
**Example 3 (Chinese):** "我计划下周去巴黎旅行,参观卢浮宫。"
Output:
{
"triplets": [
{
"subject_name": "我",
"subject_id": 0,
"predicate": "PLANS_TO_VISIT",
"object_name": "巴黎",
"object_id": 1,
"value": null
},
{
"subject_name": "我",
"subject_id": 0,
"predicate": "PLANS_TO_VISIT",
"object_name": "卢浮宫",
"object_id": 2,
"value": null
}
],
"entities": [
{
"entity_idx": 0,
"name": "我",
"type": "Person",
"description": "用户本人"
},
{
"entity_idx": 1,
"name": "巴黎",
"type": "Location",
"description": "法国首都城市"
},
{
"entity_idx": 2,
"name": "卢浮宫",
"type": "Location",
"description": "位于巴黎的世界著名博物馆"
}
]
}
**Example 4 (Chinese):** "张明在腾讯工作负责AI产品开发。"
Output:
{
"triplets": [
{
"subject_name": "张明",
"subject_id": 0,
"predicate": "WORKS_AT",
"object_name": "腾讯",
"object_id": 1,
"value": null
},
{
"subject_name": "张明",
"subject_id": 0,
"predicate": "RESPONSIBLE_FOR",
"object_name": "AI产品开发",
"object_id": 2,
"value": null
}
],
"entities": [
{
"entity_idx": 0,
"name": "张明",
"type": "Person",
"description": "个人姓名"
},
{
"entity_idx": 1,
"name": "腾讯",
"type": "Organization",
"description": "中国科技公司"
},
{
"entity_idx": 2,
"name": "AI产品开发",
"type": "WorkRole",
"description": "人工智能产品研发工作"
}
]
}
**Example 5 (Entity Only):** "Tripod" or "三脚架"
Output:
{
"triplets": [],
"entities": [
{
"entity_idx": 0,
"name": "Tripod",
"type": "Equipment",
"description": "Photography equipment accessory"
}
]
}
===Output Format===
**JSON Requirements:**
- Use only ASCII double quotes (") for JSON structure
- Never use Chinese quotation marks ("") or Unicode quotes
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
- The output language should ALWAYS match the input language
- If input is in English, extract statements in English
- If input is in Chinese, extract statements in Chinese
- Preserve the original language and do not translate
{{ json_schema }}

View File

@@ -0,0 +1,29 @@
{% macro tidy(name) -%}
{{ name.replace('_', ' ') }}
{%- endmacro %}
=== Task ===
Summarize the provided conversation chunks into a concise Memory summary.
=== Requirements ===
- Focus on factual statements, user preferences, relationships, and salient temporal context.
- Avoid repetition and filler; be specific.
- Keep it under {{ max_words or 200 }} words.
- Output must be valid JSON conforming to the schema below.
=== Input ===
{% if chunk_texts %}
{{ chunk_texts }}
{% endif %}
=== Output Schema ===
**CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
3. Ensure all JSON strings are properly closed and comma-separated
4. Do not include line breaks within JSON string values
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
The output language should always be the same as the input language.
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
{{ json_schema }}

View File

@@ -0,0 +1,23 @@
你将收到一条冲突判定对象:{{ data }}。
任务:分析冲突产生原因,给出解决方案,并生成设为失效后的记忆。
仅输出一个合法 JSON 对象,严格遵循下述结构:
{
"conflict": 与输入同结构,包含 data 与 conflict_memory,
"reflexion": { "reason": string, "solution": string },
"resolved": {
"original_memory_id": 被设为失效的记忆 id,
"resolved_memory": 完整的设为失效后的记忆对象
}
}
必须遵守:
- 只输出 JSON不要添加解释或多余文本。
- 使用标准双引号,必要时对内部引号进行转义。
- 字段名与结构必须与给定模式一致。
- 当 conflict 为 false 时resolved 必须为 null。
- 其中 conflict.data 必须为数组形式,即使只有一个对象也需使用 [ ] 包裹。
模式参考:
[
{{ json_schema }}
]

View File

@@ -0,0 +1,2 @@
You are an AI assistant that extracts entity nodes from conversational messages.
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.

View File

@@ -0,0 +1,5 @@
You are given a conversation context and a CURRENT MESSAGE.
Your task is to extract user name and age mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the reference entities.
{{ message }}

View File

@@ -0,0 +1,42 @@
import os
from jinja2 import Environment, FileSystemLoader
from typing import List, Dict, Any
# Setup Jinja2 environment
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any]) -> str:
"""
Renders the evaluate prompt using the evaluate.jinja2 template.
Args:
evaluate_data: The data to evaluate
schema: The JSON schema to use for the output.
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("evaluate.jinja2")
rendered_prompt = template.render(evaluate_data=evaluate_data, json_schema=schema)
return rendered_prompt
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any]) -> str:
"""
Renders the reflexion prompt using the extract_temporal.jinja2 template.
Args:
data: The data to reflex on.
schema: The JSON schema to use for the output.
Returns:
Rendered prompt content as a string.
"""
template = prompt_env.get_template("reflexion.jinja2")
rendered_prompt = template.render(data=data, json_schema=schema)
return rendered_prompt

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""自我反思工具模块
本模块提供自我反思引擎的核心功能,包括:
- 记忆冲突判定
- 反思执行
- 记忆更新
从 app.core.memory.src.data_config_api 迁移而来。
"""
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
__all__ = ["conflict", "reflexion", "self_reflexion"]

View File

@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
"""记忆冲突判定模块
本模块提供记忆冲突判定功能使用LLM判断记忆数据中是否存在冲突。
从 app.core.memory.src.data_config_api.evaluate 迁移而来。
"""
import logging
from typing import List, Any
import time
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.schemas.memory_storage_schema import ConflictResultSchema
from pydantic import BaseModel
async def conflict(evaluate_data: List[Any]) -> List[Any]:
"""
Evaluates memory conflict using the evaluate.jinja2 template.
Args:
evaluate_data: 反思数据列表。
Returns:
冲突记忆列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")
print(f"====== 冲突判定开始 ======\n")
start_time = time.time()
response = await client.response_structured(messages, ConflictResultSchema)
end_time = time.time()
print(f"冲突判定耗时: {end_time - start_time}")
print(f"冲突判定原始输出:(type={type(response)})\n{response}")
if not response:
logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。")
return []
try:
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
except Exception:
try:
return [response.dict()]
except Exception:
logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。")
return [response]

View File

@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
"""反思执行模块
本模块提供反思执行功能使用LLM对冲突记忆进行反思和解决。
从 app.core.memory.src.data_config_api.reflexion 迁移而来。
"""
import logging
from typing import List, Any
import time
from app.core.memory.utils.prompt.template_render import render_reflexion_prompt
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.schemas.memory_storage_schema import ReflexionResultSchema
from pydantic import BaseModel
async def reflexion(ref_data: List[Any]) -> List[Any]:
"""
Reflexes on the given reference data using the reflexion.jinja2 template.
Args:
ref_data: 反思数据列表。
Returns:
反思结果列表JSON 数组)。
"""
from app.core.memory.utils.config import definitions as config_defs
client = get_llm_client(config_defs.SELECTED_LLM_ID)
rendered_prompt = await render_reflexion_prompt(ref_data, ReflexionResultSchema)
messages = [{"role": "user", "content": rendered_prompt}]
print(f"提示词长度: {len(rendered_prompt)}")
print(f"====== 反思开始 ======\n")
start_time = time.time()
response = await client.response_structured(messages, ReflexionResultSchema)
end_time = time.time()
print(f"反思耗时: {end_time - start_time}")
print(f"反思原始输出:(type={type(response)})\n{response}")
if not response:
logging.error("LLM 反思输出解析失败,返回空列表以继续流程。")
return []
# 统一返回为列表[dict],便于自我反思主流程更新数据库
try:
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
except Exception:
try:
return [response.dict()]
except Exception:
logging.warning("无法标准化反思返回类型,尝试直接封装为列表。")
return [response]

View File

@@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
"""自我反思主执行模块
本模块提供自我反思引擎的主流程,包括:
- 获取反思数据
- 冲突判断
- 反思执行
- 记忆更新
从 app.core.memory.src.data_config_api.self_reflexion 迁移而来。
"""
import os
import json
import logging
import asyncio
from typing import List, Dict, Any
import uuid
from app.core.memory.utils.config.definitions import (
REFLEXION_ENABLED,
REFLEXION_ITERATION_PERIOD,
REFLEXION_RANGE,
REFLEXION_BASELINE,
)
from app.db import get_db
from sqlalchemy.orm import Session
from app.models.retrieval_info import RetrievalInfo
from app.core.memory.utils.config.get_data import get_data
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
from app.repositories.neo4j.cypher_queries import UPDATE_STATEMENT_INVALID_AT
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# 并发限制(可通过环境变量覆盖)
CONCURRENCY = int(os.getenv("REFLEXION_CONCURRENCY", "5"))
# 确保 INFO 级别日志输出到终端
_root_logger = logging.getLogger()
if not _root_logger.handlers:
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
else:
_root_logger.setLevel(logging.INFO)
async def get_reflexion_data(host_id: uuid.UUID) -> List[Any]:
"""
根据反思范围获取判断的记忆数据。
Args:
host_id: 主机ID
Returns:
符合反思范围的记忆数据列表。
"""
if REFLEXION_RANGE == "retrieval":
return await get_data(host_id)
elif REFLEXION_RANGE == "database":
return []
else:
raise ValueError(f"未知的反思范围: {REFLEXION_RANGE}")
async def run_conflict(conflict_data: List[Any]) -> List[Any]:
"""
判断反思数据中是否存在冲突。
Args:
conflict_data: 冲突数据列表。
Returns:
如果存在冲突则返回冲突记忆列表,否则返回空列表。
"""
if not conflict_data:
return []
conflict_data = await conflict(conflict_data)
# 仅保留存在冲突的条目conflict == True
try:
return [c for c in conflict_data if isinstance(c, dict) and c.get("conflict") is True]
except Exception:
return []
async def run_reflexion(reflexion_data: List[Any]) -> Any:
"""
执行反思,解决冲突。
Args:
reflexion_data: 反思数据列表。
Returns:
解决冲突后的反思结果(由 LLM 返回)。
"""
if not reflexion_data:
return []
# 并行对每个冲突进行反思,整体缩短等待时间
sem = asyncio.Semaphore(CONCURRENCY)
async def _reflex_one(item: Any) -> Dict[str, Any] | None:
async with sem:
try:
result_list = await reflexion([item])
if not result_list:
return None
obj = result_list[0]
if hasattr(obj, "model_dump"):
return obj.model_dump()
elif hasattr(obj, "dict"):
return obj.dict()
elif isinstance(obj, dict):
return obj
except Exception as e:
logging.warning(f"反思失败,跳过一项: {e}")
return None
tasks = [_reflex_one(item) for item in reflexion_data]
results = await asyncio.gather(*tasks, return_exceptions=False)
return [r for r in results if r]
async def update_memory(solved_data: List[Any], host_id: uuid.UUID) -> str:
"""
更新记忆库,将解决冲突后的记忆更新到记忆库中。
Args:
solved_data: 解决冲突后的记忆(由 LLM 返回)。
host_id: 主机ID
Returns:
更新结果(成功或失败)。
"""
flag = False
if not solved_data:
return "数据缺失,更新失败"
if not isinstance(solved_data, list):
return "数据格式错误,更新失败"
neo4j_connector = Neo4jConnector()
try:
print(f"====== 更新记忆开始 ======\n")
sem = asyncio.Semaphore(CONCURRENCY)
success_count = 0
async def _update_one(item: Dict[str, Any]) -> bool:
async with sem:
try:
if not isinstance(item, dict):
return False
if not item:
return False
resolved = item.get("resolved")
if not isinstance(resolved, dict) or not resolved:
logging.warning(f"反思结果无可更新内容,跳过此项: {item}")
return False
resolved_mem = resolved.get("resolved_memory")
if not isinstance(resolved_mem, dict) or not resolved_mem:
logging.warning(f"反思结果缺少 resolved_memory跳过此项: {item}")
return False
group_id = resolved_mem.get("group_id")
id = resolved_mem.get("id")
# 使用 invalid_at 字段作为新的失效时间
new_invalid_at = resolved_mem.get("invalid_at")
if not all([group_id, id, new_invalid_at]):
logging.warning(f"记忆更新参数缺失,跳过此项: {item}")
return False
await neo4j_connector.execute_query(
UPDATE_STATEMENT_INVALID_AT,
group_id=group_id,
id=id,
new_invalid_at=new_invalid_at,
)
return True
except Exception as e:
logging.error(f"更新单条记忆失败: {e}")
return False
tasks = [_update_one(item) for item in solved_data if isinstance(item, dict)]
results = await asyncio.gather(*tasks, return_exceptions=False)
success_count = sum(1 for r in results if r)
logging.info(f"成功更新 {success_count} 条记忆")
flag = success_count > 0
return "更新成功" if flag else "更新失败"
except Exception as e:
logging.error(f"更新记忆库失败: {e}")
return "更新失败"
finally:
if flag: # 删除数据库中的检索数据
db: Session = next(get_db())
try:
db.query(RetrievalInfo).filter(RetrievalInfo.host_id == host_id).delete()
db.commit()
logging.info(f"成功删除 {success_count} 条检索数据")
except Exception as e:
logging.error(f"删除数据库中的检索数据失败: {e}")
async def _append_json(label: str, data: Any) -> None:
"""记录冲突记忆(后台线程写入,避免阻塞事件循环)"""
def _write():
with open("reflexion_data.json", "a", encoding="utf-8") as f:
f.write(f"### {label} ###\n")
json.dump(data, f, ensure_ascii=False, indent=4)
f.write("\n\n")
# 正确地在协程内等待后台线程执行,避免未等待的协程警告
await asyncio.to_thread(_write)
async def self_reflexion(host_id: uuid.UUID) -> str:
"""
自我反思引擎,执行反思流程。
Args:
host_id: 主机ID
Returns:
反思结果描述字符串
"""
if not REFLEXION_ENABLED:
return "未开启反思..."
print(f"====== 自我反思流程开始 ======\n")
reflexion_data = await get_reflexion_data(host_id)
if not reflexion_data:
print(f"====== 自我反思流程结束 ======\n")
return "无反思数据,结束反思"
print(f"反思数据获取成功,共 {len(reflexion_data)}")
conflict_data = await run_conflict(reflexion_data)
if not conflict_data:
print(f"====== 自我反思流程结束 ======\n")
return "无冲突,无需反思"
print(f"冲突记忆类型: {type(conflict_data)}")
await _append_json("conflict", conflict_data)
solved_data = await run_reflexion(conflict_data)
if not solved_data:
print(f"====== 自我反思流程结束 ======\n")
return "反思失败,未解决冲突"
print(f"解决冲突后的记忆类型: {type(solved_data)}")
await _append_json("solved_data", solved_data)
result = await update_memory(solved_data, host_id)
print(f"更新记忆库结果: {result}")
print(f"====== 自我反思流程结束 ======\n")
return result
if __name__ == "__main__":
import asyncio
# host_id = uuid.UUID("3f6ff1eb-50c7-4765-8e89-e4566be33333")
host_id = uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122")
asyncio.run(self_reflexion(host_id))

View File

@@ -0,0 +1,26 @@
"""
可视化模块
包含所有可视化相关的工具函数,主要用于遗忘曲线的可视化。
"""
# 从子模块导出常用函数,保持向后兼容
from .forgetting_visualizer import (
export_memory_curve_numpy,
export_memory_curves_multiple_strengths,
export_parameter_sweep_numpy,
visualize_forgetting_curve,
plot_3d_forgetting_surface,
create_comparison_visualization,
save_memory_curves_to_file,
)
__all__ = [
"export_memory_curve_numpy",
"export_memory_curves_multiple_strengths",
"export_parameter_sweep_numpy",
"visualize_forgetting_curve",
"plot_3d_forgetting_surface",
"create_comparison_visualization",
"save_memory_curves_to_file",
]

View File

@@ -0,0 +1,386 @@
"""
Memory Visualization Utilities
This module provides visualization functions for the modified Ebbinghaus forgetting curve
and utilities to export memory curves as numpy arrays.
"""
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Tuple, List, Dict, Any
import math
def export_memory_curve_numpy(forgetting_engine,
time_range: Tuple[float, float] = (0, 10),
memory_strength: float = 1.0,
num_points: int = 1000) -> Tuple[np.ndarray, np.ndarray]:
"""
Export memory curve as numpy arrays for time and retention values.
Args:
forgetting_engine: Instance of ForgettingEngine
time_range: Tuple of (start_time, end_time)
memory_strength: Memory strength value to use
num_points: Number of points to generate
Returns:
Tuple of (time_array, retention_array)
"""
start_time, end_time = time_range
time_array = np.linspace(start_time, end_time, num_points)
retention_array = np.array([
forgetting_engine.forgetting_curve(t, memory_strength)
for t in time_array
])
return time_array, retention_array
def export_memory_curves_multiple_strengths(forgetting_engine,
time_range: Tuple[float, float] = (0, 10),
memory_strengths: List[float] = None,
num_points: int = 1000) -> Dict[str, np.ndarray]:
"""
Export memory curves for multiple memory strengths as numpy arrays.
Args:
forgetting_engine: Instance of ForgettingEngine
time_range: Tuple of (start_time, end_time)
memory_strengths: List of memory strength values
num_points: Number of points to generate
Returns:
Dictionary with 'time' and retention arrays for each strength
"""
if memory_strengths is None:
memory_strengths = [0.5, 1.0, 2.0, 5.0]
start_time, end_time = time_range
time_array = np.linspace(start_time, end_time, num_points)
result = {'time': time_array}
for strength in memory_strengths:
retention_array = np.array([
forgetting_engine.forgetting_curve(t, strength)
for t in time_array
])
result[f'strength_{strength}'] = retention_array
return result
def export_parameter_sweep_numpy(base_engine,
parameter_name: str,
parameter_values: List[float],
time_range: Tuple[float, float] = (0, 10),
memory_strength: float = 1.0,
num_points: int = 1000) -> Dict[str, np.ndarray]:
"""
Export memory curves for parameter sweep as numpy arrays.
Args:
base_engine: Base ForgettingEngine instance
parameter_name: Name of parameter to sweep ('offset', 'lambda_time', 'lambda_mem')
parameter_values: List of parameter values to test
time_range: Tuple of (start_time, end_time)
memory_strength: Memory strength value to use
num_points: Number of points to generate
Returns:
Dictionary with 'time' and retention arrays for each parameter value
"""
from app.core.memory.storage_services.forgetting_engine import ForgettingEngine
from app.core.memory.models.variate_config import ForgettingEngineConfig
start_time, end_time = time_range
time_array = np.linspace(start_time, end_time, num_points)
result = {'time': time_array}
for param_value in parameter_values:
# Create new engine with modified parameter
if parameter_name == 'offset':
config = ForgettingEngineConfig(offset=param_value, lambda_time=base_engine.lambda_time, lambda_mem=base_engine.lambda_mem)
elif parameter_name == 'lambda_time':
config = ForgettingEngineConfig(offset=base_engine.offset, lambda_time=param_value, lambda_mem=base_engine.lambda_mem)
elif parameter_name == 'lambda_mem':
config = ForgettingEngineConfig(offset=base_engine.offset, lambda_time=base_engine.lambda_time, lambda_mem=param_value)
else:
raise ValueError(f"Unknown parameter: {parameter_name}")
engine = ForgettingEngine(config)
retention_array = np.array([
engine.forgetting_curve(t, memory_strength)
for t in time_array
])
result[f'{parameter_name}_{param_value}'] = retention_array
return result
def visualize_forgetting_curve(forgetting_engine,
max_time: float = 10.0,
memory_strengths: Optional[List[float]] = None,
figsize: Tuple[int, int] = (12, 8)) -> None:
"""
Visualize the modified Ebbinghaus forgetting curve.
Args:
forgetting_engine: Instance of ForgettingEngine
max_time: Maximum time to plot
memory_strengths: List of memory strength values to plot
figsize: Figure size for the plot
"""
if memory_strengths is None:
memory_strengths = [0.5, 1.0, 2.0, 5.0]
# Create time array
t = np.linspace(0, max_time, 1000)
# Create subplots
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=figsize)
fig.suptitle('Modified Ebbinghaus Forgetting Curve Analysis', fontsize=16, fontweight='bold')
# Plot 1: Different memory strengths
ax1.set_title('Effect of Memory Strength (S)')
for S in memory_strengths:
retention = [forgetting_engine.forgetting_curve(time, S) for time in t]
ax1.plot(t, retention, label=f'S = {S}', linewidth=2)
ax1.set_xlabel('Time')
ax1.set_ylabel('Memory Retention')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 1)
# Plot 2: Different lambda_time values
ax2.set_title('Effect of λ_time')
lambda_times = [0.5, 1.0, 0.3]
lambda_mem = [0.5,0.3,1.0]
offset_mem = [0.1,0.05,0.2]
for i in range(len(lambda_times)):
lt = lambda_times[i]
lm = lambda_mem[i]
off = offset_mem[i]
from app.core.memory.storage_services.forgetting_engine import ForgettingEngine
from app.core.memory.models.variate_config import ForgettingEngineConfig
config = ForgettingEngineConfig(offset=off, lambda_time=lt, lambda_mem=lm)
temp_engine = ForgettingEngine(config)
retention = [temp_engine.forgetting_curve(time, 1.0) for time in t]
ax2.plot(t, retention, label=f'λ_time = {lt}', linewidth=2)
ax2.set_xlabel('Time')
ax2.set_ylabel('Memory Retention')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 1)
plt.tight_layout()
plt.show()
def plot_3d_forgetting_surface(forgetting_engine,
max_time: float = 10.0,
max_strength: float = 5.0,
figsize: Tuple[int, int] = (12, 9)) -> None:
"""
Create a 3D surface plot of the forgetting curve.
Args:
forgetting_engine: Instance of ForgettingEngine
max_time: Maximum time to plot
max_strength: Maximum memory strength to plot
figsize: Figure size for the plot
"""
# Create meshgrid
t = np.linspace(0.1, max_time, 50)
S = np.linspace(0.1, max_strength, 50)
T, S_mesh = np.meshgrid(t, S)
# Calculate retention for each point
R = np.zeros_like(T)
for i in range(T.shape[0]):
for j in range(T.shape[1]):
R[i, j] = forgetting_engine.forgetting_curve(T[i, j], S_mesh[i, j])
# Create 3D plot
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot(111, projection='3d')
surface = ax.plot_surface(T, S_mesh, R, cmap='viridis', alpha=0.8)
ax.set_xlabel('Time (t)')
ax.set_ylabel('Memory Strength (S)')
ax.set_zlabel('Memory Retention (R)')
ax.set_title(f'3D Forgetting Curve Surface\n(offset={forgetting_engine.offset}, λ_time={forgetting_engine.lambda_time}, λ_mem={forgetting_engine.lambda_mem})')
# Add colorbar
fig.colorbar(surface, shrink=0.5, aspect=5)
plt.show()
def create_comparison_visualization(forgetting_engine, figsize: Tuple[int, int] = (15, 10)) -> None:
"""
Create a comparison visualization of different curve configurations.
Args:
forgetting_engine: Instance of ForgettingEngine
figsize: Figure size for the plot
"""
# Create figure with multiple subplots
fig, axes = plt.subplots(2, 2, figsize=figsize)
fig.suptitle('Modified Ebbinghaus Forgetting Curve - Parameter Comparison', fontsize=16, fontweight='bold')
t = np.linspace(0, 10, 100)
# Plot 1: Original vs Modified curve
ax1 = axes[0, 0]
ax1.set_title('Original vs Modified Ebbinghaus Curve')
# Original Ebbinghaus: R = e^(-t/S)
S = 2.0
original = np.exp(-t / S)
ax1.plot(t, original, 'r--', label='Original: R = e^(-t/S)', linewidth=2)
# Modified with offset
modified = [forgetting_engine.forgetting_curve(time, S) for time in t]
ax1.plot(t, modified, 'b-', label='Modified: offset + (1-offset)*e^(-λ_time*t/λ_mem*S)', linewidth=2)
ax1.set_xlabel('Time')
ax1.set_ylabel('Memory Retention')
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_ylim(0, 1)
# Plot 2: Different offset values
ax2 = axes[0, 1]
ax2.set_title('Effect of Offset Parameter')
for offset in [0.0, 0.1, 0.2, 0.3]:
from forgetting.forgetting_engine import ForgettingEngine
from app.core.memory.models.variate_config import ForgettingEngineConfig
config = ForgettingEngineConfig(offset=offset, lambda_time=1.0, lambda_mem=1.0)
engine = ForgettingEngine(config)
retention = [engine.forgetting_curve(time, 1.0) for time in t]
ax2.plot(t, retention, label=f'offset = {offset}', linewidth=2)
ax2.set_xlabel('Time')
ax2.set_ylabel('Memory Retention')
ax2.legend()
ax2.grid(True, alpha=0.3)
ax2.set_ylim(0, 1)
# Plot 3: Lambda time effect
ax3 = axes[1, 0]
ax3.set_title('Effect of λ_time (Time Sensitivity)')
for lambda_time in [0.5, 1.0, 2.0, 3.0]:
from forgetting.forgetting_engine import ForgettingEngine
from app.core.memory.models.config_models import ForgettingEngineConfig
config = ForgettingEngineConfig(offset=0.1, lambda_time=lambda_time, lambda_mem=1.0)
engine = ForgettingEngine(config)
retention = [engine.forgetting_curve(time, 1.0) for time in t]
ax3.plot(t, retention, label=f'λ_time = {lambda_time}', linewidth=2)
ax3.set_xlabel('Time')
ax3.set_ylabel('Memory Retention')
ax3.legend()
ax3.grid(True, alpha=0.3)
ax3.set_ylim(0, 1)
# Plot 4: Memory strength effect
ax4 = axes[1, 1]
ax4.set_title('Effect of Memory Strength (S)')
for strength in [0.5, 1.0, 2.0, 4.0]:
retention = [forgetting_engine.forgetting_curve(time, strength) for time in t]
ax4.plot(t, retention, label=f'S = {strength}', linewidth=2)
ax4.set_xlabel('Time')
ax4.set_ylabel('Memory Retention')
ax4.legend()
ax4.grid(True, alpha=0.3)
ax4.set_ylim(0, 1)
plt.tight_layout()
plt.show()
def save_memory_curves_to_file(forgetting_engine,
filename: str,
time_range: Tuple[float, float] = (0, 10),
memory_strengths: List[float] = None,
num_points: int = 1000,
format: str = 'npz') -> None:
"""
Save memory curves to file in various formats.
Args:
forgetting_engine: Instance of ForgettingEngine
filename: Output filename (without extension)
time_range: Tuple of (start_time, end_time)
memory_strengths: List of memory strength values
num_points: Number of points to generate
format: Output format ('npz', 'csv', 'json')
"""
if memory_strengths is None:
memory_strengths = [0.5, 1.0, 2.0, 5.0]
curves_data = export_memory_curves_multiple_strengths(
forgetting_engine, time_range, memory_strengths, num_points
)
if format == 'npz':
np.savez(f"{filename}.npz", **curves_data)
elif format == 'csv':
import pandas as pd
df = pd.DataFrame(curves_data)
df.to_csv(f"{filename}.csv", index=False)
elif format == 'json':
import json
# Convert numpy arrays to lists for JSON serialization
json_data = {k: v.tolist() if isinstance(v, np.ndarray) else v
for k, v in curves_data.items()}
with open(f"{filename}.json", 'w') as f:
json.dump(json_data, f, indent=2)
else:
raise ValueError(f"Unsupported format: {format}")
if __name__ == "__main__":
# Example usage
from app.core.memory.storage_services.forgetting_engine import ForgettingEngine
print("Memory Visualization Utilities Demo")
print("=" * 40)
# Create engine
from app.core.memory.models.variate_config import ForgettingEngineConfig
config = ForgettingEngineConfig(offset=0.1, lambda_time=0.5, lambda_mem=0.5)
engine = ForgettingEngine(config)
# # Export single curve as numpy
# time_arr, retention_arr = export_memory_curve_numpy(engine, (0, 10), 1.0, 100)
# print(f"Exported single curve: {len(time_arr)} points")
# print(f"Time range: {time_arr[0]:.2f} to {time_arr[-1]:.2f}")
# print(f"Retention range: {retention_arr.min():.4f} to {retention_arr.max():.4f}")
# # Export multiple curves
# curves = export_memory_curves_multiple_strengths(engine, (0, 10), [0.5, 1.0, 2.0])
# print(f"\nExported multiple curves: {list(curves.keys())}")
# # Parameter sweep
# param_sweep = export_parameter_sweep_numpy(engine, 'offset', [0.0, 0.1, 0.2, 0.3])
# print(f"Parameter sweep results: {list(param_sweep.keys())}")
# print("\nVisualization functions are ready to use!")
visualize_forgetting_curve(engine)
create_comparison_visualization(engine)