feat(memory): add end user memory count filtering
- Sync memory_count after Neo4j write and forgetting cycle - Filter Neo4j end user list by memory_count > 0 - Filter RAG end user list by Memory knowledge chunk count
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
import asyncio
|
|
||||||
import uuid
|
import uuid
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -10,7 +10,7 @@ from app.dependencies import get_current_user
|
|||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
|
||||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
from app.services import memory_dashboard_service, workspace_service
|
||||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||||
from app.services.app_statistics_service import AppStatisticsService
|
from app.services.app_statistics_service import AppStatisticsService
|
||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
@@ -48,7 +48,7 @@ def get_workspace_total_end_users(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/end_users", response_model=ApiResponse)
|
@router.get("/end_users", response_model=ApiResponse)
|
||||||
async def get_workspace_end_users(
|
def get_workspace_end_users(
|
||||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||||
@@ -58,6 +58,15 @@ async def get_workspace_end_users(
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||||
|
|
||||||
|
新增:记忆数量过滤:
|
||||||
|
Neo4j 模式:
|
||||||
|
- 使用 end_users.memory_count 过滤 memory_count > 0 的宿主
|
||||||
|
- memory_num.total 直接取 end_user.memory_count
|
||||||
|
|
||||||
|
RAG 模式:
|
||||||
|
- 使用 documents.chunk_num 聚合过滤 chunk 总数 > 0 的宿主
|
||||||
|
- memory_num.total 取聚合后的 chunk 总数
|
||||||
|
|
||||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||||
@@ -80,17 +89,29 @@ async def get_workspace_end_users(
|
|||||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||||
|
|
||||||
# 获取分页的 end_users
|
if current_workspace_type == "rag":
|
||||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated_rag(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
current_user=current_user,
|
current_user=current_user,
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=pagesize,
|
||||||
keyword=keyword
|
keyword=keyword,
|
||||||
)
|
)
|
||||||
|
raw_items = end_users_result.get("items", [])
|
||||||
|
end_users = [item["end_user"] for item in raw_items]
|
||||||
|
else:
|
||||||
|
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||||
|
db=db,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
current_user=current_user,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
keyword=keyword,
|
||||||
|
)
|
||||||
|
raw_items = end_users_result.get("items", [])
|
||||||
|
end_users = raw_items
|
||||||
|
|
||||||
end_users = end_users_result.get("items", [])
|
|
||||||
total = end_users_result.get("total", 0)
|
total = end_users_result.get("total", 0)
|
||||||
|
|
||||||
if not end_users:
|
if not end_users:
|
||||||
@@ -101,50 +122,19 @@ async def get_workspace_end_users(
|
|||||||
"page": page,
|
"page": page,
|
||||||
"pagesize": pagesize,
|
"pagesize": pagesize,
|
||||||
"total": total,
|
"total": total,
|
||||||
"hasnext": (page * pagesize) < total
|
"hasnext": (page * pagesize) < total,
|
||||||
}
|
},
|
||||||
}, msg="宿主列表获取成功")
|
}, msg="宿主列表获取成功")
|
||||||
|
|
||||||
end_user_ids = [str(user.id) for user in end_users]
|
end_user_ids = [str(user.id) for user in end_users]
|
||||||
|
|
||||||
# 并发执行两个独立的查询任务
|
try:
|
||||||
async def get_memory_configs():
|
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
except Exception as e:
|
||||||
try:
|
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||||
return await asyncio.to_thread(
|
memory_configs_map = {}
|
||||||
get_end_users_connected_configs_batch,
|
|
||||||
end_user_ids, db
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
async def get_memory_nums():
|
# 触发按需初始化:为 implicit_emotions_storage / interest_distribution 中没有记录的用户异步生成数据
|
||||||
"""获取记忆数量"""
|
|
||||||
if current_workspace_type == "rag":
|
|
||||||
# RAG 模式:批量查询
|
|
||||||
try:
|
|
||||||
chunk_map = await asyncio.to_thread(
|
|
||||||
memory_dashboard_service.get_users_total_chunk_batch,
|
|
||||||
end_user_ids, db, current_user
|
|
||||||
)
|
|
||||||
return {uid: {"total": count} for uid, count in chunk_map.items()}
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
|
||||||
|
|
||||||
elif current_workspace_type == "neo4j":
|
|
||||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
|
||||||
try:
|
|
||||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
|
||||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
|
||||||
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
|
||||||
|
|
||||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
|
||||||
try:
|
try:
|
||||||
from app.celery_app import celery_app as _celery_app
|
from app.celery_app import celery_app as _celery_app
|
||||||
_celery_app.send_task(
|
_celery_app.send_task(
|
||||||
@@ -159,27 +149,26 @@ async def get_workspace_end_users(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||||
|
|
||||||
# 并发执行配置查询和记忆数量查询
|
|
||||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
|
||||||
get_memory_configs(),
|
|
||||||
get_memory_nums()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 构建结果列表
|
|
||||||
items = []
|
items = []
|
||||||
for end_user in end_users:
|
for index, end_user in enumerate(end_users):
|
||||||
user_id = str(end_user.id)
|
user_id = str(end_user.id)
|
||||||
config_info = memory_configs_map.get(user_id, {})
|
config_info = memory_configs_map.get(user_id, {})
|
||||||
|
|
||||||
|
if current_workspace_type == "rag":
|
||||||
|
memory_total = int(raw_items[index].get("memory_count", 0) or 0)
|
||||||
|
else:
|
||||||
|
memory_total = int(getattr(end_user, "memory_count", 0) or 0)
|
||||||
|
|
||||||
items.append({
|
items.append({
|
||||||
'end_user': {
|
"end_user": {
|
||||||
'id': user_id,
|
"id": user_id,
|
||||||
'other_name': end_user.other_name
|
"other_name": end_user.other_name,
|
||||||
},
|
},
|
||||||
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
|
"memory_num": {"total": memory_total},
|
||||||
'memory_config': {
|
"memory_config": {
|
||||||
"memory_config_id": config_info.get("memory_config_id"),
|
"memory_config_id": config_info.get("memory_config_id"),
|
||||||
"memory_config_name": config_info.get("memory_config_name")
|
"memory_config_name": config_info.get("memory_config_name"),
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||||
@@ -407,6 +396,7 @@ def get_current_user_rag_total_num(
|
|||||||
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
|
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
|
||||||
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
|
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/rag_content", response_model=ApiResponse)
|
@router.get("/rag_content", response_model=ApiResponse)
|
||||||
def get_rag_content(
|
def get_rag_content(
|
||||||
end_user_id: str = Query(..., description="宿主ID"),
|
end_user_id: str = Query(..., description="宿主ID"),
|
||||||
|
|||||||
@@ -313,6 +313,9 @@ async def write(
|
|||||||
except Exception as cache_err:
|
except Exception as cache_err:
|
||||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
|
#同步neo4j记忆节点总数到pgsql,end_user表的memory_count字段
|
||||||
|
await _sync_memory_count_after_write(end_user_id)
|
||||||
|
|
||||||
# Close LLM/Embedder underlying httpx clients to prevent
|
# Close LLM/Embedder underlying httpx clients to prevent
|
||||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||||
for client_obj in (llm_client, embedder_client):
|
for client_obj in (llm_client, embedder_client):
|
||||||
@@ -331,3 +334,49 @@ async def write(
|
|||||||
|
|
||||||
logger.info("=== Pipeline Complete ===")
|
logger.info("=== Pipeline Complete ===")
|
||||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_memory_count_after_write(end_user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
记忆写入完成后,查 Neo4j 全量节点数,绝对值同步到 PostgreSQL end_user 表的 memory_count 字段
|
||||||
|
|
||||||
|
不使用增量累加:
|
||||||
|
- Neo4j 写入使用 MERGE 语义,节点列表长度不等于新增节点数。
|
||||||
|
- 重试或重复写入可能匹配已有节点。
|
||||||
|
- 绝对值覆盖可以避免越加越大的计数漂移。
|
||||||
|
"""
|
||||||
|
if not end_user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
result = await connector.execute_query(
|
||||||
|
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||||
|
end_user_ids=[end_user_id],
|
||||||
|
)
|
||||||
|
node_count = int(result[0]["total"]) if result else 0
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
db.query(EndUser).filter(
|
||||||
|
EndUser.id == uuid.UUID(end_user_id)
|
||||||
|
).update(
|
||||||
|
{"memory_count": node_count},
|
||||||
|
synchronize_session=False,
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[MemoryCount] 写入后同步 memory_count: "
|
||||||
|
f"end_user_id={end_user_id}, count={node_count}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
@@ -145,7 +145,8 @@ class ForgettingScheduler:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
||||||
|
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
|
||||||
|
await self._sync_memory_count_to_mysql(end_user_id)
|
||||||
return report
|
return report
|
||||||
|
|
||||||
# 步骤3:按激活值排序(激活值最低的优先)
|
# 步骤3:按激活值排序(激活值最低的优先)
|
||||||
@@ -302,7 +303,8 @@ class ForgettingScheduler:
|
|||||||
f"({reduction_rate:.2%}), "
|
f"({reduction_rate:.2%}), "
|
||||||
f"耗时 {duration:.2f} 秒"
|
f"耗时 {duration:.2f} 秒"
|
||||||
)
|
)
|
||||||
|
# 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段
|
||||||
|
await self._sync_memory_count_to_mysql(end_user_id)
|
||||||
return report
|
return report
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -350,3 +352,48 @@ class ForgettingScheduler:
|
|||||||
if results:
|
if results:
|
||||||
return results[0]['total']
|
return results[0]['total']
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
async def _sync_memory_count_to_mysql(
|
||||||
|
self,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
遗忘周期结束后,用 SEARCH_FOR_ALL_BATCH 口径查全量节点数,
|
||||||
|
同步到 PostgreSQL end_users.memory_count。
|
||||||
|
|
||||||
|
不复用 _count_knowledge_nodes:
|
||||||
|
- _count_knowledge_nodes 只统计 Statement、ExtractedEntity、MemorySummary。
|
||||||
|
- 宿主列表需要统计该 end_user_id 下全部 Neo4j 节点。
|
||||||
|
"""
|
||||||
|
if not end_user_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||||
|
end_user_ids=[end_user_id],
|
||||||
|
)
|
||||||
|
node_count = int(result[0]["total"]) if result else 0
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
db.query(EndUser).filter(
|
||||||
|
EndUser.id == UUID(end_user_id)
|
||||||
|
).update(
|
||||||
|
{"memory_count": node_count},
|
||||||
|
synchronize_session=False,
|
||||||
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[MemoryCount] 遗忘后同步 memory_count: "
|
||||||
|
f"end_user_id={end_user_id}, count={node_count}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from sqlalchemy import Column, DateTime, ForeignKey, String, Text
|
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
|
||||||
@@ -38,6 +38,15 @@ class EndUser(Base):
|
|||||||
comment="关联的记忆配置ID"
|
comment="关联的记忆配置ID"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
memory_count = Column(
|
||||||
|
Integer,
|
||||||
|
nullable=False,
|
||||||
|
default=0,
|
||||||
|
server_default="0",
|
||||||
|
index=True,
|
||||||
|
comment="记忆节点总数",
|
||||||
|
)
|
||||||
|
|
||||||
# 用户摘要四个维度 - User Summary Four Dimensions
|
# 用户摘要四个维度 - User Summary Four Dimensions
|
||||||
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)")
|
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)")
|
||||||
personality_traits = Column(Text, nullable=True, comment="性格特点")
|
personality_traits = Column(Text, nullable=True, comment="性格特点")
|
||||||
|
|||||||
@@ -19,4 +19,6 @@ class EndUser(BaseModel):
|
|||||||
|
|
||||||
# 用户摘要和洞察更新时间
|
# 用户摘要和洞察更新时间
|
||||||
user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None)
|
user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None)
|
||||||
memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None)
|
memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None)
|
||||||
|
#用户记忆节点总数(Neo4j模式)
|
||||||
|
memory_count: int = Field(description="记忆节点总数", default=0)
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import desc, nullslast, or_, and_, cast, String
|
from sqlalchemy import desc, nullslast, or_, and_, cast, String, func
|
||||||
from typing import List, Optional, Dict, Any
|
from typing import List, Optional, Dict, Any
|
||||||
import uuid
|
import uuid
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@@ -102,6 +102,7 @@ def get_workspace_end_users_paginated(
|
|||||||
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
|
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
|
||||||
|
|
||||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||||
|
固定过滤 memory_count > 0 的宿主,保证分页基于“有记忆宿主”集合计算。
|
||||||
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
|
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -120,7 +121,8 @@ def get_workspace_end_users_paginated(
|
|||||||
try:
|
try:
|
||||||
# 构建基础查询
|
# 构建基础查询
|
||||||
base_query = db.query(EndUserModel).filter(
|
base_query = db.query(EndUserModel).filter(
|
||||||
EndUserModel.workspace_id == workspace_id
|
EndUserModel.workspace_id == workspace_id,
|
||||||
|
EndUserModel.memory_count > 0 , # 只查询有记忆的宿主
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建搜索条件(过滤空字符串和None)
|
# 构建搜索条件(过滤空字符串和None)
|
||||||
@@ -169,6 +171,104 @@ def get_workspace_end_users_paginated(
|
|||||||
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
|
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_workspace_end_users_paginated_rag(
|
||||||
|
db: Session,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
current_user: User,
|
||||||
|
page: int,
|
||||||
|
pagesize: int,
|
||||||
|
keyword: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""RAG 模式宿主列表分页。
|
||||||
|
|
||||||
|
RAG 记忆数量以 documents.chunk_num 为准:
|
||||||
|
- file_name = end_user_id + ".txt"
|
||||||
|
- 只统计当前 workspace 下 permission_id="Memory" 的用户记忆知识库
|
||||||
|
- 在 SQL 层过滤 chunk 总数为 0 的宿主,保证分页准确
|
||||||
|
"""
|
||||||
|
business_logger.info(
|
||||||
|
f"获取 RAG 宿主列表(分页): workspace_id={workspace_id}, "
|
||||||
|
f"keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.models.document_model import Document
|
||||||
|
from app.models.knowledge_model import Knowledge
|
||||||
|
|
||||||
|
chunk_subquery = (
|
||||||
|
db.query(
|
||||||
|
Document.file_name.label("file_name"),
|
||||||
|
func.coalesce(func.sum(Document.chunk_num), 0).label("memory_count"),
|
||||||
|
)
|
||||||
|
.join(Knowledge, Document.kb_id == Knowledge.id)
|
||||||
|
.filter(
|
||||||
|
Knowledge.workspace_id == workspace_id,
|
||||||
|
Knowledge.status == 1,
|
||||||
|
Knowledge.permission_id == "Memory",
|
||||||
|
Document.status == 1,
|
||||||
|
)
|
||||||
|
.group_by(Document.file_name)
|
||||||
|
.subquery()
|
||||||
|
)
|
||||||
|
|
||||||
|
base_query = (
|
||||||
|
db.query(
|
||||||
|
EndUserModel,
|
||||||
|
chunk_subquery.c.memory_count.label("memory_count"),
|
||||||
|
)
|
||||||
|
.join(
|
||||||
|
chunk_subquery,
|
||||||
|
chunk_subquery.c.file_name == func.concat(cast(EndUserModel.id, String), ".txt"),
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
EndUserModel.workspace_id == workspace_id,
|
||||||
|
chunk_subquery.c.memory_count > 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
keyword = keyword.strip() if keyword else None
|
||||||
|
if keyword:
|
||||||
|
keyword_pattern = f"%{keyword}%"
|
||||||
|
base_query = base_query.filter(
|
||||||
|
or_(
|
||||||
|
EndUserModel.other_name.ilike(keyword_pattern),
|
||||||
|
and_(
|
||||||
|
or_(
|
||||||
|
EndUserModel.other_name.is_(None),
|
||||||
|
EndUserModel.other_name == "",
|
||||||
|
),
|
||||||
|
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
total = base_query.count()
|
||||||
|
if total == 0:
|
||||||
|
business_logger.info("RAG 模式下没有符合条件的宿主")
|
||||||
|
return {"items": [], "total": 0}
|
||||||
|
|
||||||
|
rows = base_query.order_by(
|
||||||
|
nullslast(desc(EndUserModel.created_at)),
|
||||||
|
desc(EndUserModel.id),
|
||||||
|
).offset((page - 1) * pagesize).limit(pagesize).all()
|
||||||
|
|
||||||
|
items = []
|
||||||
|
for end_user_orm, memory_count in rows:
|
||||||
|
items.append({
|
||||||
|
"end_user": EndUserSchema.model_validate(end_user_orm),
|
||||||
|
"memory_count": int(memory_count or 0),
|
||||||
|
})
|
||||||
|
|
||||||
|
business_logger.info(f"成功获取 RAG 宿主记录 {len(items)} 条,总计 {total} 条")
|
||||||
|
return {"items": items, "total": total}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(
|
||||||
|
f"获取 RAG 宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def get_workspace_memory_increment(
|
def get_workspace_memory_increment(
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
Reference in New Issue
Block a user