HuggingFace ブログ : 注釈付き拡散モデル

HuggingFace ブログ : 注釈付き拡散モデル (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/06/2022

* 本ページは、HuggingFace Blog の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

HuggingFace ブログ : 注釈付き拡散モデル

このブログ記事では、ノイズ除去拡散確率モデル (DDPM, 拡散モデル, スコアベース生成モデル or 単純に オートエンコーダ としても知られています) を深く見ていきます、研究者は (非) 条件付き画像/音声/動画生成に対してそれらで顕著な結果を獲得することができています。(執筆時点の) ポピュラーな例は OpenAI による GLIDEDALL-E 2、ハイデルベルク大学による Latent Diffusion、そして Google Brain による ImageGen を含みます。

( Ho et al., 2020 ) による元の DDPM 論文を詳しく調べて、Phil Wang の 実装 に基づいて PyTorch で一歩ずつそれを実装します – それ自体は 元の TensorFlow 実装 に基づいています。生成モデリングのための拡散のアイデアは実際には既に ( Sohl-Dickstein et al., 2015 ) で導入されていることに注意してください。
けれどもそのアプローチを独立に改良したのは ( Song et al., 2019 ) (at Stanford University), それから ( Ho et al., 2020 ) (at Google Brain) までかかりました。

拡散モデルには様々な視点 (= perspectives) があることに注意してください。ここでは、離散時間 (潜在変数モデル) の視点を使用しますが、他の視点も必ず確認してください。

Alright, let’s dive in!

from IPython.display import Image
Image(filename='assets/78_annotated-diffusion/ddpm_paper.png')

まず必要なライブラリをインストールしてインポートします (PyTorch はインストールしているものとします)。

!pip install -q -U einops datasets matplotlib tqdm

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

 

What is a diffusion model?

(ノイズ除去) 拡散モデルは、それを正規化フロー, GAN や VAE のような他の生成モデルと比較した場合、それほど複雑ではありません : それらはすべて幾つかの単純な分布からのノイズをデータサンプルに変換します。これもまた、ニューラルネットワークが純粋なノイズから始めてデータを徐々にノイズ除去することを学習します。

画像についてのもう少しの詳細は、セットアップは 2 つのプロセスから成ります :

  • 私達の選択の、固定された (or 事前定義された) forword 拡散過程 $q$ は、純粋なノイズで終わるまで、画像に徐々にガウシアンノイズを追加します。

  • 学習されたreverse ノイズ除去拡散過程 $p_\theta$、そこでは純粋なノイズから始めて、実際の画像で終了するまで、ニューラルネットワークは画像を徐々にノイズ除去するように訓練されます。

$t$ でインデックスされる forward と reverse 過程の両者はある有限の時間ステップ $T$ に対して起きます (DDPM 著者は \(T=1000\) を使用)。\(t=0\) から始めてそこでは貴方のデータ分布から実画像 \(\mathbf{x}_0\) をサンプリングして (ImageNet からの猫の画像としましょう)、そして forward 過程は各時間ステップ \(t\) でガウス分布からあるノイズをサンプリングします、それが前の時間ステップの画像に追加されます。十分に大きい \(T\) と各時間ステップでノイズを追加する well-behaved なスケジュールが与えられたとき、漸進的な過程を通して \(t=T\) で 等方性 (= isotropic) ガウス分布 と呼ばれるもので終了します。

 

より数学的な形式で

これをより形式的に書き下しましょう、究極的には、ニューラルネットワークが最適化する必要がある、扱いやすい損失関数を必要とします。

\(q(\mathbf{x}_0)\) を例えば「実画像」の実データ分布とします。画像, \(\mathbf{x}_0 \sim q(\mathbf{x}_0)\) を得るためにこの分布からサンプリングできます。forward 拡散過程 \(q(\mathbf{x}_t | \mathbf{x}_{t-1})\) を以下のように定義します、これは既知の分散スケジュール \(0 < \beta_1 < \beta_2 < ... < \beta_T < 1\) に従って各時間ステップ \(t\) でガウスノイズを追加します : $$ q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I}). $$ 正規分布 (ガウス分布とも呼ばれます) は 2 つのパラメータ : 平均 \(\mu\) と分散 \(\sigma^2 \geq 0\) で定義されることを思い出してください。 基本的には、時間ステップ \(t\) の各新しい (僅かにノイズが多い) 画像は \(\mathbf{\mu}_t = \sqrt{1 - \beta_t} \mathbf{x}_{t-1}\) と \(\sigma^2_t = \beta_t\) による 条件付きガウス分布 からドローされ、これは \(\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\) をサンプリングしてから \(\mathbf{x}_t = \sqrt{1 – \beta_t} \mathbf{x}_{t-1} + \sqrt{\beta_t} \mathbf{\epsilon}\) を設定することにより行なうことができます。

