Merge #85 into develop from feature/actr-forget

[feature]actr-记忆遗忘需求开发

* feature/actr-forget: (12 commits squashed)

  - [feature]
    1.Extended fields of the date_config table;
    2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.

  - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler

  - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process

  - [feature]
    1.Extended fields of the date_config table;
    2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.

  - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler

  - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process

  - Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget

  - [fix]Eliminate the interference caused by redundant code

  - [feature]
    1.Extended fields of the date_config table;
    2.New activation value calculation has been added, and the ACTR parameter has been introduced in Neo4j.

  - [feature]1.Create a forgetting strategy executor;2.Create the forgetting scheduler

  - [feature]Introduce activation values for retrieval, and develop a two-stage retrieval reordering process

  - Merge branch 'feature/actr-forget' of codeup.aliyun.com:redbearai/python/redbear-mem-open into feature/actr-forget

Signed-off-by: 乐力齐 <accounts_690c7b0af9007d7e338af636@mail.teambition.com>
Reviewed-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>
Merged-by: aliyun6762716068 <accounts_68cb7c6b61f5dcc4200d6251@mail.teambition.com>

CR-link: https://codeup.aliyun.com/redbearai/python/redbear-mem-open/change/85
This commit is contained in:
乐力齐
2026-01-05 04:30:36 +00:00
committed by 孙科
parent d299c39c55
commit e8a5cfe7e3
24 changed files with 4178 additions and 287 deletions

View File

@@ -106,28 +106,32 @@ class SearchService:
limit: int = 15,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config: "MemoryConfig" = None,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
Execute hybrid search with two-stage ranking.
Stage 1: Filter by content relevance (BM25 + Embedding)
Stage 2: Rerank by activation values (ACTR)
Args:
group_id: Group identifier for filtering results
group_id: Group identifier for filtering
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: MemoryConfig object for embedding model. Falls back to self.memory_config if not provided.
limit: Max results per category (default: 15)
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: BM25 weight (default: 0.6)
activation_boost_factor: Activation impact on memory strength (default: 0.8)
output_path: JSON output path (default: "search_results.json")
return_raw_results: Return full metadata (default: False)
memory_config: MemoryConfig for embedding model
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
@@ -151,6 +155,7 @@ class SearchService:
output_path=output_path,
memory_config=config,
rerank_alpha=rerank_alpha,
activation_boost_factor=activation_boost_factor,
)
# Extract results based on search type and include parameter

View File

