Feature/actr forget (#55)

* [changes]Request to remove 'config_id' has been received.

* [add]Add the access history record table

* [changes]Request to remove 'config_id' has been received.

* [add]Add the access history record table

* [add]Obtain the record of the forgetting trend

* [changes]Based on the AI's suggestion, make the necessary modifications.
This commit is contained in:
乐力齐
2026-01-08 15:15:13 +08:00
committed by GitHub
parent 7871663cae
commit a4af0f7432
8 changed files with 542 additions and 33 deletions

View File

@@ -11,7 +11,7 @@
"""
from typing import Optional, Dict, Any, Tuple
from datetime import datetime
from datetime import datetime, timezone
from sqlalchemy.orm import Session
@@ -24,18 +24,54 @@ from app.core.memory.storage_services.forgetting_engine.config_utils import (
)
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.forgetting_cycle_history_repository import ForgettingCycleHistoryRepository
# 获取API专用日志器
api_logger = get_api_logger()
def convert_neo4j_datetime_to_python(value: Any) -> Optional[datetime]:
"""
将 Neo4j DateTime 对象转换为 Python datetime 对象
Args:
value: Neo4j DateTime 对象、Python datetime 对象或字符串
Returns:
Python datetime 对象或 None
"""
if value is None:
return None
try:
# Neo4j DateTime 对象
if hasattr(value, 'to_native'):
return value.to_native()
# Python datetime 对象
elif isinstance(value, datetime):
return value
# 字符串格式
elif isinstance(value, str):
if value.endswith('Z'):
return datetime.fromisoformat(value.replace('Z', '+00:00'))
else:
return datetime.fromisoformat(value)
# 其他类型,尝试转换为字符串
else:
return datetime.fromisoformat(str(value).replace('Z', '+00:00'))
except Exception as e:
api_logger.warning(f"转换时间失败: {value} (类型: {type(value).__name__}), 错误: {e}")
return None
class MemoryForgetService:
"""遗忘引擎服务类"""
def __init__(self):
"""初始化服务"""
self.config_repository = DataConfigRepository()
self.history_repository = ForgettingCycleHistoryRepository()
def _get_neo4j_connector(self) -> Neo4jConnector:
"""
@@ -161,10 +197,101 @@ class MemoryForgetService:
'low_activation_nodes': 0
}
async def _get_pending_forgetting_nodes(
self,
connector: Neo4jConnector,
group_id: str,
forgetting_threshold: float,
min_days_since_access: int,
limit: int = 20
) -> list[Dict[str, Any]]:
"""
获取待遗忘节点列表
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
Args:
connector: Neo4j 连接器
group_id: 组ID
forgetting_threshold: 遗忘阈值
min_days_since_access: 最小未访问天数
limit: 返回节点数量限制
Returns:
list: 待遗忘节点列表
"""
from datetime import timedelta
# 计算最小访问时间ISO 8601 格式字符串,使用 UTC 时区)
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.group_id = $group_id
AND n.activation_value IS NOT NULL
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN
elementId(n) as node_id,
labels(n)[0] as node_type,
CASE
WHEN n:Statement THEN n.statement
WHEN n:ExtractedEntity THEN n.name
WHEN n:MemorySummary THEN n.content
ELSE ''
END as content_summary,
n.activation_value as activation_value,
n.last_access_time as last_access_time
ORDER BY n.activation_value ASC
LIMIT $limit
"""
params = {
'group_id': group_id,
'threshold': forgetting_threshold,
'min_access_time_str': min_access_time_str,
'limit': limit
}
results = await connector.execute_query(query, **params)
pending_nodes = []
for result in results:
# 将节点类型标签转换为小写
node_type_label = result['node_type'].lower()
if node_type_label == 'extractedentity':
node_type_label = 'entity'
elif node_type_label == 'memorysummary':
node_type_label = 'summary'
# 将 Neo4j DateTime 对象转换为时间戳
last_access_time = result['last_access_time']
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
# 确保 datetime 带有时区信息(假定为 UTC),避免 naive datetime 导致的时区偏差
if last_access_dt:
if last_access_dt.tzinfo is None:
last_access_dt = last_access_dt.replace(tzinfo=timezone.utc)
last_access_timestamp = int(last_access_dt.timestamp())
else:
last_access_timestamp = 0
pending_nodes.append({
'node_id': str(result['node_id']),
'node_type': node_type_label,
'content_summary': result['content_summary'] or '',
'activation_value': result['activation_value'],
'last_access_time': last_access_timestamp
})
return pending_nodes
async def trigger_forgetting_cycle(
self,
db: Session,
group_id: Optional[str] = None,
group_id: str,
max_merge_batch_size: Optional[int] = None,
min_days_since_access: Optional[int] = None,
config_id: Optional[int] = None
@@ -176,10 +303,10 @@ class MemoryForgetService:
Args:
db: 数据库会话
group_id: 组ID可选
group_id: 组ID即终端用户ID必填
max_merge_batch_size: 最大融合批次大小(可选)
min_days_since_access: 最小未访问天数(可选)
config_id: 配置ID可选
config_id: 配置ID必填,由控制器层通过 group_id 获取
Returns:
dict: 遗忘报告
@@ -187,6 +314,9 @@ class MemoryForgetService:
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
# 记录执行开始时间
execution_time = datetime.now()
# 运行遗忘周期LLM 客户端将在需要时由 forgetting_strategy 内部获取)
report = await forgetting_scheduler.run_forgetting_cycle(
group_id=group_id,
@@ -202,6 +332,58 @@ class MemoryForgetService:
f"耗时 {report['duration_seconds']:.2f}"
)
# 获取当前的激活值统计(用于记录历史)
try:
connector = forgetting_scheduler.connector
stats_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
AND n.group_id = $group_id
RETURN
count(n) as total_nodes,
avg(n.activation_value) as average_activation,
sum(CASE WHEN n.activation_value IS NOT NULL AND n.activation_value < $threshold THEN 1 ELSE 0 END) as low_activation_nodes
"""
stats_results = await connector.execute_query(
stats_query,
group_id=group_id,
threshold=config['forgetting_threshold']
)
if stats_results:
stats = stats_results[0]
total_nodes = stats['total_nodes'] or 0
average_activation = stats['average_activation']
low_activation_nodes = stats['low_activation_nodes'] or 0
else:
total_nodes = 0
average_activation = None
low_activation_nodes = 0
# 保存历史记录到数据库
self.history_repository.create(
db=db,
end_user_id=group_id,
execution_time=execution_time,
merged_count=report['merged_count'],
failed_count=report['failed_count'],
average_activation_value=average_activation,
total_nodes=total_nodes,
low_activation_nodes=low_activation_nodes,
duration_seconds=report['duration_seconds'],
trigger_type='manual'
)
api_logger.info(
f"已保存遗忘周期历史记录: end_user_id={group_id}, "
f"merged_count={report['merged_count']}"
)
except Exception as e:
# 记录历史失败不应影响主流程
api_logger.error(f"保存遗忘周期历史记录失败: {str(e)}")
return report
def read_forgetting_config(
@@ -337,7 +519,8 @@ class MemoryForgetService:
'nodes_without_activation': result['nodes_without_activation'] or 0,
'average_activation_value': result['average_activation'],
'low_activation_nodes': result['low_activation_nodes'] or 0,
'timestamp': datetime.now().isoformat()
'forgetting_threshold': forgetting_threshold,
'timestamp': int(datetime.now().timestamp())
}
else:
activation_metrics = {
@@ -346,7 +529,8 @@ class MemoryForgetService:
'nodes_without_activation': 0,
'average_activation_value': None,
'low_activation_nodes': 0,
'timestamp': datetime.now().isoformat()
'forgetting_threshold': forgetting_threshold,
'timestamp': int(datetime.now().timestamp())
}
# 收集节点类型分布
@@ -395,19 +579,95 @@ class MemoryForgetService:
'chunk_count': 0
}
# 构建统计信息(不包含监控历史数据
# 获取最近7个日期的历史趋势数据每天取最后一次执行
recent_trends = []
try:
if group_id:
# 查询所有历史记录
history_records = self.history_repository.get_recent_by_end_user(
db=db,
end_user_id=group_id
)
# 按日期分组(一天可能有多次执行,取最后一次)
from collections import OrderedDict
daily_records = OrderedDict()
# 遍历记录(已按时间降序),每个日期只保留第一次遇到的(即最后一次执行)
for record in history_records:
# 提取日期(格式: "1/1", "1/2"- 跨平台兼容
month = record.execution_time.month
day = record.execution_time.day
date_str = f"{month}/{day}"
# 如果这个日期还没有记录,添加它(这是该日期最后一次执行)
if date_str not in daily_records:
daily_records[date_str] = record
# 如果已经有7个不同的日期停止
if len(daily_records) >= 7:
break
# 构建趋势数据点(按时间从旧到新排序)
sorted_dates = sorted(
daily_records.items(),
key=lambda x: x[1].execution_time
)
for date_str, record in sorted_dates:
recent_trends.append({
'date': date_str,
'merged_count': record.merged_count,
'average_activation': record.average_activation_value,
'total_nodes': record.total_nodes,
'execution_time': int(record.execution_time.timestamp())
})
api_logger.info(f"成功获取最近 {len(recent_trends)} 个日期的历史趋势数据")
except Exception as e:
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
# 失败时返回空列表,不影响主流程
# 获取待遗忘节点列表前20个满足遗忘条件的节点
pending_nodes = []
try:
if group_id:
# 验证 min_days_since_access 配置值
min_days = config.get('min_days_since_access')
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
api_logger.warning(
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
)
min_days = 7
pending_nodes = await self._get_pending_forgetting_nodes(
connector=connector,
group_id=group_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
limit=20
)
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
except Exception as e:
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
# 失败时返回空列表,不影响主流程
# 构建统计信息
stats = {
'activation_metrics': activation_metrics,
'node_distribution': node_distribution,
'consistency_check': None, # 不再提供一致性检查
'nodes_merged_total': 0, # 不再跟踪累计融合数
'recent_cycles': [], # 不再提供历史记录
'timestamp': datetime.now().isoformat()
'recent_trends': recent_trends,
'pending_nodes': pending_nodes,
'timestamp': int(datetime.now().timestamp())
}
api_logger.info(
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}"
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
)
return stats

