feat(memory): add end user memory count filtering

- Sync memory_count after Neo4j write and forgetting cycle
- Filter Neo4j end user list by memory_count > 0
- Filter RAG end user list by Memory knowledge chunk count
This commit is contained in:
miao
2026-04-29 14:21:14 +08:00
parent d30b9224ab
commit a7d3930f4d
6 changed files with 270 additions and 73 deletions

View File

@@ -1,4 +1,4 @@
import asyncio
import uuid import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -10,7 +10,7 @@ from app.dependencies import get_current_user
from app.models.user_model import User from app.models.user_model import User
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service from app.services import memory_dashboard_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.services.app_statistics_service import AppStatisticsService from app.services.app_statistics_service import AppStatisticsService
from app.core.logging_config import get_api_logger from app.core.logging_config import get_api_logger
@@ -48,7 +48,7 @@ def get_workspace_total_end_users(
@router.get("/end_users", response_model=ApiResponse) @router.get("/end_users", response_model=ApiResponse)
async def get_workspace_end_users( def get_workspace_end_users(
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID可选默认当前用户工作空间"), workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID可选默认当前用户工作空间"),
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id"), keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id"),
page: int = Query(1, ge=1, description="页码从1开始"), page: int = Query(1, ge=1, description="页码从1开始"),
@@ -58,6 +58,15 @@ async def get_workspace_end_users(
): ):
""" """
获取工作空间的宿主列表(分页查询,支持模糊搜索) 获取工作空间的宿主列表(分页查询,支持模糊搜索)
新增:记忆数量过滤:
Neo4j 模式:
- 使用 end_users.memory_count 过滤 memory_count > 0 的宿主
- memory_num.total 直接取 end_user.memory_count
RAG 模式:
- 使用 documents.chunk_num 聚合过滤 chunk 总数 > 0 的宿主
- memory_num.total 取聚合后的 chunk 总数
返回工作空间下的宿主列表,支持分页查询和模糊搜索。 返回工作空间下的宿主列表,支持分页查询和模糊搜索。
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。 通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
@@ -80,17 +89,29 @@ async def get_workspace_end_users(
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}") api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
# 获取分页的 end_users if current_workspace_type == "rag":
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( end_users_result = memory_dashboard_service.get_workspace_end_users_paginated_rag(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
current_user=current_user, current_user=current_user,
page=page, page=page,
pagesize=pagesize, pagesize=pagesize,
keyword=keyword keyword=keyword,
) )
raw_items = end_users_result.get("items", [])
end_users = [item["end_user"] for item in raw_items]
else:
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
db=db,
workspace_id=workspace_id,
current_user=current_user,
page=page,
pagesize=pagesize,
keyword=keyword,
)
raw_items = end_users_result.get("items", [])
end_users = raw_items
end_users = end_users_result.get("items", [])
total = end_users_result.get("total", 0) total = end_users_result.get("total", 0)
if not end_users: if not end_users:
@@ -101,50 +122,19 @@ async def get_workspace_end_users(
"page": page, "page": page,
"pagesize": pagesize, "pagesize": pagesize,
"total": total, "total": total,
"hasnext": (page * pagesize) < total "hasnext": (page * pagesize) < total,
} },
}, msg="宿主列表获取成功") }, msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users] end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务 try:
async def get_memory_configs(): memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
"""获取记忆配置(在线程池中执行同步查询)""" except Exception as e:
try: api_logger.error(f"批量获取记忆配置失败: {str(e)}")
return await asyncio.to_thread( memory_configs_map = {}
get_end_users_connected_configs_batch,
end_user_ids, db
)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
return {}
async def get_memory_nums(): # 触发按需初始化:为 implicit_emotions_storage / interest_distribution 中没有记录的用户异步生成数据
"""获取记忆数量"""
if current_workspace_type == "rag":
# RAG 模式:批量查询
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式批量查询简化版本只返回total
try:
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
return {uid: {"total": count} for uid, count in batch_result.items()}
except Exception as e:
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try: try:
from app.celery_app import celery_app as _celery_app from app.celery_app import celery_app as _celery_app
_celery_app.send_task( _celery_app.send_task(
@@ -159,27 +149,26 @@ async def get_workspace_end_users(
except Exception as e: except Exception as e:
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}") api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果列表
items = [] items = []
for end_user in end_users: for index, end_user in enumerate(end_users):
user_id = str(end_user.id) user_id = str(end_user.id)
config_info = memory_configs_map.get(user_id, {}) config_info = memory_configs_map.get(user_id, {})
if current_workspace_type == "rag":
memory_total = int(raw_items[index].get("memory_count", 0) or 0)
else:
memory_total = int(getattr(end_user, "memory_count", 0) or 0)
items.append({ items.append({
'end_user': { "end_user": {
'id': user_id, "id": user_id,
'other_name': end_user.other_name "other_name": end_user.other_name,
}, },
'memory_num': memory_nums_map.get(user_id, {"total": 0}), "memory_num": {"total": memory_total},
'memory_config': { "memory_config": {
"memory_config_id": config_info.get("memory_config_id"), "memory_config_id": config_info.get("memory_config_id"),
"memory_config_name": config_info.get("memory_config_name") "memory_config_name": config_info.get("memory_config_name"),
} },
}) })
# 触发社区聚类补全任务(异步,不阻塞接口响应) # 触发社区聚类补全任务(异步,不阻塞接口响应)
@@ -407,6 +396,7 @@ def get_current_user_rag_total_num(
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user) total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
return success(data=total_chunk, msg="宿主RAG知识数据获取成功") return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
@router.get("/rag_content", response_model=ApiResponse) @router.get("/rag_content", response_model=ApiResponse)
def get_rag_content( def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"), end_user_id: str = Query(..., description="宿主ID"),

View File

@@ -313,6 +313,9 @@ async def write(
except Exception as cache_err: except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
#同步neo4j记忆节点总数到pgsqlend_user表的memory_count字段
await _sync_memory_count_after_write(end_user_id)
# Close LLM/Embedder underlying httpx clients to prevent # Close LLM/Embedder underlying httpx clients to prevent
# 'RuntimeError: Event loop is closed' during garbage collection # 'RuntimeError: Event loop is closed' during garbage collection
for client_obj in (llm_client, embedder_client): for client_obj in (llm_client, embedder_client):
@@ -331,3 +334,49 @@ async def write(
logger.info("=== Pipeline Complete ===") logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds") logger.info(f"Total execution time: {total_time:.2f} seconds")
async def _sync_memory_count_after_write(end_user_id: str) -> None:
"""
记忆写入完成后,查 Neo4j 全量节点数,绝对值同步到 PostgreSQL end_user 表的 memory_count 字段
不使用增量累加:
- Neo4j 写入使用 MERGE 语义,节点列表长度不等于新增节点数。
- 重试或重复写入可能匹配已有节点。
- 绝对值覆盖可以避免越加越大的计数漂移。
"""
if not end_user_id:
return
try:
from app.models.end_user_model import EndUser
from app.repositories.memory_config_repository import MemoryConfigRepository
connector = Neo4jConnector()
try:
result = await connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=[end_user_id],
)
node_count = int(result[0]["total"]) if result else 0
finally:
await connector.close()
with get_db_context() as db:
db.query(EndUser).filter(
EndUser.id == uuid.UUID(end_user_id)
).update(
{"memory_count": node_count},
synchronize_session=False,
)
db.commit()
logger.info(
f"[MemoryCount] 写入后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)

View File

@@ -145,7 +145,8 @@ class ForgettingScheduler:
} }
logger.info("没有可遗忘的节点对,遗忘周期结束") logger.info("没有可遗忘的节点对,遗忘周期结束")
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
await self._sync_memory_count_to_mysql(end_user_id)
return report return report
# 步骤3按激活值排序激活值最低的优先 # 步骤3按激活值排序激活值最低的优先
@@ -302,7 +303,8 @@ class ForgettingScheduler:
f"({reduction_rate:.2%}), " f"({reduction_rate:.2%}), "
f"耗时 {duration:.2f}" f"耗时 {duration:.2f}"
) )
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
await self._sync_memory_count_to_mysql(end_user_id)
return report return report
except Exception as e: except Exception as e:
@@ -350,3 +352,48 @@ class ForgettingScheduler:
if results: if results:
return results[0]['total'] return results[0]['total']
return 0 return 0
async def _sync_memory_count_to_mysql(
self,
end_user_id: Optional[str] = None,
) -> None:
"""
遗忘周期结束后,用 SEARCH_FOR_ALL_BATCH 口径查全量节点数,
同步到 PostgreSQL end_users.memory_count。
不复用 _count_knowledge_nodes
- _count_knowledge_nodes 只统计 Statement、ExtractedEntity、MemorySummary。
- 宿主列表需要统计该 end_user_id 下全部 Neo4j 节点。
"""
if not end_user_id:
return
try:
from app.db import get_db_context
from app.models.end_user_model import EndUser
from app.repositories.memory_config_repository import MemoryConfigRepository
result = await self.connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=[end_user_id],
)
node_count = int(result[0]["total"]) if result else 0
with get_db_context() as db:
db.query(EndUser).filter(
EndUser.id == UUID(end_user_id)
).update(
{"memory_count": node_count},
synchronize_session=False,
)
db.commit()
logger.info(
f"[MemoryCount] 遗忘后同步 memory_count: "
f"end_user_id={end_user_id}, count={node_count}"
)
except Exception as e:
logger.warning(
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
exc_info=True,
)