\(\beta_t\) は各時間ステップ \(t\) で定数ではないことに注意してください (そのため添字) — 実際にいわゆる「分散スケジュール」を定義します、これは更に見るように線形, 二次, コサイン等である可能性があります。

よって \(\mathbf{x}_0\) から始めて、\(\mathbf{x}_1, …, \mathbf{x}_t, …, \mathbf{x}_T\) で終了します、ここで \(\mathbf{x}_T\) はスケジュールを適切に設定した場合純粋なガウスノイズです。

今、条件付き分布 \(p(\mathbf{x}_{t-1} | \mathbf{x}_t)\) を知っていれば、過程を逆に実行できるでしょう : あるランダムなガウスノイズ \(\mathbf{x}_T\) をサンプリングしてから、それを徐々に「ノイズ除去」して、その結果、実分布 \(\mathbf{x}_0\) からのサンプルが得られます。

けれども、私達は \(p(\mathbf{x}_{t-1} | \mathbf{x}_t)\) を知りません。この条件付き確率を計算するためにはすべての可能な画像の分布を知る必要があるため、手に負えません。そのため、この 条件付き確率分布を近似 (学習) する ためにニューラルネットワークを利用していきます、それを \(\theta\) を (勾配降下で更新される) ニューラルネットワークのパラメータとする \(p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t)\) としましょう。

Ok, 従って backward 過程の (条件付き) 確率分布を表すニューラルネットワークが必要です。この逆過程もガウス分布であると仮定する場合、任意のガウス分布は 2 つのパラメータで定義されることを思い出してください :

  • \(\mu_\theta\) でパラメータ化される平均 ;
  • \(\Sigma_\theta\) でパラメータ化される分散 ;

従って過程を次のようにパラメータ化できます :

$$ p_\theta (\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \mu_\theta(\mathbf{x}_{t},t), \Sigma_\theta (\mathbf{x}_{t},t))$$

ここで平均と分散はまたノイズレベル \(t\) で条件付けられます。

このように、ニューラルネットワークは平均と分散を学習/表現する必要があります。けれども、DDPM の著者は 分散を固定することを決め、ニューラルネットワークにはこの条件付き確率分布の平均 \(\mu_\theta\) だけを学習させます。論文から :

まず、\(\Sigma_\theta ( \mathbf{x}_t, t) = \sigma^2_t \mathbf{I}\) を訓練されない時間依存な定数として設定します。実験では、\(\sigma^2_t = \beta_t\) と \(\sigma^2_t = \tilde{\beta}_t\) (論文参照) の両者は同様の結果を持ちました。

これはその後 Improved diffusion models 論文で改良されました、そこではニューラルネットワークはこの backwards 過程の平均に加えて分散も学習します。

そして、私たちのニューラルネットワークはこの条件付き確率分布の平均だけを学習/表現する必要があるものと仮定して、続けます。

 

(平均の再パラメータ化による) 目的関数の定義

