#!/usr/bin/env python # -*- coding: utf-8 -*- """ 数据库迁移辅助模块 提供常用的迁移操作和错误处理 """ from alembic import op import sqlalchemy as sa from sqlalchemy import text from sqlalchemy.engine import reflection import logging import uuid # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger('alembic.migration') def safe_execute(func, *args, **kwargs): """ 安全执行函数,捕获异常并记录日志 Args: func: 要执行的函数 *args: 位置参数 **kwargs: 关键字参数 Returns: 执行结果或 None(如果发生异常) """ try: return func(*args, **kwargs) except Exception as e: logger.error(f"执行 {func.__name__} 时发生错误: {e}") return None def column_exists(table_name, column_name): """ 检查列是否存在 Args: table_name: 表名 column_name: 列名 Returns: bool: 列是否存在 """ conn = op.get_bind() inspector = reflection.Inspector.from_engine(conn) columns = inspector.get_columns(table_name) return column_name in [column['name'] for column in columns] def index_exists(table_name, index_name): """ 检查索引是否存在 Args: table_name: 表名 index_name: 索引名 Returns: bool: 索引是否存在 """ conn = op.get_bind() inspector = reflection.Inspector.from_engine(conn) indexes = inspector.get_indexes(table_name) return index_name in [index['name'] for index in indexes] def table_exists(table_name): """ 检查表是否存在 Args: table_name: 表名 Returns: bool: 表是否存在 """ conn = op.get_bind() inspector = reflection.Inspector.from_engine(conn) return table_name in inspector.get_table_names() def safe_add_column(table_name, column_name, column_type, **kwargs): """ 安全添加列,如果列已存在则跳过 Args: table_name: 表名 column_name: 列名 column_type: 列类型 **kwargs: 其他参数,如 nullable, default, comment 等 Returns: bool: 是否成功添加列 """ if not column_exists(table_name, column_name): logger.info(f"添加列 {column_name} 到表 {table_name}") op.add_column(table_name, sa.Column(column_name, column_type, **kwargs)) return True else: logger.info(f"列 {column_name} 已存在于表 {table_name},跳过添加") return False def safe_drop_column(table_name, column_name): """ 安全删除列,如果列不存在则跳过 Args: table_name: 表名 column_name: 列名 Returns: bool: 是否成功删除列 """ if column_exists(table_name, column_name): logger.info(f"从表 {table_name} 中删除列 {column_name}") op.drop_column(table_name, column_name) return True else: logger.info(f"列 {column_name} 不存在于表 {table_name},跳过删除") return False def safe_create_index(index_name, table_name, columns, unique=False): """ 安全创建索引,如果索引已存在则跳过 Args: index_name: 索引名 table_name: 表名 columns: 列名列表 unique: 是否唯一索引 Returns: bool: 是否成功创建索引 """ if not index_exists(table_name, index_name): logger.info(f"在表 {table_name} 上创建索引 {index_name}") op.create_index(index_name, table_name, columns, unique=unique) return True else: logger.info(f"索引 {index_name} 已存在于表 {table_name},跳过创建") return False def safe_drop_index(index_name, table_name): """ 安全删除索引,如果索引不存在则跳过 Args: index_name: 索引名 table_name: 表名 Returns: bool: 是否成功删除索引 """ if index_exists(table_name, index_name): logger.info(f"删除表 {table_name} 上的索引 {index_name}") op.drop_index(index_name, table_name) return True else: logger.info(f"索引 {index_name} 不存在于表 {table_name},跳过删除") return False def safe_alter_column(table_name, column_name, **kwargs): """ 安全修改列,处理 MySQL 特有的问题 Args: table_name: 表名 column_name: 列名 **kwargs: 其他参数,如 nullable, type_, existing_type 等 Returns: bool: 是否成功修改列 """ if column_exists(table_name, column_name): # 获取数据库连接 conn = op.get_bind() # 检查数据库类型 is_mysql = conn.dialect.name == 'mysql' if is_mysql and 'nullable' in kwargs and not kwargs.get('existing_type') and not kwargs.get('type_'): # MySQL 需要指定列类型 inspector = reflection.Inspector.from_engine(conn) columns = inspector.get_columns(table_name) column_info = next((col for col in columns if col['name'] == column_name), None) if column_info: # 获取列的当前类型 column_type = column_info['type'] nullable = kwargs.get('nullable') logger.info(f"使用 SQL 语句修改列 {column_name} 的 nullable 属性为 {nullable}") # 构建 SQL 语句 null_str = "NULL" if nullable else "NOT NULL" comment = column_info.get('comment', '') comment_str = f" COMMENT '{comment}'" if comment else "" # 根据列类型构建类型字符串 if isinstance(column_type, sa.String): type_str = f"VARCHAR({column_type.length})" elif isinstance(column_type, sa.Integer): type_str = "INTEGER" elif isinstance(column_type, sa.Boolean): type_str = "BOOLEAN" elif isinstance(column_type, sa.DateTime): type_str = "DATETIME" elif isinstance(column_type, sa.Text): type_str = "TEXT" else: # 默认使用 VARCHAR(255) type_str = "VARCHAR(255)" # 执行 SQL 语句 sql = f"ALTER TABLE {table_name} MODIFY COLUMN {column_name} {type_str} {null_str}{comment_str}" conn.execute(text(sql)) return True else: logger.error(f"无法获取列 {column_name} 的信息") return False else: # 使用 alembic 的 alter_column logger.info(f"修改表 {table_name} 中的列 {column_name}") op.alter_column(table_name, column_name, **kwargs) return True else: logger.info(f"列 {column_name} 不存在于表 {table_name},跳过修改") return False def generate_uuid_for_null_values(table_name, column_name): """ 为表中指定列的 NULL 值生成 UUID Args: table_name: 表名 column_name: 列名 Returns: int: 更新的记录数 """ if column_exists(table_name, column_name): conn = op.get_bind() # 获取所有 NULL 值的记录 result = conn.execute(text(f"SELECT id FROM {table_name} WHERE {column_name} IS NULL")) records = result.fetchall() if records: logger.info(f"为表 {table_name} 中 {len(records)} 条记录的 {column_name} 列生成 UUID") # 为每条记录生成 UUID 并更新 for record in records: record_id = record[0] record_uuid = str(uuid.uuid4()) conn.execute(text(f"UPDATE {table_name} SET {column_name} = '{record_uuid}' WHERE id = {record_id}")) return len(records) else: logger.info(f"表 {table_name} 中没有 {column_name} 列为 NULL 的记录") return 0 else: logger.info(f"列 {column_name} 不存在于表 {table_name},跳过生成 UUID") return 0