View File

@@ -1,7 +1,7 @@
import datetime import datetime
import uuid import uuid
from sqlalchemy import Column, DateTime, ForeignKey, String, Text from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
@@ -38,6 +38,15 @@ class EndUser(Base):
comment="关联的记忆配置ID" comment="关联的记忆配置ID"
) )
memory_count = Column(
Integer,
nullable=False,
default=0,
server_default="0",
index=True,
comment="记忆节点总数",
)
# 用户摘要四个维度 - User Summary Four Dimensions # 用户摘要四个维度 - User Summary Four Dimensions
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)") user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)")
personality_traits = Column(Text, nullable=True, comment="性格特点") personality_traits = Column(Text, nullable=True, comment="性格特点")

View File

@@ -19,4 +19,6 @@ class EndUser(BaseModel):
# 用户摘要和洞察更新时间 # 用户摘要和洞察更新时间
user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None) user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None)
memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None)
#用户记忆节点总数Neo4j模式
memory_count: int = Field(description="记忆节点总数", default=0)

View File

@@ -1,5 +1,5 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import desc, nullslast, or_, and_, cast, String from sqlalchemy import desc, nullslast, or_, and_, cast, String, func
from typing import List, Optional, Dict, Any from typing import List, Optional, Dict, Any
import uuid import uuid
from fastapi import HTTPException from fastapi import HTTPException
@@ -102,6 +102,7 @@ def get_workspace_end_users_paginated(
"""获取工作空间的宿主列表(分页版本,支持模糊搜索) """获取工作空间的宿主列表(分页版本,支持模糊搜索)
返回结果按 created_at 从新到旧排序NULL 值排在最后) 返回结果按 created_at 从新到旧排序NULL 值排在最后)
固定过滤 memory_count > 0 的宿主,保证分页基于“有记忆宿主”集合计算。
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段 支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
Args: Args:
@@ -120,7 +121,8 @@ def get_workspace_end_users_paginated(
try: try:
# 构建基础查询 # 构建基础查询
base_query = db.query(EndUserModel).filter( base_query = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id EndUserModel.workspace_id == workspace_id,
EndUserModel.memory_count > 0 , # 只查询有记忆的宿主
) )
# 构建搜索条件过滤空字符串和None # 构建搜索条件过滤空字符串和None
@@ -169,6 +171,104 @@ def get_workspace_end_users_paginated(
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}") business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
raise raise
def get_workspace_end_users_paginated_rag(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
page: int,
pagesize: int,
keyword: Optional[str] = None,
) -> Dict[str, Any]:
"""RAG 模式宿主列表分页。
RAG 记忆数量以 documents.chunk_num 为准:
- file_name = end_user_id + ".txt"
- 只统计当前 workspace 下 permission_id="Memory" 的用户记忆知识库
- 在 SQL 层过滤 chunk 总数为 0 的宿主,保证分页准确
"""
business_logger.info(
f"获取 RAG 宿主列表(分页): workspace_id={workspace_id}, "
f"keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}"
)
try:
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
chunk_subquery = (
db.query(
Document.file_name.label("file_name"),
func.coalesce(func.sum(Document.chunk_num), 0).label("memory_count"),
)
.join(Knowledge, Document.kb_id == Knowledge.id)
.filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id == "Memory",
Document.status == 1,
)
.group_by(Document.file_name)
.subquery()
)
base_query = (
db.query(
EndUserModel,
chunk_subquery.c.memory_count.label("memory_count"),
)
.join(
chunk_subquery,
chunk_subquery.c.file_name == func.concat(cast(EndUserModel.id, String), ".txt"),
)
.filter(
EndUserModel.workspace_id == workspace_id,
chunk_subquery.c.memory_count > 0,
)
)
keyword = keyword.strip() if keyword else None
if keyword:
keyword_pattern = f"%{keyword}%"
base_query = base_query.filter(
or_(
EndUserModel.other_name.ilike(keyword_pattern),
and_(
or_(
EndUserModel.other_name.is_(None),
EndUserModel.other_name == "",
),
cast(EndUserModel.id, String).ilike(keyword_pattern),
),
)
)
total = base_query.count()
if total == 0:
business_logger.info("RAG 模式下没有符合条件的宿主")
return {"items": [], "total": 0}
rows = base_query.order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id),
).offset((page - 1) * pagesize).limit(pagesize).all()
items = []
for end_user_orm, memory_count in rows:
items.append({
"end_user": EndUserSchema.model_validate(end_user_orm),
"memory_count": int(memory_count or 0),
})
business_logger.info(f"成功获取 RAG 宿主记录 {len(items)} 条,总计 {total}")
return {"items": items, "total": total}
except HTTPException:
raise
except Exception as e:
business_logger.error(
f"获取 RAG 宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}"
)
raise
def get_workspace_memory_increment( def get_workspace_memory_increment(
db: Session, db: Session,