Merge #85 into develop from feature/actr-forget

[feature]actr-记忆遗忘需求开发

* feature/actr-forget: (12 commits squashed)

  - [feature]
    1.Extended fields of the date_config table;
    2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.

  - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler

  - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process

  - [feature]
    1.Extended fields of the date_config table;
    2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.

  - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler

  - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process

  - Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget

  - [fix]Eliminate the interference caused by redundant code

  - [feature]
    1.Extended fields of the date_config table;
    2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.

  - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler

  - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process

  - Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget

Signed-off-by: 乐力齐 <accounts_690c7b0af9007d7e338af636@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/85
This commit is contained in:
乐力齐
2026-01-05 04:30:36 +00:00
committed by 孙科
parent d299c39c55
commit e8a5cfe7e3
24 changed files with 4178 additions and 287 deletions

View File

@@ -85,6 +85,8 @@ health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# 构建定时任务配置 # 构建定时任务配置
beat_schedule_config = { beat_schedule_config = {
@@ -103,6 +105,13 @@ beat_schedule_config = {
"schedule": memory_cache_regeneration_schedule, "schedule": memory_cache_regeneration_schedule,
"args": (), "args": (),
}, },
"run-forgetting-cycle": {
"task": "app.tasks.run_forgetting_cycle_task",
"schedule": forgetting_cycle_schedule,
"kwargs": {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
} }
# 如果配置了默认工作空间ID则添加记忆总量统计任务 # 如果配置了默认工作空间ID则添加记忆总量统计任务

View File

@@ -33,7 +33,7 @@ from . import (
emotion_config_controller, emotion_config_controller,
prompt_optimizer_controller, prompt_optimizer_controller,
tool_controller, tool_controller,
home_page_controller, memory_forget_controller,
) )
from . import user_memory_controllers from . import user_memory_controllers
@@ -71,6 +71,6 @@ manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(prompt_optimizer_controller.router)
manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(memory_reflection_controller.router)
manager_router.include_router(tool_controller.router) manager_router.include_router(tool_controller.router)
manager_router.include_router(home_page_controller.router) manager_router.include_router(memory_forget_controller.router)
__all__ = ["manager_router"] __all__ = ["manager_router"]

View File

@@ -0,0 +1,324 @@
"""
遗忘引擎控制器模块
本模块提供遗忘引擎的 REST API 接口,包括:
1. 手动触发遗忘周期
2. 获取和更新配置
3. 获取统计信息
4. 获取遗忘曲线数据
所有接口都需要用户认证,并自动关联到当前工作空间。
"""
from typing import Optional
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ForgettingTriggerRequest,
ForgettingConfigResponse,
ForgettingConfigUpdateRequest,
ForgettingStatsResponse,
ForgettingReportResponse,
ForgettingCurveRequest,
ForgettingCurveResponse,
ForgettingCurvePoint,
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService
# 获取API专用日志器
api_logger = get_api_logger()
router = APIRouter(
prefix="/memory/forget",
tags=["Memory Forgetting Engine"],
dependencies=[Depends(get_current_user)] # 所有路由都需要认证
)
# 初始化服务
forget_service = MemoryForgetService()
# ==================== API 端点 ====================
@router.post("/trigger", response_model=ApiResponse)
async def trigger_forgetting_cycle(
payload: ForgettingTriggerRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
手动触发遗忘周期
执行一次完整的遗忘周期,识别并融合低激活值节点。
Args:
payload: 触发请求参数
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含遗忘报告的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: "
f"group_id={payload.group_id}, max_batch={payload.max_merge_batch_size}, "
f"min_days={payload.min_days_since_access}"
)
try:
# 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db,
group_id=payload.group_id,
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=payload.config_id
)
# 构建响应
response_data = ForgettingReportResponse(**report)
return success(data=response_data.model_dump(), msg="遗忘周期执行成功")
except RuntimeError as e:
api_logger.warning(f"遗忘周期执行被拒绝: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "RuntimeError")
except Exception as e:
api_logger.error(f"触发遗忘周期失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "触发遗忘周期失败", str(e))
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取遗忘引擎配置
读取指定配置ID的遗忘引擎参数。
Args:
config_id: 配置ID
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含配置信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}"
)
try:
# 调用服务层读取配置
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
# 构建响应
response_data = ForgettingConfigResponse(**config)
return success(data=response_data.model_dump(), msg="查询成功")
except ValueError as e:
api_logger.warning(f"配置不存在: config_id={config_id}, 错误: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, f"配置不存在: {config_id}", str(e))
except Exception as e:
api_logger.error(f"读取遗忘引擎配置失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
@router.post("/update_config", response_model=ApiResponse)
async def update_forgetting_config(
payload: ForgettingConfigUpdateRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
更新遗忘引擎配置
更新指定配置ID的遗忘引擎参数。
Args:
payload: 配置更新请求
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含更新结果的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}"
)
try:
# 构建更新字段字典(排除 None 值和 config_id
update_data = {
key: value
for key, value in payload.model_dump(exclude_none=True).items()
if key != 'config_id'
}
# 调用服务层更新配置
config = forget_service.update_forgetting_config(
db=db,
config_id=payload.config_id,
update_fields=update_data
)
# 构建响应
response_data = ForgettingConfigResponse(**config)
return success(data=response_data.model_dump(), msg="更新成功")
except ValueError as e:
api_logger.warning(f"配置不存在: config_id={payload.config_id}, 错误: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
db.rollback()
api_logger.error(f"更新遗忘引擎配置失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
@router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats(
group_id: Optional[str] = None,
config_id: Optional[int] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取遗忘引擎统计信息
返回知识层节点统计、激活值分布等信息。
Args:
group_id: 组ID可选
config_id: 配置ID可选用于获取遗忘阈值
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含统计信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"group_id={group_id}, config_id={config_id}"
)
try:
# 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats(
db=db,
group_id=group_id,
config_id=config_id
)
# 构建响应
response_data = ForgettingStatsResponse(**stats)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取遗忘引擎统计失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
@router.post("/forgetting_curve", response_model=ApiResponse)
async def get_forgetting_curve(
request: ForgettingCurveRequest,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取遗忘曲线数据
生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。
Args:
request: 遗忘曲线请求参数
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含遗忘曲线数据的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘曲线: "
f"importance_score={request.importance_score}, days={request.days}, config_id={request.config_id}"
)
try:
# 调用服务层生成遗忘曲线
result = await forget_service.get_forgetting_curve(
db=db,
importance_score=request.importance_score,
days=request.days,
config_id=request.config_id
)
# 转换为响应格式
curve_points = [
ForgettingCurvePoint(**point)
for point in result['curve_data']
]
# 构建响应
response_data = ForgettingCurveResponse(
curve_data=curve_points,
config=result['config']
)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取遗忘曲线失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取遗忘曲线失败", str(e))

View File

@@ -1,4 +1,3 @@
import datetime
import os import os
import uuid import uuid
from typing import Optional from typing import Optional
@@ -9,12 +8,7 @@ from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.core.response_utils import fail, success from app.core.response_utils import fail, success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user from app.dependencies import get_current_user
from app.models.end_user_model import EndUser
from app.models.user_model import User from app.models.user_model import User
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
)
from app.schemas.memory_storage_schema import ( from app.schemas.memory_storage_schema import (
ConfigKey, ConfigKey,
ConfigParamsCreate, ConfigParamsCreate,
@@ -22,8 +16,6 @@ from app.schemas.memory_storage_schema import (
ConfigPilotRun, ConfigPilotRun,
ConfigUpdate, ConfigUpdate,
ConfigUpdateExtracted, ConfigUpdateExtracted,
ConfigUpdateForget,
GenerateCacheRequest,
) )
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services.memory_storage_service import ( from app.services.memory_storage_service import (
@@ -238,28 +230,8 @@ def update_config_extracted(
# --- Forget config params --- # --- Forget config params ---
@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径) # 遗忘引擎配置接口已迁移到 memory_forget_controller.py
def update_config_forget( # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
payload: ConfigUpdateForget,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}")
try:
svc = DataConfigService(db)
result = svc.update_forget(payload)
return success(data=result, msg="更新成功")
except Exception as e:
api_logger.error(f"Update config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e))
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
@@ -283,28 +255,6 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}") api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_forget(
config_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}")
try:
svc = DataConfigService(db)
result = svc.get_forget(ConfigKey(config_id=config_id))
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Read config forget failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 @router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config( def read_all_config(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),

View File

@@ -106,28 +106,32 @@ class SearchService:
limit: int = 15, limit: int = 15,
search_type: str = "hybrid", search_type: str = "hybrid",
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
rerank_alpha: float = 0.4, rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
output_path: str = "search_results.json", output_path: str = "search_results.json",
return_raw_results: bool = False, return_raw_results: bool = False,
memory_config: "MemoryConfig" = None, memory_config: "MemoryConfig" = None,
) -> Tuple[str, str, Optional[dict]]: ) -> Tuple[str, str, Optional[dict]]:
""" """
Execute hybrid search and return clean content. Execute hybrid search with two-stage ranking.
Stage 1: Filter by content relevance (BM25 + Embedding)
Stage 2: Rerank by activation values (ACTR)
Args: Args:
group_id: Group identifier for filtering results group_id: Group identifier for filtering
question: Search query text question: Search query text
limit: Maximum number of results to return (default: 5) limit: Max results per category (default: 15)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"]) include: Result types (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4) rerank_alpha: BM25 weight (default: 0.6)
output_path: Path to save search results (default: "search_results.json") activation_boost_factor: Activation impact on memory strength (default: 0.8)
return_raw_results: If True, also return the raw search results as third element (default: False) output_path: JSON output path (default: "search_results.json")
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided. return_raw_results: Return full metadata (default: False)
memory_config: MemoryConfig for embedding model
Returns: Returns:
Tuple of (clean_content, cleaned_query, raw_results) Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = ["statements", "chunks", "entities", "summaries"]
@@ -151,6 +155,7 @@ class SearchService:
output_path=output_path, output_path=output_path,
memory_config=config, memory_config=config,
rerank_alpha=rerank_alpha, rerank_alpha=rerank_alpha,
activation_boost_factor=activation_boost_factor,
) )
# Extract results based on search type and include parameter # Extract results based on search type and include parameter

View File

@@ -228,6 +228,13 @@ class StatementNode(Node):
chunk_embedding: Optional embedding vector for the parent chunk chunk_embedding: Optional embedding vector for the parent chunk
connect_strength: Classification of connection strength ('Strong' or 'Weak') connect_strength: Classification of connection strength ('Strong' or 'Weak')
config_id: Configuration ID used to process this statement config_id: Configuration ID used to process this statement
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
access_history: List of ISO timestamp strings recording each access
last_access_time: ISO timestamp of the most recent access
access_count: Total number of times this node has been accessed
""" """
# Core fields (ordered as requested) # Core fields (ordered as requested)
chunk_id: str = Field(..., description="ID of the parent chunk") chunk_id: str = Field(..., description="ID of the parent chunk")
@@ -269,6 +276,33 @@ class StatementNode(Node):
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), default 0.5"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access"
)
access_count: int = Field(
default=0,
ge=0,
description="Total number of times this node has been accessed"
)
@field_validator('valid_at', 'invalid_at', mode='before') @field_validator('valid_at', 'invalid_at', mode='before')
@classmethod @classmethod
def validate_datetime(cls, v): def validate_datetime(cls, v):
@@ -351,6 +385,13 @@ class ExtractedEntityNode(Node):
fact_summary: Summary of facts about this entity fact_summary: Summary of facts about this entity
connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both') connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both')
config_id: Configuration ID used to process this entity (integer or string) config_id: Configuration ID used to process this entity (integer or string)
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
access_history: List of ISO timestamp strings recording each access
last_access_time: ISO timestamp of the most recent access
access_count: Total number of times this node has been accessed
""" """
entity_idx: int = Field(..., description="Unique identifier for the entity") entity_idx: int = Field(..., description="Unique identifier for the entity")
statement_id: str = Field(..., description="Statement this entity was extracted from") statement_id: str = Field(..., description="Statement this entity was extracted from")
@@ -365,6 +406,33 @@ class ExtractedEntityNode(Node):
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), default 0.5"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access"
)
access_count: int = Field(
default=0,
ge=0,
description="Total number of times this node has been accessed"
)
@field_validator('aliases', mode='before') @field_validator('aliases', mode='before')
@classmethod @classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
@@ -401,6 +469,16 @@ class MemorySummaryNode(Node):
summary_embedding: Optional embedding vector for the summary summary_embedding: Optional embedding vector for the summary
metadata: Additional metadata for the summary metadata: Additional metadata for the summary
config_id: Configuration ID used to process this summary config_id: Configuration ID used to process this summary
original_statement_id: ID of the original statement that was merged (for ACT-R forgetting)
original_entity_id: ID of the original entity that was merged (for ACT-R forgetting)
merged_at: Timestamp when the nodes were merged
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), inherited from merged nodes
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes
access_history: List of ISO timestamp strings recording each access (reset on creation)
last_access_time: ISO timestamp of the most recent access (set to creation time)
access_count: Total number of times this node has been accessed (reset to 1 on creation)
""" """
summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary") summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary")
dialog_id: str = Field(..., description="ID of the parent dialog") dialog_id: str = Field(..., description="ID of the parent dialog")
@@ -409,3 +487,44 @@ class MemorySummaryNode(Node):
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field(
None,
description="ID of the original statement that was merged (for traceability)"
)
original_entity_id: Optional[str] = Field(
None,
description="ID of the original entity that was merged (for traceability)"
)
merged_at: Optional[datetime] = Field(
None,
description="Timestamp when the nodes were merged"
)
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), inherited from merged nodes"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access (reset on creation)"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access (set to creation time)"
)
access_count: int = Field(
default=1,
ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)"
)

View File

