GaLoreを使って0.01Bモデル(EN)を作ってみる(モデルが保存できない)

初めに

LoRAよりもメモリ効率がよく学習ができる手法であるGaLoreで試してみます

論文のabstractの日本語訳は以下です(claude 3 opus を使用)

大規模言語モデル(LLM)の学習では、重みと最適化器の状態のサイズが増大するため、メモリに関する大きな課題があります。低ランク適応(LoRA)などの一般的なメモリ削減手法では、各層の凍結された事前学習済みの重みに、学習可能な低ランク行列を追加することで、学習可能なパラメータと最適化器の状態を削減します。しかし、このようなアプローチは、パラメータ探索を低ランク部分空間に制限し、学習のダイナミクスを変更するため、事前学習と微調整の両段階で、通常、フルランクの重みを用いた学習よりも性能が低下します。さらに、フルランクのウォームスタートが必要になる場合もあります。本研究では、フルパラメータ学習を可能にしながら、LoRAなどの一般的な低ランク適応手法よりもメモリ効率の良い学習戦略である、勾配低ランク射影(GaLore)を提案します。我々のアプローチは、最適化器の状態のメモリ使用量を最大で65.5%削減しながら、LLaMA 1Bと7Bのアーキテクチャを用いて最大19.7Bトークンを含むC4データセットで事前学習し、GLUEタスクでRoBERTaを微調整する際の効率とパフォーマンスを維持します。我々の8ビットGaLoreは、最適化器のメモリをさらに最大82.5%、BF16ベースラインと比較して学習メモリ全体を63.3%削減します。特筆すべきは、モデル並列化、チェックポイント、オフロード戦略を使用せずに、24GBメモリ(NVIDIA RTX 4090など)を搭載した民生用GPUで7Bモデルを事前学習できることを初めて実証したことです。

arxiv.org

環境環境

  • L4 GPU
  • ubuntu22.04

準備

まずはcloneします

git clone https://github.com/jiaweizzhao/GaLore

環境を作成します

python -m venv venv
source venv/bin/activate

ライブラリを入れます

pip install -e .

学習

まずは、パラメータを設定するために GaLore/configs/llama_10m.json に0.01B用のconfig.jsonを作成します。以下のパラメータは 0.02B及び0.04Bを参考にして0.01B用に書き換えています

{
    "architectures": [
        "LLaMAForCausalLM"
    ],
    "bos_token_id": 0,
    "eos_token_id": 1,
    "hidden_act": "silu",
    "hidden_size": 128,
    "intermediate_size": 344,
    "initializer_range": 0.02,
    "max_sequence_length": 1024,
    "model_type": "llama",
    "num_attention_heads": 2,
    "num_hidden_layers": 2,
    "pad_token_id": -1,
    "rms_norm_eps": 1e-06,
    "transformers_version": "4.28.1",
    "use_cache": true,
    "vocab_size": 32000
}

以下で学習をします。この時になるべく24GBのVRAMを使えるようにbatchサイズを調節しています(使用は13GBです)

torchrun --standalone --nproc_per_node 1 torchrun_main.py \
    --model_config configs/llama_10m.json \
    --lr 0.005 \
    --galore_scale 0.25 \
    --rank 1024 \
    --update_proj_gap 500 \
    --batch_size 256 \
    --total_batch_size 4096 \
    --activation_checkpointing \
    --num_training_steps 10000 \
    --warmup_steps 15000 \
    --weight_decay 0 \
    --grad_clipping 1.0 \
    --dtype bfloat16 \
    --eval_every 1000 \
    --single_gpu \
    --optimizer galore_adamw8bit_per_laye