tianfeng_task_modules/scripts/run_migration.py
2025-03-17 14:58:05 +08:00

341 lines
11 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
数据库迁移执行脚本
用法:
python scripts/run_migration.py [--revision 版本号]
示例:
# 升级到最新版本
python scripts/run_migration.py
# 升级到指定版本
python scripts/run_migration.py --revision 001
# 降级到指定版本
python scripts/run_migration.py --revision 001 --downgrade
# 生成新的迁移脚本
python scripts/run_migration.py --generate "添加新字段"
# 为特定表生成迁移脚本
python scripts/run_migration.py --generate "为用户表添加邮箱字段" --table users
# 为多个表生成迁移脚本
python scripts/run_migration.py --generate "添加审计字段" --table users,orders,products
"""
import argparse
import os
import sys
from pathlib import Path
import subprocess
import shutil
import tempfile
import locale
import datetime
# 将项目根目录添加到 Python 路径
root_dir = Path(__file__).parent.parent
sys.path.insert(0, str(root_dir))
# 导入项目日志模块
from utils.logger import get_logger, setup_logger
# 设置日志
setup_logger()
logger = get_logger('migration')
def create_ascii_config(original_config_path):
"""
创建一个 ASCII 编码的配置文件副本
Args:
original_config_path (Path): 原始配置文件路径
Returns:
Path: 临时配置文件路径
"""
# 如果临时文件已存在,直接使用
temp_config_path = original_config_path.with_suffix('.ini.tmp')
if temp_config_path.exists():
logger.info(f"使用已存在的临时配置文件: {temp_config_path}")
return temp_config_path
# 创建临时文件
logger.info(f"创建 ASCII 编码的临时配置文件: {temp_config_path}")
# 复制配置文件内容,去除中文注释
with open(temp_config_path, 'w', encoding='ascii', errors='ignore') as temp_file:
temp_file.write("""
# Alembic Configuration File
# This file contains basic configuration for Alembic
[alembic]
# Path to migration scripts
script_location = migrations
# Template uses jinja2 format
output_encoding = utf-8
# Database connection configuration
# In practice, this value will be overridden by the configuration in env.py
sqlalchemy.url = driver://user:pass@localhost/dbname
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
""")
return temp_config_path
def run_migration_command(command, args):
"""
执行迁移命令
Args:
command (str): 命令名称,如 'upgrade', 'downgrade', 'revision'
args (argparse.Namespace): 命令行参数
Returns:
bool: 是否成功执行命令
"""
# 获取项目根目录
root_dir = Path(__file__).parent.parent
# 原始配置文件路径
original_config_path = root_dir / "migrations" / "alembic.ini"
# 创建 ASCII 编码的临时配置文件
temp_config_path = create_ascii_config(original_config_path)
# 构建命令
cmd = ["alembic", "-c", str(temp_config_path)]
if args.verbose:
cmd.append("--verbose")
# 打印所有参数,帮助调试
logger.info(f"命令参数: {vars(args)}")
if command == 'upgrade':
cmd.extend(["upgrade", args.revision or "head"])
elif command == 'downgrade':
cmd.extend(["downgrade", args.revision or "base"])
elif command == 'revision':
logger.info(f"生成迁移脚本,描述: {args.message}")
cmd.extend(["revision", "--autogenerate", "-m", args.message])
# 如果指定了表名,创建一个临时的 env.py 文件,只包含指定的表
if args.table:
# 将表名列表转换为 Python 列表字符串
tables = [f"'{table.strip()}'" for table in args.table.split(',')]
tables_str = f"[{', '.join(tables)}]"
# 创建临时环境变量,传递表名列表
os.environ["ALEMBIC_TABLES"] = args.table
logger.info(f"指定迁移表: {args.table}")
if args.branch:
cmd.extend(["--branch", args.branch])
elif command == 'history':
cmd.append("history")
if args.verbose:
cmd.append("-v")
elif command == 'current':
cmd.append("current")
elif command == 'show':
cmd.extend(["show", args.revision or "head"])
elif command == 'list_tables':
# 这不是 alembic 命令,而是我们自定义的命令
return list_database_tables()
else:
logger.error(f"未知命令: {command}")
return False
# 执行命令
logger.info(f"执行命令: {' '.join(cmd)}")
logger.info(f"工作目录: {root_dir}")
logger.info(f"Python 编码: {sys.getdefaultencoding()}")
logger.info(f"文件系统编码: {sys.getfilesystemencoding()}")
logger.info(f"系统默认编码: {locale.getpreferredencoding()}")
# 设置环境变量,确保使用 UTF-8 编码
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
# 创建命令输出日志文件 - 使用项目日志目录
from config.settings import LogConfig
log_config = LogConfig.as_dict()
log_dir = Path(os.path.dirname(log_config["file"]))
log_dir.mkdir(exist_ok=True)
# 使用日期和命令类型命名日志文件
today = datetime.datetime.now().strftime('%Y%m%d')
log_file = log_dir / f"migration_{command}_{today}.log"
try:
# 使用直接执行命令的方式,避免 subprocess 的编码问题
logger.info(f"正在执行迁移,日志将写入 {log_file}...")
# 使用 subprocess.Popen 而不是 subprocess.run
with open(log_file, "a", encoding="utf-8") as f_out:
# 添加分隔线和时间戳
f_out.write(f"\n\n{'='*80}\n")
f_out.write(f"执行时间: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f_out.write(f"执行命令: {' '.join(cmd)}\n")
f_out.write(f"{'='*80}\n\n")
process = subprocess.Popen(
cmd,
cwd=str(root_dir),
env=env,
stdout=f_out,
stderr=subprocess.STDOUT,
text=True,
encoding="utf-8",
errors="replace" # 使用 replace 策略处理无法解码的字符
)
process.wait()
# 读取日志文件
with open(log_file, "r", encoding="utf-8") as f:
# 只读取最后 50 行,避免日志过长
lines = f.readlines()
output = ''.join(lines[-50:]) if len(lines) > 50 else ''.join(lines)
logger.info(f"命令执行状态: {'成功' if process.returncode == 0 else '失败'}")
logger.info(f"输出(最后部分):\n{output}")
# 清理环境变量
if 'ALEMBIC_TABLES' in os.environ:
del os.environ['ALEMBIC_TABLES']
return process.returncode == 0
except Exception as e:
logger.error(f"执行命令时发生错误: {e}", exc_info=True)
# 清理环境变量
if 'ALEMBIC_TABLES' in os.environ:
del os.environ['ALEMBIC_TABLES']
return False
def list_database_tables():
"""
列出数据库中的所有表
Returns:
bool: 是否成功执行命令
"""
try:
# 导入数据库配置
from config.database import DBConfig
# 获取数据库连接
engine = DBConfig.engine
# 获取所有表名
from sqlalchemy import inspect
inspector = inspect(engine)
tables = inspector.get_table_names()
logger.info("数据库中的表:")
for i, table in enumerate(tables, 1):
logger.info(f"{i}. {table}")
return True
except Exception as e:
logger.error(f"获取数据库表时发生错误: {e}", exc_info=True)
return False
def run_migration(args):
"""
执行数据库迁移
Args:
args (argparse.Namespace): 命令行参数
Returns:
bool: 是否成功执行迁移
"""
# 打印所有参数,帮助调试
logger.info(f"运行迁移,参数: {vars(args)}")
if args.list_tables:
# 列出数据库中的所有表
return run_migration_command('list_tables', args)
elif args.message:
# 生成新的迁移脚本
logger.info(f"生成新的迁移脚本,描述: {args.message}")
return run_migration_command('revision', args)
elif args.history:
# 显示迁移历史
return run_migration_command('history', args)
elif args.current:
# 显示当前版本
return run_migration_command('current', args)
elif args.show:
# 显示指定版本的详细信息
return run_migration_command('show', args)
elif args.downgrade:
# 降级到指定版本
return run_migration_command('downgrade', args)
else:
# 升级到指定版本
return run_migration_command('upgrade', args)
def main():
parser = argparse.ArgumentParser(description="执行数据库迁移")
parser.add_argument("--revision", help="版本号,为空表示升级到最新版本")
parser.add_argument("--downgrade", action="store_true", help="是否降级")
parser.add_argument("--verbose", "-v", action="store_true", help="显示详细日志")
parser.add_argument("--generate", "--gen", "-m", dest="message", help="生成新的迁移脚本,需要提供描述信息")
parser.add_argument("--branch", "-b", help="分支名称,用于生成迁移脚本")
parser.add_argument("--history", action="store_true", help="显示迁移历史")
parser.add_argument("--current", action="store_true", help="显示当前版本")
parser.add_argument("--show", action="store_true", help="显示指定版本的详细信息")
parser.add_argument("--table", "-t", help="指定要迁移的表名,多个表用逗号分隔")
parser.add_argument("--list-tables", action="store_true", help="列出数据库中的所有表")
args = parser.parse_args()
success = run_migration(args)
sys.exit(0 if success else 1)
if __name__ == "__main__":
main()