@@ -228,6 +228,13 @@ class StatementNode(Node):
chunk_embedding: Optional embedding vector for the parent chunk
connect_strength: Classification of connection strength ('Strong' or 'Weak')
config_id: Configuration ID used to process this statement
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
access_history: List of ISO timestamp strings recording each access
last_access_time: ISO timestamp of the most recent access
access_count: Total number of times this node has been accessed
"""
# Core fields (ordered as requested)
chunk_id: str = Field(..., description="ID of the parent chunk")
@@ -269,6 +276,33 @@ class StatementNode(Node):
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), default 0.5"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access"
)
access_count: int = Field(
default=0,
ge=0,
description="Total number of times this node has been accessed"
)
@field_validator('valid_at', 'invalid_at', mode='before')
@classmethod
def validate_datetime(cls, v):
@@ -351,6 +385,13 @@ class ExtractedEntityNode(Node):
fact_summary: Summary of facts about this entity
connect_strength: Classification of connection strength ('Strong', 'Weak', or 'Both')
config_id: Configuration ID used to process this entity (integer or string)
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), default 0.5
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0)
access_history: List of ISO timestamp strings recording each access
last_access_time: ISO timestamp of the most recent access
access_count: Total number of times this node has been accessed
"""
entity_idx: int = Field(..., description="Unique identifier for the entity")
statement_id: str = Field(..., description="Statement this entity was extracted from")
@@ -365,6 +406,33 @@ class ExtractedEntityNode(Node):
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), default 0.5"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0)"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access"
)
access_count: int = Field(
default=0,
ge=0,
description="Total number of times this node has been accessed"
)
@field_validator('aliases', mode='before')
@classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
@@ -401,6 +469,16 @@ class MemorySummaryNode(Node):
summary_embedding: Optional embedding vector for the summary
metadata: Additional metadata for the summary
config_id: Configuration ID used to process this summary
original_statement_id: ID of the original statement that was merged (for ACT-R forgetting)
original_entity_id: ID of the original entity that was merged (for ACT-R forgetting)
merged_at: Timestamp when the nodes were merged
# ACT-R Memory Activation Properties
importance_score: Importance score for memory activation (0.0-1.0), inherited from merged nodes
activation_value: Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes
access_history: List of ISO timestamp strings recording each access (reset on creation)
last_access_time: ISO timestamp of the most recent access (set to creation time)
access_count: Total number of times this node has been accessed (reset to 1 on creation)
"""
summary_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for the summary")
dialog_id: str = Field(..., description="ID of the parent dialog")
@@ -409,3 +487,44 @@ class MemorySummaryNode(Node):
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field(
None,
description="ID of the original statement that was merged (for traceability)"
)
original_entity_id: Optional[str] = Field(
None,
description="ID of the original entity that was merged (for traceability)"
)
merged_at: Optional[datetime] = Field(
None,
description="Timestamp when the nodes were merged"
)
# ACT-R Memory Activation Properties
importance_score: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Importance score for memory activation (0.0-1.0), inherited from merged nodes"
)
activation_value: Optional[float] = Field(
None,
ge=0.0,
le=1.0,
description="Current activation value calculated by ACT-R engine (0.0-1.0), inherited from merged nodes"
)
access_history: List[str] = Field(
default_factory=list,
description="List of ISO timestamp strings recording each access (reset on creation)"
)
last_access_time: Optional[str] = Field(
None,
description="ISO timestamp of the most recent access (set to creation time)"
)
access_count: int = Field(
default=1,
ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)"
)

View File

@@ -69,6 +69,12 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
for item in results:
if score_field in item:
score = item.get(score_field)
# 对于 activation_valueNone 值保持为 None不使用回退值
# 这样可以区分有激活值和无激活值的节点
if score_field == "activation_value" and score is None:
scores.append(None) # 保持 None稍后特殊处理
continue
if score is not None and isinstance(score, (int, float)):
scores.append(float(score))
else:
@@ -76,205 +82,433 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
if not scores:
return results
if len(scores) == 1:
# Single score, set to 1.0
# 过滤掉 None 值,只对有效分数进行归一化
valid_scores = [s for s in scores if s is not None]
if not valid_scores:
# 所有分数都是 None不进行归一化
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
if score_field in item or score_field == "activation_value":
item[f"normalized_{score_field}"] = None
return results
# Calculate mean and standard deviation
mean_score = sum(scores) / len(scores)
variance = sum((score - mean_score) ** 2 for score in scores) / len(scores)
if len(valid_scores) == 1: # Single valid score, set to 1.0
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
item[f"normalized_{score_field}"] = None
else:
item[f"normalized_{score_field}"] = 1.0
return results
# Calculate mean and standard deviation (only for valid scores)
mean_score = sum(valid_scores) / len(valid_scores)
variance = sum((score - mean_score) ** 2 for score in valid_scores) / len(valid_scores)
std_dev = math.sqrt(variance)
if std_dev == 0:
# All scores are the same, set them to 1.0
for item in results:
if score_field in item:
item[f"normalized_{score_field}"] = 1.0
# All valid scores are the same, set them to 1.0
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
item[f"normalized_{score_field}"] = None
else:
item[f"normalized_{score_field}"] = 1.0
else:
for item in results:
if score_field in item:
score = item[score_field]
# Handle None or non-numeric scores
if score is None or not isinstance(score, (int, float)):
score = 0.0
# Calculate z-score
z_score = (score - mean_score) / std_dev
# Transform to positive range using sigmoid function
normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized
for item, score in zip(results, scores):
if score_field in item or score_field == "activation_value":
if score is None:
# 保持 None不进行归一化
item[f"normalized_{score_field}"] = None
else:
# Calculate z-score
z_score = (score - mean_score) / std_dev
# Transform to positive range using sigmoid function
normalized = 1 / (1 + math.exp(-z_score))
item[f"normalized_{score_field}"] = normalized
return results
def rerank_hybrid_results(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10
) -> Dict[str, List[Dict[str, Any]]]:
"""
Rerank hybrid search results by combining BM25 and embedding scores.
# ============================================================================
# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
# ============================================================================
Args:
keyword_results: Results from keyword/BM25 search
embedding_results: Results from embedding search
alpha: Weight for BM25 scores (1-alpha for embedding scores)
limit: Maximum number of results to return per category
# def rerank_hybrid_results(
# keyword_results: Dict[str, List[Dict[str, Any]]],
# embedding_results: Dict[str, List[Dict[str, Any]]],
# alpha: float = 0.6,
# limit: int = 10
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid search results by combining BM25 and embedding scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
#
# Returns:
# Reranked results with combined scores
# """
# reranked = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# # Create a combined pool of unique items
# combined_items = {}
#
# # Add keyword results with BM25 scores
# for item in keyword_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# combined_items[item_id]["embedding_score"] = 0 # Default
#
# # Add or update with embedding results
# for item in embedding_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if item_id:
# if item_id in combined_items:
# # Update existing item with embedding score
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# # New item from embedding search only
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0 # Default
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
#
# # Calculate combined scores and rank
# for item_id, item in combined_items.items():
# bm25_score = item.get("bm25_score", 0)
# embedding_score = item.get("embedding_score", 0)
#
# # Combined score: weighted average of normalized scores
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# item["combined_score"] = combined_score
#
# # Keep original score for reference
# if "score" not in item and bm25_score > 0:
# item["score"] = bm25_score
# elif "score" not in item and embedding_score > 0:
# item["score"] = embedding_score
#
# # Sort by combined score and limit results
# sorted_items = sorted(
# combined_items.values(),
# key=lambda x: x.get("combined_score", 0),
# reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
Returns:
Reranked results with combined scores
"""
reranked = {}
# def rerank_with_forgetting_curve(
# keyword_results: Dict[str, List[Dict[str, Any]]],
# embedding_results: Dict[str, List[Dict[str, Any]]],
# alpha: float = 0.6,
# limit: int = 10,
# forgetting_config: ForgettingEngineConfig | None = None,
# now: datetime | None = None,
# ) -> Dict[str, List[Dict[str, Any]]]:
# """
# Rerank hybrid results with a forgetting curve applied to combined scores.
#
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
#
# The forgetting curve reduces scores for older memories or weaker connections.
#
# Args:
# keyword_results: Results from keyword/BM25 search
# embedding_results: Results from embedding search
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
# limit: Maximum number of results to return per category
# forgetting_config: Configuration for the forgetting engine
# now: Optional current time override for testing
#
# Returns:
# Reranked results with combined and final scores (after forgetting)
# """
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
# now_dt = now or datetime.now()
#
# reranked: Dict[str, List[Dict[str, Any]]] = {}
#
# for category in ["statements", "chunks", "entities","summaries"]:
# keyword_items = keyword_results.get(category, [])
# embedding_items = embedding_results.get(category, [])
#
# # Normalize scores within each search type
# keyword_items = normalize_scores(keyword_items, "score")
# embedding_items = normalize_scores(embedding_items, "score")
#
# combined_items: Dict[str, Dict[str, Any]] = {}
#
# # Combine two result sets by ID
# for src_items, is_embedding in (
# (keyword_items, False), (embedding_items, True)
# ):
# for item in src_items:
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
# if not item_id:
# continue
# existing = combined_items.get(item_id)
# if not existing:
# combined_items[item_id] = item.copy()
# combined_items[item_id]["bm25_score"] = 0
# combined_items[item_id]["embedding_score"] = 0
# # Update normalized score from the right source
# if is_embedding:
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# else:
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
#
# # Calculate scores and apply forgetting weights
# for item_id, item in combined_items.items():
# bm25_score = float(item.get("bm25_score", 0) or 0)
# embedding_score = float(item.get("embedding_score", 0) or 0)
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
#
# # Estimate time elapsed in days
# dt = _parse_datetime(item.get("created_at"))
# if dt is None:
# time_elapsed_days = 0.0
# else:
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
#
# # Memory strength (currently set to default value)
# memory_strength = 1.0
# forgetting_weight = engine.calculate_weight(
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
# )
# final_score = combined_score * forgetting_weight
# item["combined_score"] = final_score
#
# sorted_items = sorted(
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
# )[:limit]
#
# reranked[category] = sorted_items
#
# return reranked
for category in ["statements", "chunks", "entities","summaries"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# Normalize scores within each search type
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# Create a combined pool of unique items
combined_items = {}
# Add keyword results with BM25 scores
for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # Default
# Add or update with embedding results
for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id:
if item_id in combined_items:
# Update existing item with embedding score
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
# New item from embedding search only
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # Default
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# Calculate combined scores and rank
for item_id, item in combined_items.items():
bm25_score = item.get("bm25_score", 0)
embedding_score = item.get("embedding_score", 0)
# Combined score: weighted average of normalized scores
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
item["combined_score"] = combined_score
# Keep original score for reference
if "score" not in item and bm25_score > 0:
item["score"] = bm25_score
elif "score" not in item and embedding_score > 0:
item["score"] = embedding_score
# Sort by combined score and limit results
sorted_items = sorted(
combined_items.values(),
key=lambda x: x.get("combined_score", 0),
reverse=True
)[:limit]
reranked[category] = sorted_items
return reranked
def rerank_with_forgetting_curve(
def rerank_with_activation(
keyword_results: Dict[str, List[Dict[str, Any]]],
embedding_results: Dict[str, List[Dict[str, Any]]],
alpha: float = 0.6,
limit: int = 10,
forgetting_config: ForgettingEngineConfig | None = None,
activation_boost_factor: float = 0.8,
now: datetime | None = None,
) -> Dict[str, List[Dict[str, Any]]]:
"""
Rerank hybrid results with a forgetting curve applied to combined scores.
The forgetting curve reduces scores for older memories or weaker connections.
Args:
keyword_results: Results from keyword/BM25 search
embedding_results: Results from embedding search
alpha: Weight for BM25 scores (1-alpha for embedding scores)
limit: Maximum number of results to return per category
forgetting_config: Configuration for the forgetting engine
now: Optional current time override for testing
Returns:
Reranked results with combined and final scores (after forgetting)
两阶段排序:先按内容相关性筛选,再按激活值排序。
阶段1: content_score = alpha*BM25 + (1-alpha)*Embedding取 Top-(limit*3)
阶段2: 在候选中按 activation_score 排序,取 Top-limit
无激活值的节点用于补充不足
返回结果中的评分字段说明:
- bm25_score: BM25 归一化分数
- embedding_score: Embedding 归一化分数
- content_score: 内容相关性 = alpha*bm25 + (1-alpha)*embedding
- activation_score: ACTR 激活值归一化分数
- base_score: 第一阶段基础分数(等于 content_score
- final_score: 最终排序依据
* 有激活值的节点:final_score = activation_score
* 无激活值的节点final_score = base_score
参数:
keyword_results: BM25 检索结果
embedding_results: 向量嵌入检索结果
alpha: BM25 权重 (默认: 0.6)
limit: 每类最大结果数
forgetting_config: 遗忘引擎配置(当前未使用)
activation_boost_factor: 激活度对记忆强度的影响系数 (默认: 0.8)
now: 当前时间(用于遗忘计算)
返回:
带评分元数据的重排序结果,按 final_score 排序
"""
engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
# 验证权重范围
if not (0 <= alpha <= 1):
raise ValueError(f"alpha 必须在 [0, 1] 范围内,当前值: {alpha}")
# 初始化遗忘引擎(如果需要)
engine = None
if forgetting_config:
engine = ForgettingEngine(forgetting_config)
now_dt = now or datetime.now()
reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities","summaries"]:
for category in ["statements", "chunks", "entities", "summaries"]:
keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, [])
# Normalize scores within each search type
# 步骤 1: 归一化分数
keyword_items = normalize_scores(keyword_items, "score")
embedding_items = normalize_scores(embedding_items, "score")
# 步骤 2: 按 ID 合并结果
combined_items: Dict[str, Dict[str, Any]] = {}
# Combine two result sets by ID
for src_items, is_embedding in (
(keyword_items, False), (embedding_items, True)
):
for item in src_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
existing = combined_items.get(item_id)
if not existing:
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0
combined_items[item_id]["embedding_score"] = 0
# Update normalized score from the right source
if is_embedding:
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
# Calculate scores and apply forgetting weights
for item_id, item in combined_items.items():
bm25_score = float(item.get("bm25_score", 0) or 0)
embedding_score = float(item.get("embedding_score", 0) or 0)
combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
# Estimate time elapsed in days
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
# 添加关键词结果
for item in keyword_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
combined_items[item_id]["embedding_score"] = 0 # 默认值
# 添加或更新向量嵌入结果
for item in embedding_items:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if not item_id:
continue
if item_id in combined_items:
# 更新现有项的嵌入分数
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# Memory strength (currently set to default value)
memory_strength = 1.0
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days, memory_strength=memory_strength
)
# print(f"Forgetting weight for {item_id}: {forgetting_weight}")
# print(f"Time elapsed days for {item_id}: {time_elapsed_days}")
final_score = combined_score * forgetting_weight
item["combined_score"] = final_score
sorted_items = sorted(
combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
)[:limit]
# 仅来自嵌入搜索的新项
combined_items[item_id] = item.copy()
combined_items[item_id]["bm25_score"] = 0 # 默认值
combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
# 步骤 3: 归一化激活度分数
# 为所有项准备激活度值列表
items_list = list(combined_items.values())
items_list = normalize_scores(items_list, "activation_value")
# 更新 combined_items 中的归一化激活度分数
for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
# 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0)
act_norm = float(item.get("normalized_activation_value", 0) or 0)
# 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用
item["activation_score"] = act_norm
item["content_score"] = content_score
item["base_score"] = base_score
# 步骤 5: 应用遗忘曲线(可选)
if engine:
# 计算受激活度影响的记忆强度
importance = float(item.get("importance_score", 0.5) or 0.5)
# 获取 activation_value
activation_val = item.get("activation_value")
# 只对有激活值的节点应用遗忘曲线
if activation_val is not None and isinstance(activation_val, (int, float)):
activation_val = float(activation_val)
# 计算记忆强度importance_score × (1 + activation_value × boost_factor)
memory_strength = importance * (1 + activation_val * activation_boost_factor)
# 计算经过的时间(天数)
dt = _parse_datetime(item.get("created_at"))
if dt is None:
time_elapsed_days = 0.0
else:
time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
# 获取遗忘权重
forgetting_weight = engine.calculate_weight(
time_elapsed=time_elapsed_days,
memory_strength=memory_strength
)
# 应用到基础分数
item["forgetting_weight"] = forgetting_weight
item["final_score"] = base_score * forgetting_weight
else:
# 无激活值的节点不应用遗忘曲线,保持原始分数
item["final_score"] = base_score
else:
# 不使用遗忘曲线
item["final_score"] = base_score
# 步骤 6: 两阶段排序和限制
# 第一阶段按内容相关性base_score排序取 Top-K
first_stage_limit = limit * 3 # 可配置取3倍候选
first_stage_sorted = sorted(
combined_items.values(),
key=lambda x: float(x.get("base_score", 0) or 0), # 按内容分数排序
reverse=True
)[:first_stage_limit]
# 第二阶段:分离有激活值和无激活值的节点
items_with_activation = []
items_without_activation = []
for item in first_stage_sorted:
activation_score = item.get("activation_score")
# 检查是否有有效的激活值(不是 None
if activation_score is not None and isinstance(activation_score, (int, float)):
items_with_activation.append(item)
else:
items_without_activation.append(item)
# 优先按激活值排序有激活值的节点
sorted_with_activation = sorted(
items_with_activation,
key=lambda x: float(x.get("activation_score", 0) or 0),
reverse=True
)
# 如果有激活值的节点不足 limit用无激活值的节点补充
if len(sorted_with_activation) < limit:
needed = limit - len(sorted_with_activation)
# 无激活值的节点保持第一阶段的内容相关性排序
sorted_items = sorted_with_activation + items_without_activation[:needed]
else:
sorted_items = sorted_with_activation[:limit]
# 两阶段排序完成,更新 final_score 以反映实际排序依据
# Stage 1: 按 content_score 筛选候选(已完成)
# Stage 2: 按 activation_score 排序(已完成)
#
# final_score 语义:反映节点在最终结果中的排序依据
# - 有激活值的节点final_score = activation_score第二阶段排序依据
# - 无激活值的节点final_score = base_score保持内容相关性分数
for item in sorted_items:
activation_score = item.get("activation_score")
if activation_score is not None and isinstance(activation_score, (int, float)):
# 有激活值:使用激活度作为最终分数
item["final_score"] = activation_score
else:
# 无激活值:使用内容相关性分数
item["final_score"] = item.get("base_score", 0)
reranked[category] = sorted_items
return reranked
@@ -560,6 +794,7 @@ async def run_hybrid_search(
output_path: str | None,
memory_config: "MemoryConfig",
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
):
@@ -685,30 +920,28 @@ async def run_hybrid_search(
"search_timestamp": datetime.now().isoformat()
}
# Apply reranking (optionally with forgetting curve)
# Apply two-stage reranking with ACTR activation calculation
rerank_start = time.time()
if use_forgetting_rerank:
# Load forgetting parameters from pipeline config
try:
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
forgetting_cfg = ForgettingEngineConfig()
reranked_results = rerank_with_forgetting_curve(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha,
limit=limit,
forgetting_config=forgetting_cfg,
)
else:
reranked_results = rerank_hybrid_results(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha, # Configurable weight for BM25 vs embedding
limit=limit
)
logger.info("Using two-stage reranking with ACTR activation")
# 加载遗忘引擎配置
try:
pc = get_pipeline_config(memory_config)
forgetting_cfg = pc.forgetting_engine
except Exception as e:
logger.debug(f"Failed to load forgetting config, using defaults: {e}")
forgetting_cfg = ForgettingEngineConfig()
# 统一使用激活度重排序(两阶段:检索 + ACTR计算
reranked_results = rerank_with_activation(
keyword_results=keyword_results,
embedding_results=embedding_results,
alpha=rerank_alpha,
limit=limit,
forgetting_config=forgetting_cfg,
activation_boost_factor=activation_boost_factor,
)
rerank_latency = time.time() - rerank_start
latency_metrics["reranking_latency"] = round(rerank_latency, 4)
logger.info(f"Reranking completed in {rerank_latency:.4f}s")
@@ -737,6 +970,7 @@ async def run_hybrid_search(
"search_query": query_text,
"search_timestamp": datetime.now().isoformat(),
"reranking_alpha": rerank_alpha,
"activation_boost_factor": activation_boost_factor,
"forgetting_rerank": use_forgetting_rerank,
"llm_rerank": llm_rerank_applied,
}