@@ -69,6 +69,12 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
for item in results: for item in results:
if score_field in item: if score_field in item:
score = item.get(score_field) score = item.get(score_field)
# 对于 activation_valueNone 值保持为 None不使用回退值
# 这样可以区分有激活值和无激活值的节点
if score_field == "activation_value" and score is None:
scores.append(None) # 保持 None稍后特殊处理
continue
if score is not None and isinstance(score, (int, float)): if score is not None and isinstance(score, (int, float)):
scores.append(float(score)) scores.append(float(score))
else: else:
@@ -76,205 +82,433 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if not scores: if not scores:
return results return results
if len(scores) == 1: # 过滤掉 None 值,只对有效分数进行归一化
# Single score, set to 1.0 valid_scores = [s for s in scores if s is not None]
if not valid_scores:
# 所有分数都是 None不进行归一化
for item in results: for item in results:
if score_field in item: if score_field in item or score_field == "activation_value":
item[f"normalized_{score_field}"] = 1.0 item[f"normalized_{score_field}"] = None
return results return results
# Calculate mean and standard deviation if len(valid_scores) == 1: # Single valid score, set to 1.0
mean_score = sum(scores) / len(scores) for item, score in zip(results, scores):
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores) if score_field in item or score_field == "activation_value":
if score is None:
item[f"normalized_{score_field}"] = None
else:
item[f"normalized_{score_field}"] = 1.0
return results
# Calculate mean and standard deviation (only for valid scores)
mean_score = sum(valid_scores) / len(valid_scores)
variance = sum((score - mean_score) ** 2 for score in valid_scores) / len(valid_scores)
std_dev = math.sqrt(variance) std_dev = math.sqrt(variance)
if std_dev == 0: if std_dev == 0:
# All scores are the same, set them to 1.0 # All valid scores are the same, set them to 1.0
for item in results: for item, score in zip(results, scores):
if score_field in item: if score_field in item or score_field == "activation_value":
item[f"normalized_{score_field}"] = 1.0 if score is None:
item[f"normalized_{score_field}"] = None
else:
item[f"normalized_{score_field}"] = 1.0
else: else:
for item in results: for item, score in zip(results, scores):
if score_field in item: if score_field in item or score_field == "activation_value":
score = item[score_field] if score is None:
# Handle None or non-numeric scores # 保持 None不进行归一化
if score is None or not isinstance(score, (int, float)): item[f"normalized_{score_field}"] = None
score = 0.0 else:
# Calculate z-score # Calculate z-score
z_score = (score - mean_score) / std_dev z_score = (score - mean_score) / std_dev
# Transform to positive range using sigmoid function # Transform to positive range using sigmoid function
normalized = 1 / (1 + math.exp(-z_score)) normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized item[f"normalized_{score_field}"] = normalized
return results return results
def rerank_hybrid_results( # ============================================================================
keyword_results: Dict[str, List[Dict[str, Any]]], # 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
embedding_results: Dict[str, List[Dict[str, Any]]], # ============================================================================
alpha: float = 0.6,
limit: int = 10
) -> Dict[str, List[Dict[str, Any]]]:
"""
Rerank hybrid search results by combining BM25 and embedding scores.
Args: # def rerank_hybrid_results(
keyword_results: Results from keyword/BM25 search # keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Results from embedding search # embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: Weight for BM25 scores (1-alpha for embedding scores) # alpha: float = 0.6,
limit: Maximum number of results to return per category # limit: int = 10
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid search results by combining BM25 and embedding scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
#
# Returns:
# Reranked results with combined scores
# """
# reranked = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# # Create a combined pool of unique items
# combined_items = {}
#
# # Add keyword results with BM25 scores
# for item in keyword_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# combined_items[item_id]["embedding_score"] = 0 # Default
#
# # Add or update with embedding results
# for item in embedding_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# if item_id in combined_items:
# # Update existing item with embedding score
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# # New item from embedding search only
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0 # Default
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
#
# # Calculate combined scores and rank
# for item_id, item in combined_items.items():
# bm25_score = item.get("bm25_score", 0)
# embedding_score = item.get("embedding_score", 0)
#
# # Combined score: weighted average of normalized scores
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# item["combined_score"] = combined_score
#
# # Keep original score for reference
# if "score" not in item and bm25_score > 0:
# item["score"] = bm25_score
# elif "score" not in item and embedding_score > 0:
# item["score"] = embedding_score
#
# # Sort by combined score and limit results
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
Returns: # def rerank_with_forgetting_curve(
Reranked results with combined scores # keyword_results: Dict[str, List[Dict[str, Any]]],
""" # embedding_results: Dict[str, List[Dict[str, Any]]],
reranked = {} # alpha: float = 0.6,
# limit: int = 10,
# forgetting_config: ForgettingEngineConfig | None = None,
# now: datetime | None = None,
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid results with a forgetting curve applied to combined scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
#
# The forgetting curve reduces scores for older memories or weaker connections.
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
# forgetting_config: Configuration for the forgetting engine
# now: Optional current time override for testing
#
# Returns:
# Reranked results with combined and final scores (after forgetting)
# """
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
# now_dt = now or datetime.now()
#
# reranked: Dict[str, List[Dict[str, Any]]] = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# combined_items: Dict[str, Dict[str, Any]] = {}
#
# # Combine two result sets by ID
# for src_items, is_embedding in (
# (keyword_items, False), (embedding_items, True)
# ):
# for item in src_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if not item_id:
# continue
# existing = combined_items.get(item_id)
# if not existing:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = 0
# # Update normalized score from the right source
# if is_embedding:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
#
# # Calculate scores and apply forgetting weights
# for item_id, item in combined_items.items():
# bm25_score = float(item.get("bm25_score", 0) or 0)
# embedding_score = float(item.get("embedding_score", 0) or 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
#
# # Estimate time elapsed in days
# dt = _parse_datetime(item.get("created_at"))
# if dt is None:
# time_elapsed_days = 0.0
# else:
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
#
# # Memory strength (currently set to default value)
# memory_strength = 1.0
# forgetting_weight = engine.calculate_weight(
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
# )
# final_score = combined_score * forgetting_weight
# item["combined_score"] = final_score
#
# sorted_items = sorted(
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
for category in ["statements", "chunks", "entities","summaries"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# Normalize scores within each search type def rerank_with_activation(
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# Create a combined pool of unique items
combined_items = {}
# Add keyword results with BM25 scores
for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # Default
# Add or update with embedding results
for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
if item_id in combined_items:
# Update existing item with embedding score
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
# New item from embedding search only
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # Default
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# Calculate combined scores and rank
for item_id, item in combined_items.items():
bm25_score = item.get("bm25_score", 0)
embedding_score = item.get("embedding_score", 0)
# Combined score: weighted average of normalized scores
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
item["combined_score"] = combined_score
# Keep original score for reference
if "score" not in item and bm25_score > 0:
item["score"] = bm25_score
elif "score" not in item and embedding_score > 0:
item["score"] = embedding_score
# Sort by combined score and limit results
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
reranked[category] = sorted_items
return reranked
def rerank_with_forgetting_curve(
keyword_results: Dict[str, List[Dict[str, Any]]], keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]], embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6, alpha: float = 0.6,
limit: int = 10, limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None, forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None, now: datetime | None = None,
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Rerank hybrid results with a forgetting curve applied to combined scores. 两阶段排序:先按内容相关性筛选,再按激活值排序。
The forgetting curve reduces scores for older memories or weaker connections. 阶段1: content_score = alpha*BM25 + (1-alpha)*Embedding取 Top-(limit*3)
阶段2: 在候选中按 activation_score 排序,取 Top-limit
Args: 无激活值的节点用于补充不足
keyword_results: Results from keyword/BM25 search
embedding_results: Results from embedding search 返回结果中的评分字段说明:
alpha: Weight for BM25 scores (1-alpha for embedding scores) - bm25_score: BM25 归一化分数
limit: Maximum number of results to return per category - embedding_score: Embedding 归一化分数
forgetting_config: Configuration for the forgetting engine - content_score: 内容相关性 = alpha*bm25 + (1-alpha)*embedding
now: Optional current time override for testing - activation_score: ACTR 激活值归一化分数
- base_score: 第一阶段基础分数(等于 content_score
Returns: - final_score: 最终排序依据
Reranked results with combined and final scores (after forgetting) * 有激活值的节点:final_score = activation_score
* 无激活值的节点final_score = base_score
参数:
keyword_results: BM25 检索结果
embedding_results: 向量嵌入检索结果
alpha: BM25 权重 (默认: 0.6)
limit: 每类最大结果数
forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算)
返回:
带评分元数据的重排序结果,按 final_score 排序
""" """
engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig()) # 验证权重范围
if not (0 <= alpha <= 1):
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
# 初始化遗忘引擎(如果需要)
engine = None
if forgetting_config:
engine = ForgettingEngine(forgetting_config)
now_dt = now or datetime.now() now_dt = now or datetime.now()
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities","summaries"]: for category in ["statements", "chunks", "entities", "summaries"]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
# Normalize scores within each search type # 步骤 1: 归一化分数
keyword_items = normalize_scores(keyword_items, "score") keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score") embedding_items = normalize_scores(embedding_items, "score")
# 步骤 2: 按 ID 合并结果
combined_items: Dict[str, Dict[str, Any]] = {} combined_items: Dict[str, Dict[str, Any]] = {}
# Combine two result sets by ID # 添加关键词结果
for src_items, is_embedding in ( for item in keyword_items:
(keyword_items, False), (embedding_items, True) item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
): if not item_id:
for item in src_items: continue
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") combined_items[item_id] = item.copy()
if not item_id: combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
continue combined_items[item_id]["embedding_score"] = 0 # 默认值
existing = combined_items.get(item_id)
if not existing: # 添加或更新向量嵌入结果
combined_items[item_id] = item.copy() for item in embedding_items:
combined_items[item_id]["bm25_score"] = 0 item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
combined_items[item_id]["embedding_score"] = 0 if not item_id:
# Update normalized score from the right source continue
if is_embedding: if item_id in combined_items:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) # 更新现有项的嵌入分数
else: combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# Calculate scores and apply forgetting weights
for item_id, item in combined_items.items():
bm25_score = float(item.get("bm25_score", 0) or 0)
embedding_score = float(item.get("embedding_score", 0) or 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# Estimate time elapsed in days
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else: else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) # 仅来自嵌入搜索的新项
combined_items[item_id] = item.copy()
# Memory strength (currently set to default value) combined_items[item_id]["bm25_score"] = 0 # 默认值
memory_strength = 1.0 combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days, memory_strength=memory_strength # 步骤 3: 归一化激活度分数
) # 为所有项准备激活度值列表
# print(f"Forgetting weight for {item_id}: {forgetting_weight}") items_list = list(combined_items.values())
# print(f"Time elapsed days for {item_id}: {time_elapsed_days}") items_list = normalize_scores(items_list, "activation_value")
final_score = combined_score * forgetting_weight
item["combined_score"] = final_score # 更新 combined_items 中的归一化激活度分数
for item in items_list:
sorted_items = sorted( item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True if item_id and item_id in combined_items:
)[:limit] combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
# 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0)
# 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用
item["activation_score"] = act_norm
item["content_score"] = content_score
item["base_score"] = base_score
# 步骤 5: 应用遗忘曲线(可选)
if engine:
# 计算受激活度影响的记忆强度
importance = float(item.get("importance_score", 0.5) or 0.5)
# 获取 activation_value
activation_val = item.get("activation_value")
# 只对有激活值的节点应用遗忘曲线
if activation_val is not None and isinstance(activation_val, (int, float)):
activation_val = float(activation_val)
# 计算记忆强度importance_score × (1 + activation_value × boost_factor)
memory_strength = importance * (1 + activation_val * activation_boost_factor)
# 计算经过的时间(天数)
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# 获取遗忘权重
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days,
memory_strength=memory_strength
)
# 应用到基础分数
item["forgetting_weight"] = forgetting_weight
item["final_score"] = base_score * forgetting_weight
else:
# 无激活值的节点不应用遗忘曲线,保持原始分数
item["final_score"] = base_score
else:
# 不使用遗忘曲线
item["final_score"] = base_score
# 步骤 6: 两阶段排序和限制
# 第一阶段按内容相关性base_score排序取 Top-K
first_stage_limit = limit * 3 # 可配置取3倍候选
first_stage_sorted = sorted(
combined_items.values(),
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
reverse=True
)[:first_stage_limit]
# 第二阶段:分离有激活值和无激活值的节点
items_with_activation = []
items_without_activation = []
for item in first_stage_sorted:
activation_score = item.get("activation_score")
# 检查是否有有效的激活值(不是 None
if activation_score is not None and isinstance(activation_score, (int, float)):
items_with_activation.append(item)
else:
items_without_activation.append(item)
# 优先按激活值排序有激活值的节点
sorted_with_activation = sorted(
items_with_activation,
key=lambda x: float(x.get("activation_score", 0) or 0),
reverse=True
)
# 如果有激活值的节点不足 limit用无激活值的节点补充
if len(sorted_with_activation) < limit:
needed = limit - len(sorted_with_activation)
# 无激活值的节点保持第一阶段的内容相关性排序
sorted_items = sorted_with_activation + items_without_activation[:needed]
else:
sorted_items = sorted_with_activation[:limit]
# 两阶段排序完成,更新 final_score 以反映实际排序依据
# Stage 1: 按 content_score 筛选候选(已完成)
# Stage 2: 按 activation_score 排序(已完成)
#
# final_score 语义:反映节点在最终结果中的排序依据
# - 有激活值的节点final_score = activation_score第二阶段排序依据
# - 无激活值的节点final_score = base_score保持内容相关性分数
for item in sorted_items:
activation_score = item.get("activation_score")
if activation_score is not None and isinstance(activation_score, (int, float)):
# 有激活值:使用激活度作为最终分数
item["final_score"] = activation_score
else:
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
reranked[category] = sorted_items reranked[category] = sorted_items
return reranked return reranked
@@ -560,6 +794,7 @@ async def run_hybrid_search(
output_path: str | None, output_path: str | None,
memory_config: "MemoryConfig", memory_config: "MemoryConfig",
rerank_alpha: float = 0.6, rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False, use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False, use_llm_rerank: bool = False,
): ):
@@ -685,30 +920,28 @@ async def run_hybrid_search(
"search_timestamp": datetime.now().isoformat() "search_timestamp": datetime.now().isoformat()
} }
# Apply reranking (optionally with forgetting curve) # Apply two-stage reranking with ACTR activation calculation
rerank_start = time.time() rerank_start = time.time()
if use_forgetting_rerank: logger.info("Using two-stage reranking with ACTR activation")
# Load forgetting parameters from pipeline config
try: # 加载遗忘引擎配置
pc = get_pipeline_config(memory_config) try:
forgetting_cfg = pc.forgetting_engine pc = get_pipeline_config(memory_config)
except Exception as e: forgetting_cfg = pc.forgetting_engine
logger.debug(f"Failed to load forgetting config, using defaults: {e}") except Exception as e:
forgetting_cfg = ForgettingEngineConfig() logger.debug(f"Failed to load forgetting config, using defaults: {e}")
reranked_results = rerank_with_forgetting_curve( forgetting_cfg = ForgettingEngineConfig()
keyword_results=keyword_results,
embedding_results=embedding_results, # 统一使用激活度重排序(两阶段:检索 + ACTR计算
alpha=rerank_alpha, reranked_results = rerank_with_activation(
limit=limit, keyword_results=keyword_results,
forgetting_config=forgetting_cfg, embedding_results=embedding_results,
) alpha=rerank_alpha,
else: limit=limit,
reranked_results = rerank_hybrid_results( forgetting_config=forgetting_cfg,
keyword_results=keyword_results, activation_boost_factor=activation_boost_factor,
embedding_results=embedding_results, )
alpha=rerank_alpha, # Configurable weight for BM25 vs embedding
limit=limit
)
rerank_latency = time.time() - rerank_start rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4) latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"Reranking completed in {rerank_latency:.4f}s") logger.info(f"Reranking completed in {rerank_latency:.4f}s")
@@ -737,6 +970,7 @@ async def run_hybrid_search(
"search_query": query_text, "search_query": query_text,
"search_timestamp": datetime.now().isoformat(), "search_timestamp": datetime.now().isoformat(),
"reranking_alpha": rerank_alpha, "reranking_alpha": rerank_alpha,
"activation_boost_factor": activation_boost_factor,
"forgetting_rerank": use_forgetting_rerank, "forgetting_rerank": use_forgetting_rerank,
"llm_rerank": llm_rerank_applied, "llm_rerank": llm_rerank_applied,
} }

