Stable Diffusion WebUI (on Colab) : 🤗 Diffusers による LoRA 訓練

Stable Diffusion WebUI (on Colab) : 🤗 Diffusers による LoRA 訓練 (ブログ)

作成 : Masashi Okumura (@ClassCat)
作成日時 : 04/12/2023

* サンプルコードの動作確認はしておりますが、動作環境の違いやアップグレード等によりコードの修正が必要となるケースはあるかもしれません。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Website: www.classcat.com  ;   ClassCatJP

 

 

Stable Diffusion WebUI (on Colab) : 🤗 Diffusers による LoRA 訓練

前回 は Stable Diffusion WebUI における LoRA の利用方法について簡単に説明しました。LoRA は軽量で訓練にかかる時間やリソースも小さくてすみますので、自前の画像で訓練することも容易です。

この記事では、Google Colab 上で LoRA を訓練する方法について説明します。Stable Diffusion WebUI 用の LoRA の訓練は Kohya S. 氏が作成されたスクリプトをベースに遂行することが多いのですが、ここでは (🤗 Diffusers のドキュメントを数多く扱ってきたので) 🤗 Diffusers を使用して訓練と簡単な動作確認を行ない、その上で生成されたチェックポイントを WebUI で利用してみます。

 
前提条件 :

◇ 🤗 Diffusers で訓練と動作確認を行なうためには、Diffusers の最小限の知識があれば十分です。以下は良い入門です :

また、本記事は以下の 🤗 Diffusers の LoRA ドキュメントを分かりやすく解説したものでもあります :

◇ 作成したチェックポイントを Stable Diffusion WebUI で試すには、別途 WebUI の動作環境が必要です。その方法については以下の 1 と 3 を参照してください :

  1. PyTorch 2.0 : Google Colab で Stable Diffusion WebUI 入門
  2. Stable Diffusion WebUI (on Colab) : HuggingFace モデル / VAE の導入
  3. Stable Diffusion WebUI (on Colab) : LoRA の利用

 

データセット

LoRA の訓練と動作確認を試すという目的では、学習させたいコンセプトの解像度 512×512 の画像が数枚あれば十分ですが、ここでは Kohya S. 氏が提供している LoRA学習用サンプルデータ を試用します。これは鳥獣戯画をモチーフにしたカエルの画像で訓練用に 15 枚と正則化用に 50 枚の画像が収められています。以下は訓練用画像の例です :
※ 正則化用画像は過学習を防ぐ目的で使用されますが、今回は単純化のために使用しません。

 

訓練用画像の配置

それでは Google Colab 環境で実作業に入ります。GPU を有効にしておいてください。Tesla T4 で十分です :

!nvidia-smi
Tue Apr 11 22:08:33 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   46C    P8    12W /  70W |      3MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

訓練用画像や訓練による様々な生成ファイルを収める、作業用ディレクトリを /content/lora とします。訓練用画像用のディレクトリ /content/lora/train_data を作成した上で :

!mkdir -p /content/lora/train_data

Google ドライブ等を介して上記の訓練用画像、つまり (lora_train_sample_pack.zip を解凍して得られる) ‘train/20_usu frog’ にある png ファイルを次のように配置しましょう :

!ls -l /content/lora/train_data
total 844
-rw-r--r-- 1 root root 46780 Sep 26  2022  _c_choju2_0011_s1024_choju2_0011_6.png
-rw-r--r-- 1 root root 58232 Sep 26  2022 '[_c_]chojuganso0001_s1024_chojuganso0001_0.png'
-rw-r--r-- 1 root root 58188 Sep 26  2022  _c_chojuganso0002_s1024_chojuganso0002_0.png
-rw-r--r-- 1 root root 49067 Sep 26  2022  _c_chojuganso0003_s1024_chojuganso0003_0.png
-rw-r--r-- 1 root root 63014 Sep 26  2022  _c_chojuganso0005_s1024_chojuganso0005_0.png
-rw-r--r-- 1 root root 70118 Sep 26  2022  _c_chojuganso0009_s1024_chojuganso0009_0.png
-rw-r--r-- 1 root root 57497 Sep 26  2022  _c_chojuori0007_s1024_chojuori0007_0.png
-rw-r--r-- 1 root root 60775 Sep 26  2022  _c_chojuori0015_s1024_chojuori0015_0.png
-rw-r--r-- 1 root root 57468 Sep 26  2022  _c_chojuori0018_s1024_chojuori0018_0.png
-rw-r--r-- 1 root root 60281 Sep 26  2022  _c_chojuori0021_s1024_chojuori0021_0.png
-rw-r--r-- 1 root root 55912 Sep 26  2022  _c_chojuori0023_s1024_chojuori0023_1.png
-rw-r--r-- 1 root root 43152 Sep 26  2022  _c_chojuori0024_s1024_chojuori0024_1.png
-rw-r--r-- 1 root root 52116 Sep 26  2022  _c_chojuori0036_s1024_chojuori0036_0.png
-rw-r--r-- 1 root root 42102 Sep 26  2022  _c_chojuori0037_s1024_chojuori0037_0.png
-rw-r--r-- 1 root root 55867 Sep 26  2022  _c_chojuori0038_s1024_chojuori0038_0.png

 