View File

@@ -1,8 +1,40 @@
"""遗忘引擎模块
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线。
该模块实现记忆的遗忘机制,基于改进的艾宾浩斯遗忘曲线和 ACT-R 认知架构理论
"""
from app.core.memory.storage_services.forgetting_engine.forgetting_engine import ForgettingEngine
from app.core.memory.storage_services.forgetting_engine.actr_calculator import (
ACTRCalculator,
calculate_activation,
generate_forgetting_curve
)
from app.core.memory.storage_services.forgetting_engine.access_history_manager import (
AccessHistoryManager,
ConsistencyCheckResult
)
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import (
ForgettingStrategy
)
from app.core.memory.storage_services.forgetting_engine.forgetting_scheduler import (
ForgettingScheduler
)
from app.core.memory.storage_services.forgetting_engine.config_utils import (
calculate_forgetting_rate,
load_actr_config_from_db,
create_actr_calculator_from_config
)
__all__ = ["ForgettingEngine"]
__all__ = [
"ForgettingEngine",
"ACTRCalculator",
"calculate_activation",
"generate_forgetting_curve",
"AccessHistoryManager",
"ConsistencyCheckResult",
"ForgettingStrategy",
"ForgettingScheduler",
"calculate_forgetting_rate",
"load_actr_config_from_db",
"create_actr_calculator_from_config"
]

View File

