"""202604291755 Revision ID: 37e2a73b28c4 Revises: e2d60c6d1a1a Create Date: 2026-04-29 18:52:35.686290 """ from typing import Dict, List, Sequence, Union from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision: str = '37e2a73b28c4' down_revision: Union[str, None] = 'e2d60c6d1a1a' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None BATCH_SIZE = 500 def _chunked(values: List[str], size: int) -> List[List[str]]: return [values[index:index + size] for index in range(0, len(values), size)] def _load_neo4j_end_user_ids(connection) -> List[str]: """加载所有需要从 Neo4j 同步 memory_count 的宿主。 RAG 工作空间的记忆数量以 documents.chunk_num 为准,不写入 end_users.memory_count。 """ rows = connection.execute(sa.text(""" SELECT eu.id::text AS end_user_id FROM end_users eu JOIN workspaces w ON eu.workspace_id = w.id WHERE w.storage_type IS NULL OR w.storage_type <> 'rag' """)).all() return [row[0] for row in rows] async def _fetch_neo4j_counts(end_user_ids: List[str]) -> Dict[str, int]: if not end_user_ids: return {} from app.repositories.memory_config_repository import MemoryConfigRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector connector = Neo4jConnector() try: result = await connector.execute_query( MemoryConfigRepository.SEARCH_FOR_ALL_BATCH, end_user_ids=end_user_ids, ) finally: await connector.close() counts = {str(row["user_id"]): int(row["total"]) for row in result} for end_user_id in end_user_ids: counts.setdefault(end_user_id, 0) return counts def _update_memory_counts(connection, counts: Dict[str, int]) -> int: updated = 0 for end_user_id, memory_count in counts.items(): result = connection.execute( sa.text(""" UPDATE end_users SET memory_count = :memory_count WHERE id = CAST(:end_user_id AS uuid) """), { "end_user_id": end_user_id, "memory_count": memory_count, }, ) updated += result.rowcount or 0 return updated def _sync_memory_count_from_neo4j() -> None: """迁移时初始化 Neo4j 模式宿主的 memory_count。 """ import asyncio print("[memory_count] 开始同步 Neo4j 模式宿主 memory_count") connection = op.get_bind() target_ids = _load_neo4j_end_user_ids(connection) if not target_ids: print("[memory_count] 没有需要同步的 Neo4j 模式宿主") return print( f"[memory_count] 待同步宿主数量: {len(target_ids)}, " f"batch_size={BATCH_SIZE}" ) total_updated = 0 batches = _chunked(target_ids, BATCH_SIZE) for batch_index, batch_ids in enumerate(batches, start=1): print( f"[memory_count] 正在查询 Neo4j: " f"batch={batch_index}/{len(batches)}, size={len(batch_ids)}" ) counts = asyncio.run(_fetch_neo4j_counts(batch_ids)) total_updated += _update_memory_counts(connection, counts) print( f"[memory_count] 已写入 PostgreSQL: " f"updated={total_updated}/{len(target_ids)}" ) print( f"[memory_count] Neo4j 模式宿主同步完成: " f"total={len(target_ids)}, updated={total_updated}" ) def upgrade() -> None: op.add_column( 'end_users', sa.Column( 'memory_count', sa.Integer(), server_default='0', nullable=False, comment='记忆节点总数', ), ) _sync_memory_count_from_neo4j() op.create_index( op.f('ix_end_users_memory_count'), 'end_users', ['memory_count'], unique=False, ) def downgrade() -> None: op.drop_index(op.f('ix_end_users_memory_count'), table_name='end_users') op.drop_column('end_users', 'memory_count')