環境構築

最新安定版 diffusers v0.14.0 で環境構築します :

!pip install diffusers==0.14.0 transformers accelerate safetensors

diffusers の examples にあるスクリプト diffusers/examples/dreambooth/train_dreambooth_lora.py を利用して訓練しますので、レポジトリも “git clone” しておきます。”v0.14.0″ のタグ指定を忘れずに :

!git clone https://github.com/huggingface/diffusers.git -b v0.14.0

 

Accelerate による訓練

🤗 Accelerate で train_dreambooth_lora.py を起動して LoRA を訓練します。

最初に accelerate 用の設定ファイルを作成する必要がありますが、それは以下で lora ディレクトリ内に自動生成できます :

from accelerate.utils import write_basic_config

write_basic_config(mixed_precision="fp16", save_location="/content/lora/default_config.yaml")
PosixPath('/content/lora/default_config.yaml')

一応、内容を確認しておきましょう :

!cat /content/lora/default_config.yaml
{
  "compute_environment": "LOCAL_MACHINE",
  "distributed_type": "NO",
  "downcast_bf16": false,
  "machine_rank": 0,
  "main_training_function": "main",
  "mixed_precision": "fp16",
  "num_machines": 1,
  "num_processes": 1,
  "rdzv_backend": "static",
  "same_network": false,
  "tpu_use_cluster": false,
  "tpu_use_sudo": false,
  "use_cpu": false
}

 

訓練開始

これで準備ができましたので、訓練を開始しましょう。幾つかの注意点は :

  • 事前訓練済みモデルは “runwayml/stable-diffusion-v1-5” を指定していますが、好みのモデルが利用できます。
  • プロンプト用のキーワードは “gigafrog” (戯画カエル) としましたが、これも何でも良いです。

Tesla T4 でおよそ 10 分ほどで訓練が完了します :

%%time

!accelerate launch --config_file="/content/lora/default_config.yaml" \
  /content/diffusers/examples/dreambooth/train_dreambooth_lora.py \
  --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"  \
  --instance_data_dir="/content/lora/train_data" \
  --output_dir="/content/lora/output" \
  --instance_prompt="gigafrog" \
  --resolution=512 \
  --train_batch_size=1 \
  --sample_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=200 \
  --learning_rate=1e-4 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=1000 \
  --seed="0"

訓練が完了すれば、”lora/output” ディレクトリにチェック・ポイントやログファイルが生成されています :

!ls -lF /content/lora/output
total 3236
drwxr-xr-x 2 root root    4096 Apr 11 22:04 checkpoint-1000/
drwxr-xr-x 2 root root    4096 Apr 11 21:58 checkpoint-200/
drwxr-xr-x 2 root root    4096 Apr 11 21:59 checkpoint-400/
drwxr-xr-x 2 root root    4096 Apr 11 22:01 checkpoint-600/
drwxr-xr-x 2 root root    4096 Apr 11 22:02 checkpoint-800/
drwxr-xr-x 3 root root    4096 Apr 11 21:57 logs/
-rw-r--r-- 1 root root 3287771 Apr 11 22:04 pytorch_lora_weights.bin

 

Diffusers で動作確認

作成されたチェック・ポイントを Diffusers で確認するのは簡単です。