View File

@@ -1,8 +1,40 @@
"""遗忘引擎模块 """遗忘引擎模块
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。 该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线和 ACT-R 认知架构理论
""" """
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
calculate_activation,
generate_forgetting_curve
)
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager,
ConsistencyCheckResult
)
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import (
ForgettingStrategy
)
from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import (
ForgettingScheduler
)
from app.core.memory.storage_services.forgetting_engine.config_utils import (
calculate_forgetting_rate,
load_actr_config_from_db,
create_actr_calculator_from_config
)
__all__ = ["ForgettingEngine"] __all__ = [
"ForgettingEngine",
"ACTRCalculator",
"calculate_activation",
"generate_forgetting_curve",
"AccessHistoryManager",
"ConsistencyCheckResult",
"ForgettingStrategy",
"ForgettingScheduler",
"calculate_forgetting_rate",
"load_actr_config_from_db",
"create_actr_calculator_from_config"
]

View File

@@ -0,0 +1,691 @@
"""
访问历史管理器模块
本模块实现访问历史的追踪、更新和一致性保证。
负责在知识节点被访问时原子性地更新激活值相关的所有字段。
Classes:
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
"""
import logging
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from enum import Enum
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
class ConsistencyCheckResult(Enum):
"""一致性检查结果枚举"""
CONSISTENT = "consistent" # 数据一致
INCONSISTENT_HISTORY_TIME = "inconsistent_history_time" # access_history[-1] != last_access_time
INCONSISTENT_HISTORY_COUNT = "inconsistent_history_count" # len(access_history) != access_count
MISSING_ACTIVATION = "missing_activation" # 有访问历史但无激活值
INVALID_ACTIVATION_RANGE = "invalid_activation_range" # 激活值超出有效范围
class AccessHistoryManager:
"""
访问历史管理器
负责追踪知识节点的访问历史,并在访问时原子性地更新所有相关字段:
- activation_value: 激活值
- access_history: 访问历史时间戳数组
- last_access_time: 最后访问时间
- access_count: 访问次数
特性:
- 原子性更新使用Neo4j事务确保所有字段同时更新或回滚
- 并发安全:使用乐观锁机制防止并发冲突
- 一致性保证:提供一致性检查和自动修复功能
- 智能修剪:自动修剪过长的访问历史
Attributes:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 并发冲突时的最大重试次数
"""
def __init__(
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
max_retries: int = 3
):
"""
初始化访问历史管理器
Args:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 并发冲突时的最大重试次数默认3次
"""
self.connector = connector
self.actr_calculator = actr_calculator
self.max_retries = max_retries
async def record_access(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
记录节点访问并原子性更新所有相关字段
这是核心方法,实现了:
1. 首次访问初始化access_history计算初始激活值
2. 后续访问:追加访问历史,重新计算激活值
3. 历史修剪:当历史过长时自动修剪
4. 原子性:所有字段在单个事务中更新
5. 并发安全:使用乐观锁重试机制
Args:
node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
Returns:
Dict[str, Any]: 更新后的节点数据,包含:
- id: 节点ID
- activation_value: 更新后的激活值
- access_history: 更新后的访问历史
- last_access_time: 最后访问时间
- access_count: 访问次数
- importance_score: 重要性分数
Raises:
ValueError: 如果节点不存在或节点标签无效
RuntimeError: 如果重试次数耗尽仍然失败
"""
if current_time is None:
current_time = datetime.now()
current_time_iso = current_time.isoformat()
# 验证节点标签
valid_labels = ["Statement", "ExtractedEntity", "MemorySummary"]
if node_label not in valid_labels:
raise ValueError(
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
)
# 使用乐观锁重试机制处理并发冲突
for attempt in range(self.max_retries):
try:
# 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
raise ValueError(
f"Node not found: {node_label} with id={node_id}"
)
# 步骤2计算新的访问历史和激活值
update_data = await self._calculate_update(
node_data=node_data,
current_time=current_time,
current_time_iso=current_time_iso
)
# 步骤3原子性更新节点使用事务
updated_node = await self._atomic_update(
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
)
logger.info(
f"成功记录访问: {node_label}[{node_id}], "
f"activation={update_data['activation_value']:.4f}, "
f"access_count={update_data['access_count']}"
)
return updated_node
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}: {str(e)}"
)
continue
else:
logger.error(
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
f"错误: {str(e)}"
)
raise RuntimeError(
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
)
async def record_batch_access(
self,
node_ids: List[str],
node_label: str,
group_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""
批量记录多个节点的访问
为提高性能,批量更新多个节点的访问历史。
每个节点独立更新,失败的节点不影响其他节点。
Args:
node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选
current_time: 当前时间(可选)
Returns:
List[Dict[str, Any]]: 成功更新的节点列表
"""
if current_time is None:
current_time = datetime.now()
results = []
failed_count = 0
for node_id in node_ids:
try:
updated_node = await self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
current_time=current_time
)
results.append(updated_node)
except Exception as e:
failed_count += 1
logger.warning(
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
)
logger.info(
f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}"
)
return results
async def check_consistency(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
"""
检查节点数据的一致性
验证以下一致性规则:
1. access_history[-1] == last_access_time
2. len(access_history) == access_count
3. 如果有访问历史,必须有激活值
4. 激活值必须在有效范围内 [offset, 1.0]
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举
- 错误描述(如果不一致)
"""
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
return ConsistencyCheckResult.CONSISTENT, None
access_history = node_data.get('access_history', [])
last_access_time = node_data.get('last_access_time')
access_count = node_data.get('access_count', 0)
activation_value = node_data.get('activation_value')
# 检查1access_history[-1] == last_access_time
if access_history and last_access_time:
if access_history[-1] != last_access_time:
return (
ConsistencyCheckResult.INCONSISTENT_HISTORY_TIME,
f"access_history[-1]={access_history[-1]} != "
f"last_access_time={last_access_time}"
)
# 检查2len(access_history) == access_count
if len(access_history) != access_count:
return (
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
f"len(access_history)={len(access_history)} != "
f"access_count={access_count}"
)
# 检查3有访问历史必须有激活值
if access_history and activation_value is None:
return (
ConsistencyCheckResult.MISSING_ACTIVATION,
"Node has access_history but activation_value is None"
)
# 检查4激活值范围
if activation_value is not None:
offset = self.actr_calculator.offset
if not (offset <= activation_value <= 1.0):
return (
ConsistencyCheckResult.INVALID_ACTIVATION_RANGE,
f"activation_value={activation_value} out of range "
f"[{offset}, 1.0]"
)
return ConsistencyCheckResult.CONSISTENT, None
async def check_batch_consistency(
self,
node_label: str,
group_id: Optional[str] = None,
limit: int = 1000
) -> Dict[str, Any]:
"""
批量检查多个节点的一致性
Args:
node_label: 节点标签
group_id: 组ID可选
limit: 检查的最大节点数
Returns:
Dict[str, Any]: 一致性检查报告,包含:
- total_checked: 检查的节点总数
- consistent_count: 一致的节点数
- inconsistent_count: 不一致的节点数
- inconsistencies: 不一致节点的详细信息列表
- consistency_rate: 一致性率0-1
"""
# 查询所有相关节点
query = f"""
MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL
"""
if group_id:
query += " AND n.group_id = $group_id"
query += """
RETURN n.id as id
LIMIT $limit
"""
params = {"limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results]
# 检查每个节点
inconsistencies = []
consistent_count = 0
for node_id in node_ids:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
)
if result == ConsistencyCheckResult.CONSISTENT:
consistent_count += 1
else:
inconsistencies.append({
'node_id': node_id,
'result': result.value,
'message': message
})
total_checked = len(node_ids)
inconsistent_count = len(inconsistencies)
consistency_rate = consistent_count / total_checked if total_checked > 0 else 1.0
report = {
'total_checked': total_checked,
'consistent_count': consistent_count,
'inconsistent_count': inconsistent_count,
'inconsistencies': inconsistencies,
'consistency_rate': consistency_rate
}
logger.info(
f"一致性检查完成: {node_label}, "
f"一致率={consistency_rate:.2%}, "
f"不一致节点={inconsistent_count}/{total_checked}"
)
return report
async def repair_inconsistency(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> bool:
"""
自动修复节点的数据不一致问题
修复策略:
1. 如果access_history[-1] != last_access_time使用access_history[-1]
2. 如果len(access_history) != access_count使用len(access_history)
3. 如果有历史但无激活值:重新计算激活值
4. 如果激活值超出范围:重新计算激活值
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
bool: 修复成功返回True否则返回False
"""
try:
# 检查一致性
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
)
if result == ConsistencyCheckResult.CONSISTENT:
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
return True
# 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False
access_history = node_data.get('access_history', [])
importance_score = node_data.get('importance_score', 0.5)
# 准备修复数据
repair_data = {}
# 修复last_access_time
if access_history:
repair_data['last_access_time'] = access_history[-1]
# 修复access_count
repair_data['access_count'] = len(access_history)
# 修复activation_value
if access_history:
current_time = datetime.now()
last_access_dt = datetime.fromisoformat(access_history[-1])
access_history_dt = [
datetime.fromisoformat(ts) for ts in access_history
]
activation_value = self.actr_calculator.calculate_memory_activation(
access_history=access_history_dt,
current_time=current_time,
last_access_time=last_access_dt,
importance_score=importance_score
)
repair_data['activation_value'] = activation_value
# 执行修复
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
query += """
SET n += $repair_data
RETURN n
"""
params = {
'node_id': node_id,
'repair_data': repair_data
}
if group_id:
params['group_id'] = group_id
await self.connector.execute_query(query, **params)
logger.info(
f"成功修复节点不一致: {node_label}[{node_id}], "
f"问题类型={result.value}"
)
return True
except Exception as e:
logger.error(
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
)
return False
# ==================== 私有辅助方法 ====================
async def _fetch_node(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
获取节点数据
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None
"""
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
query += """
RETURN n.id as id,
n.importance_score as importance_score,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count
"""
params = {'node_id': node_id}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
if results:
return results[0]
return None
async def _calculate_update(
self,
node_data: Dict[str, Any],
current_time: datetime,
current_time_iso: str
) -> Dict[str, Any]:
"""
计算更新数据
Args:
node_data: 当前节点数据
current_time: 当前时间datetime对象
current_time_iso: 当前时间ISO格式字符串
Returns:
Dict[str, Any]: 更新数据,包含所有需要更新的字段
"""
access_history = node_data.get('access_history', [])
importance_score = node_data.get('importance_score', 0.5)
# 追加新的访问时间
new_access_history = access_history + [current_time_iso]
# 修剪访问历史(如果过长)
access_history_dt = [
datetime.fromisoformat(ts) for ts in new_access_history
]
trimmed_history_dt = self.actr_calculator.trim_access_history(
access_history=access_history_dt,
current_time=current_time
)
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
# 计算新的激活值
activation_value = self.actr_calculator.calculate_memory_activation(
access_history=trimmed_history_dt,
current_time=current_time,
last_access_time=current_time, # 最后访问时间就是当前时间
importance_score=importance_score
)
# 返回所有需要更新的字段
return {
'activation_value': activation_value,
'access_history': trimmed_history,
'last_access_time': current_time_iso,
'access_count': len(trimmed_history)
}
async def _atomic_update(
self,
node_id: str,
node_label: str,
update_data: Dict[str, Any],
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
原子性更新节点(使用乐观锁)
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
实现乐观锁机制防止并发冲突。
Args:
node_id: 节点ID
node_label: 节点标签
update_data: 更新数据
group_id: 组ID可选
Returns:
Dict[str, Any]: 更新后的节点数据
Raises:
RuntimeError: 如果更新失败或发生版本冲突
"""
# 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id):
# 步骤1读取当前节点并获取版本号
read_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
read_query += " WHERE n.group_id = $group_id"
read_query += """
RETURN n.id as id,
n.version as version,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score
"""
read_params = {'node_id': node_id}
if group_id:
read_params['group_id'] = group_id
read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single()
if not current_node:
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
# 获取当前版本号如果不存在则为0
current_version = current_node.get('version', 0) or 0
new_version = current_version + 1
# 步骤2使用乐观锁更新节点
# 只有当版本号匹配时才更新
update_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
update_query += " WHERE n.group_id = $group_id"
# 添加版本检查
if current_version > 0:
update_query += " AND n.version = $current_version"
else:
# 如果节点没有版本号,检查是否为首次更新
update_query += " AND (n.version IS NULL OR n.version = 0)"
update_query += """
SET n.activation_value = $activation_value,
n.access_history = $access_history,
n.last_access_time = $last_access_time,
n.access_count = $access_count,
n.version = $new_version
RETURN n.id as id,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score,
n.version as version
"""
update_params = {
'node_id': node_id,
'current_version': current_version,
'new_version': new_version,
'activation_value': update_data['activation_value'],
'access_history': update_data['access_history'],
'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count']
}
if group_id:
update_params['group_id'] = group_id
update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single()
if not updated_node:
raise RuntimeError(
f"Version conflict detected for {node_label}[{node_id}]. "
f"Expected version {current_version}, but node was modified by another transaction."
)
return dict(updated_node)
# 执行事务
try:
result = await self.connector.execute_write_transaction(
update_transaction,
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
)
return result
except Exception as e:
logger.error(
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
)
raise RuntimeError(
f"Failed to atomically update node: {str(e)}"
) from e

View File

@@ -0,0 +1,359 @@
"""
ACT-R Memory Activation Calculator
This module implements the unified Memory Activation model based on ACT-R
(Adaptive Control of Thought-Rational) cognitive architecture theory.
The calculator integrates BLA (Base-Level Activation) computation into the
Memory Activation formula, providing a single coherent model for memory strength
calculation that reflects both recency and frequency of access.
Formula: R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
Where:
- R(i): Memory activation value (0 to 1)
- offset: Minimum retention rate (prevents complete forgetting)
- λ: Forgetting rate (lambda_time / lambda_mem)
- t: Time since last access
- I: Importance score (0 to 1)
- t_k: Time since k-th access
- d: Decay constant (typically 0.5)
Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe?
"""
import math
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
class ACTRCalculator:
"""
Unified ACT-R Memory Activation Calculator.
This calculator implements the Memory Activation model that combines
recency and frequency effects into a single activation value computation.
It replaces the separate BLA calculation with an integrated approach.
Attributes:
decay_constant: Decay parameter d (typically 0.5)
forgetting_rate: Lambda parameter λ controlling forgetting speed
offset: Minimum retention rate (baseline memory strength)
max_history_length: Maximum number of access records to keep
"""
def __init__(
self,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1,
max_history_length: int = 100
):
"""
Initialize the ACT-R calculator.
Args:
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
max_history_length: Maximum access history length (default 100)
"""
self.decay_constant = decay_constant
self.forgetting_rate = forgetting_rate
self.offset = offset
self.max_history_length = max_history_length
def calculate_memory_activation(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5
) -> float:
"""
Calculate memory activation value using the unified Memory Activation formula.
This method computes R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
The formula integrates:
- Recency effect: Recent accesses contribute more (via t)
- Frequency effect: Multiple accesses strengthen memory (via Σ)
- Importance weighting: Important memories decay slower (via I)
Args:
access_history: List of access timestamps (ISO format or datetime objects)
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
Returns:
float: Memory activation value between offset and 1.0
Raises:
ValueError: If access_history is empty or contains invalid data
"""
if not access_history:
raise ValueError("access_history cannot be empty")
if not (0.0 <= importance_score <= 1.0):
raise ValueError(f"importance_score must be between 0 and 1, got {importance_score}")
# Calculate time since last access (in days)
time_since_last = (current_time - last_access_time).total_seconds() / 86400.0
time_since_last = max(time_since_last, 0.0001) # Avoid division by zero
# Calculate BLA component: Σ(I·t_k^(-d))
bla_sum = 0.0
for access_time in access_history:
# Calculate time since this access (in days)
time_diff = (current_time - access_time).total_seconds() / 86400.0
time_diff = max(time_diff, 0.0001) # Avoid division by zero
# Add weighted power-law term: I * t_k^(-d)
bla_sum += importance_score * (time_diff ** (-self.decay_constant))
# Avoid division by zero in case of numerical issues
if bla_sum <= 0:
bla_sum = 0.0001
# Calculate Memory Activation: R(i) = offset + (1-offset) * exp(-λ*t / BLA)
exponent = -self.forgetting_rate * time_since_last / bla_sum
# Clamp exponent to avoid numerical overflow/underflow
exponent = max(min(exponent, 100), -100)
activation = self.offset + (1 - self.offset) * math.exp(exponent)
# Ensure activation is within valid range [offset, 1.0]
return max(self.offset, min(1.0, activation))
def trim_access_history(
self,
access_history: List[datetime],
current_time: datetime
) -> List[datetime]:
"""
Intelligently trim access history to prevent unbounded growth.
Strategy:
- Keep all records if under max_history_length
- If over limit, keep most recent 50% and sample from older records
- Preserves both recent accesses (high importance) and historical pattern
Args:
access_history: List of access timestamps (sorted or unsorted)
current_time: Current time for calculation
Returns:
List[datetime]: Trimmed access history
"""
if len(access_history) <= self.max_history_length:
return access_history
# Sort by time (most recent first)
sorted_history = sorted(access_history, reverse=True)
# Calculate split point (keep most recent 50%)
keep_recent_count = self.max_history_length // 2
# Keep most recent 50%
recent_records = sorted_history[:keep_recent_count]
# Sample from older records
older_records = sorted_history[keep_recent_count:]
sample_count = self.max_history_length - keep_recent_count
if len(older_records) <= sample_count:
# If older records fit, keep them all
sampled_older = older_records
else:
# Sample evenly from older records
step = len(older_records) / sample_count
sampled_older = [
older_records[int(i * step)]
for i in range(sample_count)
]
# Combine and return
trimmed_history = recent_records + sampled_older
return sorted(trimmed_history, reverse=True)
def get_forgetting_curve( # 预测激活值决定复习测试不同配置效果选择合适的d
self,
initial_time: datetime,
importance_score: float = 0.5,
days: int = 60
) -> List[Dict[str, Any]]:
"""
Generate forgetting curve data for visualization.
This method simulates how memory activation decays over time
for a single initial access, useful for understanding and
visualizing the forgetting behavior.
Args:
initial_time: Time of initial memory creation/access
importance_score: Importance weight (0 to 1, default 0.5)
days: Number of days to simulate (default 60)
Returns:
List of dictionaries with keys:
- 'day': Day number (0 to days)
- 'activation': Memory activation value
- 'retention_rate': Same as activation (for compatibility)
"""
curve_data = []
access_history = [initial_time]
for day in range(days + 1):
current_time = initial_time + timedelta(days=day)
try:
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=initial_time,
importance_score=importance_score
)
except ValueError:
# Handle edge cases
activation = self.offset
curve_data.append({
'day': day,
'activation': activation,
'retention_rate': activation # Alias for compatibility
})
return curve_data
def calculate_forgetting_score(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5
) -> float:
"""
Calculate forgetting score (inverse of activation).
Forgetting score = 1 - activation value
Higher score means more likely to be forgotten.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
Returns:
float: Forgetting score between 0 and (1 - offset)
"""
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
return 1.0 - activation
def should_forget(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5,
threshold: float = 0.3
) -> bool:
"""
Determine if a memory should be forgotten based on activation threshold.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
threshold: Activation threshold below which memory should be forgotten
Returns:
bool: True if activation < threshold (should forget), False otherwise
"""
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
return activation < threshold
# Convenience functions for quick calculations
def calculate_activation(
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1
) -> float:
"""
Quick function to calculate activation without creating a calculator instance.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
Returns:
float: Memory activation value between offset and 1.0
"""
calculator = ACTRCalculator(
decay_constant=decay_constant,
forgetting_rate=forgetting_rate,
offset=offset
)
return calculator.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
def generate_forgetting_curve(
initial_time: datetime,
importance_score: float = 0.5,
days: int = 60,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1
) -> List[Dict[str, Any]]:
"""
Quick function to generate forgetting curve data.
Args:
initial_time: Time of initial memory creation/access
importance_score: Importance weight (0 to 1, default 0.5)
days: Number of days to simulate (default 60)
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
Returns:
List of dictionaries with forgetting curve data
"""
calculator = ACTRCalculator(
decay_constant=decay_constant,
forgetting_rate=forgetting_rate,
offset=offset
)
return calculator.get_forgetting_curve(
initial_time=initial_time,
importance_score=importance_score,
days=days
)