backward 過程の平均を学習する目的関数を導出するため、著者らは \(q\) と \(p_\theta\) の組み合わせが変分オートエンコーダ (VAE) (Kingma et al., 2013) として見なせることを観察しています。このため、正解データサンプル \(\mathbf{x}_0\) に関する負の対数尤度を最小化するために 変分下限 (ELBO とも呼ばれます) が使用できます (ELBO に関する詳細については VAE 論文を参照)。この過程に対する ELBO は各時間ステップ \(t\) における損失の和, \(L = L_0 + L_1 + … + L_T\) であることがわかっています。forward \(q\) 過程と backward 過程の構成から、損失の各項 (except for \(L_0\)) は実際には 2 つのガウス分布間の KL ダイバージェンス で、これは平均に関する L2-損失として明示的に書くことができます。

Sohl-Dickstein et al. により示されたように、構築された forward 過程 \(q\) の直接的な結果は、\(\mathbf{x}_0\) で条件付けられた任意のノイズレベルで \(\mathbf{x}_t\) をサンプリングできることです (ガウス分布の和もガウス分布であるためです)。これは非常に便利です : \(\mathbf{x}_t\) をサンプリングするために \(q\) を繰り返し適用する必要はありません。\(\alpha_t := 1 – \beta_t\) そして \(\bar{\alpha}t := \Pi{s=1}^{t} \alpha_s\) として次を得ます :

$$q(\mathbf{x}_t | \mathbf{x}_0) = \cal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1- \bar{\alpha}_t) \mathbf{I})$$

この等式を「良い性質 (= nice property)」と呼ぶことにしましょう。これは、ガウスノイズをサンプリングしてそれを適切にスケールして \(\mathbf{x}_0\) に追加して \(\mathbf{x}_t\) を直接得られることを意味しています。\(\bar{\alpha}_t\) は既知の \(\beta_t\) 分散スケジュールの関数なので従ってこれもまた既知であり事前計算できることに注意してください。そしてこれは訓練の間に 損失関数 \(L\) のランダム項を最適化する ことを可能にします (あるいは換言すれば、訓練の間に \(t\) をランダムにサンプリングして \(L_t\) を最適化することを可能にします)。

Ho et al. で示された、この特性のもう一つの美点は、損失を構成する KL 項のノイズレベル \(t\) について (ネットワーク \(\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)\) を通して) 追加されたノイズをニューラルネットワークに学習 (予測) させるために (幾つかの数学の後、それについては読者は この優れたブログ記事 を参照してください) 代わりに平均を再パラメータ化できる ことです。これはニューラルネットワークが (直接的な) 平均予測器よりもむしろノイズ予測器になることを意味します。平均は次のように計算できます :

$$\mathbf{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t – \frac{\beta_t}{\sqrt{1- \bar{\alpha}t}} \mathbf{\epsilon}\theta(\mathbf{x}_t, t) \right)$$

そして最終的な目的関数 \(L_t\) は以下のようなものです (\(\mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\) が与えられたときランダムな時間ステップに対して) :

$$| \mathbf{\epsilon} – \mathbf{\epsilon}_\theta(\mathbf{x}t, t) |^2 = | \mathbf{\epsilon} – \mathbf{\epsilon}\theta( \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{(1- \bar{\alpha}_t) } \mathbf{\epsilon}, t) |^2.$$

ここで、\(\mathbf{x}0\) は初期 (実、破損していない (= uncorrupted)) 画像で、固定された forward 過程により与えられる直接的なノイズレベル \(t\) サンプルを見ます。 \(\mathbf{\epsilon}\) は時間ステップ \(t\) でサンプリングされた純粋なノイズで、\(\mathbf{\epsilon}\theta (\mathbf{x}_t, t)\) はニューラルネットワークです。ニューラルネットワークは正解 (のガウスノイズ) と予測ガウスノイズの間の単純な平均二乗誤差 (MSE) を使用して最適化されます。

訓練アルゴリズムは今では以下のようなものになります :

