[add] migration script
This commit is contained in:
139
api/migrations/versions/37e2a73b28c4_202604291755.py
Normal file
139
api/migrations/versions/37e2a73b28c4_202604291755.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""202604291755
|
||||||
|
|
||||||
|
Revision ID: 37e2a73b28c4
|
||||||
|
Revises: e2d60c6d1a1a
|
||||||
|
Create Date: 2026-04-29 18:52:35.686290
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Dict, List, Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '37e2a73b28c4'
|
||||||
|
down_revision: Union[str, None] = 'e2d60c6d1a1a'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
BATCH_SIZE = 500
|
||||||
|
|
||||||
|
def _chunked(values: List[str], size: int) -> List[List[str]]:
|
||||||
|
return [values[index:index + size] for index in range(0, len(values), size)]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_neo4j_end_user_ids(connection) -> List[str]:
|
||||||
|
"""加载所有需要从 Neo4j 同步 memory_count 的宿主。
|
||||||
|
|
||||||
|
RAG 工作空间的记忆数量以 documents.chunk_num 为准,不写入 end_users.memory_count。
|
||||||
|
"""
|
||||||
|
rows = connection.execute(sa.text("""
|
||||||
|
SELECT eu.id::text AS end_user_id
|
||||||
|
FROM end_users eu
|
||||||
|
JOIN workspaces w ON eu.workspace_id = w.id
|
||||||
|
WHERE w.storage_type IS NULL OR w.storage_type <> 'rag'
|
||||||
|
""")).all()
|
||||||
|
return [row[0] for row in rows]
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_neo4j_counts(end_user_ids: List[str]) -> Dict[str, int]:
|
||||||
|
if not end_user_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
result = await connector.execute_query(
|
||||||
|
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||||
|
end_user_ids=end_user_ids,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
counts = {str(row["user_id"]): int(row["total"]) for row in result}
|
||||||
|
for end_user_id in end_user_ids:
|
||||||
|
counts.setdefault(end_user_id, 0)
|
||||||
|
return counts
|
||||||
|
|
||||||
|
|
||||||
|
def _update_memory_counts(connection, counts: Dict[str, int]) -> int:
|
||||||
|
updated = 0
|
||||||
|
for end_user_id, memory_count in counts.items():
|
||||||
|
result = connection.execute(
|
||||||
|
sa.text("""
|
||||||
|
UPDATE end_users
|
||||||
|
SET memory_count = :memory_count
|
||||||
|
WHERE id = CAST(:end_user_id AS uuid)
|
||||||
|
"""),
|
||||||
|
{
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"memory_count": memory_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
updated += result.rowcount or 0
|
||||||
|
return updated
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_memory_count_from_neo4j() -> None:
|
||||||
|
"""迁移时初始化 Neo4j 模式宿主的 memory_count。
|
||||||
|
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
print("[memory_count] 开始同步 Neo4j 模式宿主 memory_count")
|
||||||
|
connection = op.get_bind()
|
||||||
|
target_ids = _load_neo4j_end_user_ids(connection)
|
||||||
|
if not target_ids:
|
||||||
|
print("[memory_count] 没有需要同步的 Neo4j 模式宿主")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[memory_count] 待同步宿主数量: {len(target_ids)}, "
|
||||||
|
f"batch_size={BATCH_SIZE}"
|
||||||
|
)
|
||||||
|
|
||||||
|
total_updated = 0
|
||||||
|
batches = _chunked(target_ids, BATCH_SIZE)
|
||||||
|
for batch_index, batch_ids in enumerate(batches, start=1):
|
||||||
|
print(
|
||||||
|
f"[memory_count] 正在查询 Neo4j: "
|
||||||
|
f"batch={batch_index}/{len(batches)}, size={len(batch_ids)}"
|
||||||
|
)
|
||||||
|
counts = asyncio.run(_fetch_neo4j_counts(batch_ids))
|
||||||
|
total_updated += _update_memory_counts(connection, counts)
|
||||||
|
print(
|
||||||
|
f"[memory_count] 已写入 PostgreSQL: "
|
||||||
|
f"updated={total_updated}/{len(target_ids)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[memory_count] Neo4j 模式宿主同步完成: "
|
||||||
|
f"total={len(target_ids)}, updated={total_updated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
'end_users',
|
||||||
|
sa.Column(
|
||||||
|
'memory_count',
|
||||||
|
sa.Integer(),
|
||||||
|
server_default='0',
|
||||||
|
nullable=False,
|
||||||
|
comment='记忆节点总数',
|
||||||
|
),
|
||||||
|
)
|
||||||
|
_sync_memory_count_from_neo4j()
|
||||||
|
op.create_index(
|
||||||
|
op.f('ix_end_users_memory_count'),
|
||||||
|
'end_users',
|
||||||
|
['memory_count'],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(op.f('ix_end_users_memory_count'), table_name='end_users')
|
||||||
|
op.drop_column('end_users', 'memory_count')
|
||||||
Reference in New Issue
Block a user