VWED_server/migrations/migration_helpers.py

275 lines
8.3 KiB
Python
Raw Normal View History

2025-04-30 16:57:46 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库迁移辅助模块
提供常用的迁移操作和错误处理
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
from sqlalchemy.engine import reflection
import logging
import uuid
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('alembic.migration')
def safe_execute(func, *args, **kwargs):
"""
安全执行函数捕获异常并记录日志
Args:
func: 要执行的函数
*args: 位置参数
**kwargs: 关键字参数
Returns:
执行结果或 None如果发生异常
"""
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(f"执行 {func.__name__} 时发生错误: {e}")
return None
def column_exists(table_name, column_name):
"""
检查列是否存在
Args:
table_name: 表名
column_name: 列名
Returns:
bool: 列是否存在
"""
conn = op.get_bind()
inspector = reflection.Inspector.from_engine(conn)
columns = inspector.get_columns(table_name)
return column_name in [column['name'] for column in columns]
def index_exists(table_name, index_name):
"""
检查索引是否存在
Args:
table_name: 表名
index_name: 索引名
Returns:
bool: 索引是否存在
"""
conn = op.get_bind()
inspector = reflection.Inspector.from_engine(conn)
indexes = inspector.get_indexes(table_name)
return index_name in [index['name'] for index in indexes]
def table_exists(table_name):
"""
检查表是否存在
Args:
table_name: 表名
Returns:
bool: 表是否存在
"""
conn = op.get_bind()
inspector = reflection.Inspector.from_engine(conn)
return table_name in inspector.get_table_names()
def safe_add_column(table_name, column_name, column_type, **kwargs):
"""
安全添加列如果列已存在则跳过
Args:
table_name: 表名
column_name: 列名
column_type: 列类型
**kwargs: 其他参数 nullable, default, comment
Returns:
bool: 是否成功添加列
"""
if not column_exists(table_name, column_name):
logger.info(f"添加列 {column_name} 到表 {table_name}")
op.add_column(table_name, sa.Column(column_name, column_type, **kwargs))
return True
else:
logger.info(f"{column_name} 已存在于表 {table_name},跳过添加")
return False
def safe_drop_column(table_name, column_name):
"""
安全删除列如果列不存在则跳过
Args:
table_name: 表名
column_name: 列名
Returns:
bool: 是否成功删除列
"""
if column_exists(table_name, column_name):
logger.info(f"从表 {table_name} 中删除列 {column_name}")
op.drop_column(table_name, column_name)
return True
else:
logger.info(f"{column_name} 不存在于表 {table_name},跳过删除")
return False
def safe_create_index(index_name, table_name, columns, unique=False):
"""
安全创建索引如果索引已存在则跳过
Args:
index_name: 索引名
table_name: 表名
columns: 列名列表
unique: 是否唯一索引
Returns:
bool: 是否成功创建索引
"""
if not index_exists(table_name, index_name):
logger.info(f"在表 {table_name} 上创建索引 {index_name}")
op.create_index(index_name, table_name, columns, unique=unique)
return True
else:
logger.info(f"索引 {index_name} 已存在于表 {table_name},跳过创建")
return False
def safe_drop_index(index_name, table_name):
"""
安全删除索引如果索引不存在则跳过
Args:
index_name: 索引名
table_name: 表名
Returns:
bool: 是否成功删除索引
"""
if index_exists(table_name, index_name):
logger.info(f"删除表 {table_name} 上的索引 {index_name}")
op.drop_index(index_name, table_name)
return True
else:
logger.info(f"索引 {index_name} 不存在于表 {table_name},跳过删除")
return False
def safe_alter_column(table_name, column_name, **kwargs):
"""
安全修改列处理 MySQL 特有的问题
Args:
table_name: 表名
column_name: 列名
**kwargs: 其他参数 nullable, type_, existing_type
Returns:
bool: 是否成功修改列
"""
if column_exists(table_name, column_name):
# 获取数据库连接
conn = op.get_bind()
# 检查数据库类型
is_mysql = conn.dialect.name == 'mysql'
if is_mysql and 'nullable' in kwargs and not kwargs.get('existing_type') and not kwargs.get('type_'):
# MySQL 需要指定列类型
inspector = reflection.Inspector.from_engine(conn)
columns = inspector.get_columns(table_name)
column_info = next((col for col in columns if col['name'] == column_name), None)
if column_info:
# 获取列的当前类型
column_type = column_info['type']
nullable = kwargs.get('nullable')
logger.info(f"使用 SQL 语句修改列 {column_name} 的 nullable 属性为 {nullable}")
# 构建 SQL 语句
null_str = "NULL" if nullable else "NOT NULL"
comment = column_info.get('comment', '')
comment_str = f" COMMENT '{comment}'" if comment else ""
# 根据列类型构建类型字符串
if isinstance(column_type, sa.String):
type_str = f"VARCHAR({column_type.length})"
elif isinstance(column_type, sa.Integer):
type_str = "INTEGER"
elif isinstance(column_type, sa.Boolean):
type_str = "BOOLEAN"
elif isinstance(column_type, sa.DateTime):
type_str = "DATETIME"
elif isinstance(column_type, sa.Text):
type_str = "TEXT"
else:
# 默认使用 VARCHAR(255)
type_str = "VARCHAR(255)"
# 执行 SQL 语句
sql = f"ALTER TABLE {table_name} MODIFY COLUMN {column_name} {type_str} {null_str}{comment_str}"
conn.execute(text(sql))
return True
else:
logger.error(f"无法获取列 {column_name} 的信息")
return False
else:
# 使用 alembic 的 alter_column
logger.info(f"修改表 {table_name} 中的列 {column_name}")
op.alter_column(table_name, column_name, **kwargs)
return True
else:
logger.info(f"{column_name} 不存在于表 {table_name},跳过修改")
return False
def generate_uuid_for_null_values(table_name, column_name):
"""
为表中指定列的 NULL 值生成 UUID
Args:
table_name: 表名
column_name: 列名
Returns:
int: 更新的记录数
"""
if column_exists(table_name, column_name):
conn = op.get_bind()
# 获取所有 NULL 值的记录
result = conn.execute(text(f"SELECT id FROM {table_name} WHERE {column_name} IS NULL"))
records = result.fetchall()
if records:
logger.info(f"为表 {table_name}{len(records)} 条记录的 {column_name} 列生成 UUID")
# 为每条记录生成 UUID 并更新
for record in records:
record_id = record[0]
record_uuid = str(uuid.uuid4())
conn.execute(text(f"UPDATE {table_name} SET {column_name} = '{record_uuid}' WHERE id = {record_id}"))
return len(records)
else:
logger.info(f"{table_name} 中没有 {column_name} 列为 NULL 的记录")
return 0
else:
logger.info(f"{column_name} 不存在于表 {table_name},跳过生成 UUID")
return 0