HuggingFace ブログ : JAX / Flax で 🧨 Stable Diffusion !

HuggingFace ブログ : JAX / Flax で 🧨 Stable Diffusion ! (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/11/2022

* 本ページは、HuggingFace Blog の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

HuggingFace ブログ : JAX / Flax で 🧨 Stable Diffusion !

🤗 Hugging Face Diffusers はバージョン 0.5.1 から Flax をサポートします!これは Colab, Kaggle or Google Cloud Platform などで利用可能な Google TPU 上で超高速な推論を可能にします。

この記事は JAX / Flax を使用して推論を実行する方法を示します。Stable Diffusion がどのように動作するかの詳細を望むか、あるいはそれを GPU で実行したい場合には、この Colab ノートブック を参照してください。

まず、TPU バックエンドを使用していることを確認してください。このノートブックを Colab で実行している場合には、上のメニューで「ランタイム」を選択し、オプション「ランタイムのタイプを変更」を選択してから、「ハードウェア アクセラレータ」設定で TPU を選択します。

JAX は TPU 専用ではないことに注意してください、しかしそれはそのハードウェア上で輝きます、各 TPU サーバが並列に動作する 8 TPU アクセラレータを持つからです。

 

セットアップ

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"
    Found 8 JAX devices of type TPU v2.

diffusers がインストールされていることを確認します。

!pip install diffusers==0.5.1

そしてすべての依存関係をインポートします。

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

 

モデルのロード

モデルを使用する前に、重みをダウンロードして使用するためにモデルライセンスを承認する必要があります。

ライセンスは、そのようなパワフルな機械学習システムの潜在的な弊害を軽減するように設計されています。私たちはユーザがライセンス全体を注意深く読むことを要求します。ここに要約を提供します :

  1. 違法あるいは有害な出力やコンテンツを意図的に生成または共有するためにモデルを使用できません、

  2. 私たちは貴方が生成した出力の権利を主張しません、貴方はそれらを自由に使用できて、そしてライセンスの条件に反するべきではないそれらの使用について責任を負います、そして

  3. 重みを再配布してモデルを商用 and/or サービスとして利用しても良いです。それを行なう場合、ライセンスのものと同じ使用制限を含めて、そして CreativeML OpenRAIL-M のコピーをすべてのユーザと共有しなければならないことに留意してください。

Flax の重みは Stable Diffusion レポジトリの一部として Hugging Face ハブで利用可能です。Stable Diffusion モデルは CreateML OpenRail-M ライセンスのもとで配布されています。それはオープンライセンスで、貴方が生成した出力の権利は主張せず、違法 or 有害なコンテンツを意図的に生成することを禁止するものです。モデルカード は詳細を提供していますので、それを読んでライセンスを承認するか注意深く考える時間を少し取ってください。それを行なう場合、コードを動作させるには貴方はハブの登録ユーザであり、アクセストークンを使用する必要があります。貴方のアクセストークンを提供する 2 つのオプションがあります :

  1. ターミナルで huggingface-cli login コマンドラインツールを使用してプロンプトに従って貴方のトークンをペーストします。それは貴方のコンピュータのファイル內でセーブされます。

  2. または、ノートブックで notebook_login() を使用します、これも同じことを行います。

このコンピュータで前に既に認証していない限りは、次のセルは login インターフェイスを提示します、貴方のアクセストークンをペーストする必要があります。

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

TPU デバイスは bfloat16 をサポートします、効率的な半精度浮動小数点型です。それをテストのために使用しますが、代わりに完全精度を使用するために float32 も使用できます。

dtype = jnp.bfloat16

Flax は関数型フレームワークですので、モデルはステートレスでパラメータはそれらの外側でストアされます。事前訓練済み Flax パイプラインのロードはパイプライン自体とモデル重み (or パラメータ) の両者を返します。重みの bf16 バージョンを使用しています、これは型の警告に繋がりますが、安全に無視できます。

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

 

推論

TPU は通常は並列に動作する 8 デバイスを持ちますので、デバイスの数だけプロンプトを複製します。それから 8 デバイス上で同時に推論を実行します、各々が一つの画像を生成することを担います。このように、単一の画像を生成する 1 チップのためにかかる時間と同じ量で 8 画像を取得できます。

プロンプトの複製後、パイプラインの prepare_inputs 関数を呼び出してトークン化されたテキストの id を取得します。トークン化テキストの長さは、基礎となる CLIP テキストモデルの設定に従って 77 トークンに設定されています、

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
    (8, 77)

 

レプリケーションと並列化

モデルパラメータと入力は 8 つの並列デバイスに渡り複製されなければなりません。パラメータ辞書は flax.jax_utils.replicate を使用して複製され、これは辞書を traverse して重みの shape を変更し、それらが 8 回繰り返されるようにします。配列は shard を使用して複製されます。

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
   (8, 1, 77)

その shape はつまり、8 デバイスの各々は入力として shape (1, 77) の jnp 配列を受け取るということです。従って 1 はデバイス毎のバッチサイズです。十分なメモリを持つ TPU では、(チップ毎に) 複数の画像を同時に生成したい場合、それは 1 より大きいかもしれません。

画像を生成する準備が殆どできました!生成関数に渡す乱数 generator を作成する必要があるだけです。これは Flax では標準的な手順で、これは乱数について非常に厳しく意固地なものです – 乱数を扱うすべての関数は generator を受け取ることが想定されています。これは、複数の分散デバイスに渡り訓練しているときでさえも、再現性を保証します。

下のヘルパー関数は乱数 generator を初期化するためにシードを使用します。同じシードを使用する限り、正確に同じ結果を取得します。後でノートブックで結果を調べるとき、自由に異なるシードを使用してください。

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

rng を取得してからそれを 8 回「分割」して各デバイスが異なる generator を受け取るようにします。従って、各デバイスは異なる画像を作成して、全過程が再現可能です。

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX コードは非常に高速に実行できる効率的な表現にコンパイルできます。けれども、続く呼び出しですべての入力が同じ shape を持つことを保証する必要があります ; そうでなければ、JAX はコードを再コンパイルする必要があり、最適化された速度を活用することができません。

Flax パイプラインは引数として jit = True 引数を渡す場合、コードをコンパイルすることができます。それはまたモデルが 8 つの利用可能なデバイスで並列に動作することも保証します。

以下のセルを最初に実行したときはコンパイルに長い時間かかりますが、続く呼び出しは (異なる入力でさえも) 遥かに高速です。例えば、私がテストしたとき TPU v2-8 でコンパイルするのに 1 分以上かかりましたが、その先の推論実行のためには 7s かかるだけです。

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
    CPU times: user 464 ms, sys: 105 ms, total: 569 ms
    Wall time: 7.07 s

返された配列は shape (8, 1, 512, 512, 3) を持ちます。2 番目の次元を取り除くためにそれを reshape して 512 × 512 × 3 の 8 画像を得てからそれらを PIL に変換します。

images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

 

可視化

画像をグリッドで表示するヘルパー関数を作成しましょう。

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
image_grid(images, 2, 4)

 

異なるプロンプトの使用

すべてのデバイスで同じプロンプトを複製する必要はありません。
望むことは何でもできます : 2 つのプロンプトを 4 回生成したり、一度に 8 つの異なるプロンプトを生成することさえ可能です。Let’s do that!

入力準備のコードを便利な関数にリファクタリングします :

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)

 

以上