tianfeng_task_modules/utils/db_migration.py
2025-03-17 18:31:20 +08:00

218 lines
7.9 KiB
Python

"""
数据库迁移工具
用于自动检测模型变更并应用到数据库
"""
import logging
import sqlalchemy as sa
from sqlalchemy import inspect, text, Enum
from sqlalchemy.exc import OperationalError, ProgrammingError
from data.session import get_db
from config.database import DBConfig
import importlib
import pkgutil
import data.models
import enum
import json
# 设置日志
logger = logging.getLogger(__name__)
class DBMigration:
"""数据库迁移工具类"""
def __init__(self):
"""初始化迁移工具"""
self.db = next(get_db())
self.engine = self.db.bind
self.inspector = inspect(self.engine)
self.metadata = sa.MetaData()
self.metadata.reflect(bind=self.engine)
def get_all_models(self):
"""获取所有模型类"""
models = []
# 导入所有模型模块
for _, name, _ in pkgutil.iter_modules(data.models.__path__):
try:
module = importlib.import_module(f"data.models.{name}")
# 获取模块中的所有类
for attr_name in dir(module):
attr = getattr(module, attr_name)
# 检查是否为模型类
if isinstance(attr, type) and hasattr(attr, "__tablename__") and attr.__name__ != "BaseModel":
models.append(attr)
except ImportError as e:
logger.error(f"导入模块 data.models.{name} 失败: {str(e)}")
return models
def get_table_columns(self, table_name):
"""获取表的所有列信息"""
columns = {}
try:
for column in self.inspector.get_columns(table_name):
columns[column["name"]] = column
except Exception as e:
logger.error(f"获取表 {table_name} 的列信息失败: {str(e)}")
return columns
def get_model_columns(self, model):
"""获取模型的所有列信息"""
columns = {}
for name, column in model.__table__.columns.items():
columns[name] = column
return columns
def check_table_exists(self, table_name):
"""检查表是否存在"""
return self.inspector.has_table(table_name)
def create_table(self, model):
"""创建表"""
try:
model.__table__.create(self.engine)
logger.info(f"创建表 {model.__tablename__} 成功")
return True
except Exception as e:
logger.error(f"创建表 {model.__tablename__} 失败: {str(e)}")
return False
def get_column_type_sql(self, column):
"""获取列类型的SQL表示"""
try:
# 处理枚举类型
if isinstance(column.type, Enum):
# MySQL中枚举类型的处理
if self.engine.dialect.name == 'mysql':
enum_values = [f"'{val}'" for val in column.type.enums]
return f"ENUM({', '.join(enum_values)})"
# PostgreSQL中枚举类型的处理
elif self.engine.dialect.name == 'postgresql':
# 获取枚举类型名称
enum_name = column.type.name or f"{column.table.name}_{column.name}_enum"
# 创建枚举类型
enum_values = [f"'{val}'" for val in column.type.enums]
self.db.execute(text(f"CREATE TYPE IF NOT EXISTS {enum_name} AS ENUM ({', '.join(enum_values)})"))
self.db.commit()
return enum_name
# 其他类型直接使用SQLAlchemy的编译功能
return column.type.compile(self.engine.dialect)
except Exception as e:
logger.error(f"获取列类型失败: {str(e)}")
# 默认返回VARCHAR类型
return "VARCHAR(255)"
def get_column_default_sql(self, column):
"""获取列默认值的SQL表示"""
try:
default = column.default
if default is None:
return ""
# 处理服务器端默认值
if default.is_sequence or default.is_callable:
return "" # 服务器端默认值不需要在ADD COLUMN中指定
# 处理标量默认值
if default.is_scalar:
if isinstance(default.arg, bool):
return f" DEFAULT {1 if default.arg else 0}" if self.engine.dialect.name == 'mysql' else f" DEFAULT {str(default.arg).lower()}"
elif isinstance(default.arg, (int, float)):
return f" DEFAULT {default.arg}"
elif isinstance(default.arg, str):
return f" DEFAULT '{default.arg}'"
elif isinstance(default.arg, enum.Enum):
return f" DEFAULT '{default.arg.value}'"
elif default.arg is None:
return " DEFAULT NULL"
# 处理JSON默认值
if hasattr(column.type, 'python_type') and column.type.python_type == dict:
if default.arg is None:
return " DEFAULT '{}'"
return f" DEFAULT '{json.dumps(default.arg)}'"
return ""
except Exception as e:
logger.error(f"获取列默认值失败: {str(e)}")
return ""
def add_column(self, table_name, column_name, column):
"""添加列"""
try:
# 获取列类型
column_type = self.get_column_type_sql(column)
# 获取列默认值
default_value = self.get_column_default_sql(column)
# 获取列是否可为空
nullable = "" if column.nullable else " NOT NULL"
# 构建SQL语句
sql = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}{nullable}{default_value}"
# 执行SQL
logger.info(f"执行SQL: {sql}")
self.db.execute(text(sql))
self.db.commit()
logger.info(f"添加列 {table_name}.{column_name} 成功")
return True
except Exception as e:
logger.error(f"添加列 {table_name}.{column_name} 失败: {str(e)}")
self.db.rollback()
return False
def migrate(self):
"""执行迁移"""
models = self.get_all_models()
logger.info(f"发现 {len(models)} 个模型")
for model in models:
table_name = model.__tablename__
# 检查表是否存在
if not self.check_table_exists(table_name):
logger.info(f"{table_name} 不存在,准备创建")
self.create_table(model)
continue
# 获取表的列信息
table_columns = self.get_table_columns(table_name)
# 获取模型的列信息
model_columns = self.get_model_columns(model)
# 检查是否有新增的列
for column_name, column in model_columns.items():
if column_name not in table_columns:
logger.info(f"发现新增列 {table_name}.{column_name},准备添加")
self.add_column(table_name, column_name, column)
logger.info("数据库迁移完成")
def run_migration():
"""运行数据库迁移"""
try:
logger.info("开始数据库迁移")
migration = DBMigration()
migration.migrate()
logger.info("数据库迁移成功")
return True
except Exception as e:
logger.error(f"数据库迁移失败: {str(e)}")
return False
if __name__ == "__main__":
# 设置日志格式
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# 运行迁移
run_migration()