Fix/memory bug fix (#171)
This commit is contained in:
@@ -12,6 +12,7 @@
|
||||
|
||||
from typing import Optional, Dict, Any, Tuple
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -23,7 +24,7 @@ from app.core.memory.storage_services.forgetting_engine.config_utils import (
|
||||
load_actr_config_from_db,
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.forgetting_cycle_history_repository import ForgettingCycleHistoryRepository
|
||||
|
||||
|
||||
@@ -70,7 +71,7 @@ class MemoryForgetService:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化服务"""
|
||||
self.config_repository = DataConfigRepository()
|
||||
self.config_repository = MemoryConfigRepository()
|
||||
self.history_repository = ForgettingCycleHistoryRepository()
|
||||
|
||||
def _get_neo4j_connector(self) -> Neo4jConnector:
|
||||
@@ -87,7 +88,7 @@ class MemoryForgetService:
|
||||
async def _get_forgetting_components(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Tuple[ACTRCalculator, ForgettingStrategy, ForgettingScheduler, Dict[str, Any]]:
|
||||
"""
|
||||
获取遗忘引擎组件(计算器、策略、调度器)
|
||||
@@ -132,7 +133,7 @@ class MemoryForgetService:
|
||||
async def _get_knowledge_stats(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
group_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
forgetting_threshold: float = 0.3
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -140,7 +141,7 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
forgetting_threshold: 遗忘阈值
|
||||
|
||||
Returns:
|
||||
@@ -152,8 +153,8 @@ class MemoryForgetService:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
query += """
|
||||
WITH n,
|
||||
@@ -172,8 +173,8 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
results = await connector.execute_query(query, **params)
|
||||
|
||||
@@ -200,7 +201,7 @@ class MemoryForgetService:
|
||||
async def _get_pending_forgetting_nodes(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
forgetting_threshold: float,
|
||||
min_days_since_access: int,
|
||||
limit: int = 20
|
||||
@@ -212,7 +213,7 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
group_id: 组ID
|
||||
end_user_id: 组ID
|
||||
forgetting_threshold: 遗忘阈值
|
||||
min_days_since_access: 最小未访问天数
|
||||
limit: 返回节点数量限制
|
||||
@@ -229,7 +230,7 @@ class MemoryForgetService:
|
||||
query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.group_id = $group_id
|
||||
AND n.end_user_id = $end_user_id
|
||||
AND n.activation_value IS NOT NULL
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
@@ -250,7 +251,7 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
params = {
|
||||
'group_id': group_id,
|
||||
'end_user_id': end_user_id,
|
||||
'threshold': forgetting_threshold,
|
||||
'min_access_time_str': min_access_time_str,
|
||||
'limit': limit
|
||||
@@ -291,10 +292,10 @@ class MemoryForgetService:
|
||||
async def trigger_forgetting_cycle(
|
||||
self,
|
||||
db: Session,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
max_merge_batch_size: Optional[int] = None,
|
||||
min_days_since_access: Optional[int] = None,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
手动触发遗忘周期
|
||||
@@ -303,10 +304,10 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
group_id: 组ID(即终端用户ID,必填)
|
||||
end_user_id: 组ID(即终端用户ID,必填)
|
||||
max_merge_batch_size: 最大融合批次大小(可选)
|
||||
min_days_since_access: 最小未访问天数(可选)
|
||||
config_id: 配置ID(必填,由控制器层通过 group_id 获取)
|
||||
config_id: 配置ID(必填,由控制器层通过 end_user_id 获取)
|
||||
|
||||
Returns:
|
||||
dict: 遗忘报告
|
||||
@@ -319,7 +320,7 @@ class MemoryForgetService:
|
||||
|
||||
# 运行遗忘周期(LLM 客户端将在需要时由 forgetting_strategy 内部获取)
|
||||
report = await forgetting_scheduler.run_forgetting_cycle(
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
max_merge_batch_size=max_merge_batch_size,
|
||||
min_days_since_access=min_days_since_access,
|
||||
config_id=config_id,
|
||||
@@ -338,7 +339,7 @@ class MemoryForgetService:
|
||||
stats_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
AND n.group_id = $group_id
|
||||
AND n.end_user_id = $end_user_id
|
||||
RETURN
|
||||
count(n) as total_nodes,
|
||||
avg(n.activation_value) as average_activation,
|
||||
@@ -347,7 +348,7 @@ class MemoryForgetService:
|
||||
|
||||
stats_results = await connector.execute_query(
|
||||
stats_query,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
threshold=config['forgetting_threshold']
|
||||
)
|
||||
|
||||
@@ -364,7 +365,7 @@ class MemoryForgetService:
|
||||
# 保存历史记录到数据库
|
||||
self.history_repository.create(
|
||||
db=db,
|
||||
end_user_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
execution_time=execution_time,
|
||||
merged_count=report['merged_count'],
|
||||
failed_count=report['failed_count'],
|
||||
@@ -376,7 +377,7 @@ class MemoryForgetService:
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"已保存遗忘周期历史记录: end_user_id={group_id}, "
|
||||
f"已保存遗忘周期历史记录: end_user_id={end_user_id}, "
|
||||
f"merged_count={report['merged_count']}"
|
||||
)
|
||||
|
||||
@@ -389,7 +390,7 @@ class MemoryForgetService:
|
||||
def read_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int
|
||||
config_id: UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎配置
|
||||
@@ -416,7 +417,7 @@ class MemoryForgetService:
|
||||
def update_forgetting_config(
|
||||
self,
|
||||
db: Session,
|
||||
config_id: int,
|
||||
config_id: UUID,
|
||||
update_fields: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -465,8 +466,8 @@ class MemoryForgetService:
|
||||
async def get_forgetting_stats(
|
||||
self,
|
||||
db: Session,
|
||||
group_id: Optional[str] = None,
|
||||
config_id: Optional[int] = None
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘引擎统计信息
|
||||
@@ -475,7 +476,7 @@ class MemoryForgetService:
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
group_id: 组ID(可选)
|
||||
end_user_id: 组ID(可选)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
|
||||
Returns:
|
||||
@@ -493,8 +494,8 @@ class MemoryForgetService:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
activation_query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
activation_query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
activation_query += """
|
||||
RETURN
|
||||
@@ -506,8 +507,8 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
params = {'threshold': forgetting_threshold}
|
||||
if group_id:
|
||||
params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
params['end_user_id'] = end_user_id
|
||||
|
||||
activation_results = await connector.execute_query(activation_query, **params)
|
||||
|
||||
@@ -539,8 +540,8 @@ class MemoryForgetService:
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary OR n:Chunk)
|
||||
"""
|
||||
|
||||
if group_id:
|
||||
distribution_query += " AND n.group_id = $group_id"
|
||||
if end_user_id:
|
||||
distribution_query += " AND n.end_user_id = $end_user_id"
|
||||
|
||||
distribution_query += """
|
||||
WITH n,
|
||||
@@ -558,8 +559,8 @@ class MemoryForgetService:
|
||||
"""
|
||||
|
||||
dist_params = {}
|
||||
if group_id:
|
||||
dist_params['group_id'] = group_id
|
||||
if end_user_id:
|
||||
dist_params['end_user_id'] = end_user_id
|
||||
|
||||
distribution_results = await connector.execute_query(distribution_query, **dist_params)
|
||||
|
||||
@@ -582,11 +583,11 @@ class MemoryForgetService:
|
||||
# 获取最近7个日期的历史趋势数据(每天取最后一次执行)
|
||||
recent_trends = []
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 查询所有历史记录
|
||||
history_records = self.history_repository.get_recent_by_end_user(
|
||||
db=db,
|
||||
end_user_id=group_id
|
||||
end_user_id=end_user_id
|
||||
)
|
||||
|
||||
# 按日期分组(一天可能有多次执行,取最后一次)
|
||||
@@ -632,7 +633,7 @@ class MemoryForgetService:
|
||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
||||
pending_nodes = []
|
||||
try:
|
||||
if group_id:
|
||||
if end_user_id:
|
||||
# 验证 min_days_since_access 配置值
|
||||
min_days = config.get('min_days_since_access')
|
||||
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||
@@ -643,7 +644,7 @@ class MemoryForgetService:
|
||||
|
||||
pending_nodes = await self._get_pending_forgetting_nodes(
|
||||
connector=connector,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
limit=20
|
||||
@@ -677,7 +678,7 @@ class MemoryForgetService:
|
||||
db: Session,
|
||||
importance_score: float,
|
||||
days: int,
|
||||
config_id: Optional[int] = None
|
||||
config_id: Optional[UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取遗忘曲线数据
|
||||
|
||||
Reference in New Issue
Block a user