View File

@@ -89,11 +89,15 @@ class DataConfigService: # 数据配置服务类PostgreSQL
value = item[field]
dt = None
# 如果是 datetime 对象,直接使用
if isinstance(value, datetime):
# 处理不同类型的时间值
if hasattr(value, 'to_native'):
# Neo4j DateTime 对象
dt = value.to_native()
elif isinstance(value, datetime):
# Python datetime 对象
dt = value
# 如果是字符串,先解析
elif isinstance(value, str):
# 字符串格式
try:
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
except Exception:

View File

@@ -160,11 +160,20 @@ class MemoryInsightHelper:
month_counts = Counter()
valid_dates_count = 0
for record in records:
creation_time_str = record.get("creation_time")
if not creation_time_str:
creation_time = record.get("creation_time")
if not creation_time:
continue
try:
dt_object = datetime.fromisoformat(creation_time_str.replace("Z", "+00:00"))
# 处理 Neo4j DateTime 对象或字符串
if hasattr(creation_time, 'to_native'):
dt_object = creation_time.to_native()
elif isinstance(creation_time, str):
dt_object = datetime.fromisoformat(creation_time.replace("Z", "+00:00"))
elif isinstance(creation_time, datetime):
dt_object = creation_time
else:
dt_object = datetime.fromisoformat(str(creation_time).replace("Z", "+00:00"))
month_counts[dt_object.month] += 1
valid_dates_count += 1
except (ValueError, TypeError, AttributeError):
@@ -225,8 +234,33 @@ class MemoryInsightHelper:
)
start_year, end_year = "N/A", "N/A"
if time_records and time_records[0]["start_time"]:
start_year = datetime.fromisoformat(time_records[0]["start_time"].replace("Z", "+00:00")).year
end_year = datetime.fromisoformat(time_records[0]["end_time"].replace("Z", "+00:00")).year
start_time = time_records[0]["start_time"]
end_time = time_records[0]["end_time"]
# 处理 Neo4j DateTime 对象或字符串
try:
if hasattr(start_time, 'to_native'):
start_year = start_time.to_native().year
elif isinstance(start_time, str):
start_year = datetime.fromisoformat(start_time.replace("Z", "+00:00")).year
elif isinstance(start_time, datetime):
start_year = start_time.year
else:
start_year = datetime.fromisoformat(str(start_time).replace("Z", "+00:00")).year
except Exception:
start_year = "N/A"
try:
if hasattr(end_time, 'to_native'):
end_year = end_time.to_native().year
elif isinstance(end_time, str):
end_year = datetime.fromisoformat(end_time.replace("Z", "+00:00")).year
elif isinstance(end_time, datetime):
end_year = end_time.year
else:
end_year = datetime.fromisoformat(str(end_time).replace("Z", "+00:00")).year
except Exception:
end_year = "N/A"
return {
"user_id": most_connected_user,