View File

@@ -0,0 +1,195 @@
"""
遗忘引擎配置工具模块
本模块提供从数据库加载配置并创建遗忘引擎组件的辅助函数。
Functions:
calculate_forgetting_rate: 计算遗忘速率lambda_time / lambda_mem
load_actr_config_from_db: 从数据库加载 ACT-R 配置参数
create_actr_calculator_from_config: 从配置创建 ACTRCalculator 实例
"""
import logging
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
"""
计算遗忘速率
公式forgetting_rate = lambda_time / lambda_mem
这个计算将两个独立的 lambda 参数组合成一个统一的遗忘速率参数,
用于 ACT-R 激活值计算。
Args:
lambda_time: 时间衰减参数0-1
lambda_mem: 记忆衰减参数0-1
Returns:
float: 遗忘速率
Raises:
ValueError: 如果 lambda_mem 为 0
Examples:
>>> calculate_forgetting_rate(0.5, 0.5)
1.0
>>> calculate_forgetting_rate(0.3, 0.5)
0.6
"""
if lambda_mem == 0:
raise ValueError("lambda_mem 不能为 0")
forgetting_rate = lambda_time / lambda_mem
logger.debug(
f"计算遗忘速率: lambda_time={lambda_time}, "
f"lambda_mem={lambda_mem}, "
f"forgetting_rate={forgetting_rate:.4f}"
)
return forgetting_rate
def load_actr_config_from_db(
db: Session,
config_id: Optional[int] = None
) -> Dict[str, Any]:
"""
从数据库加载 ACT-R 配置参数
从 PostgreSQL 的 data_config 表读取配置参数,
并计算派生参数(如 forgetting_rate
Args:
db: 数据库会话
config_id: 配置 ID可选如果为 None 则使用默认值)
Returns:
Dict[str, Any]: 配置参数字典,包含:
- decay_constant: 衰减常数 d
- lambda_time: 时间衰减参数
- lambda_mem: 记忆衰减参数
- forgetting_rate: 遗忘速率(根据 lambda_time / lambda_mem 计算得出)
- offset: 偏移量
- max_history_length: 访问历史最大长度
- forgetting_threshold: 遗忘阈值
- min_days_since_access: 最小未访问天数
- enable_llm_summary: 是否使用 LLM 生成摘要
- max_merge_batch_size: 单次最大融合节点对数
- forgetting_interval_hours: 遗忘周期间隔
注意llm_id 不包含在返回的配置中,需要时由 forgetting_strategy 直接从数据库读取
Raises:
ValueError: 如果指定的 config_id 不存在
"""
# 必须指定 config_id
if config_id is None:
logger.error("未指定 config_id无法加载配置")
raise ValueError("config_id 不能为空,必须指定一个有效的配置 ID")
# 从数据库加载配置
try:
repository = DataConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None:
logger.error(f"配置不存在: config_id={config_id}")
raise ValueError(f"配置不存在: config_id={config_id}")
# 读取配置参数(信任数据库默认值)
lambda_time = db_config.lambda_time
lambda_mem = db_config.lambda_mem
decay_constant = db_config.decay_constant
offset = db_config.offset
max_history_length = db_config.max_history_length
forgetting_threshold = db_config.forgetting_threshold
min_days_since_access = db_config.min_days_since_access
enable_llm_summary = db_config.enable_llm_summary
max_merge_batch_size = db_config.max_merge_batch_size
forgetting_interval_hours = db_config.forgetting_interval_hours
# 计算 forgetting_rate
forgetting_rate = calculate_forgetting_rate(lambda_time, lambda_mem)
config = {
'decay_constant': decay_constant,
'lambda_time': lambda_time,
'lambda_mem': lambda_mem,
'forgetting_rate': forgetting_rate,
'offset': offset,
'max_history_length': max_history_length,
'forgetting_threshold': forgetting_threshold,
'min_days_since_access': min_days_since_access,
'enable_llm_summary': enable_llm_summary,
'max_merge_batch_size': max_merge_batch_size,
'forgetting_interval_hours': forgetting_interval_hours
# 注意llm_id 不包含在配置响应中,仅在内部使用
}
logger.info(
f"成功加载 ACT-R 配置: config_id={config_id}, "
f"forgetting_rate={forgetting_rate:.4f}"
)
return config
except Exception as e:
logger.error(f"加载 ACT-R 配置失败: config_id={config_id}, 错误: {str(e)}")
raise
def create_actr_calculator_from_config(
db: Session,
config_id: Optional[int] = None
) -> ACTRCalculator:
"""
从数据库配置创建 ACTRCalculator 实例
这是创建 ACTRCalculator 的推荐方式,确保使用数据库中的配置参数。
Args:
db: 数据库会话
config_id: 配置 ID可选如果为 None 则使用默认值)
Returns:
ACTRCalculator: 配置好的 ACT-R 计算器实例
Raises:
ValueError: 如果指定的 config_id 不存在
Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
"""
# 加载配置
config = load_actr_config_from_db(db, config_id)
# 创建计算器
calculator = ACTRCalculator(
decay_constant=config['decay_constant'],
forgetting_rate=config['forgetting_rate'],
offset=config['offset'],
max_history_length=config['max_history_length']
)
logger.info(
f"创建 ACTRCalculator: config_id={config_id}, "
f"decay_constant={config['decay_constant']}, "
f"forgetting_rate={config['forgetting_rate']:.4f}, "
f"offset={config['offset']}"
)
return calculator

View File

