Stable Diffusion WebUI (on Colab) : 🤗 Diffusers による LoRA 訓練 (ブログ)
作成 : Masashi Okumura (@ClassCat)
作成日時 : 04/12/2023
* サンプルコードの動作確認はしておりますが、動作環境の違いやアップグレード等によりコードの修正が必要となるケースはあるかもしれません。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Website: www.classcat.com ; ClassCatJP
Stable Diffusion WebUI (on Colab) : 🤗 Diffusers による LoRA 訓練
前回 は Stable Diffusion WebUI における LoRA の利用方法について簡単に説明しました。LoRA は軽量で訓練にかかる時間やリソースも小さくてすみますので、自前の画像で訓練することも容易です。
この記事では、Google Colab 上で LoRA を訓練する方法について説明します。Stable Diffusion WebUI 用の LoRA の訓練は Kohya S. 氏が作成されたスクリプトをベースに遂行することが多いのですが、ここでは (🤗 Diffusers のドキュメントを数多く扱ってきたので) 🤗 Diffusers を使用して訓練と簡単な動作確認を行ない、その上で生成されたチェックポイントを WebUI で利用してみます。
前提条件 :
◇ 🤗 Diffusers で訓練と動作確認を行なうためには、Diffusers の最小限の知識があれば十分です。以下は良い入門です :
また、本記事は以下の 🤗 Diffusers の LoRA ドキュメントを分かりやすく解説したものでもあります :
◇ 作成したチェックポイントを Stable Diffusion WebUI で試すには、別途 WebUI の動作環境が必要です。その方法については以下の 1 と 3 を参照してください :
- PyTorch 2.0 : Google Colab で Stable Diffusion WebUI 入門
- Stable Diffusion WebUI (on Colab) : HuggingFace モデル / VAE の導入
- Stable Diffusion WebUI (on Colab) : LoRA の利用
データセット
LoRA の訓練と動作確認を試すという目的では、学習させたいコンセプトの解像度 512×512 の画像が数枚あれば十分ですが、ここでは Kohya S. 氏が提供している LoRA学習用サンプルデータ を試用します。これは鳥獣戯画をモチーフにしたカエルの画像で訓練用に 15 枚と正則化用に 50 枚の画像が収められています。以下は訓練用画像の例です :
※ 正則化用画像は過学習を防ぐ目的で使用されますが、今回は単純化のために使用しません。
訓練用画像の配置
それでは Google Colab 環境で実作業に入ります。GPU を有効にしておいてください。Tesla T4 で十分です :
!nvidia-smi
Tue Apr 11 22:08:33 2023 +-----------------------------------------------------------------------------+ | NVIDIA-SMI 525.85.12 Driver Version: 525.85.12 CUDA Version: 12.0 | |-------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |===============================+======================+======================| | 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 | | N/A 46C P8 12W / 70W | 3MiB / 15360MiB | 0% Default | | | | N/A | +-------------------------------+----------------------+----------------------+ +-----------------------------------------------------------------------------+ | Processes: | | GPU GI CI PID Type Process name GPU Memory | | ID ID Usage | |=============================================================================| | No running processes found | +-----------------------------------------------------------------------------+
訓練用画像や訓練による様々な生成ファイルを収める、作業用ディレクトリを /content/lora とします。訓練用画像用のディレクトリ /content/lora/train_data を作成した上で :
!mkdir -p /content/lora/train_data
Google ドライブ等を介して上記の訓練用画像、つまり (lora_train_sample_pack.zip を解凍して得られる) ‘train/20_usu frog’ にある png ファイルを次のように配置しましょう :
!ls -l /content/lora/train_data
total 844 -rw-r--r-- 1 root root 46780 Sep 26 2022 _c_choju2_0011_s1024_choju2_0011_6.png -rw-r--r-- 1 root root 58232 Sep 26 2022 '[_c_]chojuganso0001_s1024_chojuganso0001_0.png' -rw-r--r-- 1 root root 58188 Sep 26 2022 _c_chojuganso0002_s1024_chojuganso0002_0.png -rw-r--r-- 1 root root 49067 Sep 26 2022 _c_chojuganso0003_s1024_chojuganso0003_0.png -rw-r--r-- 1 root root 63014 Sep 26 2022 _c_chojuganso0005_s1024_chojuganso0005_0.png -rw-r--r-- 1 root root 70118 Sep 26 2022 _c_chojuganso0009_s1024_chojuganso0009_0.png -rw-r--r-- 1 root root 57497 Sep 26 2022 _c_chojuori0007_s1024_chojuori0007_0.png -rw-r--r-- 1 root root 60775 Sep 26 2022 _c_chojuori0015_s1024_chojuori0015_0.png -rw-r--r-- 1 root root 57468 Sep 26 2022 _c_chojuori0018_s1024_chojuori0018_0.png -rw-r--r-- 1 root root 60281 Sep 26 2022 _c_chojuori0021_s1024_chojuori0021_0.png -rw-r--r-- 1 root root 55912 Sep 26 2022 _c_chojuori0023_s1024_chojuori0023_1.png -rw-r--r-- 1 root root 43152 Sep 26 2022 _c_chojuori0024_s1024_chojuori0024_1.png -rw-r--r-- 1 root root 52116 Sep 26 2022 _c_chojuori0036_s1024_chojuori0036_0.png -rw-r--r-- 1 root root 42102 Sep 26 2022 _c_chojuori0037_s1024_chojuori0037_0.png -rw-r--r-- 1 root root 55867 Sep 26 2022 _c_chojuori0038_s1024_chojuori0038_0.png
環境構築
最新安定版 diffusers v0.14.0 で環境構築します :
!pip install diffusers==0.14.0 transformers accelerate safetensors
diffusers の examples にあるスクリプト diffusers/examples/dreambooth/train_dreambooth_lora.py を利用して訓練しますので、レポジトリも “git clone” しておきます。”v0.14.0″ のタグ指定を忘れずに :
!git clone https://github.com/huggingface/diffusers.git -b v0.14.0
Accelerate による訓練
🤗 Accelerate で train_dreambooth_lora.py を起動して LoRA を訓練します。
最初に accelerate 用の設定ファイルを作成する必要がありますが、それは以下で lora ディレクトリ内に自動生成できます :
from accelerate.utils import write_basic_config
write_basic_config(mixed_precision="fp16", save_location="/content/lora/default_config.yaml")
PosixPath('/content/lora/default_config.yaml')
一応、内容を確認しておきましょう :
!cat /content/lora/default_config.yaml
{ "compute_environment": "LOCAL_MACHINE", "distributed_type": "NO", "downcast_bf16": false, "machine_rank": 0, "main_training_function": "main", "mixed_precision": "fp16", "num_machines": 1, "num_processes": 1, "rdzv_backend": "static", "same_network": false, "tpu_use_cluster": false, "tpu_use_sudo": false, "use_cpu": false }
訓練開始
これで準備ができましたので、訓練を開始しましょう。幾つかの注意点は :
- 事前訓練済みモデルは “runwayml/stable-diffusion-v1-5” を指定していますが、好みのモデルが利用できます。
- プロンプト用のキーワードは “gigafrog” (戯画カエル) としましたが、これも何でも良いです。
Tesla T4 でおよそ 10 分ほどで訓練が完了します :
%%time
!accelerate launch --config_file="/content/lora/default_config.yaml" \
/content/diffusers/examples/dreambooth/train_dreambooth_lora.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
--instance_data_dir="/content/lora/train_data" \
--output_dir="/content/lora/output" \
--instance_prompt="gigafrog" \
--resolution=512 \
--train_batch_size=1 \
--sample_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=200 \
--learning_rate=1e-4 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=1000 \
--seed="0"
訓練が完了すれば、”lora/output” ディレクトリにチェック・ポイントやログファイルが生成されています :
!ls -lF /content/lora/output
total 3236 drwxr-xr-x 2 root root 4096 Apr 11 22:04 checkpoint-1000/ drwxr-xr-x 2 root root 4096 Apr 11 21:58 checkpoint-200/ drwxr-xr-x 2 root root 4096 Apr 11 21:59 checkpoint-400/ drwxr-xr-x 2 root root 4096 Apr 11 22:01 checkpoint-600/ drwxr-xr-x 2 root root 4096 Apr 11 22:02 checkpoint-800/ drwxr-xr-x 3 root root 4096 Apr 11 21:57 logs/ -rw-r--r-- 1 root root 3287771 Apr 11 22:04 pytorch_lora_weights.bin
Diffusers で動作確認
作成されたチェック・ポイントを Diffusers で確認するのは簡単です。
まずは LoRA を使用しないで、指定したプロンプト “gigafrog” で画像生成を試してみましょう。この場合にはもちろん鳥獣戯画風味ではないカエルの画像が生成されます。
スケジューラは UniPCMultistepScheduler に変更しています :
import torch
from diffusers import StableDiffusionPipeline, UniPCMultistepScheduler
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
).to('cuda')
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
from PIL import Image
def image_grid(imgs, rows, cols):
assert len(imgs) == rows * cols
w, h = imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size
for i, img in enumerate(imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(4)]
images = pipe(prompt="gigafrog, masterpiece, best quality",
negative_prompt="worst quality, low quality",
generator=generator,
num_inference_steps=20,
height=512,
width=512,
num_images_per_prompt=4,
guidance_scale=7.0).images
image_grid(images, 2, 2)
次に LoRA のチェック・ポイントを取り込んだ上で、同様に画像生成してみます :
pipe.unet.load_attn_procs("/content/lora/output/pytorch_lora_weights.bin")
generator = [torch.Generator(device="cpu").manual_seed(i) for i in range(4)]
images = pipe(prompt="gigafrog, masterpiece, best quality",
negative_prompt="worst quality, low quality",
generator=generator,
num_inference_steps=20,
height=512,
width=512,
num_images_per_prompt=4,
guidance_scale=7.0).images
image_grid(images, 2, 2)
Looks Good ! 😎
Stable Diffusion WebUI で動作確認
Diffusers で LoRA を訓練して動作確認するのは簡単なのですが、Stable Diffusion WebUI にそのまま持ち込もうとするとチェックポイントの互換性の問題が出てしまいますので、少し工夫が必要です。
例えば以下のように変換してやれば上手くいきます :
import os;
import re;
import torch;
from safetensors.torch import save_file;
loraName = "gigafrog"
lora_output_dir = '/content/lora/output'
for root, dirs, files in os.walk(lora_output_dir):
for dir in dirs:
ckptIndex = re.search('^checkpoint\-(\d+)$', dir);
if ckptIndex:
newDict = dict();
checkpoint = torch.load(os.path.join(lora_output_dir, dir, 'custom_checkpoint_0.pkl'));
for idx, key in enumerate(checkpoint):
newKey = re.sub('\.processor\.', '_', key);
newKey = re.sub('mid_block\.', 'mid_block_', newKey);
newKey = re.sub('_lora.up.', '.lora_up.', newKey);
newKey = re.sub('_lora.down.', '.lora_down.', newKey);
newKey = re.sub('\.(\d+)\.', '_\\1_', newKey);
newKey = re.sub('to_out', 'to_out_0', newKey);
newKey = 'lora_unet_'+newKey;
newDict[newKey] = checkpoint[key];
newLoraName = lora_output_dir + '/' + loraName + '-' + ckptIndex.group(1) + '.safetensors';
print("Saving " + newLoraName);
save_file(newDict, newLoraName);
Saving /content/lora/output/gigafrog-1000.safetensors Saving /content/lora/output/gigafrog-200.safetensors Saving /content/lora/output/gigafrog-400.safetensors Saving /content/lora/output/gigafrog-600.safetensors Saving /content/lora/output/gigafrog-800.safetensors
このコードは diffusers の以下の issue を参考にしました :
lora/output ディレクトリにそれぞれのチェックポイントに対応した safetensors ファイルが生成されています :
!ls -lF /content/lora/output
total 18996 drwxr-xr-x 2 root root 4096 Apr 11 22:04 checkpoint-1000/ drwxr-xr-x 2 root root 4096 Apr 11 21:58 checkpoint-200/ drwxr-xr-x 2 root root 4096 Apr 11 21:59 checkpoint-400/ drwxr-xr-x 2 root root 4096 Apr 11 22:01 checkpoint-600/ drwxr-xr-x 2 root root 4096 Apr 11 22:02 checkpoint-800/ -rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-1000.safetensors -rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-200.safetensors -rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-400.safetensors -rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-600.safetensors -rw-r--r-- 1 root root 3227438 Apr 11 22:18 gigafrog-800.safetensors drwxr-xr-x 3 root root 4096 Apr 11 21:57 logs/ -rw-r--r-- 1 root root 3287771 Apr 11 22:04 pytorch_lora_weights.bin
生成された gigafrog-1000.safetensors ファイルを WebUI の models/Lora ディレクトリに配備すれば利用可能となります :
以上