SDXL 用 LoRA 訓練 (on Google Colab) (ブログ)
作成 : Masashi Okumura (@ClassCat)
作成日時 : 10/21/2023
* サンプルコードの動作確認はしておりますが、動作環境によりコードの追加変更が必要な場合はあるかもしれません。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Website: www.classcat.com ; ClassCatJP
SDXL 用 LoRA 訓練 (on Google Colab)
Stable Diffusion のポピュラーな訓練スクリプト・ツール kohya-ss/sd-scripts が SDXL をサポートしました。Google Colab 上で動作確認できましたので、具体的な手順を簡単に記述しておきます。
注意点としては準備する訓練データ画像の解像度は 1024×1024 が軸になりますので、GPU メモリを大きく消費することです。取り敢えず動作確認をするだけであれば、解像度を下げれば Tesla T4 でもギリギリ動作しますが、やはり A100 の利用が望ましいです。
1. 概要
SD 1.5 で kohya-ss/sd-scripts による LoRA 訓練の経験があれば、基本的な流れに違いはありません。スクリプト “sdxl_train_network.py” を使用すれば良いです。
2. kohya-ss/sd-scripts
最初に kohya-ss/sd-scripts を “git clone” して依存関係をインストールしましょう :
!git clone https://github.com/kohya-ss/sd-scripts
%cd /content/sd-scripts
!pip install --upgrade -q -r requirements.txt
!pip install -q xformers==0.0.22 triton
!pip install -q bitsandbytes
3. データセット
データセットは通常通りに (画像ファイル、キャプションファイル) のペアを用意します。またデータセットについて記述した toml 設定ファイルも用意します。
ここでは以下のようなディレクトリ構成を想定します :
/content/LoRA
├── config
│ └── dataset_config.toml
└── train_data
├── 0001.png
├── 0001.txt
├── 0xxx.png
└── 0xxx.txt
toml 設定ファイルの内容は以下のようなものです :
[general]
shuffle_caption = true
enable_bucket = true
caption_extension = ".txt"
[[datasets]]
resolution = 1024
min_bucket_reso = 768
max_bucket_reso = 1280
[[datasets.subsets]]
image_dir = "/content/LoRA/train_data"
class_tokens = "mysample"
num_repeats = 1
Hugging Face Accelerate
Accelerate 用の設定ファイルも /content/LoRA/config に配置することにします :
import os
accelerate_config = "/content/LoRA/config/accelerate_config.yaml"
from accelerate.utils import write_basic_config
if not os.path.exists(accelerate_config):
write_basic_config(save_location=accelerate_config)
/content/LoRA
├── config
│ ├── accelerate_config.yaml
│ └── dataset_config.toml
└── train_data
├── 0001.png
├── 0001.txt
├── 0002.png
├── 0002.txt
...
SDXL の取得配置
SDXL は /content/pretrained_model に配置しました :
!mkdir -p /content/pretrained_model
!wget https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors \
-O /content/pretrained_model/sd_xl_base_1.0.safetensors
訓練の実行
必要最小限の引数とともに訓練を実行します。出力先は /content/LoRA/output です。
バッチサイズは 1 にしてありますが、GPU メモリに応じて増やすことができます。
%%time
%cd /content/sd-scripts
!accelerate launch \
--config_file=/content/LoRA/config/accelerate_config.yaml \
sdxl_train_network.py \
--dataset_config=/content/LoRA/config/dataset_config.toml \
--pretrained_model_name_or_path=/content/pretrained_model/sd_xl_base_1.0.safetensors \
--network_module=networks.lora \
--mixed_precision="fp16" \
--network_dim=8 --network_alpha=4 \
--learning_rate=1e-4 --optimizer_type="AdamW8bit" \
--output_dir=/content/LoRA/output --output_name="mysample" \
--save_precision="fp16" --save_every_n_epochs=1 --save_model_as=safetensors \
--max_train_epochs=5 --train_batch_size=1 \
--xformers --no_half_vae
最終的なディレクトリ構成は以下のようなものです :
/content/LoRA
├── config
│ ├── accelerate_config.yaml
│ └── dataset_config.toml
├── output
│ ├── mysample-000001.safetensors
│ ├── mysample-000002.safetensors
│ ├── mysample-000003.safetensors
│ ├── mysample-000004.safetensors
│ └── mysample.safetensors
└── train_data
├── 0001.png
├── 0001.txt
├── 0002.png
├── 0002.txt
├── 0100.png
└── 0100.txt
/content/pretrained_model
└── sd_xl_base_1.0.safetensors
/content/sd-scripts
以上