[fix] Set the page for the nodes to be forgotten
This commit is contained in:
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
PendingNodesResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||
async def get_pending_nodes(
|
||||
end_user_id: str,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
此接口独立分页,与 /stats 接口分离。
|
||||
|
||||
Args:
|
||||
end_user_id: 组ID(即 end_user_id,必填)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||
|
||||
Examples:
|
||||
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||
|
||||
Notes:
|
||||
- page 从1开始,pagesize 必须大于0
|
||||
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证 end_user_id 必填
|
||||
if not end_user_id:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
# 验证分页参数
|
||||
if page < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||
if pagesize < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取待遗忘节点列表
|
||||
result = await forget_service.get_pending_nodes(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = PendingNodesResponse(**result)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
|
||||
@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
|
||||
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
||||
|
||||
|
||||
class PageInfo(BaseModel):
|
||||
"""分页信息模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
page: int = Field(..., description="当前页码(从1开始)")
|
||||
pagesize: int = Field(..., description="每页数量")
|
||||
total: int = Field(..., description="总记录数")
|
||||
hasnext: bool = Field(..., description="是否有下一页")
|
||||
|
||||
|
||||
class PendingNodesResponse(BaseModel):
|
||||
"""待遗忘节点列表响应模型(独立分页接口)"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
|
||||
page: PageInfo = Field(..., description="分页信息")
|
||||
|
||||
|
||||
class ForgettingStatsResponse(BaseModel):
|
||||
"""遗忘引擎统计信息响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
|
||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||
|
||||
|
||||
|
||||
@@ -203,29 +203,36 @@ class MemoryForgetService:
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
forgetting_threshold: float,
|
||||
min_days_since_access: int
|
||||
) -> list[Dict[str, Any]]:
|
||||
min_days_since_access: int,
|
||||
page: Optional[int] = None,
|
||||
pagesize: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
|
||||
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
end_user_id: 组ID
|
||||
forgetting_threshold: 遗忘阈值
|
||||
min_days_since_access: 最小未访问天数
|
||||
|
||||
page: 页码(可选,从1开始)
|
||||
pagesize: 每页数量(可选)
|
||||
|
||||
Returns:
|
||||
list: 待遗忘节点列表
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息(分页时)
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
||||
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
||||
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
||||
|
||||
query = """
|
||||
|
||||
# 基础查询(用于获取总数)
|
||||
count_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
@@ -233,10 +240,22 @@ class MemoryForgetService:
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
# 数据查询
|
||||
data_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
AND n.activation_value IS NOT NULL
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
elementId(n) as node_id,
|
||||
labels(n)[0] as node_type,
|
||||
CASE
|
||||
CASE
|
||||
WHEN n:Statement THEN n.statement
|
||||
WHEN n:ExtractedEntity THEN n.name
|
||||
WHEN n:MemorySummary THEN n.content
|
||||
@@ -246,15 +265,31 @@ class MemoryForgetService:
|
||||
n.last_access_time as last_access_time
|
||||
ORDER BY n.activation_value ASC
|
||||
"""
|
||||
|
||||
|
||||
# 如果启用分页,添加 SKIP 和 LIMIT
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
data_query += " SKIP $skip LIMIT $limit"
|
||||
|
||||
params = {
|
||||
'end_user_id': end_user_id,
|
||||
'threshold': forgetting_threshold,
|
||||
'min_access_time_str': min_access_time_str
|
||||
}
|
||||
|
||||
results = await connector.execute_query(query, **params)
|
||||
|
||||
|
||||
# 获取总数(分页时需要)
|
||||
total = 0
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
count_results = await connector.execute_query(count_query, **params)
|
||||
if count_results:
|
||||
total = count_results[0]['total']
|
||||
|
||||
# 添加分页参数
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
params['skip'] = (page - 1) * pagesize
|
||||
params['limit'] = pagesize
|
||||
|
||||
results = await connector.execute_query(data_query, **params)
|
||||
|
||||
pending_nodes = []
|
||||
for result in results:
|
||||
# 将节点类型标签转换为小写
|
||||
@@ -263,7 +298,7 @@ class MemoryForgetService:
|
||||
node_type_label = 'entity'
|
||||
elif node_type_label == 'memorysummary':
|
||||
node_type_label = 'summary'
|
||||
|
||||
|
||||
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
||||
last_access_time = result['last_access_time']
|
||||
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
||||
@@ -274,7 +309,7 @@ class MemoryForgetService:
|
||||
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
||||
else:
|
||||
last_access_timestamp = 0
|
||||
|
||||
|
||||
pending_nodes.append({
|
||||
'node_id': str(result['node_id']),
|
||||
'node_type': node_type_label,
|
||||
@@ -282,8 +317,20 @@ class MemoryForgetService:
|
||||
'activation_value': result['activation_value'],
|
||||
'last_access_time': last_access_timestamp
|
||||
})
|
||||
|
||||
return pending_nodes
|
||||
|
||||
# 构建返回结果
|
||||
result: Dict[str, Any] = {'items': pending_nodes}
|
||||
|
||||
# 如果启用分页,添加分页信息
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
result['page'] = {
|
||||
'page': page,
|
||||
'pagesize': pagesize,
|
||||
'total': total,
|
||||
'hasnext': (page * pagesize) < total
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def trigger_forgetting_cycle(
|
||||
self,
|
||||
@@ -656,24 +703,79 @@ class MemoryForgetService:
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
||||
# 失败时返回空列表,不影响主流程
|
||||
|
||||
# 构建统计信息
|
||||
|
||||
# 构建统计信息(不包含 pending_nodes,已分离到独立接口)
|
||||
stats = {
|
||||
'activation_metrics': activation_metrics,
|
||||
'node_distribution': node_distribution,
|
||||
'recent_trends': recent_trends,
|
||||
'pending_nodes': pending_nodes,
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
||||
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
||||
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
|
||||
f"trend_days={len(recent_trends)}"
|
||||
)
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def get_pending_nodes(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
config_id: Optional[UUID] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 组ID(必填)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
|
||||
Returns:
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息
|
||||
"""
|
||||
# 获取遗忘引擎组件
|
||||
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||
|
||||
connector = forgetting_scheduler.connector
|
||||
forgetting_threshold = config['forgetting_threshold']
|
||||
|
||||
# 验证 min_days_since_access 配置值
|
||||
min_days = config.get('min_days_since_access')
|
||||
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||
api_logger.warning(
|
||||
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
|
||||
)
|
||||
min_days = 7
|
||||
|
||||
# 调用内部方法获取分页数据
|
||||
pending_nodes_result = await self._get_pending_forgetting_nodes(
|
||||
connector=connector,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
|
||||
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
|
||||
)
|
||||
|
||||
return pending_nodes_result
|
||||
|
||||
async def get_forgetting_curve(
|
||||
self,
|
||||
db: Session,
|
||||
|
||||
Reference in New Issue
Block a user