VWED_server/migrations/migration_helpers.py
2025-04-30 16:57:46 +08:00

275 lines
8.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库迁移辅助模块
提供常用的迁移操作和错误处理
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy import text
from sqlalchemy.engine import reflection
import logging
import uuid
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('alembic.migration')
def safe_execute(func, *args, **kwargs):
"""
安全执行函数,捕获异常并记录日志
Args:
func: 要执行的函数
*args: 位置参数
**kwargs: 关键字参数
Returns:
执行结果或 None如果发生异常
"""
try:
return func(*args, **kwargs)
except Exception as e:
logger.error(f"执行 {func.__name__} 时发生错误: {e}")
return None
def column_exists(table_name, column_name):
"""
检查列是否存在
Args:
table_name: 表名
column_name: 列名
Returns:
bool: 列是否存在
"""
conn = op.get_bind()
inspector = reflection.Inspector.from_engine(conn)
columns = inspector.get_columns(table_name)
return column_name in [column['name'] for column in columns]
def index_exists(table_name, index_name):
"""
检查索引是否存在
Args:
table_name: 表名
index_name: 索引名
Returns:
bool: 索引是否存在
"""
conn = op.get_bind()
inspector = reflection.Inspector.from_engine(conn)
indexes = inspector.get_indexes(table_name)
return index_name in [index['name'] for index in indexes]
def table_exists(table_name):
"""
检查表是否存在
Args:
table_name: 表名
Returns:
bool: 表是否存在
"""
conn = op.get_bind()
inspector = reflection.Inspector.from_engine(conn)
return table_name in inspector.get_table_names()
def safe_add_column(table_name, column_name, column_type, **kwargs):
"""
安全添加列,如果列已存在则跳过
Args:
table_name: 表名
column_name: 列名
column_type: 列类型
**kwargs: 其他参数,如 nullable, default, comment 等
Returns:
bool: 是否成功添加列
"""
if not column_exists(table_name, column_name):
logger.info(f"添加列 {column_name} 到表 {table_name}")
op.add_column(table_name, sa.Column(column_name, column_type, **kwargs))
return True
else:
logger.info(f"{column_name} 已存在于表 {table_name},跳过添加")
return False
def safe_drop_column(table_name, column_name):
"""
安全删除列,如果列不存在则跳过
Args:
table_name: 表名
column_name: 列名
Returns:
bool: 是否成功删除列
"""
if column_exists(table_name, column_name):
logger.info(f"从表 {table_name} 中删除列 {column_name}")
op.drop_column(table_name, column_name)
return True
else:
logger.info(f"{column_name} 不存在于表 {table_name},跳过删除")
return False
def safe_create_index(index_name, table_name, columns, unique=False):
"""
安全创建索引,如果索引已存在则跳过
Args:
index_name: 索引名
table_name: 表名
columns: 列名列表
unique: 是否唯一索引
Returns:
bool: 是否成功创建索引
"""
if not index_exists(table_name, index_name):
logger.info(f"在表 {table_name} 上创建索引 {index_name}")
op.create_index(index_name, table_name, columns, unique=unique)
return True
else:
logger.info(f"索引 {index_name} 已存在于表 {table_name},跳过创建")
return False
def safe_drop_index(index_name, table_name):
"""
安全删除索引,如果索引不存在则跳过
Args:
index_name: 索引名
table_name: 表名
Returns:
bool: 是否成功删除索引
"""
if index_exists(table_name, index_name):
logger.info(f"删除表 {table_name} 上的索引 {index_name}")
op.drop_index(index_name, table_name)
return True
else:
logger.info(f"索引 {index_name} 不存在于表 {table_name},跳过删除")
return False
def safe_alter_column(table_name, column_name, **kwargs):
"""
安全修改列,处理 MySQL 特有的问题
Args:
table_name: 表名
column_name: 列名
**kwargs: 其他参数,如 nullable, type_, existing_type 等
Returns:
bool: 是否成功修改列
"""
if column_exists(table_name, column_name):
# 获取数据库连接
conn = op.get_bind()
# 检查数据库类型
is_mysql = conn.dialect.name == 'mysql'
if is_mysql and 'nullable' in kwargs and not kwargs.get('existing_type') and not kwargs.get('type_'):
# MySQL 需要指定列类型
inspector = reflection.Inspector.from_engine(conn)
columns = inspector.get_columns(table_name)
column_info = next((col for col in columns if col['name'] == column_name), None)
if column_info:
# 获取列的当前类型
column_type = column_info['type']
nullable = kwargs.get('nullable')
logger.info(f"使用 SQL 语句修改列 {column_name} 的 nullable 属性为 {nullable}")
# 构建 SQL 语句
null_str = "NULL" if nullable else "NOT NULL"
comment = column_info.get('comment', '')
comment_str = f" COMMENT '{comment}'" if comment else ""
# 根据列类型构建类型字符串
if isinstance(column_type, sa.String):
type_str = f"VARCHAR({column_type.length})"
elif isinstance(column_type, sa.Integer):
type_str = "INTEGER"
elif isinstance(column_type, sa.Boolean):
type_str = "BOOLEAN"
elif isinstance(column_type, sa.DateTime):
type_str = "DATETIME"
elif isinstance(column_type, sa.Text):
type_str = "TEXT"
else:
# 默认使用 VARCHAR(255)
type_str = "VARCHAR(255)"
# 执行 SQL 语句
sql = f"ALTER TABLE {table_name} MODIFY COLUMN {column_name} {type_str} {null_str}{comment_str}"
conn.execute(text(sql))
return True
else:
logger.error(f"无法获取列 {column_name} 的信息")
return False
else:
# 使用 alembic 的 alter_column
logger.info(f"修改表 {table_name} 中的列 {column_name}")
op.alter_column(table_name, column_name, **kwargs)
return True
else:
logger.info(f"{column_name} 不存在于表 {table_name},跳过修改")
return False
def generate_uuid_for_null_values(table_name, column_name):
"""
为表中指定列的 NULL 值生成 UUID
Args:
table_name: 表名
column_name: 列名
Returns:
int: 更新的记录数
"""
if column_exists(table_name, column_name):
conn = op.get_bind()
# 获取所有 NULL 值的记录
result = conn.execute(text(f"SELECT id FROM {table_name} WHERE {column_name} IS NULL"))
records = result.fetchall()
if records:
logger.info(f"为表 {table_name}{len(records)} 条记录的 {column_name} 列生成 UUID")
# 为每条记录生成 UUID 并更新
for record in records:
record_id = record[0]
record_uuid = str(uuid.uuid4())
conn.execute(text(f"UPDATE {table_name} SET {column_name} = '{record_uuid}' WHERE id = {record_id}"))
return len(records)
else:
logger.info(f"{table_name} 中没有 {column_name} 列为 NULL 的记录")
return 0
else:
logger.info(f"{column_name} 不存在于表 {table_name},跳过生成 UUID")
return 0