換言すれば :

  • リアルの未知なそして多分複雑なデータ分布 \(q(\mathbf{x}_0)\) からランダムなサンプル \(\mathbf{x}_0\) を取ります。

  • \(1\) と \(T\) 間のノイズレベル \(t\) を一様にサンプリングします (i.e., ランダムな時間ステップ)。

  • ガウス分布からノイズをサンプリングして (上記で定義された良い性質を使用して) レベル \(t\) でこのノイズにより入力を破損させます。

  • ニューラルネットワークは破損画像 \(\mathbf{x}_t\) に基づいてこのノイズ (i.e. 既知のスケジュール \(\beta_t\) に基づいて \(\mathbf{x}_0\) 上で適用されるノイズ) を予測するように訓練されます。

実際には、このすべてはデータのバッチ上で行なわれます、ニューラルネットワークを最適化するために確率的勾配降下を使用するためです。

 

ニューラルネットワーク

ニューラルネットワークは特定の時間ステップにおけるノイズがある画像を受け取り、予測されたノイズを返す必要があります。予測ノイズは入力画像と同じサイズ/解像度を持つテンソルであることに注意してください。従って技術的には、ネットワークは同じ shape のテンソルを受け取り、出力します。このためにどのようなタイプのニューラルネットワークを利用できるでしょう?

ここで通常使用されるものは オートエンコーダ のそれに非常に類似しています、これを貴方は典型的な「深層学習へのイントロ」チュートリアルで覚えているかもしれません。オートエンコーダはエンコーダとデコーダの間の中にいわゆる「ボトルネック」層を持っています。エンコーダは最初に画像を「ボトルネック」と呼ばれる、より小さな隠れ表現にエンコードし、そしてデコーダは隠れ表現を実際の画像にデコードし戻します。これはネットワークがボトルネック層内に最も重要な情報だけを保持することを強制します。

アーキテクチャの観点からは、DDPM の著者らは (Ronneberger et al., 2015) で紹介された U-Net を追求しました (これはその時点で医用画像セグメンテーションに対して最先端の結果を獲得しました)。任意のオートエンコーダのように、このネットワークは中央のボトルネックで構成され、これはネットワークが最も重要な情報だけを学習することを確実にしています。重要なのは、それはエンコーダとデコーダの間に残差接続を導入していて、勾配フローを大幅に改良しています (inspired by ResNet in He et al., 2015)。

ご覧のように、U-Net モデルは最初に入力をサンプリングして (i.e. 入力を空間的解像度の観点からより小さくします)、その後でアップサンプリングが実行されます。

以下で、このネットワークを一歩ずつ実装します。

 

ネットワーク・ヘルパー

最初に、幾つかのヘルパー関数とクラスを定義します、これらはニューラルネットワークを実装するときに使用されます。重要なのは、Residual (残差) モジュールを定義します、これは入力を特定の関数の出力に単純に追加します (換言すれば、残差接続を特定の関数に追加します)。

up- と downsampling 演算のためにエイリアスも定義します。

def exists(x):
    return x is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

def Upsample(dim):
    return nn.ConvTranspose2d(dim, dim, 4, 2, 1)

def Downsample(dim):
    return nn.Conv2d(dim, dim, 4, 2, 1)

 

位置埋め込み

ニューラルネットワークのパラメータは時間 (ノイズレベル) に渡り共有されますので、著者らは Transformer (Vaswani et al., 2017) にインスパイアされ、\(t\) をエンコードするのに sinusoidal (正弦関数の) 位置埋め込みを使用しています。これは、バッチのすべての画像について、ニューラルネットワークにどの特定の時間ステップ (ノイズレベル) でそれが動作しているかを「知る」ようにします。

SinusoidalPositionEmbeddings モジュールは入力として shape (batch_size, 1) のテンソルを受け取り (i.e. バッチの幾つかのノイズのある画像のノイズレベル)、そしてこれを位置埋め込みの次元である dim を持つ shape (batch_size, dim) のテンソルに変換します。それから、後で見るように、これは各残差ブロックに追加されます。

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

 

