tianfeng_task_modules/scripts/generate_migration.py
2025-03-17 14:58:05 +08:00

192 lines
6.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
通用数据库迁移脚本生成工具
用法:
python scripts/generate_migration.py --table 表名 --field 字段名 --type 字段类型 [--nullable] [--default 默认值] [--comment 注释]
示例:
python scripts/generate_migration.py --table tasks --field task_id --type String(36) --comment "任务UUID用于外部引用"
"""
import argparse
import os
import sys
import uuid
import datetime
from pathlib import Path
# 将项目根目录添加到 Python 路径
sys.path.insert(0, str(Path(__file__).parent.parent))
def generate_migration_script(table_name, field_name, field_type, nullable=False, default=None, comment=None, unique=False, index=False):
"""
生成迁移脚本
Args:
table_name (str): 表名
field_name (str): 字段名
field_type (str): 字段类型,如 String(36), Integer, Boolean 等
nullable (bool): 是否可为空
default (str): 默认值
comment (str): 注释
unique (bool): 是否唯一
index (bool): 是否创建索引
"""
# 生成迁移脚本文件名
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
revision_id = timestamp
script_name = f"{timestamp}_add_{field_name}_to_{table_name}.py"
script_path = Path(__file__).parent.parent / "migrations" / "versions" / script_name
# 获取最新的 revision
versions_dir = Path(__file__).parent.parent / "migrations" / "versions"
revisions = []
for file in versions_dir.glob("*.py"):
if file.name.startswith("__"):
continue
with open(file, "r", encoding="utf-8") as f:
for line in f:
if line.startswith("revision = "):
rev = line.split("=")[1].strip().strip("'\"")
revisions.append(rev)
break
down_revision = "None"
if revisions:
down_revision = f"'{sorted(revisions)[-1]}'"
# 生成字段定义
field_def = f"sa.Column('{field_name}', sa.{field_type}"
if not nullable:
field_def += ", nullable=False"
if default is not None:
field_def += f", server_default=sa.text('{default}')"
if comment:
field_def += f", comment='{comment}'"
field_def += ")"
# 生成迁移脚本内容
script_content = f"""\"\"\"添加 {field_name} 字段到 {table_name}
Revision ID: {revision_id}
Revises: {down_revision.strip("'")}
Create Date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}
\"\"\"
from alembic import op
import sqlalchemy as sa
import uuid
# revision identifiers, used by Alembic.
revision = '{revision_id}'
down_revision = {down_revision}
branch_labels = None
depends_on = None
def upgrade():
# 添加 {field_name} 字段到 {table_name}
op.add_column('{table_name}', {field_def})
# 为已有记录生成值
conn = op.get_bind()
# 获取所有记录
records = conn.execute("SELECT id FROM {table_name} WHERE {field_name} IS NULL").fetchall()
# 为每个记录生成值并更新
for record in records:
"""
# 根据字段类型生成不同的更新语句
if "String" in field_type and "uuid" in field_name.lower():
script_content += f""" record_uuid = str(uuid.uuid4())
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{record_uuid}}' WHERE id = {{record[0]}}")
"""
elif "String" in field_type:
script_content += f""" # 为字符串字段生成默认值,根据实际情况修改
default_value = f"default_{{record[0]}}"
conn.execute(f"UPDATE {table_name} SET {field_name} = '{{default_value}}' WHERE id = {{record[0]}}")
"""
elif "Integer" in field_type:
script_content += f""" # 为整数字段生成默认值,根据实际情况修改
default_value = 0
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
"""
elif "Boolean" in field_type:
script_content += f""" # 为布尔字段生成默认值,根据实际情况修改
default_value = False
conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
"""
else:
script_content += f""" # 为字段生成默认值,根据实际情况修改
# default_value = ...
# conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
pass
"""
# 添加索引和约束
if unique or index:
script_content += f"""
# 添加索引
op.create_index('ix_{table_name}_{field_name}', '{table_name}', ['{field_name}'], unique={str(unique).lower()})
"""
if not nullable:
script_content += f"""
# 添加非空约束
op.alter_column('{table_name}', '{field_name}', nullable=False)
"""
# 添加降级函数
script_content += f"""
def downgrade():
# 删除 {field_name} 字段
"""
if unique or index:
script_content += f""" op.drop_index('ix_{table_name}_{field_name}', table_name='{table_name}')
"""
script_content += f""" op.drop_column('{table_name}', '{field_name}')
"""
# 写入文件
with open(script_path, "w", encoding="utf-8") as f:
f.write(script_content)
print(f"迁移脚本已生成: {script_path}")
def main():
parser = argparse.ArgumentParser(description="生成数据库迁移脚本")
parser.add_argument("--table", required=True, help="表名")
parser.add_argument("--field", required=True, help="字段名")
parser.add_argument("--type", required=True, help="字段类型,如 String(36), Integer, Boolean 等")
parser.add_argument("--nullable", action="store_true", help="是否可为空")
parser.add_argument("--default", help="默认值")
parser.add_argument("--comment", help="注释")
parser.add_argument("--unique", action="store_true", help="是否唯一")
parser.add_argument("--index", action="store_true", help="是否创建索引")
args = parser.parse_args()
generate_migration_script(
args.table,
args.field,
args.type,
args.nullable,
args.default,
args.comment,
args.unique,
args.index
)
if __name__ == "__main__":
main()