@@ -0,0 +1,351 @@
"""
遗忘调度器模块
本模块实现遗忘周期的调度和管理,负责:
1. 手动触发遗忘周期
2. 批量处理可遗忘节点(限制批量大小)
3. 按激活值优先级排序(激活值最低的优先)
4. 进度跟踪和日志记录
5. 生成遗忘报告
注意:定期调度功能已迁移到 Celery Beat见 app/tasks.py 中的 run_forgetting_cycle_task
Classes:
ForgettingScheduler: 遗忘调度器,提供遗忘周期管理功能
"""
import logging
from typing import Dict, Any, Optional
from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
class ForgettingScheduler:
"""
遗忘调度器
管理遗忘周期的执行,实现批量处理、优先级排序和进度跟踪功能。
核心功能:
1. 运行遗忘周期:识别可遗忘节点并批量融合
2. 优先级排序:优先处理激活值最低的节点对
3. 批量限制:限制单次处理的节点对数量
4. 进度跟踪:每完成 10% 记录一次日志
5. 遗忘报告:生成详细的执行报告
注意:定期调度功能已迁移到 Celery Beat 定时任务
Attributes:
forgetting_strategy: 遗忘策略执行器实例
connector: Neo4j 连接器实例
is_running: 是否正在运行遗忘周期
"""
def __init__(
self,
forgetting_strategy: ForgettingStrategy,
connector: Neo4jConnector
):
"""
初始化遗忘调度器
Args:
forgetting_strategy: 遗忘策略执行器实例
connector: Neo4j 连接器实例
"""
self.forgetting_strategy = forgetting_strategy
self.connector = connector
self.is_running = False
logger.info("初始化遗忘调度器")
async def run_forgetting_cycle(
self,
group_id: Optional[str] = None,
max_merge_batch_size: int = 100,
min_days_since_access: int = 30,
config_id: Optional[int] = None,
db = None
) -> Dict[str, Any]:
"""
运行一次完整的遗忘周期
Args:
group_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
Dict[str, Any]: 遗忘报告,包含:
- merged_count: 融合的节点对数量
- nodes_before: 遗忘前的节点总数
- nodes_after: 遗忘后的节点总数
- reduction_rate: 节点减少率0-1
- duration_seconds: 执行耗时(秒)
- start_time: 开始时间ISO 格式)
- end_time: 结束时间ISO 格式)
- failed_count: 失败的融合数量
- success_rate: 成功率0-1
Raises:
RuntimeError: 如果已有遗忘周期正在运行
"""
# 检查是否已有遗忘周期在运行
if self.is_running:
raise RuntimeError("遗忘周期已在运行中,请等待当前周期完成")
self.is_running = True
start_time = datetime.now()
start_time_iso = start_time.isoformat()
logger.info(
f"开始遗忘周期: group_id={group_id}, "
f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}"
)
try:
# 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id)
logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id,
min_days_since_access=min_days_since_access
)
total_forgettable = len(forgettable_pairs)
logger.info(f"识别到 {total_forgettable} 个可遗忘节点对")
if total_forgettable == 0:
logger.info("没有可遗忘的节点对,遗忘周期结束")
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
report = {
'merged_count': 0,
'nodes_before': nodes_before,
'nodes_after': nodes_before,
'reduction_rate': 0.0,
'duration_seconds': duration,
'start_time': start_time_iso,
'end_time': end_time.isoformat(),
'failed_count': 0,
'success_rate': 1.0
}
logger.info("没有可遗忘的节点对,遗忘周期结束")
return report
# 步骤3按激活值排序激活值最低的优先
# avg_activation 已经在 find_forgettable_nodes 中计算并排序
# 这里只需要确认排序是正确的(升序)
sorted_pairs = sorted(
forgettable_pairs,
key=lambda x: x['avg_activation']
)
# 步骤4限制批量大小
pairs_to_process = sorted_pairs[:max_merge_batch_size]
actual_batch_size = len(pairs_to_process)
logger.info(
f"将处理 {actual_batch_size} 个节点对 "
f"(限制: {max_merge_batch_size})"
)
# 步骤5批量融合节点每 10% 记录进度
merged_count = 0
failed_count = 0
skipped_count = 0 # 跳过的节点对数量(节点已被处理)
progress_interval = max(1, actual_batch_size // 10) # 每 10% 记录一次
# 跟踪已处理的节点 ID避免重复处理
processed_statement_ids = set()
processed_entity_ids = set()
# 预先过滤掉重复的节点对
unique_pairs = []
for pair in pairs_to_process:
statement_id = pair['statement_id']
entity_id = pair['entity_id']
# 如果节点已被标记为处理,跳过
if statement_id in processed_statement_ids or entity_id in processed_entity_ids:
skipped_count += 1
logger.debug(
f"预过滤:跳过重复节点对 Statement[{statement_id}] + Entity[{entity_id}]"
)
continue
# 标记节点为已处理
processed_statement_ids.add(statement_id)
processed_entity_ids.add(entity_id)
unique_pairs.append(pair)
logger.info(
f"预过滤完成:原始 {actual_batch_size} 对,去重后 {len(unique_pairs)} 对,"
f"跳过 {skipped_count} 对重复节点"
)
# 更新实际处理的批次大小
actual_batch_size = len(unique_pairs)
progress_interval = max(1, actual_batch_size // 10) # 重新计算进度间隔
for idx, pair in enumerate(unique_pairs, start=1):
statement_id = pair['statement_id']
entity_id = pair['entity_id']
try:
# 准备节点数据
statement_node = {
'statement_id': statement_id,
'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'],
'group_id': group_id
}
entity_node = {
'entity_id': entity_id,
'entity_name': pair['entity_name'],
'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'],
'group_id': group_id
}
# 融合节点
await self.forgetting_strategy.merge_nodes_to_summary(
statement_node=statement_node,
entity_node=entity_node,
config_id=config_id,
db=db
)
merged_count += 1
# 进度跟踪:每 10% 记录一次
if actual_batch_size > 0 and (idx % progress_interval == 0 or idx == actual_batch_size):
progress_pct = (idx / actual_batch_size) * 100
logger.info(
f"遗忘进度: {idx}/{actual_batch_size} "
f"({progress_pct:.1f}%), "
f"已融合: {merged_count}, 失败: {failed_count}"
)
except Exception as e:
failed_count += 1
# 检查是否是节点不存在的错误
if "nodes may not exist" in str(e):
logger.warning(
f"节点对 ({idx}/{actual_batch_size}) 的节点不存在(可能已被其他操作删除): "
f"Statement[{statement_id}] + Entity[{entity_id}]"
)
else:
logger.error(
f"融合节点对失败 ({idx}/{actual_batch_size}): "
f"Statement[{statement_id}] + Entity[{entity_id}], "
f"错误: {str(e)}"
)
# 继续处理剩余节点
continue
# 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id)
logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
# 计算节点减少率
if nodes_before > 0:
reduction_rate = (nodes_before - nodes_after) / nodes_before
else:
reduction_rate = 0.0
# 计算成功率
if actual_batch_size > 0:
success_rate = merged_count / actual_batch_size
else:
success_rate = 1.0
report = {
'merged_count': merged_count,
'nodes_before': nodes_before,
'nodes_after': nodes_after,
'reduction_rate': reduction_rate,
'duration_seconds': duration,
'start_time': start_time_iso,
'end_time': end_time.isoformat(),
'failed_count': failed_count,
'success_rate': success_rate
}
logger.info(
f"遗忘周期完成: "
f"融合 {merged_count} 对节点, "
f"失败 {failed_count} 对, "
f"节点减少 {nodes_before - nodes_after}"
f"({reduction_rate:.2%}), "
f"耗时 {duration:.2f}"
)
return report
except Exception as e:
logger.error(f"遗忘周期执行失败: {str(e)}")
raise
finally:
self.is_running = False
# ==================== 私有辅助方法 ====================
async def _count_knowledge_nodes(
self,
group_id: Optional[str] = None
) -> int:
"""
统计知识层节点总数
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args:
group_id: 组 ID可选用于过滤特定组的节点
Returns:
int: 知识层节点总数
"""
query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
query += """
RETURN count(n) as total
"""
params = {}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
if results:
return results[0]['total']
return 0

View File

@@ -0,0 +1,611 @@
"""
遗忘策略执行器模块
本模块实现基于 ACT-R 激活值的遗忘策略,负责:
1. 识别低激活值的节点对Statement-Entity
2. 将低激活值节点融合为 MemorySummary 节点
3. 使用 LLM 生成高质量摘要(可选)
4. 保留溯源信息并删除原始节点
Classes:
ForgettingStrategy: 遗忘策略执行器,提供节点识别和融合功能
"""
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
class ForgettingStrategy:
"""
遗忘策略执行器
基于 ACT-R 激活值识别和融合低价值记忆节点。
实现了完整的遗忘周期:识别 → 融合 → 删除。
核心功能:
1. 识别可遗忘节点:激活值低于阈值且长期未访问的 Statement-Entity 对
2. 节点融合:创建 MemorySummary 节点,继承较高的激活值和重要性
3. LLM 摘要生成:使用 LLM 生成语义摘要(可降级到简单拼接)
4. 溯源保留:记录原始节点 ID保持可追溯性
Attributes:
connector: Neo4j 连接器实例
actr_calculator: ACT-R 激活值计算器实例
forgetting_threshold: 遗忘阈值(激活值低于此值的节点可被遗忘)
"""
def __init__(
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
forgetting_threshold: float = 0.3,
enable_llm_summary: bool = True
):
"""
初始化遗忘策略执行器
Args:
connector: Neo4j 连接器实例
actr_calculator: ACT-R 激活值计算器实例
forgetting_threshold: 遗忘阈值(默认 0.3
enable_llm_summary: 是否启用 LLM 摘要生成(默认 True
"""
self.connector = connector
self.actr_calculator = actr_calculator
self.forgetting_threshold = forgetting_threshold
self.enable_llm_summary = enable_llm_summary
logger.info(
f"初始化遗忘策略执行器: threshold={forgetting_threshold}, "
f"enable_llm_summary={enable_llm_summary}"
)
async def calculate_forgetting_score(
self,
activation_value: float
) -> float:
"""
计算遗忘分数
遗忘分数 = 1 - 激活值
分数越高,越容易被遗忘。
注意:激活值已经包含了 importance_score 的权重,
因此不需要单独考虑重要性分数。
Args:
activation_value: 节点的激活值0-1
Returns:
float: 遗忘分数0-1值越高越容易被遗忘
"""
return 1.0 - activation_value
async def find_forgettable_nodes(
self,
group_id: Optional[str] = None,
min_days_since_access: int = 30
) -> List[Dict[str, Any]]:
"""
识别可遗忘的节点对
查找满足以下条件的 Statement-Entity 节点对:
1. 两个节点的激活值都低于遗忘阈值
2. 两个节点都至少 min_days_since_access 天未被访问
3. Statement 和 Entity 之间存在关系边
Args:
group_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天)
Returns:
List[Dict[str, Any]]: 可遗忘节点对列表,每个元素包含:
- statement_id: Statement 节点 ID
- statement_text: Statement 文本内容
- statement_activation: Statement 激活值
- statement_importance: Statement 重要性分数
- statement_last_access: Statement 最后访问时间
- entity_id: Entity 节点 ID
- entity_name: Entity 名称
- entity_type: Entity 类型
- entity_activation: Entity 激活值
- entity_importance: Entity 重要性分数
- entity_last_access: Entity 最后访问时间
- avg_activation: 平均激活值(用于排序)
"""
# 计算时间阈值
cutoff_time = datetime.now() - timedelta(days=min_days_since_access)
cutoff_time_iso = cutoff_time.isoformat()
# 构建查询
query = """
MATCH (s:Statement)-[r]-(e:ExtractedEntity)
WHERE s.activation_value IS NOT NULL
AND e.activation_value IS NOT NULL
AND s.activation_value < $threshold
AND e.activation_value < $threshold
AND s.last_access_time < $cutoff_time
AND e.last_access_time < $cutoff_time
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
"""
if group_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
query += """
RETURN s.id as statement_id,
s.statement as statement_text,
s.activation_value as statement_activation,
s.importance_score as statement_importance,
s.last_access_time as statement_last_access,
e.id as entity_id,
e.name as entity_name,
e.entity_type as entity_type,
e.activation_value as entity_activation,
e.importance_score as entity_importance,
e.last_access_time as entity_last_access,
(s.activation_value + e.activation_value) / 2.0 as avg_activation
ORDER BY avg_activation ASC
"""
params = {
'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso
}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
logger.info(
f"识别到 {len(results)} 个可遗忘节点对 "
f"(threshold={self.forgetting_threshold}, "
f"min_days={min_days_since_access})"
)
return results
async def merge_nodes_to_summary(
self,
statement_node: Dict[str, Any],
entity_node: Dict[str, Any],
config_id: Optional[int] = None,
db = None
) -> str:
"""
将 Statement 和 Entity 节点融合为 MemorySummary 节点
融合过程:
1. 生成摘要内容(使用 LLM 或简单拼接)
2. 创建 MemorySummary 节点,继承较高的激活值和重要性分数
3. 删除原始 Statement 和 Entity 节点
4. 保留溯源信息original_statement_id, original_entity_id
Args:
statement_node: Statement 节点数据,必须包含:
- statement_id: 节点 ID
- statement_text: 文本内容
- statement_activation: 激活值
- statement_importance: 重要性分数
entity_node: Entity 节点数据,必须包含:
- entity_id: 节点 ID
- entity_name: 实体名称
- entity_type: 实体类型
- entity_activation: 激活值
- entity_importance: 重要性分数
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
str: 创建的 MemorySummary 节点 ID
Raises:
ValueError: 如果节点数据不完整
RuntimeError: 如果融合操作失败
"""
# 验证输入数据
required_statement_keys = [
'statement_id', 'statement_text',
'statement_activation', 'statement_importance'
]
required_entity_keys = [
'entity_id', 'entity_name', 'entity_type',
'entity_activation', 'entity_importance'
]
for key in required_statement_keys:
if key not in statement_node:
raise ValueError(f"Statement 节点缺少必需字段: {key}")
for key in required_entity_keys:
if key not in entity_node:
raise ValueError(f"Entity 节点缺少必需字段: {key}")
# 验证实体类型:不允许融合 Person 类型的实体
if entity_node.get('entity_type') == 'Person':
raise ValueError(
f"不允许融合 Person 类型的实体: entity_id={entity_node.get('entity_id')}, "
f"entity_name={entity_node.get('entity_name')}"
)
# 提取节点信息
statement_id = statement_node['statement_id']
statement_text = statement_node['statement_text']
statement_activation = statement_node['statement_activation']
statement_importance = statement_node['statement_importance']
entity_id = entity_node['entity_id']
entity_name = entity_node['entity_name']
entity_type = entity_node['entity_type']
entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance']
# 生成摘要内容
summary_text = await self._generate_summary(
statement_text=statement_text,
entity_name=entity_name,
entity_type=entity_type,
config_id=config_id,
db=db
)
# 计算继承的激活值和重要性(取较高值)
inherited_activation = max(statement_activation, entity_activation)
inherited_importance = max(statement_importance, entity_importance)
# 创建 MemorySummary 节点
current_time = datetime.now()
current_time_iso = current_time.isoformat()
# 生成新的 MemorySummary ID
import uuid
summary_id = f"summary_{uuid.uuid4().hex[:16]}"
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 使用事务创建 MemorySummary 并删除原节点
async def merge_transaction(tx, **params):
"""事务函数:创建摘要节点并删除原节点"""
query = """
// 首先检查节点是否存在
OPTIONAL MATCH (s:Statement {id: $statement_id})
OPTIONAL MATCH (e:ExtractedEntity {id: $entity_id})
// 如果任一节点不存在,直接返回 null不执行后续操作
WITH s, e
WHERE s IS NOT NULL AND e IS NOT NULL
// 创建 MemorySummary 节点
CREATE (ms:MemorySummary {
id: $summary_id,
summary: $summary_text,
original_statement_id: $statement_id,
original_entity_id: $entity_id,
activation_value: $inherited_activation,
importance_score: $inherited_importance,
access_history: [$current_time],
last_access_time: $current_time,
access_count: 1,
version: 1,
group_id: $group_id,
created_at: datetime($current_time),
merged_at: datetime($current_time)
})
// 转移 Statement 的出边到 MemorySummary只转移目标节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (s)-[r_out]->(target)
WHERE target <> e AND r_out IS NOT NULL AND target IS NOT NULL
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
ON CREATE SET
new_rel = properties(r_out),
new_rel.original_relationship_type = type(r_out),
new_rel.merged_from_statement = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Statement 的入边到 MemorySummary只转移源节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (source)-[r_in]->(s)
WHERE r_in IS NOT NULL AND source IS NOT NULL
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
ON CREATE SET
new_rel = properties(r_in),
new_rel.original_relationship_type = type(r_in),
new_rel.merged_from_statement = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Entity 的出边到 MemorySummary只转移目标节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (e)-[r_out]->(target)
WHERE target <> s AND r_out IS NOT NULL AND target IS NOT NULL
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
ON CREATE SET
new_rel = properties(r_out),
new_rel.original_relationship_type = type(r_out),
new_rel.merged_from_entity = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Entity 的入边到 MemorySummary只转移源节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (source)-[r_in]->(e)
WHERE source <> s AND r_in IS NOT NULL AND source IS NOT NULL
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
ON CREATE SET
new_rel = properties(r_in),
new_rel.original_relationship_type = type(r_in),
new_rel.merged_from_entity = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 删除原始节点
WITH ms, s, e
DETACH DELETE s, e
RETURN ms.id as summary_id
"""
result = await tx.run(query, **params)
record = await result.single()
if not record:
raise RuntimeError("Failed to create MemorySummary node - nodes may not exist")
return record['summary_id']
params = {
'summary_id': summary_id,
'summary_text': summary_text,
'statement_id': statement_id,
'entity_id': entity_id,
'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance,
'current_time': current_time_iso,
'group_id': group_id
}
try:
created_summary_id = await self.connector.execute_write_transaction(
merge_transaction,
**params
)
logger.info(
f"成功融合节点: Statement[{statement_id}] + Entity[{entity_id}] "
f"-> MemorySummary[{created_summary_id}], "
f"activation={inherited_activation:.4f}, "
f"importance={inherited_importance:.4f}"
)
return created_summary_id
except Exception as e:
# 记录详细的错误信息,包括异常类型和堆栈
import traceback
error_details = traceback.format_exc()
logger.error(
f"融合节点失败: Statement[{statement_id}] + Entity[{entity_id}], "
f"错误类型: {type(e).__name__}, "
f"错误信息: {str(e)}, "
f"详细堆栈:\n{error_details}"
)
raise RuntimeError(
f"融合节点失败: {str(e)}"
) from e
# ==================== 私有辅助方法 ====================
async def _generate_summary(
self,
statement_text: str,
entity_name: str,
entity_type: str,
config_id: Optional[int] = None,
db = None
) -> str:
"""
生成摘要内容
优先使用 LLM 生成高质量摘要,如果 LLM 不可用或失败,
则降级到简单文本拼接。
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
str: 生成的摘要文本(最多 200 个字符)
"""
# 如果配置禁用 LLM 摘要,直接使用简单拼接
if not self.enable_llm_summary:
logger.info("LLM 摘要生成已禁用,使用简单拼接")
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
# 尝试获取 LLM 客户端
llm_client = None
if config_id is not None and db is not None:
try:
llm_client = await self._get_llm_client(db, config_id)
except Exception as e:
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
# 如果没有 LLM 客户端,直接使用简单拼接
if llm_client is None:
logger.info("未能获取 LLM 客户端,使用简单拼接")
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
# 尝试使用 LLM 生成摘要
try:
summary = await self._generate_llm_summary(
statement_text=statement_text,
entity_name=entity_name,
entity_type=entity_type,
llm_client=llm_client
)
# 限制长度为 200 个字符
if len(summary) > 200:
summary = f"{summary[:197]}..."
logger.info(f"使用 LLM 生成摘要: {summary}")
return summary
except Exception as e:
logger.warning(
f"LLM 摘要生成失败,降级到简单拼接: {str(e)}"
)
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
async def _get_llm_client(self, db, config_id: int):
"""
从数据库获取 LLM 客户端
Args:
db: 数据库会话
config_id: 配置ID
Returns:
LLM 客户端实例,如果无法获取则返回 None
"""
try:
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 从数据库读取配置
repository = DataConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None or db_config.llm_id is None:
logger.warning(f"配置 {config_id} 不存在或未设置 llm_id")
return None
# 创建 LLM 客户端
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(db_config.llm_id))
logger.info(f"成功获取 LLM 客户端: config_id={config_id}, llm_id={db_config.llm_id}")
return llm_client
except Exception as e:
logger.error(f"获取 LLM 客户端失败: {str(e)}")
return None
async def _generate_llm_summary(
self,
statement_text: str,
entity_name: str,
entity_type: str,
llm_client
) -> str:
"""
使用 LLM 生成高质量摘要
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
llm_client: LLM 客户端实例
Returns:
str: LLM 生成的摘要文本
Raises:
Exception: 如果 LLM 调用失败
"""
# 构建提示词
prompt = f"""请为以下记忆片段生成一个简洁的摘要不超过200个字符
实体名称: {entity_name}
实体类型: {entity_type}
陈述内容: {statement_text}
要求:
1. 摘要应该保留核心语义信息
2. 长度不超过200个字符
3. 使用简洁、自然的中文表达
4. 只返回摘要文本,不要包含其他内容
摘要:"""
# 调用 LLM直接传递 prompt 字符串)
response = await llm_client.chat(prompt)
# 提取摘要文本
if isinstance(response, str):
summary = response.strip()
elif hasattr(response, 'content'):
summary = response.content.strip()
else:
summary = str(response).strip()
return summary
def _simple_concatenation(
self,
statement_text: str,
entity_name: str,
entity_type: str
) -> str:
"""
简单文本拼接生成摘要
降级策略:当 LLM 不可用时使用。
格式:[实体类型]实体名称: 陈述内容
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
Returns:
str: 拼接的摘要文本(最多 200 个字符)
"""
# 构建简单摘要
summary = f"[{entity_type}]{entity_name}: {statement_text}"
# 限制长度为 200 个字符(注意:这里的长度是字符数,不是字节数)
if len(summary) > 200:
# 截断并添加省略号
summary = f"{summary[:197]}..."
return summary

