Merge pull request #723 from SuanmoSuanyangTechnology/fix/forgetting-task

[fix] Remove the limit on the number of output items.
This commit is contained in:
Ke Sun
2026-03-30 15:37:09 +08:00
committed by GitHub
3 changed files with 242 additions and 35 deletions

View File

@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
ForgettingCurveRequest, ForgettingCurveRequest,
ForgettingCurveResponse, ForgettingCurveResponse,
ForgettingCurvePoint, ForgettingCurvePoint,
PendingNodesResponse,
) )
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService from app.services.memory_forget_service import MemoryForgetService
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
@router.get("/pending-nodes", response_model=ApiResponse)
async def get_pending_nodes(
end_user_id: str,
page: int = 1,
pagesize: int = 10,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
此接口独立分页,与 /stats 接口分离。
Args:
end_user_id: 组ID即 end_user_id必填
page: 页码从1开始默认1
pagesize: 每页数量默认10
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含待遗忘节点列表和分页信息的响应
Examples:
- 第1页每页10条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
- 第2页每页20条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
Notes:
- page 从1开始pagesize 必须大于0
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 验证 end_user_id 必填
if not end_user_id:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
# 通过 end_user_id 获取关联的 config_id
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
# 验证分页参数
if page < 1:
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
if pagesize < 1:
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
)
try:
# 调用服务层获取待遗忘节点列表
result = await forget_service.get_pending_nodes(
db=db,
end_user_id=end_user_id,
config_id=config_id,
page=page,
pagesize=pagesize
)
# 构建响应
response_data = PendingNodesResponse(**result)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
@router.post("/forgetting_curve", response_model=ApiResponse) @router.post("/forgetting_curve", response_model=ApiResponse)
async def get_forgetting_curve( async def get_forgetting_curve(
request: ForgettingCurveRequest, request: ForgettingCurveRequest,

View File

@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
last_access_time: int = Field(..., description="最后访问时间Unix时间戳") last_access_time: int = Field(..., description="最后访问时间Unix时间戳")
class PageInfo(BaseModel):
"""分页信息模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
page: int = Field(..., description="当前页码从1开始")
pagesize: int = Field(..., description="每页数量")
total: int = Field(..., description="总记录数")
hasnext: bool = Field(..., description="是否有下一页")
class PendingNodesResponse(BaseModel):
"""待遗忘节点列表响应模型(独立分页接口)"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
page: PageInfo = Field(..., description="分页信息")
class ForgettingStatsResponse(BaseModel): class ForgettingStatsResponse(BaseModel):
"""遗忘引擎统计信息响应模型""" """遗忘引擎统计信息响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
node_distribution: Dict[str, int] = Field(..., description="节点类型分布") node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
description="最近7个日期的遗忘趋势数据每天取最后一次执行") description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)") timestamp: int = Field(..., description="统计时间(时间戳)")

View File

@@ -204,30 +204,35 @@ class MemoryForgetService:
end_user_id: str, end_user_id: str,
forgetting_threshold: float, forgetting_threshold: float,
min_days_since_access: int, min_days_since_access: int,
limit: int = 20 page: Optional[int] = None,
) -> list[Dict[str, Any]]: pagesize: Optional[int] = None
) -> Dict[str, Any]:
""" """
获取待遗忘节点列表 获取待遗忘节点列表
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数) 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
Args: Args:
connector: Neo4j 连接器 connector: Neo4j 连接器
end_user_id: 组ID end_user_id: 组ID
forgetting_threshold: 遗忘阈值 forgetting_threshold: 遗忘阈值
min_days_since_access: 最小未访问天数 min_days_since_access: 最小未访问天数
limit: 返回节点数量限制 page: 页码可选从1开始
pagesize: 每页数量(可选)
Returns: Returns:
list: 待遗忘节点列表 dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息(分页时)
""" """
from datetime import timedelta from datetime import timedelta
# 计算最小访问时间ISO 8601 格式字符串,使用 UTC 时区) # 计算最小访问时间ISO 8601 格式字符串,使用 UTC 时区)
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access) min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
query = """ # 基础查询(用于获取总数)
count_query = """
MATCH (n) MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id AND n.end_user_id = $end_user_id
@@ -235,10 +240,22 @@ class MemoryForgetService:
AND n.activation_value < $threshold AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str) AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN RETURN count(n) as total
"""
# 数据查询
data_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id
AND n.activation_value IS NOT NULL
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN
elementId(n) as node_id, elementId(n) as node_id,
labels(n)[0] as node_type, labels(n)[0] as node_type,
CASE CASE
WHEN n:Statement THEN n.statement WHEN n:Statement THEN n.statement
WHEN n:ExtractedEntity THEN n.name WHEN n:ExtractedEntity THEN n.name
WHEN n:MemorySummary THEN n.content WHEN n:MemorySummary THEN n.content
@@ -247,18 +264,32 @@ class MemoryForgetService:
n.activation_value as activation_value, n.activation_value as activation_value,
n.last_access_time as last_access_time n.last_access_time as last_access_time
ORDER BY n.activation_value ASC ORDER BY n.activation_value ASC
LIMIT $limit
""" """
# 如果启用分页,添加 SKIP 和 LIMIT
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
data_query += " SKIP $skip LIMIT $limit"
params = { params = {
'end_user_id': end_user_id, 'end_user_id': end_user_id,
'threshold': forgetting_threshold, 'threshold': forgetting_threshold,
'min_access_time_str': min_access_time_str, 'min_access_time_str': min_access_time_str
'limit': limit
} }
results = await connector.execute_query(query, **params) # 获取总数(分页时需要)
total = 0
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
count_results = await connector.execute_query(count_query, **params)
if count_results:
total = count_results[0]['total']
# 添加分页参数
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
params['skip'] = (page - 1) * pagesize
params['limit'] = pagesize
results = await connector.execute_query(data_query, **params)
pending_nodes = [] pending_nodes = []
for result in results: for result in results:
# 将节点类型标签转换为小写 # 将节点类型标签转换为小写
@@ -267,7 +298,7 @@ class MemoryForgetService:
node_type_label = 'entity' node_type_label = 'entity'
elif node_type_label == 'memorysummary': elif node_type_label == 'memorysummary':
node_type_label = 'summary' node_type_label = 'summary'
# 将 Neo4j DateTime 对象转换为时间戳(毫秒) # 将 Neo4j DateTime 对象转换为时间戳(毫秒)
last_access_time = result['last_access_time'] last_access_time = result['last_access_time']
last_access_dt = convert_neo4j_datetime_to_python(last_access_time) last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
@@ -278,7 +309,7 @@ class MemoryForgetService:
last_access_timestamp = int(last_access_dt.timestamp() * 1000) last_access_timestamp = int(last_access_dt.timestamp() * 1000)
else: else:
last_access_timestamp = 0 last_access_timestamp = 0
pending_nodes.append({ pending_nodes.append({
'node_id': str(result['node_id']), 'node_id': str(result['node_id']),
'node_type': node_type_label, 'node_type': node_type_label,
@@ -286,8 +317,20 @@ class MemoryForgetService:
'activation_value': result['activation_value'], 'activation_value': result['activation_value'],
'last_access_time': last_access_timestamp 'last_access_time': last_access_timestamp
}) })
return pending_nodes # 构建返回结果
result: Dict[str, Any] = {'items': pending_nodes}
# 如果启用分页,添加分页信息
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
result['page'] = {
'page': page,
'pagesize': pagesize,
'total': total,
'hasnext': (page * pagesize) < total
}
return result
async def trigger_forgetting_cycle( async def trigger_forgetting_cycle(
self, self,
@@ -636,7 +679,7 @@ class MemoryForgetService:
api_logger.error(f"获取历史趋势数据失败: {str(e)}") api_logger.error(f"获取历史趋势数据失败: {str(e)}")
# 失败时返回空列表,不影响主流程 # 失败时返回空列表,不影响主流程
# 获取待遗忘节点列表前20个满足遗忘条件的节点 # 获取待遗忘节点列表
pending_nodes = [] pending_nodes = []
try: try:
if end_user_id: if end_user_id:
@@ -652,8 +695,7 @@ class MemoryForgetService:
connector=connector, connector=connector,
end_user_id=end_user_id, end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold, forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days), min_days_since_access=int(min_days)
limit=20
) )
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点") api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
@@ -661,24 +703,79 @@ class MemoryForgetService:
except Exception as e: except Exception as e:
api_logger.error(f"获取待遗忘节点失败: {str(e)}") api_logger.error(f"获取待遗忘节点失败: {str(e)}")
# 失败时返回空列表,不影响主流程 # 失败时返回空列表,不影响主流程
# 构建统计信息 # 构建统计信息(不包含 pending_nodes已分离到独立接口
stats = { stats = {
'activation_metrics': activation_metrics, 'activation_metrics': activation_metrics,
'node_distribution': node_distribution, 'node_distribution': node_distribution,
'recent_trends': recent_trends, 'recent_trends': recent_trends,
'pending_nodes': pending_nodes,
'timestamp': int(datetime.now().timestamp() * 1000) 'timestamp': int(datetime.now().timestamp() * 1000)
} }
api_logger.info( api_logger.info(
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, " f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, " f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}" f"trend_days={len(recent_trends)}"
) )
return stats return stats
async def get_pending_nodes(
self,
db: Session,
end_user_id: str,
config_id: Optional[UUID] = None,
page: int = 1,
pagesize: int = 10
) -> Dict[str, Any]:
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
Args:
db: 数据库会话
end_user_id: 组ID必填
config_id: 配置ID可选用于获取遗忘阈值
page: 页码从1开始默认1
pagesize: 每页数量默认10
Returns:
dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息
"""
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
connector = forgetting_scheduler.connector
forgetting_threshold = config['forgetting_threshold']
# 验证 min_days_since_access 配置值
min_days = config.get('min_days_since_access')
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
api_logger.warning(
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
)
min_days = 7
# 调用内部方法获取分页数据
pending_nodes_result = await self._get_pending_forgetting_nodes(
connector=connector,
end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
page=page,
pagesize=pagesize
)
api_logger.info(
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
)
return pending_nodes_result
async def get_forgetting_curve( async def get_forgetting_curve(
self, self,
db: Session, db: Session,