ResNet/ConvNeXT ブロック

次に、U-Net モデルの中核のビルディングブロックを定義します。DDPM の著者らは Wide ResNet ブロック (Zagoruyko et al., 2016) を採用しましたが、Phil Wang はまた ConvNeXT ブロック (Liu et al., 2022) のサポートも追加することを決めました、後者は画像ドメインで素晴らしい成功を獲得したためです。最終的な U-Net アーキテクチャでいずれかを選択できます。

Update: Phil Wang decided to remove ConvNeXT blocks from his implementation as they didn’t seem to work well for him. However, we obtained nice results with them, as shown further in this blog.

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups = 8):
        super().__init__()
        self.proj = nn.Conv2d(dim, dim_out, 3, padding = 1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift = None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x

class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""
    
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.block1(x)

        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            h = rearrange(time_emb, "b c -> b c 1 1") + h

        h = self.block2(h)
        return h + self.res_conv(x)
    
class ConvNextBlock(nn.Module):
    """https://arxiv.org/abs/2201.03545"""

    def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.GELU(), nn.Linear(time_emb_dim, dim))
            if exists(time_emb_dim)
            else None
        )

        self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)

        self.net = nn.Sequential(
            nn.GroupNorm(1, dim) if norm else nn.Identity(),
            nn.Conv2d(dim, dim_out * mult, 3, padding=1),
            nn.GELU(),
            nn.GroupNorm(1, dim_out * mult),
            nn.Conv2d(dim_out * mult, dim_out, 3, padding=1),
        )

        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        h = self.ds_conv(x)

        if exists(self.mlp) and exists(time_emb):
            condition = self.mlp(time_emb)
            h = h + rearrange(condition, "b c -> b c 1 1")

        h = self.net(h)
        return h + self.res_conv(x)

 

Attention モジュール

次に、attention モジュールを定義します、DDPM の著者らはこれを畳込みブロックの間に追加しました。Attention は有名な Transformer アーキテクチャ (Vaswani et al., 2017) のビルディングブロックで、これは NLP とビジョンから タンパク質フォールディング まで AI の様々なドメインで素晴らしい成功を示しています。Phil Wang は attention の 2 つのバリエーションを採用しています : 一つは (Transformer で使用された) 通常のマルチヘッド自己アテンションで、もう一つは 線形アテンションの変種 (Shen et al., 2018) で、その時間とメモリ要件は (通常のアテンションについては二次であるのに対して) シークエンス長で線形にスケールします。

アテンション機能の広範囲な說明については、読者は Jay Allamar の 素晴らしいブログ記事 を参照してください。

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head**-0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)

        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), 
                                    nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

 

グループ正規化

DDPM の著者らは U-Net の畳込み/アテンション層をグループ正規化 (Wu et al., 2018) で交互配置 (= interleave) しています。以下で PreNorm クラスを定義します、これは更に見るように、アテンション層の前に groupnorm を適用するために使用されます。正規化を Transformer の注意層の前に適用するか後に適用するかについては 議論 があることに注意してください。

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

 

条件付き U-Net

すべてのビルディングブロック (位置埋め込み, ResNet/ConvNeXT ブロック, アテンション, とグループ正規化) を定義したので、ニューラルネットワーク全体を定義するときです。ネットワーク \(\mathbf{\epsilon}_\theta(\mathbf{x}_t, t)\) の仕事は、ノイズのある画像とノイズレベルのバッチを受け取り、入力に追加されたノイズを出力することであることを思い出してください。より正式には :

  • ネットワークは入力として shape (batch_size, num_channels, height, width) のノイズのある画像と shape (batch_size, 1) のノイズレベルのバッチを受け取り、shape (batch_size, num_channels, height, width) のテンソルを返します。

