FLUX.1 : AI Toolkit による FLUX.1 LoRA トレーニング (2) RunPod で実践

AI Toolkit by Ostris の基本を抑えたので、今回は RunPod で LoRA トレーニングを実践してみます。
Google Colab でもまったく同様の手順でトレーニングできますが、A100 GPU でないと動作しないのが難点です。

FLUX.1 : AI Toolkit による FLUX.1 LoRA トレーニング (2) RunPod で実践

作成 : Masashi Okumura (@ClassCat)
作成日時 : 08/29/2024

* 本記事は github ostris/ai-toolkit の以下のページを参考にしています :

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

 

FLUX.1 : AI Toolkit による FLUX.1 LoRA トレーニング (2) RunPod で実践

前回 AI Toolkit by Ostris の基本 を抑えたので、今回は RunPod で LoRA トレーニングを実践してみます。

Google Colab でもまったく同様の手順でトレーニングできますが、A100 GPU でないと動作しないのが難点です。

 

RunPod

AI Toolkit by Ostris の README で推奨されている、ポッド仕様の例は以下です :

  • テンプレート: runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04

  • 1x A40 (48 GB VRAM)
  • 19 vCPU 100 GB RAM
  • ~120 GB Disk
  • ~120 GB Pod Volume

  • Start Jupyter Notebook

GPU は最低 24GB VRAM あれば何でも良いです。
ディスク・ボリュームは単にトレーニングして推論するだけなら 120GB も必要ないです

 

トレーニングデータ

トレーニングデータは、定番ですが、東北ずんこ・ずんだもんプロジェクトAI画像モデル用学習データ を使用させていただきました。

記載されている使用規約は :

  • 公式が配布する 東北3姉妹のイラストは非商用、もしくは東北企業であれば自由に使える そうです。
  • また、AI画像モデル用学習データについては :

    ずんずんPJ公式イラストを学習したデータのLoRAの配布はなどは大丈夫です(2次創作のデータが入る場合は2次創作者への配慮[2次創作者のガイドライン確認など]が別途必要になります)

 

実践

上記の仕様でポッドを起動し、Jupyter ノートブックでアクセスします。

 

ポッドの確認

最初にポッドの仕様を簡単に確認しましょう :

!nvidia-smi
Tue Aug 27 09:44:14 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.161.08             Driver Version: 535.161.08   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A40                     On  | 00000000:53:00.0 Off |                    0 |
|  0%   29C    P8              21W / 300W |      0MiB / 46068MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|  No running processes found                                                           |
+---------------------------------------------------------------------------------------+
!free -h
               total        used        free      shared  buff/cache   available
Mem:           503Gi        73Gi        28Gi       5.2Gi       401Gi       421Gi
Swap:             0B          0B          0B

 

トレーニングデータの準備

次にトレーニングデータを準備します。前述の「東北ずんこ・ずんだもんプロジェクト」の AI 画像モデル用学習データから 東北イタコ のデータを使用しましたが、もちろんお好みでどうぞ。

zip ファイルをアップロードして解凍します :

!unzip -q /workspace/itako.zip

 

環境構築

AI Toolkit by Ostris のレポジトリを複製して、インストールします :

!git clone https://github.com/ostris/ai-toolkit.git
%cd /workspace/ai-toolkit
!git submodule update --init --recursive
!pip install -U pip
!pip install -r requirements.txt

config ファイルはサンプルの train_lora_flux_24gb.yaml をコピーして使用します :

!cp -p /workspace/ai-toolkit/config/examples/train_lora_flux_24gb.yaml \
  /workspace/ai-toolkit/config/itako.yaml

書き換えが必須なのは以下の 2 箇所だけです :

  • トレーニングデータのパス
          datasets:
            # datasets are a folder of images. captions need to be txt files with the same name as the image
            # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
            # images will automatically be resized and bucketed into the resolution specified
            # on windows, escape back slashes with another backslash so
            # "C:\\path\\to\\images\\folder"
            - folder_path: "/path/to/images/folder"
    

     

  • トレーニング中にサンプル画像を生成するためのプロンプト
          sample:
            sampler: "flowmatch" # must match train.noise_scheduler
            sample_every: 250 # sample every this many steps
            width: 1024
            height: 1024
            prompts:
              # you can add [trigger] to the prompts here and it will be replaced with the trigger word
    #          - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
              - "woman with red hair, playing chess at the park, bomb going off in the background"
              - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
              - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
              - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
              - "a bear building a log cabin in the snow covered mountains"
              - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
              - "hipster man with a beard, building a chair, in a wood shop"
              - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
              - "a man holding a sign that says, 'this is a sign'"
              - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
    

それから、Hugging Face のアクセストークンを含む .env ファイルをルートに配置しておくと、”black-forest-labs/FLUX.1-dev” を自動的にダウンロードしてくれます :

import os
hf_token = os.getenv('HF_TOKEN')

with open('/workspace/ai-toolkit/.env', 'w') as f:
    f.write(f"HF_TOKEN={hf_token}")

※ HF_TOKEN はポッド起動時に設定しておいた環境変数です。

 

トレーニング

それではトレーニングしてみましょう。デフォルトの 2,000 ステップ実行しましたが、1,000 〜 1,500 ステップくらいでも十分です :

%%time

%cd /workspace/ai-toolkit

!python run.py /workspace/ai-toolkit/config/itako.yaml

 
0 ステップ

 
500 ステップ

 
1,000 ステップ

 
2,000 ステップ

 

推論

トレーニングが完了したら、サンプリングしてみます :

import torch
from diffusers import  FluxPipeline

pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
    token=hf_token,
)
pipe.enable_model_cpu_offload()
pipe.load_lora_weights("/workspace/ai-toolkit/output/my_first_flux_lora_v1/my_first_flux_lora_v1.safetensors")
%%time

prompt = "itako, 1girl, solo, kimono, hagoromo, tabi, animal ear fluff, hair ornament, dancing"

out = pipe(
    prompt=prompt,
    guidance_scale=3.5,
    width=1024,
    height=1024,
    num_inference_steps=40,
).images[0]

out.save("/workspace/image.png")

 

以上