192 lines
6.3 KiB
Python
192 lines
6.3 KiB
Python
|
#!/usr/bin/env python
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
|
|||
|
"""
|
|||
|
通用数据库迁移脚本生成工具
|
|||
|
|
|||
|
用法:
|
|||
|
python scripts/generate_migration.py --table 表名 --field 字段名 --type 字段类型 [--nullable] [--default 默认值] [--comment 注释]
|
|||
|
|
|||
|
示例:
|
|||
|
python scripts/generate_migration.py --table tasks --field task_id --type String(36) --comment "任务UUID,用于外部引用"
|
|||
|
"""
|
|||
|
|
|||
|
import argparse
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import uuid
|
|||
|
import datetime
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
# 将项目根目录添加到 Python 路径
|
|||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|||
|
|
|||
|
def generate_migration_script(table_name, field_name, field_type, nullable=False, default=None, comment=None, unique=False, index=False):
|
|||
|
"""
|
|||
|
生成迁移脚本
|
|||
|
|
|||
|
Args:
|
|||
|
table_name (str): 表名
|
|||
|
field_name (str): 字段名
|
|||
|
field_type (str): 字段类型,如 String(36), Integer, Boolean 等
|
|||
|
nullable (bool): 是否可为空
|
|||
|
default (str): 默认值
|
|||
|
comment (str): 注释
|
|||
|
unique (bool): 是否唯一
|
|||
|
index (bool): 是否创建索引
|
|||
|
"""
|
|||
|
# 生成迁移脚本文件名
|
|||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|||
|
revision_id = timestamp
|
|||
|
script_name = f"{timestamp}_add_{field_name}_to_{table_name}.py"
|
|||
|
script_path = Path(__file__).parent.parent / "migrations" / "versions" / script_name
|
|||
|
|
|||
|
# 获取最新的 revision
|
|||
|
versions_dir = Path(__file__).parent.parent / "migrations" / "versions"
|
|||
|
revisions = []
|
|||
|
for file in versions_dir.glob("*.py"):
|
|||
|
if file.name.startswith("__"):
|
|||
|
continue
|
|||
|
with open(file, "r", encoding="utf-8") as f:
|
|||
|
for line in f:
|
|||
|
if line.startswith("revision = "):
|
|||
|
rev = line.split("=")[1].strip().strip("'\"")
|
|||
|
revisions.append(rev)
|
|||
|
break
|
|||
|
|
|||
|
down_revision = "None"
|
|||
|
if revisions:
|
|||
|
down_revision = f"'{sorted(revisions)[-1]}'"
|
|||
|
|
|||
|
# 生成字段定义
|
|||
|
field_def = f"sa.Column('{field_name}', sa.{field_type}"
|
|||
|
if not nullable:
|
|||
|
field_def += ", nullable=False"
|
|||
|
if default is not None:
|
|||
|
field_def += f", server_default=sa.text('{default}')"
|
|||
|
if comment:
|
|||
|
field_def += f", comment='{comment}'"
|
|||
|
field_def += ")"
|
|||
|
|
|||
|
# 生成迁移脚本内容
|
|||
|
script_content = f"""\"\"\"添加 {field_name} 字段到 {table_name} 表
|
|||
|
|
|||
|
Revision ID: {revision_id}
|
|||
|
Revises: {down_revision.strip("'")}
|
|||
|
Create Date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
|||
|
|
|||
|
\"\"\"
|
|||
|
from alembic import op
|
|||
|
import sqlalchemy as sa
|
|||
|
import uuid
|
|||
|
|
|||
|
|
|||
|
# revision identifiers, used by Alembic.
|
|||
|
revision = '{revision_id}'
|
|||
|
down_revision = {down_revision}
|
|||
|
branch_labels = None
|
|||
|
depends_on = None
|
|||
|
|
|||
|
|
|||
|
def upgrade():
|
|||
|
# 添加 {field_name} 字段到 {table_name} 表
|
|||
|
op.add_column('{table_name}', {field_def})
|
|||
|
|
|||
|
# 为已有记录生成值
|
|||
|
conn = op.get_bind()
|
|||
|
|
|||
|
# 获取所有记录
|
|||
|
records = conn.execute("SELECT id FROM {table_name} WHERE {field_name} IS NULL").fetchall()
|
|||
|
|
|||
|
# 为每个记录生成值并更新
|
|||
|
for record in records:
|
|||
|
"""
|
|||
|
|
|||
|
# 根据字段类型生成不同的更新语句
|
|||
|
if "String" in field_type and "uuid" in field_name.lower():
|
|||
|
script_content += f""" record_uuid = str(uuid.uuid4())
|
|||
|
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{record_uuid}}' WHERE id = {{record[0]}}")
|
|||
|
"""
|
|||
|
elif "String" in field_type:
|
|||
|
script_content += f""" # 为字符串字段生成默认值,根据实际情况修改
|
|||
|
default_value = f"default_{{record[0]}}"
|
|||
|
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{default_value}}' WHERE id = {{record[0]}}")
|
|||
|
"""
|
|||
|
elif "Integer" in field_type:
|
|||
|
script_content += f""" # 为整数字段生成默认值,根据实际情况修改
|
|||
|
default_value = 0
|
|||
|
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
|
|||
|
"""
|
|||
|
elif "Boolean" in field_type:
|
|||
|
script_content += f""" # 为布尔字段生成默认值,根据实际情况修改
|
|||
|
default_value = False
|
|||
|
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
|
|||
|
"""
|
|||
|
else:
|
|||
|
script_content += f""" # 为字段生成默认值,根据实际情况修改
|
|||
|
# default_value = ...
|
|||
|
# conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
|
|||
|
pass
|
|||
|
"""
|
|||
|
|
|||
|
# 添加索引和约束
|
|||
|
if unique or index:
|
|||
|
script_content += f"""
|
|||
|
# 添加索引
|
|||
|
op.create_index('ix_{table_name}_{field_name}', '{table_name}', ['{field_name}'], unique={str(unique).lower()})
|
|||
|
"""
|
|||
|
|
|||
|
if not nullable:
|
|||
|
script_content += f"""
|
|||
|
# 添加非空约束
|
|||
|
op.alter_column('{table_name}', '{field_name}', nullable=False)
|
|||
|
"""
|
|||
|
|
|||
|
# 添加降级函数
|
|||
|
script_content += f"""
|
|||
|
|
|||
|
def downgrade():
|
|||
|
# 删除 {field_name} 字段
|
|||
|
"""
|
|||
|
|
|||
|
if unique or index:
|
|||
|
script_content += f""" op.drop_index('ix_{table_name}_{field_name}', table_name='{table_name}')
|
|||
|
"""
|
|||
|
|
|||
|
script_content += f""" op.drop_column('{table_name}', '{field_name}')
|
|||
|
"""
|
|||
|
|
|||
|
# 写入文件
|
|||
|
with open(script_path, "w", encoding="utf-8") as f:
|
|||
|
f.write(script_content)
|
|||
|
|
|||
|
print(f"迁移脚本已生成: {script_path}")
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
parser = argparse.ArgumentParser(description="生成数据库迁移脚本")
|
|||
|
parser.add_argument("--table", required=True, help="表名")
|
|||
|
parser.add_argument("--field", required=True, help="字段名")
|
|||
|
parser.add_argument("--type", required=True, help="字段类型,如 String(36), Integer, Boolean 等")
|
|||
|
parser.add_argument("--nullable", action="store_true", help="是否可为空")
|
|||
|
parser.add_argument("--default", help="默认值")
|
|||
|
parser.add_argument("--comment", help="注释")
|
|||
|
parser.add_argument("--unique", action="store_true", help="是否唯一")
|
|||
|
parser.add_argument("--index", action="store_true", help="是否创建索引")
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
generate_migration_script(
|
|||
|
args.table,
|
|||
|
args.field,
|
|||
|
args.type,
|
|||
|
args.nullable,
|
|||
|
args.default,
|
|||
|
args.comment,
|
|||
|
args.unique,
|
|||
|
args.index
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|