[add] migrations script

This commit is contained in:
Mark
2026-01-28 15:24:55 +08:00
parent 7e56c09620
commit 44bf1eeae2

View File

@@ -0,0 +1,224 @@
"""202601281340
Revision ID: 915bed077f8d
Revises: 75f0ec80e50b
Create Date: 2026-01-28 13:38:49.471560
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '915bed077f8d'
down_revision: Union[str, None] = '75f0ec80e50b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
BACKUP_TABLE_NAME = 'model_api_keys_backup_20260123'
def get_temp_models():
"""创建临时模型,用于迁移过程中查询数据"""
metadata = sa.MetaData()
# 临时ModelApiKey表仅包含需要的字段
ModelApiKey = sa.Table(
'model_api_keys', metadata,
sa.Column('id', sa.UUID(), primary_key=True),
sa.Column('model_config_id', sa.UUID(), nullable=True),
)
# 临时关联表(和升级脚本创建的表结构一致)
ModelConfigApiKeyAssociation = sa.Table(
'model_config_api_key_association', metadata,
sa.Column('model_config_id', sa.UUID(), nullable=False),
sa.Column('api_key_id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
)
ModelApiKeyBackup = sa.Table(
BACKUP_TABLE_NAME, metadata,
sa.Column('id', sa.UUID(), primary_key=True),
sa.Column('model_name', sa.String(), nullable=False),
sa.Column('description', sa.String(), nullable=True),
sa.Column('provider', sa.String(), nullable=False),
sa.Column('api_key', sa.String(), nullable=False),
sa.Column('api_base', sa.String(), nullable=True),
sa.Column('config', sa.JSON(), nullable=True),
sa.Column('usage_count', sa.String(), default="0"),
sa.Column('last_used_at', sa.DateTime(), nullable=True),
sa.Column('priority', sa.String(), default="1"),
sa.Column('model_config_id', sa.UUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.Column('updated_at', sa.DateTime(), nullable=True),
sa.Column('is_active', sa.Boolean(), default=True),
)
return ModelApiKey, ModelConfigApiKeyAssociation, ModelApiKeyBackup
def backup_model_api_keys():
"""备份model_api_keys表的结构和数据"""
connection = op.get_bind()
# 检查备份表是否已存在
result = connection.execute(sa.text(f"""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_name = '{BACKUP_TABLE_NAME}'
);
""")).scalar()
if result:
# 备份表已存在,先删除再重建(确保结构一致)
op.execute(f"DROP TABLE IF EXISTS {BACKUP_TABLE_NAME};")
# 直接复制表结构和数据PostgreSQL专用一步完成
op.execute(f"""
CREATE TABLE {BACKUP_TABLE_NAME} AS
SELECT * FROM model_api_keys;
""")
# 统计行数
backup_count = connection.execute(sa.text(f"SELECT COUNT(*) FROM {BACKUP_TABLE_NAME}")).scalar()
original_count = connection.execute(sa.text("SELECT COUNT(*) FROM model_api_keys")).scalar()
print(
f"已备份model_api_keys表到 {BACKUP_TABLE_NAME} \n"
f" 原表数据行数:{original_count} | 备份表数据行数:{backup_count}"
)
# def restore_model_api_keys_from_backup():
# """从备份表恢复model_api_keys数据可选用于回滚失败时手动恢复"""
# # 1. 清空原表(谨慎使用!)
# # op.execute("TRUNCATE TABLE model_api_keys;")
#
# # 2. 从备份表恢复数据
# op.execute(f"""
# INSERT INTO model_api_keys
# SELECT * FROM {BACKUP_TABLE_NAME}
# ON CONFLICT (id) DO UPDATE SET
# model_name = EXCLUDED.model_name,
# description = EXCLUDED.description,
# provider = EXCLUDED.provider,
# api_key = EXCLUDED.api_key,
# api_base = EXCLUDED.api_base,
# config = EXCLUDED.config,
# usage_count = EXCLUDED.usage_count,
# last_used_at = EXCLUDED.last_used_at,
# priority = EXCLUDED.priority,
# model_config_id = EXCLUDED.model_config_id,
# created_at = EXCLUDED.created_at,
# updated_at = EXCLUDED.updated_at,
# is_active = EXCLUDED.is_active;
# """)
# print(f"✅ 已从 {BACKUP_TABLE_NAME} 恢复model_api_keys表数据")
def upgrade() -> None:
backup_model_api_keys()
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('model_bases',
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL'),
sa.Column('name', sa.String(), nullable=False, comment='模型唯一标识如gpt-3.5-turbo'),
sa.Column('type', sa.String(), nullable=False, comment='模型类型'),
sa.Column('provider', sa.String(), nullable=False),
sa.Column('description', sa.Text(), nullable=True, comment='模型描述'),
sa.Column('is_deprecated', sa.Boolean(), nullable=False, comment='是否弃用'),
sa.Column('is_official', sa.Boolean(), nullable=True, comment='是否供应商官方模型(区分自定义)'),
sa.Column('tags', sa.ARRAY(sa.String()), nullable=False, comment="模型标签(如['聊天', '创作']"),
sa.Column('add_count', sa.Integer(), nullable=False, comment='模型被用户添加的次数'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name', 'provider', name='uk_model_name_provider')
)
op.create_index(op.f('ix_model_bases_id'), 'model_bases', ['id'], unique=False)
op.create_index(op.f('ix_model_bases_provider'), 'model_bases', ['provider'], unique=False)
op.create_index(op.f('ix_model_bases_type'), 'model_bases', ['type'], unique=False)
op.create_table('model_config_api_key_association',
sa.Column('model_config_id', sa.UUID(), nullable=False),
sa.Column('api_key_id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['api_key_id'], ['model_api_keys.id'], ),
sa.ForeignKeyConstraint(['model_config_id'], ['model_configs.id'], ),
sa.PrimaryKeyConstraint('model_config_id', 'api_key_id')
)
op.add_column('model_api_keys', sa.Column('description', sa.String(), nullable=True, comment='备注'))
op.add_column('model_configs', sa.Column('model_id', sa.UUID(), nullable=True, comment='基础模型ID'))
op.add_column('model_configs', sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL'))
op.add_column('model_configs', sa.Column('provider', sa.String(), server_default='composite', nullable=False, comment='供应商'))
op.add_column('model_configs', sa.Column('is_composite', sa.Boolean(), server_default='true', nullable=False, comment='是否为组合模型'))
op.add_column('model_configs', sa.Column('load_balance_strategy', sa.String(), nullable=True, comment='负载均衡策略'))
op.create_index(op.f('ix_model_configs_model_id'), 'model_configs', ['model_id'], unique=False)
op.create_foreign_key("model_configs_model_id_fkey", 'model_configs', 'model_bases', ['model_id'], ['id'])
connection = op.get_bind()
ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models()
# 查询所有有model_config_id的API Key
api_keys = connection.execute(
sa.select(ModelApiKey.c.id, ModelApiKey.c.model_config_id)
.where(ModelApiKey.c.model_config_id.isnot(None))
).fetchall()
# 批量插入到多对多表
if api_keys:
association_data = [
{
'model_config_id': row.model_config_id,
'api_key_id': row.id
}
for row in api_keys
]
connection.execute(ModelConfigApiKeyAssociation.insert(), association_data)
op.drop_constraint(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', type_='foreignkey')
op.drop_column('model_api_keys', 'model_config_id')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("model_configs_model_id_fkey", 'model_configs', type_='foreignkey')
op.drop_index(op.f('ix_model_configs_model_id'), table_name='model_configs')
op.drop_column('model_configs', 'load_balance_strategy')
op.drop_column('model_configs', 'is_composite')
op.drop_column('model_configs', 'provider')
op.drop_column('model_configs', 'logo')
op.drop_column('model_configs', 'model_id')
op.add_column('model_api_keys', sa.Column('model_config_id', sa.UUID(), autoincrement=False, nullable=True, comment='模型配置ID'))
connection = op.get_bind()
ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models()
# 查询多对多表中的关联数据取每个API Key的第一个关联的model_config_id
association_data = connection.execute(
sa.select(
ModelConfigApiKeyAssociation.c.api_key_id,
ModelConfigApiKeyAssociation.c.model_config_id
).distinct(ModelConfigApiKeyAssociation.c.api_key_id)
).fetchall()
# 批量更新model_api_keys表
if association_data:
for api_key_id, model_config_id in association_data:
connection.execute(
sa.update(ModelApiKey)
.where(ModelApiKey.c.id == api_key_id)
.values(model_config_id=model_config_id)
)
op.execute(
"UPDATE model_api_keys SET model_config_id = '00000000-0000-0000-0000-000000000000' WHERE model_config_id IS NULL")
op.alter_column('model_api_keys', 'model_config_id', nullable=False)
op.create_foreign_key(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', 'model_configs', ['model_config_id'], ['id'])
op.drop_column('model_api_keys', 'description')
op.drop_table('model_config_api_key_association')
# ### 可选:回滚时恢复备份(如需)###
# restore_model_api_keys_from_backup()
print(
f"回滚完成!备份表 {BACKUP_TABLE_NAME} 仍保留,如需手动恢复可执行 restore_model_api_keys_from_backup() 函数")
op.drop_index(op.f('ix_model_bases_type'), table_name='model_bases')
op.drop_index(op.f('ix_model_bases_provider'), table_name='model_bases')
op.drop_index(op.f('ix_model_bases_id'), table_name='model_bases')
op.drop_table('model_bases')
# ### end Alembic commands ###