@@ -0,0 +1,691 @@
"""
访问历史管理器模块
本模块实现访问历史的追踪、更新和一致性保证。
负责在知识节点被访问时原子性地更新激活值相关的所有字段。
Classes:
AccessHistoryManager: 访问历史管理器,提供并发安全的访问记录和一致性检查
"""
import logging
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from enum import Enum
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
class ConsistencyCheckResult(Enum):
"""一致性检查结果枚举"""
CONSISTENT = "consistent" # 数据一致
INCONSISTENT_HISTORY_TIME = "inconsistent_history_time" # access_history[-1] != last_access_time
INCONSISTENT_HISTORY_COUNT = "inconsistent_history_count" # len(access_history) != access_count
MISSING_ACTIVATION = "missing_activation" # 有访问历史但无激活值
INVALID_ACTIVATION_RANGE = "invalid_activation_range" # 激活值超出有效范围
class AccessHistoryManager:
"""
访问历史管理器
负责追踪知识节点的访问历史,并在访问时原子性地更新所有相关字段:
- activation_value: 激活值
- access_history: 访问历史时间戳数组
- last_access_time: 最后访问时间
- access_count: 访问次数
特性:
- 原子性更新使用Neo4j事务确保所有字段同时更新或回滚
- 并发安全:使用乐观锁机制防止并发冲突
- 一致性保证:提供一致性检查和自动修复功能
- 智能修剪:自动修剪过长的访问历史
Attributes:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 并发冲突时的最大重试次数
"""
def __init__(
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
max_retries: int = 3
):
"""
初始化访问历史管理器
Args:
connector: Neo4j连接器实例
actr_calculator: ACT-R激活值计算器实例
max_retries: 并发冲突时的最大重试次数默认3次
"""
self.connector = connector
self.actr_calculator = actr_calculator
self.max_retries = max_retries
async def record_access(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> Dict[str, Any]:
"""
记录节点访问并原子性更新所有相关字段
这是核心方法,实现了:
1. 首次访问初始化access_history计算初始激活值
2. 后续访问:追加访问历史,重新计算激活值
3. 历史修剪:当历史过长时自动修剪
4. 原子性:所有字段在单个事务中更新
5. 并发安全:使用乐观锁重试机制
Args:
node_id: 节点ID
node_label: 节点标签Statement, ExtractedEntity, MemorySummary
group_id: 组ID可选用于过滤
current_time: 当前时间(可选,默认使用系统时间)
Returns:
Dict[str, Any]: 更新后的节点数据,包含:
- id: 节点ID
- activation_value: 更新后的激活值
- access_history: 更新后的访问历史
- last_access_time: 最后访问时间
- access_count: 访问次数
- importance_score: 重要性分数
Raises:
ValueError: 如果节点不存在或节点标签无效
RuntimeError: 如果重试次数耗尽仍然失败
"""
if current_time is None:
current_time = datetime.now()
current_time_iso = current_time.isoformat()
# 验证节点标签
valid_labels = ["Statement", "ExtractedEntity", "MemorySummary"]
if node_label not in valid_labels:
raise ValueError(
f"Invalid node_label: {node_label}. Must be one of {valid_labels}"
)
# 使用乐观锁重试机制处理并发冲突
for attempt in range(self.max_retries):
try:
# 步骤1读取当前节点状态
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
raise ValueError(
f"Node not found: {node_label} with id={node_id}"
)
# 步骤2计算新的访问历史和激活值
update_data = await self._calculate_update(
node_data=node_data,
current_time=current_time,
current_time_iso=current_time_iso
)
# 步骤3原子性更新节点使用事务
updated_node = await self._atomic_update(
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
)
logger.info(
f"成功记录访问: {node_label}[{node_id}], "
f"activation={update_data['activation_value']:.4f}, "
f"access_count={update_data['access_count']}"
)
return updated_node
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(
f"访问记录失败(尝试 {attempt + 1}/{self.max_retries}: {str(e)}"
)
continue
else:
logger.error(
f"访问记录失败,重试次数耗尽: {node_label}[{node_id}], "
f"错误: {str(e)}"
)
raise RuntimeError(
f"Failed to record access after {self.max_retries} attempts: {str(e)}"
)
async def record_batch_access(
self,
node_ids: List[str],
node_label: str,
group_id: Optional[str] = None,
current_time: Optional[datetime] = None
) -> List[Dict[str, Any]]:
"""
批量记录多个节点的访问
为提高性能,批量更新多个节点的访问历史。
每个节点独立更新,失败的节点不影响其他节点。
Args:
node_ids: 节点ID列表
node_label: 节点标签(所有节点必须是同一类型)
group_id: 组ID可选
current_time: 当前时间(可选)
Returns:
List[Dict[str, Any]]: 成功更新的节点列表
"""
if current_time is None:
current_time = datetime.now()
results = []
failed_count = 0
for node_id in node_ids:
try:
updated_node = await self.record_access(
node_id=node_id,
node_label=node_label,
group_id=group_id,
current_time=current_time
)
results.append(updated_node)
except Exception as e:
failed_count += 1
logger.warning(
f"批量访问记录失败: {node_label}[{node_id}], 错误: {str(e)}"
)
logger.info(
f"批量访问记录完成: 成功 {len(results)}/{len(node_ids)}, "
f"失败 {failed_count}"
)
return results
async def check_consistency(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> Tuple[ConsistencyCheckResult, Optional[str]]:
"""
检查节点数据的一致性
验证以下一致性规则:
1. access_history[-1] == last_access_time
2. len(access_history) == access_count
3. 如果有访问历史,必须有激活值
4. 激活值必须在有效范围内 [offset, 1.0]
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
Tuple[ConsistencyCheckResult, Optional[str]]:
- 一致性检查结果枚举
- 错误描述(如果不一致)
"""
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
return ConsistencyCheckResult.CONSISTENT, None
access_history = node_data.get('access_history', [])
last_access_time = node_data.get('last_access_time')
access_count = node_data.get('access_count', 0)
activation_value = node_data.get('activation_value')
# 检查1access_history[-1] == last_access_time
if access_history and last_access_time:
if access_history[-1] != last_access_time:
return (
ConsistencyCheckResult.INCONSISTENT_HISTORY_TIME,
f"access_history[-1]={access_history[-1]} != "
f"last_access_time={last_access_time}"
)
# 检查2len(access_history) == access_count
if len(access_history) != access_count:
return (
ConsistencyCheckResult.INCONSISTENT_HISTORY_COUNT,
f"len(access_history)={len(access_history)} != "
f"access_count={access_count}"
)
# 检查3有访问历史必须有激活值
if access_history and activation_value is None:
return (
ConsistencyCheckResult.MISSING_ACTIVATION,
"Node has access_history but activation_value is None"
)
# 检查4激活值范围
if activation_value is not None:
offset = self.actr_calculator.offset
if not (offset <= activation_value <= 1.0):
return (
ConsistencyCheckResult.INVALID_ACTIVATION_RANGE,
f"activation_value={activation_value} out of range "
f"[{offset}, 1.0]"
)
return ConsistencyCheckResult.CONSISTENT, None
async def check_batch_consistency(
self,
node_label: str,
group_id: Optional[str] = None,
limit: int = 1000
) -> Dict[str, Any]:
"""
批量检查多个节点的一致性
Args:
node_label: 节点标签
group_id: 组ID可选
limit: 检查的最大节点数
Returns:
Dict[str, Any]: 一致性检查报告,包含:
- total_checked: 检查的节点总数
- consistent_count: 一致的节点数
- inconsistent_count: 不一致的节点数
- inconsistencies: 不一致节点的详细信息列表
- consistency_rate: 一致性率0-1
"""
# 查询所有相关节点
query = f"""
MATCH (n:{node_label})
WHERE n.access_history IS NOT NULL
"""
if group_id:
query += " AND n.group_id = $group_id"
query += """
RETURN n.id as id
LIMIT $limit
"""
params = {"limit": limit}
if group_id:
params["group_id"] = group_id
results = await self.connector.execute_query(query, **params)
node_ids = [r['id'] for r in results]
# 检查每个节点
inconsistencies = []
consistent_count = 0
for node_id in node_ids:
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
)
if result == ConsistencyCheckResult.CONSISTENT:
consistent_count += 1
else:
inconsistencies.append({
'node_id': node_id,
'result': result.value,
'message': message
})
total_checked = len(node_ids)
inconsistent_count = len(inconsistencies)
consistency_rate = consistent_count / total_checked if total_checked > 0 else 1.0
report = {
'total_checked': total_checked,
'consistent_count': consistent_count,
'inconsistent_count': inconsistent_count,
'inconsistencies': inconsistencies,
'consistency_rate': consistency_rate
}
logger.info(
f"一致性检查完成: {node_label}, "
f"一致率={consistency_rate:.2%}, "
f"不一致节点={inconsistent_count}/{total_checked}"
)
return report
async def repair_inconsistency(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> bool:
"""
自动修复节点的数据不一致问题
修复策略:
1. 如果access_history[-1] != last_access_time使用access_history[-1]
2. 如果len(access_history) != access_count使用len(access_history)
3. 如果有历史但无激活值:重新计算激活值
4. 如果激活值超出范围:重新计算激活值
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
bool: 修复成功返回True否则返回False
"""
try:
# 检查一致性
result, message = await self.check_consistency(
node_id=node_id,
node_label=node_label,
group_id=group_id
)
if result == ConsistencyCheckResult.CONSISTENT:
logger.info(f"节点数据一致,无需修复: {node_label}[{node_id}]")
return True
# 获取节点数据
node_data = await self._fetch_node(node_id, node_label, group_id)
if not node_data:
logger.error(f"节点不存在,无法修复: {node_label}[{node_id}]")
return False
access_history = node_data.get('access_history', [])
importance_score = node_data.get('importance_score', 0.5)
# 准备修复数据
repair_data = {}
# 修复last_access_time
if access_history:
repair_data['last_access_time'] = access_history[-1]
# 修复access_count
repair_data['access_count'] = len(access_history)
# 修复activation_value
if access_history:
current_time = datetime.now()
last_access_dt = datetime.fromisoformat(access_history[-1])
access_history_dt = [
datetime.fromisoformat(ts) for ts in access_history
]
activation_value = self.actr_calculator.calculate_memory_activation(
access_history=access_history_dt,
current_time=current_time,
last_access_time=last_access_dt,
importance_score=importance_score
)
repair_data['activation_value'] = activation_value
# 执行修复
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
query += """
SET n += $repair_data
RETURN n
"""
params = {
'node_id': node_id,
'repair_data': repair_data
}
if group_id:
params['group_id'] = group_id
await self.connector.execute_query(query, **params)
logger.info(
f"成功修复节点不一致: {node_label}[{node_id}], "
f"问题类型={result.value}"
)
return True
except Exception as e:
logger.error(
f"修复节点失败: {node_label}[{node_id}], 错误: {str(e)}"
)
return False
# ==================== 私有辅助方法 ====================
async def _fetch_node(
self,
node_id: str,
node_label: str,
group_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
获取节点数据
Args:
node_id: 节点ID
node_label: 节点标签
group_id: 组ID可选
Returns:
Optional[Dict[str, Any]]: 节点数据如果不存在返回None
"""
query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
query += " WHERE n.group_id = $group_id"
query += """
RETURN n.id as id,
n.importance_score as importance_score,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count
"""
params = {'node_id': node_id}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
if results:
return results[0]
return None
async def _calculate_update(
self,
node_data: Dict[str, Any],
current_time: datetime,
current_time_iso: str
) -> Dict[str, Any]:
"""
计算更新数据
Args:
node_data: 当前节点数据
current_time: 当前时间datetime对象
current_time_iso: 当前时间ISO格式字符串
Returns:
Dict[str, Any]: 更新数据,包含所有需要更新的字段
"""
access_history = node_data.get('access_history', [])
importance_score = node_data.get('importance_score', 0.5)
# 追加新的访问时间
new_access_history = access_history + [current_time_iso]
# 修剪访问历史(如果过长)
access_history_dt = [
datetime.fromisoformat(ts) for ts in new_access_history
]
trimmed_history_dt = self.actr_calculator.trim_access_history(
access_history=access_history_dt,
current_time=current_time
)
trimmed_history = [ts.isoformat() for ts in trimmed_history_dt]
# 计算新的激活值
activation_value = self.actr_calculator.calculate_memory_activation(
access_history=trimmed_history_dt,
current_time=current_time,
last_access_time=current_time, # 最后访问时间就是当前时间
importance_score=importance_score
)
# 返回所有需要更新的字段
return {
'activation_value': activation_value,
'access_history': trimmed_history,
'last_access_time': current_time_iso,
'access_count': len(trimmed_history)
}
async def _atomic_update(
self,
node_id: str,
node_label: str,
update_data: Dict[str, Any],
group_id: Optional[str] = None
) -> Dict[str, Any]:
"""
原子性更新节点(使用乐观锁)
使用Neo4j事务和版本号确保所有字段同时更新或回滚。
实现乐观锁机制防止并发冲突。
Args:
node_id: 节点ID
node_label: 节点标签
update_data: 更新数据
group_id: 组ID可选
Returns:
Dict[str, Any]: 更新后的节点数据
Raises:
RuntimeError: 如果更新失败或发生版本冲突
"""
# 定义事务函数
async def update_transaction(tx, node_id, node_label, update_data, group_id):
# 步骤1读取当前节点并获取版本号
read_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
read_query += " WHERE n.group_id = $group_id"
read_query += """
RETURN n.id as id,
n.version as version,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score
"""
read_params = {'node_id': node_id}
if group_id:
read_params['group_id'] = group_id
read_result = await tx.run(read_query, **read_params)
current_node = await read_result.single()
if not current_node:
raise RuntimeError(f"Node not found: {node_label}[{node_id}]")
# 获取当前版本号如果不存在则为0
current_version = current_node.get('version', 0) or 0
new_version = current_version + 1
# 步骤2使用乐观锁更新节点
# 只有当版本号匹配时才更新
update_query = f"""
MATCH (n:{node_label} {{id: $node_id}})
"""
if group_id:
update_query += " WHERE n.group_id = $group_id"
# 添加版本检查
if current_version > 0:
update_query += " AND n.version = $current_version"
else:
# 如果节点没有版本号,检查是否为首次更新
update_query += " AND (n.version IS NULL OR n.version = 0)"
update_query += """
SET n.activation_value = $activation_value,
n.access_history = $access_history,
n.last_access_time = $last_access_time,
n.access_count = $access_count,
n.version = $new_version
RETURN n.id as id,
n.activation_value as activation_value,
n.access_history as access_history,
n.last_access_time as last_access_time,
n.access_count as access_count,
n.importance_score as importance_score,
n.version as version
"""
update_params = {
'node_id': node_id,
'current_version': current_version,
'new_version': new_version,
'activation_value': update_data['activation_value'],
'access_history': update_data['access_history'],
'last_access_time': update_data['last_access_time'],
'access_count': update_data['access_count']
}
if group_id:
update_params['group_id'] = group_id
update_result = await tx.run(update_query, **update_params)
updated_node = await update_result.single()
if not updated_node:
raise RuntimeError(
f"Version conflict detected for {node_label}[{node_id}]. "
f"Expected version {current_version}, but node was modified by another transaction."
)
return dict(updated_node)
# 执行事务
try:
result = await self.connector.execute_write_transaction(
update_transaction,
node_id=node_id,
node_label=node_label,
update_data=update_data,
group_id=group_id
)
return result
except Exception as e:
logger.error(
f"原子性更新失败: {node_label}[{node_id}], 错误: {str(e)}"
)
raise RuntimeError(
f"Failed to atomically update node: {str(e)}"
) from e

View File

@@ -0,0 +1,359 @@
"""
ACT-R Memory Activation Calculator
This module implements the unified Memory Activation model based on ACT-R
(Adaptive Control of Thought-Rational) cognitive architecture theory.
The calculator integrates BLA (Base-Level Activation) computation into the
Memory Activation formula, providing a single coherent model for memory strength
calculation that reflects both recency and frequency of access.
Formula: R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
Where:
- R(i): Memory activation value (0 to 1)
- offset: Minimum retention rate (prevents complete forgetting)
- λ: Forgetting rate (lambda_time / lambda_mem)
- t: Time since last access
- I: Importance score (0 to 1)
- t_k: Time since k-th access
- d: Decay constant (typically 0.5)
Reference: Anderson, J. R. (2007). How Can the Human Mind Occur in the Physical Universe?
"""
import math
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
class ACTRCalculator:
"""
Unified ACT-R Memory Activation Calculator.
This calculator implements the Memory Activation model that combines
recency and frequency effects into a single activation value computation.
It replaces the separate BLA calculation with an integrated approach.
Attributes:
decay_constant: Decay parameter d (typically 0.5)
forgetting_rate: Lambda parameter λ controlling forgetting speed
offset: Minimum retention rate (baseline memory strength)
max_history_length: Maximum number of access records to keep
"""
def __init__(
self,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1,
max_history_length: int = 100
):
"""
Initialize the ACT-R calculator.
Args:
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
max_history_length: Maximum access history length (default 100)
"""
self.decay_constant = decay_constant
self.forgetting_rate = forgetting_rate
self.offset = offset
self.max_history_length = max_history_length
def calculate_memory_activation(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5
) -> float:
"""
Calculate memory activation value using the unified Memory Activation formula.
This method computes R(i) = offset + (1-offset) * exp(-λ*t / Σ(I·t_k^(-d)))
The formula integrates:
- Recency effect: Recent accesses contribute more (via t)
- Frequency effect: Multiple accesses strengthen memory (via Σ)
- Importance weighting: Important memories decay slower (via I)
Args:
access_history: List of access timestamps (ISO format or datetime objects)
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
Returns:
float: Memory activation value between offset and 1.0
Raises:
ValueError: If access_history is empty or contains invalid data
"""
if not access_history:
raise ValueError("access_history cannot be empty")
if not (0.0 <= importance_score <= 1.0):
raise ValueError(f"importance_score must be between 0 and 1, got {importance_score}")
# Calculate time since last access (in days)
time_since_last = (current_time - last_access_time).total_seconds() / 86400.0
time_since_last = max(time_since_last, 0.0001) # Avoid division by zero
# Calculate BLA component: Σ(I·t_k^(-d))
bla_sum = 0.0
for access_time in access_history:
# Calculate time since this access (in days)
time_diff = (current_time - access_time).total_seconds() / 86400.0
time_diff = max(time_diff, 0.0001) # Avoid division by zero
# Add weighted power-law term: I * t_k^(-d)
bla_sum += importance_score * (time_diff ** (-self.decay_constant))
# Avoid division by zero in case of numerical issues
if bla_sum <= 0:
bla_sum = 0.0001
# Calculate Memory Activation: R(i) = offset + (1-offset) * exp(-λ*t / BLA)
exponent = -self.forgetting_rate * time_since_last / bla_sum
# Clamp exponent to avoid numerical overflow/underflow
exponent = max(min(exponent, 100), -100)
activation = self.offset + (1 - self.offset) * math.exp(exponent)
# Ensure activation is within valid range [offset, 1.0]
return max(self.offset, min(1.0, activation))
def trim_access_history(
self,
access_history: List[datetime],
current_time: datetime
) -> List[datetime]:
"""
Intelligently trim access history to prevent unbounded growth.
Strategy:
- Keep all records if under max_history_length
- If over limit, keep most recent 50% and sample from older records
- Preserves both recent accesses (high importance) and historical pattern
Args:
access_history: List of access timestamps (sorted or unsorted)
current_time: Current time for calculation
Returns:
List[datetime]: Trimmed access history
"""
if len(access_history) <= self.max_history_length:
return access_history
# Sort by time (most recent first)
sorted_history = sorted(access_history, reverse=True)
# Calculate split point (keep most recent 50%)
keep_recent_count = self.max_history_length // 2
# Keep most recent 50%
recent_records = sorted_history[:keep_recent_count]
# Sample from older records
older_records = sorted_history[keep_recent_count:]
sample_count = self.max_history_length - keep_recent_count
if len(older_records) <= sample_count:
# If older records fit, keep them all
sampled_older = older_records
else:
# Sample evenly from older records
step = len(older_records) / sample_count
sampled_older = [
older_records[int(i * step)]
for i in range(sample_count)
]
# Combine and return
trimmed_history = recent_records + sampled_older
return sorted(trimmed_history, reverse=True)
def get_forgetting_curve( # 预测激活值决定复习测试不同配置效果选择合适的d
self,
initial_time: datetime,
importance_score: float = 0.5,
days: int = 60
) -> List[Dict[str, Any]]:
"""
Generate forgetting curve data for visualization.
This method simulates how memory activation decays over time
for a single initial access, useful for understanding and
visualizing the forgetting behavior.
Args:
initial_time: Time of initial memory creation/access
importance_score: Importance weight (0 to 1, default 0.5)
days: Number of days to simulate (default 60)
Returns:
List of dictionaries with keys:
- 'day': Day number (0 to days)
- 'activation': Memory activation value
- 'retention_rate': Same as activation (for compatibility)
"""
curve_data = []
access_history = [initial_time]
for day in range(days + 1):
current_time = initial_time + timedelta(days=day)
try:
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=initial_time,
importance_score=importance_score
)
except ValueError:
# Handle edge cases
activation = self.offset
curve_data.append({
'day': day,
'activation': activation,
'retention_rate': activation # Alias for compatibility
})
return curve_data
def calculate_forgetting_score(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5
) -> float:
"""
Calculate forgetting score (inverse of activation).
Forgetting score = 1 - activation value
Higher score means more likely to be forgotten.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
Returns:
float: Forgetting score between 0 and (1 - offset)
"""
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
return 1.0 - activation
def should_forget(
self,
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5,
threshold: float = 0.3
) -> bool:
"""
Determine if a memory should be forgotten based on activation threshold.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
threshold: Activation threshold below which memory should be forgotten
Returns:
bool: True if activation < threshold (should forget), False otherwise
"""
activation = self.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
return activation < threshold
# Convenience functions for quick calculations
def calculate_activation(
access_history: List[datetime],
current_time: datetime,
last_access_time: datetime,
importance_score: float = 0.5,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1
) -> float:
"""
Quick function to calculate activation without creating a calculator instance.
Args:
access_history: List of access timestamps
current_time: Current time for calculation
last_access_time: Time of most recent access
importance_score: Importance weight (0 to 1, default 0.5)
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
Returns:
float: Memory activation value between offset and 1.0
"""
calculator = ACTRCalculator(
decay_constant=decay_constant,
forgetting_rate=forgetting_rate,
offset=offset
)
return calculator.calculate_memory_activation(
access_history=access_history,
current_time=current_time,
last_access_time=last_access_time,
importance_score=importance_score
)
def generate_forgetting_curve(
initial_time: datetime,
importance_score: float = 0.5,
days: int = 60,
decay_constant: float = 0.5,
forgetting_rate: float = 0.3,
offset: float = 0.1
) -> List[Dict[str, Any]]:
"""
Quick function to generate forgetting curve data.
Args:
initial_time: Time of initial memory creation/access
importance_score: Importance weight (0 to 1, default 0.5)
days: Number of days to simulate (default 60)
decay_constant: Decay parameter d (default 0.5)
forgetting_rate: Forgetting rate λ (default 0.3)
offset: Minimum retention rate (default 0.1)
Returns:
List of dictionaries with forgetting curve data
"""
calculator = ACTRCalculator(
decay_constant=decay_constant,
forgetting_rate=forgetting_rate,
offset=offset
)
return calculator.get_forgetting_curve(
initial_time=initial_time,
importance_score=importance_score,
days=days
)

View File

@@ -0,0 +1,195 @@
"""
遗忘引擎配置工具模块
本模块提供从数据库加载配置并创建遗忘引擎组件的辅助函数。
Functions:
calculate_forgetting_rate: 计算遗忘速率lambda_time / lambda_mem
load_actr_config_from_db: 从数据库加载 ACT-R 配置参数
create_actr_calculator_from_config: 从配置创建 ACTRCalculator 实例
"""
import logging
from typing import Optional, Dict, Any
from sqlalchemy.orm import Session
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
def calculate_forgetting_rate(lambda_time: float, lambda_mem: float) -> float:
"""
计算遗忘速率
公式forgetting_rate = lambda_time / lambda_mem
这个计算将两个独立的 lambda 参数组合成一个统一的遗忘速率参数,
用于 ACT-R 激活值计算。
Args:
lambda_time: 时间衰减参数0-1
lambda_mem: 记忆衰减参数0-1
Returns:
float: 遗忘速率
Raises:
ValueError: 如果 lambda_mem 为 0
Examples:
>>> calculate_forgetting_rate(0.5, 0.5)
1.0
>>> calculate_forgetting_rate(0.3, 0.5)
0.6
"""
if lambda_mem == 0:
raise ValueError("lambda_mem 不能为 0")
forgetting_rate = lambda_time / lambda_mem
logger.debug(
f"计算遗忘速率: lambda_time={lambda_time}, "
f"lambda_mem={lambda_mem}, "
f"forgetting_rate={forgetting_rate:.4f}"
)
return forgetting_rate
def load_actr_config_from_db(
db: Session,
config_id: Optional[int] = None
) -> Dict[str, Any]:
"""
从数据库加载 ACT-R 配置参数
从 PostgreSQL 的 data_config 表读取配置参数,
并计算派生参数(如 forgetting_rate
Args:
db: 数据库会话
config_id: 配置 ID可选如果为 None 则使用默认值)
Returns:
Dict[str, Any]: 配置参数字典,包含:
- decay_constant: 衰减常数 d
- lambda_time: 时间衰减参数
- lambda_mem: 记忆衰减参数
- forgetting_rate: 遗忘速率(根据 lambda_time / lambda_mem 计算得出)
- offset: 偏移量
- max_history_length: 访问历史最大长度
- forgetting_threshold: 遗忘阈值
- min_days_since_access: 最小未访问天数
- enable_llm_summary: 是否使用 LLM 生成摘要
- max_merge_batch_size: 单次最大融合节点对数
- forgetting_interval_hours: 遗忘周期间隔
注意llm_id 不包含在返回的配置中,需要时由 forgetting_strategy 直接从数据库读取
Raises:
ValueError: 如果指定的 config_id 不存在
"""
# 必须指定 config_id
if config_id is None:
logger.error("未指定 config_id无法加载配置")
raise ValueError("config_id 不能为空,必须指定一个有效的配置 ID")
# 从数据库加载配置
try:
repository = DataConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None:
logger.error(f"配置不存在: config_id={config_id}")
raise ValueError(f"配置不存在: config_id={config_id}")
# 读取配置参数(信任数据库默认值)
lambda_time = db_config.lambda_time
lambda_mem = db_config.lambda_mem
decay_constant = db_config.decay_constant
offset = db_config.offset
max_history_length = db_config.max_history_length
forgetting_threshold = db_config.forgetting_threshold
min_days_since_access = db_config.min_days_since_access
enable_llm_summary = db_config.enable_llm_summary
max_merge_batch_size = db_config.max_merge_batch_size
forgetting_interval_hours = db_config.forgetting_interval_hours
# 计算 forgetting_rate
forgetting_rate = calculate_forgetting_rate(lambda_time, lambda_mem)
config = {
'decay_constant': decay_constant,
'lambda_time': lambda_time,
'lambda_mem': lambda_mem,
'forgetting_rate': forgetting_rate,
'offset': offset,
'max_history_length': max_history_length,
'forgetting_threshold': forgetting_threshold,
'min_days_since_access': min_days_since_access,
'enable_llm_summary': enable_llm_summary,
'max_merge_batch_size': max_merge_batch_size,
'forgetting_interval_hours': forgetting_interval_hours
# 注意llm_id 不包含在配置响应中,仅在内部使用
}
logger.info(
f"成功加载 ACT-R 配置: config_id={config_id}, "
f"forgetting_rate={forgetting_rate:.4f}"
)
return config
except Exception as e:
logger.error(f"加载 ACT-R 配置失败: config_id={config_id}, 错误: {str(e)}")
raise
def create_actr_calculator_from_config(
db: Session,
config_id: Optional[int] = None
) -> ACTRCalculator:
"""
从数据库配置创建 ACTRCalculator 实例
这是创建 ACTRCalculator 的推荐方式,确保使用数据库中的配置参数。
Args:
db: 数据库会话
config_id: 配置 ID可选如果为 None 则使用默认值)
Returns:
ACTRCalculator: 配置好的 ACT-R 计算器实例
Raises:
ValueError: 如果指定的 config_id 不存在
Examples:
>>> from sqlalchemy.orm import Session
>>> db = Session()
>>> calculator = create_actr_calculator_from_config(db, config_id=1)
>>> # 使用计算器
>>> activation = calculator.calculate_memory_activation(...)
"""
# 加载配置
config = load_actr_config_from_db(db, config_id)
# 创建计算器
calculator = ACTRCalculator(
decay_constant=config['decay_constant'],
forgetting_rate=config['forgetting_rate'],
offset=config['offset'],
max_history_length=config['max_history_length']
)
logger.info(
f"创建 ACTRCalculator: config_id={config_id}, "
f"decay_constant={config['decay_constant']}, "
f"forgetting_rate={config['forgetting_rate']:.4f}, "
f"offset={config['offset']}"
)
return calculator

View File

@@ -0,0 +1,351 @@
"""
遗忘调度器模块
本模块实现遗忘周期的调度和管理,负责:
1. 手动触发遗忘周期
2. 批量处理可遗忘节点(限制批量大小)
3. 按激活值优先级排序(激活值最低的优先)
4. 进度跟踪和日志记录
5. 生成遗忘报告
注意:定期调度功能已迁移到 Celery Beat见 app/tasks.py 中的 run_forgetting_cycle_task
Classes:
ForgettingScheduler: 遗忘调度器,提供遗忘周期管理功能
"""
import logging
from typing import Dict, Any, Optional
from datetime import datetime
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
class ForgettingScheduler:
"""
遗忘调度器
管理遗忘周期的执行,实现批量处理、优先级排序和进度跟踪功能。
核心功能:
1. 运行遗忘周期:识别可遗忘节点并批量融合
2. 优先级排序:优先处理激活值最低的节点对
3. 批量限制:限制单次处理的节点对数量
4. 进度跟踪:每完成 10% 记录一次日志
5. 遗忘报告:生成详细的执行报告
注意:定期调度功能已迁移到 Celery Beat 定时任务
Attributes:
forgetting_strategy: 遗忘策略执行器实例
connector: Neo4j 连接器实例
is_running: 是否正在运行遗忘周期
"""
def __init__(
self,
forgetting_strategy: ForgettingStrategy,
connector: Neo4jConnector
):
"""
初始化遗忘调度器
Args:
forgetting_strategy: 遗忘策略执行器实例
connector: Neo4j 连接器实例
"""
self.forgetting_strategy = forgetting_strategy
self.connector = connector
self.is_running = False
logger.info("初始化遗忘调度器")
async def run_forgetting_cycle(
self,
group_id: Optional[str] = None,
max_merge_batch_size: int = 100,
min_days_since_access: int = 30,
config_id: Optional[int] = None,
db = None
) -> Dict[str, Any]:
"""
运行一次完整的遗忘周期
Args:
group_id: 组 ID可选用于过滤特定组的节点
max_merge_batch_size: 单次最大融合节点对数(默认 100
min_days_since_access: 最小未访问天数(默认 30 天)
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
Dict[str, Any]: 遗忘报告,包含:
- merged_count: 融合的节点对数量
- nodes_before: 遗忘前的节点总数
- nodes_after: 遗忘后的节点总数
- reduction_rate: 节点减少率0-1
- duration_seconds: 执行耗时(秒)
- start_time: 开始时间ISO 格式)
- end_time: 结束时间ISO 格式)
- failed_count: 失败的融合数量
- success_rate: 成功率0-1
Raises:
RuntimeError: 如果已有遗忘周期正在运行
"""
# 检查是否已有遗忘周期在运行
if self.is_running:
raise RuntimeError("遗忘周期已在运行中,请等待当前周期完成")
self.is_running = True
start_time = datetime.now()
start_time_iso = start_time.isoformat()
logger.info(
f"开始遗忘周期: group_id={group_id}, "
f"max_batch={max_merge_batch_size}, "
f"min_days={min_days_since_access}"
)
try:
# 步骤1统计遗忘前的节点数量
nodes_before = await self._count_knowledge_nodes(group_id)
logger.info(f"遗忘前节点总数: {nodes_before}")
# 步骤2识别可遗忘的节点对
forgettable_pairs = await self.forgetting_strategy.find_forgettable_nodes(
group_id=group_id,
min_days_since_access=min_days_since_access
)
total_forgettable = len(forgettable_pairs)
logger.info(f"识别到 {total_forgettable} 个可遗忘节点对")
if total_forgettable == 0:
logger.info("没有可遗忘的节点对,遗忘周期结束")
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
report = {
'merged_count': 0,
'nodes_before': nodes_before,
'nodes_after': nodes_before,
'reduction_rate': 0.0,
'duration_seconds': duration,
'start_time': start_time_iso,
'end_time': end_time.isoformat(),
'failed_count': 0,
'success_rate': 1.0
}
logger.info("没有可遗忘的节点对,遗忘周期结束")
return report
# 步骤3按激活值排序激活值最低的优先
# avg_activation 已经在 find_forgettable_nodes 中计算并排序
# 这里只需要确认排序是正确的(升序)
sorted_pairs = sorted(
forgettable_pairs,
key=lambda x: x['avg_activation']
)
# 步骤4限制批量大小
pairs_to_process = sorted_pairs[:max_merge_batch_size]
actual_batch_size = len(pairs_to_process)
logger.info(
f"将处理 {actual_batch_size} 个节点对 "
f"(限制: {max_merge_batch_size})"
)
# 步骤5批量融合节点每 10% 记录进度
merged_count = 0
failed_count = 0
skipped_count = 0 # 跳过的节点对数量(节点已被处理)
progress_interval = max(1, actual_batch_size // 10) # 每 10% 记录一次
# 跟踪已处理的节点 ID避免重复处理
processed_statement_ids = set()
processed_entity_ids = set()
# 预先过滤掉重复的节点对
unique_pairs = []
for pair in pairs_to_process:
statement_id = pair['statement_id']
entity_id = pair['entity_id']
# 如果节点已被标记为处理,跳过
if statement_id in processed_statement_ids or entity_id in processed_entity_ids:
skipped_count += 1
logger.debug(
f"预过滤:跳过重复节点对 Statement[{statement_id}] + Entity[{entity_id}]"
)
continue
# 标记节点为已处理
processed_statement_ids.add(statement_id)
processed_entity_ids.add(entity_id)
unique_pairs.append(pair)
logger.info(
f"预过滤完成:原始 {actual_batch_size} 对,去重后 {len(unique_pairs)} 对,"
f"跳过 {skipped_count} 对重复节点"
)
# 更新实际处理的批次大小
actual_batch_size = len(unique_pairs)
progress_interval = max(1, actual_batch_size // 10) # 重新计算进度间隔
for idx, pair in enumerate(unique_pairs, start=1):
statement_id = pair['statement_id']
entity_id = pair['entity_id']
try:
# 准备节点数据
statement_node = {
'statement_id': statement_id,
'statement_text': pair['statement_text'],
'statement_activation': pair['statement_activation'],
'statement_importance': pair['statement_importance'],
'group_id': group_id
}
entity_node = {
'entity_id': entity_id,
'entity_name': pair['entity_name'],
'entity_type': pair['entity_type'],
'entity_activation': pair['entity_activation'],
'entity_importance': pair['entity_importance'],
'group_id': group_id
}
# 融合节点
await self.forgetting_strategy.merge_nodes_to_summary(
statement_node=statement_node,
entity_node=entity_node,
config_id=config_id,
db=db
)
merged_count += 1
# 进度跟踪:每 10% 记录一次
if actual_batch_size > 0 and (idx % progress_interval == 0 or idx == actual_batch_size):
progress_pct = (idx / actual_batch_size) * 100
logger.info(
f"遗忘进度: {idx}/{actual_batch_size} "
f"({progress_pct:.1f}%), "
f"已融合: {merged_count}, 失败: {failed_count}"
)
except Exception as e:
failed_count += 1
# 检查是否是节点不存在的错误
if "nodes may not exist" in str(e):
logger.warning(
f"节点对 ({idx}/{actual_batch_size}) 的节点不存在(可能已被其他操作删除): "
f"Statement[{statement_id}] + Entity[{entity_id}]"
)
else:
logger.error(
f"融合节点对失败 ({idx}/{actual_batch_size}): "
f"Statement[{statement_id}] + Entity[{entity_id}], "
f"错误: {str(e)}"
)
# 继续处理剩余节点
continue
# 步骤6统计遗忘后的节点数量
nodes_after = await self._count_knowledge_nodes(group_id)
logger.info(f"遗忘后节点总数: {nodes_after}")
# 步骤7生成遗忘报告
end_time = datetime.now()
duration = (end_time - start_time).total_seconds()
# 计算节点减少率
if nodes_before > 0:
reduction_rate = (nodes_before - nodes_after) / nodes_before
else:
reduction_rate = 0.0
# 计算成功率
if actual_batch_size > 0:
success_rate = merged_count / actual_batch_size
else:
success_rate = 1.0
report = {
'merged_count': merged_count,
'nodes_before': nodes_before,
'nodes_after': nodes_after,
'reduction_rate': reduction_rate,
'duration_seconds': duration,
'start_time': start_time_iso,
'end_time': end_time.isoformat(),
'failed_count': failed_count,
'success_rate': success_rate
}
logger.info(
f"遗忘周期完成: "
f"融合 {merged_count} 对节点, "
f"失败 {failed_count} 对, "
f"节点减少 {nodes_before - nodes_after}"
f"({reduction_rate:.2%}), "
f"耗时 {duration:.2f}"
)
return report
except Exception as e:
logger.error(f"遗忘周期执行失败: {str(e)}")
raise
finally:
self.is_running = False
# ==================== 私有辅助方法 ====================
async def _count_knowledge_nodes(
self,
group_id: Optional[str] = None
) -> int:
"""
统计知识层节点总数
统计 Statement、ExtractedEntity 和 MemorySummary 节点的总数。
Args:
group_id: 组 ID可选用于过滤特定组的节点
Returns:
int: 知识层节点总数
"""
query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
"""
if group_id:
query += " AND n.group_id = $group_id"
query += """
RETURN count(n) as total
"""
params = {}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
if results:
return results[0]['total']
return 0

View File

@@ -0,0 +1,611 @@
"""
遗忘策略执行器模块
本模块实现基于 ACT-R 激活值的遗忘策略,负责:
1. 识别低激活值的节点对Statement-Entity
2. 将低激活值节点融合为 MemorySummary 节点
3. 使用 LLM 生成高质量摘要(可选)
4. 保留溯源信息并删除原始节点
Classes:
ForgettingStrategy: 遗忘策略执行器,提供节点识别和融合功能
"""
import logging
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.core.memory.storage_services.forgetting_engine.actr_calculator import ACTRCalculator
logger = logging.getLogger(__name__)
class ForgettingStrategy:
"""
遗忘策略执行器
基于 ACT-R 激活值识别和融合低价值记忆节点。
实现了完整的遗忘周期:识别 → 融合 → 删除。
核心功能:
1. 识别可遗忘节点:激活值低于阈值且长期未访问的 Statement-Entity 对
2. 节点融合:创建 MemorySummary 节点,继承较高的激活值和重要性
3. LLM 摘要生成:使用 LLM 生成语义摘要(可降级到简单拼接)
4. 溯源保留:记录原始节点 ID保持可追溯性
Attributes:
connector: Neo4j 连接器实例
actr_calculator: ACT-R 激活值计算器实例
forgetting_threshold: 遗忘阈值(激活值低于此值的节点可被遗忘)
"""
def __init__(
self,
connector: Neo4jConnector,
actr_calculator: ACTRCalculator,
forgetting_threshold: float = 0.3,
enable_llm_summary: bool = True
):
"""
初始化遗忘策略执行器
Args:
connector: Neo4j 连接器实例
actr_calculator: ACT-R 激活值计算器实例
forgetting_threshold: 遗忘阈值(默认 0.3
enable_llm_summary: 是否启用 LLM 摘要生成(默认 True
"""
self.connector = connector
self.actr_calculator = actr_calculator
self.forgetting_threshold = forgetting_threshold
self.enable_llm_summary = enable_llm_summary
logger.info(
f"初始化遗忘策略执行器: threshold={forgetting_threshold}, "
f"enable_llm_summary={enable_llm_summary}"
)
async def calculate_forgetting_score(
self,
activation_value: float
) -> float:
"""
计算遗忘分数
遗忘分数 = 1 - 激活值
分数越高,越容易被遗忘。
注意:激活值已经包含了 importance_score 的权重,
因此不需要单独考虑重要性分数。
Args:
activation_value: 节点的激活值0-1
Returns:
float: 遗忘分数0-1值越高越容易被遗忘
"""
return 1.0 - activation_value
async def find_forgettable_nodes(
self,
group_id: Optional[str] = None,
min_days_since_access: int = 30
) -> List[Dict[str, Any]]:
"""
识别可遗忘的节点对
查找满足以下条件的 Statement-Entity 节点对:
1. 两个节点的激活值都低于遗忘阈值
2. 两个节点都至少 min_days_since_access 天未被访问
3. Statement 和 Entity 之间存在关系边
Args:
group_id: 组 ID可选用于过滤特定组的节点
min_days_since_access: 最小未访问天数(默认 30 天)
Returns:
List[Dict[str, Any]]: 可遗忘节点对列表,每个元素包含:
- statement_id: Statement 节点 ID
- statement_text: Statement 文本内容
- statement_activation: Statement 激活值
- statement_importance: Statement 重要性分数
- statement_last_access: Statement 最后访问时间
- entity_id: Entity 节点 ID
- entity_name: Entity 名称
- entity_type: Entity 类型
- entity_activation: Entity 激活值
- entity_importance: Entity 重要性分数
- entity_last_access: Entity 最后访问时间
- avg_activation: 平均激活值(用于排序)
"""
# 计算时间阈值
cutoff_time = datetime.now() - timedelta(days=min_days_since_access)
cutoff_time_iso = cutoff_time.isoformat()
# 构建查询
query = """
MATCH (s:Statement)-[r]-(e:ExtractedEntity)
WHERE s.activation_value IS NOT NULL
AND e.activation_value IS NOT NULL
AND s.activation_value < $threshold
AND e.activation_value < $threshold
AND s.last_access_time < $cutoff_time
AND e.last_access_time < $cutoff_time
AND (e.entity_type IS NULL OR e.entity_type <> 'Person')
"""
if group_id:
query += " AND s.group_id = $group_id AND e.group_id = $group_id"
query += """
RETURN s.id as statement_id,
s.statement as statement_text,
s.activation_value as statement_activation,
s.importance_score as statement_importance,
s.last_access_time as statement_last_access,
e.id as entity_id,
e.name as entity_name,
e.entity_type as entity_type,
e.activation_value as entity_activation,
e.importance_score as entity_importance,
e.last_access_time as entity_last_access,
(s.activation_value + e.activation_value) / 2.0 as avg_activation
ORDER BY avg_activation ASC
"""
params = {
'threshold': self.forgetting_threshold,
'cutoff_time': cutoff_time_iso
}
if group_id:
params['group_id'] = group_id
results = await self.connector.execute_query(query, **params)
logger.info(
f"识别到 {len(results)} 个可遗忘节点对 "
f"(threshold={self.forgetting_threshold}, "
f"min_days={min_days_since_access})"
)
return results
async def merge_nodes_to_summary(
self,
statement_node: Dict[str, Any],
entity_node: Dict[str, Any],
config_id: Optional[int] = None,
db = None
) -> str:
"""
将 Statement 和 Entity 节点融合为 MemorySummary 节点
融合过程:
1. 生成摘要内容(使用 LLM 或简单拼接)
2. 创建 MemorySummary 节点,继承较高的激活值和重要性分数
3. 删除原始 Statement 和 Entity 节点
4. 保留溯源信息original_statement_id, original_entity_id
Args:
statement_node: Statement 节点数据,必须包含:
- statement_id: 节点 ID
- statement_text: 文本内容
- statement_activation: 激活值
- statement_importance: 重要性分数
entity_node: Entity 节点数据,必须包含:
- entity_id: 节点 ID
- entity_name: 实体名称
- entity_type: 实体类型
- entity_activation: 激活值
- entity_importance: 重要性分数
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
str: 创建的 MemorySummary 节点 ID
Raises:
ValueError: 如果节点数据不完整
RuntimeError: 如果融合操作失败
"""
# 验证输入数据
required_statement_keys = [
'statement_id', 'statement_text',
'statement_activation', 'statement_importance'
]
required_entity_keys = [
'entity_id', 'entity_name', 'entity_type',
'entity_activation', 'entity_importance'
]
for key in required_statement_keys:
if key not in statement_node:
raise ValueError(f"Statement 节点缺少必需字段: {key}")
for key in required_entity_keys:
if key not in entity_node:
raise ValueError(f"Entity 节点缺少必需字段: {key}")
# 验证实体类型:不允许融合 Person 类型的实体
if entity_node.get('entity_type') == 'Person':
raise ValueError(
f"不允许融合 Person 类型的实体: entity_id={entity_node.get('entity_id')}, "
f"entity_name={entity_node.get('entity_name')}"
)
# 提取节点信息
statement_id = statement_node['statement_id']
statement_text = statement_node['statement_text']
statement_activation = statement_node['statement_activation']
statement_importance = statement_node['statement_importance']
entity_id = entity_node['entity_id']
entity_name = entity_node['entity_name']
entity_type = entity_node['entity_type']
entity_activation = entity_node['entity_activation']
entity_importance = entity_node['entity_importance']
# 生成摘要内容
summary_text = await self._generate_summary(
statement_text=statement_text,
entity_name=entity_name,
entity_type=entity_type,
config_id=config_id,
db=db
)
# 计算继承的激活值和重要性(取较高值)
inherited_activation = max(statement_activation, entity_activation)
inherited_importance = max(statement_importance, entity_importance)
# 创建 MemorySummary 节点
current_time = datetime.now()
current_time_iso = current_time.isoformat()
# 生成新的 MemorySummary ID
import uuid
summary_id = f"summary_{uuid.uuid4().hex[:16]}"
# 获取 group_id从 statement 或 entity 节点)
group_id = statement_node.get('group_id') or entity_node.get('group_id')
# 使用事务创建 MemorySummary 并删除原节点
async def merge_transaction(tx, **params):
"""事务函数:创建摘要节点并删除原节点"""
query = """
// 首先检查节点是否存在
OPTIONAL MATCH (s:Statement {id: $statement_id})
OPTIONAL MATCH (e:ExtractedEntity {id: $entity_id})
// 如果任一节点不存在,直接返回 null不执行后续操作
WITH s, e
WHERE s IS NOT NULL AND e IS NOT NULL
// 创建 MemorySummary 节点
CREATE (ms:MemorySummary {
id: $summary_id,
summary: $summary_text,
original_statement_id: $statement_id,
original_entity_id: $entity_id,
activation_value: $inherited_activation,
importance_score: $inherited_importance,
access_history: [$current_time],
last_access_time: $current_time,
access_count: 1,
version: 1,
group_id: $group_id,
created_at: datetime($current_time),
merged_at: datetime($current_time)
})
// 转移 Statement 的出边到 MemorySummary只转移目标节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (s)-[r_out]->(target)
WHERE target <> e AND r_out IS NOT NULL AND target IS NOT NULL
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
ON CREATE SET
new_rel = properties(r_out),
new_rel.original_relationship_type = type(r_out),
new_rel.merged_from_statement = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Statement 的入边到 MemorySummary只转移源节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (source)-[r_in]->(s)
WHERE r_in IS NOT NULL AND source IS NOT NULL
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
ON CREATE SET
new_rel = properties(r_in),
new_rel.original_relationship_type = type(r_in),
new_rel.merged_from_statement = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Entity 的出边到 MemorySummary只转移目标节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (e)-[r_out]->(target)
WHERE target <> s AND r_out IS NOT NULL AND target IS NOT NULL
FOREACH (_ IN CASE WHEN target IS NOT NULL THEN [1] ELSE [] END |
MERGE (ms)-[new_rel:DERIVED_FROM]->(target)
ON CREATE SET
new_rel = properties(r_out),
new_rel.original_relationship_type = type(r_out),
new_rel.merged_from_entity = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 转移 Entity 的入边到 MemorySummary只转移源节点仍存在的边
WITH ms, s, e
CALL (ms, s, e) {
OPTIONAL MATCH (source)-[r_in]->(e)
WHERE source <> s AND r_in IS NOT NULL AND source IS NOT NULL
FOREACH (_ IN CASE WHEN source IS NOT NULL THEN [1] ELSE [] END |
MERGE (source)-[new_rel:DERIVED_FROM]->(ms)
ON CREATE SET
new_rel = properties(r_in),
new_rel.original_relationship_type = type(r_in),
new_rel.merged_from_entity = true,
new_rel.merge_count = 1
ON MATCH SET
new_rel.merge_count = coalesce(new_rel.merge_count, 0) + 1
)
}
// 删除原始节点
WITH ms, s, e
DETACH DELETE s, e
RETURN ms.id as summary_id
"""
result = await tx.run(query, **params)
record = await result.single()
if not record:
raise RuntimeError("Failed to create MemorySummary node - nodes may not exist")
return record['summary_id']
params = {
'summary_id': summary_id,
'summary_text': summary_text,
'statement_id': statement_id,
'entity_id': entity_id,
'inherited_activation': inherited_activation,
'inherited_importance': inherited_importance,
'current_time': current_time_iso,
'group_id': group_id
}
try:
created_summary_id = await self.connector.execute_write_transaction(
merge_transaction,
**params
)
logger.info(
f"成功融合节点: Statement[{statement_id}] + Entity[{entity_id}] "
f"-> MemorySummary[{created_summary_id}], "
f"activation={inherited_activation:.4f}, "
f"importance={inherited_importance:.4f}"
)
return created_summary_id
except Exception as e:
# 记录详细的错误信息,包括异常类型和堆栈
import traceback
error_details = traceback.format_exc()
logger.error(
f"融合节点失败: Statement[{statement_id}] + Entity[{entity_id}], "
f"错误类型: {type(e).__name__}, "
f"错误信息: {str(e)}, "
f"详细堆栈:\n{error_details}"
)
raise RuntimeError(
f"融合节点失败: {str(e)}"
) from e
# ==================== 私有辅助方法 ====================
async def _generate_summary(
self,
statement_text: str,
entity_name: str,
entity_type: str,
config_id: Optional[int] = None,
db = None
) -> str:
"""
生成摘要内容
优先使用 LLM 生成高质量摘要,如果 LLM 不可用或失败,
则降级到简单文本拼接。
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
config_id: 配置ID可选用于获取 llm_id
db: 数据库会话(可选,用于获取 llm_id
Returns:
str: 生成的摘要文本(最多 200 个字符)
"""
# 如果配置禁用 LLM 摘要,直接使用简单拼接
if not self.enable_llm_summary:
logger.info("LLM 摘要生成已禁用,使用简单拼接")
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
# 尝试获取 LLM 客户端
llm_client = None
if config_id is not None and db is not None:
try:
llm_client = await self._get_llm_client(db, config_id)
except Exception as e:
logger.warning(f"获取 LLM 客户端失败: {str(e)}")
# 如果没有 LLM 客户端,直接使用简单拼接
if llm_client is None:
logger.info("未能获取 LLM 客户端,使用简单拼接")
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
# 尝试使用 LLM 生成摘要
try:
summary = await self._generate_llm_summary(
statement_text=statement_text,
entity_name=entity_name,
entity_type=entity_type,
llm_client=llm_client
)
# 限制长度为 200 个字符
if len(summary) > 200:
summary = f"{summary[:197]}..."
logger.info(f"使用 LLM 生成摘要: {summary}")
return summary
except Exception as e:
logger.warning(
f"LLM 摘要生成失败,降级到简单拼接: {str(e)}"
)
return self._simple_concatenation(
statement_text, entity_name, entity_type
)
async def _get_llm_client(self, db, config_id: int):
"""
从数据库获取 LLM 客户端
Args:
db: 数据库会话
config_id: 配置ID
Returns:
LLM 客户端实例,如果无法获取则返回 None
"""
try:
from app.repositories.data_config_repository import DataConfigRepository
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 从数据库读取配置
repository = DataConfigRepository()
db_config = repository.get_by_id(db, config_id)
if db_config is None or db_config.llm_id is None:
logger.warning(f"配置 {config_id} 不存在或未设置 llm_id")
return None
# 创建 LLM 客户端
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(str(db_config.llm_id))
logger.info(f"成功获取 LLM 客户端: config_id={config_id}, llm_id={db_config.llm_id}")
return llm_client
except Exception as e:
logger.error(f"获取 LLM 客户端失败: {str(e)}")
return None
async def _generate_llm_summary(
self,
statement_text: str,
entity_name: str,
entity_type: str,
llm_client
) -> str:
"""
使用 LLM 生成高质量摘要
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
llm_client: LLM 客户端实例
Returns:
str: LLM 生成的摘要文本
Raises:
Exception: 如果 LLM 调用失败
"""
# 构建提示词
prompt = f"""请为以下记忆片段生成一个简洁的摘要不超过200个字符
实体名称: {entity_name}
实体类型: {entity_type}
陈述内容: {statement_text}
要求:
1. 摘要应该保留核心语义信息
2. 长度不超过200个字符
3. 使用简洁、自然的中文表达
4. 只返回摘要文本,不要包含其他内容
摘要:"""
# 调用 LLM直接传递 prompt 字符串)
response = await llm_client.chat(prompt)
# 提取摘要文本
if isinstance(response, str):
summary = response.strip()
elif hasattr(response, 'content'):
summary = response.content.strip()
else:
summary = str(response).strip()
return summary
def _simple_concatenation(
self,
statement_text: str,
entity_name: str,
entity_type: str
) -> str:
"""
简单文本拼接生成摘要
降级策略:当 LLM 不可用时使用。
格式:[实体类型]实体名称: 陈述内容
Args:
statement_text: Statement 文本内容
entity_name: Entity 名称
entity_type: Entity 类型
Returns:
str: 拼接的摘要文本(最多 200 个字符)
"""
# 构建简单摘要
summary = f"[{entity_type}]{entity_name}: {statement_text}"
# 限制长度为 200 个字符(注意:这里的长度是字符数,不是字节数)
if len(summary) > 200:
# 截断并添加省略号
summary = f"{summary[:197]}..."
return summary