ノイズ除去拡散確率モデル on CIFAR-10 (code)

ノイズ除去拡散確率モデル0 on CIFAR-10 (code)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/24/2022 (No releases published)

* 本ページは、以下のレポジトリのスクリプトをノートブック形式に変換したコードです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

ノイズ除去拡散確率モデル on CIFAR-10 (code)

インポート

import os,sys
import math

import numpy as np

import torch

from torch import nn
from torch.nn import init
from torch.nn import functional as F

import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler

from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image

from tqdm import tqdm

 

ハイパーパラメータ


CONF_DEVICE = "cuda:1"

CONF_EPOCH = 200
CONF_IMG_SIZE = 32
CONF_LR = 1e-4
CONF_BATCH_SIZE = 200

CONF_CHANNEL = 128
CONF_CHANNEL_MULT = [1, 2, 3, 4]
CONF_ATTN = [2]
CONF_NUM_RES_BLOCKS = 2
CONF_DROPOUT = 0.15
CONF_MULTIPLIER = 2.

CONF_TRAINING_LOAD_WEIGHT = None
CONF_SAVE_WEIGHT_DIR = "./Checkpoints/"

CONF_BETA_1 = 1e-4
CONF_BETA_T = 0.02
CONF_T = 1000
CONF_GRAD_CLIP = 1.

CONF_SAMPLED_DIR = "./SampledImgs/"
CONF_SAMPLED_NOISY_IMG_NAME = "NoisyNoGuidenceImgs.png"
CONF_SAMPLED_IMG_NAME = "SampledNoGuidenceImgs.png"
CONF_NROW = 8

 

データのロード

device = torch.device(CONF_DEVICE)

dataset = CIFAR10(
    root='./CIFAR10',
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]))

dataloader = DataLoader(
    dataset,
    batch_size=CONF_BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    drop_last=True,
    pin_memory=True
    )

 

モデル

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        assert list(emb.shape) == [T, d_model // 2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
        assert list(emb.shape) == [T, d_model // 2, 2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class DownSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        x = self.main(x)
        return x


class UpSample(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)

    def forward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)
        return x


class AttnBlock(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
        self.initialize()

    def initialize(self):
        for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0, 2, 3, 1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k) * (int(C) ** (-0.5))
        assert list(w.shape) == [B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0, 2, 3, 1).view(B, H * W, C)
        h = torch.bmm(w, v)
        assert list(h.shape) == [B, H * W, C]
        h = h.view(B, H, W, C).permute(0, 3, 1, 2)
        h = self.proj(h)

        return x + h


class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
        )
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
        )
        if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
        else:
            self.shortcut = nn.Identity()
        if attn:
            self.attn = AttnBlock(out_ch)
        else:
            self.attn = nn.Identity()
        self.initialize()

    def initialize(self):
        for module in self.modules():
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)

    def forward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:, :, None, None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)
        return h


class UNet(nn.Module):
    def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
        super().__init__()
        assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
        tdim = ch * 4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
        self.downblocks = nn.ModuleList()
        chs = [ch]  # record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult in enumerate(ch_mult):
            out_ch = ch * mult
            for _ in range(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)
            if i != len(ch_mult) - 1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
        ])

        self.upblocks = nn.ModuleList()
        for i, mult in reversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ in range(num_res_blocks + 1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i != 0:
                self.upblocks.append(UpSample(now_ch))
        assert len(chs) == 0

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
        )
        self.initialize()

    def initialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)


    def forward(self, x, t):
        # Timestep embedding
        temb = self.time_embedding(t)

        h = self.head(x)
        hs = [h]
        for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)
        # Middle
        for layer in self.middleblocks:
            h = layer(h, temb)

        for layer in self.upblocks:
            if isinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)

        assert len(hs) == 0
        return h

 

モデルインスタンスの作成/ロード

# model setup
net_model = UNet(T=CONF_T,
                 ch=CONF_CHANNEL,
                 ch_mult=CONF_CHANNEL_MULT,
                 attn=CONF_ATTN,
                 num_res_blocks=CONF_NUM_RES_BLOCKS,
                 dropout=CONF_DROPOUT
                 ).to(device)

if CONF_TRAINING_LOAD_WEIGHT is not None:
    net_model.load_state_dict(torch.load(os.path.join(CONF_SAVE_WEIGHT_DIR, CONF_TRAINING_LOAD_WEIGHT), map_location=device))
    print("weight : {} loaded.".format(CONF_TRAINING_LOAD_WEIGHT))

 

訓練

訓練準備

class GradualWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
        self.multiplier = multiplier
        self.total_epoch = warm_epoch
        self.after_scheduler = after_scheduler
        self.finished = False
        self.last_epoch = None
        self.base_lrs = None
        super().__init__(optimizer)

    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]


    def step(self, epoch=None, metrics=None):
        if self.finished and self.after_scheduler:
            if epoch is None:
                self.after_scheduler.step(None)
            else:
                self.after_scheduler.step(epoch - self.total_epoch)
        else:
            return super(GradualWarmupScheduler, self).step(epoch)
optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=CONF_LR, weight_decay=1e-4)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=CONF_EPOCH, eta_min=0, last_epoch=-1)
warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer, multiplier=CONF_MULTIPLIER, warm_epoch=CONF_EPOCH // 10, after_scheduler=cosineScheduler)

 

トレーナー

def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    device = t.device
    out = torch.gather(v, index=t, dim=0).float().to(device)
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

class GaussianDiffusionTrainer(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)

        # calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))


    def forward(self, x_0):
        """
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
            extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
        return loss
# トレーナーのインスタンス化

trainer = GaussianDiffusionTrainer(net_model, CONF_BETA_1, CONF_BETA_T, CONF_T).to(device)

 

訓練ループ

%%time

os.makedirs(CONF_SAVE_WEIGHT_DIR, exist_ok=True)

# start training
for e in range(CONF_EPOCH):
    # tqdm(iterable)
    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for images, _ in tqdmDataLoader:
            # train
            optimizer.zero_grad()
            x_0 = images.to(device)
            loss = trainer(x_0).sum() / 1000.
            loss.backward() # Tensor.backward()
            torch.nn.utils.clip_grad_norm_(net_model.parameters(), CONF_GRAD_CLIP)
            optimizer.step()
            tqdmDataLoader.set_postfix(ordered_dict={
                    "epoch": e,
                    "loss: ": loss.item(),
                    "img shape: ": x_0.shape,
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                })
    warmUpScheduler.step()
    torch.save(net_model.state_dict(),
               os.path.join(CONF_SAVE_WEIGHT_DIR, 'ckpt_' + str(e) + "_.pt"))

 

評価

サンプラー

class GaussianDiffusionSampler(nn.Module):
    def __init__(self, model, beta_1, beta_T, T):
        super().__init__()

        self.model = model
        self.T = T

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]

        self.register_buffer('coeff1', torch.sqrt(1. / alphas))
        self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))

        self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))

    def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
        assert x_t.shape == eps.shape
        return (
            extract(self.coeff1, t, x_t.shape) * x_t -
            extract(self.coeff2, t, x_t.shape) * eps
        )

    def p_mean_variance(self, x_t, t):
        # below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        eps = self.model(x_t, t)
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)

        return xt_prev_mean, var

    def forward(self, x_T):
        """
        Algorithm 2.
        """
        x_t = x_T
        for time_step in reversed(range(self.T)):
            print(time_step)
            t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t)
            # no noise when t == 0
            if time_step > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0
            x_t = mean + torch.sqrt(var) * noise
            assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
        x_0 = x_t
        return torch.clip(x_0, -1, 1)

 

サンプリング

%%time

CONF_TEST_LOAD_WEIGHT = "ckpt_199_.pt"

os.makedirs(CONF_SAMPLED_DIR, exist_ok=True)

with torch.no_grad():
    device = torch.device(CONF_DEVICE)

    model = UNet(T=CONF_T,
                 ch=CONF_CHANNEL,
                 ch_mult=CONF_CHANNEL_MULT,
                 attn=CONF_ATTN,
                     num_res_blocks=CONF_NUM_RES_BLOCKS, dropout=0.)

    ckpt = torch.load(os.path.join(CONF_SAVE_WEIGHT_DIR, CONF_TEST_LOAD_WEIGHT), map_location=device)
    model.load_state_dict(ckpt)
    print("model load weight done.")

    model.eval()

    sampler = GaussianDiffusionSampler(model, CONF_BETA_1, CONF_BETA_T, CONF_T).to(device)

    # Sampled from standard normal distribution
    noisyImage = torch.randn(size=[CONF_BATCH_SIZE, 3, 32, 32], device=device)
    saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
    save_image(saveNoisy, os.path.join(CONF_SAMPLED_DIR, CONF_SAMPLED_NOISY_IMG_NAME), nrow=CONF_NROW)

    sampledImgs = sampler(noisyImage)
    sampledImgs = sampledImgs * 0.5 + 0.5  # [0 ~ 1]
    save_image(sampledImgs, os.path.join(CONF_SAMPLED_DIR, CONF_SAMPLED_IMG_NAME), nrow=CONF_NROW)

 

以上