218 lines
7.9 KiB
Python
218 lines
7.9 KiB
Python
"""
|
|
数据库迁移工具
|
|
用于自动检测模型变更并应用到数据库
|
|
"""
|
|
import logging
|
|
import sqlalchemy as sa
|
|
from sqlalchemy import inspect, text, Enum
|
|
from sqlalchemy.exc import OperationalError, ProgrammingError
|
|
from data.session import get_db
|
|
import importlib
|
|
import pkgutil
|
|
import data.models
|
|
import enum
|
|
import json
|
|
from utils.logger import get_logger
|
|
|
|
# 设置日志
|
|
logger = get_logger("utils.db_migration")
|
|
|
|
class DBMigration:
|
|
"""数据库迁移工具类"""
|
|
|
|
def __init__(self):
|
|
"""初始化迁移工具"""
|
|
self.db = next(get_db())
|
|
self.engine = self.db.bind
|
|
self.inspector = inspect(self.engine)
|
|
self.metadata = sa.MetaData()
|
|
self.metadata.reflect(bind=self.engine)
|
|
|
|
def get_all_models(self):
|
|
"""获取所有模型类"""
|
|
models = []
|
|
# 导入所有模型模块
|
|
for _, name, _ in pkgutil.iter_modules(data.models.__path__):
|
|
try:
|
|
module = importlib.import_module(f"data.models.{name}")
|
|
# 获取模块中的所有类
|
|
for attr_name in dir(module):
|
|
attr = getattr(module, attr_name)
|
|
# 检查是否为模型类
|
|
if isinstance(attr, type) and hasattr(attr, "__tablename__") and attr.__name__ != "BaseModel":
|
|
models.append(attr)
|
|
except ImportError as e:
|
|
logger.error(f"导入模块 data.models.{name} 失败: {str(e)}")
|
|
|
|
return models
|
|
|
|
def get_table_columns(self, table_name):
|
|
"""获取表的所有列信息"""
|
|
columns = {}
|
|
try:
|
|
for column in self.inspector.get_columns(table_name):
|
|
columns[column["name"]] = column
|
|
except Exception as e:
|
|
logger.error(f"获取表 {table_name} 的列信息失败: {str(e)}")
|
|
|
|
return columns
|
|
|
|
def get_model_columns(self, model):
|
|
"""获取模型的所有列信息"""
|
|
columns = {}
|
|
for name, column in model.__table__.columns.items():
|
|
columns[name] = column
|
|
|
|
return columns
|
|
|
|
def check_table_exists(self, table_name):
|
|
"""检查表是否存在"""
|
|
return self.inspector.has_table(table_name)
|
|
|
|
def create_table(self, model):
|
|
"""创建表"""
|
|
try:
|
|
model.__table__.create(self.engine)
|
|
logger.info(f"创建表 {model.__tablename__} 成功")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"创建表 {model.__tablename__} 失败: {str(e)}")
|
|
return False
|
|
|
|
def get_column_type_sql(self, column):
|
|
"""获取列类型的SQL表示"""
|
|
try:
|
|
# 处理枚举类型
|
|
if isinstance(column.type, Enum):
|
|
# MySQL中枚举类型的处理
|
|
if self.engine.dialect.name == 'mysql':
|
|
enum_values = [f"'{val}'" for val in column.type.enums]
|
|
return f"ENUM({', '.join(enum_values)})"
|
|
# PostgreSQL中枚举类型的处理
|
|
elif self.engine.dialect.name == 'postgresql':
|
|
# 获取枚举类型名称
|
|
enum_name = column.type.name or f"{column.table.name}_{column.name}_enum"
|
|
# 创建枚举类型
|
|
enum_values = [f"'{val}'" for val in column.type.enums]
|
|
self.db.execute(text(f"CREATE TYPE IF NOT EXISTS {enum_name} AS ENUM ({', '.join(enum_values)})"))
|
|
self.db.commit()
|
|
return enum_name
|
|
|
|
# 其他类型直接使用SQLAlchemy的编译功能
|
|
return column.type.compile(self.engine.dialect)
|
|
except Exception as e:
|
|
logger.error(f"获取列类型失败: {str(e)}")
|
|
# 默认返回VARCHAR类型
|
|
return "VARCHAR(255)"
|
|
|
|
def get_column_default_sql(self, column):
|
|
"""获取列默认值的SQL表示"""
|
|
try:
|
|
default = column.default
|
|
if default is None:
|
|
return ""
|
|
|
|
# 处理服务器端默认值
|
|
if default.is_sequence or default.is_callable:
|
|
return "" # 服务器端默认值不需要在ADD COLUMN中指定
|
|
|
|
# 处理标量默认值
|
|
if default.is_scalar:
|
|
if isinstance(default.arg, bool):
|
|
return f" DEFAULT {1 if default.arg else 0}" if self.engine.dialect.name == 'mysql' else f" DEFAULT {str(default.arg).lower()}"
|
|
elif isinstance(default.arg, (int, float)):
|
|
return f" DEFAULT {default.arg}"
|
|
elif isinstance(default.arg, str):
|
|
return f" DEFAULT '{default.arg}'"
|
|
elif isinstance(default.arg, enum.Enum):
|
|
return f" DEFAULT '{default.arg.value}'"
|
|
elif default.arg is None:
|
|
return " DEFAULT NULL"
|
|
|
|
# 处理JSON默认值
|
|
if hasattr(column.type, 'python_type') and column.type.python_type == dict:
|
|
if default.arg is None:
|
|
return " DEFAULT '{}'"
|
|
return f" DEFAULT '{json.dumps(default.arg)}'"
|
|
|
|
return ""
|
|
except Exception as e:
|
|
logger.error(f"获取列默认值失败: {str(e)}")
|
|
return ""
|
|
|
|
def add_column(self, table_name, column_name, column):
|
|
"""添加列"""
|
|
try:
|
|
# 获取列类型
|
|
column_type = self.get_column_type_sql(column)
|
|
|
|
# 获取列默认值
|
|
default_value = self.get_column_default_sql(column)
|
|
|
|
# 获取列是否可为空
|
|
nullable = "" if column.nullable else " NOT NULL"
|
|
|
|
# 构建SQL语句
|
|
sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}{nullable}{default_value}"
|
|
|
|
# 执行SQL
|
|
logger.info(f"执行SQL: {sql}")
|
|
self.db.execute(text(sql))
|
|
self.db.commit()
|
|
|
|
logger.info(f"添加列 {table_name}.{column_name} 成功")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"添加列 {table_name}.{column_name} 失败: {str(e)}")
|
|
self.db.rollback()
|
|
return False
|
|
|
|
def migrate(self):
|
|
"""执行迁移"""
|
|
models = self.get_all_models()
|
|
logger.info(f"发现 {len(models)} 个模型")
|
|
|
|
for model in models:
|
|
table_name = model.__tablename__
|
|
|
|
# 检查表是否存在
|
|
if not self.check_table_exists(table_name):
|
|
logger.info(f"表 {table_name} 不存在,准备创建")
|
|
self.create_table(model)
|
|
continue
|
|
|
|
# 获取表的列信息
|
|
table_columns = self.get_table_columns(table_name)
|
|
|
|
# 获取模型的列信息
|
|
model_columns = self.get_model_columns(model)
|
|
|
|
# 检查是否有新增的列
|
|
for column_name, column in model_columns.items():
|
|
if column_name not in table_columns:
|
|
logger.info(f"发现新增列 {table_name}.{column_name},准备添加")
|
|
self.add_column(table_name, column_name, column)
|
|
|
|
logger.info("数据库迁移完成")
|
|
|
|
def run_migration():
|
|
"""运行数据库迁移"""
|
|
try:
|
|
logger.info("开始数据库迁移")
|
|
migration = DBMigration()
|
|
migration.migrate()
|
|
logger.info("数据库迁移成功")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"数据库迁移失败: {str(e)}")
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
# 设置日志格式
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
)
|
|
|
|
# 运行迁移
|
|
run_migration() |