ネットワークは以下のように構築されます :

  • 最初に、ノイズのある画像のバッチに畳み込み層が適用されて、そしてノイズレベルに対して位置埋め込みが計算されます。

  • 次に、ダウンサンプリング・ステージのシークエンスが適用されます。各ダウンサンプリング・ステージは 2 つの ResNet/ConvNeXT ブロック + groupnorm + アテンション + 残差接続 + ダウンサンプリング演算から構成されます。

  • ネットワークの真ん中では、再度 ResNet or ConvNeXT ブロックが適用され、アテンションで交互配置されます。

  • 次に、アップサンプリング・ステージのシークエンスが適用されます。各アップサンプリング・ステージは 2 つの ResNet/ConvNeXT ブロック + groupnorm + アテンション + 残差接続 + アップサンプリング演算から構成されます。

  • 最後に、ResNet/ConvNeX ブロックと続いて畳み込み層が適用されます。

究極的には、ニューラルネットワークはレゴブロックのように層を積み上げます (しかしそれらが どのように動作するか理解する ことは重要です)。

class Unet(nn.Module):
    def __init__(
        self,
        dim,
        init_dim=None,
        out_dim=None,
        dim_mults=(1, 2, 4, 8),
        channels=3,
        with_time_emb=True,
        resnet_block_groups=8,
        use_convnext=True,
        convnext_mult=2,
    ):
        super().__init__()

        # determine dimensions
        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        
        if use_convnext:
            block_klass = partial(ConvNextBlock, mult=convnext_mult)
        else:
            block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(dim),
                nn.Linear(dim, time_dim),
                nn.GELU(),
                nn.Linear(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.Sequential(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def forward(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        # downsample
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        # bottleneck
        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        # upsample
        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        return self.final_conv(x)

デフォルトでは、(use_convnext が True に設定されているため) ノイズ予測器は ConvNeXT ブロックを使用して (with_time_emb が True に設定されているので) 位置埋め込みが追加されます。

 

forward 拡散過程の定義

forward 拡散過程は実分布からの画像に多くの時間ステップ \(T\) 內で徐々にノイズを追加します。これは 分散スケジュール に従って発生します。オリジナルの DDPM 著者らは線形スケジュールを採用しました :

forward 過程分散を \(\beta_1 = 10^{−4}\) から \(\beta_T = 0.02\) に線形に増加する定数に設定します。

けれども、(Nichol et al., 2021) でコサインスケジュールを採用したときより良い結果が獲得できることが示されました。

以下で、\(T\) 時間ステップに対する様々なスケジュールを定義します (we’ll choose one later on)。

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

まずは、\(T=200\) 時間ステップについて線形スケジュールを使用し、分散 \(\bar{\alpha}_t\) の累積の積 (= cumulative product) のような \(\beta_t\) から必要となる様々な変数を定義しましょう。下の変数の各々は単なる 1 次元テンソルで、\(t\) から \(T\) の値をストアします。重要なのは、extract 関数も定義していることで、これはバッチのインデックスのために適切な \(t\) インデックスを抽出することを可能にします。

timesteps = 200

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

拡散過程の各時間ステップでノイズがどのように追加されるか、猫の画像で例示します。

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
image

ノイズは Pillow 画像ではなく、PyTorch テンソルに追加されます。最初に画像変換を定義します、これは PIL 画像から PyTorch テンソル (その上でノイズを加えることができます) に進むことを可能にします、逆もまた同様です。

これらの変換は非常に単純です : 最初に \(255\) で除算して画像を (それらが \([0,1]\) にあるように) 正規化して、それらが \([-1, 1]\) 範囲にあることを確実にします。DDPM 論文から :

画像データは \({0, 1, … , 255}\) の整数から成り、\([−1, 1]\) に線形にスケールされることを仮定しています。これは、ニューラルネットワークの逆過程が標準正規事前分布 \(p(\mathbf{x}_T )\) からはじまりスケールされた入力上で作用することを保証します。

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(), # turn into Numpy array of shape HWC, divide by 255
    Lambda(lambda t: (t * 2) - 1),
    
])

x_start = transform(image).unsqueeze(0)
x_start.shape
Output:
----------------------------------------------------------------------------------------------------
torch.Size([1, 3, 128, 128])

逆変換もまた定義します、これは \([-1, 1]\) の値を含む PyTorch テンソルを受け取り PIL 画像に変換し戻します :

import numpy as np

reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])

