275 lines
8.3 KiB
Python
275 lines
8.3 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
数据库迁移辅助模块
|
|||
|
提供常用的迁移操作和错误处理
|
|||
|
"""
|
|||
|
|
|||
|
from alembic import op
|
|||
|
import sqlalchemy as sa
|
|||
|
from sqlalchemy import text
|
|||
|
from sqlalchemy.engine import reflection
|
|||
|
import logging
|
|||
|
import uuid
|
|||
|
|
|||
|
# 配置日志
|
|||
|
logging.basicConfig(level=logging.INFO)
|
|||
|
logger = logging.getLogger('alembic.migration')
|
|||
|
|
|||
|
|
|||
|
def safe_execute(func, *args, **kwargs):
|
|||
|
"""
|
|||
|
安全执行函数,捕获异常并记录日志
|
|||
|
|
|||
|
Args:
|
|||
|
func: 要执行的函数
|
|||
|
*args: 位置参数
|
|||
|
**kwargs: 关键字参数
|
|||
|
|
|||
|
Returns:
|
|||
|
执行结果或 None(如果发生异常)
|
|||
|
"""
|
|||
|
try:
|
|||
|
return func(*args, **kwargs)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"执行 {func.__name__} 时发生错误: {e}")
|
|||
|
return None
|
|||
|
|
|||
|
|
|||
|
def column_exists(table_name, column_name):
|
|||
|
"""
|
|||
|
检查列是否存在
|
|||
|
|
|||
|
Args:
|
|||
|
table_name: 表名
|
|||
|
column_name: 列名
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 列是否存在
|
|||
|
"""
|
|||
|
conn = op.get_bind()
|
|||
|
inspector = reflection.Inspector.from_engine(conn)
|
|||
|
columns = inspector.get_columns(table_name)
|
|||
|
return column_name in [column['name'] for column in columns]
|
|||
|
|
|||
|
|
|||
|
def index_exists(table_name, index_name):
|
|||
|
"""
|
|||
|
检查索引是否存在
|
|||
|
|
|||
|
Args:
|
|||
|
table_name: 表名
|
|||
|
index_name: 索引名
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 索引是否存在
|
|||
|
"""
|
|||
|
conn = op.get_bind()
|
|||
|
inspector = reflection.Inspector.from_engine(conn)
|
|||
|
indexes = inspector.get_indexes(table_name)
|
|||
|
return index_name in [index['name'] for index in indexes]
|
|||
|
|
|||
|
|
|||
|
def table_exists(table_name):
|
|||
|
"""
|
|||
|
检查表是否存在
|
|||
|
|
|||
|
Args:
|
|||
|
table_name: 表名
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 表是否存在
|
|||
|
"""
|
|||
|
conn = op.get_bind()
|
|||
|
inspector = reflection.Inspector.from_engine(conn)
|
|||
|
return table_name in inspector.get_table_names()
|
|||
|
|
|||
|
|
|||
|
def safe_add_column(table_name, column_name, column_type, **kwargs):
|
|||
|
"""
|
|||
|
安全添加列,如果列已存在则跳过
|
|||
|
|
|||
|
Args:
|
|||
|
table_name: 表名
|
|||
|
column_name: 列名
|
|||
|
column_type: 列类型
|
|||
|
**kwargs: 其他参数,如 nullable, default, comment 等
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否成功添加列
|
|||
|
"""
|
|||
|
if not column_exists(table_name, column_name):
|
|||
|
logger.info(f"添加列 {column_name} 到表 {table_name}")
|
|||
|
op.add_column(table_name, sa.Column(column_name, column_type, **kwargs))
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.info(f"列 {column_name} 已存在于表 {table_name},跳过添加")
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def safe_drop_column(table_name, column_name):
|
|||
|
"""
|
|||
|
安全删除列,如果列不存在则跳过
|
|||
|
|
|||
|
Args:
|
|||
|
table_name: 表名
|
|||
|
column_name: 列名
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否成功删除列
|
|||
|
"""
|
|||
|
if column_exists(table_name, column_name):
|
|||
|
logger.info(f"从表 {table_name} 中删除列 {column_name}")
|
|||
|
op.drop_column(table_name, column_name)
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.info(f"列 {column_name} 不存在于表 {table_name},跳过删除")
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def safe_create_index(index_name, table_name, columns, unique=False):
|
|||
|
"""
|
|||
|
安全创建索引,如果索引已存在则跳过
|
|||
|
|
|||
|
Args:
|
|||
|
index_name: 索引名
|
|||
|
table_name: 表名
|
|||
|
columns: 列名列表
|
|||
|
unique: 是否唯一索引
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否成功创建索引
|
|||
|
"""
|
|||
|
if not index_exists(table_name, index_name):
|
|||
|
logger.info(f"在表 {table_name} 上创建索引 {index_name}")
|
|||
|
op.create_index(index_name, table_name, columns, unique=unique)
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.info(f"索引 {index_name} 已存在于表 {table_name},跳过创建")
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def safe_drop_index(index_name, table_name):
|
|||
|
"""
|
|||
|
安全删除索引,如果索引不存在则跳过
|
|||
|
|
|||
|
Args:
|
|||
|
index_name: 索引名
|
|||
|
table_name: 表名
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否成功删除索引
|
|||
|
"""
|
|||
|
if index_exists(table_name, index_name):
|
|||
|
logger.info(f"删除表 {table_name} 上的索引 {index_name}")
|
|||
|
op.drop_index(index_name, table_name)
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.info(f"索引 {index_name} 不存在于表 {table_name},跳过删除")
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def safe_alter_column(table_name, column_name, **kwargs):
|
|||
|
"""
|
|||
|
安全修改列,处理 MySQL 特有的问题
|
|||
|
|
|||
|
Args:
|
|||
|
table_name: 表名
|
|||
|
column_name: 列名
|
|||
|
**kwargs: 其他参数,如 nullable, type_, existing_type 等
|
|||
|
|
|||
|
Returns:
|
|||
|
bool: 是否成功修改列
|
|||
|
"""
|
|||
|
if column_exists(table_name, column_name):
|
|||
|
# 获取数据库连接
|
|||
|
conn = op.get_bind()
|
|||
|
|
|||
|
# 检查数据库类型
|
|||
|
is_mysql = conn.dialect.name == 'mysql'
|
|||
|
|
|||
|
if is_mysql and 'nullable' in kwargs and not kwargs.get('existing_type') and not kwargs.get('type_'):
|
|||
|
# MySQL 需要指定列类型
|
|||
|
inspector = reflection.Inspector.from_engine(conn)
|
|||
|
columns = inspector.get_columns(table_name)
|
|||
|
column_info = next((col for col in columns if col['name'] == column_name), None)
|
|||
|
|
|||
|
if column_info:
|
|||
|
# 获取列的当前类型
|
|||
|
column_type = column_info['type']
|
|||
|
nullable = kwargs.get('nullable')
|
|||
|
|
|||
|
logger.info(f"使用 SQL 语句修改列 {column_name} 的 nullable 属性为 {nullable}")
|
|||
|
|
|||
|
# 构建 SQL 语句
|
|||
|
null_str = "NULL" if nullable else "NOT NULL"
|
|||
|
comment = column_info.get('comment', '')
|
|||
|
comment_str = f" COMMENT '{comment}'" if comment else ""
|
|||
|
|
|||
|
# 根据列类型构建类型字符串
|
|||
|
if isinstance(column_type, sa.String):
|
|||
|
type_str = f"VARCHAR({column_type.length})"
|
|||
|
elif isinstance(column_type, sa.Integer):
|
|||
|
type_str = "INTEGER"
|
|||
|
elif isinstance(column_type, sa.Boolean):
|
|||
|
type_str = "BOOLEAN"
|
|||
|
elif isinstance(column_type, sa.DateTime):
|
|||
|
type_str = "DATETIME"
|
|||
|
elif isinstance(column_type, sa.Text):
|
|||
|
type_str = "TEXT"
|
|||
|
else:
|
|||
|
# 默认使用 VARCHAR(255)
|
|||
|
type_str = "VARCHAR(255)"
|
|||
|
|
|||
|
# 执行 SQL 语句
|
|||
|
sql = f"ALTER TABLE {table_name} MODIFY COLUMN {column_name} {type_str} {null_str}{comment_str}"
|
|||
|
conn.execute(text(sql))
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.error(f"无法获取列 {column_name} 的信息")
|
|||
|
return False
|
|||
|
else:
|
|||
|
# 使用 alembic 的 alter_column
|
|||
|
logger.info(f"修改表 {table_name} 中的列 {column_name}")
|
|||
|
op.alter_column(table_name, column_name, **kwargs)
|
|||
|
return True
|
|||
|
else:
|
|||
|
logger.info(f"列 {column_name} 不存在于表 {table_name},跳过修改")
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
def generate_uuid_for_null_values(table_name, column_name):
|
|||
|
"""
|
|||
|
为表中指定列的 NULL 值生成 UUID
|
|||
|
|
|||
|
Args:
|
|||
|
table_name: 表名
|
|||
|
column_name: 列名
|
|||
|
|
|||
|
Returns:
|
|||
|
int: 更新的记录数
|
|||
|
"""
|
|||
|
if column_exists(table_name, column_name):
|
|||
|
conn = op.get_bind()
|
|||
|
|
|||
|
# 获取所有 NULL 值的记录
|
|||
|
result = conn.execute(text(f"SELECT id FROM {table_name} WHERE {column_name} IS NULL"))
|
|||
|
records = result.fetchall()
|
|||
|
|
|||
|
if records:
|
|||
|
logger.info(f"为表 {table_name} 中 {len(records)} 条记录的 {column_name} 列生成 UUID")
|
|||
|
|
|||
|
# 为每条记录生成 UUID 并更新
|
|||
|
for record in records:
|
|||
|
record_id = record[0]
|
|||
|
record_uuid = str(uuid.uuid4())
|
|||
|
conn.execute(text(f"UPDATE {table_name} SET {column_name} = '{record_uuid}' WHERE id = {record_id}"))
|
|||
|
|
|||
|
return len(records)
|
|||
|
else:
|
|||
|
logger.info(f"表 {table_name} 中没有 {column_name} 列为 NULL 的记录")
|
|||
|
return 0
|
|||
|
else:
|
|||
|
logger.info(f"列 {column_name} 不存在于表 {table_name},跳过生成 UUID")
|
|||
|
return 0
|