diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 525fe1eb..56276339 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -1,4 +1,4 @@ -import asyncio + import uuid from fastapi import APIRouter, Depends, HTTPException, status, Query from pydantic import BaseModel, Field @@ -10,7 +10,7 @@ from app.dependencies import get_current_user from app.models.user_model import User from app.schemas.response_schema import ApiResponse -from app.services import memory_dashboard_service, memory_storage_service, workspace_service +from app.services import memory_dashboard_service, workspace_service from app.services.memory_agent_service import get_end_users_connected_configs_batch from app.services.app_statistics_service import AppStatisticsService from app.core.logging_config import get_api_logger @@ -48,7 +48,7 @@ def get_workspace_total_end_users( @router.get("/end_users", response_model=ApiResponse) -async def get_workspace_end_users( +def get_workspace_end_users( workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"), keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"), page: int = Query(1, ge=1, description="页码,从1开始"), @@ -58,6 +58,15 @@ async def get_workspace_end_users( ): """ 获取工作空间的宿主列表(分页查询,支持模糊搜索) + + 新增:记忆数量过滤: + Neo4j 模式: + - 使用 end_users.memory_count 过滤 memory_count > 0 的宿主 + - memory_num.total 直接取 end_user.memory_count + + RAG 模式: + - 使用 documents.chunk_num 聚合过滤 chunk 总数 > 0 的宿主 + - memory_num.total 取聚合后的 chunk 总数 返回工作空间下的宿主列表,支持分页查询和模糊搜索。 通过 keyword 参数同时模糊匹配 other_name 和 id 字段。 @@ -80,17 +89,29 @@ async def get_workspace_end_users( current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}") - # 获取分页的 end_users - end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( - db=db, - workspace_id=workspace_id, - current_user=current_user, - page=page, - pagesize=pagesize, - keyword=keyword - ) + if current_workspace_type == "rag": + end_users_result = memory_dashboard_service.get_workspace_end_users_paginated_rag( + db=db, + workspace_id=workspace_id, + current_user=current_user, + page=page, + pagesize=pagesize, + keyword=keyword, + ) + raw_items = end_users_result.get("items", []) + end_users = [item["end_user"] for item in raw_items] + else: + end_users_result = memory_dashboard_service.get_workspace_end_users_paginated( + db=db, + workspace_id=workspace_id, + current_user=current_user, + page=page, + pagesize=pagesize, + keyword=keyword, + ) + raw_items = end_users_result.get("items", []) + end_users = raw_items - end_users = end_users_result.get("items", []) total = end_users_result.get("total", 0) if not end_users: @@ -101,50 +122,19 @@ async def get_workspace_end_users( "page": page, "pagesize": pagesize, "total": total, - "hasnext": (page * pagesize) < total - } + "hasnext": (page * pagesize) < total, + }, }, msg="宿主列表获取成功") end_user_ids = [str(user.id) for user in end_users] - # 并发执行两个独立的查询任务 - async def get_memory_configs(): - """获取记忆配置(在线程池中执行同步查询)""" - try: - return await asyncio.to_thread( - get_end_users_connected_configs_batch, - end_user_ids, db - ) - except Exception as e: - api_logger.error(f"批量获取记忆配置失败: {str(e)}") - return {} + try: + memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db) + except Exception as e: + api_logger.error(f"批量获取记忆配置失败: {str(e)}") + memory_configs_map = {} - async def get_memory_nums(): - """获取记忆数量""" - if current_workspace_type == "rag": - # RAG 模式:批量查询 - try: - chunk_map = await asyncio.to_thread( - memory_dashboard_service.get_users_total_chunk_batch, - end_user_ids, db, current_user - ) - return {uid: {"total": count} for uid, count in chunk_map.items()} - except Exception as e: - api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}") - return {uid: {"total": 0} for uid in end_user_ids} - - elif current_workspace_type == "neo4j": - # Neo4j 模式:批量查询(简化版本,只返回total) - try: - batch_result = await memory_storage_service.search_all_batch(end_user_ids) - return {uid: {"total": count} for uid, count in batch_result.items()} - except Exception as e: - api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}") - return {uid: {"total": 0} for uid in end_user_ids} - - return {uid: {"total": 0} for uid in end_user_ids} - - # 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据 + # 触发按需初始化:为 implicit_emotions_storage / interest_distribution 中没有记录的用户异步生成数据 try: from app.celery_app import celery_app as _celery_app _celery_app.send_task( @@ -159,27 +149,26 @@ async def get_workspace_end_users( except Exception as e: api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}") - # 并发执行配置查询和记忆数量查询 - memory_configs_map, memory_nums_map = await asyncio.gather( - get_memory_configs(), - get_memory_nums() - ) - - # 构建结果列表 items = [] - for end_user in end_users: + for index, end_user in enumerate(end_users): user_id = str(end_user.id) config_info = memory_configs_map.get(user_id, {}) + + if current_workspace_type == "rag": + memory_total = int(raw_items[index].get("memory_count", 0) or 0) + else: + memory_total = int(getattr(end_user, "memory_count", 0) or 0) + items.append({ - 'end_user': { - 'id': user_id, - 'other_name': end_user.other_name + "end_user": { + "id": user_id, + "other_name": end_user.other_name, }, - 'memory_num': memory_nums_map.get(user_id, {"total": 0}), - 'memory_config': { + "memory_num": {"total": memory_total}, + "memory_config": { "memory_config_id": config_info.get("memory_config_id"), - "memory_config_name": config_info.get("memory_config_name") - } + "memory_config_name": config_info.get("memory_config_name"), + }, }) # 触发社区聚类补全任务(异步,不阻塞接口响应) @@ -407,6 +396,7 @@ def get_current_user_rag_total_num( total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user) return success(data=total_chunk, msg="宿主RAG知识数据获取成功") + @router.get("/rag_content", response_model=ApiResponse) def get_rag_content( end_user_id: str = Query(..., description="宿主ID"), diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 3b0ea1ee..9b0be9c8 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -313,6 +313,9 @@ async def write( except Exception as cache_err: logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + #同步neo4j记忆节点总数到pgsql,end_user表的memory_count字段 + await _sync_memory_count_after_write(end_user_id) + # Close LLM/Embedder underlying httpx clients to prevent # 'RuntimeError: Event loop is closed' during garbage collection for client_obj in (llm_client, embedder_client): @@ -331,3 +334,49 @@ async def write( logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") + + +async def _sync_memory_count_after_write(end_user_id: str) -> None: + """ + 记忆写入完成后,查 Neo4j 全量节点数,绝对值同步到 PostgreSQL end_user 表的 memory_count 字段 + + 不使用增量累加: + - Neo4j 写入使用 MERGE 语义,节点列表长度不等于新增节点数。 + - 重试或重复写入可能匹配已有节点。 + - 绝对值覆盖可以避免越加越大的计数漂移。 + """ + if not end_user_id: + return + + try: + from app.models.end_user_model import EndUser + from app.repositories.memory_config_repository import MemoryConfigRepository + + connector = Neo4jConnector() + try: + result = await connector.execute_query( + MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, + end_user_ids=[end_user_id], + ) + node_count = int(result[0]["total"]) if result else 0 + finally: + await connector.close() + + with get_db_context() as db: + db.query(EndUser).filter( + EndUser.id == uuid.UUID(end_user_id) + ).update( + {"memory_count": node_count}, + synchronize_session=False, + ) + db.commit() + + logger.info( + f"[MemoryCount] 写入后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) \ No newline at end of file diff --git a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py index 072d587c..acd436c7 100644 --- a/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py +++ b/api/app/core/memory/storage_services/forgetting_engine/forgetting_scheduler.py @@ -145,7 +145,8 @@ class ForgettingScheduler: } logger.info("没有可遗忘的节点对,遗忘周期结束") - + # 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段 + await self._sync_memory_count_to_mysql(end_user_id) return report # 步骤3:按激活值排序(激活值最低的优先) @@ -302,7 +303,8 @@ class ForgettingScheduler: f"({reduction_rate:.2%}), " f"耗时 {duration:.2f} 秒" ) - + # 同步 Neo4j 记忆节点总数到 PostgreSQL的 end_user 表的 memory_count 字段 + await self._sync_memory_count_to_mysql(end_user_id) return report except Exception as e: @@ -350,3 +352,48 @@ class ForgettingScheduler: if results: return results[0]['total'] return 0 + + async def _sync_memory_count_to_mysql( + self, + end_user_id: Optional[str] = None, + ) -> None: + """ + 遗忘周期结束后,用 SEARCH_FOR_ALL_BATCH 口径查全量节点数, + 同步到 PostgreSQL end_users.memory_count。 + + 不复用 _count_knowledge_nodes: + - _count_knowledge_nodes 只统计 Statement、ExtractedEntity、MemorySummary。 + - 宿主列表需要统计该 end_user_id 下全部 Neo4j 节点。 + """ + if not end_user_id: + return + + try: + from app.db import get_db_context + from app.models.end_user_model import EndUser + from app.repositories.memory_config_repository import MemoryConfigRepository + + result = await self.connector.execute_query( + MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, + end_user_ids=[end_user_id], + ) + node_count = int(result[0]["total"]) if result else 0 + + with get_db_context() as db: + db.query(EndUser).filter( + EndUser.id == UUID(end_user_id) + ).update( + {"memory_count": node_count}, + synchronize_session=False, + ) + db.commit() + + logger.info( + f"[MemoryCount] 遗忘后同步 memory_count: " + f"end_user_id={end_user_id}, count={node_count}" + ) + except Exception as e: + logger.warning( + f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}", + exc_info=True, + ) diff --git a/api/app/models/end_user_model.py b/api/app/models/end_user_model.py index ff46786a..952d58eb 100644 --- a/api/app/models/end_user_model.py +++ b/api/app/models/end_user_model.py @@ -1,7 +1,7 @@ import datetime import uuid -from sqlalchemy import Column, DateTime, ForeignKey, String, Text +from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship @@ -38,6 +38,15 @@ class EndUser(Base): comment="关联的记忆配置ID" ) + memory_count = Column( + Integer, + nullable=False, + default=0, + server_default="0", + index=True, + comment="记忆节点总数", + ) + # 用户摘要四个维度 - User Summary Four Dimensions user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)") personality_traits = Column(Text, nullable=True, comment="性格特点") diff --git a/api/app/schemas/end_user_schema.py b/api/app/schemas/end_user_schema.py index c2498203..94d2e5dd 100644 --- a/api/app/schemas/end_user_schema.py +++ b/api/app/schemas/end_user_schema.py @@ -19,4 +19,6 @@ class EndUser(BaseModel): # 用户摘要和洞察更新时间 user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None) - memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) \ No newline at end of file + memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None) + #用户记忆节点总数(Neo4j模式) + memory_count: int = Field(description="记忆节点总数", default=0) \ No newline at end of file diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index aaf9ac6d..6ce793a1 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -1,5 +1,5 @@ from sqlalchemy.orm import Session -from sqlalchemy import desc, nullslast, or_, and_, cast, String +from sqlalchemy import desc, nullslast, or_, and_, cast, String, func from typing import List, Optional, Dict, Any import uuid from fastapi import HTTPException @@ -102,6 +102,7 @@ def get_workspace_end_users_paginated( """获取工作空间的宿主列表(分页版本,支持模糊搜索) 返回结果按 created_at 从新到旧排序(NULL 值排在最后) + 固定过滤 memory_count > 0 的宿主,保证分页基于“有记忆宿主”集合计算。 支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段 Args: @@ -120,7 +121,8 @@ def get_workspace_end_users_paginated( try: # 构建基础查询 base_query = db.query(EndUserModel).filter( - EndUserModel.workspace_id == workspace_id + EndUserModel.workspace_id == workspace_id, + EndUserModel.memory_count > 0 , # 只查询有记忆的宿主 ) # 构建搜索条件(过滤空字符串和None) @@ -169,6 +171,104 @@ def get_workspace_end_users_paginated( business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}") raise +def get_workspace_end_users_paginated_rag( + db: Session, + workspace_id: uuid.UUID, + current_user: User, + page: int, + pagesize: int, + keyword: Optional[str] = None, +) -> Dict[str, Any]: + """RAG 模式宿主列表分页。 + + RAG 记忆数量以 documents.chunk_num 为准: + - file_name = end_user_id + ".txt" + - 只统计当前 workspace 下 permission_id="Memory" 的用户记忆知识库 + - 在 SQL 层过滤 chunk 总数为 0 的宿主,保证分页准确 + """ + business_logger.info( + f"获取 RAG 宿主列表(分页): workspace_id={workspace_id}, " + f"keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}" + ) + + try: + from app.models.document_model import Document + from app.models.knowledge_model import Knowledge + + chunk_subquery = ( + db.query( + Document.file_name.label("file_name"), + func.coalesce(func.sum(Document.chunk_num), 0).label("memory_count"), + ) + .join(Knowledge, Document.kb_id == Knowledge.id) + .filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id == "Memory", + Document.status == 1, + ) + .group_by(Document.file_name) + .subquery() + ) + + base_query = ( + db.query( + EndUserModel, + chunk_subquery.c.memory_count.label("memory_count"), + ) + .join( + chunk_subquery, + chunk_subquery.c.file_name == func.concat(cast(EndUserModel.id, String), ".txt"), + ) + .filter( + EndUserModel.workspace_id == workspace_id, + chunk_subquery.c.memory_count > 0, + ) + ) + + keyword = keyword.strip() if keyword else None + if keyword: + keyword_pattern = f"%{keyword}%" + base_query = base_query.filter( + or_( + EndUserModel.other_name.ilike(keyword_pattern), + and_( + or_( + EndUserModel.other_name.is_(None), + EndUserModel.other_name == "", + ), + cast(EndUserModel.id, String).ilike(keyword_pattern), + ), + ) + ) + + total = base_query.count() + if total == 0: + business_logger.info("RAG 模式下没有符合条件的宿主") + return {"items": [], "total": 0} + + rows = base_query.order_by( + nullslast(desc(EndUserModel.created_at)), + desc(EndUserModel.id), + ).offset((page - 1) * pagesize).limit(pagesize).all() + + items = [] + for end_user_orm, memory_count in rows: + items.append({ + "end_user": EndUserSchema.model_validate(end_user_orm), + "memory_count": int(memory_count or 0), + }) + + business_logger.info(f"成功获取 RAG 宿主记录 {len(items)} 条,总计 {total} 条") + return {"items": items, "total": total} + + except HTTPException: + raise + except Exception as e: + business_logger.error( + f"获取 RAG 宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}" + ) + raise def get_workspace_memory_increment( db: Session,