275 lines
8.3 KiB
Python
275 lines
8.3 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
数据库迁移辅助模块
|
||
提供常用的迁移操作和错误处理
|
||
"""
|
||
|
||
from alembic import op
|
||
import sqlalchemy as sa
|
||
from sqlalchemy import text
|
||
from sqlalchemy.engine import reflection
|
||
import logging
|
||
import uuid
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger('alembic.migration')
|
||
|
||
|
||
def safe_execute(func, *args, **kwargs):
|
||
"""
|
||
安全执行函数,捕获异常并记录日志
|
||
|
||
Args:
|
||
func: 要执行的函数
|
||
*args: 位置参数
|
||
**kwargs: 关键字参数
|
||
|
||
Returns:
|
||
执行结果或 None(如果发生异常)
|
||
"""
|
||
try:
|
||
return func(*args, **kwargs)
|
||
except Exception as e:
|
||
logger.error(f"执行 {func.__name__} 时发生错误: {e}")
|
||
return None
|
||
|
||
|
||
def column_exists(table_name, column_name):
|
||
"""
|
||
检查列是否存在
|
||
|
||
Args:
|
||
table_name: 表名
|
||
column_name: 列名
|
||
|
||
Returns:
|
||
bool: 列是否存在
|
||
"""
|
||
conn = op.get_bind()
|
||
inspector = reflection.Inspector.from_engine(conn)
|
||
columns = inspector.get_columns(table_name)
|
||
return column_name in [column['name'] for column in columns]
|
||
|
||
|
||
def index_exists(table_name, index_name):
|
||
"""
|
||
检查索引是否存在
|
||
|
||
Args:
|
||
table_name: 表名
|
||
index_name: 索引名
|
||
|
||
Returns:
|
||
bool: 索引是否存在
|
||
"""
|
||
conn = op.get_bind()
|
||
inspector = reflection.Inspector.from_engine(conn)
|
||
indexes = inspector.get_indexes(table_name)
|
||
return index_name in [index['name'] for index in indexes]
|
||
|
||
|
||
def table_exists(table_name):
|
||
"""
|
||
检查表是否存在
|
||
|
||
Args:
|
||
table_name: 表名
|
||
|
||
Returns:
|
||
bool: 表是否存在
|
||
"""
|
||
conn = op.get_bind()
|
||
inspector = reflection.Inspector.from_engine(conn)
|
||
return table_name in inspector.get_table_names()
|
||
|
||
|
||
def safe_add_column(table_name, column_name, column_type, **kwargs):
|
||
"""
|
||
安全添加列,如果列已存在则跳过
|
||
|
||
Args:
|
||
table_name: 表名
|
||
column_name: 列名
|
||
column_type: 列类型
|
||
**kwargs: 其他参数,如 nullable, default, comment 等
|
||
|
||
Returns:
|
||
bool: 是否成功添加列
|
||
"""
|
||
if not column_exists(table_name, column_name):
|
||
logger.info(f"添加列 {column_name} 到表 {table_name}")
|
||
op.add_column(table_name, sa.Column(column_name, column_type, **kwargs))
|
||
return True
|
||
else:
|
||
logger.info(f"列 {column_name} 已存在于表 {table_name},跳过添加")
|
||
return False
|
||
|
||
|
||
def safe_drop_column(table_name, column_name):
|
||
"""
|
||
安全删除列,如果列不存在则跳过
|
||
|
||
Args:
|
||
table_name: 表名
|
||
column_name: 列名
|
||
|
||
Returns:
|
||
bool: 是否成功删除列
|
||
"""
|
||
if column_exists(table_name, column_name):
|
||
logger.info(f"从表 {table_name} 中删除列 {column_name}")
|
||
op.drop_column(table_name, column_name)
|
||
return True
|
||
else:
|
||
logger.info(f"列 {column_name} 不存在于表 {table_name},跳过删除")
|
||
return False
|
||
|
||
|
||
def safe_create_index(index_name, table_name, columns, unique=False):
|
||
"""
|
||
安全创建索引,如果索引已存在则跳过
|
||
|
||
Args:
|
||
index_name: 索引名
|
||
table_name: 表名
|
||
columns: 列名列表
|
||
unique: 是否唯一索引
|
||
|
||
Returns:
|
||
bool: 是否成功创建索引
|
||
"""
|
||
if not index_exists(table_name, index_name):
|
||
logger.info(f"在表 {table_name} 上创建索引 {index_name}")
|
||
op.create_index(index_name, table_name, columns, unique=unique)
|
||
return True
|
||
else:
|
||
logger.info(f"索引 {index_name} 已存在于表 {table_name},跳过创建")
|
||
return False
|
||
|
||
|
||
def safe_drop_index(index_name, table_name):
|
||
"""
|
||
安全删除索引,如果索引不存在则跳过
|
||
|
||
Args:
|
||
index_name: 索引名
|
||
table_name: 表名
|
||
|
||
Returns:
|
||
bool: 是否成功删除索引
|
||
"""
|
||
if index_exists(table_name, index_name):
|
||
logger.info(f"删除表 {table_name} 上的索引 {index_name}")
|
||
op.drop_index(index_name, table_name)
|
||
return True
|
||
else:
|
||
logger.info(f"索引 {index_name} 不存在于表 {table_name},跳过删除")
|
||
return False
|
||
|
||
|
||
def safe_alter_column(table_name, column_name, **kwargs):
|
||
"""
|
||
安全修改列,处理 MySQL 特有的问题
|
||
|
||
Args:
|
||
table_name: 表名
|
||
column_name: 列名
|
||
**kwargs: 其他参数,如 nullable, type_, existing_type 等
|
||
|
||
Returns:
|
||
bool: 是否成功修改列
|
||
"""
|
||
if column_exists(table_name, column_name):
|
||
# 获取数据库连接
|
||
conn = op.get_bind()
|
||
|
||
# 检查数据库类型
|
||
is_mysql = conn.dialect.name == 'mysql'
|
||
|
||
if is_mysql and 'nullable' in kwargs and not kwargs.get('existing_type') and not kwargs.get('type_'):
|
||
# MySQL 需要指定列类型
|
||
inspector = reflection.Inspector.from_engine(conn)
|
||
columns = inspector.get_columns(table_name)
|
||
column_info = next((col for col in columns if col['name'] == column_name), None)
|
||
|
||
if column_info:
|
||
# 获取列的当前类型
|
||
column_type = column_info['type']
|
||
nullable = kwargs.get('nullable')
|
||
|
||
logger.info(f"使用 SQL 语句修改列 {column_name} 的 nullable 属性为 {nullable}")
|
||
|
||
# 构建 SQL 语句
|
||
null_str = "NULL" if nullable else "NOT NULL"
|
||
comment = column_info.get('comment', '')
|
||
comment_str = f" COMMENT '{comment}'" if comment else ""
|
||
|
||
# 根据列类型构建类型字符串
|
||
if isinstance(column_type, sa.String):
|
||
type_str = f"VARCHAR({column_type.length})"
|
||
elif isinstance(column_type, sa.Integer):
|
||
type_str = "INTEGER"
|
||
elif isinstance(column_type, sa.Boolean):
|
||
type_str = "BOOLEAN"
|
||
elif isinstance(column_type, sa.DateTime):
|
||
type_str = "DATETIME"
|
||
elif isinstance(column_type, sa.Text):
|
||
type_str = "TEXT"
|
||
else:
|
||
# 默认使用 VARCHAR(255)
|
||
type_str = "VARCHAR(255)"
|
||
|
||
# 执行 SQL 语句
|
||
sql = f"ALTER TABLE {table_name} MODIFY COLUMN {column_name} {type_str} {null_str}{comment_str}"
|
||
conn.execute(text(sql))
|
||
return True
|
||
else:
|
||
logger.error(f"无法获取列 {column_name} 的信息")
|
||
return False
|
||
else:
|
||
# 使用 alembic 的 alter_column
|
||
logger.info(f"修改表 {table_name} 中的列 {column_name}")
|
||
op.alter_column(table_name, column_name, **kwargs)
|
||
return True
|
||
else:
|
||
logger.info(f"列 {column_name} 不存在于表 {table_name},跳过修改")
|
||
return False
|
||
|
||
|
||
def generate_uuid_for_null_values(table_name, column_name):
|
||
"""
|
||
为表中指定列的 NULL 值生成 UUID
|
||
|
||
Args:
|
||
table_name: 表名
|
||
column_name: 列名
|
||
|
||
Returns:
|
||
int: 更新的记录数
|
||
"""
|
||
if column_exists(table_name, column_name):
|
||
conn = op.get_bind()
|
||
|
||
# 获取所有 NULL 值的记录
|
||
result = conn.execute(text(f"SELECT id FROM {table_name} WHERE {column_name} IS NULL"))
|
||
records = result.fetchall()
|
||
|
||
if records:
|
||
logger.info(f"为表 {table_name} 中 {len(records)} 条记录的 {column_name} 列生成 UUID")
|
||
|
||
# 为每条记录生成 UUID 并更新
|
||
for record in records:
|
||
record_id = record[0]
|
||
record_uuid = str(uuid.uuid4())
|
||
conn.execute(text(f"UPDATE {table_name} SET {column_name} = '{record_uuid}' WHERE id = {record_id}"))
|
||
|
||
return len(records)
|
||
else:
|
||
logger.info(f"表 {table_name} 中没有 {column_name} 列为 NULL 的记录")
|
||
return 0
|
||
else:
|
||
logger.info(f"列 {column_name} 不存在于表 {table_name},跳过生成 UUID")
|
||
return 0 |