192 lines
6.3 KiB
Python
192 lines
6.3 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
"""
|
||
通用数据库迁移脚本生成工具
|
||
|
||
用法:
|
||
python scripts/generate_migration.py --table 表名 --field 字段名 --type 字段类型 [--nullable] [--default 默认值] [--comment 注释]
|
||
|
||
示例:
|
||
python scripts/generate_migration.py --table tasks --field task_id --type String(36) --comment "任务UUID,用于外部引用"
|
||
"""
|
||
|
||
import argparse
|
||
import os
|
||
import sys
|
||
import uuid
|
||
import datetime
|
||
from pathlib import Path
|
||
|
||
# 将项目根目录添加到 Python 路径
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
def generate_migration_script(table_name, field_name, field_type, nullable=False, default=None, comment=None, unique=False, index=False):
|
||
"""
|
||
生成迁移脚本
|
||
|
||
Args:
|
||
table_name (str): 表名
|
||
field_name (str): 字段名
|
||
field_type (str): 字段类型,如 String(36), Integer, Boolean 等
|
||
nullable (bool): 是否可为空
|
||
default (str): 默认值
|
||
comment (str): 注释
|
||
unique (bool): 是否唯一
|
||
index (bool): 是否创建索引
|
||
"""
|
||
# 生成迁移脚本文件名
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
revision_id = timestamp
|
||
script_name = f"{timestamp}_add_{field_name}_to_{table_name}.py"
|
||
script_path = Path(__file__).parent.parent / "migrations" / "versions" / script_name
|
||
|
||
# 获取最新的 revision
|
||
versions_dir = Path(__file__).parent.parent / "migrations" / "versions"
|
||
revisions = []
|
||
for file in versions_dir.glob("*.py"):
|
||
if file.name.startswith("__"):
|
||
continue
|
||
with open(file, "r", encoding="utf-8") as f:
|
||
for line in f:
|
||
if line.startswith("revision = "):
|
||
rev = line.split("=")[1].strip().strip("'\"")
|
||
revisions.append(rev)
|
||
break
|
||
|
||
down_revision = "None"
|
||
if revisions:
|
||
down_revision = f"'{sorted(revisions)[-1]}'"
|
||
|
||
# 生成字段定义
|
||
field_def = f"sa.Column('{field_name}', sa.{field_type}"
|
||
if not nullable:
|
||
field_def += ", nullable=False"
|
||
if default is not None:
|
||
field_def += f", server_default=sa.text('{default}')"
|
||
if comment:
|
||
field_def += f", comment='{comment}'"
|
||
field_def += ")"
|
||
|
||
# 生成迁移脚本内容
|
||
script_content = f"""\"\"\"添加 {field_name} 字段到 {table_name} 表
|
||
|
||
Revision ID: {revision_id}
|
||
Revises: {down_revision.strip("'")}
|
||
Create Date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
|
||
|
||
\"\"\"
|
||
from alembic import op
|
||
import sqlalchemy as sa
|
||
import uuid
|
||
|
||
|
||
# revision identifiers, used by Alembic.
|
||
revision = '{revision_id}'
|
||
down_revision = {down_revision}
|
||
branch_labels = None
|
||
depends_on = None
|
||
|
||
|
||
def upgrade():
|
||
# 添加 {field_name} 字段到 {table_name} 表
|
||
op.add_column('{table_name}', {field_def})
|
||
|
||
# 为已有记录生成值
|
||
conn = op.get_bind()
|
||
|
||
# 获取所有记录
|
||
records = conn.execute("SELECT id FROM {table_name} WHERE {field_name} IS NULL").fetchall()
|
||
|
||
# 为每个记录生成值并更新
|
||
for record in records:
|
||
"""
|
||
|
||
# 根据字段类型生成不同的更新语句
|
||
if "String" in field_type and "uuid" in field_name.lower():
|
||
script_content += f""" record_uuid = str(uuid.uuid4())
|
||
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{record_uuid}}' WHERE id = {{record[0]}}")
|
||
"""
|
||
elif "String" in field_type:
|
||
script_content += f""" # 为字符串字段生成默认值,根据实际情况修改
|
||
default_value = f"default_{{record[0]}}"
|
||
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{default_value}}' WHERE id = {{record[0]}}")
|
||
"""
|
||
elif "Integer" in field_type:
|
||
script_content += f""" # 为整数字段生成默认值,根据实际情况修改
|
||
default_value = 0
|
||
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
|
||
"""
|
||
elif "Boolean" in field_type:
|
||
script_content += f""" # 为布尔字段生成默认值,根据实际情况修改
|
||
default_value = False
|
||
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
|
||
"""
|
||
else:
|
||
script_content += f""" # 为字段生成默认值,根据实际情况修改
|
||
# default_value = ...
|
||
# conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
|
||
pass
|
||
"""
|
||
|
||
# 添加索引和约束
|
||
if unique or index:
|
||
script_content += f"""
|
||
# 添加索引
|
||
op.create_index('ix_{table_name}_{field_name}', '{table_name}', ['{field_name}'], unique={str(unique).lower()})
|
||
"""
|
||
|
||
if not nullable:
|
||
script_content += f"""
|
||
# 添加非空约束
|
||
op.alter_column('{table_name}', '{field_name}', nullable=False)
|
||
"""
|
||
|
||
# 添加降级函数
|
||
script_content += f"""
|
||
|
||
def downgrade():
|
||
# 删除 {field_name} 字段
|
||
"""
|
||
|
||
if unique or index:
|
||
script_content += f""" op.drop_index('ix_{table_name}_{field_name}', table_name='{table_name}')
|
||
"""
|
||
|
||
script_content += f""" op.drop_column('{table_name}', '{field_name}')
|
||
"""
|
||
|
||
# 写入文件
|
||
with open(script_path, "w", encoding="utf-8") as f:
|
||
f.write(script_content)
|
||
|
||
print(f"迁移脚本已生成: {script_path}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="生成数据库迁移脚本")
|
||
parser.add_argument("--table", required=True, help="表名")
|
||
parser.add_argument("--field", required=True, help="字段名")
|
||
parser.add_argument("--type", required=True, help="字段类型,如 String(36), Integer, Boolean 等")
|
||
parser.add_argument("--nullable", action="store_true", help="是否可为空")
|
||
parser.add_argument("--default", help="默认值")
|
||
parser.add_argument("--comment", help="注释")
|
||
parser.add_argument("--unique", action="store_true", help="是否唯一")
|
||
parser.add_argument("--index", action="store_true", help="是否创建索引")
|
||
|
||
args = parser.parse_args()
|
||
|
||
generate_migration_script(
|
||
args.table,
|
||
args.field,
|
||
args.type,
|
||
args.nullable,
|
||
args.default,
|
||
args.comment,
|
||
args.unique,
|
||
args.index
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |