Merge #85 into develop from feature/actr-forget
[feature]actr-记忆遗忘需求开发
* feature/actr-forget: (12 commits squashed)
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget
- [fix]Eliminate the interference caused by redundant code
- [feature]
1.Extended fields of the date_config table;
2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.
- [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler
- [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process
- Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget
Signed-off-by: 乐力齐 <accounts_690c7b0af9007d7e338af636@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/85
This commit is contained in:
@@ -85,6 +85,8 @@ health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
|
||||
@@ -103,6 +105,13 @@ beat_schedule_config = {
|
||||
"schedule": memory_cache_regeneration_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"run-forgetting-cycle": {
|
||||
"task": "app.tasks.run_forgetting_cycle_task",
|
||||
"schedule": forgetting_cycle_schedule,
|
||||
"kwargs": {
|
||||
"config_id": None, # 使用默认配置,可以通过环境变量配置
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
|
||||
@@ -33,7 +33,7 @@ from . import (
|
||||
emotion_config_controller,
|
||||
prompt_optimizer_controller,
|
||||
tool_controller,
|
||||
home_page_controller,
|
||||
memory_forget_controller,
|
||||
)
|
||||
from . import user_memory_controllers
|
||||
|
||||
@@ -71,6 +71,6 @@ manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
manager_router.include_router(memory_reflection_controller.router)
|
||||
manager_router.include_router(tool_controller.router)
|
||||
manager_router.include_router(home_page_controller.router)
|
||||
manager_router.include_router(memory_forget_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
324
api/app/controllers/memory_forget_controller.py
Normal file
324
api/app/controllers/memory_forget_controller.py
Normal file
@@ -0,0 +1,324 @@
|
||||
"""
|
||||
遗忘引擎控制器模块
|
||||
|
||||
本模块提供遗忘引擎的 REST API 接口,包括:
|
||||
1. 手动触发遗忘周期
|
||||
2. 获取和更新配置
|
||||
3. 获取统计信息
|
||||
4. 获取遗忘曲线数据
|
||||
|
||||
所有接口都需要用户认证,并自动关联到当前工作空间。
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ForgettingTriggerRequest,
|
||||
ForgettingConfigResponse,
|
||||
ForgettingConfigUpdateRequest,
|
||||
ForgettingStatsResponse,
|
||||
ForgettingReportResponse,
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory/forget",
|
||||
tags=["Memory Forgetting Engine"],
|
||||
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
|
||||
)
|
||||
|
||||
# 初始化服务
|
||||
forget_service = MemoryForgetService()
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
@router.post("/trigger", response_model=ApiResponse)
|
||||
async def trigger_forgetting_cycle(
|
||||
payload: ForgettingTriggerRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
手动触发遗忘周期
|
||||
|
||||
执行一次完整的遗忘周期,识别并融合低激活值节点。
|
||||
|
||||
Args:
|
||||
payload: 触发请求参数
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含遗忘报告的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
|
||||
f"group_id={payload.group_id}, max_batch={payload.max_merge_batch_size}, "
|
||||
f"min_days={payload.min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层执行遗忘周期
|
||||
report = await forget_service.trigger_forgetting_cycle(
|
||||
db=db,
|
||||
group_id=payload.group_id,
|
||||
max_merge_batch_size=payload.max_merge_batch_size,
|
||||
min_days_since_access=payload.min_days_since_access,
|
||||
config_id=payload.config_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingReportResponse(**report)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="遗忘周期执行成功")
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.warning(f"遗忘周期执行被拒绝: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "RuntimeError")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"触发遗忘周期失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "触发遗忘周期失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_config", response_model=ApiResponse)
|
||||
async def read_forgetting_config(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘引擎配置
|
||||
|
||||
读取指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含配置信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层读取配置
|
||||
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingConfigResponse(**config)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"配置不存在: config_id={config_id}, 错误: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"配置不存在: {config_id}", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"读取遗忘引擎配置失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse)
|
||||
async def update_forgetting_config(
|
||||
payload: ForgettingConfigUpdateRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
更新遗忘引擎配置
|
||||
|
||||
更新指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
payload: 配置更新请求
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含更新结果的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 构建更新字段字典(排除 None 值和 config_id)
|
||||
update_data = {
|
||||
key: value
|
||||
for key, value in payload.model_dump(exclude_none=True).items()
|
||||
if key != 'config_id'
|
||||
}
|
||||
|
||||
# 调用服务层更新配置
|
||||
config = forget_service.update_forgetting_config(
|
||||
db=db,
|
||||
config_id=payload.config_id,
|
||||
update_fields=update_data
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingConfigResponse(**config)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"配置不存在: config_id={payload.config_id}, 错误: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"更新遗忘引擎配置失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||
|
||||
|
||||
@router.get("/stats", response_model=ApiResponse)
|
||||
async def get_forgetting_stats(
|
||||
group_id: Optional[str] = None,
|
||||
config_id: Optional[int] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘引擎统计信息
|
||||
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
group_id: 组ID(可选)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含统计信息的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
|
||||
f"group_id={group_id}, config_id={config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取统计信息
|
||||
stats = await forget_service.get_forgetting_stats(
|
||||
db=db,
|
||||
group_id=group_id,
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingStatsResponse(**stats)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取遗忘引擎统计失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取遗忘曲线数据
|
||||
|
||||
生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。
|
||||
|
||||
Args:
|
||||
request: 遗忘曲线请求参数
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含遗忘曲线数据的响应
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘曲线: "
|
||||
f"importance_score={request.importance_score}, days={request.days}, config_id={request.config_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层生成遗忘曲线
|
||||
result = await forget_service.get_forgetting_curve(
|
||||
db=db,
|
||||
importance_score=request.importance_score,
|
||||
days=request.days,
|
||||
config_id=request.config_id
|
||||
)
|
||||
|
||||
# 转换为响应格式
|
||||
curve_points = [
|
||||
ForgettingCurvePoint(**point)
|
||||
for point in result['curve_data']
|
||||
]
|
||||
|
||||
# 构建响应
|
||||
response_data = ForgettingCurveResponse(
|
||||
curve_data=curve_points,
|
||||
config=result['config']
|
||||
)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取遗忘曲线失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘曲线失败", str(e))
|
||||
@@ -1,4 +1,3 @@
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
@@ -9,12 +8,7 @@ from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
)
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
@@ -22,8 +16,6 @@ from app.schemas.memory_storage_schema import (
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
GenerateCacheRequest,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_storage_service import (
|
||||
@@ -238,28 +230,8 @@ def update_config_extracted(
|
||||
|
||||
|
||||
# --- Forget config params ---
|
||||
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径)
|
||||
def update_config_forget(
|
||||
payload: ConfigUpdateForget,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.update_forget(payload)
|
||||
return success(data=result, msg="更新成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Update config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
|
||||
|
||||
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py
|
||||
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
|
||||
|
||||
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_extracted(
|
||||
@@ -283,28 +255,6 @@ def read_config_extracted(
|
||||
api_logger.error(f"Read config extracted failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
|
||||
|
||||
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
|
||||
def read_config_forget(
|
||||
config_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.get_forget(ConfigKey(config_id=config_id))
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Read config forget failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
|
||||
|
||||
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
|
||||
@@ -106,28 +106,32 @@ class SearchService:
|
||||
limit: int = 15,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.4,
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
Execute hybrid search with two-stage ranking.
|
||||
|
||||
Stage 1: Filter by content relevance (BM25 + Embedding)
|
||||
Stage 2: Rerank by activation values (ACTR)
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering results
|
||||
group_id: Group identifier for filtering
|
||||
question: Search query text
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
|
||||
limit: Max results per category (default: 15)
|
||||
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: BM25 weight (default: 0.6)
|
||||
activation_boost_factor: Activation impact on memory strength (default: 0.8)
|
||||
output_path: JSON output path (default: "search_results.json")
|
||||
return_raw_results: Return full metadata (default: False)
|
||||
memory_config: MemoryConfig for embedding model
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
@@ -151,6 +155,7 @@ class SearchService:
|
||||
output_path=output_path,
|
||||
memory_config=config,
|
||||
rerank_alpha=rerank_alpha,
|
||||
activation_boost_factor=activation_boost_factor,
|
||||
)
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
|
||||
@@ -228,6 +228,13 @@ class StatementNode(Node):
|
||||
chunk_embedding: Optional embedding vector for the parent chunk
|
||||
connect_strength: Classification of connection strength ('Strong' or 'Weak')
|
||||
config_id: Configuration ID used to process this statement
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
|
||||
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
|
||||
access_history: List of ISO timestamp strings recording each access
|
||||
last_access_time: ISO timestamp of the most recent access
|
||||
access_count: Total number of times this node has been accessed
|
||||
"""
|
||||
# Core fields (ordered as requested)
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk")
|
||||
@@ -269,6 +276,33 @@ class StatementNode(Node):
|
||||
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score for memory activation (0.0-1.0), default 0.5"
|
||||
)
|
||||
activation_value: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
|
||||
)
|
||||
access_history: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of ISO timestamp strings recording each access"
|
||||
)
|
||||
last_access_time: Optional[str] = Field(
|
||||
None,
|
||||
description="ISO timestamp of the most recent access"
|
||||
)
|
||||
access_count: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
@field_validator('valid_at', 'invalid_at', mode='before')
|
||||
@classmethod
|
||||
def validate_datetime(cls, v):
|
||||
@@ -351,6 +385,13 @@ class ExtractedEntityNode(Node):
|
||||
fact_summary: Summary of facts about this entity
|
||||
connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both')
|
||||
config_id: Configuration ID used to process this entity (integer or string)
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
|
||||
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
|
||||
access_history: List of ISO timestamp strings recording each access
|
||||
last_access_time: ISO timestamp of the most recent access
|
||||
access_count: Total number of times this node has been accessed
|
||||
"""
|
||||
entity_idx: int = Field(..., description="Unique identifier for the entity")
|
||||
statement_id: str = Field(..., description="Statement this entity was extracted from")
|
||||
@@ -365,6 +406,33 @@ class ExtractedEntityNode(Node):
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score for memory activation (0.0-1.0), default 0.5"
|
||||
)
|
||||
activation_value: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
|
||||
)
|
||||
access_history: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of ISO timestamp strings recording each access"
|
||||
)
|
||||
last_access_time: Optional[str] = Field(
|
||||
None,
|
||||
description="ISO timestamp of the most recent access"
|
||||
)
|
||||
access_count: int = Field(
|
||||
default=0,
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed"
|
||||
)
|
||||
|
||||
@field_validator('aliases', mode='before')
|
||||
@classmethod
|
||||
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
|
||||
@@ -401,6 +469,16 @@ class MemorySummaryNode(Node):
|
||||
summary_embedding: Optional embedding vector for the summary
|
||||
metadata: Additional metadata for the summary
|
||||
config_id: Configuration ID used to process this summary
|
||||
original_statement_id: ID of the original statement that was merged (for ACT-R forgetting)
|
||||
original_entity_id: ID of the original entity that was merged (for ACT-R forgetting)
|
||||
merged_at: Timestamp when the nodes were merged
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: Importance score for memory activation (0.0-1.0), inherited from merged nodes
|
||||
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes
|
||||
access_history: List of ISO timestamp strings recording each access (reset on creation)
|
||||
last_access_time: ISO timestamp of the most recent access (set to creation time)
|
||||
access_count: Total number of times this node has been accessed (reset to 1 on creation)
|
||||
"""
|
||||
summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary")
|
||||
dialog_id: str = Field(..., description="ID of the parent dialog")
|
||||
@@ -409,3 +487,44 @@ class MemorySummaryNode(Node):
|
||||
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
|
||||
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
|
||||
|
||||
# ACT-R Forgetting Engine Properties
|
||||
original_statement_id: Optional[str] = Field(
|
||||
None,
|
||||
description="ID of the original statement that was merged (for traceability)"
|
||||
)
|
||||
original_entity_id: Optional[str] = Field(
|
||||
None,
|
||||
description="ID of the original entity that was merged (for traceability)"
|
||||
)
|
||||
merged_at: Optional[datetime] = Field(
|
||||
None,
|
||||
description="Timestamp when the nodes were merged"
|
||||
)
|
||||
|
||||
# ACT-R Memory Activation Properties
|
||||
importance_score: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Importance score for memory activation (0.0-1.0), inherited from merged nodes"
|
||||
)
|
||||
activation_value: Optional[float] = Field(
|
||||
None,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes"
|
||||
)
|
||||
access_history: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of ISO timestamp strings recording each access (reset on creation)"
|
||||
)
|
||||
last_access_time: Optional[str] = Field(
|
||||
None,
|
||||
description="ISO timestamp of the most recent access (set to creation time)"
|
||||
)
|
||||
access_count: int = Field(
|
||||
default=1,
|
||||
ge=0,
|
||||
description="Total number of times this node has been accessed (reset to 1 on creation)"
|
||||
)
|
||||
|
||||
@@ -69,6 +69,12 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item.get(score_field)
|
||||
# 对于 activation_value,None 值保持为 None,不使用回退值
|
||||
# 这样可以区分有激活值和无激活值的节点
|
||||
if score_field == "activation_value" and score is None:
|
||||
scores.append(None) # 保持 None,稍后特殊处理
|
||||
continue
|
||||
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
scores.append(float(score))
|
||||
else:
|
||||
@@ -76,205 +82,433 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
|
||||
if not scores:
|
||||
return results
|
||||
|
||||
if len(scores) == 1:
|
||||
# Single score, set to 1.0
|
||||
|
||||
# 过滤掉 None 值,只对有效分数进行归一化
|
||||
valid_scores = [s for s in scores if s is not None]
|
||||
|
||||
if not valid_scores:
|
||||
# 所有分数都是 None,不进行归一化
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
if score_field in item or score_field == "activation_value":
|
||||
item[f"normalized_{score_field}"] = None
|
||||
return results
|
||||
|
||||
# Calculate mean and standard deviation
|
||||
mean_score = sum(scores) / len(scores)
|
||||
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
|
||||
if len(valid_scores) == 1: # Single valid score, set to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
item[f"normalized_{score_field}"] = None
|
||||
else:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
return results
|
||||
|
||||
# Calculate mean and standard deviation (only for valid scores)
|
||||
mean_score = sum(valid_scores) / len(valid_scores)
|
||||
variance = sum((score - mean_score) ** 2 for score in valid_scores) / len(valid_scores)
|
||||
std_dev = math.sqrt(variance)
|
||||
|
||||
if std_dev == 0:
|
||||
# All scores are the same, set them to 1.0
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
# All valid scores are the same, set them to 1.0
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
item[f"normalized_{score_field}"] = None
|
||||
else:
|
||||
item[f"normalized_{score_field}"] = 1.0
|
||||
else:
|
||||
for item in results:
|
||||
if score_field in item:
|
||||
score = item[score_field]
|
||||
# Handle None or non-numeric scores
|
||||
if score is None or not isinstance(score, (int, float)):
|
||||
score = 0.0
|
||||
# Calculate z-score
|
||||
z_score = (score - mean_score) / std_dev
|
||||
# Transform to positive range using sigmoid function
|
||||
normalized = 1 / (1 + math.exp(-z_score))
|
||||
item[f"normalized_{score_field}"] = normalized
|
||||
for item, score in zip(results, scores):
|
||||
if score_field in item or score_field == "activation_value":
|
||||
if score is None:
|
||||
# 保持 None,不进行归一化
|
||||
item[f"normalized_{score_field}"] = None
|
||||
else:
|
||||
# Calculate z-score
|
||||
z_score = (score - mean_score) / std_dev
|
||||
# Transform to positive range using sigmoid function
|
||||
normalized = 1 / (1 + math.exp(-z_score))
|
||||
item[f"normalized_{score_field}"] = normalized
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rerank_hybrid_results(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Rerank hybrid search results by combining BM25 and embedding scores.
|
||||
# ============================================================================
|
||||
# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
|
||||
# ============================================================================
|
||||
|
||||
Args:
|
||||
keyword_results: Results from keyword/BM25 search
|
||||
embedding_results: Results from embedding search
|
||||
alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
limit: Maximum number of results to return per category
|
||||
# def rerank_hybrid_results(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid search results by combining BM25 and embedding scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined scores
|
||||
# """
|
||||
# reranked = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# # Create a combined pool of unique items
|
||||
# combined_items = {}
|
||||
#
|
||||
# # Add keyword results with BM25 scores
|
||||
# for item in keyword_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# combined_items[item_id]["embedding_score"] = 0 # Default
|
||||
#
|
||||
# # Add or update with embedding results
|
||||
# for item in embedding_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# if item_id in combined_items:
|
||||
# # Update existing item with embedding score
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# # New item from embedding search only
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0 # Default
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate combined scores and rank
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = item.get("bm25_score", 0)
|
||||
# embedding_score = item.get("embedding_score", 0)
|
||||
#
|
||||
# # Combined score: weighted average of normalized scores
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# item["combined_score"] = combined_score
|
||||
#
|
||||
# # Keep original score for reference
|
||||
# if "score" not in item and bm25_score > 0:
|
||||
# item["score"] = bm25_score
|
||||
# elif "score" not in item and embedding_score > 0:
|
||||
# item["score"] = embedding_score
|
||||
#
|
||||
# # Sort by combined score and limit results
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
|
||||
Returns:
|
||||
Reranked results with combined scores
|
||||
"""
|
||||
reranked = {}
|
||||
# def rerank_with_forgetting_curve(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10,
|
||||
# forgetting_config: ForgettingEngineConfig | None = None,
|
||||
# now: datetime | None = None,
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid results with a forgetting curve applied to combined scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
|
||||
#
|
||||
# The forgetting curve reduces scores for older memories or weaker connections.
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
# forgetting_config: Configuration for the forgetting engine
|
||||
# now: Optional current time override for testing
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined and final scores (after forgetting)
|
||||
# """
|
||||
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
|
||||
# now_dt = now or datetime.now()
|
||||
#
|
||||
# reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
#
|
||||
# # Combine two result sets by ID
|
||||
# for src_items, is_embedding in (
|
||||
# (keyword_items, False), (embedding_items, True)
|
||||
# ):
|
||||
# for item in src_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if not item_id:
|
||||
# continue
|
||||
# existing = combined_items.get(item_id)
|
||||
# if not existing:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
# # Update normalized score from the right source
|
||||
# if is_embedding:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate scores and apply forgetting weights
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
# embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
#
|
||||
# # Estimate time elapsed in days
|
||||
# dt = _parse_datetime(item.get("created_at"))
|
||||
# if dt is None:
|
||||
# time_elapsed_days = 0.0
|
||||
# else:
|
||||
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
#
|
||||
# # Memory strength (currently set to default value)
|
||||
# memory_strength = 1.0
|
||||
# forgetting_weight = engine.calculate_weight(
|
||||
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
|
||||
# )
|
||||
# final_score = combined_score * forgetting_weight
|
||||
# item["combined_score"] = final_score
|
||||
#
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
|
||||
for category in ["statements", "chunks", "entities","summaries"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
# Normalize scores within each search type
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
# Create a combined pool of unique items
|
||||
combined_items = {}
|
||||
|
||||
# Add keyword results with BM25 scores
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # Default
|
||||
|
||||
# Add or update with embedding results
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id:
|
||||
if item_id in combined_items:
|
||||
# Update existing item with embedding score
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
# New item from embedding search only
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # Default
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# Calculate combined scores and rank
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = item.get("bm25_score", 0)
|
||||
embedding_score = item.get("embedding_score", 0)
|
||||
|
||||
# Combined score: weighted average of normalized scores
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
item["combined_score"] = combined_score
|
||||
|
||||
# Keep original score for reference
|
||||
if "score" not in item and bm25_score > 0:
|
||||
item["score"] = bm25_score
|
||||
elif "score" not in item and embedding_score > 0:
|
||||
item["score"] = embedding_score
|
||||
|
||||
# Sort by combined score and limit results
|
||||
sorted_items = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: x.get("combined_score", 0),
|
||||
reverse=True
|
||||
)[:limit]
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
return reranked
|
||||
|
||||
def rerank_with_forgetting_curve(
|
||||
def rerank_with_activation(
|
||||
keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
alpha: float = 0.6,
|
||||
limit: int = 10,
|
||||
forgetting_config: ForgettingEngineConfig | None = None,
|
||||
activation_boost_factor: float = 0.8,
|
||||
now: datetime | None = None,
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Rerank hybrid results with a forgetting curve applied to combined scores.
|
||||
|
||||
The forgetting curve reduces scores for older memories or weaker connections.
|
||||
|
||||
Args:
|
||||
keyword_results: Results from keyword/BM25 search
|
||||
embedding_results: Results from embedding search
|
||||
alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
limit: Maximum number of results to return per category
|
||||
forgetting_config: Configuration for the forgetting engine
|
||||
now: Optional current time override for testing
|
||||
|
||||
Returns:
|
||||
Reranked results with combined and final scores (after forgetting)
|
||||
两阶段排序:先按内容相关性筛选,再按激活值排序。
|
||||
|
||||
阶段1: content_score = alpha*BM25 + (1-alpha)*Embedding,取 Top-(limit*3)
|
||||
阶段2: 在候选中按 activation_score 排序,取 Top-limit
|
||||
无激活值的节点用于补充不足
|
||||
|
||||
返回结果中的评分字段说明:
|
||||
- bm25_score: BM25 归一化分数
|
||||
- embedding_score: Embedding 归一化分数
|
||||
- content_score: 内容相关性 = alpha*bm25 + (1-alpha)*embedding
|
||||
- activation_score: ACTR 激活值归一化分数
|
||||
- base_score: 第一阶段基础分数(等于 content_score)
|
||||
- final_score: 最终排序依据
|
||||
* 有激活值的节点:final_score = activation_score
|
||||
* 无激活值的节点:final_score = base_score
|
||||
|
||||
参数:
|
||||
keyword_results: BM25 检索结果
|
||||
embedding_results: 向量嵌入检索结果
|
||||
alpha: BM25 权重 (默认: 0.6)
|
||||
limit: 每类最大结果数
|
||||
forgetting_config: 遗忘引擎配置(当前未使用)
|
||||
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
|
||||
now: 当前时间(用于遗忘计算)
|
||||
|
||||
返回:
|
||||
带评分元数据的重排序结果,按 final_score 排序
|
||||
"""
|
||||
engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
|
||||
# 验证权重范围
|
||||
if not (0 <= alpha <= 1):
|
||||
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
|
||||
|
||||
# 初始化遗忘引擎(如果需要)
|
||||
engine = None
|
||||
if forgetting_config:
|
||||
engine = ForgettingEngine(forgetting_config)
|
||||
now_dt = now or datetime.now()
|
||||
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities","summaries"]:
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
# Normalize scores within each search type
|
||||
|
||||
# 步骤 1: 归一化分数
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
|
||||
# 步骤 2: 按 ID 合并结果
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Combine two result sets by ID
|
||||
for src_items, is_embedding in (
|
||||
(keyword_items, False), (embedding_items, True)
|
||||
):
|
||||
for item in src_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
existing = combined_items.get(item_id)
|
||||
if not existing:
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0
|
||||
combined_items[item_id]["embedding_score"] = 0
|
||||
# Update normalized score from the right source
|
||||
if is_embedding:
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# Calculate scores and apply forgetting weights
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
|
||||
# Estimate time elapsed in days
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
|
||||
# 添加关键词结果
|
||||
for item in keyword_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
combined_items[item_id]["embedding_score"] = 0 # 默认值
|
||||
|
||||
# 添加或更新向量嵌入结果
|
||||
for item in embedding_items:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if not item_id:
|
||||
continue
|
||||
if item_id in combined_items:
|
||||
# 更新现有项的嵌入分数
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
# Memory strength (currently set to default value)
|
||||
memory_strength = 1.0
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days, memory_strength=memory_strength
|
||||
)
|
||||
# print(f"Forgetting weight for {item_id}: {forgetting_weight}")
|
||||
# print(f"Time elapsed days for {item_id}: {time_elapsed_days}")
|
||||
final_score = combined_score * forgetting_weight
|
||||
item["combined_score"] = final_score
|
||||
|
||||
sorted_items = sorted(
|
||||
combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
|
||||
)[:limit]
|
||||
|
||||
# 仅来自嵌入搜索的新项
|
||||
combined_items[item_id] = item.copy()
|
||||
combined_items[item_id]["bm25_score"] = 0 # 默认值
|
||||
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
|
||||
# 步骤 3: 归一化激活度分数
|
||||
# 为所有项准备激活度值列表
|
||||
items_list = list(combined_items.values())
|
||||
items_list = normalize_scores(items_list, "activation_value")
|
||||
|
||||
# 更新 combined_items 中的归一化激活度分数
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
# 存储激活度分数供第二阶段使用
|
||||
item["activation_score"] = act_norm
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
# 步骤 5: 应用遗忘曲线(可选)
|
||||
if engine:
|
||||
# 计算受激活度影响的记忆强度
|
||||
importance = float(item.get("importance_score", 0.5) or 0.5)
|
||||
|
||||
# 获取 activation_value
|
||||
activation_val = item.get("activation_value")
|
||||
|
||||
# 只对有激活值的节点应用遗忘曲线
|
||||
if activation_val is not None and isinstance(activation_val, (int, float)):
|
||||
activation_val = float(activation_val)
|
||||
|
||||
# 计算记忆强度:importance_score × (1 + activation_value × boost_factor)
|
||||
memory_strength = importance * (1 + activation_val * activation_boost_factor)
|
||||
|
||||
# 计算经过的时间(天数)
|
||||
dt = _parse_datetime(item.get("created_at"))
|
||||
if dt is None:
|
||||
time_elapsed_days = 0.0
|
||||
else:
|
||||
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
|
||||
# 获取遗忘权重
|
||||
forgetting_weight = engine.calculate_weight(
|
||||
time_elapsed=time_elapsed_days,
|
||||
memory_strength=memory_strength
|
||||
)
|
||||
|
||||
# 应用到基础分数
|
||||
item["forgetting_weight"] = forgetting_weight
|
||||
item["final_score"] = base_score * forgetting_weight
|
||||
else:
|
||||
# 无激活值的节点不应用遗忘曲线,保持原始分数
|
||||
item["final_score"] = base_score
|
||||
else:
|
||||
# 不使用遗忘曲线
|
||||
item["final_score"] = base_score
|
||||
|
||||
# 步骤 6: 两阶段排序和限制
|
||||
# 第一阶段:按内容相关性(base_score)排序,取 Top-K
|
||||
first_stage_limit = limit * 3 # 可配置,取3倍候选
|
||||
first_stage_sorted = sorted(
|
||||
combined_items.values(),
|
||||
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
|
||||
reverse=True
|
||||
)[:first_stage_limit]
|
||||
|
||||
# 第二阶段:分离有激活值和无激活值的节点
|
||||
items_with_activation = []
|
||||
items_without_activation = []
|
||||
|
||||
for item in first_stage_sorted:
|
||||
activation_score = item.get("activation_score")
|
||||
# 检查是否有有效的激活值(不是 None)
|
||||
if activation_score is not None and isinstance(activation_score, (int, float)):
|
||||
items_with_activation.append(item)
|
||||
else:
|
||||
items_without_activation.append(item)
|
||||
|
||||
# 优先按激活值排序有激活值的节点
|
||||
sorted_with_activation = sorted(
|
||||
items_with_activation,
|
||||
key=lambda x: float(x.get("activation_score", 0) or 0),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# 如果有激活值的节点不足 limit,用无激活值的节点补充
|
||||
if len(sorted_with_activation) < limit:
|
||||
needed = limit - len(sorted_with_activation)
|
||||
# 无激活值的节点保持第一阶段的内容相关性排序
|
||||
sorted_items = sorted_with_activation + items_without_activation[:needed]
|
||||
else:
|
||||
sorted_items = sorted_with_activation[:limit]
|
||||
|
||||
# 两阶段排序完成,更新 final_score 以反映实际排序依据
|
||||
# Stage 1: 按 content_score 筛选候选(已完成)
|
||||
# Stage 2: 按 activation_score 排序(已完成)
|
||||
#
|
||||
# final_score 语义:反映节点在最终结果中的排序依据
|
||||
# - 有激活值的节点:final_score = activation_score(第二阶段排序依据)
|
||||
# - 无激活值的节点:final_score = base_score(保持内容相关性分数)
|
||||
for item in sorted_items:
|
||||
activation_score = item.get("activation_score")
|
||||
if activation_score is not None and isinstance(activation_score, (int, float)):
|
||||
# 有激活值:使用激活度作为最终分数
|
||||
item["final_score"] = activation_score
|
||||
else:
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
|
||||
return reranked
|
||||
|
||||
|
||||
@@ -560,6 +794,7 @@ async def run_hybrid_search(
|
||||
output_path: str | None,
|
||||
memory_config: "MemoryConfig",
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
):
|
||||
@@ -685,30 +920,28 @@ async def run_hybrid_search(
|
||||
"search_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Apply reranking (optionally with forgetting curve)
|
||||
# Apply two-stage reranking with ACTR activation calculation
|
||||
rerank_start = time.time()
|
||||
if use_forgetting_rerank:
|
||||
# Load forgetting parameters from pipeline config
|
||||
try:
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
reranked_results = rerank_with_forgetting_curve(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha,
|
||||
limit=limit,
|
||||
forgetting_config=forgetting_cfg,
|
||||
)
|
||||
else:
|
||||
reranked_results = rerank_hybrid_results(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha, # Configurable weight for BM25 vs embedding
|
||||
limit=limit
|
||||
)
|
||||
logger.info("Using two-stage reranking with ACTR activation")
|
||||
|
||||
# 加载遗忘引擎配置
|
||||
try:
|
||||
pc = get_pipeline_config(memory_config)
|
||||
forgetting_cfg = pc.forgetting_engine
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
|
||||
forgetting_cfg = ForgettingEngineConfig()
|
||||
|
||||
# 统一使用激活度重排序(两阶段:检索 + ACTR计算)
|
||||
reranked_results = rerank_with_activation(
|
||||
keyword_results=keyword_results,
|
||||
embedding_results=embedding_results,
|
||||
alpha=rerank_alpha,
|
||||
limit=limit,
|
||||
forgetting_config=forgetting_cfg,
|
||||
activation_boost_factor=activation_boost_factor,
|
||||
)
|
||||
|
||||
rerank_latency = time.time() - rerank_start
|
||||
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
|
||||
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
|
||||
@@ -737,6 +970,7 @@ async def run_hybrid_search(
|
||||
"search_query": query_text,
|
||||
"search_timestamp": datetime.now().isoformat(),
|
||||
"reranking_alpha": rerank_alpha,
|
||||
"activation_boost_factor": activation_boost_factor,
|
||||
"forgetting_rerank": use_forgetting_rerank,
|
||||
"llm_rerank": llm_rerank_applied,
|
||||
}
|
||||
|
||||
@@ -1,8 +1,40 @@
|
||||
"""遗忘引擎模块
|
||||
|
||||
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。
|
||||
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线和 ACT-R 认知架构理论。
|
||||
"""
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
|
||||
ACTRCalculator,
|
||||
calculate_activation,
|
||||
generate_forgetting_curve
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
|
||||
AccessHistoryManager,
|
||||
ConsistencyCheckResult
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import (
|
||||
ForgettingStrategy
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import (
|
||||
ForgettingScheduler
|
||||
)
|
||||
from app.core.memory.storage_services.forgetting_engine.config_utils import (
|
||||
calculate_forgetting_rate,
|
||||
load_actr_config_from_db,
|
||||
create_actr_calculator_from_config
|
||||
)
|
||||
|
||||
__all__ = ["ForgettingEngine"]
|
||||
__all__ = [
|
||||
"ForgettingEngine",
|
||||
"ACTRCalculator",
|
||||
"calculate_activation",
|
||||
"generate_forgetting_curve",
|
||||
"AccessHistoryManager",
|
||||
"ConsistencyCheckResult",
|
||||
"ForgettingStrategy",
|
||||
"ForgettingScheduler",
|
||||
"calculate_forgetting_rate",
|
||||
"load_actr_config_from_db",
|
||||
"create_actr_calculator_from_config"
|
||||
]
|
||||
|
||||
@@ -0,0 +1,691 @@
|
||||
"""
|
||||
访问历史管理器模块
|
||||
|
||||
本模块实现访问历史的追踪、更新和一致性保证。
|
||||
负责在知识节点被访问时原子性地更新激活值相关的所有字段。
|
||||
|
||||
Classes:
|
||||
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConsistencyCheckResult(Enum):
|
||||
"""一致性检查结果枚举"""
|
||||
CONSISTENT = "consistent" # 数据一致
|
||||
INCONSISTENT_HISTORY_TIME = "inconsistent_history_time" # access_history[-1] != last_access_time
|
||||
INCONSISTENT_HISTORY_COUNT = "inconsistent_history_count" # len(access_history) != access_count
|
||||
MISSING_ACTIVATION = "missing_activation" # 有访问历史但无激活值
|
||||
INVALID_ACTIVATION_RANGE = "invalid_activation_range" # 激活值超出有效范围
|
||||
|
||||
|
||||
class AccessHistoryManager:
|
||||
"""
|
||||
访问历史管理器
|
||||
|
||||
负责追踪知识节点的访问历史,并在访问时原子性地更新所有相关字段:
|
||||
- activation_value: 激活值
|
||||
- access_history: 访问历史时间戳数组
|
||||
- last_access_time: 最后访问时间
|
||||
- access_count: 访问次数
|
||||
|
||||
特性:
|
||||
- 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚
|
||||
- 并发安全:使用乐观锁机制防止并发冲突
|
||||
- 一致性保证:提供一致性检查和自动修复功能
|
||||
- 智能修剪:自动修剪过长的访问历史
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
actr_calculator: ACTRCalculator,
|
||||
max_retries: int = 3
|
||||
):
|
||||
"""
|
||||
初始化访问历史管理器
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
actr_calculator: ACT-R激活值计算器实例
|
||||
max_retries: 并发冲突时的最大重试次数(默认3次)
|
||||
"""
|
||||
self.connector = connector
|
||||
self.actr_calculator = actr_calculator
|
||||
self.max_retries = max_retries
|
||||
|
||||
async def record_access(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
记录节点访问并原子性更新所有相关字段
|
||||
|
||||
这是核心方法,实现了:
|
||||
1. 首次访问:初始化access_history,计算初始激活值
|
||||
2. 后续访问:追加访问历史,重新计算激活值
|
||||
3. 历史修剪:当历史过长时自动修剪
|
||||
4. 原子性:所有字段在单个事务中更新
|
||||
5. 并发安全:使用乐观锁重试机制
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选,用于过滤)
|
||||
current_time: 当前时间(可选,默认使用系统时间)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据,包含:
|
||||
- id: 节点ID
|
||||
- activation_value: 更新后的激活值
|
||||
- access_history: 更新后的访问历史
|
||||
- last_access_time: 最后访问时间
|
||||
- access_count: 访问次数
|
||||
- importance_score: 重要性分数
|
||||
|
||||
Raises:
|
||||
ValueError: 如果节点不存在或节点标签无效
|
||||
RuntimeError: 如果重试次数耗尽仍然失败
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
current_time_iso = current_time.isoformat()
|
||||
|
||||
# 验证节点标签
|
||||
valid_labels = ["Statement", "ExtractedEntity", "MemorySummary"]
|
||||
if node_label not in valid_labels:
|
||||
raise ValueError(
|
||||
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
|
||||
)
|
||||
|
||||
# 使用乐观锁重试机制处理并发冲突
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# 步骤1:读取当前节点状态
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
|
||||
if not node_data:
|
||||
raise ValueError(
|
||||
f"Node not found: {node_label} with id={node_id}"
|
||||
)
|
||||
|
||||
# 步骤2:计算新的访问历史和激活值
|
||||
update_data = await self._calculate_update(
|
||||
node_data=node_data,
|
||||
current_time=current_time,
|
||||
current_time_iso=current_time_iso
|
||||
)
|
||||
|
||||
# 步骤3:原子性更新节点(使用事务)
|
||||
updated_node = await self._atomic_update(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功记录访问: {node_label}[{node_id}], "
|
||||
f"activation={update_data['activation_value']:.4f}, "
|
||||
f"access_count={update_data['access_count']}"
|
||||
)
|
||||
|
||||
return updated_node
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
|
||||
f"错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
|
||||
async def record_batch_access(
|
||||
self,
|
||||
node_ids: List[str],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
current_time: Optional[datetime] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量记录多个节点的访问
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史。
|
||||
每个节点独立更新,失败的节点不影响其他节点。
|
||||
|
||||
Args:
|
||||
node_ids: 节点ID列表
|
||||
node_label: 节点标签(所有节点必须是同一类型)
|
||||
group_id: 组ID(可选)
|
||||
current_time: 当前时间(可选)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 成功更新的节点列表
|
||||
"""
|
||||
if current_time is None:
|
||||
current_time = datetime.now()
|
||||
|
||||
results = []
|
||||
failed_count = 0
|
||||
|
||||
for node_id in node_ids:
|
||||
try:
|
||||
updated_node = await self.record_access(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id,
|
||||
current_time=current_time
|
||||
)
|
||||
results.append(updated_node)
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
logger.warning(
|
||||
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
|
||||
f"失败 {failed_count}"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def check_consistency(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
"""
|
||||
检查节点数据的一致性
|
||||
|
||||
验证以下一致性规则:
|
||||
1. access_history[-1] == last_access_time
|
||||
2. len(access_history) == access_count
|
||||
3. 如果有访问历史,必须有激活值
|
||||
4. 激活值必须在有效范围内 [offset, 1.0]
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[ConsistencyCheckResult, Optional[str]]:
|
||||
- 一致性检查结果枚举
|
||||
- 错误描述(如果不一致)
|
||||
"""
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
|
||||
if not node_data:
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
access_history = node_data.get('access_history', [])
|
||||
last_access_time = node_data.get('last_access_time')
|
||||
access_count = node_data.get('access_count', 0)
|
||||
activation_value = node_data.get('activation_value')
|
||||
|
||||
# 检查1:access_history[-1] == last_access_time
|
||||
if access_history and last_access_time:
|
||||
if access_history[-1] != last_access_time:
|
||||
return (
|
||||
ConsistencyCheckResult.INCONSISTENT_HISTORY_TIME,
|
||||
f"access_history[-1]={access_history[-1]} != "
|
||||
f"last_access_time={last_access_time}"
|
||||
)
|
||||
|
||||
# 检查2:len(access_history) == access_count
|
||||
if len(access_history) != access_count:
|
||||
return (
|
||||
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
|
||||
f"len(access_history)={len(access_history)} != "
|
||||
f"access_count={access_count}"
|
||||
)
|
||||
|
||||
# 检查3:有访问历史必须有激活值
|
||||
if access_history and activation_value is None:
|
||||
return (
|
||||
ConsistencyCheckResult.MISSING_ACTIVATION,
|
||||
"Node has access_history but activation_value is None"
|
||||
)
|
||||
|
||||
# 检查4:激活值范围
|
||||
if activation_value is not None:
|
||||
offset = self.actr_calculator.offset
|
||||
if not (offset <= activation_value <= 1.0):
|
||||
return (
|
||||
ConsistencyCheckResult.INVALID_ACTIVATION_RANGE,
|
||||
f"activation_value={activation_value} out of range "
|
||||
f"[{offset}, 1.0]"
|
||||
)
|
||||
|
||||
return ConsistencyCheckResult.CONSISTENT, None
|
||||
|
||||
async def check_batch_consistency(
|
||||
self,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
limit: int = 1000
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
批量检查多个节点的一致性
|
||||
|
||||
Args:
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
limit: 检查的最大节点数
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 一致性检查报告,包含:
|
||||
- total_checked: 检查的节点总数
|
||||
- consistent_count: 一致的节点数
|
||||
- inconsistent_count: 不一致的节点数
|
||||
- inconsistencies: 不一致节点的详细信息列表
|
||||
- consistency_rate: 一致性率(0-1)
|
||||
"""
|
||||
# 查询所有相关节点
|
||||
query = f"""
|
||||
MATCH (n:{node_label})
|
||||
WHERE n.access_history IS NOT NULL
|
||||
"""
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
query += """
|
||||
RETURN n.id as id
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
params = {"limit": limit}
|
||||
if group_id:
|
||||
params["group_id"] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
node_ids = [r['id'] for r in results]
|
||||
|
||||
# 检查每个节点
|
||||
inconsistencies = []
|
||||
consistent_count = 0
|
||||
|
||||
for node_id in node_ids:
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
consistent_count += 1
|
||||
else:
|
||||
inconsistencies.append({
|
||||
'node_id': node_id,
|
||||
'result': result.value,
|
||||
'message': message
|
||||
})
|
||||
|
||||
total_checked = len(node_ids)
|
||||
inconsistent_count = len(inconsistencies)
|
||||
consistency_rate = consistent_count / total_checked if total_checked > 0 else 1.0
|
||||
|
||||
report = {
|
||||
'total_checked': total_checked,
|
||||
'consistent_count': consistent_count,
|
||||
'inconsistent_count': inconsistent_count,
|
||||
'inconsistencies': inconsistencies,
|
||||
'consistency_rate': consistency_rate
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"一致性检查完成: {node_label}, "
|
||||
f"一致率={consistency_rate:.2%}, "
|
||||
f"不一致节点={inconsistent_count}/{total_checked}"
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
async def repair_inconsistency(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
自动修复节点的数据不一致问题
|
||||
|
||||
修复策略:
|
||||
1. 如果access_history[-1] != last_access_time:使用access_history[-1]
|
||||
2. 如果len(access_history) != access_count:使用len(access_history)
|
||||
3. 如果有历史但无激活值:重新计算激活值
|
||||
4. 如果激活值超出范围:重新计算激活值
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
bool: 修复成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 检查一致性
|
||||
result, message = await self.check_consistency(
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
if result == ConsistencyCheckResult.CONSISTENT:
|
||||
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
|
||||
return True
|
||||
|
||||
# 获取节点数据
|
||||
node_data = await self._fetch_node(node_id, node_label, group_id)
|
||||
if not node_data:
|
||||
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
|
||||
return False
|
||||
|
||||
access_history = node_data.get('access_history', [])
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 准备修复数据
|
||||
repair_data = {}
|
||||
|
||||
# 修复last_access_time
|
||||
if access_history:
|
||||
repair_data['last_access_time'] = access_history[-1]
|
||||
|
||||
# 修复access_count
|
||||
repair_data['access_count'] = len(access_history)
|
||||
|
||||
# 修复activation_value
|
||||
if access_history:
|
||||
current_time = datetime.now()
|
||||
last_access_dt = datetime.fromisoformat(access_history[-1])
|
||||
access_history_dt = [
|
||||
datetime.fromisoformat(ts) for ts in access_history
|
||||
]
|
||||
|
||||
activation_value = self.actr_calculator.calculate_memory_activation(
|
||||
access_history=access_history_dt,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_dt,
|
||||
importance_score=importance_score
|
||||
)
|
||||
repair_data['activation_value'] = activation_value
|
||||
|
||||
# 执行修复
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
query += """
|
||||
SET n += $repair_data
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
params = {
|
||||
'node_id': node_id,
|
||||
'repair_data': repair_data
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
await self.connector.execute_query(query, **params)
|
||||
|
||||
logger.info(
|
||||
f"成功修复节点不一致: {node_label}[{node_id}], "
|
||||
f"问题类型={result.value}"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
async def _fetch_node(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取节点数据
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Optional[Dict[str, Any]]: 节点数据,如果不存在返回None
|
||||
"""
|
||||
query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
query += " WHERE n.group_id = $group_id"
|
||||
query += """
|
||||
RETURN n.id as id,
|
||||
n.importance_score as importance_score,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count
|
||||
"""
|
||||
|
||||
params = {'node_id': node_id}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
if results:
|
||||
return results[0]
|
||||
return None
|
||||
|
||||
async def _calculate_update(
|
||||
self,
|
||||
node_data: Dict[str, Any],
|
||||
current_time: datetime,
|
||||
current_time_iso: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
计算更新数据
|
||||
|
||||
Args:
|
||||
node_data: 当前节点数据
|
||||
current_time: 当前时间(datetime对象)
|
||||
current_time_iso: 当前时间(ISO格式字符串)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新数据,包含所有需要更新的字段
|
||||
"""
|
||||
access_history = node_data.get('access_history', [])
|
||||
importance_score = node_data.get('importance_score', 0.5)
|
||||
|
||||
# 追加新的访问时间
|
||||
new_access_history = access_history + [current_time_iso]
|
||||
|
||||
# 修剪访问历史(如果过长)
|
||||
access_history_dt = [
|
||||
datetime.fromisoformat(ts) for ts in new_access_history
|
||||
]
|
||||
trimmed_history_dt = self.actr_calculator.trim_access_history(
|
||||
access_history=access_history_dt,
|
||||
current_time=current_time
|
||||
)
|
||||
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
|
||||
|
||||
# 计算新的激活值
|
||||
activation_value = self.actr_calculator.calculate_memory_activation(
|
||||
access_history=trimmed_history_dt,
|
||||
current_time=current_time,
|
||||
last_access_time=current_time, # 最后访问时间就是当前时间
|
||||
importance_score=importance_score
|
||||
)
|
||||
|
||||
# 返回所有需要更新的字段
|
||||
return {
|
||||
'activation_value': activation_value,
|
||||
'access_history': trimmed_history,
|
||||
'last_access_time': current_time_iso,
|
||||
'access_count': len(trimmed_history)
|
||||
}
|
||||
|
||||
async def _atomic_update(
|
||||
self,
|
||||
node_id: str,
|
||||
node_label: str,
|
||||
update_data: Dict[str, Any],
|
||||
group_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
原子性更新节点(使用乐观锁)
|
||||
|
||||
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
|
||||
实现乐观锁机制防止并发冲突。
|
||||
|
||||
Args:
|
||||
node_id: 节点ID
|
||||
node_label: 节点标签
|
||||
update_data: 更新数据
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 更新后的节点数据
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果更新失败或发生版本冲突
|
||||
"""
|
||||
# 定义事务函数
|
||||
async def update_transaction(tx, node_id, node_label, update_data, group_id):
|
||||
# 步骤1:读取当前节点并获取版本号
|
||||
read_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
read_query += " WHERE n.group_id = $group_id"
|
||||
read_query += """
|
||||
RETURN n.id as id,
|
||||
n.version as version,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score
|
||||
"""
|
||||
|
||||
read_params = {'node_id': node_id}
|
||||
if group_id:
|
||||
read_params['group_id'] = group_id
|
||||
|
||||
read_result = await tx.run(read_query, **read_params)
|
||||
current_node = await read_result.single()
|
||||
|
||||
if not current_node:
|
||||
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
|
||||
|
||||
# 获取当前版本号(如果不存在则为0)
|
||||
current_version = current_node.get('version', 0) or 0
|
||||
new_version = current_version + 1
|
||||
|
||||
# 步骤2:使用乐观锁更新节点
|
||||
# 只有当版本号匹配时才更新
|
||||
update_query = f"""
|
||||
MATCH (n:{node_label} {{id: $node_id}})
|
||||
"""
|
||||
if group_id:
|
||||
update_query += " WHERE n.group_id = $group_id"
|
||||
|
||||
# 添加版本检查
|
||||
if current_version > 0:
|
||||
update_query += " AND n.version = $current_version"
|
||||
else:
|
||||
# 如果节点没有版本号,检查是否为首次更新
|
||||
update_query += " AND (n.version IS NULL OR n.version = 0)"
|
||||
|
||||
update_query += """
|
||||
SET n.activation_value = $activation_value,
|
||||
n.access_history = $access_history,
|
||||
n.last_access_time = $last_access_time,
|
||||
n.access_count = $access_count,
|
||||
n.version = $new_version
|
||||
RETURN n.id as id,
|
||||
n.activation_value as activation_value,
|
||||
n.access_history as access_history,
|
||||
n.last_access_time as last_access_time,
|
||||
n.access_count as access_count,
|
||||
n.importance_score as importance_score,
|
||||
n.version as version
|
||||
"""
|
||||
|
||||
update_params = {
|
||||
'node_id': node_id,
|
||||
'current_version': current_version,
|
||||
'new_version': new_version,
|
||||
'activation_value': update_data['activation_value'],
|
||||
'access_history': update_data['access_history'],
|
||||
'last_access_time': update_data['last_access_time'],
|
||||
'access_count': update_data['access_count']
|
||||
}
|
||||
if group_id:
|
||||
update_params['group_id'] = group_id
|
||||
|
||||
update_result = await tx.run(update_query, **update_params)
|
||||
updated_node = await update_result.single()
|
||||
|
||||
if not updated_node:
|
||||
raise RuntimeError(
|
||||
f"Version conflict detected for {node_label}[{node_id}]. "
|
||||
f"Expected version {current_version}, but node was modified by another transaction."
|
||||
)
|
||||
|
||||
return dict(updated_node)
|
||||
|
||||
# 执行事务
|
||||
try:
|
||||
result = await self.connector.execute_write_transaction(
|
||||
update_transaction,
|
||||
node_id=node_id,
|
||||
node_label=node_label,
|
||||
update_data=update_data,
|
||||
group_id=group_id
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to atomically update node: {str(e)}"
|
||||
) from e
|
||||
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
ACT-R Memory Activation Calculator
|
||||
|
||||
This module implements the unified Memory Activation model based on ACT-R
|
||||
(Adaptive Control of Thought-Rational) cognitive architecture theory.
|
||||
|
||||
The calculator integrates BLA (Base-Level Activation) computation into the
|
||||
Memory Activation formula, providing a single coherent model for memory strength
|
||||
calculation that reflects both recency and frequency of access.
|
||||
|
||||
Formula: R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
|
||||
|
||||
Where:
|
||||
- R(i): Memory activation value (0 to 1)
|
||||
- offset: Minimum retention rate (prevents complete forgetting)
|
||||
- λ: Forgetting rate (lambda_time / lambda_mem)
|
||||
- t: Time since last access
|
||||
- I: Importance score (0 to 1)
|
||||
- t_k: Time since k-th access
|
||||
- d: Decay constant (typically 0.5)
|
||||
|
||||
Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe?
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class ACTRCalculator:
|
||||
"""
|
||||
Unified ACT-R Memory Activation Calculator.
|
||||
|
||||
This calculator implements the Memory Activation model that combines
|
||||
recency and frequency effects into a single activation value computation.
|
||||
It replaces the separate BLA calculation with an integrated approach.
|
||||
|
||||
Attributes:
|
||||
decay_constant: Decay parameter d (typically 0.5)
|
||||
forgetting_rate: Lambda parameter λ controlling forgetting speed
|
||||
offset: Minimum retention rate (baseline memory strength)
|
||||
max_history_length: Maximum number of access records to keep
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decay_constant: float = 0.5,
|
||||
forgetting_rate: float = 0.3,
|
||||
offset: float = 0.1,
|
||||
max_history_length: int = 100
|
||||
):
|
||||
"""
|
||||
Initialize the ACT-R calculator.
|
||||
|
||||
Args:
|
||||
decay_constant: Decay parameter d (default 0.5)
|
||||
forgetting_rate: Forgetting rate λ (default 0.3)
|
||||
offset: Minimum retention rate (default 0.1)
|
||||
max_history_length: Maximum access history length (default 100)
|
||||
"""
|
||||
self.decay_constant = decay_constant
|
||||
self.forgetting_rate = forgetting_rate
|
||||
self.offset = offset
|
||||
self.max_history_length = max_history_length
|
||||
|
||||
def calculate_memory_activation(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5
|
||||
) -> float:
|
||||
"""
|
||||
Calculate memory activation value using the unified Memory Activation formula.
|
||||
|
||||
This method computes R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
|
||||
|
||||
The formula integrates:
|
||||
- Recency effect: Recent accesses contribute more (via t)
|
||||
- Frequency effect: Multiple accesses strengthen memory (via Σ)
|
||||
- Importance weighting: Important memories decay slower (via I)
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps (ISO format or datetime objects)
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
|
||||
Returns:
|
||||
float: Memory activation value between offset and 1.0
|
||||
|
||||
Raises:
|
||||
ValueError: If access_history is empty or contains invalid data
|
||||
"""
|
||||
if not access_history:
|
||||
raise ValueError("access_history cannot be empty")
|
||||
|
||||
if not (0.0 <= importance_score <= 1.0):
|
||||
raise ValueError(f"importance_score must be between 0 and 1, got {importance_score}")
|
||||
|
||||
# Calculate time since last access (in days)
|
||||
time_since_last = (current_time - last_access_time).total_seconds() / 86400.0
|
||||
time_since_last = max(time_since_last, 0.0001) # Avoid division by zero
|
||||
|
||||
# Calculate BLA component: Σ(I·t_k^(-d))
|
||||
bla_sum = 0.0
|
||||
for access_time in access_history:
|
||||
# Calculate time since this access (in days)
|
||||
time_diff = (current_time - access_time).total_seconds() / 86400.0
|
||||
time_diff = max(time_diff, 0.0001) # Avoid division by zero
|
||||
|
||||
# Add weighted power-law term: I * t_k^(-d)
|
||||
bla_sum += importance_score * (time_diff ** (-self.decay_constant))
|
||||
|
||||
# Avoid division by zero in case of numerical issues
|
||||
if bla_sum <= 0:
|
||||
bla_sum = 0.0001
|
||||
|
||||
# Calculate Memory Activation: R(i) = offset + (1-offset) * exp(-λ*t / BLA)
|
||||
exponent = -self.forgetting_rate * time_since_last / bla_sum
|
||||
|
||||
# Clamp exponent to avoid numerical overflow/underflow
|
||||
exponent = max(min(exponent, 100), -100)
|
||||
|
||||
activation = self.offset + (1 - self.offset) * math.exp(exponent)
|
||||
|
||||
# Ensure activation is within valid range [offset, 1.0]
|
||||
return max(self.offset, min(1.0, activation))
|
||||
|
||||
def trim_access_history(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime
|
||||
) -> List[datetime]:
|
||||
"""
|
||||
Intelligently trim access history to prevent unbounded growth.
|
||||
|
||||
Strategy:
|
||||
- Keep all records if under max_history_length
|
||||
- If over limit, keep most recent 50% and sample from older records
|
||||
- Preserves both recent accesses (high importance) and historical pattern
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps (sorted or unsorted)
|
||||
current_time: Current time for calculation
|
||||
|
||||
Returns:
|
||||
List[datetime]: Trimmed access history
|
||||
"""
|
||||
if len(access_history) <= self.max_history_length:
|
||||
return access_history
|
||||
|
||||
# Sort by time (most recent first)
|
||||
sorted_history = sorted(access_history, reverse=True)
|
||||
|
||||
# Calculate split point (keep most recent 50%)
|
||||
keep_recent_count = self.max_history_length // 2
|
||||
|
||||
# Keep most recent 50%
|
||||
recent_records = sorted_history[:keep_recent_count]
|
||||
|
||||
# Sample from older records
|
||||
older_records = sorted_history[keep_recent_count:]
|
||||
sample_count = self.max_history_length - keep_recent_count
|
||||
|
||||
if len(older_records) <= sample_count:
|
||||
# If older records fit, keep them all
|
||||
sampled_older = older_records
|
||||
else:
|
||||
# Sample evenly from older records
|
||||
step = len(older_records) / sample_count
|
||||
sampled_older = [
|
||||
older_records[int(i * step)]
|
||||
for i in range(sample_count)
|
||||
]
|
||||
|
||||
# Combine and return
|
||||
trimmed_history = recent_records + sampled_older
|
||||
return sorted(trimmed_history, reverse=True)
|
||||
|
||||
def get_forgetting_curve( # 预测激活值,决定复习;测试不同配置效果,选择合适的d
|
||||
self,
|
||||
initial_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
days: int = 60
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate forgetting curve data for visualization.
|
||||
|
||||
This method simulates how memory activation decays over time
|
||||
for a single initial access, useful for understanding and
|
||||
visualizing the forgetting behavior.
|
||||
|
||||
Args:
|
||||
initial_time: Time of initial memory creation/access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
days: Number of days to simulate (default 60)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with keys:
|
||||
- 'day': Day number (0 to days)
|
||||
- 'activation': Memory activation value
|
||||
- 'retention_rate': Same as activation (for compatibility)
|
||||
"""
|
||||
curve_data = []
|
||||
access_history = [initial_time]
|
||||
|
||||
for day in range(days + 1):
|
||||
current_time = initial_time + timedelta(days=day)
|
||||
|
||||
try:
|
||||
activation = self.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=initial_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
except ValueError:
|
||||
# Handle edge cases
|
||||
activation = self.offset
|
||||
|
||||
curve_data.append({
|
||||
'day': day,
|
||||
'activation': activation,
|
||||
'retention_rate': activation # Alias for compatibility
|
||||
})
|
||||
|
||||
return curve_data
|
||||
|
||||
def calculate_forgetting_score(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5
|
||||
) -> float:
|
||||
"""
|
||||
Calculate forgetting score (inverse of activation).
|
||||
|
||||
Forgetting score = 1 - activation value
|
||||
Higher score means more likely to be forgotten.
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
|
||||
Returns:
|
||||
float: Forgetting score between 0 and (1 - offset)
|
||||
"""
|
||||
activation = self.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
return 1.0 - activation
|
||||
|
||||
def should_forget(
|
||||
self,
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
threshold: float = 0.3
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a memory should be forgotten based on activation threshold.
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
threshold: Activation threshold below which memory should be forgotten
|
||||
|
||||
Returns:
|
||||
bool: True if activation < threshold (should forget), False otherwise
|
||||
"""
|
||||
activation = self.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
return activation < threshold
|
||||
|
||||
|
||||
# Convenience functions for quick calculations
|
||||
def calculate_activation(
|
||||
access_history: List[datetime],
|
||||
current_time: datetime,
|
||||
last_access_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
decay_constant: float = 0.5,
|
||||
forgetting_rate: float = 0.3,
|
||||
offset: float = 0.1
|
||||
) -> float:
|
||||
"""
|
||||
Quick function to calculate activation without creating a calculator instance.
|
||||
|
||||
Args:
|
||||
access_history: List of access timestamps
|
||||
current_time: Current time for calculation
|
||||
last_access_time: Time of most recent access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
decay_constant: Decay parameter d (default 0.5)
|
||||
forgetting_rate: Forgetting rate λ (default 0.3)
|
||||
offset: Minimum retention rate (default 0.1)
|
||||
|
||||
Returns:
|
||||
float: Memory activation value between offset and 1.0
|
||||
"""
|
||||
calculator = ACTRCalculator(
|
||||
decay_constant=decay_constant,
|
||||
forgetting_rate=forgetting_rate,
|
||||
offset=offset
|
||||
)
|
||||
return calculator.calculate_memory_activation(
|
||||
access_history=access_history,
|
||||
current_time=current_time,
|
||||
last_access_time=last_access_time,
|
||||
importance_score=importance_score
|
||||
)
|
||||
|
||||
|
||||
def generate_forgetting_curve(
|
||||
initial_time: datetime,
|
||||
importance_score: float = 0.5,
|
||||
days: int = 60,
|
||||
decay_constant: float = 0.5,
|
||||
forgetting_rate: float = 0.3,
|
||||
offset: float = 0.1
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Quick function to generate forgetting curve data.
|
||||
|
||||
Args:
|
||||
initial_time: Time of initial memory creation/access
|
||||
importance_score: Importance weight (0 to 1, default 0.5)
|
||||
days: Number of days to simulate (default 60)
|
||||
decay_constant: Decay parameter d (default 0.5)
|
||||
forgetting_rate: Forgetting rate λ (default 0.3)
|
||||
offset: Minimum retention rate (default 0.1)
|
||||
|
||||
Returns:
|
||||
List of dictionaries with forgetting curve data
|
||||
"""
|
||||
calculator = ACTRCalculator(
|
||||
decay_constant=decay_constant,
|
||||
forgetting_rate=forgetting_rate,
|
||||
offset=offset
|
||||
)
|
||||
return calculator.get_forgetting_curve(
|
||||
initial_time=initial_time,
|
||||
importance_score=importance_score,
|
||||
days=days
|
||||
)
|
||||
@@ -0,0 +1,195 @@
|
||||
"""
|
||||
遗忘引擎配置工具模块
|
||||
|
||||
本模块提供从数据库加载配置并创建遗忘引擎组件的辅助函数。
|
||||
|
||||
Functions:
|
||||
calculate_forgetting_rate: 计算遗忘速率(lambda_time / lambda_mem)
|
||||
load_actr_config_from_db: 从数据库加载 ACT-R 配置参数
|
||||
create_actr_calculator_from_config: 从配置创建 ACTRCalculator 实例
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
|
||||
"""
|
||||
计算遗忘速率
|
||||
|
||||
公式:forgetting_rate = lambda_time / lambda_mem
|
||||
|
||||
这个计算将两个独立的 lambda 参数组合成一个统一的遗忘速率参数,
|
||||
用于 ACT-R 激活值计算。
|
||||
|
||||
Args:
|
||||
lambda_time: 时间衰减参数(0-1)
|
||||
lambda_mem: 记忆衰减参数(0-1)
|
||||
|
||||
Returns:
|
||||
float: 遗忘速率
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 lambda_mem 为 0
|
||||
|
||||
Examples:
|
||||
>>> calculate_forgetting_rate(0.5, 0.5)
|
||||
1.0
|
||||
>>> calculate_forgetting_rate(0.3, 0.5)
|
||||
0.6
|
||||
"""
|
||||
if lambda_mem == 0:
|
||||
raise ValueError("lambda_mem 不能为 0")
|
||||
|
||||
forgetting_rate = lambda_time / lambda_mem
|
||||
|
||||
logger.debug(
|
||||
f"计算遗忘速率: lambda_time={lambda_time}, "
|
||||
f"lambda_mem={lambda_mem}, "
|
||||
f"forgetting_rate={forgetting_rate:.4f}"
|
||||
)
|
||||
|
||||
return forgetting_rate
|
||||
|
||||
|
||||
def load_actr_config_from_db(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库加载 ACT-R 配置参数
|
||||
|
||||
从 PostgreSQL 的 data_config 表读取配置参数,
|
||||
并计算派生参数(如 forgetting_rate)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置 ID(可选,如果为 None 则使用默认值)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 配置参数字典,包含:
|
||||
- decay_constant: 衰减常数 d
|
||||
- lambda_time: 时间衰减参数
|
||||
- lambda_mem: 记忆衰减参数
|
||||
- forgetting_rate: 遗忘速率(根据 lambda_time / lambda_mem 计算得出)
|
||||
- offset: 偏移量
|
||||
- max_history_length: 访问历史最大长度
|
||||
- forgetting_threshold: 遗忘阈值
|
||||
- min_days_since_access: 最小未访问天数
|
||||
- enable_llm_summary: 是否使用 LLM 生成摘要
|
||||
- max_merge_batch_size: 单次最大融合节点对数
|
||||
- forgetting_interval_hours: 遗忘周期间隔
|
||||
|
||||
注意:llm_id 不包含在返回的配置中,需要时由 forgetting_strategy 直接从数据库读取
|
||||
|
||||
Raises:
|
||||
ValueError: 如果指定的 config_id 不存在
|
||||
"""
|
||||
# 必须指定 config_id
|
||||
if config_id is None:
|
||||
logger.error("未指定 config_id,无法加载配置")
|
||||
raise ValueError("config_id 不能为空,必须指定一个有效的配置 ID")
|
||||
|
||||
# 从数据库加载配置
|
||||
try:
|
||||
repository = DataConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None:
|
||||
logger.error(f"配置不存在: config_id={config_id}")
|
||||
raise ValueError(f"配置不存在: config_id={config_id}")
|
||||
|
||||
# 读取配置参数(信任数据库默认值)
|
||||
lambda_time = db_config.lambda_time
|
||||
lambda_mem = db_config.lambda_mem
|
||||
decay_constant = db_config.decay_constant
|
||||
offset = db_config.offset
|
||||
max_history_length = db_config.max_history_length
|
||||
forgetting_threshold = db_config.forgetting_threshold
|
||||
min_days_since_access = db_config.min_days_since_access
|
||||
enable_llm_summary = db_config.enable_llm_summary
|
||||
max_merge_batch_size = db_config.max_merge_batch_size
|
||||
forgetting_interval_hours = db_config.forgetting_interval_hours
|
||||
|
||||
# 计算 forgetting_rate
|
||||
forgetting_rate = calculate_forgetting_rate(lambda_time, lambda_mem)
|
||||
|
||||
config = {
|
||||
'decay_constant': decay_constant,
|
||||
'lambda_time': lambda_time,
|
||||
'lambda_mem': lambda_mem,
|
||||
'forgetting_rate': forgetting_rate,
|
||||
'offset': offset,
|
||||
'max_history_length': max_history_length,
|
||||
'forgetting_threshold': forgetting_threshold,
|
||||
'min_days_since_access': min_days_since_access,
|
||||
'enable_llm_summary': enable_llm_summary,
|
||||
'max_merge_batch_size': max_merge_batch_size,
|
||||
'forgetting_interval_hours': forgetting_interval_hours
|
||||
# 注意:llm_id 不包含在配置响应中,仅在内部使用
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"成功加载 ACT-R 配置: config_id={config_id}, "
|
||||
f"forgetting_rate={forgetting_rate:.4f}"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载 ACT-R 配置失败: config_id={config_id}, 错误: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_actr_calculator_from_config(
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
) -> ACTRCalculator:
|
||||
"""
|
||||
从数据库配置创建 ACTRCalculator 实例
|
||||
|
||||
这是创建 ACTRCalculator 的推荐方式,确保使用数据库中的配置参数。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置 ID(可选,如果为 None 则使用默认值)
|
||||
|
||||
Returns:
|
||||
ACTRCalculator: 配置好的 ACT-R 计算器实例
|
||||
|
||||
Raises:
|
||||
ValueError: 如果指定的 config_id 不存在
|
||||
|
||||
Examples:
|
||||
>>> from sqlalchemy.orm import Session
|
||||
>>> db = Session()
|
||||
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
|
||||
>>> # 使用计算器
|
||||
>>> activation = calculator.calculate_memory_activation(...)
|
||||
"""
|
||||
# 加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
|
||||
# 创建计算器
|
||||
calculator = ACTRCalculator(
|
||||
decay_constant=config['decay_constant'],
|
||||
forgetting_rate=config['forgetting_rate'],
|
||||
offset=config['offset'],
|
||||
max_history_length=config['max_history_length']
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"创建 ACTRCalculator: config_id={config_id}, "
|
||||
f"decay_constant={config['decay_constant']}, "
|
||||
f"forgetting_rate={config['forgetting_rate']:.4f}, "
|
||||
f"offset={config['offset']}"
|
||||
)
|
||||
|
||||
return calculator
|
||||
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
遗忘调度器模块
|
||||
|
||||
本模块实现遗忘周期的调度和管理,负责:
|
||||
1. 手动触发遗忘周期
|
||||
2. 批量处理可遗忘节点(限制批量大小)
|
||||
3. 按激活值优先级排序(激活值最低的优先)
|
||||
4. 进度跟踪和日志记录
|
||||
5. 生成遗忘报告
|
||||
|
||||
注意:定期调度功能已迁移到 Celery Beat,见 app/tasks.py 中的 run_forgetting_cycle_task
|
||||
|
||||
Classes:
|
||||
ForgettingScheduler: 遗忘调度器,提供遗忘周期管理功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForgettingScheduler:
|
||||
"""
|
||||
遗忘调度器
|
||||
|
||||
管理遗忘周期的执行,实现批量处理、优先级排序和进度跟踪功能。
|
||||
|
||||
核心功能:
|
||||
1. 运行遗忘周期:识别可遗忘节点并批量融合
|
||||
2. 优先级排序:优先处理激活值最低的节点对
|
||||
3. 批量限制:限制单次处理的节点对数量
|
||||
4. 进度跟踪:每完成 10% 记录一次日志
|
||||
5. 遗忘报告:生成详细的执行报告
|
||||
|
||||
注意:定期调度功能已迁移到 Celery Beat 定时任务
|
||||
|
||||
Attributes:
|
||||
forgetting_strategy: 遗忘策略执行器实例
|
||||
connector: Neo4j 连接器实例
|
||||
is_running: 是否正在运行遗忘周期
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
forgetting_strategy: ForgettingStrategy,
|
||||
connector: Neo4jConnector
|
||||
):
|
||||
"""
|
||||
初始化遗忘调度器
|
||||
|
||||
Args:
|
||||
forgetting_strategy: 遗忘策略执行器实例
|
||||
connector: Neo4j 连接器实例
|
||||
"""
|
||||
self.forgetting_strategy = forgetting_strategy
|
||||
self.connector = connector
|
||||
self.is_running = False
|
||||
|
||||
logger.info("初始化遗忘调度器")
|
||||
|
||||
async def run_forgetting_cycle(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
max_merge_batch_size: int = 100,
|
||||
min_days_since_access: int = 30,
|
||||
config_id: Optional[int] = None,
|
||||
db = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
运行一次完整的遗忘周期
|
||||
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
max_merge_batch_size: 单次最大融合节点对数(默认 100)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
db: 数据库会话(可选,用于获取 llm_id)
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 遗忘报告,包含:
|
||||
- merged_count: 融合的节点对数量
|
||||
- nodes_before: 遗忘前的节点总数
|
||||
- nodes_after: 遗忘后的节点总数
|
||||
- reduction_rate: 节点减少率(0-1)
|
||||
- duration_seconds: 执行耗时(秒)
|
||||
- start_time: 开始时间(ISO 格式)
|
||||
- end_time: 结束时间(ISO 格式)
|
||||
- failed_count: 失败的融合数量
|
||||
- success_rate: 成功率(0-1)
|
||||
|
||||
Raises:
|
||||
RuntimeError: 如果已有遗忘周期正在运行
|
||||
"""
|
||||
# 检查是否已有遗忘周期在运行
|
||||
if self.is_running:
|
||||
raise RuntimeError("遗忘周期已在运行中,请等待当前周期完成")
|
||||
|
||||
self.is_running = True
|
||||
start_time = datetime.now()
|
||||
start_time_iso = start_time.isoformat()
|
||||
|
||||
logger.info(
|
||||
f"开始遗忘周期: group_id={group_id}, "
|
||||
f"max_batch={max_merge_batch_size}, "
|
||||
f"min_days={min_days_since_access}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 步骤1:统计遗忘前的节点数量
|
||||
nodes_before = await self._count_knowledge_nodes(group_id)
|
||||
logger.info(f"遗忘前节点总数: {nodes_before}")
|
||||
|
||||
# 步骤2:识别可遗忘的节点对
|
||||
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
|
||||
group_id=group_id,
|
||||
min_days_since_access=min_days_since_access
|
||||
)
|
||||
|
||||
total_forgettable = len(forgettable_pairs)
|
||||
logger.info(f"识别到 {total_forgettable} 个可遗忘节点对")
|
||||
|
||||
if total_forgettable == 0:
|
||||
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
report = {
|
||||
'merged_count': 0,
|
||||
'nodes_before': nodes_before,
|
||||
'nodes_after': nodes_before,
|
||||
'reduction_rate': 0.0,
|
||||
'duration_seconds': duration,
|
||||
'start_time': start_time_iso,
|
||||
'end_time': end_time.isoformat(),
|
||||
'failed_count': 0,
|
||||
'success_rate': 1.0
|
||||
}
|
||||
|
||||
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
||||
|
||||
return report
|
||||
|
||||
# 步骤3:按激活值排序(激活值最低的优先)
|
||||
# avg_activation 已经在 find_forgettable_nodes 中计算并排序
|
||||
# 这里只需要确认排序是正确的(升序)
|
||||
sorted_pairs = sorted(
|
||||
forgettable_pairs,
|
||||
key=lambda x: x['avg_activation']
|
||||
)
|
||||
|
||||
# 步骤4:限制批量大小
|
||||
pairs_to_process = sorted_pairs[:max_merge_batch_size]
|
||||
actual_batch_size = len(pairs_to_process)
|
||||
|
||||
logger.info(
|
||||
f"将处理 {actual_batch_size} 个节点对 "
|
||||
f"(限制: {max_merge_batch_size})"
|
||||
)
|
||||
|
||||
# 步骤5:批量融合节点,每 10% 记录进度
|
||||
merged_count = 0
|
||||
failed_count = 0
|
||||
skipped_count = 0 # 跳过的节点对数量(节点已被处理)
|
||||
progress_interval = max(1, actual_batch_size // 10) # 每 10% 记录一次
|
||||
|
||||
# 跟踪已处理的节点 ID,避免重复处理
|
||||
processed_statement_ids = set()
|
||||
processed_entity_ids = set()
|
||||
|
||||
# 预先过滤掉重复的节点对
|
||||
unique_pairs = []
|
||||
for pair in pairs_to_process:
|
||||
statement_id = pair['statement_id']
|
||||
entity_id = pair['entity_id']
|
||||
|
||||
# 如果节点已被标记为处理,跳过
|
||||
if statement_id in processed_statement_ids or entity_id in processed_entity_ids:
|
||||
skipped_count += 1
|
||||
logger.debug(
|
||||
f"预过滤:跳过重复节点对 Statement[{statement_id}] + Entity[{entity_id}]"
|
||||
)
|
||||
continue
|
||||
|
||||
# 标记节点为已处理
|
||||
processed_statement_ids.add(statement_id)
|
||||
processed_entity_ids.add(entity_id)
|
||||
unique_pairs.append(pair)
|
||||
|
||||
logger.info(
|
||||
f"预过滤完成:原始 {actual_batch_size} 对,去重后 {len(unique_pairs)} 对,"
|
||||
f"跳过 {skipped_count} 对重复节点"
|
||||
)
|
||||
|
||||
# 更新实际处理的批次大小
|
||||
actual_batch_size = len(unique_pairs)
|
||||
progress_interval = max(1, actual_batch_size // 10) # 重新计算进度间隔
|
||||
|
||||
for idx, pair in enumerate(unique_pairs, start=1):
|
||||
statement_id = pair['statement_id']
|
||||
entity_id = pair['entity_id']
|
||||
|
||||
try:
|
||||
# 准备节点数据
|
||||
statement_node = {
|
||||
'statement_id': statement_id,
|
||||
'statement_text': pair['statement_text'],
|
||||
'statement_activation': pair['statement_activation'],
|
||||
'statement_importance': pair['statement_importance'],
|
||||
'group_id': group_id
|
||||
}
|
||||
|
||||
entity_node = {
|
||||
'entity_id': entity_id,
|
||||
'entity_name': pair['entity_name'],
|
||||
'entity_type': pair['entity_type'],
|
||||
'entity_activation': pair['entity_activation'],
|
||||
'entity_importance': pair['entity_importance'],
|
||||
'group_id': group_id
|
||||
}
|
||||
|
||||
# 融合节点
|
||||
await self.forgetting_strategy.merge_nodes_to_summary(
|
||||
statement_node=statement_node,
|
||||
entity_node=entity_node,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
merged_count += 1
|
||||
|
||||
# 进度跟踪:每 10% 记录一次
|
||||
if actual_batch_size > 0 and (idx % progress_interval == 0 or idx == actual_batch_size):
|
||||
progress_pct = (idx / actual_batch_size) * 100
|
||||
logger.info(
|
||||
f"遗忘进度: {idx}/{actual_batch_size} "
|
||||
f"({progress_pct:.1f}%), "
|
||||
f"已融合: {merged_count}, 失败: {failed_count}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
# 检查是否是节点不存在的错误
|
||||
if "nodes may not exist" in str(e):
|
||||
logger.warning(
|
||||
f"节点对 ({idx}/{actual_batch_size}) 的节点不存在(可能已被其他操作删除): "
|
||||
f"Statement[{statement_id}] + Entity[{entity_id}]"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"融合节点对失败 ({idx}/{actual_batch_size}): "
|
||||
f"Statement[{statement_id}] + Entity[{entity_id}], "
|
||||
f"错误: {str(e)}"
|
||||
)
|
||||
# 继续处理剩余节点
|
||||
continue
|
||||
|
||||
# 步骤6:统计遗忘后的节点数量
|
||||
nodes_after = await self._count_knowledge_nodes(group_id)
|
||||
logger.info(f"遗忘后节点总数: {nodes_after}")
|
||||
|
||||
# 步骤7:生成遗忘报告
|
||||
end_time = datetime.now()
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
# 计算节点减少率
|
||||
if nodes_before > 0:
|
||||
reduction_rate = (nodes_before - nodes_after) / nodes_before
|
||||
else:
|
||||
reduction_rate = 0.0
|
||||
|
||||
# 计算成功率
|
||||
if actual_batch_size > 0:
|
||||
success_rate = merged_count / actual_batch_size
|
||||
else:
|
||||
success_rate = 1.0
|
||||
|
||||
report = {
|
||||
'merged_count': merged_count,
|
||||
'nodes_before': nodes_before,
|
||||
'nodes_after': nodes_after,
|
||||
'reduction_rate': reduction_rate,
|
||||
'duration_seconds': duration,
|
||||
'start_time': start_time_iso,
|
||||
'end_time': end_time.isoformat(),
|
||||
'failed_count': failed_count,
|
||||
'success_rate': success_rate
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"遗忘周期完成: "
|
||||
f"融合 {merged_count} 对节点, "
|
||||
f"失败 {failed_count} 对, "
|
||||
f"节点减少 {nodes_before - nodes_after} 个 "
|
||||
f"({reduction_rate:.2%}), "
|
||||
f"耗时 {duration:.2f} 秒"
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"遗忘周期执行失败: {str(e)}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
self.is_running = False
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
async def _count_knowledge_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
统计知识层节点总数
|
||||
|
||||
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
|
||||
Returns:
|
||||
int: 知识层节点总数
|
||||
"""
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
|
||||
query += """
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
params = {}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
if results:
|
||||
return results[0]['total']
|
||||
return 0
|
||||
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
遗忘策略执行器模块
|
||||
|
||||
本模块实现基于 ACT-R 激活值的遗忘策略,负责:
|
||||
1. 识别低激活值的节点对(Statement-Entity)
|
||||
2. 将低激活值节点融合为 MemorySummary 节点
|
||||
3. 使用 LLM 生成高质量摘要(可选)
|
||||
4. 保留溯源信息并删除原始节点
|
||||
|
||||
Classes:
|
||||
ForgettingStrategy: 遗忘策略执行器,提供节点识别和融合功能
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ForgettingStrategy:
|
||||
"""
|
||||
遗忘策略执行器
|
||||
|
||||
基于 ACT-R 激活值识别和融合低价值记忆节点。
|
||||
实现了完整的遗忘周期:识别 → 融合 → 删除。
|
||||
|
||||
核心功能:
|
||||
1. 识别可遗忘节点:激活值低于阈值且长期未访问的 Statement-Entity 对
|
||||
2. 节点融合:创建 MemorySummary 节点,继承较高的激活值和重要性
|
||||
3. LLM 摘要生成:使用 LLM 生成语义摘要(可降级到简单拼接)
|
||||
4. 溯源保留:记录原始节点 ID,保持可追溯性
|
||||
|
||||
Attributes:
|
||||
connector: Neo4j 连接器实例
|
||||
actr_calculator: ACT-R 激活值计算器实例
|
||||
forgetting_threshold: 遗忘阈值(激活值低于此值的节点可被遗忘)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
actr_calculator: ACTRCalculator,
|
||||
forgetting_threshold: float = 0.3,
|
||||
enable_llm_summary: bool = True
|
||||
):
|
||||
"""
|
||||
初始化遗忘策略执行器
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器实例
|
||||
actr_calculator: ACT-R 激活值计算器实例
|
||||
forgetting_threshold: 遗忘阈值(默认 0.3)
|
||||
enable_llm_summary: 是否启用 LLM 摘要生成(默认 True)
|
||||
"""
|
||||
self.connector = connector
|
||||
self.actr_calculator = actr_calculator
|
||||
self.forgetting_threshold = forgetting_threshold
|
||||
self.enable_llm_summary = enable_llm_summary
|
||||
|
||||
logger.info(
|
||||
f"初始化遗忘策略执行器: threshold={forgetting_threshold}, "
|
||||
f"enable_llm_summary={enable_llm_summary}"
|
||||
)
|
||||
|
||||
async def calculate_forgetting_score(
|
||||
self,
|
||||
activation_value: float
|
||||
) -> float:
|
||||
"""
|
||||
计算遗忘分数
|
||||
|
||||
遗忘分数 = 1 - 激活值
|
||||
分数越高,越容易被遗忘。
|
||||
|
||||
注意:激活值已经包含了 importance_score 的权重,
|
||||
因此不需要单独考虑重要性分数。
|
||||
|
||||
Args:
|
||||
activation_value: 节点的激活值(0-1)
|
||||
|
||||
Returns:
|
||||
float: 遗忘分数(0-1),值越高越容易被遗忘
|
||||
"""
|
||||
return 1.0 - activation_value
|
||||
|
||||
async def find_forgettable_nodes(
|
||||
self,
|
||||
group_id: Optional[str] = None,
|
||||
min_days_since_access: int = 30
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
识别可遗忘的节点对
|
||||
|
||||
查找满足以下条件的 Statement-Entity 节点对:
|
||||
1. 两个节点的激活值都低于遗忘阈值
|
||||
2. 两个节点都至少 min_days_since_access 天未被访问
|
||||
3. Statement 和 Entity 之间存在关系边
|
||||
|
||||
Args:
|
||||
group_id: 组 ID(可选,用于过滤特定组的节点)
|
||||
min_days_since_access: 最小未访问天数(默认 30 天)
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 可遗忘节点对列表,每个元素包含:
|
||||
- statement_id: Statement 节点 ID
|
||||
- statement_text: Statement 文本内容
|
||||
- statement_activation: Statement 激活值
|
||||
- statement_importance: Statement 重要性分数
|
||||
- statement_last_access: Statement 最后访问时间
|
||||
- entity_id: Entity 节点 ID
|
||||
- entity_name: Entity 名称
|
||||
- entity_type: Entity 类型
|
||||
- entity_activation: Entity 激活值
|
||||
- entity_importance: Entity 重要性分数
|
||||
- entity_last_access: Entity 最后访问时间
|
||||
- avg_activation: 平均激活值(用于排序)
|
||||
"""
|
||||
# 计算时间阈值
|
||||
cutoff_time = datetime.now() - timedelta(days=min_days_since_access)
|
||||
cutoff_time_iso = cutoff_time.isoformat()
|
||||
|
||||
# 构建查询
|
||||
query = """
|
||||
MATCH (s:Statement)-[r]-(e:ExtractedEntity)
|
||||
WHERE s.activation_value IS NOT NULL
|
||||
AND e.activation_value IS NOT NULL
|
||||
AND s.activation_value < $threshold
|
||||
AND e.activation_value < $threshold
|
||||
AND s.last_access_time < $cutoff_time
|
||||
AND e.last_access_time < $cutoff_time
|
||||
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
|
||||
|
||||
query += """
|
||||
RETURN s.id as statement_id,
|
||||
s.statement as statement_text,
|
||||
s.activation_value as statement_activation,
|
||||
s.importance_score as statement_importance,
|
||||
s.last_access_time as statement_last_access,
|
||||
e.id as entity_id,
|
||||
e.name as entity_name,
|
||||
e.entity_type as entity_type,
|
||||
e.activation_value as entity_activation,
|
||||
e.importance_score as entity_importance,
|
||||
e.last_access_time as entity_last_access,
|
||||
(s.activation_value + e.activation_value) / 2.0 as avg_activation
|
||||
ORDER BY avg_activation ASC
|
||||
"""
|
||||
|
||||
params = {
|
||||
'threshold': self.forgetting_threshold,
|
||||
'cutoff_time': cutoff_time_iso
|
||||
}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
results = await self.connector.execute_query(query, **params)
|
||||
|
||||
logger.info(
|
||||
f"识别到 {len(results)} 个可遗忘节点对 "
|
||||
f"(threshold={self.forgetting_threshold}, "
|
||||
f"min_days={min_days_since_access})"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def merge_nodes_to_summary(
|
||||
self,
|
||||
statement_node: Dict[str, Any],
|
||||
entity_node: Dict[str, Any],
|
||||
config_id: Optional[int] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
将 Statement 和 Entity 节点融合为 MemorySummary 节点
|
||||
|
||||
融合过程:
|
||||
1. 生成摘要内容(使用 LLM 或简单拼接)
|
||||
2. 创建 MemorySummary 节点,继承较高的激活值和重要性分数
|
||||
3. 删除原始 Statement 和 Entity 节点
|
||||
4. 保留溯源信息(original_statement_id, original_entity_id)
|
||||
|
||||
Args:
|
||||
statement_node: Statement 节点数据,必须包含:
|
||||
- statement_id: 节点 ID
|
||||
- statement_text: 文本内容
|
||||
- statement_activation: 激活值
|
||||
- statement_importance: 重要性分数
|
||||
entity_node: Entity 节点数据,必须包含:
|
||||
- entity_id: 节点 ID
|
||||
- entity_name: 实体名称
|
||||
- entity_type: 实体类型
|
||||
- entity_activation: 激活值
|
||||
- entity_importance: 重要性分数
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
db: 数据库会话(可选,用于获取 llm_id)
|
||||
|
||||
Returns:
|
||||
str: 创建的 MemorySummary 节点 ID
|
||||
|
||||
Raises:
|
||||
ValueError: 如果节点数据不完整
|
||||
RuntimeError: 如果融合操作失败
|
||||
"""
|
||||
# 验证输入数据
|
||||
required_statement_keys = [
|
||||
'statement_id', 'statement_text',
|
||||
'statement_activation', 'statement_importance'
|
||||
]
|
||||
required_entity_keys = [
|
||||
'entity_id', 'entity_name', 'entity_type',
|
||||
'entity_activation', 'entity_importance'
|
||||
]
|
||||
|
||||
for key in required_statement_keys:
|
||||
if key not in statement_node:
|
||||
raise ValueError(f"Statement 节点缺少必需字段: {key}")
|
||||
|
||||
for key in required_entity_keys:
|
||||
if key not in entity_node:
|
||||
raise ValueError(f"Entity 节点缺少必需字段: {key}")
|
||||
|
||||
# 验证实体类型:不允许融合 Person 类型的实体
|
||||
if entity_node.get('entity_type') == 'Person':
|
||||
raise ValueError(
|
||||
f"不允许融合 Person 类型的实体: entity_id={entity_node.get('entity_id')}, "
|
||||
f"entity_name={entity_node.get('entity_name')}"
|
||||
)
|
||||
|
||||
# 提取节点信息
|
||||
statement_id = statement_node['statement_id']
|
||||
statement_text = statement_node['statement_text']
|
||||
statement_activation = statement_node['statement_activation']
|
||||
statement_importance = statement_node['statement_importance']
|
||||
|
||||
entity_id = entity_node['entity_id']
|
||||
entity_name = entity_node['entity_name']
|
||||
entity_type = entity_node['entity_type']
|
||||
entity_activation = entity_node['entity_activation']
|
||||
entity_importance = entity_node['entity_importance']
|
||||
|
||||
# 生成摘要内容
|
||||
summary_text = await self._generate_summary(
|
||||
statement_text=statement_text,
|
||||
entity_name=entity_name,
|
||||
entity_type=entity_type,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
# 计算继承的激活值和重要性(取较高值)
|
||||
inherited_activation = max(statement_activation, entity_activation)
|
||||
inherited_importance = max(statement_importance, entity_importance)
|
||||
|
||||
# 创建 MemorySummary 节点
|
||||
current_time = datetime.now()
|
||||
current_time_iso = current_time.isoformat()
|
||||
|
||||
# 生成新的 MemorySummary ID
|
||||
import uuid
|
||||
summary_id = f"summary_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
# 获取 group_id(从 statement 或 entity 节点)
|
||||
group_id = statement_node.get('group_id') or entity_node.get('group_id')
|
||||
|
||||
# 使用事务创建 MemorySummary 并删除原节点
|
||||
async def merge_transaction(tx, **params):
|
||||
"""事务函数:创建摘要节点并删除原节点"""
|
||||
query = """
|
||||
// 首先检查节点是否存在
|
||||
OPTIONAL MATCH (s:Statement {id: $statement_id})
|
||||
OPTIONAL MATCH (e:ExtractedEntity {id: $entity_id})
|
||||
|
||||
// 如果任一节点不存在,直接返回 null(不执行后续操作)
|
||||
WITH s, e
|
||||
WHERE s IS NOT NULL AND e IS NOT NULL
|
||||
|
||||
// 创建 MemorySummary 节点
|
||||
CREATE (ms:MemorySummary {
|
||||
id: $summary_id,
|
||||
summary: $summary_text,
|
||||
original_statement_id: $statement_id,
|
||||
original_entity_id: $entity_id,
|
||||
activation_value: $inherited_activation,
|
||||
importance_score: $inherited_importance,
|
||||
access_history: [$current_time],
|
||||
last_access_time: $current_time,
|
||||
access_count: 1,
|
||||
version: 1,
|
||||
group_id: $group_id,
|
||||
created_at: datetime($current_time),
|
||||
merged_at: datetime($current_time)
|
||||
})
|
||||
|
||||
// 转移 Statement 的出边到 MemorySummary(只转移目标节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (s)-[r_out]->(target)
|
||||
WHERE target <> e AND r_out IS NOT NULL AND target IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_out),
|
||||
new_rel.original_relationship_type = type(r_out),
|
||||
new_rel.merged_from_statement = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 转移 Statement 的入边到 MemorySummary(只转移源节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (source)-[r_in]->(s)
|
||||
WHERE r_in IS NOT NULL AND source IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_in),
|
||||
new_rel.original_relationship_type = type(r_in),
|
||||
new_rel.merged_from_statement = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 转移 Entity 的出边到 MemorySummary(只转移目标节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (e)-[r_out]->(target)
|
||||
WHERE target <> s AND r_out IS NOT NULL AND target IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_out),
|
||||
new_rel.original_relationship_type = type(r_out),
|
||||
new_rel.merged_from_entity = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 转移 Entity 的入边到 MemorySummary(只转移源节点仍存在的边)
|
||||
WITH ms, s, e
|
||||
CALL (ms, s, e) {
|
||||
OPTIONAL MATCH (source)-[r_in]->(e)
|
||||
WHERE source <> s AND r_in IS NOT NULL AND source IS NOT NULL
|
||||
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
|
||||
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
|
||||
ON CREATE SET
|
||||
new_rel = properties(r_in),
|
||||
new_rel.original_relationship_type = type(r_in),
|
||||
new_rel.merged_from_entity = true,
|
||||
new_rel.merge_count = 1
|
||||
ON MATCH SET
|
||||
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
|
||||
)
|
||||
}
|
||||
|
||||
// 删除原始节点
|
||||
WITH ms, s, e
|
||||
DETACH DELETE s, e
|
||||
|
||||
RETURN ms.id as summary_id
|
||||
"""
|
||||
|
||||
result = await tx.run(query, **params)
|
||||
record = await result.single()
|
||||
|
||||
if not record:
|
||||
raise RuntimeError("Failed to create MemorySummary node - nodes may not exist")
|
||||
|
||||
return record['summary_id']
|
||||
|
||||
params = {
|
||||
'summary_id': summary_id,
|
||||
'summary_text': summary_text,
|
||||
'statement_id': statement_id,
|
||||
'entity_id': entity_id,
|
||||
'inherited_activation': inherited_activation,
|
||||
'inherited_importance': inherited_importance,
|
||||
'current_time': current_time_iso,
|
||||
'group_id': group_id
|
||||
}
|
||||
|
||||
try:
|
||||
created_summary_id = await self.connector.execute_write_transaction(
|
||||
merge_transaction,
|
||||
**params
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"成功融合节点: Statement[{statement_id}] + Entity[{entity_id}] "
|
||||
f"-> MemorySummary[{created_summary_id}], "
|
||||
f"activation={inherited_activation:.4f}, "
|
||||
f"importance={inherited_importance:.4f}"
|
||||
)
|
||||
|
||||
return created_summary_id
|
||||
|
||||
except Exception as e:
|
||||
# 记录详细的错误信息,包括异常类型和堆栈
|
||||
import traceback
|
||||
error_details = traceback.format_exc()
|
||||
logger.error(
|
||||
f"融合节点失败: Statement[{statement_id}] + Entity[{entity_id}], "
|
||||
f"错误类型: {type(e).__name__}, "
|
||||
f"错误信息: {str(e)}, "
|
||||
f"详细堆栈:\n{error_details}"
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"融合节点失败: {str(e)}"
|
||||
) from e
|
||||
|
||||
# ==================== 私有辅助方法 ====================
|
||||
|
||||
async def _generate_summary(
|
||||
self,
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
config_id: Optional[int] = None,
|
||||
db = None
|
||||
) -> str:
|
||||
"""
|
||||
生成摘要内容
|
||||
|
||||
优先使用 LLM 生成高质量摘要,如果 LLM 不可用或失败,
|
||||
则降级到简单文本拼接。
|
||||
|
||||
Args:
|
||||
statement_text: Statement 文本内容
|
||||
entity_name: Entity 名称
|
||||
entity_type: Entity 类型
|
||||
config_id: 配置ID(可选,用于获取 llm_id)
|
||||
db: 数据库会话(可选,用于获取 llm_id)
|
||||
|
||||
Returns:
|
||||
str: 生成的摘要文本(最多 200 个字符)
|
||||
"""
|
||||
# 如果配置禁用 LLM 摘要,直接使用简单拼接
|
||||
if not self.enable_llm_summary:
|
||||
logger.info("LLM 摘要生成已禁用,使用简单拼接")
|
||||
return self._simple_concatenation(
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
# 尝试获取 LLM 客户端
|
||||
llm_client = None
|
||||
if config_id is not None and db is not None:
|
||||
try:
|
||||
llm_client = await self._get_llm_client(db, config_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
|
||||
|
||||
# 如果没有 LLM 客户端,直接使用简单拼接
|
||||
if llm_client is None:
|
||||
logger.info("未能获取 LLM 客户端,使用简单拼接")
|
||||
return self._simple_concatenation(
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
# 尝试使用 LLM 生成摘要
|
||||
try:
|
||||
summary = await self._generate_llm_summary(
|
||||
statement_text=statement_text,
|
||||
entity_name=entity_name,
|
||||
entity_type=entity_type,
|
||||
llm_client=llm_client
|
||||
)
|
||||
|
||||
# 限制长度为 200 个字符
|
||||
if len(summary) > 200:
|
||||
summary = f"{summary[:197]}..."
|
||||
|
||||
logger.info(f"使用 LLM 生成摘要: {summary}")
|
||||
return summary
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"LLM 摘要生成失败,降级到简单拼接: {str(e)}"
|
||||
)
|
||||
return self._simple_concatenation(
|
||||
statement_text, entity_name, entity_type
|
||||
)
|
||||
|
||||
async def _get_llm_client(self, db, config_id: int):
|
||||
"""
|
||||
从数据库获取 LLM 客户端
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
LLM 客户端实例,如果无法获取则返回 None
|
||||
"""
|
||||
try:
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# 从数据库读取配置
|
||||
repository = DataConfigRepository()
|
||||
db_config = repository.get_by_id(db, config_id)
|
||||
|
||||
if db_config is None or db_config.llm_id is None:
|
||||
logger.warning(f"配置 {config_id} 不存在或未设置 llm_id")
|
||||
return None
|
||||
|
||||
# 创建 LLM 客户端
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(str(db_config.llm_id))
|
||||
|
||||
logger.info(f"成功获取 LLM 客户端: config_id={config_id}, llm_id={db_config.llm_id}")
|
||||
return llm_client
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取 LLM 客户端失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _generate_llm_summary(
|
||||
self,
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str,
|
||||
llm_client
|
||||
) -> str:
|
||||
"""
|
||||
使用 LLM 生成高质量摘要
|
||||
|
||||
Args:
|
||||
statement_text: Statement 文本内容
|
||||
entity_name: Entity 名称
|
||||
entity_type: Entity 类型
|
||||
llm_client: LLM 客户端实例
|
||||
|
||||
Returns:
|
||||
str: LLM 生成的摘要文本
|
||||
|
||||
Raises:
|
||||
Exception: 如果 LLM 调用失败
|
||||
"""
|
||||
# 构建提示词
|
||||
prompt = f"""请为以下记忆片段生成一个简洁的摘要(不超过200个字符):
|
||||
|
||||
实体名称: {entity_name}
|
||||
实体类型: {entity_type}
|
||||
陈述内容: {statement_text}
|
||||
|
||||
要求:
|
||||
1. 摘要应该保留核心语义信息
|
||||
2. 长度不超过200个字符
|
||||
3. 使用简洁、自然的中文表达
|
||||
4. 只返回摘要文本,不要包含其他内容
|
||||
|
||||
摘要:"""
|
||||
|
||||
# 调用 LLM(直接传递 prompt 字符串)
|
||||
response = await llm_client.chat(prompt)
|
||||
|
||||
# 提取摘要文本
|
||||
if isinstance(response, str):
|
||||
summary = response.strip()
|
||||
elif hasattr(response, 'content'):
|
||||
summary = response.content.strip()
|
||||
else:
|
||||
summary = str(response).strip()
|
||||
|
||||
return summary
|
||||
|
||||
def _simple_concatenation(
|
||||
self,
|
||||
statement_text: str,
|
||||
entity_name: str,
|
||||
entity_type: str
|
||||
) -> str:
|
||||
"""
|
||||
简单文本拼接生成摘要
|
||||
|
||||
降级策略:当 LLM 不可用时使用。
|
||||
格式:[实体类型]实体名称: 陈述内容
|
||||
|
||||
Args:
|
||||
statement_text: Statement 文本内容
|
||||
entity_name: Entity 名称
|
||||
entity_type: Entity 类型
|
||||
|
||||
Returns:
|
||||
str: 拼接的摘要文本(最多 200 个字符)
|
||||
"""
|
||||
# 构建简单摘要
|
||||
summary = f"[{entity_type}]{entity_name}: {statement_text}"
|
||||
|
||||
# 限制长度为 200 个字符(注意:这里的长度是字符数,不是字节数)
|
||||
if len(summary) > 200:
|
||||
# 截断并添加省略号
|
||||
summary = f"{summary[:197]}..."
|
||||
|
||||
return summary
|
||||
|
||||
@@ -65,6 +65,15 @@ class DataConfig(Base):
|
||||
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数")
|
||||
offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数")
|
||||
|
||||
# ACT-R 遗忘引擎配置
|
||||
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d,默认0.5")
|
||||
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值,默认0.3")
|
||||
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔(小时),默认24")
|
||||
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要,默认True")
|
||||
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数,默认100")
|
||||
max_history_length = Column(Integer, default=100, comment="访问历史最大长度,默认100")
|
||||
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数,默认30")
|
||||
|
||||
# 情绪引擎配置
|
||||
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
|
||||
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")
|
||||
|
||||
@@ -106,7 +106,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
|
||||
"emotion_intensity": statement.emotion_intensity,
|
||||
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
|
||||
"emotion_subject": statement.emotion_subject,
|
||||
"emotion_target": statement.emotion_target
|
||||
"emotion_target": statement.emotion_target,
|
||||
# 添加 ACT-R 记忆激活属性
|
||||
"importance_score": statement.importance_score,
|
||||
"activation_value": statement.activation_value,
|
||||
"access_history": statement.access_history if statement.access_history else [],
|
||||
"last_access_time": statement.last_access_time,
|
||||
"access_count": statement.access_count
|
||||
}
|
||||
flattened_statements.append(flattened_statement)
|
||||
|
||||
|
||||
@@ -38,7 +38,12 @@ SET s += {
|
||||
valid_at: statement.valid_at,
|
||||
invalid_at: statement.invalid_at,
|
||||
statement_embedding: statement.statement_embedding,
|
||||
relevence_info: statement.relevence_info
|
||||
relevence_info: statement.relevence_info,
|
||||
importance_score: statement.importance_score,
|
||||
activation_value: statement.activation_value,
|
||||
access_history: statement.access_history,
|
||||
last_access_time: statement.last_access_time,
|
||||
access_count: statement.access_count
|
||||
}
|
||||
RETURN s.id AS uuid
|
||||
"""
|
||||
@@ -111,7 +116,12 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
|
||||
WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength
|
||||
ELSE e.connect_strength
|
||||
END
|
||||
END
|
||||
END,
|
||||
e.importance_score = CASE WHEN entity.importance_score IS NOT NULL THEN entity.importance_score ELSE coalesce(e.importance_score, 0.5) END,
|
||||
e.activation_value = CASE WHEN entity.activation_value IS NOT NULL THEN entity.activation_value ELSE e.activation_value END,
|
||||
e.access_history = CASE WHEN entity.access_history IS NOT NULL THEN entity.access_history ELSE coalesce(e.access_history, []) END,
|
||||
e.last_access_time = CASE WHEN entity.last_access_time IS NOT NULL THEN entity.last_access_time ELSE e.last_access_time END,
|
||||
e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END
|
||||
RETURN e.id AS uuid
|
||||
"""
|
||||
|
||||
@@ -225,6 +235,10 @@ RETURN e.id AS id,
|
||||
e.name AS name,
|
||||
e.group_id AS group_id,
|
||||
e.entity_type AS entity_type,
|
||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||
e.last_access_time AS last_access_time,
|
||||
COALESCE(e.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -243,6 +257,10 @@ RETURN s.id AS id,
|
||||
s.expired_at AS expired_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -258,6 +276,9 @@ RETURN c.id AS chunk_id,
|
||||
c.group_id AS group_id,
|
||||
c.content AS content,
|
||||
c.dialog_id AS dialog_id,
|
||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||
c.last_access_time AS last_access_time,
|
||||
COALESCE(c.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -278,6 +299,10 @@ RETURN s.id AS id,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -305,6 +330,10 @@ RETURN e.id AS id,
|
||||
e.connect_strength AS connect_strength,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT c.id) AS chunk_ids,
|
||||
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(e.importance_score, 0.5) AS importance_score,
|
||||
e.last_access_time AS last_access_time,
|
||||
COALESCE(e.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -322,6 +351,9 @@ RETURN c.id AS chunk_id,
|
||||
c.sequence_number AS sequence_number,
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
COALESCE(c.activation_value, 0.5) AS activation_value,
|
||||
c.last_access_time AS last_access_time,
|
||||
COALESCE(c.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -419,7 +451,11 @@ RETURN s.id AS id,
|
||||
s.created_at AS created_at,
|
||||
s.valid_at AS valid_at,
|
||||
s.invalid_at AS invalid_at,
|
||||
collect(DISTINCT s.id) AS statement_ids
|
||||
collect(DISTINCT s.id) AS statement_ids,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count
|
||||
ORDER BY datetime(s.created_at) DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
@@ -446,6 +482,10 @@ RETURN s.id AS id,
|
||||
s.invalid_at AS invalid_at,
|
||||
c.id AS chunk_id_from_rel,
|
||||
collect(DISTINCT e.id) AS entity_ids,
|
||||
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||
s.last_access_time AS last_access_time,
|
||||
COALESCE(s.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY s.created_at DESC, score DESC
|
||||
LIMIT $limit
|
||||
@@ -635,6 +675,10 @@ RETURN m.id AS id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(m.importance_score, 0.5) AS importance_score,
|
||||
m.last_access_time AS last_access_time,
|
||||
COALESCE(m.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
@@ -653,6 +697,10 @@ RETURN m.id AS id,
|
||||
m.chunk_ids AS chunk_ids,
|
||||
m.content AS content,
|
||||
m.created_at AS created_at,
|
||||
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
|
||||
COALESCE(m.importance_score, 0.5) AS importance_score,
|
||||
m.last_access_time AS last_access_time,
|
||||
COALESCE(m.access_count, 0) AS access_count,
|
||||
score
|
||||
ORDER BY score DESC
|
||||
LIMIT $limit
|
||||
|
||||
@@ -55,6 +55,13 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
|
||||
if 'aliases' not in n or n['aliases'] is None:
|
||||
n['aliases'] = []
|
||||
|
||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||
n['importance_score'] = n.get('importance_score', 0.5)
|
||||
n['activation_value'] = n.get('activation_value')
|
||||
n['access_history'] = n.get('access_history', [])
|
||||
n['last_access_time'] = n.get('last_access_time')
|
||||
n['access_count'] = n.get('access_count', 0)
|
||||
|
||||
return ExtractedEntityNode(**n)
|
||||
|
||||
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
# 使用新的仓储层
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -24,6 +25,157 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
SEARCH_STATEMENTS_L_VALID_AT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _update_activation_values_batch(
|
||||
connector: Neo4jConnector,
|
||||
nodes: List[Dict[str, Any]],
|
||||
node_label: str,
|
||||
group_id: Optional[str] = None,
|
||||
max_retries: int = 3
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
批量更新节点的激活值
|
||||
|
||||
为提高性能,批量更新多个节点的访问历史和激活值。
|
||||
使用重试机制处理更新失败的情况。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
nodes: 节点列表,每个节点必须包含 'id' 字段
|
||||
node_label: 节点标签(Statement, ExtractedEntity, MemorySummary)
|
||||
group_id: 组ID(可选)
|
||||
max_retries: 最大重试次数
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 成功更新的节点列表
|
||||
"""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
|
||||
# 创建计算器和管理器实例
|
||||
actr_calculator = ACTRCalculator()
|
||||
access_manager = AccessHistoryManager(
|
||||
connector=connector,
|
||||
actr_calculator=actr_calculator,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# 提取节点ID列表
|
||||
node_ids = [node.get('id') for node in nodes if node.get('id')]
|
||||
|
||||
if not node_ids:
|
||||
logger.warning(f"批量更新激活值:没有有效的节点ID")
|
||||
return nodes
|
||||
|
||||
# 批量记录访问
|
||||
try:
|
||||
updated_nodes = await access_manager.record_batch_access(
|
||||
node_ids=node_ids,
|
||||
node_label=node_label,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"批量更新激活值成功: {node_label}, "
|
||||
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
|
||||
)
|
||||
|
||||
return updated_nodes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
|
||||
)
|
||||
# 失败时返回原始节点列表
|
||||
return nodes
|
||||
|
||||
|
||||
async def _update_search_results_activation(
|
||||
connector: Neo4jConnector,
|
||||
results: Dict[str, List[Dict[str, Any]]],
|
||||
group_id: Optional[str] = None
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
更新搜索结果中所有知识节点的激活值
|
||||
|
||||
对 Statement、ExtractedEntity、MemorySummary 节点进行批量激活值更新。
|
||||
ChunkNode 和 DialogueNode 不参与激活值更新(数据层隔离)。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器
|
||||
results: 搜索结果字典,包含不同类型节点的列表
|
||||
group_id: 组ID(可选)
|
||||
|
||||
Returns:
|
||||
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
|
||||
"""
|
||||
# 定义需要更新激活值的节点类型
|
||||
knowledge_node_types = {
|
||||
'statements': 'Statement',
|
||||
'entities': 'ExtractedEntity',
|
||||
'summaries': 'MemorySummary'
|
||||
}
|
||||
|
||||
# 并行更新所有类型的节点
|
||||
update_tasks = []
|
||||
update_keys = []
|
||||
|
||||
for key, label in knowledge_node_types.items():
|
||||
if key in results and results[key]:
|
||||
update_tasks.append(
|
||||
_update_activation_values_batch(
|
||||
connector=connector,
|
||||
nodes=results[key],
|
||||
node_label=label,
|
||||
group_id=group_id
|
||||
)
|
||||
)
|
||||
update_keys.append(key)
|
||||
|
||||
if not update_tasks:
|
||||
return results
|
||||
|
||||
# 并行执行所有更新
|
||||
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
|
||||
|
||||
# 更新结果字典,保留原始搜索分数
|
||||
updated_results = results.copy()
|
||||
for key, update_result in zip(update_keys, update_results):
|
||||
if not isinstance(update_result, Exception):
|
||||
# 更新成功,合并原始搜索结果和更新后的激活值数据
|
||||
# 保留原始的 score 字段(BM25/Embedding 分数)
|
||||
original_nodes = results[key]
|
||||
updated_nodes = update_result
|
||||
|
||||
# 创建 ID 到原始节点的映射(用于快速查找 score)
|
||||
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
|
||||
|
||||
# 合并数据:激活值来自更新结果,score 来自原始结果
|
||||
merged_nodes = []
|
||||
for updated_node in updated_nodes:
|
||||
node_id = updated_node.get('id')
|
||||
if node_id and node_id in original_map:
|
||||
# 保留原始的 score 字段
|
||||
original_score = original_map[node_id].get('score')
|
||||
if original_score is not None:
|
||||
updated_node['score'] = original_score
|
||||
merged_nodes.append(updated_node)
|
||||
|
||||
updated_results[key] = merged_nodes
|
||||
else:
|
||||
# 更新失败,记录错误但保留原始结果
|
||||
logger.warning(
|
||||
f"更新 {key} 激活值失败: {str(update_result)}"
|
||||
)
|
||||
|
||||
return updated_results
|
||||
|
||||
|
||||
async def search_graph(
|
||||
connector: Neo4jConnector,
|
||||
@@ -36,6 +188,7 @@ async def search_graph(
|
||||
Search across Statements, Entities, Chunks, and Summaries using a free-text query.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
INTEGRATED: Updates activation values for knowledge nodes before returning results
|
||||
|
||||
- Statements: matches s.statement CONTAINS q
|
||||
- Entities: matches e.name CONTAINS q
|
||||
@@ -50,7 +203,7 @@ async def search_graph(
|
||||
include: List of categories to search (default: all)
|
||||
|
||||
Returns:
|
||||
Dictionary with search results per category
|
||||
Dictionary with search results per category (with updated activation values)
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
@@ -106,6 +259,13 @@ async def search_graph(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -121,6 +281,7 @@ async def search_graph_by_embedding(
|
||||
Embedding-based semantic search across Statements, Chunks, and Entities.
|
||||
|
||||
OPTIMIZED: Runs all queries in parallel using asyncio.gather()
|
||||
INTEGRATED: Updates activation values for knowledge nodes before returning results
|
||||
|
||||
- Computes query embedding with the provided embedder_client
|
||||
- Ranks by cosine similarity in Cypher
|
||||
@@ -203,6 +364,16 @@ async def search_graph_by_embedding(
|
||||
else:
|
||||
results[key] = result
|
||||
|
||||
# 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary)
|
||||
update_start = time.time()
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
update_time = time.time() - update_start
|
||||
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
|
||||
|
||||
return results
|
||||
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
|
||||
connector: Neo4jConnector,
|
||||
@@ -304,6 +475,8 @@ async def search_graph_by_keyword_temporal(
|
||||
) -> Dict[str, List[Any]]:
|
||||
"""
|
||||
Temporal keyword search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements containing query_text created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -326,7 +499,15 @@ async def search_graph_by_keyword_temporal(
|
||||
)
|
||||
print(f"查询结果为:\n{statements}")
|
||||
|
||||
return {"statements": statements}
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_temporal(
|
||||
@@ -342,6 +523,8 @@ async def search_graph_by_temporal(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created between start_date and end_date
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -362,7 +545,16 @@ async def search_graph_by_temporal(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def search_graph_by_dialog_id(
|
||||
@@ -419,6 +611,8 @@ async def search_graph_by_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -436,7 +630,16 @@ async def search_graph_by_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_by_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -448,6 +651,8 @@ async def search_graph_by_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -465,7 +670,16 @@ async def search_graph_by_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_created_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -477,6 +691,8 @@ async def search_graph_g_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -494,7 +710,16 @@ async def search_graph_g_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_g_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -506,6 +731,8 @@ async def search_graph_g_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -523,7 +750,16 @@ async def search_graph_g_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_created_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -535,6 +771,8 @@ async def search_graph_l_created_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements created at created_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -552,7 +790,16 @@ async def search_graph_l_created_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
async def search_graph_l_valid_at(
|
||||
connector: Neo4jConnector,
|
||||
@@ -564,6 +811,8 @@ async def search_graph_l_valid_at(
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""
|
||||
Temporal search across Statements.
|
||||
|
||||
INTEGRATED: Updates activation values for Statement nodes before returning results
|
||||
|
||||
- Matches statements valid at valid_at
|
||||
- Optionally filters by group_id, apply_id, user_id
|
||||
@@ -581,4 +830,13 @@ async def search_graph_l_valid_at(
|
||||
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
|
||||
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
|
||||
print(f"查询结果为:\n{statements}")
|
||||
return {"statements": statements}
|
||||
|
||||
# 更新 Statement 节点的激活值
|
||||
results = {"statements": statements}
|
||||
results = await _update_search_results_activation(
|
||||
connector=connector,
|
||||
results=results,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@@ -8,7 +8,6 @@ Classes:
|
||||
Neo4jConnector: Neo4j数据库连接器,提供异步查询接口
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, List, Dict
|
||||
|
||||
from neo4j import AsyncGraphDatabase, basic_auth
|
||||
@@ -85,6 +84,63 @@ class Neo4jConnector:
|
||||
records, summary, keys = result
|
||||
return [record.data() for record in records]
|
||||
|
||||
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||
"""在写事务中执行操作
|
||||
|
||||
提供显式事务支持,确保操作的原子性。
|
||||
如果事务函数抛出异常,所有更改将自动回滚。
|
||||
|
||||
Args:
|
||||
transaction_func: 事务函数,接收 tx 参数并执行查询
|
||||
**kwargs: 传递给事务函数的额外参数
|
||||
|
||||
Returns:
|
||||
Any: 事务函数的返回值
|
||||
|
||||
Example:
|
||||
>>> async def create_node(tx, name):
|
||||
... result = await tx.run(
|
||||
... "CREATE (n:Person {name: $name}) RETURN n",
|
||||
... name=name
|
||||
... )
|
||||
... return await result.single()
|
||||
>>>
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> result = await connector.execute_write_transaction(
|
||||
... create_node, name="Alice"
|
||||
... )
|
||||
"""
|
||||
async with self.driver.session(database="neo4j") as session:
|
||||
return await session.execute_write(transaction_func, **kwargs)
|
||||
|
||||
async def execute_read_transaction(self, transaction_func, **kwargs: Any) -> Any:
|
||||
"""在读事务中执行操作
|
||||
|
||||
提供显式事务支持用于读操作。
|
||||
|
||||
Args:
|
||||
transaction_func: 事务函数,接收 tx 参数并执行查询
|
||||
**kwargs: 传递给事务函数的额外参数
|
||||
|
||||
Returns:
|
||||
Any: 事务函数的返回值
|
||||
|
||||
Example:
|
||||
>>> async def get_node(tx, name):
|
||||
... result = await tx.run(
|
||||
... "MATCH (n:Person {name: $name}) RETURN n",
|
||||
... name=name
|
||||
... )
|
||||
... return await result.single()
|
||||
>>>
|
||||
>>> connector = Neo4jConnector()
|
||||
>>> result = await connector.execute_read_transaction(
|
||||
... get_node, name="Alice"
|
||||
... )
|
||||
"""
|
||||
async with self.driver.session(database="neo4j") as session:
|
||||
return await session.execute_read(transaction_func, **kwargs)
|
||||
|
||||
async def delete_group(self, group_id: str):
|
||||
"""删除指定组的所有数据
|
||||
|
||||
|
||||
@@ -75,6 +75,13 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
|
||||
n['emotion_subject'] = n.get('emotion_subject')
|
||||
n['emotion_target'] = n.get('emotion_target')
|
||||
|
||||
# 处理 ACT-R 属性 - 确保字段存在且有默认值
|
||||
n['importance_score'] = n.get('importance_score', 0.5)
|
||||
n['activation_value'] = n.get('activation_value')
|
||||
n['access_history'] = n.get('access_history', [])
|
||||
n['last_access_time'] = n.get('last_access_time')
|
||||
n['access_count'] = n.get('access_count', 0)
|
||||
|
||||
return StatementNode(**n)
|
||||
|
||||
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:
|
||||
|
||||
@@ -399,3 +399,104 @@ class GenerateCacheRequest(BaseModel):
|
||||
None,
|
||||
description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 遗忘引擎相关 Schema
|
||||
# ============================================================================
|
||||
|
||||
class ForgettingTriggerRequest(BaseModel):
|
||||
"""手动触发遗忘周期请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
group_id: Optional[str] = Field(None, description="组ID(可选,用于过滤特定组的节点)")
|
||||
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)")
|
||||
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(可选,用于指定遗忘引擎配置)") # TODO 后续group_id更换成enduser_id,自动与config_id关联 ,要删除此行
|
||||
|
||||
|
||||
class ForgettingConfigResponse(BaseModel):
|
||||
"""遗忘引擎配置响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
decay_constant: float = Field(..., description="衰减常数 d")
|
||||
lambda_time: float = Field(..., description="时间衰减参数")
|
||||
lambda_mem: float = Field(..., description="记忆衰减参数")
|
||||
forgetting_rate: float = Field(..., description="遗忘速率(根据 lambda_time / lambda_mem 计算得出)")
|
||||
offset: float = Field(..., description="偏移量")
|
||||
max_history_length: int = Field(..., description="访问历史最大长度")
|
||||
forgetting_threshold: float = Field(..., description="遗忘阈值")
|
||||
min_days_since_access: int = Field(..., description="最小未访问天数")
|
||||
enable_llm_summary: bool = Field(..., description="是否使用 LLM 生成摘要")
|
||||
max_merge_batch_size: int = Field(..., description="单次最大融合节点对数")
|
||||
forgetting_interval_hours: int = Field(..., description="遗忘周期间隔(小时)")
|
||||
|
||||
|
||||
class ForgettingConfigUpdateRequest(BaseModel):
|
||||
"""遗忘引擎配置更新请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
config_id: int = Field(..., description="配置ID")
|
||||
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
|
||||
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
|
||||
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
|
||||
offset: Optional[float] = Field(None, ge=0.0, le=1.0, description="偏移量")
|
||||
max_history_length: Optional[int] = Field(None, ge=10, le=1000, description="访问历史最大长度")
|
||||
forgetting_threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="遗忘阈值")
|
||||
min_days_since_access: Optional[int] = Field(None, ge=1, le=365, description="最小未访问天数")
|
||||
enable_llm_summary: Optional[bool] = Field(None, description="是否使用 LLM 生成摘要")
|
||||
max_merge_batch_size: Optional[int] = Field(None, ge=1, le=1000, description="单次最大融合节点对数")
|
||||
forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="遗忘周期间隔(小时)")
|
||||
|
||||
|
||||
class ForgettingStatsResponse(BaseModel):
|
||||
"""遗忘引擎统计信息响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
|
||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||
consistency_check: Optional[Dict[str, Any]] = Field(None, description="数据一致性检查结果")
|
||||
nodes_merged_total: int = Field(..., description="累计融合节点对数")
|
||||
recent_cycles: List[Dict[str, Any]] = Field(..., description="最近的遗忘周期记录")
|
||||
timestamp: str = Field(..., description="统计时间(ISO格式)")
|
||||
|
||||
|
||||
class ForgettingReportResponse(BaseModel):
|
||||
"""遗忘周期报告响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
merged_count: int = Field(..., description="融合的节点对数量")
|
||||
nodes_before: int = Field(..., description="遗忘前的节点总数")
|
||||
nodes_after: int = Field(..., description="遗忘后的节点总数")
|
||||
reduction_rate: float = Field(..., description="节点减少率(0-1)")
|
||||
duration_seconds: float = Field(..., description="执行耗时(秒)")
|
||||
start_time: str = Field(..., description="开始时间(ISO格式)")
|
||||
end_time: str = Field(..., description="结束时间(ISO格式)")
|
||||
failed_count: int = Field(..., description="失败的融合数量")
|
||||
success_rate: float = Field(..., description="成功率(0-1)")
|
||||
|
||||
|
||||
class ForgettingCurvePoint(BaseModel):
|
||||
"""遗忘曲线数据点模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
day: int = Field(..., description="天数")
|
||||
activation: float = Field(..., description="激活值")
|
||||
retention_rate: float = Field(..., description="保持率(与激活值相同)")
|
||||
|
||||
|
||||
class ForgettingCurveRequest(BaseModel):
|
||||
"""遗忘曲线请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)")
|
||||
days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)")
|
||||
config_id: Optional[int] = Field(None, description="配置ID(可选,如果为None则使用默认配置)")
|
||||
|
||||
|
||||
class ForgettingCurveResponse(BaseModel):
|
||||
"""遗忘曲线响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表")
|
||||
config: Dict[str, Any] = Field(..., description="使用的配置参数")
|
||||
|
||||
460
api/app/services/memory_forget_service.py
Normal file
460
api/app/services/memory_forget_service.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
遗忘引擎服务层模块
|
||||
|
||||
本模块提供遗忘引擎的业务逻辑实现,包括:
|
||||
1. 遗忘周期执行
|
||||
2. 配置管理
|
||||
3. 统计信息查询
|
||||
4. 遗忘曲线生成
|
||||
|
||||
所有业务逻辑从控制器层分离到此服务层。
|
||||
"""
|
||||
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import ForgettingScheduler
|
||||
from app.core.memory.storage_services.forgetting_engine.config_utils import (
|
||||
load_actr_config_from_db,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
class MemoryForgetService:
|
||||
"""遗忘引擎服务类"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.config_repository = DataConfigRepository()
|
||||
|
||||
def _get_neo4j_connector(self) -> Neo4jConnector:
|
||||
"""
|
||||
获取 Neo4j 连接器实例
|
||||
|
||||
Returns:
|
||||
Neo4jConnector: Neo4j 连接器实例
|
||||
"""
|
||||
# 这里应该从配置或依赖注入获取连接器
|
||||
# 暂时创建新实例(实际应该使用单例或连接池)
|
||||
return Neo4jConnector()
|
||||
|
||||
async def _get_forgetting_components(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]:
|
||||
"""
|
||||
获取遗忘引擎组件(计算器、策略、调度器)
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID(可选)
|
||||
|
||||
Returns:
|
||||
tuple: (actr_calculator, forgetting_strategy, forgetting_scheduler, config)
|
||||
"""
|
||||
# 加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
|
||||
# 创建 ACT-R 计算器
|
||||
actr_calculator = ACTRCalculator(
|
||||
decay_constant=config['decay_constant'],
|
||||
forgetting_rate=config['forgetting_rate'],
|
||||
offset=config['offset'],
|
||||
max_history_length=config['max_history_length']
|
||||
)
|
||||
|
||||
# 获取 Neo4j 连接器
|
||||
connector = self._get_neo4j_connector()
|
||||
|
||||
# 创建遗忘策略执行器
|
||||
forgetting_strategy = ForgettingStrategy(
|
||||
connector=connector,
|
||||
actr_calculator=actr_calculator,
|
||||
forgetting_threshold=config['forgetting_threshold'],
|
||||
enable_llm_summary=config['enable_llm_summary']
|
||||
)
|
||||
|
||||
# 创建遗忘调度器
|
||||
forgetting_scheduler = ForgettingScheduler(
|
||||
forgetting_strategy=forgetting_strategy,
|
||||
connector=connector
|
||||
)
|
||||
|
||||
return actr_calculator, forgetting_strategy, forgetting_scheduler, config
|
||||
|
||||
async def _get_knowledge_stats(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
forgetting_threshold: float = 0.3
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取知识层统计信息
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
group_id: 组ID(可选)
|
||||
forgetting_threshold: 遗忘阈值
|
||||
|
||||
Returns:
|
||||
dict: 统计信息字典
|
||||
"""
|
||||
# 构建查询
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
|
||||
query += """
|
||||
WITH n,
|
||||
CASE
|
||||
WHEN n:Statement THEN 'statement'
|
||||
WHEN n:ExtractedEntity THEN 'entity'
|
||||
WHEN n:MemorySummary THEN 'summary'
|
||||
END as node_type
|
||||
RETURN
|
||||
count(n) as total_nodes,
|
||||
sum(CASE WHEN node_type = 'statement' THEN 1 ELSE 0 END) as statement_count,
|
||||
sum(CASE WHEN node_type = 'entity' THEN 1 ELSE 0 END) as entity_count,
|
||||
sum(CASE WHEN node_type = 'summary' THEN 1 ELSE 0 END) as summary_count,
|
||||
avg(n.activation_value) as average_activation,
|
||||
sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
|
||||
"""
|
||||
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
results = await connector.execute_query(query, **params)
|
||||
|
||||
if results:
|
||||
result = results[0]
|
||||
return {
|
||||
'total_nodes': result['total_nodes'] or 0,
|
||||
'statement_count': result['statement_count'] or 0,
|
||||
'entity_count': result['entity_count'] or 0,
|
||||
'summary_count': result['summary_count'] or 0,
|
||||
'average_activation': result['average_activation'],
|
||||
'low_activation_nodes': result['low_activation_nodes'] or 0
|
||||
}
|
||||
|
||||
return {
|
||||
'total_nodes': 0,
|
||||
'statement_count': 0,
|
||||
'entity_count': 0,
|
||||
'summary_count': 0,
|
||||
'average_activation': None,
|
||||
'low_activation_nodes': 0
|
||||
}
|
||||
|
||||
async def trigger_forgetting_cycle(
|
||||
self,
|
||||
db: Session,
|
||||
group_id: Optional[str] = None,
|
||||
max_merge_batch_size: Optional[int] = None,
|
||||
min_days_since_access: Optional[int] = None,
|
||||
config_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发遗忘周期
|
||||
|
||||
执行一次完整的遗忘周期,识别并融合低激活值节点。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
group_id: 组ID(可选)
|
||||
max_merge_batch_size: 最大融合批次大小(可选)
|
||||
min_days_since_access: 最小未访问天数(可选)
|
||||
config_id: 配置ID(可选)
|
||||
|
||||
Returns:
|
||||
dict: 遗忘报告
|
||||
"""
|
||||
# 获取遗忘引擎组件
|
||||
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||
|
||||
# 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取)
|
||||
report = await forgetting_scheduler.run_forgetting_cycle(
|
||||
group_id=group_id,
|
||||
max_merge_batch_size=max_merge_batch_size,
|
||||
min_days_since_access=min_days_since_access,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"遗忘周期完成: 融合 {report['merged_count']} 对节点, "
|
||||
f"失败 {report['failed_count']} 对, "
|
||||
f"耗时 {report['duration_seconds']:.2f} 秒"
|
||||
)
|
||||
|
||||
return report
|
||||
|
||||
def read_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎配置
|
||||
|
||||
读取指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
|
||||
Returns:
|
||||
dict: 配置信息字典
|
||||
"""
|
||||
# 加载配置
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
|
||||
# 添加 config_id 到返回结果
|
||||
config['config_id'] = config_id
|
||||
|
||||
api_logger.info(f"成功读取遗忘引擎配置: config_id={config_id}")
|
||||
|
||||
return config
|
||||
|
||||
def update_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int,
|
||||
update_fields: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
更新遗忘引擎配置
|
||||
|
||||
更新指定配置ID的遗忘引擎参数。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
config_id: 配置ID
|
||||
update_fields: 要更新的字段字典
|
||||
|
||||
Returns:
|
||||
dict: 更新后的配置信息
|
||||
|
||||
Raises:
|
||||
ValueError: 配置不存在
|
||||
"""
|
||||
# 检查配置是否存在
|
||||
db_config = self.config_repository.get_by_id(db, config_id)
|
||||
if db_config is None:
|
||||
raise ValueError(f"配置不存在: {config_id}")
|
||||
|
||||
# 执行更新
|
||||
if update_fields:
|
||||
for key, value in update_fields.items():
|
||||
if hasattr(db_config, key):
|
||||
setattr(db_config, key, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_config)
|
||||
|
||||
api_logger.info(
|
||||
f"成功更新遗忘引擎配置: config_id={config_id}, "
|
||||
f"更新字段: {list(update_fields.keys())}"
|
||||
)
|
||||
else:
|
||||
api_logger.info(f"没有字段需要更新: config_id={config_id}")
|
||||
|
||||
# 重新加载配置并返回
|
||||
config = load_actr_config_from_db(db, config_id)
|
||||
config['config_id'] = config_id
|
||||
|
||||
return config
|
||||
|
||||
async def get_forgetting_stats(
|
||||
self,
|
||||
db: Session,
|
||||
group_id: Optional[str] = None,
|
||||
config_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎统计信息
|
||||
|
||||
返回知识层节点统计、激活值分布等信息。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
group_id: 组ID(可选)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
|
||||
Returns:
|
||||
dict: 统计信息字典
|
||||
"""
|
||||
# 获取遗忘引擎组件
|
||||
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||
|
||||
connector = forgetting_scheduler.connector
|
||||
forgetting_threshold = config['forgetting_threshold']
|
||||
|
||||
# 收集激活值指标
|
||||
activation_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
activation_query += " AND n.group_id = $group_id"
|
||||
|
||||
activation_query += """
|
||||
RETURN
|
||||
count(n) as total_nodes,
|
||||
sum(CASE WHEN n.activation_value IS NOT NULL THEN 1 ELSE 0 END) as nodes_with_activation,
|
||||
sum(CASE WHEN n.activation_value IS NULL THEN 1 ELSE 0 END) as nodes_without_activation,
|
||||
avg(n.activation_value) as average_activation,
|
||||
sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
|
||||
"""
|
||||
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
|
||||
activation_results = await connector.execute_query(activation_query, **params)
|
||||
|
||||
if activation_results:
|
||||
result = activation_results[0]
|
||||
activation_metrics = {
|
||||
'total_nodes': result['total_nodes'] or 0,
|
||||
'nodes_with_activation': result['nodes_with_activation'] or 0,
|
||||
'nodes_without_activation': result['nodes_without_activation'] or 0,
|
||||
'average_activation_value': result['average_activation'],
|
||||
'low_activation_nodes': result['low_activation_nodes'] or 0,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
else:
|
||||
activation_metrics = {
|
||||
'total_nodes': 0,
|
||||
'nodes_with_activation': 0,
|
||||
'nodes_without_activation': 0,
|
||||
'average_activation_value': None,
|
||||
'low_activation_nodes': 0,
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 收集节点类型分布
|
||||
distribution_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
distribution_query += " AND n.group_id = $group_id"
|
||||
|
||||
distribution_query += """
|
||||
WITH n,
|
||||
CASE
|
||||
WHEN n:Statement THEN 'statement'
|
||||
WHEN n:ExtractedEntity THEN 'entity'
|
||||
WHEN n:MemorySummary THEN 'summary'
|
||||
WHEN n:Chunk THEN 'chunk'
|
||||
END as node_type
|
||||
RETURN
|
||||
sum(CASE WHEN node_type = 'statement' THEN 1 ELSE 0 END) as statement_count,
|
||||
sum(CASE WHEN node_type = 'entity' THEN 1 ELSE 0 END) as entity_count,
|
||||
sum(CASE WHEN node_type = 'summary' THEN 1 ELSE 0 END) as summary_count,
|
||||
sum(CASE WHEN node_type = 'chunk' THEN 1 ELSE 0 END) as chunk_count
|
||||
"""
|
||||
|
||||
dist_params = {}
|
||||
if group_id:
|
||||
dist_params['group_id'] = group_id
|
||||
|
||||
distribution_results = await connector.execute_query(distribution_query, **dist_params)
|
||||
|
||||
if distribution_results:
|
||||
result = distribution_results[0]
|
||||
node_distribution = {
|
||||
'statement_count': result['statement_count'] or 0,
|
||||
'entity_count': result['entity_count'] or 0,
|
||||
'summary_count': result['summary_count'] or 0,
|
||||
'chunk_count': result['chunk_count'] or 0
|
||||
}
|
||||
else:
|
||||
node_distribution = {
|
||||
'statement_count': 0,
|
||||
'entity_count': 0,
|
||||
'summary_count': 0,
|
||||
'chunk_count': 0
|
||||
}
|
||||
|
||||
# 构建统计信息(不包含监控历史数据)
|
||||
stats = {
|
||||
'activation_metrics': activation_metrics,
|
||||
'node_distribution': node_distribution,
|
||||
'consistency_check': None, # 不再提供一致性检查
|
||||
'nodes_merged_total': 0, # 不再跟踪累计融合数
|
||||
'recent_cycles': [], # 不再提供历史记录
|
||||
'timestamp': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
||||
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}"
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
async def get_forgetting_curve(
|
||||
self,
|
||||
db: Session,
|
||||
importance_score: float,
|
||||
days: int,
|
||||
config_id: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘曲线数据
|
||||
|
||||
生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
importance_score: 重要性分数(0-1)
|
||||
days: 模拟天数
|
||||
config_id: 配置ID(可选)
|
||||
|
||||
Returns:
|
||||
dict: 包含曲线数据和配置的字典
|
||||
"""
|
||||
# 获取 ACT-R 计算器
|
||||
actr_calculator, _, _, config = await self._get_forgetting_components(db, config_id)
|
||||
|
||||
# 生成遗忘曲线数据
|
||||
initial_time = datetime.now()
|
||||
curve_data = actr_calculator.get_forgetting_curve(
|
||||
initial_time=initial_time,
|
||||
importance_score=importance_score,
|
||||
days=days
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功生成遗忘曲线数据: {len(curve_data)} 个数据点"
|
||||
)
|
||||
|
||||
return {
|
||||
'curve_data': curve_data,
|
||||
'config': {
|
||||
'decay_constant': config['decay_constant'],
|
||||
'forgetting_rate': config['forgetting_rate'],
|
||||
'offset': config['offset'],
|
||||
'importance_score': importance_score,
|
||||
'days': days
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
@@ -26,7 +25,6 @@ from app.schemas.memory_storage_schema import (
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
ConfigUpdateForget,
|
||||
)
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.utils.sse_utils import format_sse_message
|
||||
@@ -159,11 +157,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
return {"affected": 1}
|
||||
|
||||
# --- Forget config params ---
|
||||
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
|
||||
config = DataConfigRepository.update_forget(self.db, update)
|
||||
if not config:
|
||||
raise ValueError("未找到配置")
|
||||
return {"affected": 1}
|
||||
# 遗忘引擎配置方法已迁移到 memory_forget_service.py
|
||||
# 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
|
||||
|
||||
# --- Read ---
|
||||
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
|
||||
@@ -172,12 +167,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
raise ValueError("未找到配置")
|
||||
return result
|
||||
|
||||
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数
|
||||
result = DataConfigRepository.get_forget_config(self.db, key.config_id)
|
||||
if not result:
|
||||
raise ValueError("未找到配置")
|
||||
return result
|
||||
|
||||
# --- Read All ---
|
||||
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
|
||||
configs = DataConfigRepository.get_all(self.db, workspace_id)
|
||||
|
||||
@@ -412,7 +412,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
@@ -499,7 +499,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Log but continue - will fail later with proper error
|
||||
pass
|
||||
|
||||
@@ -1064,4 +1064,74 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True)
|
||||
def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""定时任务:运行遗忘周期
|
||||
|
||||
定期执行遗忘周期,识别并融合低激活值的知识节点。
|
||||
|
||||
Args:
|
||||
config_id: 配置ID(可选,如果为None则使用默认配置)
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}")
|
||||
|
||||
forget_service = MemoryForgetService()
|
||||
|
||||
# 运行遗忘周期
|
||||
report = await forget_service.trigger_forgetting(
|
||||
db=db,
|
||||
group_id=None, # 处理所有组
|
||||
config_id=config_id
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
api_logger.info(
|
||||
f"遗忘周期定时任务完成: "
|
||||
f"融合 {report['merged_count']} 对节点, "
|
||||
f"失败 {report['failed_count']} 对, "
|
||||
f"耗时 {duration:.2f} 秒"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "遗忘周期执行成功",
|
||||
"report": report,
|
||||
"duration_seconds": duration
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
duration = time.time() - start_time
|
||||
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
|
||||
|
||||
return {
|
||||
"status": "FAILED",
|
||||
"message": f"遗忘周期执行失败: {str(e)}",
|
||||
"duration_seconds": duration
|
||||
}
|
||||
|
||||
# 运行异步函数
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(_run())
|
||||
return result
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
Reference in New Issue
Block a user