feat(memory): add end user memory count filtering

- Sync memory_count after Neo4j write and forgetting cycle
- Filter Neo4j end user list by memory_count > 0
- Filter RAG end user list by Memory knowledge chunk count
This commit is contained in:
miao
2026-04-29 14:21:14 +08:00
parent d30b9224ab
commit a7d3930f4d
6 changed files with 270 additions and 73 deletions

View File

@@ -1,5 +1,5 @@
from sqlalchemy.orm import Session
from sqlalchemy import desc, nullslast, or_, and_, cast, String
from sqlalchemy import desc, nullslast, or_, and_, cast, String, func
from typing import List, Optional, Dict, Any
import uuid
from fastapi import HTTPException
@@ -102,6 +102,7 @@ def get_workspace_end_users_paginated(
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
返回结果按 created_at 从新到旧排序NULL 值排在最后)
固定过滤 memory_count > 0 的宿主,保证分页基于“有记忆宿主”集合计算。
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
Args:
@@ -120,7 +121,8 @@ def get_workspace_end_users_paginated(
try:
# 构建基础查询
base_query = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id
EndUserModel.workspace_id == workspace_id,
EndUserModel.memory_count > 0 , # 只查询有记忆的宿主
)
# 构建搜索条件过滤空字符串和None
@@ -169,6 +171,104 @@ def get_workspace_end_users_paginated(
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_end_users_paginated_rag(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
page: int,
pagesize: int,
keyword: Optional[str] = None,
) -> Dict[str, Any]:
"""RAG 模式宿主列表分页。
RAG 记忆数量以 documents.chunk_num 为准:
- file_name = end_user_id + ".txt"
- 只统计当前 workspace 下 permission_id="Memory" 的用户记忆知识库
- 在 SQL 层过滤 chunk 总数为 0 的宿主,保证分页准确
"""
business_logger.info(
f"获取 RAG 宿主列表(分页): workspace_id={workspace_id}, "
f"keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}"
)
try:
from app.models.document_model import Document
from app.models.knowledge_model import Knowledge
chunk_subquery = (
db.query(
Document.file_name.label("file_name"),
func.coalesce(func.sum(Document.chunk_num), 0).label("memory_count"),
)
.join(Knowledge, Document.kb_id == Knowledge.id)
.filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id == "Memory",
Document.status == 1,
)
.group_by(Document.file_name)
.subquery()
)
base_query = (
db.query(
EndUserModel,
chunk_subquery.c.memory_count.label("memory_count"),
)
.join(
chunk_subquery,
chunk_subquery.c.file_name == func.concat(cast(EndUserModel.id, String), ".txt"),
)
.filter(
EndUserModel.workspace_id == workspace_id,
chunk_subquery.c.memory_count > 0,
)
)
keyword = keyword.strip() if keyword else None
if keyword:
keyword_pattern = f"%{keyword}%"
base_query = base_query.filter(
or_(
EndUserModel.other_name.ilike(keyword_pattern),
and_(
or_(
EndUserModel.other_name.is_(None),
EndUserModel.other_name == "",
),
cast(EndUserModel.id, String).ilike(keyword_pattern),
),
)
)
total = base_query.count()
if total == 0:
business_logger.info("RAG 模式下没有符合条件的宿主")
return {"items": [], "total": 0}
rows = base_query.order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id),
).offset((page - 1) * pagesize).limit(pagesize).all()
items = []
for end_user_orm, memory_count in rows:
items.append({
"end_user": EndUserSchema.model_validate(end_user_orm),
"memory_count": int(memory_count or 0),
})
business_logger.info(f"成功获取 RAG 宿主记录 {len(items)} 条,总计 {total}")
return {"items": items, "total": total}
except HTTPException:
raise
except Exception as e:
business_logger.error(
f"获取 RAG 宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}"
)
raise
def get_workspace_memory_increment(
db: Session,