tianfeng_task_modules/scripts/generate_migration.py

192 lines
6.3 KiB
Python
Raw Normal View History

2025-03-17 14:58:05 +08:00
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
通用数据库迁移脚本生成工具
用法:
python scripts/generate_migration.py --table 表名 --field 字段名 --type 字段类型 [--nullable] [--default 默认值] [--comment 注释]
示例:
python scripts/generate_migration.py --table tasks --field task_id --type String(36) --comment "任务UUID用于外部引用"
"""
import argparse
import os
import sys
import uuid
import datetime
from pathlib import Path
# 将项目根目录添加到 Python 路径
sys.path.insert(0, str(Path(__file__).parent.parent))
def generate_migration_script(table_name, field_name, field_type, nullable=False, default=None, comment=None, unique=False, index=False):
"""
生成迁移脚本
Args:
table_name (str): 表名
field_name (str): 字段名
field_type (str): 字段类型 String(36), Integer, Boolean
nullable (bool): 是否可为空
default (str): 默认值
comment (str): 注释
unique (bool): 是否唯一
index (bool): 是否创建索引
"""
# 生成迁移脚本文件名
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
revision_id = timestamp
script_name = f"{timestamp}_add_{field_name}_to_{table_name}.py"
script_path = Path(__file__).parent.parent / "migrations" / "versions" / script_name
# 获取最新的 revision
versions_dir = Path(__file__).parent.parent / "migrations" / "versions"
revisions = []
for file in versions_dir.glob("*.py"):
if file.name.startswith("__"):
continue
with open(file, "r", encoding="utf-8") as f:
for line in f:
if line.startswith("revision = "):
rev = line.split("=")[1].strip().strip("'\"")
revisions.append(rev)
break
down_revision = "None"
if revisions:
down_revision = f"'{sorted(revisions)[-1]}'"
# 生成字段定义
field_def = f"sa.Column('{field_name}', sa.{field_type}"
if not nullable:
field_def += ", nullable=False"
if default is not None:
field_def += f", server_default=sa.text('{default}')"
if comment:
field_def += f", comment='{comment}'"
field_def += ")"
# 生成迁移脚本内容
script_content = f"""\"\"\"添加 {field_name} 字段到 {table_name}
Revision ID: {revision_id}
Revises: {down_revision.strip("'")}
Create Date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
\"\"\"
from alembic import op
import sqlalchemy as sa
import uuid
# revision identifiers, used by Alembic.
revision = '{revision_id}'
down_revision = {down_revision}
branch_labels = None
depends_on = None
def upgrade():
# 添加 {field_name} 字段到 {table_name} 表
op.add_column('{table_name}', {field_def})
# 为已有记录生成值
conn = op.get_bind()
# 获取所有记录
records = conn.execute("SELECT id FROM {table_name} WHERE {field_name} IS NULL").fetchall()
# 为每个记录生成值并更新
for record in records:
"""
# 根据字段类型生成不同的更新语句
if "String" in field_type and "uuid" in field_name.lower():
script_content += f""" record_uuid = str(uuid.uuid4())
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{record_uuid}}' WHERE id = {{record[0]}}")
"""
elif "String" in field_type:
script_content += f""" # 为字符串字段生成默认值,根据实际情况修改
default_value = f"default_{{record[0]}}"
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{default_value}}' WHERE id = {{record[0]}}")
"""
elif "Integer" in field_type:
script_content += f""" # 为整数字段生成默认值,根据实际情况修改
default_value = 0
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
"""
elif "Boolean" in field_type:
script_content += f""" # 为布尔字段生成默认值,根据实际情况修改
default_value = False
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
"""
else:
script_content += f""" # 为字段生成默认值,根据实际情况修改
# default_value = ...
# conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
pass
"""
# 添加索引和约束
if unique or index:
script_content += f"""
# 添加索引
op.create_index('ix_{table_name}_{field_name}', '{table_name}', ['{field_name}'], unique={str(unique).lower()})
"""
if not nullable:
script_content += f"""
# 添加非空约束
op.alter_column('{table_name}', '{field_name}', nullable=False)
"""
# 添加降级函数
script_content += f"""
def downgrade():
# 删除 {field_name} 字段
"""
if unique or index:
script_content += f""" op.drop_index('ix_{table_name}_{field_name}', table_name='{table_name}')
"""
script_content += f""" op.drop_column('{table_name}', '{field_name}')
"""
# 写入文件
with open(script_path, "w", encoding="utf-8") as f:
f.write(script_content)
print(f"迁移脚本已生成: {script_path}")
def main():
parser = argparse.ArgumentParser(description="生成数据库迁移脚本")
parser.add_argument("--table", required=True, help="表名")
parser.add_argument("--field", required=True, help="字段名")
parser.add_argument("--type", required=True, help="字段类型,如 String(36), Integer, Boolean 等")
parser.add_argument("--nullable", action="store_true", help="是否可为空")
parser.add_argument("--default", help="默认值")
parser.add_argument("--comment", help="注释")
parser.add_argument("--unique", action="store_true", help="是否唯一")
parser.add_argument("--index", action="store_true", help="是否创建索引")
args = parser.parse_args()
generate_migration_script(
args.table,
args.field,
args.type,
args.nullable,
args.default,
args.comment,
args.unique,
args.index
)
if __name__ == "__main__":
main()