ノイズ除去拡散確率モデル0 on CIFAR-10 (code)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 09/24/2022 (No releases published)
* 本ページは、以下のレポジトリのスクリプトをノートブック形式に変換したコードです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
クラスキャット 人工知能 研究開発支援サービス
◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
ノイズ除去拡散確率モデル on CIFAR-10 (code)
インポート
import os,sys
import math
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from tqdm import tqdm
ハイパーパラメータ
CONF_DEVICE = "cuda:1"
CONF_EPOCH = 200
CONF_IMG_SIZE = 32
CONF_LR = 1e-4
CONF_BATCH_SIZE = 200
CONF_CHANNEL = 128
CONF_CHANNEL_MULT = [1, 2, 3, 4]
CONF_ATTN = [2]
CONF_NUM_RES_BLOCKS = 2
CONF_DROPOUT = 0.15
CONF_MULTIPLIER = 2.
CONF_TRAINING_LOAD_WEIGHT = None
CONF_SAVE_WEIGHT_DIR = "./Checkpoints/"
CONF_BETA_1 = 1e-4
CONF_BETA_T = 0.02
CONF_T = 1000
CONF_GRAD_CLIP = 1.
CONF_SAMPLED_DIR = "./SampledImgs/"
CONF_SAMPLED_NOISY_IMG_NAME = "NoisyNoGuidenceImgs.png"
CONF_SAMPLED_IMG_NAME = "SampledNoGuidenceImgs.png"
CONF_NROW = 8
データのロード
device = torch.device(CONF_DEVICE)
dataset = CIFAR10(
root='./CIFAR10',
train=True,
download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
dataloader = DataLoader(
dataset,
batch_size=CONF_BATCH_SIZE,
shuffle=True,
num_workers=4,
drop_last=True,
pin_memory=True
)
モデル
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class TimeEmbedding(nn.Module):
def __init__(self, T, d_model, dim):
assert d_model % 2 == 0
super().__init__()
emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
emb = torch.exp(-emb)
pos = torch.arange(T).float()
emb = pos[:, None] * emb[None, :]
assert list(emb.shape) == [T, d_model // 2]
emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)
assert list(emb.shape) == [T, d_model // 2, 2]
emb = emb.view(T, d_model)
self.timembedding = nn.Sequential(
nn.Embedding.from_pretrained(emb),
nn.Linear(d_model, dim),
Swish(),
nn.Linear(dim, dim),
)
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, nn.Linear):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
def forward(self, t):
emb = self.timembedding(t)
return emb
class DownSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
x = self.main(x)
return x
class UpSample(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.main = nn.Conv2d(in_ch, in_ch, 3, stride=1, padding=1)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.main.weight)
init.zeros_(self.main.bias)
def forward(self, x, temb):
_, _, H, W = x.shape
x = F.interpolate(
x, scale_factor=2, mode='nearest')
x = self.main(x)
return x
class AttnBlock(nn.Module):
def __init__(self, in_ch):
super().__init__()
self.group_norm = nn.GroupNorm(32, in_ch)
self.proj_q = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_k = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj_v = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.proj = nn.Conv2d(in_ch, in_ch, 1, stride=1, padding=0)
self.initialize()
def initialize(self):
for module in [self.proj_q, self.proj_k, self.proj_v, self.proj]:
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.proj.weight, gain=1e-5)
def forward(self, x):
B, C, H, W = x.shape
h = self.group_norm(x)
q = self.proj_q(h)
k = self.proj_k(h)
v = self.proj_v(h)
q = q.permute(0, 2, 3, 1).view(B, H * W, C)
k = k.view(B, C, H * W)
w = torch.bmm(q, k) * (int(C) ** (-0.5))
assert list(w.shape) == [B, H * W, H * W]
w = F.softmax(w, dim=-1)
v = v.permute(0, 2, 3, 1).view(B, H * W, C)
h = torch.bmm(w, v)
assert list(h.shape) == [B, H * W, C]
h = h.view(B, H, W, C).permute(0, 3, 1, 2)
h = self.proj(h)
return x + h
class ResBlock(nn.Module):
def __init__(self, in_ch, out_ch, tdim, dropout, attn=False):
super().__init__()
self.block1 = nn.Sequential(
nn.GroupNorm(32, in_ch),
Swish(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=1),
)
self.temb_proj = nn.Sequential(
Swish(),
nn.Linear(tdim, out_ch),
)
self.block2 = nn.Sequential(
nn.GroupNorm(32, out_ch),
Swish(),
nn.Dropout(dropout),
nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1),
)
if in_ch != out_ch:
self.shortcut = nn.Conv2d(in_ch, out_ch, 1, stride=1, padding=0)
else:
self.shortcut = nn.Identity()
if attn:
self.attn = AttnBlock(out_ch)
else:
self.attn = nn.Identity()
self.initialize()
def initialize(self):
for module in self.modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
init.xavier_uniform_(module.weight)
init.zeros_(module.bias)
init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)
def forward(self, x, temb):
h = self.block1(x)
h += self.temb_proj(temb)[:, :, None, None]
h = self.block2(h)
h = h + self.shortcut(x)
h = self.attn(h)
return h
class UNet(nn.Module):
def __init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):
super().__init__()
assert all([i < len(ch_mult) for i in attn]), 'attn index out of bound'
tdim = ch * 4
self.time_embedding = TimeEmbedding(T, ch, tdim)
self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)
self.downblocks = nn.ModuleList()
chs = [ch] # record output channel when dowmsample for upsample
now_ch = ch
for i, mult in enumerate(ch_mult):
out_ch = ch * mult
for _ in range(num_res_blocks):
self.downblocks.append(ResBlock(
in_ch=now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
chs.append(now_ch)
if i != len(ch_mult) - 1:
self.downblocks.append(DownSample(now_ch))
chs.append(now_ch)
self.middleblocks = nn.ModuleList([
ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
ResBlock(now_ch, now_ch, tdim, dropout, attn=False),
])
self.upblocks = nn.ModuleList()
for i, mult in reversed(list(enumerate(ch_mult))):
out_ch = ch * mult
for _ in range(num_res_blocks + 1):
self.upblocks.append(ResBlock(
in_ch=chs.pop() + now_ch, out_ch=out_ch, tdim=tdim,
dropout=dropout, attn=(i in attn)))
now_ch = out_ch
if i != 0:
self.upblocks.append(UpSample(now_ch))
assert len(chs) == 0
self.tail = nn.Sequential(
nn.GroupNorm(32, now_ch),
Swish(),
nn.Conv2d(now_ch, 3, 3, stride=1, padding=1)
)
self.initialize()
def initialize(self):
init.xavier_uniform_(self.head.weight)
init.zeros_(self.head.bias)
init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
init.zeros_(self.tail[-1].bias)
def forward(self, x, t):
# Timestep embedding
temb = self.time_embedding(t)
h = self.head(x)
hs = [h]
for layer in self.downblocks:
h = layer(h, temb)
hs.append(h)
# Middle
for layer in self.middleblocks:
h = layer(h, temb)
for layer in self.upblocks:
if isinstance(layer, ResBlock):
h = torch.cat([h, hs.pop()], dim=1)
h = layer(h, temb)
h = self.tail(h)
assert len(hs) == 0
return h
モデルインスタンスの作成/ロード
# model setup
net_model = UNet(T=CONF_T,
ch=CONF_CHANNEL,
ch_mult=CONF_CHANNEL_MULT,
attn=CONF_ATTN,
num_res_blocks=CONF_NUM_RES_BLOCKS,
dropout=CONF_DROPOUT
).to(device)
if CONF_TRAINING_LOAD_WEIGHT is not None:
net_model.load_state_dict(torch.load(os.path.join(CONF_SAVE_WEIGHT_DIR, CONF_TRAINING_LOAD_WEIGHT), map_location=device))
print("weight : {} loaded.".format(CONF_TRAINING_LOAD_WEIGHT))
訓練
訓練準備
class GradualWarmupScheduler(_LRScheduler):
def __init__(self, optimizer, multiplier, warm_epoch, after_scheduler=None):
self.multiplier = multiplier
self.total_epoch = warm_epoch
self.after_scheduler = after_scheduler
self.finished = False
self.last_epoch = None
self.base_lrs = None
super().__init__(optimizer)
def get_lr(self):
if self.last_epoch > self.total_epoch:
if self.after_scheduler:
if not self.finished:
self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
self.finished = True
return self.after_scheduler.get_lr()
return [base_lr * self.multiplier for base_lr in self.base_lrs]
return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
def step(self, epoch=None, metrics=None):
if self.finished and self.after_scheduler:
if epoch is None:
self.after_scheduler.step(None)
else:
self.after_scheduler.step(epoch - self.total_epoch)
else:
return super(GradualWarmupScheduler, self).step(epoch)
optimizer = torch.optim.AdamW(
net_model.parameters(), lr=CONF_LR, weight_decay=1e-4)
cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer=optimizer, T_max=CONF_EPOCH, eta_min=0, last_epoch=-1)
warmUpScheduler = GradualWarmupScheduler(
optimizer=optimizer, multiplier=CONF_MULTIPLIER, warm_epoch=CONF_EPOCH // 10, after_scheduler=cosineScheduler)
トレーナー
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
device = t.device
out = torch.gather(v, index=t, dim=0).float().to(device)
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer(
'betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
'sqrt_alphas_bar', torch.sqrt(alphas_bar))
self.register_buffer(
'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
def forward(self, x_0):
"""
Algorithm 1.
"""
t = torch.randint(self.T, size=(x_0.shape[0], ), device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 +
extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')
return loss
# トレーナーのインスタンス化
trainer = GaussianDiffusionTrainer(net_model, CONF_BETA_1, CONF_BETA_T, CONF_T).to(device)
訓練ループ
%%time
os.makedirs(CONF_SAVE_WEIGHT_DIR, exist_ok=True)
# start training
for e in range(CONF_EPOCH):
# tqdm(iterable)
with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
for images, _ in tqdmDataLoader:
# train
optimizer.zero_grad()
x_0 = images.to(device)
loss = trainer(x_0).sum() / 1000.
loss.backward() # Tensor.backward()
torch.nn.utils.clip_grad_norm_(net_model.parameters(), CONF_GRAD_CLIP)
optimizer.step()
tqdmDataLoader.set_postfix(ordered_dict={
"epoch": e,
"loss: ": loss.item(),
"img shape: ": x_0.shape,
"LR": optimizer.state_dict()['param_groups'][0]["lr"]
})
warmUpScheduler.step()
torch.save(net_model.state_dict(),
os.path.join(CONF_SAVE_WEIGHT_DIR, 'ckpt_' + str(e) + "_.pt"))
評価
サンプラー
class GaussianDiffusionSampler(nn.Module):
def __init__(self, model, beta_1, beta_T, T):
super().__init__()
self.model = model
self.T = T
self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
alphas = 1. - self.betas
alphas_bar = torch.cumprod(alphas, dim=0)
alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
self.register_buffer('coeff1', torch.sqrt(1. / alphas))
self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
assert x_t.shape == eps.shape
return (
extract(self.coeff1, t, x_t.shape) * x_t -
extract(self.coeff2, t, x_t.shape) * eps
)
def p_mean_variance(self, x_t, t):
# below: only log_variance is used in the KL computations
var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
var = extract(var, t, x_t.shape)
eps = self.model(x_t, t)
xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
return xt_prev_mean, var
def forward(self, x_T):
"""
Algorithm 2.
"""
x_t = x_T
for time_step in reversed(range(self.T)):
print(time_step)
t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
mean, var= self.p_mean_variance(x_t=x_t, t=t)
# no noise when t == 0
if time_step > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
x_0 = x_t
return torch.clip(x_0, -1, 1)
サンプリング
%%time
CONF_TEST_LOAD_WEIGHT = "ckpt_199_.pt"
os.makedirs(CONF_SAMPLED_DIR, exist_ok=True)
with torch.no_grad():
device = torch.device(CONF_DEVICE)
model = UNet(T=CONF_T,
ch=CONF_CHANNEL,
ch_mult=CONF_CHANNEL_MULT,
attn=CONF_ATTN,
num_res_blocks=CONF_NUM_RES_BLOCKS, dropout=0.)
ckpt = torch.load(os.path.join(CONF_SAVE_WEIGHT_DIR, CONF_TEST_LOAD_WEIGHT), map_location=device)
model.load_state_dict(ckpt)
print("model load weight done.")
model.eval()
sampler = GaussianDiffusionSampler(model, CONF_BETA_1, CONF_BETA_T, CONF_T).to(device)
# Sampled from standard normal distribution
noisyImage = torch.randn(size=[CONF_BATCH_SIZE, 3, 32, 32], device=device)
saveNoisy = torch.clamp(noisyImage * 0.5 + 0.5, 0, 1)
save_image(saveNoisy, os.path.join(CONF_SAMPLED_DIR, CONF_SAMPLED_NOISY_IMG_NAME), nrow=CONF_NROW)
sampledImgs = sampler(noisyImage)
sampledImgs = sampledImgs * 0.5 + 0.5 # [0 ~ 1]
save_image(sampledImgs, os.path.join(CONF_SAMPLED_DIR, CONF_SAMPLED_IMG_NAME), nrow=CONF_NROW)
以上