Let’s verify this:

reverse_transform(x_start.squeeze())

これで論文内のように forward 拡散過程を定義できます :

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

特定の時間ステップでそれをテストしましょう :

def get_noisy_image(x_start, t):
  # add noise
  x_noisy = q_sample(x_start, t=t)

  # turn back into PIL image
  noisy_image = reverse_transform(x_noisy.squeeze())

  return noisy_image
# take time step
t = torch.tensor([40])

get_noisy_image(x_start, t)

様々な時間ステップについてこれを可視化しましょう :

import matplotlib.pyplot as plt

# use seed for reproducability
torch.manual_seed(0)

# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

これは、モデルが与えられたとき損失関数を以下のように定義できることを意味します :

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

denoise_model は上で定義された U-Net になります。真のノイズと予測されたノイズ間の Huber 損失を採用します。

 

PyTorch Dataset + DataLoader を定義する

ここで通常の PyTorch データセット を定義します。データセットは Fashion-MNIST, CIFAR-10 or ImageNet のような実際のデータセットからの画像で単純に構成され、\([−1, 1]\) に線形にスケールされます。

各画像は同じサイズにリサイズされます。注意すべき面白いことは、画像もランダムに水平に反転されることです。論文から :

CIFAR10 に対して訓練の間ランダムな水平反転を使用しました。訓練を反転ありとなしで試しましたが、反転がサンプルの品質を僅かに改良することを見出しました。

ここではハブから Fashion MNIST データセットを簡単にロードするために 🤗 Datasets ライブラリ を使用します。このデータセットは、同じ解像度、つまり 28×28 を既に持つ画像から構成されます。

from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset("fashion_mnist")
image_size = 28
channels = 1
batch_size = 128

次に、データセット全体上で on-the-fly に適用する関数を定義します。そのために with_transform 機能 を使用します。この関数は幾つかの基本的な画像前処理を適用するだけです : ランダムな水平反転, 再スケーリングそして最後に \([-1,1]\) 範囲の値を持つようにします。

from torchvision import transforms
from torch.utils.data import DataLoader

# define image transformations (e.g. using torchvision)
transform = Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
print(batch.keys())
Output:
----------------------------------------------------------------------------------------------------
dict_keys(['pixel_values'])

 

サンプリング

(進捗を追跡するために) 訓練中にモデルからサンプリングしますので、そのためにコードを下で定義します。サンプリングは論文では Algorithm 2 としてまとめられています :

拡散モデルからの新しい画像の生成は拡散過程を逆にすることにより発生します : \(T\) から始めます、そこではガウス分布から純粋なノイズをサンプリングします、それから (学習した条件付き確率を使用して) ニューラルネットワークを使用してそれを徐々にノイズ除去し、時間ステップ \(t = 0\) で終了します。上で示されたように、ノイズ予測器を使用して、平均の再パラメータ化を行なうことで僅かに少なくノイズ除去された画像 \(\mathbf{x}_{t-1 }\) を導出できます。分散は前もって知られていることを忘れないでください。

理想的には、実データ分布から得られたような画像で終わることです。

以下のコードはこれを実装しています。

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

上のコードはオリジナル実装の単純化バージョンであることに注意してください。私達の単純化 (論文の Algorithm 2 に沿っています) は、クリッピング を使用する より複雑な実装であるオリジナル と同様に上手く動作することを見い出しています。

 

モデルの訓練

次に、モデルを通常の PyTorch 流儀で訓練します。上で定義された sample メソッドを使用して生成画像を定期的にセーブするためのロジックも定義します。

from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

以下で、モデルを定義し、それを GPU に移します。標準的な optimizer (Adam) も定義します。

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

Let’s start training!

from torchvision.utils import save_image

