225 lines
10 KiB
Python
225 lines
10 KiB
Python
"""202601281340
|
||
|
||
Revision ID: 915bed077f8d
|
||
Revises: 75f0ec80e50b
|
||
Create Date: 2026-01-28 13:38:49.471560
|
||
|
||
"""
|
||
from typing import Sequence, Union
|
||
|
||
from alembic import op
|
||
import sqlalchemy as sa
|
||
from sqlalchemy.dialects import postgresql
|
||
|
||
# revision identifiers, used by Alembic.
|
||
revision: str = '915bed077f8d'
|
||
down_revision: Union[str, None] = '75f0ec80e50b'
|
||
branch_labels: Union[str, Sequence[str], None] = None
|
||
depends_on: Union[str, Sequence[str], None] = None
|
||
|
||
BACKUP_TABLE_NAME = 'model_api_keys_backup_20260123'
|
||
|
||
def get_temp_models():
|
||
"""创建临时模型,用于迁移过程中查询数据"""
|
||
metadata = sa.MetaData()
|
||
|
||
# 临时ModelApiKey表(仅包含需要的字段)
|
||
ModelApiKey = sa.Table(
|
||
'model_api_keys', metadata,
|
||
sa.Column('id', sa.UUID(), primary_key=True),
|
||
sa.Column('model_config_id', sa.UUID(), nullable=True),
|
||
)
|
||
|
||
# 临时关联表(和升级脚本创建的表结构一致)
|
||
ModelConfigApiKeyAssociation = sa.Table(
|
||
'model_config_api_key_association', metadata,
|
||
sa.Column('model_config_id', sa.UUID(), nullable=False),
|
||
sa.Column('api_key_id', sa.UUID(), nullable=False),
|
||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||
)
|
||
|
||
ModelApiKeyBackup = sa.Table(
|
||
BACKUP_TABLE_NAME, metadata,
|
||
sa.Column('id', sa.UUID(), primary_key=True),
|
||
sa.Column('model_name', sa.String(), nullable=False),
|
||
sa.Column('description', sa.String(), nullable=True),
|
||
sa.Column('provider', sa.String(), nullable=False),
|
||
sa.Column('api_key', sa.String(), nullable=False),
|
||
sa.Column('api_base', sa.String(), nullable=True),
|
||
sa.Column('config', sa.JSON(), nullable=True),
|
||
sa.Column('usage_count', sa.String(), default="0"),
|
||
sa.Column('last_used_at', sa.DateTime(), nullable=True),
|
||
sa.Column('priority', sa.String(), default="1"),
|
||
sa.Column('model_config_id', sa.UUID(), nullable=True),
|
||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||
sa.Column('updated_at', sa.DateTime(), nullable=True),
|
||
sa.Column('is_active', sa.Boolean(), default=True),
|
||
)
|
||
|
||
return ModelApiKey, ModelConfigApiKeyAssociation, ModelApiKeyBackup
|
||
|
||
|
||
def backup_model_api_keys():
|
||
"""备份model_api_keys表的结构和数据"""
|
||
connection = op.get_bind()
|
||
|
||
# 检查备份表是否已存在
|
||
result = connection.execute(sa.text(f"""
|
||
SELECT EXISTS (
|
||
SELECT FROM information_schema.tables
|
||
WHERE table_name = '{BACKUP_TABLE_NAME}'
|
||
);
|
||
""")).scalar()
|
||
|
||
if result:
|
||
# 备份表已存在,先删除再重建(确保结构一致)
|
||
op.execute(f"DROP TABLE IF EXISTS {BACKUP_TABLE_NAME};")
|
||
|
||
# 直接复制表结构和数据(PostgreSQL专用,一步完成)
|
||
op.execute(f"""
|
||
CREATE TABLE {BACKUP_TABLE_NAME} AS
|
||
SELECT * FROM model_api_keys;
|
||
""")
|
||
|
||
# 统计行数
|
||
backup_count = connection.execute(sa.text(f"SELECT COUNT(*) FROM {BACKUP_TABLE_NAME}")).scalar()
|
||
original_count = connection.execute(sa.text("SELECT COUNT(*) FROM model_api_keys")).scalar()
|
||
|
||
print(
|
||
f"已备份model_api_keys表到 {BACKUP_TABLE_NAME} \n"
|
||
f" 原表数据行数:{original_count} | 备份表数据行数:{backup_count}"
|
||
)
|
||
|
||
# def restore_model_api_keys_from_backup():
|
||
# """从备份表恢复model_api_keys数据(可选,用于回滚失败时手动恢复)"""
|
||
# # 1. 清空原表(谨慎使用!)
|
||
# # op.execute("TRUNCATE TABLE model_api_keys;")
|
||
#
|
||
# # 2. 从备份表恢复数据
|
||
# op.execute(f"""
|
||
# INSERT INTO model_api_keys
|
||
# SELECT * FROM {BACKUP_TABLE_NAME}
|
||
# ON CONFLICT (id) DO UPDATE SET
|
||
# model_name = EXCLUDED.model_name,
|
||
# description = EXCLUDED.description,
|
||
# provider = EXCLUDED.provider,
|
||
# api_key = EXCLUDED.api_key,
|
||
# api_base = EXCLUDED.api_base,
|
||
# config = EXCLUDED.config,
|
||
# usage_count = EXCLUDED.usage_count,
|
||
# last_used_at = EXCLUDED.last_used_at,
|
||
# priority = EXCLUDED.priority,
|
||
# model_config_id = EXCLUDED.model_config_id,
|
||
# created_at = EXCLUDED.created_at,
|
||
# updated_at = EXCLUDED.updated_at,
|
||
# is_active = EXCLUDED.is_active;
|
||
# """)
|
||
# print(f"✅ 已从 {BACKUP_TABLE_NAME} 恢复model_api_keys表数据")
|
||
|
||
def upgrade() -> None:
|
||
backup_model_api_keys()
|
||
# ### commands auto generated by Alembic - please adjust! ###
|
||
op.create_table('model_bases',
|
||
sa.Column('id', sa.UUID(), nullable=False),
|
||
sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL'),
|
||
sa.Column('name', sa.String(), nullable=False, comment='模型唯一标识(如gpt-3.5-turbo)'),
|
||
sa.Column('type', sa.String(), nullable=False, comment='模型类型'),
|
||
sa.Column('provider', sa.String(), nullable=False),
|
||
sa.Column('description', sa.Text(), nullable=True, comment='模型描述'),
|
||
sa.Column('is_deprecated', sa.Boolean(), nullable=False, comment='是否弃用'),
|
||
sa.Column('is_official', sa.Boolean(), nullable=True, comment='是否供应商官方模型(区分自定义)'),
|
||
sa.Column('tags', sa.ARRAY(sa.String()), nullable=False, comment="模型标签(如['聊天', '创作'])"),
|
||
sa.Column('add_count', sa.Integer(), nullable=False, comment='模型被用户添加的次数'),
|
||
sa.PrimaryKeyConstraint('id'),
|
||
sa.UniqueConstraint('name', 'provider', name='uk_model_name_provider')
|
||
)
|
||
op.create_index(op.f('ix_model_bases_id'), 'model_bases', ['id'], unique=False)
|
||
op.create_index(op.f('ix_model_bases_provider'), 'model_bases', ['provider'], unique=False)
|
||
op.create_index(op.f('ix_model_bases_type'), 'model_bases', ['type'], unique=False)
|
||
op.create_table('model_config_api_key_association',
|
||
sa.Column('model_config_id', sa.UUID(), nullable=False),
|
||
sa.Column('api_key_id', sa.UUID(), nullable=False),
|
||
sa.Column('created_at', sa.DateTime(), nullable=True),
|
||
sa.ForeignKeyConstraint(['api_key_id'], ['model_api_keys.id'], ),
|
||
sa.ForeignKeyConstraint(['model_config_id'], ['model_configs.id'], ),
|
||
sa.PrimaryKeyConstraint('model_config_id', 'api_key_id')
|
||
)
|
||
op.add_column('model_api_keys', sa.Column('description', sa.String(), nullable=True, comment='备注'))
|
||
op.add_column('model_configs', sa.Column('model_id', sa.UUID(), nullable=True, comment='基础模型ID'))
|
||
op.add_column('model_configs', sa.Column('logo', sa.String(length=255), nullable=True, comment='模型logo图片URL'))
|
||
op.add_column('model_configs', sa.Column('provider', sa.String(), server_default='composite', nullable=False, comment='供应商'))
|
||
op.add_column('model_configs', sa.Column('is_composite', sa.Boolean(), server_default='true', nullable=False, comment='是否为组合模型'))
|
||
op.add_column('model_configs', sa.Column('load_balance_strategy', sa.String(), nullable=True, comment='负载均衡策略'))
|
||
op.create_index(op.f('ix_model_configs_model_id'), 'model_configs', ['model_id'], unique=False)
|
||
op.create_foreign_key("model_configs_model_id_fkey", 'model_configs', 'model_bases', ['model_id'], ['id'])
|
||
connection = op.get_bind()
|
||
ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models()
|
||
|
||
# 查询所有有model_config_id的API Key
|
||
api_keys = connection.execute(
|
||
sa.select(ModelApiKey.c.id, ModelApiKey.c.model_config_id)
|
||
.where(ModelApiKey.c.model_config_id.isnot(None))
|
||
).fetchall()
|
||
|
||
# 批量插入到多对多表
|
||
if api_keys:
|
||
association_data = [
|
||
{
|
||
'model_config_id': row.model_config_id,
|
||
'api_key_id': row.id
|
||
}
|
||
for row in api_keys
|
||
]
|
||
connection.execute(ModelConfigApiKeyAssociation.insert(), association_data)
|
||
op.drop_constraint(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', type_='foreignkey')
|
||
op.drop_column('model_api_keys', 'model_config_id')
|
||
# ### end Alembic commands ###
|
||
|
||
|
||
def downgrade() -> None:
|
||
# ### commands auto generated by Alembic - please adjust! ###
|
||
op.drop_constraint("model_configs_model_id_fkey", 'model_configs', type_='foreignkey')
|
||
op.drop_index(op.f('ix_model_configs_model_id'), table_name='model_configs')
|
||
op.drop_column('model_configs', 'load_balance_strategy')
|
||
op.drop_column('model_configs', 'is_composite')
|
||
op.drop_column('model_configs', 'provider')
|
||
op.drop_column('model_configs', 'logo')
|
||
op.drop_column('model_configs', 'model_id')
|
||
op.add_column('model_api_keys', sa.Column('model_config_id', sa.UUID(), autoincrement=False, nullable=True, comment='模型配置ID'))
|
||
connection = op.get_bind()
|
||
ModelApiKey, ModelConfigApiKeyAssociation, _ = get_temp_models()
|
||
|
||
# 查询多对多表中的关联数据(取每个API Key的第一个关联的model_config_id)
|
||
association_data = connection.execute(
|
||
sa.select(
|
||
ModelConfigApiKeyAssociation.c.api_key_id,
|
||
ModelConfigApiKeyAssociation.c.model_config_id
|
||
).distinct(ModelConfigApiKeyAssociation.c.api_key_id)
|
||
).fetchall()
|
||
|
||
# 批量更新model_api_keys表
|
||
if association_data:
|
||
for api_key_id, model_config_id in association_data:
|
||
connection.execute(
|
||
sa.update(ModelApiKey)
|
||
.where(ModelApiKey.c.id == api_key_id)
|
||
.values(model_config_id=model_config_id)
|
||
)
|
||
|
||
op.execute(
|
||
"UPDATE model_api_keys SET model_config_id = '00000000-0000-0000-0000-000000000000' WHERE model_config_id IS NULL")
|
||
op.alter_column('model_api_keys', 'model_config_id', nullable=False)
|
||
op.create_foreign_key(op.f('model_api_keys_model_config_id_fkey'), 'model_api_keys', 'model_configs', ['model_config_id'], ['id'])
|
||
op.drop_column('model_api_keys', 'description')
|
||
op.drop_table('model_config_api_key_association')
|
||
# ### 可选:回滚时恢复备份(如需)###
|
||
# restore_model_api_keys_from_backup()
|
||
|
||
print(
|
||
f"回滚完成!备份表 {BACKUP_TABLE_NAME} 仍保留,如需手动恢复可执行 restore_model_api_keys_from_backup() 函数")
|
||
op.drop_index(op.f('ix_model_bases_type'), table_name='model_bases')
|
||
op.drop_index(op.f('ix_model_bases_provider'), table_name='model_bases')
|
||
op.drop_index(op.f('ix_model_bases_id'), table_name='model_bases')
|
||
op.drop_table('model_bases')
|
||
# ### end Alembic commands ###
|