#!/usr/bin/env python # -*- coding: utf-8 -*- """ 通用数据库迁移脚本生成工具 用法: python scripts/generate_migration.py --table 表名 --field 字段名 --type 字段类型 [--nullable] [--default 默认值] [--comment 注释] 示例: python scripts/generate_migration.py --table tasks --field task_id --type String(36) --comment "任务UUID,用于外部引用" """ import argparse import os import sys import uuid import datetime from pathlib import Path # 将项目根目录添加到 Python 路径 sys.path.insert(0, str(Path(__file__).parent.parent)) def generate_migration_script(table_name, field_name, field_type, nullable=False, default=None, comment=None, unique=False, index=False): """ 生成迁移脚本 Args: table_name (str): 表名 field_name (str): 字段名 field_type (str): 字段类型,如 String(36), Integer, Boolean 等 nullable (bool): 是否可为空 default (str): 默认值 comment (str): 注释 unique (bool): 是否唯一 index (bool): 是否创建索引 """ # 生成迁移脚本文件名 timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") revision_id = timestamp script_name = f"{timestamp}_add_{field_name}_to_{table_name}.py" script_path = Path(__file__).parent.parent / "migrations" / "versions" / script_name # 获取最新的 revision versions_dir = Path(__file__).parent.parent / "migrations" / "versions" revisions = [] for file in versions_dir.glob("*.py"): if file.name.startswith("__"): continue with open(file, "r", encoding="utf-8") as f: for line in f: if line.startswith("revision = "): rev = line.split("=")[1].strip().strip("'\"") revisions.append(rev) break down_revision = "None" if revisions: down_revision = f"'{sorted(revisions)[-1]}'" # 生成字段定义 field_def = f"sa.Column('{field_name}', sa.{field_type}" if not nullable: field_def += ", nullable=False" if default is not None: field_def += f", server_default=sa.text('{default}')" if comment: field_def += f", comment='{comment}'" field_def += ")" # 生成迁移脚本内容 script_content = f"""\"\"\"添加 {field_name} 字段到 {table_name} 表 Revision ID: {revision_id} Revises: {down_revision.strip("'")} Create Date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")} \"\"\" from alembic import op import sqlalchemy as sa import uuid # revision identifiers, used by Alembic. revision = '{revision_id}' down_revision = {down_revision} branch_labels = None depends_on = None def upgrade(): # 添加 {field_name} 字段到 {table_name} 表 op.add_column('{table_name}', {field_def}) # 为已有记录生成值 conn = op.get_bind() # 获取所有记录 records = conn.execute("SELECT id FROM {table_name} WHERE {field_name} IS NULL").fetchall() # 为每个记录生成值并更新 for record in records: """ # 根据字段类型生成不同的更新语句 if "String" in field_type and "uuid" in field_name.lower(): script_content += f""" record_uuid = str(uuid.uuid4()) conn.execute(f"UPDATE {table_name} SET {field_name} = '{{record_uuid}}' WHERE id = {{record[0]}}") """ elif "String" in field_type: script_content += f""" # 为字符串字段生成默认值,根据实际情况修改 default_value = f"default_{{record[0]}}" conn.execute(f"UPDATE {table_name} SET {field_name} = '{{default_value}}' WHERE id = {{record[0]}}") """ elif "Integer" in field_type: script_content += f""" # 为整数字段生成默认值,根据实际情况修改 default_value = 0 conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}") """ elif "Boolean" in field_type: script_content += f""" # 为布尔字段生成默认值,根据实际情况修改 default_value = False conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}") """ else: script_content += f""" # 为字段生成默认值,根据实际情况修改 # default_value = ... # conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}") pass """ # 添加索引和约束 if unique or index: script_content += f""" # 添加索引 op.create_index('ix_{table_name}_{field_name}', '{table_name}', ['{field_name}'], unique={str(unique).lower()}) """ if not nullable: script_content += f""" # 添加非空约束 op.alter_column('{table_name}', '{field_name}', nullable=False) """ # 添加降级函数 script_content += f""" def downgrade(): # 删除 {field_name} 字段 """ if unique or index: script_content += f""" op.drop_index('ix_{table_name}_{field_name}', table_name='{table_name}') """ script_content += f""" op.drop_column('{table_name}', '{field_name}') """ # 写入文件 with open(script_path, "w", encoding="utf-8") as f: f.write(script_content) print(f"迁移脚本已生成: {script_path}") def main(): parser = argparse.ArgumentParser(description="生成数据库迁移脚本") parser.add_argument("--table", required=True, help="表名") parser.add_argument("--field", required=True, help="字段名") parser.add_argument("--type", required=True, help="字段类型,如 String(36), Integer, Boolean 等") parser.add_argument("--nullable", action="store_true", help="是否可为空") parser.add_argument("--default", help="默认值") parser.add_argument("--comment", help="注释") parser.add_argument("--unique", action="store_true", help="是否唯一") parser.add_argument("--index", action="store_true", help="是否创建索引") args = parser.parse_args() generate_migration_script( args.table, args.field, args.type, args.nullable, args.default, args.comment, args.unique, args.index ) if __name__ == "__main__": main()