View File

@@ -65,6 +65,15 @@ class DataConfig(Base):
lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数") lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率0-1 小数")
offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数") offset = Column("offset", Float, default=0.0, comment="偏移度0-1 小数")
# ACT-R 遗忘引擎配置
decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d默认0.5")
forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值默认0.3")
forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔小时默认24")
enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要默认True")
max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数默认100")
max_history_length = Column(Integer, default=100, comment="访问历史最大长度默认100")
min_days_since_access = Column(Integer, default=30, comment="最小未访问天数默认30")
# 情绪引擎配置 # 情绪引擎配置
emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取") emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取")
emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID") emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID")

View File

@@ -106,7 +106,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC
"emotion_intensity": statement.emotion_intensity, "emotion_intensity": statement.emotion_intensity,
"emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [], "emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [],
"emotion_subject": statement.emotion_subject, "emotion_subject": statement.emotion_subject,
"emotion_target": statement.emotion_target "emotion_target": statement.emotion_target,
# 添加 ACT-R 记忆激活属性
"importance_score": statement.importance_score,
"activation_value": statement.activation_value,
"access_history": statement.access_history if statement.access_history else [],
"last_access_time": statement.last_access_time,
"access_count": statement.access_count
} }
flattened_statements.append(flattened_statement) flattened_statements.append(flattened_statement)

View File

@@ -38,7 +38,12 @@ SET s += {
valid_at: statement.valid_at, valid_at: statement.valid_at,
invalid_at: statement.invalid_at, invalid_at: statement.invalid_at,
statement_embedding: statement.statement_embedding, statement_embedding: statement.statement_embedding,
relevence_info: statement.relevence_info relevence_info: statement.relevence_info,
importance_score: statement.importance_score,
activation_value: statement.activation_value,
access_history: statement.access_history,
last_access_time: statement.last_access_time,
access_count: statement.access_count
} }
RETURN s.id AS uuid RETURN s.id AS uuid
""" """
@@ -111,7 +116,12 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity
WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength
ELSE e.connect_strength ELSE e.connect_strength
END END
END END,
e.importance_score = CASE WHEN entity.importance_score IS NOT NULL THEN entity.importance_score ELSE coalesce(e.importance_score, 0.5) END,
e.activation_value = CASE WHEN entity.activation_value IS NOT NULL THEN entity.activation_value ELSE e.activation_value END,
e.access_history = CASE WHEN entity.access_history IS NOT NULL THEN entity.access_history ELSE coalesce(e.access_history, []) END,
e.last_access_time = CASE WHEN entity.last_access_time IS NOT NULL THEN entity.last_access_time ELSE e.last_access_time END,
e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END
RETURN e.id AS uuid RETURN e.id AS uuid
""" """
@@ -225,6 +235,10 @@ RETURN e.id AS id,
e.name AS name, e.name AS name,
e.group_id AS group_id, e.group_id AS group_id,
e.entity_type AS entity_type, e.entity_type AS entity_type,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
@@ -243,6 +257,10 @@ RETURN s.id AS id,
s.expired_at AS expired_at, s.expired_at AS expired_at,
s.valid_at AS valid_at, s.valid_at AS valid_at,
s.invalid_at AS invalid_at, s.invalid_at AS invalid_at,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
@@ -258,6 +276,9 @@ RETURN c.id AS chunk_id,
c.group_id AS group_id, c.group_id AS group_id,
c.content AS content, c.content AS content,
c.dialog_id AS dialog_id, c.dialog_id AS dialog_id,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
@@ -278,6 +299,10 @@ RETURN s.id AS id,
s.invalid_at AS invalid_at, s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel, c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids, collect(DISTINCT e.id) AS entity_ids,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
@@ -305,6 +330,10 @@ RETURN e.id AS id,
e.connect_strength AS connect_strength, e.connect_strength AS connect_strength,
collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT c.id) AS chunk_ids, collect(DISTINCT c.id) AS chunk_ids,
COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value,
COALESCE(e.importance_score, 0.5) AS importance_score,
e.last_access_time AS last_access_time,
COALESCE(e.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
@@ -322,6 +351,9 @@ RETURN c.id AS chunk_id,
c.sequence_number AS sequence_number, c.sequence_number AS sequence_number,
collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT s.id) AS statement_ids,
collect(DISTINCT e.id) AS entity_ids, collect(DISTINCT e.id) AS entity_ids,
COALESCE(c.activation_value, 0.5) AS activation_value,
c.last_access_time AS last_access_time,
COALESCE(c.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
@@ -419,7 +451,11 @@ RETURN s.id AS id,
s.created_at AS created_at, s.created_at AS created_at,
s.valid_at AS valid_at, s.valid_at AS valid_at,
s.invalid_at AS invalid_at, s.invalid_at AS invalid_at,
collect(DISTINCT s.id) AS statement_ids collect(DISTINCT s.id) AS statement_ids,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count
ORDER BY datetime(s.created_at) DESC ORDER BY datetime(s.created_at) DESC
LIMIT $limit LIMIT $limit
""" """
@@ -446,6 +482,10 @@ RETURN s.id AS id,
s.invalid_at AS invalid_at, s.invalid_at AS invalid_at,
c.id AS chunk_id_from_rel, c.id AS chunk_id_from_rel,
collect(DISTINCT e.id) AS entity_ids, collect(DISTINCT e.id) AS entity_ids,
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
COALESCE(s.importance_score, 0.5) AS importance_score,
s.last_access_time AS last_access_time,
COALESCE(s.access_count, 0) AS access_count,
score score
ORDER BY s.created_at DESC, score DESC ORDER BY s.created_at DESC, score DESC
LIMIT $limit LIMIT $limit
@@ -635,6 +675,10 @@ RETURN m.id AS id,
m.chunk_ids AS chunk_ids, m.chunk_ids AS chunk_ids,
m.content AS content, m.content AS content,
m.created_at AS created_at, m.created_at AS created_at,
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
COALESCE(m.importance_score, 0.5) AS importance_score,
m.last_access_time AS last_access_time,
COALESCE(m.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit
@@ -653,6 +697,10 @@ RETURN m.id AS id,
m.chunk_ids AS chunk_ids, m.chunk_ids AS chunk_ids,
m.content AS content, m.content AS content,
m.created_at AS created_at, m.created_at AS created_at,
COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value,
COALESCE(m.importance_score, 0.5) AS importance_score,
m.last_access_time AS last_access_time,
COALESCE(m.access_count, 0) AS access_count,
score score
ORDER BY score DESC ORDER BY score DESC
LIMIT $limit LIMIT $limit

View File

@@ -55,6 +55,13 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]):
if 'aliases' not in n or n['aliases'] is None: if 'aliases' not in n or n['aliases'] is None:
n['aliases'] = [] n['aliases'] = []
# 处理 ACT-R 属性 - 确保字段存在且有默认值
n['importance_score'] = n.get('importance_score', 0.5)
n['activation_value'] = n.get('activation_value')
n['access_history'] = n.get('access_history', [])
n['last_access_time'] = n.get('last_access_time')
n['access_count'] = n.get('access_count', 0)
return ExtractedEntityNode(**n) return ExtractedEntityNode(**n)
async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]: async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]:

View File