epochs = 5

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: sample t uniformally for every example in the batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # save generated images
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{milestone}.png'), nrow = 6)
Output:
----------------------------------------------------------------------------------------------------
Loss: 0.46477368474006653
Loss: 0.12143351882696152
Loss: 0.08106148988008499
Loss: 0.0801810547709465
Loss: 0.06122320517897606
Loss: 0.06310459971427917
Loss: 0.05681884288787842
Loss: 0.05729678273200989
Loss: 0.05497899278998375
Loss: 0.04439849033951759
Loss: 0.05415581166744232
Loss: 0.06020551547408104
Loss: 0.046830907464027405
Loss: 0.051029372960329056
Loss: 0.0478244312107563
Loss: 0.046767622232437134
Loss: 0.04305662214756012
Loss: 0.05216279625892639
Loss: 0.04748568311333656
Loss: 0.05107741802930832
Loss: 0.04588869959115982
Loss: 0.043014321476221085
Loss: 0.046371955424547195
Loss: 0.04952816292643547
Loss: 0.04472338408231735

 

サンプリング (推論)

モデルからサンプリングするには、上で定義された sample 関数を使用できます。

# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

モデルは素敵な T-シャツを生成できるようです!(その上で) 訓練したデータセットはかなり低解像度 (28×28) であることに留意してください。

ノイズ除去過程の gif を作成することもできます :

import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

 

Follow-up reads

DDPM 論文は、拡散モデルが (非) 条件付き画像生成のための有望な方向性であることを示したことに注意してください。それ以降、これは特にテキスト条件付き画像生成のために (大いに) 改良されてきました。以下に、幾つかの重要な (しかし完全ではない) follow-up ワークを列挙します :

  • Improved Denoising Diffusion Probabilistic Models (Nichol et al., 2021) : 改良されたノイズ除去拡散確率モデル – 条件付き分布の (平均に加えて) 分散の学習はパフォーマンスの改良に役立つことを見い出しました。

  • Cascaded Diffusion Models for High Fidelity Image Generation (Ho et al., 2021) : 高忠実度な画像生成のための cascaded 拡散モデル – 高忠実度な画像合成のために解像度を上げた画像を生成するマルチ拡散モデルのパイプラインから成る、cascaded 拡散を導入する。

  • Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021) : 拡散モデル Beat GAN on 画像合成 – 拡散モデルが、U-Net アーキテクチャを改良して分類器ガイダンスを導入することにより、現在の最先端の生成モデルよりも優れた画像サンプル品質を獲得できることを示します。

  • Classifier-Free Diffusion Guidance (Ho et al., 2021) : 分類器フリーな拡散ガイダンス – 条件付きそして条件無し拡散モデルを単一のニューラルネットワークで一緒に訓練することにより、拡散モデルをガイドするために分類器は必要ないことを示します。

  • Hierarchical Text-Conditional Image Generation with CLIP Latents (DALL-E 2) (Ramesh et al., 2022) : CLIP 潜在 (変数) による階層的テキスト条件付き画像生成 (DALL-E 2) – テキストキャプションを CLIP 画像埋め込みに変換するために事前分布を使用し、その後で拡散モデルはそれを画像にデコードします。

  • Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (ImageGen) (Saharia et al., 2022) : 深層言語理解によるフォトリアリスティックなテキスト-to-画像拡散モデル (ImageGen) – 大規模な事前訓練済み言語モデル (e.g. T5) と cascaded 拡散との組み合わせはテキスト-to-画像合成について上手く動作することを示します。

このリストは執筆時点 (2022/06/07) までの重要なワークだけを含むことに注意してください。

今のところ、拡散モデルの主な (おそらくは唯一の) 欠点はそれらが画像を生成するために複数の forward パスを必要とすることのようです (これは GAN のような生成モデルには当てはまりません)。けれども、わずか 10 ほどののノイズ除去ステップで高忠実度な生成を可能にする研究が進んでいます。

 

以上