""" 数据库迁移工具 用于自动检测模型变更并应用到数据库 """ import logging import sqlalchemy as sa from sqlalchemy import inspect, text, Enum from sqlalchemy.exc import OperationalError, ProgrammingError from data.session import get_db from config.database import DBConfig import importlib import pkgutil import data.models import enum import json # 设置日志 logger = logging.getLogger(__name__) class DBMigration: """数据库迁移工具类""" def __init__(self): """初始化迁移工具""" self.db = next(get_db()) self.engine = self.db.bind self.inspector = inspect(self.engine) self.metadata = sa.MetaData() self.metadata.reflect(bind=self.engine) def get_all_models(self): """获取所有模型类""" models = [] # 导入所有模型模块 for _, name, _ in pkgutil.iter_modules(data.models.__path__): try: module = importlib.import_module(f"data.models.{name}") # 获取模块中的所有类 for attr_name in dir(module): attr = getattr(module, attr_name) # 检查是否为模型类 if isinstance(attr, type) and hasattr(attr, "__tablename__") and attr.__name__ != "BaseModel": models.append(attr) except ImportError as e: logger.error(f"导入模块 data.models.{name} 失败: {str(e)}") return models def get_table_columns(self, table_name): """获取表的所有列信息""" columns = {} try: for column in self.inspector.get_columns(table_name): columns[column["name"]] = column except Exception as e: logger.error(f"获取表 {table_name} 的列信息失败: {str(e)}") return columns def get_model_columns(self, model): """获取模型的所有列信息""" columns = {} for name, column in model.__table__.columns.items(): columns[name] = column return columns def check_table_exists(self, table_name): """检查表是否存在""" return self.inspector.has_table(table_name) def create_table(self, model): """创建表""" try: model.__table__.create(self.engine) logger.info(f"创建表 {model.__tablename__} 成功") return True except Exception as e: logger.error(f"创建表 {model.__tablename__} 失败: {str(e)}") return False def get_column_type_sql(self, column): """获取列类型的SQL表示""" try: # 处理枚举类型 if isinstance(column.type, Enum): # MySQL中枚举类型的处理 if self.engine.dialect.name == 'mysql': enum_values = [f"'{val}'" for val in column.type.enums] return f"ENUM({', '.join(enum_values)})" # PostgreSQL中枚举类型的处理 elif self.engine.dialect.name == 'postgresql': # 获取枚举类型名称 enum_name = column.type.name or f"{column.table.name}_{column.name}_enum" # 创建枚举类型 enum_values = [f"'{val}'" for val in column.type.enums] self.db.execute(text(f"CREATE TYPE IF NOT EXISTS {enum_name} AS ENUM ({', '.join(enum_values)})")) self.db.commit() return enum_name # 其他类型直接使用SQLAlchemy的编译功能 return column.type.compile(self.engine.dialect) except Exception as e: logger.error(f"获取列类型失败: {str(e)}") # 默认返回VARCHAR类型 return "VARCHAR(255)" def get_column_default_sql(self, column): """获取列默认值的SQL表示""" try: default = column.default if default is None: return "" # 处理服务器端默认值 if default.is_sequence or default.is_callable: return "" # 服务器端默认值不需要在ADD COLUMN中指定 # 处理标量默认值 if default.is_scalar: if isinstance(default.arg, bool): return f" DEFAULT {1 if default.arg else 0}" if self.engine.dialect.name == 'mysql' else f" DEFAULT {str(default.arg).lower()}" elif isinstance(default.arg, (int, float)): return f" DEFAULT {default.arg}" elif isinstance(default.arg, str): return f" DEFAULT '{default.arg}'" elif isinstance(default.arg, enum.Enum): return f" DEFAULT '{default.arg.value}'" elif default.arg is None: return " DEFAULT NULL" # 处理JSON默认值 if hasattr(column.type, 'python_type') and column.type.python_type == dict: if default.arg is None: return " DEFAULT '{}'" return f" DEFAULT '{json.dumps(default.arg)}'" return "" except Exception as e: logger.error(f"获取列默认值失败: {str(e)}") return "" def add_column(self, table_name, column_name, column): """添加列""" try: # 获取列类型 column_type = self.get_column_type_sql(column) # 获取列默认值 default_value = self.get_column_default_sql(column) # 获取列是否可为空 nullable = "" if column.nullable else " NOT NULL" # 构建SQL语句 sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}{nullable}{default_value}" # 执行SQL logger.info(f"执行SQL: {sql}") self.db.execute(text(sql)) self.db.commit() logger.info(f"添加列 {table_name}.{column_name} 成功") return True except Exception as e: logger.error(f"添加列 {table_name}.{column_name} 失败: {str(e)}") self.db.rollback() return False def migrate(self): """执行迁移""" models = self.get_all_models() logger.info(f"发现 {len(models)} 个模型") for model in models: table_name = model.__tablename__ # 检查表是否存在 if not self.check_table_exists(table_name): logger.info(f"表 {table_name} 不存在,准备创建") self.create_table(model) continue # 获取表的列信息 table_columns = self.get_table_columns(table_name) # 获取模型的列信息 model_columns = self.get_model_columns(model) # 检查是否有新增的列 for column_name, column in model_columns.items(): if column_name not in table_columns: logger.info(f"发现新增列 {table_name}.{column_name},准备添加") self.add_column(table_name, column_name, column) logger.info("数据库迁移完成") def run_migration(): """运行数据库迁移""" try: logger.info("开始数据库迁移") migration = DBMigration() migration.migrate() logger.info("数据库迁移成功") return True except Exception as e: logger.error(f"数据库迁移失败: {str(e)}") return False if __name__ == "__main__": # 设置日志格式 logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) # 运行迁移 run_migration()