diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 525fe1eb..56276339 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -1,4 +1,4 @@ -import asyncio + import uuid from fastapi import APIRouter, Depends, HTTPException, status, Query from pydantic import BaseModel, Field @@ -10,7 +10,7 @@ from app.dependencies import get_current_user from app.models.user_model import User from app.schemas.response_schema import ApiResponse -from app.services import memory_dashboard_service, memory_storage_service, workspace_service +from app.services import memory_dashboard_service, workspace_service from app.services.memory_agent_service import get_end_users_connected_configs_batch from app.services.app_statistics_service import AppStatisticsService from app.core.logging_config import get_api_logger @@ -48,7 +48,7 @@ def get_workspace_total_end_users( @router.get("/end_users", response_model=ApiResponse) -async def get_workspace_end_users( +def get_workspace_end_users( workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"), keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"), page: int = Query(1, ge=1, description="页码,从1开始"), @@ -58,6 +58,15 @@ async def get_workspace_end_users( ): """ 获取工作空间的宿主列表(分页查询,支持模糊搜索) + + 新增:记忆数量过滤: + Neo4j 模式: + - 使用 end_users.memory_count 过滤 memory_count > 0 的宿主 + - memory_num.total 直接取 end_user.memory_count + + RAG 模式: + - 使用 documents.chunk_num 聚合过滤 chunk 总数 > 0 的宿主 + - memory_num.total 取聚合后的 chunk 总数 返回工作空间下的宿主列表,支持分页查询和模糊搜索。 通过 keyword 参数同时模糊匹配 other_name 和 id 字段。 @@ -80,17 +89,29 @@ async def get_workspace_end_users( current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}") - # 获取分页的 end_users - end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( - db=db, - workspace_id=workspace_id, - current_user=current_user, - page=page, - pagesize=pagesize, - keyword=keyword - ) + if current_workspace_type == "rag": + end_users_result = memory_dashboard_service.get_workspace_end_users_paginated_rag( + db=db, + workspace_id=workspace_id, + current_user=current_user, + page=page, + pagesize=pagesize, + keyword=keyword, + ) + raw_items = end_users_result.get("items", []) + end_users = [item["end_user"] for item in raw_items] + else: + end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( + db=db, + workspace_id=workspace_id, + current_user=current_user, + page=page, + pagesize=pagesize, + keyword=keyword, + ) + raw_items = end_users_result.get("items", []) + end_users = raw_items - end_users = end_users_result.get("items", []) total = end_users_result.get("total", 0) if not end_users: @@ -101,50 +122,19 @@ async def get_workspace_end_users( "page": page, "pagesize": pagesize, "total": total, - "hasnext": (page * pagesize) < total - } + "hasnext": (page * pagesize) < total, + }, }, msg="宿主列表获取成功") end_user_ids = [str(user.id) for user in end_users] - # 并发执行两个独立的查询任务 - async def get_memory_configs(): - """获取记忆配置(在线程池中执行同步查询)""" - try: - return await asyncio.to_thread( - get_end_users_connected_configs_batch, - end_user_ids, db - ) - except Exception as e: - api_logger.error(f"批量获取记忆配置失败: {str(e)}") - return {} + try: + memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db) + except Exception as e: + api_logger.error(f"批量获取记忆配置失败: {str(e)}") + memory_configs_map = {} - async def get_memory_nums(): - """获取记忆数量""" - if current_workspace_type == "rag": - # RAG 模式:批量查询 - try: - chunk_map = await asyncio.to_thread( - memory_dashboard_service.get_users_total_chunk_batch, - end_user_ids, db, current_user - ) - return {uid: {"total": count} for uid, count in chunk_map.items()} - except Exception as e: - api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}") - return {uid: {"total": 0} for uid in end_user_ids} - - elif current_workspace_type == "neo4j": - # Neo4j 模式:批量查询(简化版本,只返回total) - try: - batch_result = await memory_storage_service.search_all_batch(end_user_ids) - return {uid: {"total": count} for uid, count in batch_result.items()} - except Exception as e: - api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}") - return {uid: {"total": 0} for uid in end_user_ids} - - return {uid: {"total": 0} for uid in end_user_ids} - - # 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据 + # 触发按需初始化:为 implicit_emotions_storage / interest_distribution 中没有记录的用户异步生成数据 try: from app.celery_app import celery_app as _celery_app _celery_app.send_task( @@ -159,27 +149,26 @@ async def get_workspace_end_users( except Exception as e: api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}") - # 并发执行配置查询和记忆数量查询 - memory_configs_map, memory_nums_map = await asyncio.gather( - get_memory_configs(), - get_memory_nums() - ) - - # 构建结果列表 items = [] - for end_user in end_users: + for index, end_user in enumerate(end_users): user_id = str(end_user.id) config_info = memory_configs_map.get(user_id, {}) + + if current_workspace_type == "rag": + memory_total = int(raw_items[index].get("memory_count", 0) or 0) + else: + memory_total = int(getattr(end_user, "memory_count", 0) or 0) + items.append({ - 'end_user': { - 'id': user_id, - 'other_name': end_user.other_name + "end_user": { + "id": user_id, + "other_name": end_user.other_name, }, - 'memory_num': memory_nums_map.get(user_id, {"total": 0}), - 'memory_config': { + "memory_num": {"total": memory_total}, + "memory_config": { "memory_config_id": config_info.get("memory_config_id"), - "memory_config_name": config_info.get("memory_config_name") - } + "memory_config_name": config_info.get("memory_config_name"), + }, }) # 触发社区聚类补全任务(异步,不阻塞接口响应) @@ -407,6 +396,7 @@ def get_current_user_rag_total_num( total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user) return success(data=total_chunk, msg="宿主RAG知识数据获取成功") + @router.get("/rag_content", response_model=ApiResponse) def get_rag_content( end_user_id: str = Query(..., description="宿主ID"), diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 3b0ea1ee..1dcc73b2 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -20,6 +20,7 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem memory_summary_generation from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.log.logging_utils import log_time +from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j from app.db import get_db_context from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges from app.repositories.neo4j.add_nodes import add_memory_summary_nodes @@ -313,6 +314,28 @@ async def write( except Exception as cache_err: logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + # 同步 Neo4j 记忆节点总数到 PostgreSQL end_users.memory_count + if end_user_id: + try: + memory_count_connector = Neo4jConnector() + try: + node_count = await sync_end_user_memory_count_from_neo4j( + end_user_id, + memory_count_connector, + ) + finally: + await memory_count_connector.close() + + logger.info( + f"[MemoryCount] 写入后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) + # Close LLM/Embedder underlying httpx clients to prevent # 'RuntimeError: Event loop is closed' during garbage collection for client_obj in (llm_client, embedder_client): @@ -331,3 +354,4 @@ async def write( logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") + diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py index 072d587c..39c9eed6 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py @@ -20,6 +20,7 @@ from uuid import UUID from datetime import datetime from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy +from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -145,7 +146,22 @@ class ForgettingScheduler: } logger.info("没有可遗忘的节点对,遗忘周期结束") - + # 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count + if end_user_id: + try: + node_count = await sync_end_user_memory_count_from_neo4j( + end_user_id, + self.connector, + ) + logger.info( + f"[MemoryCount] 遗忘后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) return report # 步骤3:按激活值排序(激活值最低的优先) @@ -302,7 +318,22 @@ class ForgettingScheduler: f"({reduction_rate:.2%}), " f"耗时 {duration:.2f} 秒" ) - + # 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count + if end_user_id: + try: + node_count = await sync_end_user_memory_count_from_neo4j( + end_user_id, + self.connector, + ) + logger.info( + f"[MemoryCount] 遗忘后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) return report except Exception as e: diff --git a/api/app/core/memory/utils/memory_count_utils.py b/api/app/core/memory/utils/memory_count_utils.py new file mode 100644 index 00000000..316cb635 --- /dev/null +++ b/api/app/core/memory/utils/memory_count_utils.py @@ -0,0 +1,36 @@ +from uuid import UUID + +from app.db import get_db_context +from app.models.end_user_model import EndUser +from app.repositories.memory_config_repository import MemoryConfigRepository +from app.repositories.neo4j.neo4j_connector import Neo4jConnector + + +async def sync_end_user_memory_count_from_neo4j( + end_user_id: str, + connector: Neo4jConnector, +) -> int: + """ + Sync one end user's Neo4j memory node count to PostgreSQL. + + The caller owns the Neo4j connector lifecycle. + """ + if not end_user_id: + return 0 + + result = await connector.execute_query( + MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, + end_user_ids=[end_user_id], + ) + node_count = int(result[0]["total"]) if result else 0 + + with get_db_context() as db: + db.query(EndUser).filter( + EndUser.id == UUID(end_user_id) + ).update( + {"memory_count": node_count}, + synchronize_session=False, + ) + db.commit() + + return node_count diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index ff46786a..952d58eb 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -1,7 +1,7 @@ import datetime import uuid -from sqlalchemy import Column, DateTime, ForeignKey, String, Text +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -38,6 +38,15 @@ class EndUser(Base): comment="关联的记忆配置ID" ) + memory_count = Column( + Integer, + nullable=False, + default=0, + server_default="0", + index=True, + comment="记忆节点总数", + ) + # 用户摘要四个维度 - User Summary Four Dimensions user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)") personality_traits = Column(Text, nullable=True, comment="性格特点") diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index c2498203..94d2e5dd 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -19,4 +19,6 @@ class EndUser(BaseModel): # 用户摘要和洞察更新时间 user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None) - memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) \ No newline at end of file + memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) + #用户记忆节点总数(Neo4j模式) + memory_count: int = Field(description="记忆节点总数", default=0) \ No newline at end of file diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index aaf9ac6d..6d0f0a73 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -1,5 +1,5 @@ from sqlalchemy.orm import Session -from sqlalchemy import desc, nullslast, or_, and_, cast, String +from sqlalchemy import desc, nullslast, or_, cast, String, func from typing import List, Optional, Dict, Any import uuid from fastapi import HTTPException @@ -102,6 +102,7 @@ def get_workspace_end_users_paginated( """获取工作空间的宿主列表(分页版本,支持模糊搜索) 返回结果按 created_at 从新到旧排序(NULL 值排在最后) + 固定过滤 memory_count > 0 的宿主,保证分页基于“有记忆宿主”集合计算。 支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段 Args: @@ -120,7 +121,8 @@ def get_workspace_end_users_paginated( try: # 构建基础查询 base_query = db.query(EndUserModel).filter( - EndUserModel.workspace_id == workspace_id + EndUserModel.workspace_id == workspace_id, + EndUserModel.memory_count > 0 , # 只查询有记忆的宿主 ) # 构建搜索条件(过滤空字符串和None) @@ -128,20 +130,13 @@ def get_workspace_end_users_paginated( if keyword: keyword_pattern = f"%{keyword}%" - # other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效 base_query = base_query.filter( or_( EndUserModel.other_name.ilike(keyword_pattern), - and_( - or_( - EndUserModel.other_name.is_(None), - EndUserModel.other_name == "", - ), - cast(EndUserModel.id, String).ilike(keyword_pattern), - ), + cast(EndUserModel.id, String).ilike(keyword_pattern), ) ) - business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)") + business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name 或 id)") # 获取总记录数 total = base_query.count() @@ -169,6 +164,98 @@ def get_workspace_end_users_paginated( business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}") raise +def get_workspace_end_users_paginated_rag( + db: Session, + workspace_id: uuid.UUID, + current_user: User, + page: int, + pagesize: int, + keyword: Optional[str] = None, +) -> Dict[str, Any]: + """RAG 模式宿主列表分页。 + + RAG 记忆数量以 documents.chunk_num 为准: + - file_name = end_user_id + ".txt" + - 只统计当前 workspace 下 permission_id="Memory" 的用户记忆知识库 + - 在 SQL 层过滤 chunk 总数为 0 的宿主,保证分页准确 + """ + business_logger.info( + f"获取 RAG 宿主列表(分页): workspace_id={workspace_id}, " + f"keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}" + ) + + try: + from app.models.document_model import Document + from app.models.knowledge_model import Knowledge + + chunk_subquery = ( + db.query( + Document.file_name.label("file_name"), + func.coalesce(func.sum(Document.chunk_num), 0).label("memory_count"), + ) + .join(Knowledge, Document.kb_id == Knowledge.id) + .filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id == "Memory", + Document.status == 1, + ) + .group_by(Document.file_name) + .subquery() + ) + + base_query = ( + db.query( + EndUserModel, + chunk_subquery.c.memory_count.label("memory_count"), + ) + .join( + chunk_subquery, + chunk_subquery.c.file_name == func.concat(cast(EndUserModel.id, String), ".txt"), + ) + .filter( + EndUserModel.workspace_id == workspace_id, + chunk_subquery.c.memory_count > 0, + ) + ) + + keyword = keyword.strip() if keyword else None + if keyword: + keyword_pattern = f"%{keyword}%" + base_query = base_query.filter( + or_( + EndUserModel.other_name.ilike(keyword_pattern), + cast(EndUserModel.id, String).ilike(keyword_pattern), + ) + ) + + total = base_query.count() + if total == 0: + business_logger.info("RAG 模式下没有符合条件的宿主") + return {"items": [], "total": 0} + + rows = base_query.order_by( + nullslast(desc(EndUserModel.created_at)), + desc(EndUserModel.id), + ).offset((page - 1) * pagesize).limit(pagesize).all() + + items = [] + for end_user_orm, memory_count in rows: + items.append({ + "end_user": EndUserSchema.model_validate(end_user_orm), + "memory_count": int(memory_count or 0), + }) + + business_logger.info(f"成功获取 RAG 宿主记录 {len(items)} 条,总计 {total} 条") + return {"items": items, "total": total} + + except HTTPException: + raise + except Exception as e: + business_logger.error( + f"获取 RAG 宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}" + ) + raise def get_workspace_memory_increment( db: Session,