Merge pull request #723 from SuanmoSuanyangTechnology/fix/forgetting-task
[fix] Remove the limit on the number of output items.
This commit is contained in:
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
ForgettingCurveRequest,
|
ForgettingCurveRequest,
|
||||||
ForgettingCurveResponse,
|
ForgettingCurveResponse,
|
||||||
ForgettingCurvePoint,
|
ForgettingCurvePoint,
|
||||||
|
PendingNodesResponse,
|
||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||||
|
async def get_pending_nodes(
|
||||||
|
end_user_id: str,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取待遗忘节点列表(独立分页接口)
|
||||||
|
|
||||||
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||||
|
此接口独立分页,与 /stats 接口分离。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 组ID(即 end_user_id,必填)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10)
|
||||||
|
current_user: 当前用户
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||||
|
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- page 从1开始,pagesize 必须大于0
|
||||||
|
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
# 验证 end_user_id 必填
|
||||||
|
if not end_user_id:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||||
|
|
||||||
|
# 通过 end_user_id 获取关联的 config_id
|
||||||
|
try:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
|
if config_id is None:
|
||||||
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
|
|
||||||
|
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||||
|
|
||||||
|
# 验证分页参数
|
||||||
|
if page < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||||
|
if pagesize < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||||
|
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用服务层获取待遗忘节点列表
|
||||||
|
result = await forget_service.get_pending_nodes(
|
||||||
|
db=db,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
config_id=config_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response_data = PendingNodesResponse(**result)
|
||||||
|
|
||||||
|
return success(data=response_data.model_dump(), msg="查询成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||||
async def get_forgetting_curve(
|
async def get_forgetting_curve(
|
||||||
request: ForgettingCurveRequest,
|
request: ForgettingCurveRequest,
|
||||||
|
|||||||
@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
|
|||||||
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class PageInfo(BaseModel):
|
||||||
|
"""分页信息模型"""
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
page: int = Field(..., description="当前页码(从1开始)")
|
||||||
|
pagesize: int = Field(..., description="每页数量")
|
||||||
|
total: int = Field(..., description="总记录数")
|
||||||
|
hasnext: bool = Field(..., description="是否有下一页")
|
||||||
|
|
||||||
|
|
||||||
|
class PendingNodesResponse(BaseModel):
|
||||||
|
"""待遗忘节点列表响应模型(独立分页接口)"""
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
|
||||||
|
page: PageInfo = Field(..., description="分页信息")
|
||||||
|
|
||||||
|
|
||||||
class ForgettingStatsResponse(BaseModel):
|
class ForgettingStatsResponse(BaseModel):
|
||||||
"""遗忘引擎统计信息响应模型"""
|
"""遗忘引擎统计信息响应模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
|
|||||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||||
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
|
||||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -204,30 +204,35 @@ class MemoryForgetService:
|
|||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
forgetting_threshold: float,
|
forgetting_threshold: float,
|
||||||
min_days_since_access: int,
|
min_days_since_access: int,
|
||||||
limit: int = 20
|
page: Optional[int] = None,
|
||||||
) -> list[Dict[str, Any]]:
|
pagesize: Optional[int] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取待遗忘节点列表
|
获取待遗忘节点列表
|
||||||
|
|
||||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: Neo4j 连接器
|
connector: Neo4j 连接器
|
||||||
end_user_id: 组ID
|
end_user_id: 组ID
|
||||||
forgetting_threshold: 遗忘阈值
|
forgetting_threshold: 遗忘阈值
|
||||||
min_days_since_access: 最小未访问天数
|
min_days_since_access: 最小未访问天数
|
||||||
limit: 返回节点数量限制
|
page: 页码(可选,从1开始)
|
||||||
|
pagesize: 每页数量(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: 待遗忘节点列表
|
dict: 包含待遗忘节点列表和分页信息的字典
|
||||||
|
- items: 待遗忘节点列表
|
||||||
|
- page: 分页信息(分页时)
|
||||||
"""
|
"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
||||||
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
||||||
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
||||||
|
|
||||||
query = """
|
# 基础查询(用于获取总数)
|
||||||
|
count_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
AND n.end_user_id = $end_user_id
|
AND n.end_user_id = $end_user_id
|
||||||
@@ -235,10 +240,22 @@ class MemoryForgetService:
|
|||||||
AND n.activation_value < $threshold
|
AND n.activation_value < $threshold
|
||||||
AND n.last_access_time IS NOT NULL
|
AND n.last_access_time IS NOT NULL
|
||||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||||
RETURN
|
RETURN count(n) as total
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 数据查询
|
||||||
|
data_query = """
|
||||||
|
MATCH (n)
|
||||||
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
|
AND n.end_user_id = $end_user_id
|
||||||
|
AND n.activation_value IS NOT NULL
|
||||||
|
AND n.activation_value < $threshold
|
||||||
|
AND n.last_access_time IS NOT NULL
|
||||||
|
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||||
|
RETURN
|
||||||
elementId(n) as node_id,
|
elementId(n) as node_id,
|
||||||
labels(n)[0] as node_type,
|
labels(n)[0] as node_type,
|
||||||
CASE
|
CASE
|
||||||
WHEN n:Statement THEN n.statement
|
WHEN n:Statement THEN n.statement
|
||||||
WHEN n:ExtractedEntity THEN n.name
|
WHEN n:ExtractedEntity THEN n.name
|
||||||
WHEN n:MemorySummary THEN n.content
|
WHEN n:MemorySummary THEN n.content
|
||||||
@@ -247,18 +264,32 @@ class MemoryForgetService:
|
|||||||
n.activation_value as activation_value,
|
n.activation_value as activation_value,
|
||||||
n.last_access_time as last_access_time
|
n.last_access_time as last_access_time
|
||||||
ORDER BY n.activation_value ASC
|
ORDER BY n.activation_value ASC
|
||||||
LIMIT $limit
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 如果启用分页,添加 SKIP 和 LIMIT
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
data_query += " SKIP $skip LIMIT $limit"
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'end_user_id': end_user_id,
|
'end_user_id': end_user_id,
|
||||||
'threshold': forgetting_threshold,
|
'threshold': forgetting_threshold,
|
||||||
'min_access_time_str': min_access_time_str,
|
'min_access_time_str': min_access_time_str
|
||||||
'limit': limit
|
|
||||||
}
|
}
|
||||||
|
|
||||||
results = await connector.execute_query(query, **params)
|
# 获取总数(分页时需要)
|
||||||
|
total = 0
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
count_results = await connector.execute_query(count_query, **params)
|
||||||
|
if count_results:
|
||||||
|
total = count_results[0]['total']
|
||||||
|
|
||||||
|
# 添加分页参数
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
params['skip'] = (page - 1) * pagesize
|
||||||
|
params['limit'] = pagesize
|
||||||
|
|
||||||
|
results = await connector.execute_query(data_query, **params)
|
||||||
|
|
||||||
pending_nodes = []
|
pending_nodes = []
|
||||||
for result in results:
|
for result in results:
|
||||||
# 将节点类型标签转换为小写
|
# 将节点类型标签转换为小写
|
||||||
@@ -267,7 +298,7 @@ class MemoryForgetService:
|
|||||||
node_type_label = 'entity'
|
node_type_label = 'entity'
|
||||||
elif node_type_label == 'memorysummary':
|
elif node_type_label == 'memorysummary':
|
||||||
node_type_label = 'summary'
|
node_type_label = 'summary'
|
||||||
|
|
||||||
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
||||||
last_access_time = result['last_access_time']
|
last_access_time = result['last_access_time']
|
||||||
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
||||||
@@ -278,7 +309,7 @@ class MemoryForgetService:
|
|||||||
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
||||||
else:
|
else:
|
||||||
last_access_timestamp = 0
|
last_access_timestamp = 0
|
||||||
|
|
||||||
pending_nodes.append({
|
pending_nodes.append({
|
||||||
'node_id': str(result['node_id']),
|
'node_id': str(result['node_id']),
|
||||||
'node_type': node_type_label,
|
'node_type': node_type_label,
|
||||||
@@ -286,8 +317,20 @@ class MemoryForgetService:
|
|||||||
'activation_value': result['activation_value'],
|
'activation_value': result['activation_value'],
|
||||||
'last_access_time': last_access_timestamp
|
'last_access_time': last_access_timestamp
|
||||||
})
|
})
|
||||||
|
|
||||||
return pending_nodes
|
# 构建返回结果
|
||||||
|
result: Dict[str, Any] = {'items': pending_nodes}
|
||||||
|
|
||||||
|
# 如果启用分页,添加分页信息
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
result['page'] = {
|
||||||
|
'page': page,
|
||||||
|
'pagesize': pagesize,
|
||||||
|
'total': total,
|
||||||
|
'hasnext': (page * pagesize) < total
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def trigger_forgetting_cycle(
|
async def trigger_forgetting_cycle(
|
||||||
self,
|
self,
|
||||||
@@ -636,7 +679,7 @@ class MemoryForgetService:
|
|||||||
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
|
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
|
||||||
# 失败时返回空列表,不影响主流程
|
# 失败时返回空列表,不影响主流程
|
||||||
|
|
||||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
# 获取待遗忘节点列表
|
||||||
pending_nodes = []
|
pending_nodes = []
|
||||||
try:
|
try:
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
@@ -652,8 +695,7 @@ class MemoryForgetService:
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
forgetting_threshold=forgetting_threshold,
|
forgetting_threshold=forgetting_threshold,
|
||||||
min_days_since_access=int(min_days),
|
min_days_since_access=int(min_days)
|
||||||
limit=20
|
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
|
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
|
||||||
@@ -661,24 +703,79 @@ class MemoryForgetService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
||||||
# 失败时返回空列表,不影响主流程
|
# 失败时返回空列表,不影响主流程
|
||||||
|
|
||||||
# 构建统计信息
|
# 构建统计信息(不包含 pending_nodes,已分离到独立接口)
|
||||||
stats = {
|
stats = {
|
||||||
'activation_metrics': activation_metrics,
|
'activation_metrics': activation_metrics,
|
||||||
'node_distribution': node_distribution,
|
'node_distribution': node_distribution,
|
||||||
'recent_trends': recent_trends,
|
'recent_trends': recent_trends,
|
||||||
'pending_nodes': pending_nodes,
|
|
||||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
||||||
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
||||||
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
|
f"trend_days={len(recent_trends)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
async def get_pending_nodes(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
end_user_id: str,
|
||||||
|
config_id: Optional[UUID] = None,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取待遗忘节点列表(独立分页接口)
|
||||||
|
|
||||||
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
end_user_id: 组ID(必填)
|
||||||
|
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含待遗忘节点列表和分页信息的字典
|
||||||
|
- items: 待遗忘节点列表
|
||||||
|
- page: 分页信息
|
||||||
|
"""
|
||||||
|
# 获取遗忘引擎组件
|
||||||
|
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||||
|
|
||||||
|
connector = forgetting_scheduler.connector
|
||||||
|
forgetting_threshold = config['forgetting_threshold']
|
||||||
|
|
||||||
|
# 验证 min_days_since_access 配置值
|
||||||
|
min_days = config.get('min_days_since_access')
|
||||||
|
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||||
|
api_logger.warning(
|
||||||
|
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
|
||||||
|
)
|
||||||
|
min_days = 7
|
||||||
|
|
||||||
|
# 调用内部方法获取分页数据
|
||||||
|
pending_nodes_result = await self._get_pending_forgetting_nodes(
|
||||||
|
connector=connector,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
forgetting_threshold=forgetting_threshold,
|
||||||
|
min_days_since_access=int(min_days),
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
|
||||||
|
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return pending_nodes_result
|
||||||
|
|
||||||
async def get_forgetting_curve(
|
async def get_forgetting_curve(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
Reference in New Issue
Block a user