まずは LoRA を使用しないで、指定したプロンプト “gigafrog” で画像生成を試してみましょう。この場合にはもちろん鳥獣戯画風味ではないカエルの画像が生成されます。

スケジューラは UniPCMultistepScheduler に変更しています :

import torch
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler

pipe = StableDiffusionPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=torch.float16,
    ).to('cuda')

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
from PIL import Image

def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    grid_w, grid_h = grid.size

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(4)]

images = pipe(prompt="gigafrog, masterpiece, best quality",
              negative_prompt="worst quality, low quality",
              generator=generator,
              num_inference_steps=20,
              height=512,
              width=512,
              num_images_per_prompt=4,
              guidance_scale=7.0).images

image_grid(images, 2, 2)

次に LoRA のチェック・ポイントを取り込んだ上で、同様に画像生成してみます :

pipe.unet.load_attn_procs("/content/lora/output/pytorch_lora_weights.bin")
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(4)]

images = pipe(prompt="gigafrog, masterpiece, best quality",
              negative_prompt="worst quality, low quality",
              generator=generator,
              num_inference_steps=20,
              height=512,
              width=512,
              num_images_per_prompt=4,
              guidance_scale=7.0).images

image_grid(images, 2, 2)

Looks Good ! 😎

 

Stable Diffusion WebUI で動作確認

Diffusers で LoRA を訓練して動作確認するのは簡単なのですが、Stable Diffusion WebUI にそのまま持ち込もうとするとチェックポイントの互換性の問題が出てしまいますので、少し工夫が必要です。

例えば以下のように変換してやれば上手くいきます :

import os;
import re;
import torch;
from safetensors.torch import save_file;

loraName = "gigafrog"

lora_output_dir = '/content/lora/output'

for root, dirs, files in os.walk(lora_output_dir):
  for dir in dirs:
    ckptIndex = re.search('^checkpoint\-(\d+)$', dir);
 
    if ckptIndex:
      newDict = dict();
      checkpoint = torch.load(os.path.join(lora_output_dir, dir, 'custom_checkpoint_0.pkl'));
      for idx, key in enumerate(checkpoint):

        newKey = re.sub('\.processor\.', '_', key);
        newKey = re.sub('mid_block\.', 'mid_block_', newKey);
        newKey = re.sub('_lora.up.', '.lora_up.', newKey);
        newKey = re.sub('_lora.down.', '.lora_down.', newKey);
        newKey = re.sub('\.(\d+)\.', '_\\1_', newKey);
        newKey = re.sub('to_out', 'to_out_0', newKey);
        newKey = 'lora_unet_'+newKey;

        newDict[newKey] = checkpoint[key];

      newLoraName = lora_output_dir + '/' + loraName + '-' + ckptIndex.group(1) + '.safetensors';
      print("Saving " + newLoraName);
      save_file(newDict, newLoraName);
Saving /content/lora/output/gigafrog-1000.safetensors
Saving /content/lora/output/gigafrog-200.safetensors
Saving /content/lora/output/gigafrog-400.safetensors
Saving /content/lora/output/gigafrog-600.safetensors
Saving /content/lora/output/gigafrog-800.safetensors

このコードは diffusers の以下の issue を参考にしました :

lora/output ディレクトリにそれぞれのチェックポイントに対応した safetensors ファイルが生成されています :

!ls -lF /content/lora/output
total 18996
drwxr-xr-x 2 root root    4096 Apr 11 22:04 checkpoint-1000/
drwxr-xr-x 2 root root    4096 Apr 11 21:58 checkpoint-200/
drwxr-xr-x 2 root root    4096 Apr 11 21:59 checkpoint-400/
drwxr-xr-x 2 root root    4096 Apr 11 22:01 checkpoint-600/
drwxr-xr-x 2 root root    4096 Apr 11 22:02 checkpoint-800/
-rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-1000.safetensors
-rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-200.safetensors
-rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-400.safetensors
-rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-600.safetensors
-rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-800.safetensors
drwxr-xr-x 3 root root    4096 Apr 11 21:57 logs/
-rw-r--r-- 1 root root 3287771 Apr 11 22:04 pytorch_lora_weights.bin

 
生成された gigafrog-1000.safetensors ファイルを WebUI の models/Lora ディレクトリに配備すれば利用可能となります :

 

以上