Merge pull request #458 from SuanmoSuanyangTechnology/fix/RAG-memory

Fix/rag memory
This commit is contained in:
Ke Sun
2026-03-04 19:09:03 +08:00
committed by GitHub
3 changed files with 88 additions and 5 deletions

View File

@@ -606,8 +606,8 @@ async def dashboard_data(
# 获取RAG相关数据
try:
# total_memory: 使用 total_chunkchunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
# total_memory: 只统计用户知识库permission_id='Memory')的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量

View File

@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
except Exception as e:
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
raise
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中permission_id='Memory'用户知识库的chunk_num总和
"""
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
try:
from sqlalchemy import func
result = db.query(func.sum(Knowledge.chunk_num)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id == "Memory"
).scalar()
total = result if result is not None else 0
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
return total
except Exception as e:
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
raise
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中排除用户知识库permission_id!='Memory')的数量
"""
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
try:
count = db.query(Knowledge).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id != "Memory"
).count()
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
return count
except Exception as e:
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
raise

View File

@@ -390,19 +390,59 @@ def get_rag_total_kb(
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库permission_id!='Memory'的数量
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
return total_kb
except Exception as e:
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_rag_user_kb_total_chunk(
db: Session,
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id从documents表统计所有用户知识库的chunk总数。
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from app.models.end_user_model import EndUser
from app.models.app_model import App
from sqlalchemy import func
# 通过 App 关联取该 workspace 下所有 end_user_id
end_user_ids = [
str(eid) for (eid,) in db.query(EndUser.id)
.join(App, EndUser.app_id == App.id)
.filter(App.workspace_id == workspace_id)
.all()
]
if not end_user_ids:
return 0
file_names = [f"{uid}.txt" for uid in end_user_ids]
result = db.query(func.sum(Document.chunk_num)).filter(
Document.file_name.in_(file_names)
).scalar()
total_chunk = int(result or 0)
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
return total_chunk
except Exception as e:
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_current_user_total_chunk(
end_user_id: str,
db: Session,