140 lines
4.0 KiB
Python
140 lines
4.0 KiB
Python
"""202604291755
|
|
|
|
Revision ID: 37e2a73b28c4
|
|
Revises: e2d60c6d1a1a
|
|
Create Date: 2026-04-29 18:52:35.686290
|
|
|
|
"""
|
|
from typing import Dict, List, Sequence, Union
|
|
|
|
from alembic import op
|
|
import sqlalchemy as sa
|
|
|
|
# revision identifiers, used by Alembic.
|
|
revision: str = '37e2a73b28c4'
|
|
down_revision: Union[str, None] = 'e2d60c6d1a1a'
|
|
branch_labels: Union[str, Sequence[str], None] = None
|
|
depends_on: Union[str, Sequence[str], None] = None
|
|
|
|
BATCH_SIZE = 500
|
|
|
|
def _chunked(values: List[str], size: int) -> List[List[str]]:
|
|
return [values[index:index + size] for index in range(0, len(values), size)]
|
|
|
|
|
|
def _load_neo4j_end_user_ids(connection) -> List[str]:
|
|
"""加载所有需要从 Neo4j 同步 memory_count 的宿主。
|
|
|
|
RAG 工作空间的记忆数量以 documents.chunk_num 为准,不写入 end_users.memory_count。
|
|
"""
|
|
rows = connection.execute(sa.text("""
|
|
SELECT eu.id::text AS end_user_id
|
|
FROM end_users eu
|
|
JOIN workspaces w ON eu.workspace_id = w.id
|
|
WHERE w.storage_type IS NULL OR w.storage_type <> 'rag'
|
|
""")).all()
|
|
return [row[0] for row in rows]
|
|
|
|
|
|
async def _fetch_neo4j_counts(end_user_ids: List[str]) -> Dict[str, int]:
|
|
if not end_user_ids:
|
|
return {}
|
|
|
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
|
|
|
connector = Neo4jConnector()
|
|
try:
|
|
result = await connector.execute_query(
|
|
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
|
end_user_ids=end_user_ids,
|
|
)
|
|
finally:
|
|
await connector.close()
|
|
|
|
counts = {str(row["user_id"]): int(row["total"]) for row in result}
|
|
for end_user_id in end_user_ids:
|
|
counts.setdefault(end_user_id, 0)
|
|
return counts
|
|
|
|
|
|
def _update_memory_counts(connection, counts: Dict[str, int]) -> int:
|
|
updated = 0
|
|
for end_user_id, memory_count in counts.items():
|
|
result = connection.execute(
|
|
sa.text("""
|
|
UPDATE end_users
|
|
SET memory_count = :memory_count
|
|
WHERE id = CAST(:end_user_id AS uuid)
|
|
"""),
|
|
{
|
|
"end_user_id": end_user_id,
|
|
"memory_count": memory_count,
|
|
},
|
|
)
|
|
updated += result.rowcount or 0
|
|
return updated
|
|
|
|
|
|
def _sync_memory_count_from_neo4j() -> None:
|
|
"""迁移时初始化 Neo4j 模式宿主的 memory_count。
|
|
|
|
"""
|
|
import asyncio
|
|
|
|
print("[memory_count] 开始同步 Neo4j 模式宿主 memory_count")
|
|
connection = op.get_bind()
|
|
target_ids = _load_neo4j_end_user_ids(connection)
|
|
if not target_ids:
|
|
print("[memory_count] 没有需要同步的 Neo4j 模式宿主")
|
|
return
|
|
|
|
print(
|
|
f"[memory_count] 待同步宿主数量: {len(target_ids)}, "
|
|
f"batch_size={BATCH_SIZE}"
|
|
)
|
|
|
|
total_updated = 0
|
|
batches = _chunked(target_ids, BATCH_SIZE)
|
|
for batch_index, batch_ids in enumerate(batches, start=1):
|
|
print(
|
|
f"[memory_count] 正在查询 Neo4j: "
|
|
f"batch={batch_index}/{len(batches)}, size={len(batch_ids)}"
|
|
)
|
|
counts = asyncio.run(_fetch_neo4j_counts(batch_ids))
|
|
total_updated += _update_memory_counts(connection, counts)
|
|
print(
|
|
f"[memory_count] 已写入 PostgreSQL: "
|
|
f"updated={total_updated}/{len(target_ids)}"
|
|
)
|
|
|
|
print(
|
|
f"[memory_count] Neo4j 模式宿主同步完成: "
|
|
f"total={len(target_ids)}, updated={total_updated}"
|
|
)
|
|
|
|
|
|
def upgrade() -> None:
|
|
op.add_column(
|
|
'end_users',
|
|
sa.Column(
|
|
'memory_count',
|
|
sa.Integer(),
|
|
server_default='0',
|
|
nullable=False,
|
|
comment='记忆节点总数',
|
|
),
|
|
)
|
|
_sync_memory_count_from_neo4j()
|
|
op.create_index(
|
|
op.f('ix_end_users_memory_count'),
|
|
'end_users',
|
|
['memory_count'],
|
|
unique=False,
|
|
)
|
|
|
|
|
|
def downgrade() -> None:
|
|
op.drop_index(op.f('ix_end_users_memory_count'), table_name='end_users')
|
|
op.drop_column('end_users', 'memory_count')
|