From e8a5cfe7e310d8b4e380dde33f32322caa13e204 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=90=E5=8A=9B=E9=BD=90?= Date: Mon, 5 Jan 2026 04:30:36 +0000 Subject: [PATCH] Merge #85 into develop from feature/actr-forget MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [feature]actr-记忆遗忘需求开发 * feature/actr-forget: (12 commits squashed) - [feature] 1.Extended fields of the date_config table; 2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j. - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process - [feature] 1.Extended fields of the date_config table; 2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j. - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process - Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget - [fix]Eliminate the interference caused by redundant code - [feature] 1.Extended fields of the date_config table; 2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j. - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process - Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget Signed-off-by: 乐力齐 Reviewed-by: aliyun6762716068 Merged-by: aliyun6762716068 CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/85 --- api/app/celery_app.py | 9 + api/app/controllers/__init__.py | 4 +- .../controllers/memory_forget_controller.py | 324 ++++++++ .../controllers/memory_storage_controller.py | 54 +- .../mcp_server/services/search_service.py | 29 +- api/app/core/memory/models/graph_models.py | 119 +++ api/app/core/memory/src/search.py | 612 +++++++++++----- .../forgetting_engine/__init__.py | 36 +- .../access_history_manager.py | 691 ++++++++++++++++++ .../forgetting_engine/actr_calculator.py | 359 +++++++++ .../forgetting_engine/config_utils.py | 195 +++++ .../forgetting_engine/forgetting_scheduler.py | 351 +++++++++ .../forgetting_engine/forgetting_strategy.py | 611 ++++++++++++++++ api/app/models/data_config_model.py | 9 + api/app/repositories/neo4j/add_nodes.py | 8 +- api/app/repositories/neo4j/cypher_queries.py | 54 +- .../repositories/neo4j/entity_repository.py | 7 + api/app/repositories/neo4j/graph_search.py | 276 ++++++- api/app/repositories/neo4j/neo4j_connector.py | 58 +- .../neo4j/statement_repository.py | 7 + api/app/schemas/memory_storage_schema.py | 101 +++ api/app/services/memory_forget_service.py | 460 ++++++++++++ api/app/services/memory_storage_service.py | 15 +- api/app/tasks.py | 76 +- 24 files changed, 4178 insertions(+), 287 deletions(-) create mode 100644 api/app/controllers/memory_forget_controller.py create mode 100644 api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py create mode 100644 api/app/core/memory/storage_services/forgetting_engine/actr_calculator.py create mode 100644 api/app/core/memory/storage_services/forgetting_engine/config_utils.py create mode 100644 api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py create mode 100644 api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py create mode 100644 api/app/services/memory_forget_service.py diff --git a/api/app/celery_app.py b/api/app/celery_app.py index 22bb73ae..85ad0643 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -85,6 +85,8 @@ health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS) memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME +forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期 + # 构建定时任务配置 beat_schedule_config = { @@ -103,6 +105,13 @@ beat_schedule_config = { "schedule": memory_cache_regeneration_schedule, "args": (), }, + "run-forgetting-cycle": { + "task": "app.tasks.run_forgetting_cycle_task", + "schedule": forgetting_cycle_schedule, + "kwargs": { + "config_id": None, # 使用默认配置,可以通过环境变量配置 + }, + }, } # 如果配置了默认工作空间ID,则添加记忆总量统计任务 diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 2cddfb30..e786fb65 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -33,7 +33,7 @@ from . import ( emotion_config_controller, prompt_optimizer_controller, tool_controller, - home_page_controller, + memory_forget_controller, ) from . import user_memory_controllers @@ -71,6 +71,6 @@ manager_router.include_router(emotion_config_controller.router) manager_router.include_router(prompt_optimizer_controller.router) manager_router.include_router(memory_reflection_controller.router) manager_router.include_router(tool_controller.router) -manager_router.include_router(home_page_controller.router) +manager_router.include_router(memory_forget_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/memory_forget_controller.py b/api/app/controllers/memory_forget_controller.py new file mode 100644 index 00000000..d4a76f6f --- /dev/null +++ b/api/app/controllers/memory_forget_controller.py @@ -0,0 +1,324 @@ +""" +遗忘引擎控制器模块 + +本模块提供遗忘引擎的 REST API 接口,包括: +1. 手动触发遗忘周期 +2. 获取和更新配置 +3. 获取统计信息 +4. 获取遗忘曲线数据 + +所有接口都需要用户认证,并自动关联到当前工作空间。 +""" + +from typing import Optional + +from fastapi import APIRouter, Depends +from sqlalchemy.orm import Session + +from app.core.error_codes import BizCode +from app.core.logging_config import get_api_logger +from app.core.response_utils import fail, success +from app.db import get_db +from app.dependencies import get_current_user +from app.models.user_model import User +from app.schemas.memory_storage_schema import ( + ForgettingTriggerRequest, + ForgettingConfigResponse, + ForgettingConfigUpdateRequest, + ForgettingStatsResponse, + ForgettingReportResponse, + ForgettingCurveRequest, + ForgettingCurveResponse, + ForgettingCurvePoint, +) +from app.schemas.response_schema import ApiResponse +from app.services.memory_forget_service import MemoryForgetService + + +# 获取API专用日志器 +api_logger = get_api_logger() + +router = APIRouter( + prefix="/memory/forget", + tags=["Memory Forgetting Engine"], + dependencies=[Depends(get_current_user)] # 所有路由都需要认证 +) + +# 初始化服务 +forget_service = MemoryForgetService() + + +# ==================== API 端点 ==================== + +@router.post("/trigger", response_model=ApiResponse) +async def trigger_forgetting_cycle( + payload: ForgettingTriggerRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 手动触发遗忘周期 + + 执行一次完整的遗忘周期,识别并融合低激活值节点。 + + Args: + payload: 触发请求参数 + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含遗忘报告的响应 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试触发遗忘周期但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求触发遗忘周期: " + f"group_id={payload.group_id}, max_batch={payload.max_merge_batch_size}, " + f"min_days={payload.min_days_since_access}" + ) + + try: + # 调用服务层执行遗忘周期 + report = await forget_service.trigger_forgetting_cycle( + db=db, + group_id=payload.group_id, + max_merge_batch_size=payload.max_merge_batch_size, + min_days_since_access=payload.min_days_since_access, + config_id=payload.config_id + ) + + # 构建响应 + response_data = ForgettingReportResponse(**report) + + return success(data=response_data.model_dump(), msg="遗忘周期执行成功") + + except RuntimeError as e: + api_logger.warning(f"遗忘周期执行被拒绝: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, str(e), "RuntimeError") + + except Exception as e: + api_logger.error(f"触发遗忘周期失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "触发遗忘周期失败", str(e)) + + +@router.get("/read_config", response_model=ApiResponse) +async def read_forgetting_config( + config_id: int, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 获取遗忘引擎配置 + + 读取指定配置ID的遗忘引擎参数。 + + Args: + config_id: 配置ID + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含配置信息的响应 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}" + ) + + try: + # 调用服务层读取配置 + config = forget_service.read_forgetting_config(db=db, config_id=config_id) + + # 构建响应 + response_data = ForgettingConfigResponse(**config) + + return success(data=response_data.model_dump(), msg="查询成功") + + except ValueError as e: + api_logger.warning(f"配置不存在: config_id={config_id}, 错误: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, f"配置不存在: {config_id}", str(e)) + + except Exception as e: + api_logger.error(f"读取遗忘引擎配置失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e)) + + +@router.post("/update_config", response_model=ApiResponse) +async def update_forgetting_config( + payload: ForgettingConfigUpdateRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 更新遗忘引擎配置 + + 更新指定配置ID的遗忘引擎参数。 + + Args: + payload: 配置更新请求 + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含更新结果的响应 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}" + ) + + try: + # 构建更新字段字典(排除 None 值和 config_id) + update_data = { + key: value + for key, value in payload.model_dump(exclude_none=True).items() + if key != 'config_id' + } + + # 调用服务层更新配置 + config = forget_service.update_forgetting_config( + db=db, + config_id=payload.config_id, + update_fields=update_data + ) + + # 构建响应 + response_data = ForgettingConfigResponse(**config) + + return success(data=response_data.model_dump(), msg="更新成功") + + except ValueError as e: + api_logger.warning(f"配置不存在: config_id={payload.config_id}, 错误: {str(e)}") + return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError") + + except Exception as e: + db.rollback() + api_logger.error(f"更新遗忘引擎配置失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e)) + + +@router.get("/stats", response_model=ApiResponse) +async def get_forgetting_stats( + group_id: Optional[str] = None, + config_id: Optional[int] = None, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 获取遗忘引擎统计信息 + + 返回知识层节点统计、激活值分布等信息。 + + Args: + group_id: 组ID(可选) + config_id: 配置ID(可选,用于获取遗忘阈值) + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含统计信息的响应 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: " + f"group_id={group_id}, config_id={config_id}" + ) + + try: + # 调用服务层获取统计信息 + stats = await forget_service.get_forgetting_stats( + db=db, + group_id=group_id, + config_id=config_id + ) + + # 构建响应 + response_data = ForgettingStatsResponse(**stats) + + return success(data=response_data.model_dump(), msg="查询成功") + + except Exception as e: + api_logger.error(f"获取遗忘引擎统计失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e)) + + +@router.post("/forgetting_curve", response_model=ApiResponse) +async def get_forgetting_curve( + request: ForgettingCurveRequest, + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db) +): + """ + 获取遗忘曲线数据 + + 生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。 + + Args: + request: 遗忘曲线请求参数 + current_user: 当前用户 + db: 数据库会话 + + Returns: + ApiResponse: 包含遗忘曲线数据的响应 + """ + workspace_id = current_user.current_workspace_id + + # 检查用户是否已选择工作空间 + if workspace_id is None: + api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间") + return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") + + api_logger.info( + f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘曲线: " + f"importance_score={request.importance_score}, days={request.days}, config_id={request.config_id}" + ) + + try: + # 调用服务层生成遗忘曲线 + result = await forget_service.get_forgetting_curve( + db=db, + importance_score=request.importance_score, + days=request.days, + config_id=request.config_id + ) + + # 转换为响应格式 + curve_points = [ + ForgettingCurvePoint(**point) + for point in result['curve_data'] + ] + + # 构建响应 + response_data = ForgettingCurveResponse( + curve_data=curve_points, + config=result['config'] + ) + + return success(data=response_data.model_dump(), msg="查询成功") + + except Exception as e: + api_logger.error(f"获取遗忘曲线失败: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取遗忘曲线失败", str(e)) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index bb53b833..c58ecd6d 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -1,4 +1,3 @@ -import datetime import os import uuid from typing import Optional @@ -9,12 +8,7 @@ from app.core.memory.utils.self_reflexion_utils import self_reflexion from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user -from app.models.end_user_model import EndUser from app.models.user_model import User -from app.schemas.end_user_schema import ( - EndUserProfileResponse, - EndUserProfileUpdate, -) from app.schemas.memory_storage_schema import ( ConfigKey, ConfigParamsCreate, @@ -22,8 +16,6 @@ from app.schemas.memory_storage_schema import ( ConfigPilotRun, ConfigUpdate, ConfigUpdateExtracted, - ConfigUpdateForget, - GenerateCacheRequest, ) from app.schemas.response_schema import ApiResponse from app.services.memory_storage_service import ( @@ -238,28 +230,8 @@ def update_config_extracted( # --- Forget config params --- -@router.post("/update_config_forget", response_model=ApiResponse) # 更新遗忘引擎配置参数(固定路径) -def update_config_forget( - payload: ConfigUpdateForget, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试更新遗忘引擎配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新遗忘引擎配置: {payload.config_id}") - try: - svc = DataConfigService(db) - result = svc.update_forget(payload) - return success(data=result, msg="更新成功") - except Exception as e: - api_logger.error(f"Update config forget failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "更新遗忘引擎配置失败", str(e)) - +# 遗忘引擎配置接口已迁移到 memory_forget_controller.py +# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 def read_config_extracted( @@ -283,28 +255,6 @@ def read_config_extracted( api_logger.error(f"Read config extracted failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) -@router.get("/read_config_forget", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 -def read_config_forget( - config_id: str, - current_user: User = Depends(get_current_user), - db: Session = Depends(get_db), - ) -> dict: - workspace_id = current_user.current_workspace_id - - # 检查用户是否已选择工作空间 - if workspace_id is None: - api_logger.warning(f"用户 {current_user.username} 尝试读取遗忘引擎配置但未选择工作空间") - return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") - - api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取遗忘引擎配置: {config_id}") - try: - svc = DataConfigService(db) - result = svc.get_forget(ConfigKey(config_id=config_id)) - return success(data=result, msg="查询成功") - except Exception as e: - api_logger.error(f"Read config forget failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "查询遗忘引擎配置失败", str(e)) - @router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表 def read_all_config( current_user: User = Depends(get_current_user), diff --git a/api/app/core/memory/agent/mcp_server/services/search_service.py b/api/app/core/memory/agent/mcp_server/services/search_service.py index be96bb64..47295f87 100644 --- a/api/app/core/memory/agent/mcp_server/services/search_service.py +++ b/api/app/core/memory/agent/mcp_server/services/search_service.py @@ -106,28 +106,32 @@ class SearchService: limit: int = 15, search_type: str = "hybrid", include: Optional[List[str]] = None, - rerank_alpha: float = 0.4, + rerank_alpha: float = 0.6, + activation_boost_factor: float = 0.8, output_path: str = "search_results.json", return_raw_results: bool = False, memory_config: "MemoryConfig" = None, ) -> Tuple[str, str, Optional[dict]]: """ - Execute hybrid search and return clean content. + Execute hybrid search with two-stage ranking. + + Stage 1: Filter by content relevance (BM25 + Embedding) + Stage 2: Rerank by activation values (ACTR) Args: - group_id: Group identifier for filtering results + group_id: Group identifier for filtering question: Search query text - limit: Maximum number of results to return (default: 5) - search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid") - include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"]) - rerank_alpha: Weight for BM25 scores in reranking (default: 0.4) - output_path: Path to save search results (default: "search_results.json") - return_raw_results: If True, also return the raw search results as third element (default: False) - memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided. + limit: Max results per category (default: 15) + search_type: "hybrid", "keyword", or "embedding" (default: "hybrid") + include: Result types (default: ["statements", "chunks", "entities", "summaries"]) + rerank_alpha: BM25 weight (default: 0.6) + activation_boost_factor: Activation impact on memory strength (default: 0.8) + output_path: JSON output path (default: "search_results.json") + return_raw_results: Return full metadata (default: False) + memory_config: MemoryConfig for embedding model Returns: - Tuple of (clean_content, cleaned_query, raw_results) - raw_results is None if return_raw_results=False + Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results) """ if include is None: include = ["statements", "chunks", "entities", "summaries"] @@ -151,6 +155,7 @@ class SearchService: output_path=output_path, memory_config=config, rerank_alpha=rerank_alpha, + activation_boost_factor=activation_boost_factor, ) # Extract results based on search type and include parameter diff --git a/api/app/core/memory/models/graph_models.py b/api/app/core/memory/models/graph_models.py index 5977a2d7..4d4221a3 100644 --- a/api/app/core/memory/models/graph_models.py +++ b/api/app/core/memory/models/graph_models.py @@ -228,6 +228,13 @@ class StatementNode(Node): chunk_embedding: Optional embedding vector for the parent chunk connect_strength: Classification of connection strength ('Strong' or 'Weak') config_id: Configuration ID used to process this statement + + # ACT-R Memory Activation Properties + importance_score: Importance score for memory activation (0.0-1.0), default 0.5 + activation_value: Current activation value calculated by ACT-R engine (0.0-1.0) + access_history: List of ISO timestamp strings recording each access + last_access_time: ISO timestamp of the most recent access + access_count: Total number of times this node has been accessed """ # Core fields (ordered as requested) chunk_id: str = Field(..., description="ID of the parent chunk") @@ -269,6 +276,33 @@ class StatementNode(Node): connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)") + # ACT-R Memory Activation Properties + importance_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Importance score for memory activation (0.0-1.0), default 0.5" + ) + activation_value: Optional[float] = Field( + None, + ge=0.0, + le=1.0, + description="Current activation value calculated by ACT-R engine (0.0-1.0)" + ) + access_history: List[str] = Field( + default_factory=list, + description="List of ISO timestamp strings recording each access" + ) + last_access_time: Optional[str] = Field( + None, + description="ISO timestamp of the most recent access" + ) + access_count: int = Field( + default=0, + ge=0, + description="Total number of times this node has been accessed" + ) + @field_validator('valid_at', 'invalid_at', mode='before') @classmethod def validate_datetime(cls, v): @@ -351,6 +385,13 @@ class ExtractedEntityNode(Node): fact_summary: Summary of facts about this entity connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both') config_id: Configuration ID used to process this entity (integer or string) + + # ACT-R Memory Activation Properties + importance_score: Importance score for memory activation (0.0-1.0), default 0.5 + activation_value: Current activation value calculated by ACT-R engine (0.0-1.0) + access_history: List of ISO timestamp strings recording each access + last_access_time: ISO timestamp of the most recent access + access_count: Total number of times this node has been accessed """ entity_idx: int = Field(..., description="Unique identifier for the entity") statement_id: str = Field(..., description="Statement this entity was extracted from") @@ -365,6 +406,33 @@ class ExtractedEntityNode(Node): connect_strength: str = Field(..., description="Strong VS Weak about this entity") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)") + # ACT-R Memory Activation Properties + importance_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Importance score for memory activation (0.0-1.0), default 0.5" + ) + activation_value: Optional[float] = Field( + None, + ge=0.0, + le=1.0, + description="Current activation value calculated by ACT-R engine (0.0-1.0)" + ) + access_history: List[str] = Field( + default_factory=list, + description="List of ISO timestamp strings recording each access" + ) + last_access_time: Optional[str] = Field( + None, + description="ISO timestamp of the most recent access" + ) + access_count: int = Field( + default=0, + ge=0, + description="Total number of times this node has been accessed" + ) + @field_validator('aliases', mode='before') @classmethod def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 @@ -401,6 +469,16 @@ class MemorySummaryNode(Node): summary_embedding: Optional embedding vector for the summary metadata: Additional metadata for the summary config_id: Configuration ID used to process this summary + original_statement_id: ID of the original statement that was merged (for ACT-R forgetting) + original_entity_id: ID of the original entity that was merged (for ACT-R forgetting) + merged_at: Timestamp when the nodes were merged + + # ACT-R Memory Activation Properties + importance_score: Importance score for memory activation (0.0-1.0), inherited from merged nodes + activation_value: Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes + access_history: List of ISO timestamp strings recording each access (reset on creation) + last_access_time: ISO timestamp of the most recent access (set to creation time) + access_count: Total number of times this node has been accessed (reset to 1 on creation) """ summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary") dialog_id: str = Field(..., description="ID of the parent dialog") @@ -409,3 +487,44 @@ class MemorySummaryNode(Node): summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)") + + # ACT-R Forgetting Engine Properties + original_statement_id: Optional[str] = Field( + None, + description="ID of the original statement that was merged (for traceability)" + ) + original_entity_id: Optional[str] = Field( + None, + description="ID of the original entity that was merged (for traceability)" + ) + merged_at: Optional[datetime] = Field( + None, + description="Timestamp when the nodes were merged" + ) + + # ACT-R Memory Activation Properties + importance_score: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Importance score for memory activation (0.0-1.0), inherited from merged nodes" + ) + activation_value: Optional[float] = Field( + None, + ge=0.0, + le=1.0, + description="Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes" + ) + access_history: List[str] = Field( + default_factory=list, + description="List of ISO timestamp strings recording each access (reset on creation)" + ) + last_access_time: Optional[str] = Field( + None, + description="ISO timestamp of the most recent access (set to creation time)" + ) + access_count: int = Field( + default=1, + ge=0, + description="Total number of times this node has been accessed (reset to 1 on creation)" + ) diff --git a/api/app/core/memory/src/search.py b/api/app/core/memory/src/search.py index 9353f00e..11df8166 100644 --- a/api/app/core/memory/src/search.py +++ b/api/app/core/memory/src/search.py @@ -69,6 +69,12 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") for item in results: if score_field in item: score = item.get(score_field) + # 对于 activation_value,None 值保持为 None,不使用回退值 + # 这样可以区分有激活值和无激活值的节点 + if score_field == "activation_value" and score is None: + scores.append(None) # 保持 None,稍后特殊处理 + continue + if score is not None and isinstance(score, (int, float)): scores.append(float(score)) else: @@ -76,205 +82,433 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score") if not scores: return results - - if len(scores) == 1: - # Single score, set to 1.0 + + # 过滤掉 None 值,只对有效分数进行归一化 + valid_scores = [s for s in scores if s is not None] + + if not valid_scores: + # 所有分数都是 None,不进行归一化 for item in results: - if score_field in item: - item[f"normalized_{score_field}"] = 1.0 + if score_field in item or score_field == "activation_value": + item[f"normalized_{score_field}"] = None return results - # Calculate mean and standard deviation - mean_score = sum(scores) / len(scores) - variance = sum((score - mean_score) ** 2 for score in scores) / len(scores) + if len(valid_scores) == 1: # Single valid score, set to 1.0 + for item, score in zip(results, scores): + if score_field in item or score_field == "activation_value": + if score is None: + item[f"normalized_{score_field}"] = None + else: + item[f"normalized_{score_field}"] = 1.0 + return results + + # Calculate mean and standard deviation (only for valid scores) + mean_score = sum(valid_scores) / len(valid_scores) + variance = sum((score - mean_score) ** 2 for score in valid_scores) / len(valid_scores) std_dev = math.sqrt(variance) if std_dev == 0: - # All scores are the same, set them to 1.0 - for item in results: - if score_field in item: - item[f"normalized_{score_field}"] = 1.0 + # All valid scores are the same, set them to 1.0 + for item, score in zip(results, scores): + if score_field in item or score_field == "activation_value": + if score is None: + item[f"normalized_{score_field}"] = None + else: + item[f"normalized_{score_field}"] = 1.0 else: - for item in results: - if score_field in item: - score = item[score_field] - # Handle None or non-numeric scores - if score is None or not isinstance(score, (int, float)): - score = 0.0 - # Calculate z-score - z_score = (score - mean_score) / std_dev - # Transform to positive range using sigmoid function - normalized = 1 / (1 + math.exp(-z_score)) - item[f"normalized_{score_field}"] = normalized + for item, score in zip(results, scores): + if score_field in item or score_field == "activation_value": + if score is None: + # 保持 None,不进行归一化 + item[f"normalized_{score_field}"] = None + else: + # Calculate z-score + z_score = (score - mean_score) / std_dev + # Transform to positive range using sigmoid function + normalized = 1 / (1 + math.exp(-z_score)) + item[f"normalized_{score_field}"] = normalized return results -def rerank_hybrid_results( - keyword_results: Dict[str, List[Dict[str, Any]]], - embedding_results: Dict[str, List[Dict[str, Any]]], - alpha: float = 0.6, - limit: int = 10 -) -> Dict[str, List[Dict[str, Any]]]: - """ - Rerank hybrid search results by combining BM25 and embedding scores. +# ============================================================================ +# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考 +# ============================================================================ - Args: - keyword_results: Results from keyword/BM25 search - embedding_results: Results from embedding search - alpha: Weight for BM25 scores (1-alpha for embedding scores) - limit: Maximum number of results to return per category +# def rerank_hybrid_results( +# keyword_results: Dict[str, List[Dict[str, Any]]], +# embedding_results: Dict[str, List[Dict[str, Any]]], +# alpha: float = 0.6, +# limit: int = 10 +# ) -> Dict[str, List[Dict[str, Any]]]: +# """ +# Rerank hybrid search results by combining BM25 and embedding scores. +# +# 已废弃:此函数功能已被 rerank_with_activation 完全替代 +# +# Args: +# keyword_results: Results from keyword/BM25 search +# embedding_results: Results from embedding search +# alpha: Weight for BM25 scores (1-alpha for embedding scores) +# limit: Maximum number of results to return per category +# +# Returns: +# Reranked results with combined scores +# """ +# reranked = {} +# +# for category in ["statements", "chunks", "entities","summaries"]: +# keyword_items = keyword_results.get(category, []) +# embedding_items = embedding_results.get(category, []) +# +# # Normalize scores within each search type +# keyword_items = normalize_scores(keyword_items, "score") +# embedding_items = normalize_scores(embedding_items, "score") +# +# # Create a combined pool of unique items +# combined_items = {} +# +# # Add keyword results with BM25 scores +# for item in keyword_items: +# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") +# if item_id: +# combined_items[item_id] = item.copy() +# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) +# combined_items[item_id]["embedding_score"] = 0 # Default +# +# # Add or update with embedding results +# for item in embedding_items: +# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") +# if item_id: +# if item_id in combined_items: +# # Update existing item with embedding score +# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) +# else: +# # New item from embedding search only +# combined_items[item_id] = item.copy() +# combined_items[item_id]["bm25_score"] = 0 # Default +# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) +# +# # Calculate combined scores and rank +# for item_id, item in combined_items.items(): +# bm25_score = item.get("bm25_score", 0) +# embedding_score = item.get("embedding_score", 0) +# +# # Combined score: weighted average of normalized scores +# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score +# item["combined_score"] = combined_score +# +# # Keep original score for reference +# if "score" not in item and bm25_score > 0: +# item["score"] = bm25_score +# elif "score" not in item and embedding_score > 0: +# item["score"] = embedding_score +# +# # Sort by combined score and limit results +# sorted_items = sorted( +# combined_items.values(), +# key=lambda x: x.get("combined_score", 0), +# reverse=True +# )[:limit] +# +# reranked[category] = sorted_items +# +# return reranked - Returns: - Reranked results with combined scores - """ - reranked = {} +# def rerank_with_forgetting_curve( +# keyword_results: Dict[str, List[Dict[str, Any]]], +# embedding_results: Dict[str, List[Dict[str, Any]]], +# alpha: float = 0.6, +# limit: int = 10, +# forgetting_config: ForgettingEngineConfig | None = None, +# now: datetime | None = None, +# ) -> Dict[str, List[Dict[str, Any]]]: +# """ +# Rerank hybrid results with a forgetting curve applied to combined scores. +# +# 已废弃:此函数功能已被 rerank_with_activation 完全替代 +# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度) +# +# The forgetting curve reduces scores for older memories or weaker connections. +# +# Args: +# keyword_results: Results from keyword/BM25 search +# embedding_results: Results from embedding search +# alpha: Weight for BM25 scores (1-alpha for embedding scores) +# limit: Maximum number of results to return per category +# forgetting_config: Configuration for the forgetting engine +# now: Optional current time override for testing +# +# Returns: +# Reranked results with combined and final scores (after forgetting) +# """ +# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig()) +# now_dt = now or datetime.now() +# +# reranked: Dict[str, List[Dict[str, Any]]] = {} +# +# for category in ["statements", "chunks", "entities","summaries"]: +# keyword_items = keyword_results.get(category, []) +# embedding_items = embedding_results.get(category, []) +# +# # Normalize scores within each search type +# keyword_items = normalize_scores(keyword_items, "score") +# embedding_items = normalize_scores(embedding_items, "score") +# +# combined_items: Dict[str, Dict[str, Any]] = {} +# +# # Combine two result sets by ID +# for src_items, is_embedding in ( +# (keyword_items, False), (embedding_items, True) +# ): +# for item in src_items: +# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") +# if not item_id: +# continue +# existing = combined_items.get(item_id) +# if not existing: +# combined_items[item_id] = item.copy() +# combined_items[item_id]["bm25_score"] = 0 +# combined_items[item_id]["embedding_score"] = 0 +# # Update normalized score from the right source +# if is_embedding: +# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) +# else: +# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) +# +# # Calculate scores and apply forgetting weights +# for item_id, item in combined_items.items(): +# bm25_score = float(item.get("bm25_score", 0) or 0) +# embedding_score = float(item.get("embedding_score", 0) or 0) +# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score +# +# # Estimate time elapsed in days +# dt = _parse_datetime(item.get("created_at")) +# if dt is None: +# time_elapsed_days = 0.0 +# else: +# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) +# +# # Memory strength (currently set to default value) +# memory_strength = 1.0 +# forgetting_weight = engine.calculate_weight( +# time_elapsed=time_elapsed_days, memory_strength=memory_strength +# ) +# final_score = combined_score * forgetting_weight +# item["combined_score"] = final_score +# +# sorted_items = sorted( +# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True +# )[:limit] +# +# reranked[category] = sorted_items +# +# return reranked - for category in ["statements", "chunks", "entities","summaries"]: - keyword_items = keyword_results.get(category, []) - embedding_items = embedding_results.get(category, []) - # Normalize scores within each search type - keyword_items = normalize_scores(keyword_items, "score") - embedding_items = normalize_scores(embedding_items, "score") - - # Create a combined pool of unique items - combined_items = {} - - # Add keyword results with BM25 scores - for item in keyword_items: - item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") - if item_id: - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) - combined_items[item_id]["embedding_score"] = 0 # Default - - # Add or update with embedding results - for item in embedding_items: - item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") - if item_id: - if item_id in combined_items: - # Update existing item with embedding score - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - else: - # New item from embedding search only - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = 0 # Default - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - - # Calculate combined scores and rank - for item_id, item in combined_items.items(): - bm25_score = item.get("bm25_score", 0) - embedding_score = item.get("embedding_score", 0) - - # Combined score: weighted average of normalized scores - combined_score = alpha * bm25_score + (1 - alpha) * embedding_score - item["combined_score"] = combined_score - - # Keep original score for reference - if "score" not in item and bm25_score > 0: - item["score"] = bm25_score - elif "score" not in item and embedding_score > 0: - item["score"] = embedding_score - - # Sort by combined score and limit results - sorted_items = sorted( - combined_items.values(), - key=lambda x: x.get("combined_score", 0), - reverse=True - )[:limit] - - reranked[category] = sorted_items - - return reranked - -def rerank_with_forgetting_curve( +def rerank_with_activation( keyword_results: Dict[str, List[Dict[str, Any]]], embedding_results: Dict[str, List[Dict[str, Any]]], alpha: float = 0.6, limit: int = 10, forgetting_config: ForgettingEngineConfig | None = None, + activation_boost_factor: float = 0.8, now: datetime | None = None, ) -> Dict[str, List[Dict[str, Any]]]: """ - Rerank hybrid results with a forgetting curve applied to combined scores. - - The forgetting curve reduces scores for older memories or weaker connections. - - Args: - keyword_results: Results from keyword/BM25 search - embedding_results: Results from embedding search - alpha: Weight for BM25 scores (1-alpha for embedding scores) - limit: Maximum number of results to return per category - forgetting_config: Configuration for the forgetting engine - now: Optional current time override for testing - - Returns: - Reranked results with combined and final scores (after forgetting) + 两阶段排序:先按内容相关性筛选,再按激活值排序。 + + 阶段1: content_score = alpha*BM25 + (1-alpha)*Embedding,取 Top-(limit*3) + 阶段2: 在候选中按 activation_score 排序,取 Top-limit + 无激活值的节点用于补充不足 + + 返回结果中的评分字段说明: + - bm25_score: BM25 归一化分数 + - embedding_score: Embedding 归一化分数 + - content_score: 内容相关性 = alpha*bm25 + (1-alpha)*embedding + - activation_score: ACTR 激活值归一化分数 + - base_score: 第一阶段基础分数(等于 content_score) + - final_score: 最终排序依据 + * 有激活值的节点:final_score = activation_score + * 无激活值的节点:final_score = base_score + + 参数: + keyword_results: BM25 检索结果 + embedding_results: 向量嵌入检索结果 + alpha: BM25 权重 (默认: 0.6) + limit: 每类最大结果数 + forgetting_config: 遗忘引擎配置(当前未使用) + activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8) + now: 当前时间(用于遗忘计算) + + 返回: + 带评分元数据的重排序结果,按 final_score 排序 """ - engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig()) + # 验证权重范围 + if not (0 <= alpha <= 1): + raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}") + + # 初始化遗忘引擎(如果需要) + engine = None + if forgetting_config: + engine = ForgettingEngine(forgetting_config) now_dt = now or datetime.now() - + reranked: Dict[str, List[Dict[str, Any]]] = {} - - for category in ["statements", "chunks", "entities","summaries"]: + + for category in ["statements", "chunks", "entities", "summaries"]: keyword_items = keyword_results.get(category, []) embedding_items = embedding_results.get(category, []) - - # Normalize scores within each search type + + # 步骤 1: 归一化分数 keyword_items = normalize_scores(keyword_items, "score") embedding_items = normalize_scores(embedding_items, "score") - + + # 步骤 2: 按 ID 合并结果 combined_items: Dict[str, Dict[str, Any]] = {} - - # Combine two result sets by ID - for src_items, is_embedding in ( - (keyword_items, False), (embedding_items, True) - ): - for item in src_items: - item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") - if not item_id: - continue - existing = combined_items.get(item_id) - if not existing: - combined_items[item_id] = item.copy() - combined_items[item_id]["bm25_score"] = 0 - combined_items[item_id]["embedding_score"] = 0 - # Update normalized score from the right source - if is_embedding: - combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) - else: - combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) - - # Calculate scores and apply forgetting weights - for item_id, item in combined_items.items(): - bm25_score = float(item.get("bm25_score", 0) or 0) - embedding_score = float(item.get("embedding_score", 0) or 0) - combined_score = alpha * bm25_score + (1 - alpha) * embedding_score - - # Estimate time elapsed in days - dt = _parse_datetime(item.get("created_at")) - if dt is None: - time_elapsed_days = 0.0 + + # 添加关键词结果 + for item in keyword_items: + item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") + if not item_id: + continue + combined_items[item_id] = item.copy() + combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0) + combined_items[item_id]["embedding_score"] = 0 # 默认值 + + # 添加或更新向量嵌入结果 + for item in embedding_items: + item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") + if not item_id: + continue + if item_id in combined_items: + # 更新现有项的嵌入分数 + combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) else: - time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) - - # Memory strength (currently set to default value) - memory_strength = 1.0 - forgetting_weight = engine.calculate_weight( - time_elapsed=time_elapsed_days, memory_strength=memory_strength - ) - # print(f"Forgetting weight for {item_id}: {forgetting_weight}") - # print(f"Time elapsed days for {item_id}: {time_elapsed_days}") - final_score = combined_score * forgetting_weight - item["combined_score"] = final_score - - sorted_items = sorted( - combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True - )[:limit] - + # 仅来自嵌入搜索的新项 + combined_items[item_id] = item.copy() + combined_items[item_id]["bm25_score"] = 0 # 默认值 + combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0) + + # 步骤 3: 归一化激活度分数 + # 为所有项准备激活度值列表 + items_list = list(combined_items.values()) + items_list = normalize_scores(items_list, "activation_value") + + # 更新 combined_items 中的归一化激活度分数 + for item in items_list: + item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") + if item_id and item_id in combined_items: + combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0) + + # 步骤 4: 计算基础分数和最终分数 + for item_id, item in combined_items.items(): + bm25_norm = float(item.get("bm25_score", 0) or 0) + emb_norm = float(item.get("embedding_score", 0) or 0) + act_norm = float(item.get("normalized_activation_value", 0) or 0) + + # 第一阶段:只考虑内容相关性(BM25 + Embedding) + # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 + content_score = alpha * bm25_norm + (1 - alpha) * emb_norm + base_score = content_score # 第一阶段用内容分数 + + # 存储激活度分数供第二阶段使用 + item["activation_score"] = act_norm + item["content_score"] = content_score + item["base_score"] = base_score + + # 步骤 5: 应用遗忘曲线(可选) + if engine: + # 计算受激活度影响的记忆强度 + importance = float(item.get("importance_score", 0.5) or 0.5) + + # 获取 activation_value + activation_val = item.get("activation_value") + + # 只对有激活值的节点应用遗忘曲线 + if activation_val is not None and isinstance(activation_val, (int, float)): + activation_val = float(activation_val) + + # 计算记忆强度:importance_score × (1 + activation_value × boost_factor) + memory_strength = importance * (1 + activation_val * activation_boost_factor) + + # 计算经过的时间(天数) + dt = _parse_datetime(item.get("created_at")) + if dt is None: + time_elapsed_days = 0.0 + else: + time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0) + + # 获取遗忘权重 + forgetting_weight = engine.calculate_weight( + time_elapsed=time_elapsed_days, + memory_strength=memory_strength + ) + + # 应用到基础分数 + item["forgetting_weight"] = forgetting_weight + item["final_score"] = base_score * forgetting_weight + else: + # 无激活值的节点不应用遗忘曲线,保持原始分数 + item["final_score"] = base_score + else: + # 不使用遗忘曲线 + item["final_score"] = base_score + + # 步骤 6: 两阶段排序和限制 + # 第一阶段:按内容相关性(base_score)排序,取 Top-K + first_stage_limit = limit * 3 # 可配置,取3倍候选 + first_stage_sorted = sorted( + combined_items.values(), + key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序 + reverse=True + )[:first_stage_limit] + + # 第二阶段:分离有激活值和无激活值的节点 + items_with_activation = [] + items_without_activation = [] + + for item in first_stage_sorted: + activation_score = item.get("activation_score") + # 检查是否有有效的激活值(不是 None) + if activation_score is not None and isinstance(activation_score, (int, float)): + items_with_activation.append(item) + else: + items_without_activation.append(item) + + # 优先按激活值排序有激活值的节点 + sorted_with_activation = sorted( + items_with_activation, + key=lambda x: float(x.get("activation_score", 0) or 0), + reverse=True + ) + + # 如果有激活值的节点不足 limit,用无激活值的节点补充 + if len(sorted_with_activation) < limit: + needed = limit - len(sorted_with_activation) + # 无激活值的节点保持第一阶段的内容相关性排序 + sorted_items = sorted_with_activation + items_without_activation[:needed] + else: + sorted_items = sorted_with_activation[:limit] + + # 两阶段排序完成,更新 final_score 以反映实际排序依据 + # Stage 1: 按 content_score 筛选候选(已完成) + # Stage 2: 按 activation_score 排序(已完成) + # + # final_score 语义:反映节点在最终结果中的排序依据 + # - 有激活值的节点:final_score = activation_score(第二阶段排序依据) + # - 无激活值的节点:final_score = base_score(保持内容相关性分数) + for item in sorted_items: + activation_score = item.get("activation_score") + if activation_score is not None and isinstance(activation_score, (int, float)): + # 有激活值:使用激活度作为最终分数 + item["final_score"] = activation_score + else: + # 无激活值:使用内容相关性分数 + item["final_score"] = item.get("base_score", 0) + reranked[category] = sorted_items - + return reranked @@ -560,6 +794,7 @@ async def run_hybrid_search( output_path: str | None, memory_config: "MemoryConfig", rerank_alpha: float = 0.6, + activation_boost_factor: float = 0.8, use_forgetting_rerank: bool = False, use_llm_rerank: bool = False, ): @@ -685,30 +920,28 @@ async def run_hybrid_search( "search_timestamp": datetime.now().isoformat() } - # Apply reranking (optionally with forgetting curve) + # Apply two-stage reranking with ACTR activation calculation rerank_start = time.time() - if use_forgetting_rerank: - # Load forgetting parameters from pipeline config - try: - pc = get_pipeline_config(memory_config) - forgetting_cfg = pc.forgetting_engine - except Exception as e: - logger.debug(f"Failed to load forgetting config, using defaults: {e}") - forgetting_cfg = ForgettingEngineConfig() - reranked_results = rerank_with_forgetting_curve( - keyword_results=keyword_results, - embedding_results=embedding_results, - alpha=rerank_alpha, - limit=limit, - forgetting_config=forgetting_cfg, - ) - else: - reranked_results = rerank_hybrid_results( - keyword_results=keyword_results, - embedding_results=embedding_results, - alpha=rerank_alpha, # Configurable weight for BM25 vs embedding - limit=limit - ) + logger.info("Using two-stage reranking with ACTR activation") + + # 加载遗忘引擎配置 + try: + pc = get_pipeline_config(memory_config) + forgetting_cfg = pc.forgetting_engine + except Exception as e: + logger.debug(f"Failed to load forgetting config, using defaults: {e}") + forgetting_cfg = ForgettingEngineConfig() + + # 统一使用激活度重排序(两阶段:检索 + ACTR计算) + reranked_results = rerank_with_activation( + keyword_results=keyword_results, + embedding_results=embedding_results, + alpha=rerank_alpha, + limit=limit, + forgetting_config=forgetting_cfg, + activation_boost_factor=activation_boost_factor, + ) + rerank_latency = time.time() - rerank_start latency_metrics["reranking_latency"] = round(rerank_latency, 4) logger.info(f"Reranking completed in {rerank_latency:.4f}s") @@ -737,6 +970,7 @@ async def run_hybrid_search( "search_query": query_text, "search_timestamp": datetime.now().isoformat(), "reranking_alpha": rerank_alpha, + "activation_boost_factor": activation_boost_factor, "forgetting_rerank": use_forgetting_rerank, "llm_rerank": llm_rerank_applied, } diff --git a/api/app/core/memory/storage_services/forgetting_engine/__init__.py b/api/app/core/memory/storage_services/forgetting_engine/__init__.py index db5c0769..794fc46f 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/__init__.py +++ b/api/app/core/memory/storage_services/forgetting_engine/__init__.py @@ -1,8 +1,40 @@ """遗忘引擎模块 -该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。 +该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线和 ACT-R 认知架构理论。 """ from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine +from app.core.memory.storage_services.forgetting_engine.actr_calculator import ( + ACTRCalculator, + calculate_activation, + generate_forgetting_curve +) +from app.core.memory.storage_services.forgetting_engine.access_history_manager import ( + AccessHistoryManager, + ConsistencyCheckResult +) +from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ( + ForgettingStrategy +) +from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import ( + ForgettingScheduler +) +from app.core.memory.storage_services.forgetting_engine.config_utils import ( + calculate_forgetting_rate, + load_actr_config_from_db, + create_actr_calculator_from_config +) -__all__ = ["ForgettingEngine"] +__all__ = [ + "ForgettingEngine", + "ACTRCalculator", + "calculate_activation", + "generate_forgetting_curve", + "AccessHistoryManager", + "ConsistencyCheckResult", + "ForgettingStrategy", + "ForgettingScheduler", + "calculate_forgetting_rate", + "load_actr_config_from_db", + "create_actr_calculator_from_config" +] diff --git a/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py new file mode 100644 index 00000000..acc2a717 --- /dev/null +++ b/api/app/core/memory/storage_services/forgetting_engine/access_history_manager.py @@ -0,0 +1,691 @@ +""" +访问历史管理器模块 + +本模块实现访问历史的追踪、更新和一致性保证。 +负责在知识节点被访问时原子性地更新激活值相关的所有字段。 + +Classes: + AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查 +""" + +import logging +from typing import List, Dict, Any, Optional, Tuple +from datetime import datetime +from enum import Enum + +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator + + +logger = logging.getLogger(__name__) + + +class ConsistencyCheckResult(Enum): + """一致性检查结果枚举""" + CONSISTENT = "consistent" # 数据一致 + INCONSISTENT_HISTORY_TIME = "inconsistent_history_time" # access_history[-1] != last_access_time + INCONSISTENT_HISTORY_COUNT = "inconsistent_history_count" # len(access_history) != access_count + MISSING_ACTIVATION = "missing_activation" # 有访问历史但无激活值 + INVALID_ACTIVATION_RANGE = "invalid_activation_range" # 激活值超出有效范围 + + +class AccessHistoryManager: + """ + 访问历史管理器 + + 负责追踪知识节点的访问历史,并在访问时原子性地更新所有相关字段: + - activation_value: 激活值 + - access_history: 访问历史时间戳数组 + - last_access_time: 最后访问时间 + - access_count: 访问次数 + + 特性: + - 原子性更新:使用Neo4j事务确保所有字段同时更新或回滚 + - 并发安全:使用乐观锁机制防止并发冲突 + - 一致性保证:提供一致性检查和自动修复功能 + - 智能修剪:自动修剪过长的访问历史 + + Attributes: + connector: Neo4j连接器实例 + actr_calculator: ACT-R激活值计算器实例 + max_retries: 并发冲突时的最大重试次数 + """ + + def __init__( + self, + connector: Neo4jConnector, + actr_calculator: ACTRCalculator, + max_retries: int = 3 + ): + """ + 初始化访问历史管理器 + + Args: + connector: Neo4j连接器实例 + actr_calculator: ACT-R激活值计算器实例 + max_retries: 并发冲突时的最大重试次数(默认3次) + """ + self.connector = connector + self.actr_calculator = actr_calculator + self.max_retries = max_retries + + async def record_access( + self, + node_id: str, + node_label: str, + group_id: Optional[str] = None, + current_time: Optional[datetime] = None + ) -> Dict[str, Any]: + """ + 记录节点访问并原子性更新所有相关字段 + + 这是核心方法,实现了: + 1. 首次访问:初始化access_history,计算初始激活值 + 2. 后续访问:追加访问历史,重新计算激活值 + 3. 历史修剪:当历史过长时自动修剪 + 4. 原子性:所有字段在单个事务中更新 + 5. 并发安全:使用乐观锁重试机制 + + Args: + node_id: 节点ID + node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) + group_id: 组ID(可选,用于过滤) + current_time: 当前时间(可选,默认使用系统时间) + + Returns: + Dict[str, Any]: 更新后的节点数据,包含: + - id: 节点ID + - activation_value: 更新后的激活值 + - access_history: 更新后的访问历史 + - last_access_time: 最后访问时间 + - access_count: 访问次数 + - importance_score: 重要性分数 + + Raises: + ValueError: 如果节点不存在或节点标签无效 + RuntimeError: 如果重试次数耗尽仍然失败 + """ + if current_time is None: + current_time = datetime.now() + + current_time_iso = current_time.isoformat() + + # 验证节点标签 + valid_labels = ["Statement", "ExtractedEntity", "MemorySummary"] + if node_label not in valid_labels: + raise ValueError( + f"Invalid node_label: {node_label}. Must be one of {valid_labels}" + ) + + # 使用乐观锁重试机制处理并发冲突 + for attempt in range(self.max_retries): + try: + # 步骤1:读取当前节点状态 + node_data = await self._fetch_node(node_id, node_label, group_id) + + if not node_data: + raise ValueError( + f"Node not found: {node_label} with id={node_id}" + ) + + # 步骤2:计算新的访问历史和激活值 + update_data = await self._calculate_update( + node_data=node_data, + current_time=current_time, + current_time_iso=current_time_iso + ) + + # 步骤3:原子性更新节点(使用事务) + updated_node = await self._atomic_update( + node_id=node_id, + node_label=node_label, + update_data=update_data, + group_id=group_id + ) + + logger.info( + f"成功记录访问: {node_label}[{node_id}], " + f"activation={update_data['activation_value']:.4f}, " + f"access_count={update_data['access_count']}" + ) + + return updated_node + + except Exception as e: + if attempt < self.max_retries - 1: + logger.warning( + f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}): {str(e)}" + ) + continue + else: + logger.error( + f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], " + f"错误: {str(e)}" + ) + raise RuntimeError( + f"Failed to record access after {self.max_retries} attempts: {str(e)}" + ) + + async def record_batch_access( + self, + node_ids: List[str], + node_label: str, + group_id: Optional[str] = None, + current_time: Optional[datetime] = None + ) -> List[Dict[str, Any]]: + """ + 批量记录多个节点的访问 + + 为提高性能,批量更新多个节点的访问历史。 + 每个节点独立更新,失败的节点不影响其他节点。 + + Args: + node_ids: 节点ID列表 + node_label: 节点标签(所有节点必须是同一类型) + group_id: 组ID(可选) + current_time: 当前时间(可选) + + Returns: + List[Dict[str, Any]]: 成功更新的节点列表 + """ + if current_time is None: + current_time = datetime.now() + + results = [] + failed_count = 0 + + for node_id in node_ids: + try: + updated_node = await self.record_access( + node_id=node_id, + node_label=node_label, + group_id=group_id, + current_time=current_time + ) + results.append(updated_node) + except Exception as e: + failed_count += 1 + logger.warning( + f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}" + ) + + logger.info( + f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, " + f"失败 {failed_count}" + ) + + return results + + async def check_consistency( + self, + node_id: str, + node_label: str, + group_id: Optional[str] = None + ) -> Tuple[ConsistencyCheckResult, Optional[str]]: + """ + 检查节点数据的一致性 + + 验证以下一致性规则: + 1. access_history[-1] == last_access_time + 2. len(access_history) == access_count + 3. 如果有访问历史,必须有激活值 + 4. 激活值必须在有效范围内 [offset, 1.0] + + Args: + node_id: 节点ID + node_label: 节点标签 + group_id: 组ID(可选) + + Returns: + Tuple[ConsistencyCheckResult, Optional[str]]: + - 一致性检查结果枚举 + - 错误描述(如果不一致) + """ + node_data = await self._fetch_node(node_id, node_label, group_id) + + if not node_data: + return ConsistencyCheckResult.CONSISTENT, None + + access_history = node_data.get('access_history', []) + last_access_time = node_data.get('last_access_time') + access_count = node_data.get('access_count', 0) + activation_value = node_data.get('activation_value') + + # 检查1:access_history[-1] == last_access_time + if access_history and last_access_time: + if access_history[-1] != last_access_time: + return ( + ConsistencyCheckResult.INCONSISTENT_HISTORY_TIME, + f"access_history[-1]={access_history[-1]} != " + f"last_access_time={last_access_time}" + ) + + # 检查2:len(access_history) == access_count + if len(access_history) != access_count: + return ( + ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT, + f"len(access_history)={len(access_history)} != " + f"access_count={access_count}" + ) + + # 检查3:有访问历史必须有激活值 + if access_history and activation_value is None: + return ( + ConsistencyCheckResult.MISSING_ACTIVATION, + "Node has access_history but activation_value is None" + ) + + # 检查4:激活值范围 + if activation_value is not None: + offset = self.actr_calculator.offset + if not (offset <= activation_value <= 1.0): + return ( + ConsistencyCheckResult.INVALID_ACTIVATION_RANGE, + f"activation_value={activation_value} out of range " + f"[{offset}, 1.0]" + ) + + return ConsistencyCheckResult.CONSISTENT, None + + async def check_batch_consistency( + self, + node_label: str, + group_id: Optional[str] = None, + limit: int = 1000 + ) -> Dict[str, Any]: + """ + 批量检查多个节点的一致性 + + Args: + node_label: 节点标签 + group_id: 组ID(可选) + limit: 检查的最大节点数 + + Returns: + Dict[str, Any]: 一致性检查报告,包含: + - total_checked: 检查的节点总数 + - consistent_count: 一致的节点数 + - inconsistent_count: 不一致的节点数 + - inconsistencies: 不一致节点的详细信息列表 + - consistency_rate: 一致性率(0-1) + """ + # 查询所有相关节点 + query = f""" + MATCH (n:{node_label}) + WHERE n.access_history IS NOT NULL + """ + if group_id: + query += " AND n.group_id = $group_id" + query += """ + RETURN n.id as id + LIMIT $limit + """ + + params = {"limit": limit} + if group_id: + params["group_id"] = group_id + + results = await self.connector.execute_query(query, **params) + node_ids = [r['id'] for r in results] + + # 检查每个节点 + inconsistencies = [] + consistent_count = 0 + + for node_id in node_ids: + result, message = await self.check_consistency( + node_id=node_id, + node_label=node_label, + group_id=group_id + ) + + if result == ConsistencyCheckResult.CONSISTENT: + consistent_count += 1 + else: + inconsistencies.append({ + 'node_id': node_id, + 'result': result.value, + 'message': message + }) + + total_checked = len(node_ids) + inconsistent_count = len(inconsistencies) + consistency_rate = consistent_count / total_checked if total_checked > 0 else 1.0 + + report = { + 'total_checked': total_checked, + 'consistent_count': consistent_count, + 'inconsistent_count': inconsistent_count, + 'inconsistencies': inconsistencies, + 'consistency_rate': consistency_rate + } + + logger.info( + f"一致性检查完成: {node_label}, " + f"一致率={consistency_rate:.2%}, " + f"不一致节点={inconsistent_count}/{total_checked}" + ) + + return report + + async def repair_inconsistency( + self, + node_id: str, + node_label: str, + group_id: Optional[str] = None + ) -> bool: + """ + 自动修复节点的数据不一致问题 + + 修复策略: + 1. 如果access_history[-1] != last_access_time:使用access_history[-1] + 2. 如果len(access_history) != access_count:使用len(access_history) + 3. 如果有历史但无激活值:重新计算激活值 + 4. 如果激活值超出范围:重新计算激活值 + + Args: + node_id: 节点ID + node_label: 节点标签 + group_id: 组ID(可选) + + Returns: + bool: 修复成功返回True,否则返回False + """ + try: + # 检查一致性 + result, message = await self.check_consistency( + node_id=node_id, + node_label=node_label, + group_id=group_id + ) + + if result == ConsistencyCheckResult.CONSISTENT: + logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]") + return True + + # 获取节点数据 + node_data = await self._fetch_node(node_id, node_label, group_id) + if not node_data: + logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]") + return False + + access_history = node_data.get('access_history', []) + importance_score = node_data.get('importance_score', 0.5) + + # 准备修复数据 + repair_data = {} + + # 修复last_access_time + if access_history: + repair_data['last_access_time'] = access_history[-1] + + # 修复access_count + repair_data['access_count'] = len(access_history) + + # 修复activation_value + if access_history: + current_time = datetime.now() + last_access_dt = datetime.fromisoformat(access_history[-1]) + access_history_dt = [ + datetime.fromisoformat(ts) for ts in access_history + ] + + activation_value = self.actr_calculator.calculate_memory_activation( + access_history=access_history_dt, + current_time=current_time, + last_access_time=last_access_dt, + importance_score=importance_score + ) + repair_data['activation_value'] = activation_value + + # 执行修复 + query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + """ + if group_id: + query += " WHERE n.group_id = $group_id" + query += """ + SET n += $repair_data + RETURN n + """ + + params = { + 'node_id': node_id, + 'repair_data': repair_data + } + if group_id: + params['group_id'] = group_id + + await self.connector.execute_query(query, **params) + + logger.info( + f"成功修复节点不一致: {node_label}[{node_id}], " + f"问题类型={result.value}" + ) + return True + + except Exception as e: + logger.error( + f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}" + ) + return False + + # ==================== 私有辅助方法 ==================== + + async def _fetch_node( + self, + node_id: str, + node_label: str, + group_id: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """ + 获取节点数据 + + Args: + node_id: 节点ID + node_label: 节点标签 + group_id: 组ID(可选) + + Returns: + Optional[Dict[str, Any]]: 节点数据,如果不存在返回None + """ + query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + """ + if group_id: + query += " WHERE n.group_id = $group_id" + query += """ + RETURN n.id as id, + n.importance_score as importance_score, + n.activation_value as activation_value, + n.access_history as access_history, + n.last_access_time as last_access_time, + n.access_count as access_count + """ + + params = {'node_id': node_id} + if group_id: + params['group_id'] = group_id + + results = await self.connector.execute_query(query, **params) + + if results: + return results[0] + return None + + async def _calculate_update( + self, + node_data: Dict[str, Any], + current_time: datetime, + current_time_iso: str + ) -> Dict[str, Any]: + """ + 计算更新数据 + + Args: + node_data: 当前节点数据 + current_time: 当前时间(datetime对象) + current_time_iso: 当前时间(ISO格式字符串) + + Returns: + Dict[str, Any]: 更新数据,包含所有需要更新的字段 + """ + access_history = node_data.get('access_history', []) + importance_score = node_data.get('importance_score', 0.5) + + # 追加新的访问时间 + new_access_history = access_history + [current_time_iso] + + # 修剪访问历史(如果过长) + access_history_dt = [ + datetime.fromisoformat(ts) for ts in new_access_history + ] + trimmed_history_dt = self.actr_calculator.trim_access_history( + access_history=access_history_dt, + current_time=current_time + ) + trimmed_history = [ts.isoformat() for ts in trimmed_history_dt] + + # 计算新的激活值 + activation_value = self.actr_calculator.calculate_memory_activation( + access_history=trimmed_history_dt, + current_time=current_time, + last_access_time=current_time, # 最后访问时间就是当前时间 + importance_score=importance_score + ) + + # 返回所有需要更新的字段 + return { + 'activation_value': activation_value, + 'access_history': trimmed_history, + 'last_access_time': current_time_iso, + 'access_count': len(trimmed_history) + } + + async def _atomic_update( + self, + node_id: str, + node_label: str, + update_data: Dict[str, Any], + group_id: Optional[str] = None + ) -> Dict[str, Any]: + """ + 原子性更新节点(使用乐观锁) + + 使用Neo4j事务和版本号确保所有字段同时更新或回滚。 + 实现乐观锁机制防止并发冲突。 + + Args: + node_id: 节点ID + node_label: 节点标签 + update_data: 更新数据 + group_id: 组ID(可选) + + Returns: + Dict[str, Any]: 更新后的节点数据 + + Raises: + RuntimeError: 如果更新失败或发生版本冲突 + """ + # 定义事务函数 + async def update_transaction(tx, node_id, node_label, update_data, group_id): + # 步骤1:读取当前节点并获取版本号 + read_query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + """ + if group_id: + read_query += " WHERE n.group_id = $group_id" + read_query += """ + RETURN n.id as id, + n.version as version, + n.activation_value as activation_value, + n.access_history as access_history, + n.last_access_time as last_access_time, + n.access_count as access_count, + n.importance_score as importance_score + """ + + read_params = {'node_id': node_id} + if group_id: + read_params['group_id'] = group_id + + read_result = await tx.run(read_query, **read_params) + current_node = await read_result.single() + + if not current_node: + raise RuntimeError(f"Node not found: {node_label}[{node_id}]") + + # 获取当前版本号(如果不存在则为0) + current_version = current_node.get('version', 0) or 0 + new_version = current_version + 1 + + # 步骤2:使用乐观锁更新节点 + # 只有当版本号匹配时才更新 + update_query = f""" + MATCH (n:{node_label} {{id: $node_id}}) + """ + if group_id: + update_query += " WHERE n.group_id = $group_id" + + # 添加版本检查 + if current_version > 0: + update_query += " AND n.version = $current_version" + else: + # 如果节点没有版本号,检查是否为首次更新 + update_query += " AND (n.version IS NULL OR n.version = 0)" + + update_query += """ + SET n.activation_value = $activation_value, + n.access_history = $access_history, + n.last_access_time = $last_access_time, + n.access_count = $access_count, + n.version = $new_version + RETURN n.id as id, + n.activation_value as activation_value, + n.access_history as access_history, + n.last_access_time as last_access_time, + n.access_count as access_count, + n.importance_score as importance_score, + n.version as version + """ + + update_params = { + 'node_id': node_id, + 'current_version': current_version, + 'new_version': new_version, + 'activation_value': update_data['activation_value'], + 'access_history': update_data['access_history'], + 'last_access_time': update_data['last_access_time'], + 'access_count': update_data['access_count'] + } + if group_id: + update_params['group_id'] = group_id + + update_result = await tx.run(update_query, **update_params) + updated_node = await update_result.single() + + if not updated_node: + raise RuntimeError( + f"Version conflict detected for {node_label}[{node_id}]. " + f"Expected version {current_version}, but node was modified by another transaction." + ) + + return dict(updated_node) + + # 执行事务 + try: + result = await self.connector.execute_write_transaction( + update_transaction, + node_id=node_id, + node_label=node_label, + update_data=update_data, + group_id=group_id + ) + return result + except Exception as e: + logger.error( + f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}" + ) + raise RuntimeError( + f"Failed to atomically update node: {str(e)}" + ) from e diff --git a/api/app/core/memory/storage_services/forgetting_engine/actr_calculator.py b/api/app/core/memory/storage_services/forgetting_engine/actr_calculator.py new file mode 100644 index 00000000..e00afeb9 --- /dev/null +++ b/api/app/core/memory/storage_services/forgetting_engine/actr_calculator.py @@ -0,0 +1,359 @@ +""" +ACT-R Memory Activation Calculator + +This module implements the unified Memory Activation model based on ACT-R +(Adaptive Control of Thought-Rational) cognitive architecture theory. + +The calculator integrates BLA (Base-Level Activation) computation into the +Memory Activation formula, providing a single coherent model for memory strength +calculation that reflects both recency and frequency of access. + +Formula: R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d))) + +Where: +- R(i): Memory activation value (0 to 1) +- offset: Minimum retention rate (prevents complete forgetting) +- λ: Forgetting rate (lambda_time / lambda_mem) +- t: Time since last access +- I: Importance score (0 to 1) +- t_k: Time since k-th access +- d: Decay constant (typically 0.5) + +Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe? +""" + +import math +from typing import List, Dict, Any, Optional +from datetime import datetime, timedelta + + +class ACTRCalculator: + """ + Unified ACT-R Memory Activation Calculator. + + This calculator implements the Memory Activation model that combines + recency and frequency effects into a single activation value computation. + It replaces the separate BLA calculation with an integrated approach. + + Attributes: + decay_constant: Decay parameter d (typically 0.5) + forgetting_rate: Lambda parameter λ controlling forgetting speed + offset: Minimum retention rate (baseline memory strength) + max_history_length: Maximum number of access records to keep + """ + + def __init__( + self, + decay_constant: float = 0.5, + forgetting_rate: float = 0.3, + offset: float = 0.1, + max_history_length: int = 100 + ): + """ + Initialize the ACT-R calculator. + + Args: + decay_constant: Decay parameter d (default 0.5) + forgetting_rate: Forgetting rate λ (default 0.3) + offset: Minimum retention rate (default 0.1) + max_history_length: Maximum access history length (default 100) + """ + self.decay_constant = decay_constant + self.forgetting_rate = forgetting_rate + self.offset = offset + self.max_history_length = max_history_length + + def calculate_memory_activation( + self, + access_history: List[datetime], + current_time: datetime, + last_access_time: datetime, + importance_score: float = 0.5 + ) -> float: + """ + Calculate memory activation value using the unified Memory Activation formula. + + This method computes R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d))) + + The formula integrates: + - Recency effect: Recent accesses contribute more (via t) + - Frequency effect: Multiple accesses strengthen memory (via Σ) + - Importance weighting: Important memories decay slower (via I) + + Args: + access_history: List of access timestamps (ISO format or datetime objects) + current_time: Current time for calculation + last_access_time: Time of most recent access + importance_score: Importance weight (0 to 1, default 0.5) + + Returns: + float: Memory activation value between offset and 1.0 + + Raises: + ValueError: If access_history is empty or contains invalid data + """ + if not access_history: + raise ValueError("access_history cannot be empty") + + if not (0.0 <= importance_score <= 1.0): + raise ValueError(f"importance_score must be between 0 and 1, got {importance_score}") + + # Calculate time since last access (in days) + time_since_last = (current_time - last_access_time).total_seconds() / 86400.0 + time_since_last = max(time_since_last, 0.0001) # Avoid division by zero + + # Calculate BLA component: Σ(I·t_k^(-d)) + bla_sum = 0.0 + for access_time in access_history: + # Calculate time since this access (in days) + time_diff = (current_time - access_time).total_seconds() / 86400.0 + time_diff = max(time_diff, 0.0001) # Avoid division by zero + + # Add weighted power-law term: I * t_k^(-d) + bla_sum += importance_score * (time_diff ** (-self.decay_constant)) + + # Avoid division by zero in case of numerical issues + if bla_sum <= 0: + bla_sum = 0.0001 + + # Calculate Memory Activation: R(i) = offset + (1-offset) * exp(-λ*t / BLA) + exponent = -self.forgetting_rate * time_since_last / bla_sum + + # Clamp exponent to avoid numerical overflow/underflow + exponent = max(min(exponent, 100), -100) + + activation = self.offset + (1 - self.offset) * math.exp(exponent) + + # Ensure activation is within valid range [offset, 1.0] + return max(self.offset, min(1.0, activation)) + + def trim_access_history( + self, + access_history: List[datetime], + current_time: datetime + ) -> List[datetime]: + """ + Intelligently trim access history to prevent unbounded growth. + + Strategy: + - Keep all records if under max_history_length + - If over limit, keep most recent 50% and sample from older records + - Preserves both recent accesses (high importance) and historical pattern + + Args: + access_history: List of access timestamps (sorted or unsorted) + current_time: Current time for calculation + + Returns: + List[datetime]: Trimmed access history + """ + if len(access_history) <= self.max_history_length: + return access_history + + # Sort by time (most recent first) + sorted_history = sorted(access_history, reverse=True) + + # Calculate split point (keep most recent 50%) + keep_recent_count = self.max_history_length // 2 + + # Keep most recent 50% + recent_records = sorted_history[:keep_recent_count] + + # Sample from older records + older_records = sorted_history[keep_recent_count:] + sample_count = self.max_history_length - keep_recent_count + + if len(older_records) <= sample_count: + # If older records fit, keep them all + sampled_older = older_records + else: + # Sample evenly from older records + step = len(older_records) / sample_count + sampled_older = [ + older_records[int(i * step)] + for i in range(sample_count) + ] + + # Combine and return + trimmed_history = recent_records + sampled_older + return sorted(trimmed_history, reverse=True) + + def get_forgetting_curve( # 预测激活值,决定复习;测试不同配置效果,选择合适的d + self, + initial_time: datetime, + importance_score: float = 0.5, + days: int = 60 + ) -> List[Dict[str, Any]]: + """ + Generate forgetting curve data for visualization. + + This method simulates how memory activation decays over time + for a single initial access, useful for understanding and + visualizing the forgetting behavior. + + Args: + initial_time: Time of initial memory creation/access + importance_score: Importance weight (0 to 1, default 0.5) + days: Number of days to simulate (default 60) + + Returns: + List of dictionaries with keys: + - 'day': Day number (0 to days) + - 'activation': Memory activation value + - 'retention_rate': Same as activation (for compatibility) + """ + curve_data = [] + access_history = [initial_time] + + for day in range(days + 1): + current_time = initial_time + timedelta(days=day) + + try: + activation = self.calculate_memory_activation( + access_history=access_history, + current_time=current_time, + last_access_time=initial_time, + importance_score=importance_score + ) + except ValueError: + # Handle edge cases + activation = self.offset + + curve_data.append({ + 'day': day, + 'activation': activation, + 'retention_rate': activation # Alias for compatibility + }) + + return curve_data + + def calculate_forgetting_score( + self, + access_history: List[datetime], + current_time: datetime, + last_access_time: datetime, + importance_score: float = 0.5 + ) -> float: + """ + Calculate forgetting score (inverse of activation). + + Forgetting score = 1 - activation value + Higher score means more likely to be forgotten. + + Args: + access_history: List of access timestamps + current_time: Current time for calculation + last_access_time: Time of most recent access + importance_score: Importance weight (0 to 1, default 0.5) + + Returns: + float: Forgetting score between 0 and (1 - offset) + """ + activation = self.calculate_memory_activation( + access_history=access_history, + current_time=current_time, + last_access_time=last_access_time, + importance_score=importance_score + ) + return 1.0 - activation + + def should_forget( + self, + access_history: List[datetime], + current_time: datetime, + last_access_time: datetime, + importance_score: float = 0.5, + threshold: float = 0.3 + ) -> bool: + """ + Determine if a memory should be forgotten based on activation threshold. + + Args: + access_history: List of access timestamps + current_time: Current time for calculation + last_access_time: Time of most recent access + importance_score: Importance weight (0 to 1, default 0.5) + threshold: Activation threshold below which memory should be forgotten + + Returns: + bool: True if activation < threshold (should forget), False otherwise + """ + activation = self.calculate_memory_activation( + access_history=access_history, + current_time=current_time, + last_access_time=last_access_time, + importance_score=importance_score + ) + return activation < threshold + + +# Convenience functions for quick calculations +def calculate_activation( + access_history: List[datetime], + current_time: datetime, + last_access_time: datetime, + importance_score: float = 0.5, + decay_constant: float = 0.5, + forgetting_rate: float = 0.3, + offset: float = 0.1 +) -> float: + """ + Quick function to calculate activation without creating a calculator instance. + + Args: + access_history: List of access timestamps + current_time: Current time for calculation + last_access_time: Time of most recent access + importance_score: Importance weight (0 to 1, default 0.5) + decay_constant: Decay parameter d (default 0.5) + forgetting_rate: Forgetting rate λ (default 0.3) + offset: Minimum retention rate (default 0.1) + + Returns: + float: Memory activation value between offset and 1.0 + """ + calculator = ACTRCalculator( + decay_constant=decay_constant, + forgetting_rate=forgetting_rate, + offset=offset + ) + return calculator.calculate_memory_activation( + access_history=access_history, + current_time=current_time, + last_access_time=last_access_time, + importance_score=importance_score + ) + + +def generate_forgetting_curve( + initial_time: datetime, + importance_score: float = 0.5, + days: int = 60, + decay_constant: float = 0.5, + forgetting_rate: float = 0.3, + offset: float = 0.1 +) -> List[Dict[str, Any]]: + """ + Quick function to generate forgetting curve data. + + Args: + initial_time: Time of initial memory creation/access + importance_score: Importance weight (0 to 1, default 0.5) + days: Number of days to simulate (default 60) + decay_constant: Decay parameter d (default 0.5) + forgetting_rate: Forgetting rate λ (default 0.3) + offset: Minimum retention rate (default 0.1) + + Returns: + List of dictionaries with forgetting curve data + """ + calculator = ACTRCalculator( + decay_constant=decay_constant, + forgetting_rate=forgetting_rate, + offset=offset + ) + return calculator.get_forgetting_curve( + initial_time=initial_time, + importance_score=importance_score, + days=days + ) diff --git a/api/app/core/memory/storage_services/forgetting_engine/config_utils.py b/api/app/core/memory/storage_services/forgetting_engine/config_utils.py new file mode 100644 index 00000000..ea9a6358 --- /dev/null +++ b/api/app/core/memory/storage_services/forgetting_engine/config_utils.py @@ -0,0 +1,195 @@ +""" +遗忘引擎配置工具模块 + +本模块提供从数据库加载配置并创建遗忘引擎组件的辅助函数。 + +Functions: + calculate_forgetting_rate: 计算遗忘速率(lambda_time / lambda_mem) + load_actr_config_from_db: 从数据库加载 ACT-R 配置参数 + create_actr_calculator_from_config: 从配置创建 ACTRCalculator 实例 +""" + +import logging +from typing import Optional, Dict, Any +from sqlalchemy.orm import Session + +from app.repositories.data_config_repository import DataConfigRepository +from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator + + +logger = logging.getLogger(__name__) + + +def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float: + """ + 计算遗忘速率 + + 公式:forgetting_rate = lambda_time / lambda_mem + + 这个计算将两个独立的 lambda 参数组合成一个统一的遗忘速率参数, + 用于 ACT-R 激活值计算。 + + Args: + lambda_time: 时间衰减参数(0-1) + lambda_mem: 记忆衰减参数(0-1) + + Returns: + float: 遗忘速率 + + Raises: + ValueError: 如果 lambda_mem 为 0 + + Examples: + >>> calculate_forgetting_rate(0.5, 0.5) + 1.0 + >>> calculate_forgetting_rate(0.3, 0.5) + 0.6 + """ + if lambda_mem == 0: + raise ValueError("lambda_mem 不能为 0") + + forgetting_rate = lambda_time / lambda_mem + + logger.debug( + f"计算遗忘速率: lambda_time={lambda_time}, " + f"lambda_mem={lambda_mem}, " + f"forgetting_rate={forgetting_rate:.4f}" + ) + + return forgetting_rate + + +def load_actr_config_from_db( + db: Session, + config_id: Optional[int] = None +) -> Dict[str, Any]: + """ + 从数据库加载 ACT-R 配置参数 + + 从 PostgreSQL 的 data_config 表读取配置参数, + 并计算派生参数(如 forgetting_rate)。 + + Args: + db: 数据库会话 + config_id: 配置 ID(可选,如果为 None 则使用默认值) + + Returns: + Dict[str, Any]: 配置参数字典,包含: + - decay_constant: 衰减常数 d + - lambda_time: 时间衰减参数 + - lambda_mem: 记忆衰减参数 + - forgetting_rate: 遗忘速率(根据 lambda_time / lambda_mem 计算得出) + - offset: 偏移量 + - max_history_length: 访问历史最大长度 + - forgetting_threshold: 遗忘阈值 + - min_days_since_access: 最小未访问天数 + - enable_llm_summary: 是否使用 LLM 生成摘要 + - max_merge_batch_size: 单次最大融合节点对数 + - forgetting_interval_hours: 遗忘周期间隔 + + 注意:llm_id 不包含在返回的配置中,需要时由 forgetting_strategy 直接从数据库读取 + + Raises: + ValueError: 如果指定的 config_id 不存在 + """ + # 必须指定 config_id + if config_id is None: + logger.error("未指定 config_id,无法加载配置") + raise ValueError("config_id 不能为空,必须指定一个有效的配置 ID") + + # 从数据库加载配置 + try: + repository = DataConfigRepository() + db_config = repository.get_by_id(db, config_id) + + if db_config is None: + logger.error(f"配置不存在: config_id={config_id}") + raise ValueError(f"配置不存在: config_id={config_id}") + + # 读取配置参数(信任数据库默认值) + lambda_time = db_config.lambda_time + lambda_mem = db_config.lambda_mem + decay_constant = db_config.decay_constant + offset = db_config.offset + max_history_length = db_config.max_history_length + forgetting_threshold = db_config.forgetting_threshold + min_days_since_access = db_config.min_days_since_access + enable_llm_summary = db_config.enable_llm_summary + max_merge_batch_size = db_config.max_merge_batch_size + forgetting_interval_hours = db_config.forgetting_interval_hours + + # 计算 forgetting_rate + forgetting_rate = calculate_forgetting_rate(lambda_time, lambda_mem) + + config = { + 'decay_constant': decay_constant, + 'lambda_time': lambda_time, + 'lambda_mem': lambda_mem, + 'forgetting_rate': forgetting_rate, + 'offset': offset, + 'max_history_length': max_history_length, + 'forgetting_threshold': forgetting_threshold, + 'min_days_since_access': min_days_since_access, + 'enable_llm_summary': enable_llm_summary, + 'max_merge_batch_size': max_merge_batch_size, + 'forgetting_interval_hours': forgetting_interval_hours + # 注意:llm_id 不包含在配置响应中,仅在内部使用 + } + + logger.info( + f"成功加载 ACT-R 配置: config_id={config_id}, " + f"forgetting_rate={forgetting_rate:.4f}" + ) + + return config + + except Exception as e: + logger.error(f"加载 ACT-R 配置失败: config_id={config_id}, 错误: {str(e)}") + raise + + +def create_actr_calculator_from_config( + db: Session, + config_id: Optional[int] = None +) -> ACTRCalculator: + """ + 从数据库配置创建 ACTRCalculator 实例 + + 这是创建 ACTRCalculator 的推荐方式,确保使用数据库中的配置参数。 + + Args: + db: 数据库会话 + config_id: 配置 ID(可选,如果为 None 则使用默认值) + + Returns: + ACTRCalculator: 配置好的 ACT-R 计算器实例 + + Raises: + ValueError: 如果指定的 config_id 不存在 + + Examples: + >>> from sqlalchemy.orm import Session + >>> db = Session() + >>> calculator = create_actr_calculator_from_config(db, config_id=1) + >>> # 使用计算器 + >>> activation = calculator.calculate_memory_activation(...) + """ + # 加载配置 + config = load_actr_config_from_db(db, config_id) + + # 创建计算器 + calculator = ACTRCalculator( + decay_constant=config['decay_constant'], + forgetting_rate=config['forgetting_rate'], + offset=config['offset'], + max_history_length=config['max_history_length'] + ) + + logger.info( + f"创建 ACTRCalculator: config_id={config_id}, " + f"decay_constant={config['decay_constant']}, " + f"forgetting_rate={config['forgetting_rate']:.4f}, " + f"offset={config['offset']}" + ) + + return calculator diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py new file mode 100644 index 00000000..6d42af53 --- /dev/null +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py @@ -0,0 +1,351 @@ +""" +遗忘调度器模块 + +本模块实现遗忘周期的调度和管理,负责: +1. 手动触发遗忘周期 +2. 批量处理可遗忘节点(限制批量大小) +3. 按激活值优先级排序(激活值最低的优先) +4. 进度跟踪和日志记录 +5. 生成遗忘报告 + +注意:定期调度功能已迁移到 Celery Beat,见 app/tasks.py 中的 run_forgetting_cycle_task + +Classes: + ForgettingScheduler: 遗忘调度器,提供遗忘周期管理功能 +""" + +import logging +from typing import Dict, Any, Optional +from datetime import datetime + +from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + +logger = logging.getLogger(__name__) + + +class ForgettingScheduler: + """ + 遗忘调度器 + + 管理遗忘周期的执行,实现批量处理、优先级排序和进度跟踪功能。 + + 核心功能: + 1. 运行遗忘周期:识别可遗忘节点并批量融合 + 2. 优先级排序:优先处理激活值最低的节点对 + 3. 批量限制:限制单次处理的节点对数量 + 4. 进度跟踪:每完成 10% 记录一次日志 + 5. 遗忘报告:生成详细的执行报告 + + 注意:定期调度功能已迁移到 Celery Beat 定时任务 + + Attributes: + forgetting_strategy: 遗忘策略执行器实例 + connector: Neo4j 连接器实例 + is_running: 是否正在运行遗忘周期 + """ + + def __init__( + self, + forgetting_strategy: ForgettingStrategy, + connector: Neo4jConnector + ): + """ + 初始化遗忘调度器 + + Args: + forgetting_strategy: 遗忘策略执行器实例 + connector: Neo4j 连接器实例 + """ + self.forgetting_strategy = forgetting_strategy + self.connector = connector + self.is_running = False + + logger.info("初始化遗忘调度器") + + async def run_forgetting_cycle( + self, + group_id: Optional[str] = None, + max_merge_batch_size: int = 100, + min_days_since_access: int = 30, + config_id: Optional[int] = None, + db = None + ) -> Dict[str, Any]: + """ + 运行一次完整的遗忘周期 + + + Args: + group_id: 组 ID(可选,用于过滤特定组的节点) + max_merge_batch_size: 单次最大融合节点对数(默认 100) + min_days_since_access: 最小未访问天数(默认 30 天) + config_id: 配置ID(可选,用于获取 llm_id) + db: 数据库会话(可选,用于获取 llm_id) + + Returns: + Dict[str, Any]: 遗忘报告,包含: + - merged_count: 融合的节点对数量 + - nodes_before: 遗忘前的节点总数 + - nodes_after: 遗忘后的节点总数 + - reduction_rate: 节点减少率(0-1) + - duration_seconds: 执行耗时(秒) + - start_time: 开始时间(ISO 格式) + - end_time: 结束时间(ISO 格式) + - failed_count: 失败的融合数量 + - success_rate: 成功率(0-1) + + Raises: + RuntimeError: 如果已有遗忘周期正在运行 + """ + # 检查是否已有遗忘周期在运行 + if self.is_running: + raise RuntimeError("遗忘周期已在运行中,请等待当前周期完成") + + self.is_running = True + start_time = datetime.now() + start_time_iso = start_time.isoformat() + + logger.info( + f"开始遗忘周期: group_id={group_id}, " + f"max_batch={max_merge_batch_size}, " + f"min_days={min_days_since_access}" + ) + + try: + # 步骤1:统计遗忘前的节点数量 + nodes_before = await self._count_knowledge_nodes(group_id) + logger.info(f"遗忘前节点总数: {nodes_before}") + + # 步骤2:识别可遗忘的节点对 + forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes( + group_id=group_id, + min_days_since_access=min_days_since_access + ) + + total_forgettable = len(forgettable_pairs) + logger.info(f"识别到 {total_forgettable} 个可遗忘节点对") + + if total_forgettable == 0: + logger.info("没有可遗忘的节点对,遗忘周期结束") + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() + + report = { + 'merged_count': 0, + 'nodes_before': nodes_before, + 'nodes_after': nodes_before, + 'reduction_rate': 0.0, + 'duration_seconds': duration, + 'start_time': start_time_iso, + 'end_time': end_time.isoformat(), + 'failed_count': 0, + 'success_rate': 1.0 + } + + logger.info("没有可遗忘的节点对,遗忘周期结束") + + return report + + # 步骤3:按激活值排序(激活值最低的优先) + # avg_activation 已经在 find_forgettable_nodes 中计算并排序 + # 这里只需要确认排序是正确的(升序) + sorted_pairs = sorted( + forgettable_pairs, + key=lambda x: x['avg_activation'] + ) + + # 步骤4:限制批量大小 + pairs_to_process = sorted_pairs[:max_merge_batch_size] + actual_batch_size = len(pairs_to_process) + + logger.info( + f"将处理 {actual_batch_size} 个节点对 " + f"(限制: {max_merge_batch_size})" + ) + + # 步骤5:批量融合节点,每 10% 记录进度 + merged_count = 0 + failed_count = 0 + skipped_count = 0 # 跳过的节点对数量(节点已被处理) + progress_interval = max(1, actual_batch_size // 10) # 每 10% 记录一次 + + # 跟踪已处理的节点 ID,避免重复处理 + processed_statement_ids = set() + processed_entity_ids = set() + + # 预先过滤掉重复的节点对 + unique_pairs = [] + for pair in pairs_to_process: + statement_id = pair['statement_id'] + entity_id = pair['entity_id'] + + # 如果节点已被标记为处理,跳过 + if statement_id in processed_statement_ids or entity_id in processed_entity_ids: + skipped_count += 1 + logger.debug( + f"预过滤:跳过重复节点对 Statement[{statement_id}] + Entity[{entity_id}]" + ) + continue + + # 标记节点为已处理 + processed_statement_ids.add(statement_id) + processed_entity_ids.add(entity_id) + unique_pairs.append(pair) + + logger.info( + f"预过滤完成:原始 {actual_batch_size} 对,去重后 {len(unique_pairs)} 对," + f"跳过 {skipped_count} 对重复节点" + ) + + # 更新实际处理的批次大小 + actual_batch_size = len(unique_pairs) + progress_interval = max(1, actual_batch_size // 10) # 重新计算进度间隔 + + for idx, pair in enumerate(unique_pairs, start=1): + statement_id = pair['statement_id'] + entity_id = pair['entity_id'] + + try: + # 准备节点数据 + statement_node = { + 'statement_id': statement_id, + 'statement_text': pair['statement_text'], + 'statement_activation': pair['statement_activation'], + 'statement_importance': pair['statement_importance'], + 'group_id': group_id + } + + entity_node = { + 'entity_id': entity_id, + 'entity_name': pair['entity_name'], + 'entity_type': pair['entity_type'], + 'entity_activation': pair['entity_activation'], + 'entity_importance': pair['entity_importance'], + 'group_id': group_id + } + + # 融合节点 + await self.forgetting_strategy.merge_nodes_to_summary( + statement_node=statement_node, + entity_node=entity_node, + config_id=config_id, + db=db + ) + + merged_count += 1 + + # 进度跟踪:每 10% 记录一次 + if actual_batch_size > 0 and (idx % progress_interval == 0 or idx == actual_batch_size): + progress_pct = (idx / actual_batch_size) * 100 + logger.info( + f"遗忘进度: {idx}/{actual_batch_size} " + f"({progress_pct:.1f}%), " + f"已融合: {merged_count}, 失败: {failed_count}" + ) + + except Exception as e: + failed_count += 1 + # 检查是否是节点不存在的错误 + if "nodes may not exist" in str(e): + logger.warning( + f"节点对 ({idx}/{actual_batch_size}) 的节点不存在(可能已被其他操作删除): " + f"Statement[{statement_id}] + Entity[{entity_id}]" + ) + else: + logger.error( + f"融合节点对失败 ({idx}/{actual_batch_size}): " + f"Statement[{statement_id}] + Entity[{entity_id}], " + f"错误: {str(e)}" + ) + # 继续处理剩余节点 + continue + + # 步骤6:统计遗忘后的节点数量 + nodes_after = await self._count_knowledge_nodes(group_id) + logger.info(f"遗忘后节点总数: {nodes_after}") + + # 步骤7:生成遗忘报告 + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() + + # 计算节点减少率 + if nodes_before > 0: + reduction_rate = (nodes_before - nodes_after) / nodes_before + else: + reduction_rate = 0.0 + + # 计算成功率 + if actual_batch_size > 0: + success_rate = merged_count / actual_batch_size + else: + success_rate = 1.0 + + report = { + 'merged_count': merged_count, + 'nodes_before': nodes_before, + 'nodes_after': nodes_after, + 'reduction_rate': reduction_rate, + 'duration_seconds': duration, + 'start_time': start_time_iso, + 'end_time': end_time.isoformat(), + 'failed_count': failed_count, + 'success_rate': success_rate + } + + logger.info( + f"遗忘周期完成: " + f"融合 {merged_count} 对节点, " + f"失败 {failed_count} 对, " + f"节点减少 {nodes_before - nodes_after} 个 " + f"({reduction_rate:.2%}), " + f"耗时 {duration:.2f} 秒" + ) + + return report + + except Exception as e: + logger.error(f"遗忘周期执行失败: {str(e)}") + raise + + finally: + self.is_running = False + + # ==================== 私有辅助方法 ==================== + + async def _count_knowledge_nodes( + self, + group_id: Optional[str] = None + ) -> int: + """ + 统计知识层节点总数 + + 统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。 + + Args: + group_id: 组 ID(可选,用于过滤特定组的节点) + + Returns: + int: 知识层节点总数 + """ + query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) + """ + + if group_id: + query += " AND n.group_id = $group_id" + + query += """ + RETURN count(n) as total + """ + + params = {} + if group_id: + params['group_id'] = group_id + + results = await self.connector.execute_query(query, **params) + + if results: + return results[0]['total'] + return 0 diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py new file mode 100644 index 00000000..5e1e35da --- /dev/null +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_strategy.py @@ -0,0 +1,611 @@ +""" +遗忘策略执行器模块 + +本模块实现基于 ACT-R 激活值的遗忘策略,负责: +1. 识别低激活值的节点对(Statement-Entity) +2. 将低激活值节点融合为 MemorySummary 节点 +3. 使用 LLM 生成高质量摘要(可选) +4. 保留溯源信息并删除原始节点 + +Classes: + ForgettingStrategy: 遗忘策略执行器,提供节点识别和融合功能 +""" + +import logging +from typing import List, Dict, Any, Optional +from datetime import datetime, timedelta + +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator + + +logger = logging.getLogger(__name__) + + +class ForgettingStrategy: + """ + 遗忘策略执行器 + + 基于 ACT-R 激活值识别和融合低价值记忆节点。 + 实现了完整的遗忘周期:识别 → 融合 → 删除。 + + 核心功能: + 1. 识别可遗忘节点:激活值低于阈值且长期未访问的 Statement-Entity 对 + 2. 节点融合:创建 MemorySummary 节点,继承较高的激活值和重要性 + 3. LLM 摘要生成:使用 LLM 生成语义摘要(可降级到简单拼接) + 4. 溯源保留:记录原始节点 ID,保持可追溯性 + + Attributes: + connector: Neo4j 连接器实例 + actr_calculator: ACT-R 激活值计算器实例 + forgetting_threshold: 遗忘阈值(激活值低于此值的节点可被遗忘) + """ + + def __init__( + self, + connector: Neo4jConnector, + actr_calculator: ACTRCalculator, + forgetting_threshold: float = 0.3, + enable_llm_summary: bool = True + ): + """ + 初始化遗忘策略执行器 + + Args: + connector: Neo4j 连接器实例 + actr_calculator: ACT-R 激活值计算器实例 + forgetting_threshold: 遗忘阈值(默认 0.3) + enable_llm_summary: 是否启用 LLM 摘要生成(默认 True) + """ + self.connector = connector + self.actr_calculator = actr_calculator + self.forgetting_threshold = forgetting_threshold + self.enable_llm_summary = enable_llm_summary + + logger.info( + f"初始化遗忘策略执行器: threshold={forgetting_threshold}, " + f"enable_llm_summary={enable_llm_summary}" + ) + + async def calculate_forgetting_score( + self, + activation_value: float + ) -> float: + """ + 计算遗忘分数 + + 遗忘分数 = 1 - 激活值 + 分数越高,越容易被遗忘。 + + 注意:激活值已经包含了 importance_score 的权重, + 因此不需要单独考虑重要性分数。 + + Args: + activation_value: 节点的激活值(0-1) + + Returns: + float: 遗忘分数(0-1),值越高越容易被遗忘 + """ + return 1.0 - activation_value + + async def find_forgettable_nodes( + self, + group_id: Optional[str] = None, + min_days_since_access: int = 30 + ) -> List[Dict[str, Any]]: + """ + 识别可遗忘的节点对 + + 查找满足以下条件的 Statement-Entity 节点对: + 1. 两个节点的激活值都低于遗忘阈值 + 2. 两个节点都至少 min_days_since_access 天未被访问 + 3. Statement 和 Entity 之间存在关系边 + + Args: + group_id: 组 ID(可选,用于过滤特定组的节点) + min_days_since_access: 最小未访问天数(默认 30 天) + + Returns: + List[Dict[str, Any]]: 可遗忘节点对列表,每个元素包含: + - statement_id: Statement 节点 ID + - statement_text: Statement 文本内容 + - statement_activation: Statement 激活值 + - statement_importance: Statement 重要性分数 + - statement_last_access: Statement 最后访问时间 + - entity_id: Entity 节点 ID + - entity_name: Entity 名称 + - entity_type: Entity 类型 + - entity_activation: Entity 激活值 + - entity_importance: Entity 重要性分数 + - entity_last_access: Entity 最后访问时间 + - avg_activation: 平均激活值(用于排序) + """ + # 计算时间阈值 + cutoff_time = datetime.now() - timedelta(days=min_days_since_access) + cutoff_time_iso = cutoff_time.isoformat() + + # 构建查询 + query = """ + MATCH (s:Statement)-[r]-(e:ExtractedEntity) + WHERE s.activation_value IS NOT NULL + AND e.activation_value IS NOT NULL + AND s.activation_value < $threshold + AND e.activation_value < $threshold + AND s.last_access_time < $cutoff_time + AND e.last_access_time < $cutoff_time + AND (e.entity_type IS NULL OR e.entity_type <> 'Person') + """ + + if group_id: + query += " AND s.group_id = $group_id AND e.group_id = $group_id" + + query += """ + RETURN s.id as statement_id, + s.statement as statement_text, + s.activation_value as statement_activation, + s.importance_score as statement_importance, + s.last_access_time as statement_last_access, + e.id as entity_id, + e.name as entity_name, + e.entity_type as entity_type, + e.activation_value as entity_activation, + e.importance_score as entity_importance, + e.last_access_time as entity_last_access, + (s.activation_value + e.activation_value) / 2.0 as avg_activation + ORDER BY avg_activation ASC + """ + + params = { + 'threshold': self.forgetting_threshold, + 'cutoff_time': cutoff_time_iso + } + if group_id: + params['group_id'] = group_id + + results = await self.connector.execute_query(query, **params) + + logger.info( + f"识别到 {len(results)} 个可遗忘节点对 " + f"(threshold={self.forgetting_threshold}, " + f"min_days={min_days_since_access})" + ) + + return results + + async def merge_nodes_to_summary( + self, + statement_node: Dict[str, Any], + entity_node: Dict[str, Any], + config_id: Optional[int] = None, + db = None + ) -> str: + """ + 将 Statement 和 Entity 节点融合为 MemorySummary 节点 + + 融合过程: + 1. 生成摘要内容(使用 LLM 或简单拼接) + 2. 创建 MemorySummary 节点,继承较高的激活值和重要性分数 + 3. 删除原始 Statement 和 Entity 节点 + 4. 保留溯源信息(original_statement_id, original_entity_id) + + Args: + statement_node: Statement 节点数据,必须包含: + - statement_id: 节点 ID + - statement_text: 文本内容 + - statement_activation: 激活值 + - statement_importance: 重要性分数 + entity_node: Entity 节点数据,必须包含: + - entity_id: 节点 ID + - entity_name: 实体名称 + - entity_type: 实体类型 + - entity_activation: 激活值 + - entity_importance: 重要性分数 + config_id: 配置ID(可选,用于获取 llm_id) + db: 数据库会话(可选,用于获取 llm_id) + + Returns: + str: 创建的 MemorySummary 节点 ID + + Raises: + ValueError: 如果节点数据不完整 + RuntimeError: 如果融合操作失败 + """ + # 验证输入数据 + required_statement_keys = [ + 'statement_id', 'statement_text', + 'statement_activation', 'statement_importance' + ] + required_entity_keys = [ + 'entity_id', 'entity_name', 'entity_type', + 'entity_activation', 'entity_importance' + ] + + for key in required_statement_keys: + if key not in statement_node: + raise ValueError(f"Statement 节点缺少必需字段: {key}") + + for key in required_entity_keys: + if key not in entity_node: + raise ValueError(f"Entity 节点缺少必需字段: {key}") + + # 验证实体类型:不允许融合 Person 类型的实体 + if entity_node.get('entity_type') == 'Person': + raise ValueError( + f"不允许融合 Person 类型的实体: entity_id={entity_node.get('entity_id')}, " + f"entity_name={entity_node.get('entity_name')}" + ) + + # 提取节点信息 + statement_id = statement_node['statement_id'] + statement_text = statement_node['statement_text'] + statement_activation = statement_node['statement_activation'] + statement_importance = statement_node['statement_importance'] + + entity_id = entity_node['entity_id'] + entity_name = entity_node['entity_name'] + entity_type = entity_node['entity_type'] + entity_activation = entity_node['entity_activation'] + entity_importance = entity_node['entity_importance'] + + # 生成摘要内容 + summary_text = await self._generate_summary( + statement_text=statement_text, + entity_name=entity_name, + entity_type=entity_type, + config_id=config_id, + db=db + ) + + # 计算继承的激活值和重要性(取较高值) + inherited_activation = max(statement_activation, entity_activation) + inherited_importance = max(statement_importance, entity_importance) + + # 创建 MemorySummary 节点 + current_time = datetime.now() + current_time_iso = current_time.isoformat() + + # 生成新的 MemorySummary ID + import uuid + summary_id = f"summary_{uuid.uuid4().hex[:16]}" + + # 获取 group_id(从 statement 或 entity 节点) + group_id = statement_node.get('group_id') or entity_node.get('group_id') + + # 使用事务创建 MemorySummary 并删除原节点 + async def merge_transaction(tx, **params): + """事务函数:创建摘要节点并删除原节点""" + query = """ + // 首先检查节点是否存在 + OPTIONAL MATCH (s:Statement {id: $statement_id}) + OPTIONAL MATCH (e:ExtractedEntity {id: $entity_id}) + + // 如果任一节点不存在,直接返回 null(不执行后续操作) + WITH s, e + WHERE s IS NOT NULL AND e IS NOT NULL + + // 创建 MemorySummary 节点 + CREATE (ms:MemorySummary { + id: $summary_id, + summary: $summary_text, + original_statement_id: $statement_id, + original_entity_id: $entity_id, + activation_value: $inherited_activation, + importance_score: $inherited_importance, + access_history: [$current_time], + last_access_time: $current_time, + access_count: 1, + version: 1, + group_id: $group_id, + created_at: datetime($current_time), + merged_at: datetime($current_time) + }) + + // 转移 Statement 的出边到 MemorySummary(只转移目标节点仍存在的边) + WITH ms, s, e + CALL (ms, s, e) { + OPTIONAL MATCH (s)-[r_out]->(target) + WHERE target <> e AND r_out IS NOT NULL AND target IS NOT NULL + FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END | + MERGE (ms)-[new_rel:DERIVED_FROM]->(target) + ON CREATE SET + new_rel = properties(r_out), + new_rel.original_relationship_type = type(r_out), + new_rel.merged_from_statement = true, + new_rel.merge_count = 1 + ON MATCH SET + new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1 + ) + } + + // 转移 Statement 的入边到 MemorySummary(只转移源节点仍存在的边) + WITH ms, s, e + CALL (ms, s, e) { + OPTIONAL MATCH (source)-[r_in]->(s) + WHERE r_in IS NOT NULL AND source IS NOT NULL + FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END | + MERGE (source)-[new_rel:DERIVED_FROM]->(ms) + ON CREATE SET + new_rel = properties(r_in), + new_rel.original_relationship_type = type(r_in), + new_rel.merged_from_statement = true, + new_rel.merge_count = 1 + ON MATCH SET + new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1 + ) + } + + // 转移 Entity 的出边到 MemorySummary(只转移目标节点仍存在的边) + WITH ms, s, e + CALL (ms, s, e) { + OPTIONAL MATCH (e)-[r_out]->(target) + WHERE target <> s AND r_out IS NOT NULL AND target IS NOT NULL + FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END | + MERGE (ms)-[new_rel:DERIVED_FROM]->(target) + ON CREATE SET + new_rel = properties(r_out), + new_rel.original_relationship_type = type(r_out), + new_rel.merged_from_entity = true, + new_rel.merge_count = 1 + ON MATCH SET + new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1 + ) + } + + // 转移 Entity 的入边到 MemorySummary(只转移源节点仍存在的边) + WITH ms, s, e + CALL (ms, s, e) { + OPTIONAL MATCH (source)-[r_in]->(e) + WHERE source <> s AND r_in IS NOT NULL AND source IS NOT NULL + FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END | + MERGE (source)-[new_rel:DERIVED_FROM]->(ms) + ON CREATE SET + new_rel = properties(r_in), + new_rel.original_relationship_type = type(r_in), + new_rel.merged_from_entity = true, + new_rel.merge_count = 1 + ON MATCH SET + new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1 + ) + } + + // 删除原始节点 + WITH ms, s, e + DETACH DELETE s, e + + RETURN ms.id as summary_id + """ + + result = await tx.run(query, **params) + record = await result.single() + + if not record: + raise RuntimeError("Failed to create MemorySummary node - nodes may not exist") + + return record['summary_id'] + + params = { + 'summary_id': summary_id, + 'summary_text': summary_text, + 'statement_id': statement_id, + 'entity_id': entity_id, + 'inherited_activation': inherited_activation, + 'inherited_importance': inherited_importance, + 'current_time': current_time_iso, + 'group_id': group_id + } + + try: + created_summary_id = await self.connector.execute_write_transaction( + merge_transaction, + **params + ) + + logger.info( + f"成功融合节点: Statement[{statement_id}] + Entity[{entity_id}] " + f"-> MemorySummary[{created_summary_id}], " + f"activation={inherited_activation:.4f}, " + f"importance={inherited_importance:.4f}" + ) + + return created_summary_id + + except Exception as e: + # 记录详细的错误信息,包括异常类型和堆栈 + import traceback + error_details = traceback.format_exc() + logger.error( + f"融合节点失败: Statement[{statement_id}] + Entity[{entity_id}], " + f"错误类型: {type(e).__name__}, " + f"错误信息: {str(e)}, " + f"详细堆栈:\n{error_details}" + ) + raise RuntimeError( + f"融合节点失败: {str(e)}" + ) from e + + # ==================== 私有辅助方法 ==================== + + async def _generate_summary( + self, + statement_text: str, + entity_name: str, + entity_type: str, + config_id: Optional[int] = None, + db = None + ) -> str: + """ + 生成摘要内容 + + 优先使用 LLM 生成高质量摘要,如果 LLM 不可用或失败, + 则降级到简单文本拼接。 + + Args: + statement_text: Statement 文本内容 + entity_name: Entity 名称 + entity_type: Entity 类型 + config_id: 配置ID(可选,用于获取 llm_id) + db: 数据库会话(可选,用于获取 llm_id) + + Returns: + str: 生成的摘要文本(最多 200 个字符) + """ + # 如果配置禁用 LLM 摘要,直接使用简单拼接 + if not self.enable_llm_summary: + logger.info("LLM 摘要生成已禁用,使用简单拼接") + return self._simple_concatenation( + statement_text, entity_name, entity_type + ) + + # 尝试获取 LLM 客户端 + llm_client = None + if config_id is not None and db is not None: + try: + llm_client = await self._get_llm_client(db, config_id) + except Exception as e: + logger.warning(f"获取 LLM 客户端失败: {str(e)}") + + # 如果没有 LLM 客户端,直接使用简单拼接 + if llm_client is None: + logger.info("未能获取 LLM 客户端,使用简单拼接") + return self._simple_concatenation( + statement_text, entity_name, entity_type + ) + + # 尝试使用 LLM 生成摘要 + try: + summary = await self._generate_llm_summary( + statement_text=statement_text, + entity_name=entity_name, + entity_type=entity_type, + llm_client=llm_client + ) + + # 限制长度为 200 个字符 + if len(summary) > 200: + summary = f"{summary[:197]}..." + + logger.info(f"使用 LLM 生成摘要: {summary}") + return summary + + except Exception as e: + logger.warning( + f"LLM 摘要生成失败,降级到简单拼接: {str(e)}" + ) + return self._simple_concatenation( + statement_text, entity_name, entity_type + ) + + async def _get_llm_client(self, db, config_id: int): + """ + 从数据库获取 LLM 客户端 + + Args: + db: 数据库会话 + config_id: 配置ID + + Returns: + LLM 客户端实例,如果无法获取则返回 None + """ + try: + from app.repositories.data_config_repository import DataConfigRepository + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + # 从数据库读取配置 + repository = DataConfigRepository() + db_config = repository.get_by_id(db, config_id) + + if db_config is None or db_config.llm_id is None: + logger.warning(f"配置 {config_id} 不存在或未设置 llm_id") + return None + + # 创建 LLM 客户端 + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(str(db_config.llm_id)) + + logger.info(f"成功获取 LLM 客户端: config_id={config_id}, llm_id={db_config.llm_id}") + return llm_client + + except Exception as e: + logger.error(f"获取 LLM 客户端失败: {str(e)}") + return None + + async def _generate_llm_summary( + self, + statement_text: str, + entity_name: str, + entity_type: str, + llm_client + ) -> str: + """ + 使用 LLM 生成高质量摘要 + + Args: + statement_text: Statement 文本内容 + entity_name: Entity 名称 + entity_type: Entity 类型 + llm_client: LLM 客户端实例 + + Returns: + str: LLM 生成的摘要文本 + + Raises: + Exception: 如果 LLM 调用失败 + """ + # 构建提示词 + prompt = f"""请为以下记忆片段生成一个简洁的摘要(不超过200个字符): + +实体名称: {entity_name} +实体类型: {entity_type} +陈述内容: {statement_text} + +要求: +1. 摘要应该保留核心语义信息 +2. 长度不超过200个字符 +3. 使用简洁、自然的中文表达 +4. 只返回摘要文本,不要包含其他内容 + +摘要:""" + + # 调用 LLM(直接传递 prompt 字符串) + response = await llm_client.chat(prompt) + + # 提取摘要文本 + if isinstance(response, str): + summary = response.strip() + elif hasattr(response, 'content'): + summary = response.content.strip() + else: + summary = str(response).strip() + + return summary + + def _simple_concatenation( + self, + statement_text: str, + entity_name: str, + entity_type: str + ) -> str: + """ + 简单文本拼接生成摘要 + + 降级策略:当 LLM 不可用时使用。 + 格式:[实体类型]实体名称: 陈述内容 + + Args: + statement_text: Statement 文本内容 + entity_name: Entity 名称 + entity_type: Entity 类型 + + Returns: + str: 拼接的摘要文本(最多 200 个字符) + """ + # 构建简单摘要 + summary = f"[{entity_type}]{entity_name}: {statement_text}" + + # 限制长度为 200 个字符(注意:这里的长度是字符数,不是字节数) + if len(summary) > 200: + # 截断并添加省略号 + summary = f"{summary[:197]}..." + + return summary + diff --git a/api/app/models/data_config_model.py b/api/app/models/data_config_model.py index 2914432e..67d789ea 100644 --- a/api/app/models/data_config_model.py +++ b/api/app/models/data_config_model.py @@ -65,6 +65,15 @@ class DataConfig(Base): lambda_mem = Column("lambda_mem", Float, default=0.5, comment="遗忘率,0-1 小数") offset = Column("offset", Float, default=0.0, comment="偏移度,0-1 小数") + # ACT-R 遗忘引擎配置 + decay_constant = Column(Float, default=0.5, comment="ACT-R衰减常数d,默认0.5") + forgetting_threshold = Column(Float, default=0.3, comment="遗忘阈值,默认0.3") + forgetting_interval_hours = Column(Integer, default=24, comment="遗忘周期间隔(小时),默认24") + enable_llm_summary = Column(Boolean, default=True, comment="是否使用LLM生成摘要,默认True") + max_merge_batch_size = Column(Integer, default=100, comment="单次最大融合节点对数,默认100") + max_history_length = Column(Integer, default=100, comment="访问历史最大长度,默认100") + min_days_since_access = Column(Integer, default=30, comment="最小未访问天数,默认30") + # 情绪引擎配置 emotion_enabled = Column(Boolean, default=True, comment="是否启用情绪提取") emotion_model_id = Column(String, nullable=True, comment="情绪分析专用模型ID") diff --git a/api/app/repositories/neo4j/add_nodes.py b/api/app/repositories/neo4j/add_nodes.py index ce4a6876..79466fa0 100644 --- a/api/app/repositories/neo4j/add_nodes.py +++ b/api/app/repositories/neo4j/add_nodes.py @@ -106,7 +106,13 @@ async def add_statement_nodes(statements: List[StatementNode], connector: Neo4jC "emotion_intensity": statement.emotion_intensity, "emotion_keywords": statement.emotion_keywords if statement.emotion_keywords else [], "emotion_subject": statement.emotion_subject, - "emotion_target": statement.emotion_target + "emotion_target": statement.emotion_target, + # 添加 ACT-R 记忆激活属性 + "importance_score": statement.importance_score, + "activation_value": statement.activation_value, + "access_history": statement.access_history if statement.access_history else [], + "last_access_time": statement.last_access_time, + "access_count": statement.access_count } flattened_statements.append(flattened_statement) diff --git a/api/app/repositories/neo4j/cypher_queries.py b/api/app/repositories/neo4j/cypher_queries.py index 02e96694..259b1325 100644 --- a/api/app/repositories/neo4j/cypher_queries.py +++ b/api/app/repositories/neo4j/cypher_queries.py @@ -38,7 +38,12 @@ SET s += { valid_at: statement.valid_at, invalid_at: statement.invalid_at, statement_embedding: statement.statement_embedding, - relevence_info: statement.relevence_info + relevence_info: statement.relevence_info, + importance_score: statement.importance_score, + activation_value: statement.activation_value, + access_history: statement.access_history, + last_access_time: statement.last_access_time, + access_count: statement.access_count } RETURN s.id AS uuid """ @@ -111,7 +116,12 @@ SET e.name = CASE WHEN entity.name IS NOT NULL AND entity.name <> '' THEN entity WHEN e.connect_strength IS NULL OR e.connect_strength = '' THEN entity.connect_strength ELSE e.connect_strength END - END + END, + e.importance_score = CASE WHEN entity.importance_score IS NOT NULL THEN entity.importance_score ELSE coalesce(e.importance_score, 0.5) END, + e.activation_value = CASE WHEN entity.activation_value IS NOT NULL THEN entity.activation_value ELSE e.activation_value END, + e.access_history = CASE WHEN entity.access_history IS NOT NULL THEN entity.access_history ELSE coalesce(e.access_history, []) END, + e.last_access_time = CASE WHEN entity.last_access_time IS NOT NULL THEN entity.last_access_time ELSE e.last_access_time END, + e.access_count = CASE WHEN entity.access_count IS NOT NULL THEN entity.access_count ELSE coalesce(e.access_count, 0) END RETURN e.id AS uuid """ @@ -225,6 +235,10 @@ RETURN e.id AS id, e.name AS name, e.group_id AS group_id, e.entity_type AS entity_type, + COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, + COALESCE(e.importance_score, 0.5) AS importance_score, + e.last_access_time AS last_access_time, + COALESCE(e.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit @@ -243,6 +257,10 @@ RETURN s.id AS id, s.expired_at AS expired_at, s.valid_at AS valid_at, s.invalid_at AS invalid_at, + COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value, + COALESCE(s.importance_score, 0.5) AS importance_score, + s.last_access_time AS last_access_time, + COALESCE(s.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit @@ -258,6 +276,9 @@ RETURN c.id AS chunk_id, c.group_id AS group_id, c.content AS content, c.dialog_id AS dialog_id, + COALESCE(c.activation_value, 0.5) AS activation_value, + c.last_access_time AS last_access_time, + COALESCE(c.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit @@ -278,6 +299,10 @@ RETURN s.id AS id, s.invalid_at AS invalid_at, c.id AS chunk_id_from_rel, collect(DISTINCT e.id) AS entity_ids, + COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value, + COALESCE(s.importance_score, 0.5) AS importance_score, + s.last_access_time AS last_access_time, + COALESCE(s.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit @@ -305,6 +330,10 @@ RETURN e.id AS id, e.connect_strength AS connect_strength, collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT c.id) AS chunk_ids, + COALESCE(e.activation_value, e.importance_score, 0.5) AS activation_value, + COALESCE(e.importance_score, 0.5) AS importance_score, + e.last_access_time AS last_access_time, + COALESCE(e.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit @@ -322,6 +351,9 @@ RETURN c.id AS chunk_id, c.sequence_number AS sequence_number, collect(DISTINCT s.id) AS statement_ids, collect(DISTINCT e.id) AS entity_ids, + COALESCE(c.activation_value, 0.5) AS activation_value, + c.last_access_time AS last_access_time, + COALESCE(c.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit @@ -419,7 +451,11 @@ RETURN s.id AS id, s.created_at AS created_at, s.valid_at AS valid_at, s.invalid_at AS invalid_at, - collect(DISTINCT s.id) AS statement_ids + collect(DISTINCT s.id) AS statement_ids, + COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value, + COALESCE(s.importance_score, 0.5) AS importance_score, + s.last_access_time AS last_access_time, + COALESCE(s.access_count, 0) AS access_count ORDER BY datetime(s.created_at) DESC LIMIT $limit """ @@ -446,6 +482,10 @@ RETURN s.id AS id, s.invalid_at AS invalid_at, c.id AS chunk_id_from_rel, collect(DISTINCT e.id) AS entity_ids, + COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value, + COALESCE(s.importance_score, 0.5) AS importance_score, + s.last_access_time AS last_access_time, + COALESCE(s.access_count, 0) AS access_count, score ORDER BY s.created_at DESC, score DESC LIMIT $limit @@ -635,6 +675,10 @@ RETURN m.id AS id, m.chunk_ids AS chunk_ids, m.content AS content, m.created_at AS created_at, + COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value, + COALESCE(m.importance_score, 0.5) AS importance_score, + m.last_access_time AS last_access_time, + COALESCE(m.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit @@ -653,6 +697,10 @@ RETURN m.id AS id, m.chunk_ids AS chunk_ids, m.content AS content, m.created_at AS created_at, + COALESCE(m.activation_value, m.importance_score, 0.5) AS activation_value, + COALESCE(m.importance_score, 0.5) AS importance_score, + m.last_access_time AS last_access_time, + COALESCE(m.access_count, 0) AS access_count, score ORDER BY score DESC LIMIT $limit diff --git a/api/app/repositories/neo4j/entity_repository.py b/api/app/repositories/neo4j/entity_repository.py index 87088ade..cb18feca 100644 --- a/api/app/repositories/neo4j/entity_repository.py +++ b/api/app/repositories/neo4j/entity_repository.py @@ -55,6 +55,13 @@ class EntityRepository(BaseNeo4jRepository[ExtractedEntityNode]): if 'aliases' not in n or n['aliases'] is None: n['aliases'] = [] + # 处理 ACT-R 属性 - 确保字段存在且有默认值 + n['importance_score'] = n.get('importance_score', 0.5) + n['activation_value'] = n.get('activation_value') + n['access_history'] = n.get('access_history', []) + n['last_access_time'] = n.get('last_access_time') + n['access_count'] = n.get('access_count', 0) + return ExtractedEntityNode(**n) async def find_by_type(self, entity_type: str, limit: int = 100) -> List[ExtractedEntityNode]: diff --git a/api/app/repositories/neo4j/graph_search.py b/api/app/repositories/neo4j/graph_search.py index ab2b28ac..1549ef86 100644 --- a/api/app/repositories/neo4j/graph_search.py +++ b/api/app/repositories/neo4j/graph_search.py @@ -1,5 +1,6 @@ from typing import Any, Dict, List, Optional import asyncio +import logging # 使用新的仓储层 from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -24,6 +25,157 @@ from app.repositories.neo4j.cypher_queries import ( SEARCH_STATEMENTS_L_VALID_AT, ) +logger = logging.getLogger(__name__) + + +async def _update_activation_values_batch( + connector: Neo4jConnector, + nodes: List[Dict[str, Any]], + node_label: str, + group_id: Optional[str] = None, + max_retries: int = 3 +) -> List[Dict[str, Any]]: + """ + 批量更新节点的激活值 + + 为提高性能,批量更新多个节点的访问历史和激活值。 + 使用重试机制处理更新失败的情况。 + + Args: + connector: Neo4j连接器 + nodes: 节点列表,每个节点必须包含 'id' 字段 + node_label: 节点标签(Statement, ExtractedEntity, MemorySummary) + group_id: 组ID(可选) + max_retries: 最大重试次数 + + Returns: + List[Dict[str, Any]]: 成功更新的节点列表 + """ + if not nodes: + return [] + + # 延迟导入以避免循环依赖 + from app.core.memory.storage_services.forgetting_engine.access_history_manager import AccessHistoryManager + from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator + + # 创建计算器和管理器实例 + actr_calculator = ACTRCalculator() + access_manager = AccessHistoryManager( + connector=connector, + actr_calculator=actr_calculator, + max_retries=max_retries + ) + + # 提取节点ID列表 + node_ids = [node.get('id') for node in nodes if node.get('id')] + + if not node_ids: + logger.warning(f"批量更新激活值:没有有效的节点ID") + return nodes + + # 批量记录访问 + try: + updated_nodes = await access_manager.record_batch_access( + node_ids=node_ids, + node_label=node_label, + group_id=group_id + ) + + logger.info( + f"批量更新激活值成功: {node_label}, " + f"更新数量={len(updated_nodes)}/{len(node_ids)}" + ) + + return updated_nodes + + except Exception as e: + logger.error( + f"批量更新激活值失败: {node_label}, 错误: {str(e)}" + ) + # 失败时返回原始节点列表 + return nodes + + +async def _update_search_results_activation( + connector: Neo4jConnector, + results: Dict[str, List[Dict[str, Any]]], + group_id: Optional[str] = None +) -> Dict[str, List[Dict[str, Any]]]: + """ + 更新搜索结果中所有知识节点的激活值 + + 对 Statement、ExtractedEntity、MemorySummary 节点进行批量激活值更新。 + ChunkNode 和 DialogueNode 不参与激活值更新(数据层隔离)。 + + Args: + connector: Neo4j连接器 + results: 搜索结果字典,包含不同类型节点的列表 + group_id: 组ID(可选) + + Returns: + Dict[str, List[Dict[str, Any]]]: 更新后的搜索结果 + """ + # 定义需要更新激活值的节点类型 + knowledge_node_types = { + 'statements': 'Statement', + 'entities': 'ExtractedEntity', + 'summaries': 'MemorySummary' + } + + # 并行更新所有类型的节点 + update_tasks = [] + update_keys = [] + + for key, label in knowledge_node_types.items(): + if key in results and results[key]: + update_tasks.append( + _update_activation_values_batch( + connector=connector, + nodes=results[key], + node_label=label, + group_id=group_id + ) + ) + update_keys.append(key) + + if not update_tasks: + return results + + # 并行执行所有更新 + update_results = await asyncio.gather(*update_tasks, return_exceptions=True) + + # 更新结果字典,保留原始搜索分数 + updated_results = results.copy() + for key, update_result in zip(update_keys, update_results): + if not isinstance(update_result, Exception): + # 更新成功,合并原始搜索结果和更新后的激活值数据 + # 保留原始的 score 字段(BM25/Embedding 分数) + original_nodes = results[key] + updated_nodes = update_result + + # 创建 ID 到原始节点的映射(用于快速查找 score) + original_map = {node.get('id'): node for node in original_nodes if node.get('id')} + + # 合并数据:激活值来自更新结果,score 来自原始结果 + merged_nodes = [] + for updated_node in updated_nodes: + node_id = updated_node.get('id') + if node_id and node_id in original_map: + # 保留原始的 score 字段 + original_score = original_map[node_id].get('score') + if original_score is not None: + updated_node['score'] = original_score + merged_nodes.append(updated_node) + + updated_results[key] = merged_nodes + else: + # 更新失败,记录错误但保留原始结果 + logger.warning( + f"更新 {key} 激活值失败: {str(update_result)}" + ) + + return updated_results + async def search_graph( connector: Neo4jConnector, @@ -36,6 +188,7 @@ async def search_graph( Search across Statements, Entities, Chunks, and Summaries using a free-text query. OPTIMIZED: Runs all queries in parallel using asyncio.gather() + INTEGRATED: Updates activation values for knowledge nodes before returning results - Statements: matches s.statement CONTAINS q - Entities: matches e.name CONTAINS q @@ -50,7 +203,7 @@ async def search_graph( include: List of categories to search (default: all) Returns: - Dictionary with search results per category + Dictionary with search results per category (with updated activation values) """ if include is None: include = ["statements", "chunks", "entities", "summaries"] @@ -106,6 +259,13 @@ async def search_graph( else: results[key] = result + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + return results @@ -121,6 +281,7 @@ async def search_graph_by_embedding( Embedding-based semantic search across Statements, Chunks, and Entities. OPTIMIZED: Runs all queries in parallel using asyncio.gather() + INTEGRATED: Updates activation values for knowledge nodes before returning results - Computes query embedding with the provided embedder_client - Ranks by cosine similarity in Cypher @@ -203,6 +364,16 @@ async def search_graph_by_embedding( else: results[key] = result + # 更新知识节点的激活值(Statement, ExtractedEntity, MemorySummary) + update_start = time.time() + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + update_time = time.time() - update_start + print(f"[PERF] Activation value updates took: {update_time:.4f}s") + return results async def get_dedup_candidates_for_entities( # 适配新版查询:使用全文索引按名称检索候选实体 connector: Neo4jConnector, @@ -304,6 +475,8 @@ async def search_graph_by_keyword_temporal( ) -> Dict[str, List[Any]]: """ Temporal keyword search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements containing query_text created between start_date and end_date - Optionally filters by group_id, apply_id, user_id @@ -326,7 +499,15 @@ async def search_graph_by_keyword_temporal( ) print(f"查询结果为:\n{statements}") - return {"statements": statements} + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results async def search_graph_by_temporal( @@ -342,6 +523,8 @@ async def search_graph_by_temporal( ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created between start_date and end_date - Optionally filters by group_id, apply_id, user_id @@ -362,7 +545,16 @@ async def search_graph_by_temporal( print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_TEMPORAL}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, start_date: {start_date}, end_date: {end_date}, valid_date: {valid_date}, invalid_date: {invalid_date}, limit: {limit}}}") print(f"查询结果为:\n{statements}") - return {"statements": statements} + + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results async def search_graph_by_dialog_id( @@ -419,6 +611,8 @@ async def search_graph_by_created_at( ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - Optionally filters by group_id, apply_id, user_id @@ -436,7 +630,16 @@ async def search_graph_by_created_at( print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_CREATED_AT}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") print(f"查询结果为:\n{statements}") - return {"statements": statements} + + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results async def search_graph_by_valid_at( connector: Neo4jConnector, @@ -448,6 +651,8 @@ async def search_graph_by_valid_at( ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - Optionally filters by group_id, apply_id, user_id @@ -465,7 +670,16 @@ async def search_graph_by_valid_at( print(f"查询语句为:\n{SEARCH_STATEMENTS_BY_VALID_AT}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") print(f"查询结果为:\n{statements}") - return {"statements": statements} + + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results async def search_graph_g_created_at( connector: Neo4jConnector, @@ -477,6 +691,8 @@ async def search_graph_g_created_at( ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - Optionally filters by group_id, apply_id, user_id @@ -494,7 +710,16 @@ async def search_graph_g_created_at( print(f"查询语句为:\n{SEARCH_STATEMENTS_G_CREATED_AT}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") print(f"查询结果为:\n{statements}") - return {"statements": statements} + + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results async def search_graph_g_valid_at( connector: Neo4jConnector, @@ -506,6 +731,8 @@ async def search_graph_g_valid_at( ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - Optionally filters by group_id, apply_id, user_id @@ -523,7 +750,16 @@ async def search_graph_g_valid_at( print(f"查询语句为:\n{SEARCH_STATEMENTS_G_VALID_AT}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") print(f"查询结果为:\n{statements}") - return {"statements": statements} + + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results async def search_graph_l_created_at( connector: Neo4jConnector, @@ -535,6 +771,8 @@ async def search_graph_l_created_at( ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements created at created_at - Optionally filters by group_id, apply_id, user_id @@ -552,7 +790,16 @@ async def search_graph_l_created_at( print(f"查询语句为:\n{SEARCH_STATEMENTS_L_CREATED_AT}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, created_at: {created_at}, limit: {limit}}}") print(f"查询结果为:\n{statements}") - return {"statements": statements} + + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results async def search_graph_l_valid_at( connector: Neo4jConnector, @@ -564,6 +811,8 @@ async def search_graph_l_valid_at( ) -> Dict[str, List[Dict[str, Any]]]: """ Temporal search across Statements. + + INTEGRATED: Updates activation values for Statement nodes before returning results - Matches statements valid at valid_at - Optionally filters by group_id, apply_id, user_id @@ -581,4 +830,13 @@ async def search_graph_l_valid_at( print(f"查询语句为:\n{SEARCH_STATEMENTS_L_VALID_AT}") print(f"查询参数为:\n{{group_id: {group_id}, apply_id: {apply_id}, user_id: {user_id}, valid_at: {valid_at}, limit: {limit}}}") print(f"查询结果为:\n{statements}") - return {"statements": statements} + + # 更新 Statement 节点的激活值 + results = {"statements": statements} + results = await _update_search_results_activation( + connector=connector, + results=results, + group_id=group_id + ) + + return results diff --git a/api/app/repositories/neo4j/neo4j_connector.py b/api/app/repositories/neo4j/neo4j_connector.py index 642661d4..7c4b43b5 100644 --- a/api/app/repositories/neo4j/neo4j_connector.py +++ b/api/app/repositories/neo4j/neo4j_connector.py @@ -8,7 +8,6 @@ Classes: Neo4jConnector: Neo4j数据库连接器,提供异步查询接口 """ -import os from typing import Any, List, Dict from neo4j import AsyncGraphDatabase, basic_auth @@ -85,6 +84,63 @@ class Neo4jConnector: records, summary, keys = result return [record.data() for record in records] + async def execute_write_transaction(self, transaction_func, **kwargs: Any) -> Any: + """在写事务中执行操作 + + 提供显式事务支持,确保操作的原子性。 + 如果事务函数抛出异常,所有更改将自动回滚。 + + Args: + transaction_func: 事务函数,接收 tx 参数并执行查询 + **kwargs: 传递给事务函数的额外参数 + + Returns: + Any: 事务函数的返回值 + + Example: + >>> async def create_node(tx, name): + ... result = await tx.run( + ... "CREATE (n:Person {name: $name}) RETURN n", + ... name=name + ... ) + ... return await result.single() + >>> + >>> connector = Neo4jConnector() + >>> result = await connector.execute_write_transaction( + ... create_node, name="Alice" + ... ) + """ + async with self.driver.session(database="neo4j") as session: + return await session.execute_write(transaction_func, **kwargs) + + async def execute_read_transaction(self, transaction_func, **kwargs: Any) -> Any: + """在读事务中执行操作 + + 提供显式事务支持用于读操作。 + + Args: + transaction_func: 事务函数,接收 tx 参数并执行查询 + **kwargs: 传递给事务函数的额外参数 + + Returns: + Any: 事务函数的返回值 + + Example: + >>> async def get_node(tx, name): + ... result = await tx.run( + ... "MATCH (n:Person {name: $name}) RETURN n", + ... name=name + ... ) + ... return await result.single() + >>> + >>> connector = Neo4jConnector() + >>> result = await connector.execute_read_transaction( + ... get_node, name="Alice" + ... ) + """ + async with self.driver.session(database="neo4j") as session: + return await session.execute_read(transaction_func, **kwargs) + async def delete_group(self, group_id: str): """删除指定组的所有数据 diff --git a/api/app/repositories/neo4j/statement_repository.py b/api/app/repositories/neo4j/statement_repository.py index 34858444..22343e10 100644 --- a/api/app/repositories/neo4j/statement_repository.py +++ b/api/app/repositories/neo4j/statement_repository.py @@ -75,6 +75,13 @@ class StatementRepository(BaseNeo4jRepository[StatementNode]): n['emotion_subject'] = n.get('emotion_subject') n['emotion_target'] = n.get('emotion_target') + # 处理 ACT-R 属性 - 确保字段存在且有默认值 + n['importance_score'] = n.get('importance_score', 0.5) + n['activation_value'] = n.get('activation_value') + n['access_history'] = n.get('access_history', []) + n['last_access_time'] = n.get('last_access_time') + n['access_count'] = n.get('access_count', 0) + return StatementNode(**n) async def find_by_chunk_id(self, chunk_id: str) -> List[StatementNode]: diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 33d0d097..24747c34 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -399,3 +399,104 @@ class GenerateCacheRequest(BaseModel): None, description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成" ) + + +# ============================================================================ +# 遗忘引擎相关 Schema +# ============================================================================ + +class ForgettingTriggerRequest(BaseModel): + """手动触发遗忘周期请求模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + group_id: Optional[str] = Field(None, description="组ID(可选,用于过滤特定组的节点)") + max_merge_batch_size: int = Field(100, ge=1, le=1000, description="单次最大融合节点对数(默认100)") + min_days_since_access: int = Field(30, ge=1, le=365, description="最小未访问天数(默认30天)") + config_id: Optional[int] = Field(None, description="配置ID(可选,用于指定遗忘引擎配置)") # TODO 后续group_id更换成enduser_id,自动与config_id关联 ,要删除此行 + + +class ForgettingConfigResponse(BaseModel): + """遗忘引擎配置响应模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + config_id: int = Field(..., description="配置ID") + decay_constant: float = Field(..., description="衰减常数 d") + lambda_time: float = Field(..., description="时间衰减参数") + lambda_mem: float = Field(..., description="记忆衰减参数") + forgetting_rate: float = Field(..., description="遗忘速率(根据 lambda_time / lambda_mem 计算得出)") + offset: float = Field(..., description="偏移量") + max_history_length: int = Field(..., description="访问历史最大长度") + forgetting_threshold: float = Field(..., description="遗忘阈值") + min_days_since_access: int = Field(..., description="最小未访问天数") + enable_llm_summary: bool = Field(..., description="是否使用 LLM 生成摘要") + max_merge_batch_size: int = Field(..., description="单次最大融合节点对数") + forgetting_interval_hours: int = Field(..., description="遗忘周期间隔(小时)") + + +class ForgettingConfigUpdateRequest(BaseModel): + """遗忘引擎配置更新请求模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + config_id: int = Field(..., description="配置ID") + decay_constant: Optional[float] = Field(None, ge=0.0, le=1.0, description="衰减常数 d") + lambda_time: Optional[float] = Field(None, ge=0.0, le=1.0, description="时间衰减参数") + lambda_mem: Optional[float] = Field(None, ge=0.0, le=1.0, description="记忆衰减参数") + offset: Optional[float] = Field(None, ge=0.0, le=1.0, description="偏移量") + max_history_length: Optional[int] = Field(None, ge=10, le=1000, description="访问历史最大长度") + forgetting_threshold: Optional[float] = Field(None, ge=0.0, le=1.0, description="遗忘阈值") + min_days_since_access: Optional[int] = Field(None, ge=1, le=365, description="最小未访问天数") + enable_llm_summary: Optional[bool] = Field(None, description="是否使用 LLM 生成摘要") + max_merge_batch_size: Optional[int] = Field(None, ge=1, le=1000, description="单次最大融合节点对数") + forgetting_interval_hours: Optional[int] = Field(None, ge=1, le=168, description="遗忘周期间隔(小时)") + + +class ForgettingStatsResponse(BaseModel): + """遗忘引擎统计信息响应模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + activation_metrics: Dict[str, Any] = Field(..., description="激活值相关指标") + node_distribution: Dict[str, int] = Field(..., description="节点类型分布") + consistency_check: Optional[Dict[str, Any]] = Field(None, description="数据一致性检查结果") + nodes_merged_total: int = Field(..., description="累计融合节点对数") + recent_cycles: List[Dict[str, Any]] = Field(..., description="最近的遗忘周期记录") + timestamp: str = Field(..., description="统计时间(ISO格式)") + + +class ForgettingReportResponse(BaseModel): + """遗忘周期报告响应模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + merged_count: int = Field(..., description="融合的节点对数量") + nodes_before: int = Field(..., description="遗忘前的节点总数") + nodes_after: int = Field(..., description="遗忘后的节点总数") + reduction_rate: float = Field(..., description="节点减少率(0-1)") + duration_seconds: float = Field(..., description="执行耗时(秒)") + start_time: str = Field(..., description="开始时间(ISO格式)") + end_time: str = Field(..., description="结束时间(ISO格式)") + failed_count: int = Field(..., description="失败的融合数量") + success_rate: float = Field(..., description="成功率(0-1)") + + +class ForgettingCurvePoint(BaseModel): + """遗忘曲线数据点模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + day: int = Field(..., description="天数") + activation: float = Field(..., description="激活值") + retention_rate: float = Field(..., description="保持率(与激活值相同)") + + +class ForgettingCurveRequest(BaseModel): + """遗忘曲线请求模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + importance_score: float = Field(0.5, ge=0.0, le=1.0, description="重要性分数(0-1)") + days: int = Field(60, ge=1, le=365, description="模拟天数(默认60天)") + config_id: Optional[int] = Field(None, description="配置ID(可选,如果为None则使用默认配置)") + + +class ForgettingCurveResponse(BaseModel): + """遗忘曲线响应模型""" + model_config = ConfigDict(populate_by_name=True, extra="forbid") + + curve_data: List[ForgettingCurvePoint] = Field(..., description="遗忘曲线数据点列表") + config: Dict[str, Any] = Field(..., description="使用的配置参数") diff --git a/api/app/services/memory_forget_service.py b/api/app/services/memory_forget_service.py new file mode 100644 index 00000000..30a84b25 --- /dev/null +++ b/api/app/services/memory_forget_service.py @@ -0,0 +1,460 @@ +""" +遗忘引擎服务层模块 + +本模块提供遗忘引擎的业务逻辑实现,包括: +1. 遗忘周期执行 +2. 配置管理 +3. 统计信息查询 +4. 遗忘曲线生成 + +所有业务逻辑从控制器层分离到此服务层。 +""" + +from typing import Optional, Dict, Any, Tuple +from datetime import datetime + +from sqlalchemy.orm import Session + +from app.core.logging_config import get_api_logger +from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator +from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy +from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import ForgettingScheduler +from app.core.memory.storage_services.forgetting_engine.config_utils import ( + load_actr_config_from_db, +) +from app.repositories.neo4j.neo4j_connector import Neo4jConnector +from app.repositories.data_config_repository import DataConfigRepository + + +# 获取API专用日志器 +api_logger = get_api_logger() + + +class MemoryForgetService: + """遗忘引擎服务类""" + + def __init__(self): + """初始化服务""" + self.config_repository = DataConfigRepository() + + def _get_neo4j_connector(self) -> Neo4jConnector: + """ + 获取 Neo4j 连接器实例 + + Returns: + Neo4jConnector: Neo4j 连接器实例 + """ + # 这里应该从配置或依赖注入获取连接器 + # 暂时创建新实例(实际应该使用单例或连接池) + return Neo4jConnector() + + async def _get_forgetting_components( + self, + db: Session, + config_id: Optional[int] = None + ) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]: + """ + 获取遗忘引擎组件(计算器、策略、调度器) + + Args: + db: 数据库会话 + config_id: 配置ID(可选) + + Returns: + tuple: (actr_calculator, forgetting_strategy, forgetting_scheduler, config) + """ + # 加载配置 + config = load_actr_config_from_db(db, config_id) + + # 创建 ACT-R 计算器 + actr_calculator = ACTRCalculator( + decay_constant=config['decay_constant'], + forgetting_rate=config['forgetting_rate'], + offset=config['offset'], + max_history_length=config['max_history_length'] + ) + + # 获取 Neo4j 连接器 + connector = self._get_neo4j_connector() + + # 创建遗忘策略执行器 + forgetting_strategy = ForgettingStrategy( + connector=connector, + actr_calculator=actr_calculator, + forgetting_threshold=config['forgetting_threshold'], + enable_llm_summary=config['enable_llm_summary'] + ) + + # 创建遗忘调度器 + forgetting_scheduler = ForgettingScheduler( + forgetting_strategy=forgetting_strategy, + connector=connector + ) + + return actr_calculator, forgetting_strategy, forgetting_scheduler, config + + async def _get_knowledge_stats( + self, + connector: Neo4jConnector, + group_id: Optional[str] = None, + forgetting_threshold: float = 0.3 + ) -> Dict[str, Any]: + """ + 获取知识层统计信息 + + Args: + connector: Neo4j 连接器 + group_id: 组ID(可选) + forgetting_threshold: 遗忘阈值 + + Returns: + dict: 统计信息字典 + """ + # 构建查询 + query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) + """ + + if group_id: + query += " AND n.group_id = $group_id" + + query += """ + WITH n, + CASE + WHEN n:Statement THEN 'statement' + WHEN n:ExtractedEntity THEN 'entity' + WHEN n:MemorySummary THEN 'summary' + END as node_type + RETURN + count(n) as total_nodes, + sum(CASE WHEN node_type = 'statement' THEN 1 ELSE 0 END) as statement_count, + sum(CASE WHEN node_type = 'entity' THEN 1 ELSE 0 END) as entity_count, + sum(CASE WHEN node_type = 'summary' THEN 1 ELSE 0 END) as summary_count, + avg(n.activation_value) as average_activation, + sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes + """ + + params = {'threshold': forgetting_threshold} + if group_id: + params['group_id'] = group_id + + results = await connector.execute_query(query, **params) + + if results: + result = results[0] + return { + 'total_nodes': result['total_nodes'] or 0, + 'statement_count': result['statement_count'] or 0, + 'entity_count': result['entity_count'] or 0, + 'summary_count': result['summary_count'] or 0, + 'average_activation': result['average_activation'], + 'low_activation_nodes': result['low_activation_nodes'] or 0 + } + + return { + 'total_nodes': 0, + 'statement_count': 0, + 'entity_count': 0, + 'summary_count': 0, + 'average_activation': None, + 'low_activation_nodes': 0 + } + + async def trigger_forgetting_cycle( + self, + db: Session, + group_id: Optional[str] = None, + max_merge_batch_size: Optional[int] = None, + min_days_since_access: Optional[int] = None, + config_id: Optional[int] = None + ) -> Dict[str, Any]: + """ + 手动触发遗忘周期 + + 执行一次完整的遗忘周期,识别并融合低激活值节点。 + + Args: + db: 数据库会话 + group_id: 组ID(可选) + max_merge_batch_size: 最大融合批次大小(可选) + min_days_since_access: 最小未访问天数(可选) + config_id: 配置ID(可选) + + Returns: + dict: 遗忘报告 + """ + # 获取遗忘引擎组件 + _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + + # 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取) + report = await forgetting_scheduler.run_forgetting_cycle( + group_id=group_id, + max_merge_batch_size=max_merge_batch_size, + min_days_since_access=min_days_since_access, + config_id=config_id, + db=db + ) + + api_logger.info( + f"遗忘周期完成: 融合 {report['merged_count']} 对节点, " + f"失败 {report['failed_count']} 对, " + f"耗时 {report['duration_seconds']:.2f} 秒" + ) + + return report + + def read_forgetting_config( + self, + db: Session, + config_id: int + ) -> Dict[str, Any]: + """ + 获取遗忘引擎配置 + + 读取指定配置ID的遗忘引擎参数。 + + Args: + db: 数据库会话 + config_id: 配置ID + + Returns: + dict: 配置信息字典 + """ + # 加载配置 + config = load_actr_config_from_db(db, config_id) + + # 添加 config_id 到返回结果 + config['config_id'] = config_id + + api_logger.info(f"成功读取遗忘引擎配置: config_id={config_id}") + + return config + + def update_forgetting_config( + self, + db: Session, + config_id: int, + update_fields: Dict[str, Any] + ) -> Dict[str, Any]: + """ + 更新遗忘引擎配置 + + 更新指定配置ID的遗忘引擎参数。 + + Args: + db: 数据库会话 + config_id: 配置ID + update_fields: 要更新的字段字典 + + Returns: + dict: 更新后的配置信息 + + Raises: + ValueError: 配置不存在 + """ + # 检查配置是否存在 + db_config = self.config_repository.get_by_id(db, config_id) + if db_config is None: + raise ValueError(f"配置不存在: {config_id}") + + # 执行更新 + if update_fields: + for key, value in update_fields.items(): + if hasattr(db_config, key): + setattr(db_config, key, value) + + db.commit() + db.refresh(db_config) + + api_logger.info( + f"成功更新遗忘引擎配置: config_id={config_id}, " + f"更新字段: {list(update_fields.keys())}" + ) + else: + api_logger.info(f"没有字段需要更新: config_id={config_id}") + + # 重新加载配置并返回 + config = load_actr_config_from_db(db, config_id) + config['config_id'] = config_id + + return config + + async def get_forgetting_stats( + self, + db: Session, + group_id: Optional[str] = None, + config_id: Optional[int] = None + ) -> Dict[str, Any]: + """ + 获取遗忘引擎统计信息 + + 返回知识层节点统计、激活值分布等信息。 + + Args: + db: 数据库会话 + group_id: 组ID(可选) + config_id: 配置ID(可选,用于获取遗忘阈值) + + Returns: + dict: 统计信息字典 + """ + # 获取遗忘引擎组件 + _, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id) + + connector = forgetting_scheduler.connector + forgetting_threshold = config['forgetting_threshold'] + + # 收集激活值指标 + activation_query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) + """ + + if group_id: + activation_query += " AND n.group_id = $group_id" + + activation_query += """ + RETURN + count(n) as total_nodes, + sum(CASE WHEN n.activation_value IS NOT NULL THEN 1 ELSE 0 END) as nodes_with_activation, + sum(CASE WHEN n.activation_value IS NULL THEN 1 ELSE 0 END) as nodes_without_activation, + avg(n.activation_value) as average_activation, + sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes + """ + + params = {'threshold': forgetting_threshold} + if group_id: + params['group_id'] = group_id + + activation_results = await connector.execute_query(activation_query, **params) + + if activation_results: + result = activation_results[0] + activation_metrics = { + 'total_nodes': result['total_nodes'] or 0, + 'nodes_with_activation': result['nodes_with_activation'] or 0, + 'nodes_without_activation': result['nodes_without_activation'] or 0, + 'average_activation_value': result['average_activation'], + 'low_activation_nodes': result['low_activation_nodes'] or 0, + 'timestamp': datetime.now().isoformat() + } + else: + activation_metrics = { + 'total_nodes': 0, + 'nodes_with_activation': 0, + 'nodes_without_activation': 0, + 'average_activation_value': None, + 'low_activation_nodes': 0, + 'timestamp': datetime.now().isoformat() + } + + # 收集节点类型分布 + distribution_query = """ + MATCH (n) + WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk) + """ + + if group_id: + distribution_query += " AND n.group_id = $group_id" + + distribution_query += """ + WITH n, + CASE + WHEN n:Statement THEN 'statement' + WHEN n:ExtractedEntity THEN 'entity' + WHEN n:MemorySummary THEN 'summary' + WHEN n:Chunk THEN 'chunk' + END as node_type + RETURN + sum(CASE WHEN node_type = 'statement' THEN 1 ELSE 0 END) as statement_count, + sum(CASE WHEN node_type = 'entity' THEN 1 ELSE 0 END) as entity_count, + sum(CASE WHEN node_type = 'summary' THEN 1 ELSE 0 END) as summary_count, + sum(CASE WHEN node_type = 'chunk' THEN 1 ELSE 0 END) as chunk_count + """ + + dist_params = {} + if group_id: + dist_params['group_id'] = group_id + + distribution_results = await connector.execute_query(distribution_query, **dist_params) + + if distribution_results: + result = distribution_results[0] + node_distribution = { + 'statement_count': result['statement_count'] or 0, + 'entity_count': result['entity_count'] or 0, + 'summary_count': result['summary_count'] or 0, + 'chunk_count': result['chunk_count'] or 0 + } + else: + node_distribution = { + 'statement_count': 0, + 'entity_count': 0, + 'summary_count': 0, + 'chunk_count': 0 + } + + # 构建统计信息(不包含监控历史数据) + stats = { + 'activation_metrics': activation_metrics, + 'node_distribution': node_distribution, + 'consistency_check': None, # 不再提供一致性检查 + 'nodes_merged_total': 0, # 不再跟踪累计融合数 + 'recent_cycles': [], # 不再提供历史记录 + 'timestamp': datetime.now().isoformat() + } + + api_logger.info( + f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, " + f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}" + ) + + return stats + + async def get_forgetting_curve( + self, + db: Session, + importance_score: float, + days: int, + config_id: Optional[int] = None + ) -> Dict[str, Any]: + """ + 获取遗忘曲线数据 + + 生成遗忘曲线数据用于可视化,模拟记忆激活值随时间的衰减。 + + Args: + db: 数据库会话 + importance_score: 重要性分数(0-1) + days: 模拟天数 + config_id: 配置ID(可选) + + Returns: + dict: 包含曲线数据和配置的字典 + """ + # 获取 ACT-R 计算器 + actr_calculator, _, _, config = await self._get_forgetting_components(db, config_id) + + # 生成遗忘曲线数据 + initial_time = datetime.now() + curve_data = actr_calculator.get_forgetting_curve( + initial_time=initial_time, + importance_score=importance_score, + days=days + ) + + api_logger.info( + f"成功生成遗忘曲线数据: {len(curve_data)} 个数据点" + ) + + return { + 'curve_data': curve_data, + 'config': { + 'decay_constant': config['decay_constant'], + 'forgetting_rate': config['forgetting_rate'], + 'offset': config['offset'], + 'importance_score': importance_score, + 'days': days + } + } diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 01b59445..55d96082 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -8,7 +8,6 @@ import asyncio import json import os import time -import uuid from datetime import datetime from typing import Any, AsyncGenerator, Dict, List, Optional @@ -26,7 +25,6 @@ from app.schemas.memory_storage_schema import ( ConfigPilotRun, ConfigUpdate, ConfigUpdateExtracted, - ConfigUpdateForget, ) from app.services.memory_config_service import MemoryConfigService from app.utils.sse_utils import format_sse_message @@ -159,11 +157,8 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) return {"affected": 1} # --- Forget config params --- - def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置 - config = DataConfigRepository.update_forget(self.db, update) - if not config: - raise ValueError("未找到配置") - return {"affected": 1} + # 遗忘引擎配置方法已迁移到 memory_forget_service.py + # 使用新方法: MemoryForgetService.read_forgetting_config() 和 MemoryForgetService.update_forgetting_config() # --- Read --- def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数 @@ -172,12 +167,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) raise ValueError("未找到配置") return result - def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数 - result = DataConfigRepository.get_forget_config(self.db, key.config_id) - if not result: - raise ValueError("未找到配置") - return result - # --- Read All --- def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 configs = DataConfigRepository.get_all(self.db, workspace_id) diff --git a/api/app/tasks.py b/api/app/tasks.py index 15b03ae7..28a882b7 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -412,7 +412,7 @@ def read_message_task(self, group_id: str, message: str, history: List[Dict[str, actual_config_id = connected_config.get("memory_config_id") finally: db.close() - except Exception as e: + except Exception: # Log but continue - will fail later with proper error pass @@ -499,7 +499,7 @@ def write_message_task(self, group_id: str, message: str, config_id: str,storage actual_config_id = connected_config.get("memory_config_id") finally: db.close() - except Exception as e: + except Exception: # Log but continue - will fail later with proper error pass @@ -1064,4 +1064,74 @@ def workspace_reflection_task(self) -> Dict[str, Any]: "error": str(e), "elapsed_time": elapsed_time, "task_id": self.request.id - } \ No newline at end of file + } + + + +@celery_app.task(name="app.tasks.run_forgetting_cycle_task", bind=True) +def run_forgetting_cycle_task(self, config_id: Optional[int] = None) -> Dict[str, Any]: + """定时任务:运行遗忘周期 + + 定期执行遗忘周期,识别并融合低激活值的知识节点。 + + Args: + config_id: 配置ID(可选,如果为None则使用默认配置) + + Returns: + 包含任务执行结果的字典 + """ + start_time = time.time() + + async def _run() -> Dict[str, Any]: + from app.core.logging_config import get_api_logger + from app.services.memory_forget_service import MemoryForgetService + + api_logger = get_api_logger() + + with get_db_context() as db: + try: + api_logger.info(f"开始执行遗忘周期定时任务,config_id: {config_id}") + + forget_service = MemoryForgetService() + + # 运行遗忘周期 + report = await forget_service.trigger_forgetting( + db=db, + group_id=None, # 处理所有组 + config_id=config_id + ) + + duration = time.time() - start_time + + api_logger.info( + f"遗忘周期定时任务完成: " + f"融合 {report['merged_count']} 对节点, " + f"失败 {report['failed_count']} 对, " + f"耗时 {duration:.2f} 秒" + ) + + return { + "status": "SUCCESS", + "message": "遗忘周期执行成功", + "report": report, + "duration_seconds": duration + } + + except Exception as e: + duration = time.time() - start_time + api_logger.error(f"遗忘周期定时任务失败: {str(e)}", exc_info=True) + + return { + "status": "FAILED", + "message": f"遗忘周期执行失败: {str(e)}", + "duration_seconds": duration + } + + # 运行异步函数 + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete(_run()) + return result + finally: + loop.close()