HuggingFace Diffusers 0.12 : 訓練 : Stable Diffusion テキスト-to-画像変換再調整 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/19/2023 (v0.12.1)
* 本ページは、HuggingFace Diffusers の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
HuggingFace Diffusers 0.12 : 訓練 : Stable Diffusion テキスト-to-画像変換再調整
train_text_to_image.py スクリプトは貴方自身のデータセットで stable diffusion モデルを再調整する方法を示します。
The text-to-image fine-tuning script is experimental. It’s easy to overfit and run into issues like catastrophic forgetting. We recommend to explore different hyperparameters to get the best results on your dataset.
ローカルで実行する
依存関係をインストールする
スクリプトを実行する前に、ライブラリの訓練依存関係を確実にインストールしてください :
pip install git+https://github.com/huggingface/diffusers.git
pip install -U -r requirements.txt
そして以下で 🤗 Accelerate 環境を初期化します :
accelerate config
重みをダウンロード、あるいは使用する前にモデルライセンスを承認する必要があります。この例ではモデルバージョン v1-4 を使用していますので、そのカード にアクセスしてライセンスを読んで (同意するならば) チェックボックスをチェックする必要があります。
貴方は 🤗 Hugging Face の登録ユーザである必要があり、コードが動作するためにはアクセストークンを使用する必要もあります。アクセストークンの詳細は、ドキュメントのこのセクション を参照してください。
トークンを認証するには以下のコマンドを実行します :
huggingface-cli login
既にレポジトリを複製したのであれば、これらのステップに進む必要はありません。代わりに、ローカルのチェックアウトへのパスを訓練スクリプトに渡すことができて、それはそこからロードされます。
再調整のためのハードウェア要件
gradient_checkpointing と mixed_precision を使用すると、モデルを単一 24GB GPU 上で再調整することができるはずです。より大きい batch_size と高速な訓練のためには、30GB 以上の GPU メモリを持つ GPU を利用するのが良いです。TPU や GPU で再調整するために JAX / Flax を使用することもできます、詳細は 以下 を見てください。
再調整サンプル
以下のスクリプトは Hugging Face ハブで利用可能な Justin Pinkneys のキャプション付き Pokemon データセット を使用して再調整の実行を起動します。
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir="sd-pokemon-model"
貴方自身の訓練ファイルで実行するには、データセットにより要求される形式に従ってデータセットを準備する必要があります。データセットをハブにアップロードすることもできますし、ファイルを含むローカルフォルダを準備することもできます。このドキュメント はそれを行う方法を説明しています。
カスタム・ローディング・ロジックを使用したい場合にはスクリプトを変更する必要があります。コードの適切な場所にポインタを残しました 🙂
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"
export OUTPUT_DIR="path_to_save_model"
accelerate launch train_text_to_image.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$TRAIN_DIR \
--use_ema \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--mixed_precision="fp16" \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--lr_scheduler="constant" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR}
訓練が終了すれば、モデルはコマンドで指定された OUTPUT_DIR にセーブされます。推論のために再調整済みモデルをロードするには、パスを StableDiffusionPipeline に渡すだけです :
from diffusers import StableDiffusionPipeline
model_path = "path_to_saved_model"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")
image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")
Flax / JAX 再調整
@duongna211 のおかげで、Flax を使用して Stable Diffusion を再調整することができます!これは TPU ハードウェア上で非常に効率的ですが、GPU でも素晴らしく動作します。このように Flax 訓練スクリプト を使用できます :
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export dataset_name="lambdalabs/pokemon-blip-captions"
python train_text_to_image_flax.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$dataset_name \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--max_train_steps=15000 \
--learning_rate=1e-05 \
--max_grad_norm=1 \
--output_dir="sd-pokemon-model"
以上