#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
通用数据库迁移脚本生成工具

用法:
    python scripts/generate_migration.py --table 表名 --field 字段名 --type 字段类型 [--nullable] [--default 默认值] [--comment 注释]

示例:
    python scripts/generate_migration.py --table tasks --field task_id --type String(36) --comment "任务UUID,用于外部引用"
"""

import argparse
import os
import sys
import uuid
import datetime
from pathlib import Path

# 将项目根目录添加到 Python 路径
sys.path.insert(0, str(Path(__file__).parent.parent))

def generate_migration_script(table_name, field_name, field_type, nullable=False, default=None, comment=None, unique=False, index=False):
    """
    生成迁移脚本
    
    Args:
        table_name (str): 表名
        field_name (str): 字段名
        field_type (str): 字段类型,如 String(36), Integer, Boolean 等
        nullable (bool): 是否可为空
        default (str): 默认值
        comment (str): 注释
        unique (bool): 是否唯一
        index (bool): 是否创建索引
    """
    # 生成迁移脚本文件名
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    revision_id = timestamp
    script_name = f"{timestamp}_add_{field_name}_to_{table_name}.py"
    script_path = Path(__file__).parent.parent / "migrations" / "versions" / script_name
    
    # 获取最新的 revision
    versions_dir = Path(__file__).parent.parent / "migrations" / "versions"
    revisions = []
    for file in versions_dir.glob("*.py"):
        if file.name.startswith("__"):
            continue
        with open(file, "r", encoding="utf-8") as f:
            for line in f:
                if line.startswith("revision = "):
                    rev = line.split("=")[1].strip().strip("'\"")
                    revisions.append(rev)
                    break
    
    down_revision = "None"
    if revisions:
        down_revision = f"'{sorted(revisions)[-1]}'"
    
    # 生成字段定义
    field_def = f"sa.Column('{field_name}', sa.{field_type}"
    if not nullable:
        field_def += ", nullable=False"
    if default is not None:
        field_def += f", server_default=sa.text('{default}')"
    if comment:
        field_def += f", comment='{comment}'"
    field_def += ")"
    
    # 生成迁移脚本内容
    script_content = f"""\"\"\"添加 {field_name} 字段到 {table_name} 表

Revision ID: {revision_id}
Revises: {down_revision.strip("'")}
Create Date: {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

\"\"\"
from alembic import op
import sqlalchemy as sa
import uuid


# revision identifiers, used by Alembic.
revision = '{revision_id}'
down_revision = {down_revision}
branch_labels = None
depends_on = None


def upgrade():
    # 添加 {field_name} 字段到 {table_name} 表
    op.add_column('{table_name}', {field_def})
    
    # 为已有记录生成值
    conn = op.get_bind()
    
    # 获取所有记录
    records = conn.execute("SELECT id FROM {table_name} WHERE {field_name} IS NULL").fetchall()
    
    # 为每个记录生成值并更新
    for record in records:
"""
    
    # 根据字段类型生成不同的更新语句
    if "String" in field_type and "uuid" in field_name.lower():
        script_content += f"""        record_uuid = str(uuid.uuid4())
        conn.execute(f"UPDATE {table_name} SET {field_name} = '{{record_uuid}}' WHERE id = {{record[0]}}")
"""
    elif "String" in field_type:
        script_content += f"""        # 为字符串字段生成默认值,根据实际情况修改
        default_value = f"default_{{record[0]}}"
        conn.execute(f"UPDATE {table_name} SET {field_name} = '{{default_value}}' WHERE id = {{record[0]}}")
"""
    elif "Integer" in field_type:
        script_content += f"""        # 为整数字段生成默认值,根据实际情况修改
        default_value = 0
        conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
"""
    elif "Boolean" in field_type:
        script_content += f"""        # 为布尔字段生成默认值,根据实际情况修改
        default_value = False
        conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
"""
    else:
        script_content += f"""        # 为字段生成默认值,根据实际情况修改
        # default_value = ...
        # conn.execute(f"UPDATE {table_name} SET {field_name} = {{default_value}} WHERE id = {{record[0]}}")
        pass
"""
    
    # 添加索引和约束
    if unique or index:
        script_content += f"""    
    # 添加索引
    op.create_index('ix_{table_name}_{field_name}', '{table_name}', ['{field_name}'], unique={str(unique).lower()})
"""
    
    if not nullable:
        script_content += f"""    
    # 添加非空约束
    op.alter_column('{table_name}', '{field_name}', nullable=False)
"""
    
    # 添加降级函数
    script_content += f"""

def downgrade():
    # 删除 {field_name} 字段
"""
    
    if unique or index:
        script_content += f"""    op.drop_index('ix_{table_name}_{field_name}', table_name='{table_name}')
"""
    
    script_content += f"""    op.drop_column('{table_name}', '{field_name}')
"""
    
    # 写入文件
    with open(script_path, "w", encoding="utf-8") as f:
        f.write(script_content)
    
    print(f"迁移脚本已生成: {script_path}")


def main():
    parser = argparse.ArgumentParser(description="生成数据库迁移脚本")
    parser.add_argument("--table", required=True, help="表名")
    parser.add_argument("--field", required=True, help="字段名")
    parser.add_argument("--type", required=True, help="字段类型,如 String(36), Integer, Boolean 等")
    parser.add_argument("--nullable", action="store_true", help="是否可为空")
    parser.add_argument("--default", help="默认值")
    parser.add_argument("--comment", help="注释")
    parser.add_argument("--unique", action="store_true", help="是否唯一")
    parser.add_argument("--index", action="store_true", help="是否创建索引")
    
    args = parser.parse_args()
    
    generate_migration_script(
        args.table,
        args.field,
        args.type,
        args.nullable,
        args.default,
        args.comment,
        args.unique,
        args.index
    )


if __name__ == "__main__":
    main()