@@ -1,5 +1,6 @@
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import asyncio import asyncio
import logging
# 使用新的仓储层 # 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -24,6 +25,157 @@ from app.repositories.neo4j.cypher_queries import (
SEARCH_STATEMENTS_L_VALID_AT, SEARCH_STATEMENTS_L_VALID_AT,
) )
logger = logging.getLogger(__name__)
async def _update_activation_values_batch(
connector: Neo4jConnector,
nodes: List[Dict[str, Any]],
node_label: str,
group_id: Optional[str] = None,
max_retries: int = 3
) -> List[Dict[str, Any]]:
"""
批量更新节点的激活值
为提高性能,批量更新多个节点的访问历史和激活值。
使用重试机制处理更新失败的情况。
Args:
connector: Neo4j连接器
nodes: 节点列表,每个节点必须包含 'id' 字段
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选
max_retries: 最大重试次数
Returns:
List[Dict[str, Any]]: 成功更新的节点列表
"""
if not nodes:
return []
# 延迟导入以避免循环依赖
from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
# 创建计算器和管理器实例
actr_calculator = ACTRCalculator()
access_manager = AccessHistoryManager(
connector=connector,
actr_calculator=actr_calculator,
max_retries=max_retries
)
# 提取节点ID列表
node_ids = [node.get('id') for node in nodes if node.get('id')]
if not node_ids:
logger.warning(f"批量更新激活值没有有效的节点ID")
return nodes
# 批量记录访问
try:
updated_nodes = await access_manager.record_batch_access(
node_ids=node_ids,
node_label=node_label,
group_id=group_id
)
logger.info(
f"批量更新激活值成功: {node_label}, "
f"更新数量={len(updated_nodes)}/{len(node_ids)}"
)
return updated_nodes
except Exception as e:
logger.error(
f"批量更新激活值失败: {node_label}, 错误: {str(e)}"
)
# 失败时返回原始节点列表
return nodes
async def _update_search_results_activation(
connector: Neo4jConnector,
results: Dict[str, List[Dict[str, Any]]],
group_id: Optional[str] = None
) -> Dict[str, List[Dict[str, Any]]]:
"""
更新搜索结果中所有知识节点的激活值
对 Statement、ExtractedEntity、MemorySummary 节点进行批量激活值更新。
ChunkNode 和 DialogueNode 不参与激活值更新(数据层隔离)。
Args:
connector: Neo4j连接器
results: 搜索结果字典,包含不同类型节点的列表
group_id: 组ID可选
Returns:
Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果
"""
# 定义需要更新激活值的节点类型
knowledge_node_types = {
'statements': 'Statement',
'entities': 'ExtractedEntity',
'summaries': 'MemorySummary'
}
# 并行更新所有类型的节点
update_tasks = []
update_keys = []
for key, label in knowledge_node_types.items():
if key in results and results[key]:
update_tasks.append(
_update_activation_values_batch(
connector=connector,
nodes=results[key],
node_label=label,
group_id=group_id
)
)
update_keys.append(key)
if not update_tasks:
return results
# 并行执行所有更新
update_results = await asyncio.gather(*update_tasks, return_exceptions=True)
# 更新结果字典,保留原始搜索分数
updated_results = results.copy()
for key, update_result in zip(update_keys, update_results):
if not isinstance(update_result, Exception):
# 更新成功,合并原始搜索结果和更新后的激活值数据
# 保留原始的 score 字段BM25/Embedding 分数)
original_nodes = results[key]
updated_nodes = update_result
# 创建 ID 到原始节点的映射(用于快速查找 score
original_map = {node.get('id'): node for node in original_nodes if node.get('id')}
# 合并数据激活值来自更新结果score 来自原始结果
merged_nodes = []
for updated_node in updated_nodes:
node_id = updated_node.get('id')
if node_id and node_id in original_map:
# 保留原始的 score 字段
original_score = original_map[node_id].get('score')
if original_score is not None:
updated_node['score'] = original_score
merged_nodes.append(updated_node)
updated_results[key] = merged_nodes
else:
# 更新失败,记录错误但保留原始结果
logger.warning(
f"更新 {key} 激活值失败: {str(update_result)}"
)
return updated_results
async def search_graph( async def search_graph(
connector: Neo4jConnector, connector: Neo4jConnector,
@@ -36,6 +188,7 @@ async def search_graph(
Search across Statements, Entities, Chunks, and Summaries using a free-text query. Search across Statements, Entities, Chunks, and Summaries using a free-text query.
OPTIMIZED: Runs all queries in parallel using asyncio.gather() OPTIMIZED: Runs all queries in parallel using asyncio.gather()
INTEGRATED: Updates activation values for knowledge nodes before returning results
- Statements: matches s.statement CONTAINS q - Statements: matches s.statement CONTAINS q
- Entities: matches e.name CONTAINS q - Entities: matches e.name CONTAINS q
@@ -50,7 +203,7 @@ async def search_graph(
include: List of categories to search (default: all) include: List of categories to search (default: all)
Returns: Returns:
Dictionary with search results per category Dictionary with search results per category (with updated activation values)
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries"] include = ["statements", "chunks", "entities", "summaries"]
@@ -106,6 +259,13 @@ async def search_graph(
else: else:
results[key] = result results[key] = result
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results return results
@@ -121,6 +281,7 @@ async def search_graph_by_embedding(
Embedding-based semantic search across Statements, Chunks, and Entities. Embedding-based semantic search across Statements, Chunks, and Entities.
OPTIMIZED: Runs all queries in parallel using asyncio.gather() OPTIMIZED: Runs all queries in parallel using asyncio.gather()
INTEGRATED: Updates activation values for knowledge nodes before returning results
- Computes query embedding with the provided embedder_client - Computes query embedding with the provided embedder_client
- Ranks by cosine similarity in Cypher - Ranks by cosine similarity in Cypher
@@ -203,6 +364,16 @@ async def search_graph_by_embedding(
else: else:
results[key] = result results[key] = result
# 更新知识节点的激活值Statement, ExtractedEntity, MemorySummary
update_start = time.time()
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
update_time = time.time() - update_start
print(f"[PERF] Activation value updates took: {update_time:.4f}s")
return results return results
async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体
connector: Neo4jConnector, connector: Neo4jConnector,
@@ -304,6 +475,8 @@ async def search_graph_by_keyword_temporal(
) -> Dict[str, List[Any]]: ) -> Dict[str, List[Any]]:
""" """
Temporal keyword search across Statements. Temporal keyword search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements containing query_text created between start_date and end_date - Matches statements containing query_text created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -326,7 +499,15 @@ async def search_graph_by_keyword_temporal(
) )
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements} # 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
async def search_graph_by_temporal( async def search_graph_by_temporal(
@@ -342,6 +523,8 @@ async def search_graph_by_temporal(
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created between start_date and end_date - Matches statements created between start_date and end_date
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -362,7 +545,16 @@ async def search_graph_by_temporal(
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}")
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements}
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
async def search_graph_by_dialog_id( async def search_graph_by_dialog_id(
@@ -419,6 +611,8 @@ async def search_graph_by_created_at(
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -436,7 +630,16 @@ async def search_graph_by_created_at(
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements}
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
async def search_graph_by_valid_at( async def search_graph_by_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
@@ -448,6 +651,8 @@ async def search_graph_by_valid_at(
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -465,7 +670,16 @@ async def search_graph_by_valid_at(
print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements}
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
async def search_graph_g_created_at( async def search_graph_g_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
@@ -477,6 +691,8 @@ async def search_graph_g_created_at(
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -494,7 +710,16 @@ async def search_graph_g_created_at(
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements}
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
async def search_graph_g_valid_at( async def search_graph_g_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
@@ -506,6 +731,8 @@ async def search_graph_g_valid_at(
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -523,7 +750,16 @@ async def search_graph_g_valid_at(
print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements}
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
async def search_graph_l_created_at( async def search_graph_l_created_at(
connector: Neo4jConnector, connector: Neo4jConnector,
@@ -535,6 +771,8 @@ async def search_graph_l_created_at(
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements created at created_at - Matches statements created at created_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -552,7 +790,16 @@ async def search_graph_l_created_at(
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements}
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results
async def search_graph_l_valid_at( async def search_graph_l_valid_at(
connector: Neo4jConnector, connector: Neo4jConnector,
@@ -564,6 +811,8 @@ async def search_graph_l_valid_at(
) -> Dict[str, List[Dict[str, Any]]]: ) -> Dict[str, List[Dict[str, Any]]]:
""" """
Temporal search across Statements. Temporal search across Statements.
INTEGRATED: Updates activation values for Statement nodes before returning results
- Matches statements valid at valid_at - Matches statements valid at valid_at
- Optionally filters by group_id, apply_id, user_id - Optionally filters by group_id, apply_id, user_id
@@ -581,4 +830,13 @@ async def search_graph_l_valid_at(
print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}")
print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}")
print(f"查询结果为:\n{statements}") print(f"查询结果为:\n{statements}")
return {"statements": statements}
# 更新 Statement 节点的激活值
results = {"statements": statements}
results = await _update_search_results_activation(
connector=connector,
results=results,
group_id=group_id
)
return results

View File

@@ -8,7 +8,6 @@ Classes:
Neo4jConnector: Neo4j数据库连接器提供异步查询接口 Neo4jConnector: Neo4j数据库连接器提供异步查询接口
""" """
import os
from typing import Any, List, Dict from typing import Any, List, Dict
from neo4j import AsyncGraphDatabase, basic_auth from neo4j import AsyncGraphDatabase, basic_auth
@@ -85,6 +84,63 @@ class Neo4jConnector:
records, summary, keys = result records, summary, keys = result
return [record.data() for record in records] return [record.data() for record in records]
async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在写事务中执行操作
提供显式事务支持,确保操作的原子性。
如果事务函数抛出异常,所有更改将自动回滚。
Args:
transaction_func: 事务函数,接收 tx 参数并执行查询
**kwargs: 传递给事务函数的额外参数
Returns:
Any: 事务函数的返回值
Example:
>>> async def create_node(tx, name):
... result = await tx.run(
... "CREATE (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_write_transaction(
... create_node, name="Alice"
... )
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_write(transaction_func, **kwargs)
async def execute_read_transaction(self, transaction_func, **kwargs: Any) -> Any:
"""在读事务中执行操作
提供显式事务支持用于读操作。
Args:
transaction_func: 事务函数,接收 tx 参数并执行查询
**kwargs: 传递给事务函数的额外参数
Returns:
Any: 事务函数的返回值
Example:
>>> async def get_node(tx, name):
... result = await tx.run(
... "MATCH (n:Person {name: $name}) RETURN n",
... name=name
... )
... return await result.single()
>>>
>>> connector = Neo4jConnector()
>>> result = await connector.execute_read_transaction(
... get_node, name="Alice"
... )
"""
async with self.driver.session(database="neo4j") as session:
return await session.execute_read(transaction_func, **kwargs)
async def delete_group(self, group_id: str): async def delete_group(self, group_id: str):
"""删除指定组的所有数据 """删除指定组的所有数据

View File

@@ -75,6 +75,13 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]):
n['emotion_subject'] = n.get('emotion_subject') n['emotion_subject'] = n.get('emotion_subject')
n['emotion_target'] = n.get('emotion_target') n['emotion_target'] = n.get('emotion_target')
# 处理 ACT-R 属性 - 确保字段存在且有默认值
n['importance_score'] = n.get('importance_score', 0.5)
n['activation_value'] = n.get('activation_value')
n['access_history'] = n.get('access_history', [])
n['last_access_time'] = n.get('last_access_time')
n['access_count'] = n.get('access_count', 0)
return StatementNode(**n) return StatementNode(**n)
async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]: async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]:

View File

@@ -399,3 +399,104 @@ class GenerateCacheRequest(BaseModel):
None, None,
description="终端用户IDUUID格式。如果提供只为该用户生成如果不提供为当前工作空间的所有用户生成" description="终端用户IDUUID格式。如果提供只为该用户生成如果不提供为当前工作空间的所有用户生成"
) )
# ============================================================================
# 遗忘引擎相关 Schema
# ============================================================================
class ForgettingTriggerRequest(BaseModel):
"""手动触发遗忘周期请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
group_id: Optional[str] = Field(None, description="组ID可选用于过滤特定组的节点")
max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数默认100")
min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数默认30天")
config_id: Optional[int] = Field(None, description="配置ID可选用于指定遗忘引擎配置") # TODO 后续group_id更换成enduser_id自动与config_id关联 ,要删除此行
class ForgettingConfigResponse(BaseModel):
"""遗忘引擎配置响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID")
decay_constant: float = Field(..., description="衰减常数 d")
lambda_time: float = Field(..., description="时间衰减参数")
lambda_mem: float = Field(..., description="记忆衰减参数")
forgetting_rate: float = Field(..., description="遗忘速率(根据 lambda_time / lambda_mem 计算得出)")
offset: float = Field(..., description="偏移量")
max_history_length: int = Field(..., description="访问历史最大长度")
forgetting_threshold: float = Field(..., description="遗忘阈值")
min_days_since_access: int = Field(..., description="最小未访问天数")
enable_llm_summary: bool = Field(..., description="是否使用 LLM 生成摘要")
max_merge_batch_size: int = Field(..., description="单次最大融合节点对数")
forgetting_interval_hours: int = Field(..., description="遗忘周期间隔(小时)")
class ForgettingConfigUpdateRequest(BaseModel):
"""遗忘引擎配置更新请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
config_id: int = Field(..., description="配置ID")
decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d")
lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数")
lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数")
offset: Optional[float] = Field(None, ge=0.0, le=1.0, description="偏移量")
max_history_length: Optional[int] = Field(None, ge=10, le=1000, description="访问历史最大长度")
forgetting_threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="遗忘阈值")
min_days_since_access: Optional[int] = Field(None, ge=1, le=365, description="最小未访问天数")
enable_llm_summary: Optional[bool] = Field(None, description="是否使用 LLM 生成摘要")
max_merge_batch_size: Optional[int] = Field(None, ge=1, le=1000, description="单次最大融合节点对数")
forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="遗忘周期间隔(小时)")
class ForgettingStatsResponse(BaseModel):
"""遗忘引擎统计信息响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标")
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
consistency_check: Optional[Dict[str, Any]] = Field(None, description="数据一致性检查结果")
nodes_merged_total: int = Field(..., description="累计融合节点对数")
recent_cycles: List[Dict[str, Any]] = Field(..., description="最近的遗忘周期记录")
timestamp: str = Field(..., description="统计时间ISO格式")
class ForgettingReportResponse(BaseModel):
"""遗忘周期报告响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
merged_count: int = Field(..., description="融合的节点对数量")
nodes_before: int = Field(..., description="遗忘前的节点总数")
nodes_after: int = Field(..., description="遗忘后的节点总数")
reduction_rate: float = Field(..., description="节点减少率0-1")
duration_seconds: float = Field(..., description="执行耗时(秒)")
start_time: str = Field(..., description="开始时间ISO格式")
end_time: str = Field(..., description="结束时间ISO格式")
failed_count: int = Field(..., description="失败的融合数量")
success_rate: float = Field(..., description="成功率0-1")
class ForgettingCurvePoint(BaseModel):
"""遗忘曲线数据点模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
day: int = Field(..., description="天数")
activation: float = Field(..., description="激活值")
retention_rate: float = Field(..., description="保持率(与激活值相同)")
class ForgettingCurveRequest(BaseModel):
"""遗忘曲线请求模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数0-1")
days: int = Field(60, ge=1, le=365, description="模拟天数默认60天")
config_id: Optional[int] = Field(None, description="配置ID可选如果为None则使用默认配置")
class ForgettingCurveResponse(BaseModel):
"""遗忘曲线响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表")
config: Dict[str, Any] = Field(..., description="使用的配置参数")

View File

@@ -0,0 +1,460 @@
"""
遗忘引擎服务层模块
本模块提供遗忘引擎的业务逻辑实现,包括:
1. 遗忘周期执行
2. 配置管理
3. 统计信息查询
4. 遗忘曲线生成
所有业务逻辑从控制器层分离到此服务层。
"""
from typing import Optional, Dict, Any, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import ForgettingScheduler
from app.core.memory.storage_services.forgetting_engine.config_utils import (
load_actr_config_from_db,
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.data_config_repository import DataConfigRepository
# 获取API专用日志器
api_logger = get_api_logger()
class MemoryForgetService:
"""遗忘引擎服务类"""
def __init__(self):
"""初始化服务"""
self.config_repository = DataConfigRepository()
def _get_neo4j_connector(self) -> Neo4jConnector:
"""
获取 Neo4j 连接器实例
Returns:
Neo4jConnector: Neo4j 连接器实例
"""
# 这里应该从配置或依赖注入获取连接器
# 暂时创建新实例(实际应该使用单例或连接池)
return Neo4jConnector()
async def _get_forgetting_components(
self,
db: Session,
config_id: Optional[int] = None
) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]:
"""
获取遗忘引擎组件(计算器、策略、调度器)
Args:
db: 数据库会话
config_id: 配置ID可选
Returns:
tuple: (actr_calculator, forgetting_strategy, forgetting_scheduler, config)
"""
# 加载配置
config = load_actr_config_from_db(db, config_id)
# 创建 ACT-R 计算器
actr_calculator = ACTRCalculator(
decay_constant=config['decay_constant'],
forgetting_rate=config['forgetting_rate'],
offset=config['offset'],
max_history_length=config['max_history_length']
)
# 获取 Neo4j 连接器
connector = self._get_neo4j_connector()
# 创建遗忘策略执行器
forgetting_strategy = ForgettingStrategy(
connector=connector,
actr_calculator=actr_calculator,
forgetting_threshold=config['forgetting_threshold'],
enable_llm_summary=config['enable_llm_summary']
)
# 创建遗忘调度器
forgetting_scheduler = ForgettingScheduler(
forgetting_strategy=forgetting_strategy,
connector=connector
)
return actr_calculator, forgetting_strategy, forgetting_scheduler, config
async def _get_knowledge_stats(
self,
connector: Neo4jConnector,
group_id: Optional[str] = None,
forgetting_threshold: float = 0.3
) -> Dict[str, Any]:
"""
获取知识层统计信息
Args:
connector: Neo4j 连接器
group_id: 组ID可选
forgetting_threshold: 遗忘阈值
Returns:
dict: 统计信息字典
"""
# 构建查询
query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
query += """
WITH n,
CASE
WHEN n:Statement THEN 'statement'
WHEN n:ExtractedEntity THEN 'entity'
WHEN n:MemorySummary THEN 'summary'
END as node_type
RETURN
count(n) as total_nodes,
sum(CASE WHEN node_type = 'statement' THEN 1 ELSE 0 END) as statement_count,
sum(CASE WHEN node_type = 'entity' THEN 1 ELSE 0 END) as entity_count,
sum(CASE WHEN node_type = 'summary' THEN 1 ELSE 0 END) as summary_count,
avg(n.activation_value) as average_activation,
sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
"""
params = {'threshold': forgetting_threshold}
if group_id:
params['group_id'] = group_id
results = await connector.execute_query(query, **params)
if results:
result = results[0]
return {
'total_nodes': result['total_nodes'] or 0,
'statement_count': result['statement_count'] or 0,
'entity_count': result['entity_count'] or 0,
'summary_count': result['summary_count'] or 0,
'average_activation': result['average_activation'],
'low_activation_nodes': result['low_activation_nodes'] or 0
}
return {
'total_nodes': 0,
'statement_count': 0,
'entity_count': 0,
'summary_count': 0,
'average_activation': None,
'low_activation_nodes': 0
}
async def trigger_forgetting_cycle(
self,
db: Session,
group_id: Optional[str] = None,
max_merge_batch_size: Optional[int] = None,
min_days_since_access: Optional[int] = None,
config_id: Optional[int] = None
) -> Dict[str, Any]:
"""
手动触发遗忘周期
执行一次完整的遗忘周期,识别并融合低激活值节点。
Args:
db: 数据库会话
group_id: 组ID可选
max_merge_batch_size: 最大融合批次大小(可选)
min_days_since_access: 最小未访问天数(可选)
config_id: 配置ID可选
Returns:
dict: 遗忘报告
"""
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
# 运行遗忘周期LLM 客户端将在需要时由 forgetting_strategy 内部获取)
report = await forgetting_scheduler.run_forgetting_cycle(
group_id=group_id,
max_merge_batch_size=max_merge_batch_size,
min_days_since_access=min_days_since_access,
config_id=config_id,
db=db
)
api_logger.info(
f"遗忘周期完成: 融合 {report['merged_count']} 对节点, "
f"失败 {report['failed_count']} 对, "
f"耗时 {report['duration_seconds']:.2f}"
)
return report
def read_forgetting_config(
self,
db: Session,
config_id: int
) -> Dict[str, Any]:
"""
获取遗忘引擎配置
读取指定配置ID的遗忘引擎参数。
Args:
db: 数据库会话
config_id: 配置ID
Returns:
dict: 配置信息字典
"""
# 加载配置
config = load_actr_config_from_db(db, config_id)
# 添加 config_id 到返回结果
config['config_id'] = config_id
api_logger.info(f"成功读取遗忘引擎配置: config_id={config_id}")
return config
def update_forgetting_config(
self,
db: Session,
config_id: int,
update_fields: Dict[str, Any]
) -> Dict[str, Any]:
"""
更新遗忘引擎配置
更新指定配置ID的遗忘引擎参数。
Args:
db: 数据库会话
config_id: 配置ID
update_fields: 要更新的字段字典
Returns:
dict: 更新后的配置信息
Raises:
ValueError: 配置不存在
"""
# 检查配置是否存在
db_config = self.config_repository.get_by_id(db, config_id)
if db_config is None:
raise ValueError(f"配置不存在: {config_id}")
# 执行更新
if update_fields:
for key, value in update_fields.items():
if hasattr(db_config, key):
setattr(db_config, key, value)
db.commit()
db.refresh(db_config)
api_logger.info(
f"成功更新遗忘引擎配置: config_id={config_id}, "
f"更新字段: {list(update_fields.keys())}"
)
else:
api_logger.info(f"没有字段需要更新: config_id={config_id}")
# 重新加载配置并返回
config = load_actr_config_from_db(db, config_id)
config['config_id'] = config_id
return config
async def get_forgetting_stats(
self,
db: Session,
group_id: Optional[str] = None,
config_id: Optional[int] = None
) -> Dict[str, Any]:
"""
获取遗忘引擎统计信息
返回知识层节点统计、激活值分布等信息。
Args:
db: 数据库会话
group_id: 组ID可选
config_id: 配置ID可选用于获取遗忘阈值
Returns:
dict: 统计信息字典
"""
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
connector = forgetting_scheduler.connector
forgetting_threshold = config['forgetting_threshold']
# 收集激活值指标
activation_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
"""
if group_id:
activation_query += " AND n.group_id = $group_id"
activation_query += """
RETURN
count(n) as total_nodes,
sum(CASE WHEN n.activation_value IS NOT NULL THEN 1 ELSE 0 END) as nodes_with_activation,
sum(CASE WHEN n.activation_value IS NULL THEN 1 ELSE 0 END) as nodes_without_activation,
avg(n.activation_value) as average_activation,
sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
"""
params = {'threshold': forgetting_threshold}
if group_id:
params['group_id'] = group_id
activation_results = await connector.execute_query(activation_query, **params)
if activation_results:
result = activation_results[0]
activation_metrics = {
'total_nodes': result['total_nodes'] or 0,
'nodes_with_activation': result['nodes_with_activation'] or 0,
'nodes_without_activation': result['nodes_without_activation'] or 0,
'average_activation_value': result['average_activation'],
'low_activation_nodes': result['low_activation_nodes'] or 0,
'timestamp': datetime.now().isoformat()
}
else:
activation_metrics = {
'total_nodes': 0,
'nodes_with_activation': 0,
'nodes_without_activation': 0,
'average_activation_value': None,
'low_activation_nodes': 0,
'timestamp': datetime.now().isoformat()
}
# 收集节点类型分布
distribution_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
"""
if group_id:
distribution_query += " AND n.group_id = $group_id"
distribution_query += """
WITH n,
CASE
WHEN n:Statement THEN 'statement'
WHEN n:ExtractedEntity THEN 'entity'
WHEN n:MemorySummary THEN 'summary'
WHEN n:Chunk THEN 'chunk'
END as node_type
RETURN
sum(CASE WHEN node_type = 'statement' THEN 1 ELSE 0 END) as statement_count,
sum(CASE WHEN node_type = 'entity' THEN 1 ELSE 0 END) as entity_count,
sum(CASE WHEN node_type = 'summary' THEN 1 ELSE 0 END) as summary_count,
sum(CASE WHEN node_type = 'chunk' THEN 1 ELSE 0 END) as chunk_count
"""
dist_params = {}
if group_id:
dist_params['group_id'] = group_id
distribution_results = await connector.execute_query(distribution_query, **dist_params)
if distribution_results:
result = distribution_results[0]
node_distribution = {
'statement_count': result['statement_count'] or 0,
'entity_count': result['entity_count'] or 0,
'summary_count': result['summary_count'] or 0,
'chunk_count': result['chunk_count'] or 0
}
else:
node_distribution = {
'statement_count': 0,
'entity_count': 0,
'summary_count': 0,
'chunk_count': 0
}
# 构建统计信息(不包含监控历史数据)
stats = {
'activation_metrics': activation_metrics,
'node_distribution': node_distribution,
'consistency_check': None, # 不再提供一致性检查
'nodes_merged_total': 0, # 不再跟踪累计融合数
'recent_cycles': [], # 不再提供历史记录
'timestamp': datetime.now().isoformat()
}
api_logger.info(
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}"
)
return stats
async def get_forgetting_curve(
self,
db: Session,
importance_score: float,
days: int,
config_id: Optional[int] = None
) -> Dict[str, Any]:
"""
获取遗忘曲线数据
生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。
Args:
db: 数据库会话
importance_score: 重要性分数0-1
days: 模拟天数
config_id: 配置ID可选
Returns:
dict: 包含曲线数据和配置的字典
"""
# 获取 ACT-R 计算器
actr_calculator, _, _, config = await self._get_forgetting_components(db, config_id)
# 生成遗忘曲线数据
initial_time = datetime.now()
curve_data = actr_calculator.get_forgetting_curve(
initial_time=initial_time,
importance_score=importance_score,
days=days
)
api_logger.info(
f"成功生成遗忘曲线数据: {len(curve_data)} 个数据点"
)
return {
'curve_data': curve_data,
'config': {
'decay_constant': config['decay_constant'],
'forgetting_rate': config['forgetting_rate'],
'offset': config['offset'],
'importance_score': importance_score,
'days': days
}
}

View File

@@ -8,7 +8,6 @@ import asyncio
import json import json
import os import os
import time import time
import uuid
from datetime import datetime from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
@@ -26,7 +25,6 @@ from app.schemas.memory_storage_schema import (
ConfigPilotRun, ConfigPilotRun,
ConfigUpdate, ConfigUpdate,
ConfigUpdateExtracted, ConfigUpdateExtracted,
ConfigUpdateForget,
) )
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.utils.sse_utils import format_sse_message from app.utils.sse_utils import format_sse_message
@@ -159,11 +157,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
return {"affected": 1} return {"affected": 1}
# --- Forget config params --- # --- Forget config params ---
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置 # 遗忘引擎配置方法已迁移到 memory_forget_service.py
config = DataConfigRepository.update_forget(self.db, update) # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config()
if not config:
raise ValueError("未找到配置")
return {"affected": 1}
# --- Read --- # --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
@@ -172,12 +167,6 @@ class DataConfigService: # 数据配置服务类PostgreSQL
raise ValueError("未找到配置") raise ValueError("未找到配置")
return result return result
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数
result = DataConfigRepository.get_forget_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
return result
# --- Read All --- # --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
configs = DataConfigRepository.get_all(self.db, workspace_id) configs = DataConfigRepository.get_all(self.db, workspace_id)

View File

@@ -412,7 +412,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str,
actual_config_id = connected_config.get("memory_config_id") actual_config_id = connected_config.get("memory_config_id")
finally: finally:
db.close() db.close()
except Exception as e: except Exception:
# Log but continue - will fail later with proper error # Log but continue - will fail later with proper error
pass pass
@@ -499,7 +499,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage
actual_config_id = connected_config.get("memory_config_id") actual_config_id = connected_config.get("memory_config_id")
finally: finally:
db.close() db.close()
except Exception as e: except Exception:
# Log but continue - will fail later with proper error # Log but continue - will fail later with proper error
pass pass
@@ -1064,4 +1064,74 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
"error": str(e), "error": str(e),
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
} }
@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True)
def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str, Any]:
"""定时任务:运行遗忘周期
定期执行遗忘周期,识别并融合低激活值的知识节点。
Args:
config_id: 配置ID可选如果为None则使用默认配置
Returns:
包含任务执行结果的字典
"""
start_time = time.time()
async def _run() -> Dict[str, Any]:
from app.core.logging_config import get_api_logger
from app.services.memory_forget_service import MemoryForgetService
api_logger = get_api_logger()
with get_db_context() as db:
try:
api_logger.info(f"开始执行遗忘周期定时任务config_id: {config_id}")
forget_service = MemoryForgetService()
# 运行遗忘周期
report = await forget_service.trigger_forgetting(
db=db,
group_id=None, # 处理所有组
config_id=config_id
)
duration = time.time() - start_time
api_logger.info(
f"遗忘周期定时任务完成: "
f"融合 {report['merged_count']} 对节点, "
f"失败 {report['failed_count']} 对, "
f"耗时 {duration:.2f}"
)
return {
"status": "SUCCESS",
"message": "遗忘周期执行成功",
"report": report,
"duration_seconds": duration
}
except Exception as e:
duration = time.time() - start_time
api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True)
return {
"status": "FAILED",
"message": f"遗忘周期执行失败: {str(e)}",
"duration_seconds": duration
}
# 运行异步函数
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(_run())
return result
finally:
loop.close()