PyTorch 1.8 : 画像と動画 : DCGAN チュートリアル

PyTorch 1.8 チュートリアル : 画像と動画 : DCGAN チュートリアル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/16/2021 (1.8.1+cu102)

* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

 

画像と動画 : DCGAN チュートリアル

イントロダクション

このチュートリアルはサンプルを通して DCGAN へのイントロダクションを与えます。多くの実際のセレブ (= celebrities, 有名人) の写真を見せた後で新しいセレブを生成するために敵対的生成ネットワーク (GAN, Generative Adversarial Network) を訓練します。ここでのコードの殆どは pytorch/examples の dcgan 実装からのもので、そしてこのドキュメントは実装の徹底的な説明を与えてこのモデルがどのようにそして何故動作するのかを明らかにします。しかし心配しないでください、GAN の事前知識は必要ありません、しかし内部で実際に何が起きているかについて納得するために初心者はある程度の時間を費やす必要があるかもしれません。また、時間 (節約) のために一つ、または 2 つの GPU を持つことが役立つでしょう。最初から始めましょう。

 

敵対的生成ネットワーク

What is a GAN?

GAN は訓練データの分布を捕捉することを DL モデルに教えるためのフレームワークですのでその同じ分布から新しいデータを生成できます。GAN は 2014 年に Ian Goodfellow により創案されてペーパー Generative Adversarial Nets で最初に説明されました。それらは 2 つの別個のモデル、generator と discriminator から成ります。generator のジョブは訓練画像のように見える ‘fake’ 画像を生むことです。discriminator のジョブは画像を見てそれが real 訓練画像か (generator からの) fake 画像かを出力することです。訓練の間、generator はより良い fake を生成することにより絶えず discriminator を出し抜こうとします、その一方で discriminator はより良い探偵になり real と fake 画像を正しく分類するために動作しています。このゲームの均衡は generator が訓練データに直接由来するかのように見える完全な fake を生成しているとき、そして discriminator は generator 出力が real か fake か常に 50% の信頼度で推測する状態にされることです。

今は、discriminator から始める、チュートリアルを通して使用される幾つかの記法を定義しましょう。$x$ を画像を表わすデータとしましょう。\(D(x)\) は discriminator ネットワークで、これは \(x\) が generator ではなく訓練データに由来する (スカラー) 確率を出力します。ここでは、画像を扱っていますから \(D(x)\) への入力は CHW サイズ 3x64x64 の画像です。直感的には、\(x\) が訓練データに由来するときは \(D(x)\) は HIGH で、\(x\) が generator 由来であるときは LOW です。\(D(x)\) は伝統的な二値分類器として考えることもできます。

generator の記法については、\(z\) を標準正規分布からサンプリングされた潜在空間ベクトルとしましょう。\(G(z)\) は generator 関数を表します、これは潜在ベクトル \(z\) をデータ空間にマップします。\(G\) の目標は訓練データが由来する分布 (\(p_{data}\)) を推定することですからそれはその推定された分布 (\(p_g\)) からの fake サンプルを生成できます。

従って、\(D(G(z))\) は generator \(G\) の出力が real 画像である確率 (スカラー) です。Goodfellow のペーパー で記述されているように、\(D\) と \(G\) は minimax ゲームをプレーしていてそこでは \(D\) はそれが real と fake を正しく分類する確率 (\(logD(x)\)) を最大化することを試み、そして \(G\) は \(D\) がその出力が fake であると予測する確率 (\(log(1-D(G(x)))\)) を最小化しようとします。ペーパーから、GAN 損失関数は :

\[
\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(x)))\big]
\]

理論的には、この minimax ゲームへの解は \(p_g = p_{data}\) であるところで、そしてdiscriminator は入力が real か fake かランダムに推測します。けれども、GAN の収束理論は依然として活発に研究されていて現実にはモデルは常にこのポイントまで訓練が進むわけではありません。

 

What is a DCGAN?

DCGAN は上で記述された GAN の直接的な拡張です、それが discriminator と generator でそれぞれ畳込みと転置畳込み (= convolutional-transpose) 層を明示的に使用することを除いてです。それは Radford et. al. によりペーパー Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks で最初に説明されました。discriminator は strided 畳込み 層、 バッチ正規化 層そして LeakyReLU 活性から成ります。入力は 3x64x64 入力画像で出力は real データ分布からの入力であるスカラー確率です。generator は 転置畳込み 層、バッチ正規化層そして ReLU 活性から成ります。入力は潜在ベクトル $z$、これは標準正規分布からドローされます、そして出力は 3x64x64 RGB 画像です。strided 転置畳込み層は潜在ベクトルに画像と同じ shape を持つボリュームへと変換されることを可能にします。このペーパーでは、optimizer をどのようにセットアップするか、損失関数をどのように計算するか、そしてモデル重みをどのように初期化するかについて著者はまた幾つかの tips を与えています、その総てが以下のセクションで説明されます。


from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
Random Seed:  999

 

入力

実行のための幾つかの入力を定義しましょう :

  • dataroot – データセット・フォルダのルートへのパス。次のセクションでデータセットについて更に話します。
  • workers – DataLoader でデータをロードするためのワーカー・スレッドの数
  • batch_size – 訓練で使用されるバッチサイズ。DCGAN ペーパーは 128 のバッチサイズを使用しています。
  • image_size – 訓練のために使用される画像の空間サイズ。この実装は 64×64 をデフォルトとします。他のサイズが望まれる場合、D と G の構造は変更されなければなりません。より詳細は こちら を見てください。
  • nc – 入力画像のカラーチャネルの数。カラー画像に対してはこれは 3 です。
  • nz – 潜在ベクトルの長さ。
  • ngf – generator を通して運ばれる特徴マップの深さに関連します。
  • ndf – discriminator を通して伝播される特徴マップの深さを設定します。
  • num_epochs – 実行する訓練エポックの数。より長い間の訓練は多分より良い結果につながるでしょうがまた遥かに長くかかります。
  • lr – 訓練のための学習率。DCGAN ペーパーで記述されているように、この数字は 0.0002 であるべきです。
  • beta1 – Adam optimizer のための beta1 ハイパーパラメータ。ペーパーで記述されているように、この数は 0.5 であるべきです。
  • ngpu – 利用可能な GPU の数。これが 0 であれば、コードは CPU モードで動作します。この数が 0 より大きい場合にはそれはその数の GPU 上で動作します。
# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparam for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

 

データ

このチュートリアルでは Celeb-A Faces データセット を使用します、これはリンクされたサイトか Google Drive でダウンロードできます。データセットは img_align_celeba.zip という名前のファイルとしてダウンロードされます。ひとたびダウンロードされたら、celeba という名前のディレクトリを作成して zip ファイルをそのディレクトリに解凍します。それからこの notebook のための datroot 入力を (ちょうど作成した) celeba ディレクトリに設定してください。結果としてのディレクトリ構造は以下のようになるはずです :

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

これは重要なステップです、何故ならばデータセットの root フォルダにサブディレクトリがあることを要求する ImageFolder dataset クラスを使用していくからです。今は、データセットを作成して、dataloader を作成して、その上で動作するデバイスを設定してそして最後に訓練データの幾つかを可視化します。

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

 

実装

入力パラメータが設定されてデータセットが準備された今、実装に入ることができます。
重み初期化ストラテジーから始めて、そして generator、discriminator、損失関数そして訓練ループについて詳細に語ります。

 

重み初期化

DCGAN ペーパーから、著者は総てのモデル重みは mean=0, stdev=0.02 を持つ正規分布からランダムに初期化されることを指定しています。weights_init 関数は入力として初期化されたモデルを取りそして総ての畳込み、転置畳込みそしてバッチ正規化層をこの規準を満たすように再初期化します。この関数は初期化後直ちにモデルに適用されます。

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

 

Generator

generator $G$ は潜在空間ベクトル ($z$) をデータ空間にマップするように設計されています。私達のデータは画像ですので、$z$ をデータ空間に変換することは究極的には訓練画像と同じサイズ (i.e. 3x64x64) を持つ RGB 画像を作成することを意味します。実際には、これは一連の strided 2 次元転置畳込み層を通して達成されます、各々は 2d バッチ正規化層と relu 活性と組み合わされています。generator の出力はそれを [-1, 1] の入力データ範囲に戻すために tanh 関数を通して供給されます。転置畳込み層の後のバッチ正規化関数の存在は注目すべきです、何故ならばこれは DCGAN ペーパーの重要な寄与であるからです。これらの層は訓練の間勾配のフローに役立ちます。DCGAN ペーパーからの generator の画像は下で示されます。

入力セクションで設定した入力 (nz, ngf と nc) がコードの generator アーキテクチャにどのように影響するかに注意してください。nz は z 入力ベクトルの長さで、ngf は generator を通して伝播される特徴マップのサイズに関係し、そして nc は出力画像のチャネル数です (RGB 画像のためには 3 に設定)。下は generator のためのコードです。

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. (ngf*8) x 4 x 4
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. (ngf*4) x 8 x 8
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. (ngf*2) x 16 x 16
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. (ngf) x 32 x 32
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. (nc) x 64 x 64
        )

    def forward(self, input):
        return self.main(input)

今、generator をインスタンス化して weights_init 関数を適用することができます。generator オブジェクトがどのように構造化されているかを見るためにプリントされたモデルを調べてください。

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netG.apply(weights_init)

# Print the model
print(netG)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

 

Discriminator

言及したように、discriminator $D$ は二値分類ネットワークで、これは入力として画像を取り入力画像が (fake に対立する) real であるスカラー確率を出力します。ここでは、$D$ は 3x64x64 入力画像を取り、それを Conv2d, BatchNorm2d そして LeakyReLU 層のシリーズを通して処理し、そして Sigmoid 活性化関数を通して最後の確率を出力します。問題のために必要であればこのアーキテクチャはより多くの層により拡張可能ですが、strided 畳込み、BatchNorm と LeakyReLU の使用は意義があります。DCGAN ペーパーはダウンサンプリングにプーリングよりも strided 畳込みを利用することは良い実践であることに言及しています、何故ならばそれはネットワークにそれ自身のプーリング関数を学習させるからです。またバッチ正規化と leaky relu 関数は健全な勾配フローを促進します、これは $G$ と $D$ の両者の学習過程のために重要です。

Discriminator コード

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is (nc) x 64 x 64
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf) x 32 x 32
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*2) x 16 x 16
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*4) x 8 x 8
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (ndf*8) x 4 x 4
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

さて、generator の場合と同様に、discriminator を作成し、weights_init 関数を適用しそしてモデルの構造をプリントすることができます。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Print the model
print(netD)
Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

 

損失関数と Optimizer

$D$ と $G$ のセットアップで、それらが損失関数と optimizer を通してどのように学習するかを指定できます。私達は二値交差エントロピー損失 (BCELoss) 関数を使用します、これは PyTorch で次のように定義されています :

\[
\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = – \left[ y_n \cdot \log x_n + (1 – y_n) \cdot \log (1 – x_n) \right]
\]

この関数が目的関数の両者の log 成分 ((i.e. \(log(D(x))\) と \(log(1-D(G(z)))\))) の計算をどのように提供するかに注意してください。$y$ 入力で BCE 等式のどの部分を使用するかを指定できます。これは間もなく取り上げる訓練ループで成されますが、単に $y$ (i.e. GT ラベル) を変更することによりどの成分を計算することを望むかをどのように選択できるかを理解することは重要です。

次に、私達の real ラベルを 1 としてそして fake ラベルを 0 として定義します。これらのラベルは \(D\) と \(G\) の損失を計算するときに使用されて、そしてこれはまた元の GAN ペーパーで使用された慣習です。最後に、2 つの個別の optimizer をセットアップします、一つは $D$ のためで一つは $G$ のためです。DCGAN ペーパーで指定されているように、両者は学習率 0.0002 と Beta1 = 0.5 を持つ Adam optimizer です。generator の学習進捗を追跡するために、ガウス分布からドローされた潜在ベクトルの固定バッチを生成します (i.e. fixed_noise) 。訓練ループでは、この fixed_noise を $G$ に定期的に入力します、そして反復に渡りノイズから形成される画像を見ます。

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

 

訓練

最後に、GAN フレームワークの総てのパートを定義した今、それを訓練できます。GAN を訓練することは幾分芸術形式であることに留意してください、というのは正しくないハイパーパラメータ設定は何が間違っていたか殆ど説明を持たないで崩壊モードに繋がるでしょう。ここでは、Goodfellow のペーパーから Algorithm 1 に密接に従います、その一方で ganhacks で示されるベストプラクティスの幾つかを守っています。すなわち、“construct different mini-batches for real and fake (画像)” を行ない、そしてまた \(logD(G(z))\) を最大化するために G の目的関数を調整します。訓練は 2 つの主要パートに分割されます。Part 1 は Discriminator を更新してそして Part 2 は Generator を更新します。

 
Part 1 – Discriminator を訓練する

思い出してください、discriminator を訓練するゴールは与えられた入力を real か fake として正しく分類する確率を最大化することです。Goodfellow の言葉では、“update the discriminator by ascending its stochastic gradient (その確率的勾配を登ることにより discriminator を更新する)” ことを望みます。実際には、\(log(D(x)) + log(1-D(G(z)))\) を最大化することを望みます。ganhacks からの separate したミニバッチ提案によって、これを 2 つのステップで計算します。最初に、訓練セットから real サンプルのバッチを構築して、\(D\) を通して forward パスして、損失 \(log(D(x))\)) を計算して、それから backward パスで勾配を計算します。2 番目に、現在の generator で fake サンプルのバッチを構築して、\(D\) を通してこのバッチを forward パスして 、損失 (\(log(1-D(G(z)))\)) を計算し、そして backward パスで勾配を累積します。今、all-real と all-fake の両者のバッチから累積された勾配で、Discriminator の optimizer のステップを呼び出します。

 
Part 2 – Generator を訓練する

元のペーパーで述べられているように、より良い fake を生成する努力において \(log(1-D(G(z)))\) を最小化することにより Generator を訓練することを望みます。言及したように、特に学習過程の早期では、これは十分な勾配を提供しないことが Goodfellow により示されました。その処置として、代わりに \(log(D(G(z)))\) を最大化することを望みます。コードではこれを以下により達成します : Part 1 からの Generator 出力を Discriminator で分類し、real ラベルを GT として使用して G の損失を計算し、backward パスで G の勾配を計算し、そして最後に optimizer ステップで G のパラメータを更新します。損失関数のために real ラベルを GT ラベルとして使用することは直感に反するかもしれませんが、これは BCELoss の (\(log(1-x)\) 部ではなく) \(log(x)\) 部を使用することを可能にし、これは正確に私達が望むものです。

最後に、幾つかの統計的レポーティングを行ないそして各エポックの最後に fixed_noise バッチを generator を通してプッシュして G の訓練の進捗を視覚的に追跡します。レポートされる訓練統計情報は :

  • Loss_D – 総ての real と総ての fake バッチのための損失の総計として計算された discriminator 損失 \(log(D(x)) + log(D(G(z)))\)) 。
  • Loss_G – \(log(D(G(z)))\) として計算される generator 損失
  • D(x) – 総ての real バッチに対する discriminator の (バッチに渡る) 平均出力。これは 1 近くから始まるはずで、それから理論的には G がより良くなるとき 0.5 に収束します。Think about why this is.
  • D(G(z)) – 総ての fake バッチ対する平均 discriminator 出力。最初の数は D が更新される前で 2 番目の数は D が更新される後です。これらの数は 0 近くで始まり G がより良くなるとき 0.5 に収束します。Think about why this is.

Note: このステップは少し時間がかかるかもしれません、幾つのエポックを実行するかそしてデータセットから幾つのデータを除去したか否かに依拠して。

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.9847  Loss_G: 5.5914  D(x): 0.6004    D(G(z)): 0.6680 / 0.0062
[0/5][50/1583]  Loss_D: 1.2821  Loss_G: 27.1449 D(x): 0.9201    D(G(z)): 0.6066 / 0.0000
[0/5][100/1583] Loss_D: 3.5666  Loss_G: 23.9629 D(x): 0.9697    D(G(z)): 0.8966 / 0.0000
[0/5][150/1583] Loss_D: 0.2043  Loss_G: 5.4659  D(x): 0.8647    D(G(z)): 0.0257 / 0.0077
[0/5][200/1583] Loss_D: 0.6839  Loss_G: 2.8157  D(x): 0.6325    D(G(z)): 0.0437 / 0.0860
[0/5][250/1583] Loss_D: 0.9931  Loss_G: 10.2738 D(x): 0.9582    D(G(z)): 0.5235 / 0.0001
[0/5][300/1583] Loss_D: 2.1958  Loss_G: 2.8396  D(x): 0.2550    D(G(z)): 0.0044 / 0.1051
[0/5][350/1583] Loss_D: 0.4682  Loss_G: 3.5296  D(x): 0.8502    D(G(z)): 0.2111 / 0.0450
[0/5][400/1583] Loss_D: 0.4249  Loss_G: 3.7026  D(x): 0.7436    D(G(z)): 0.0367 / 0.0417
[0/5][450/1583] Loss_D: 0.3498  Loss_G: 7.1409  D(x): 0.9223    D(G(z)): 0.1850 / 0.0030
[0/5][500/1583] Loss_D: 0.6108  Loss_G: 7.7019  D(x): 0.9020    D(G(z)): 0.3498 / 0.0011
[0/5][550/1583] Loss_D: 0.6464  Loss_G: 3.7485  D(x): 0.7302    D(G(z)): 0.1707 / 0.0627
[0/5][600/1583] Loss_D: 0.3570  Loss_G: 4.6648  D(x): 0.8717    D(G(z)): 0.1558 / 0.0155
[0/5][650/1583] Loss_D: 0.4305  Loss_G: 5.3618  D(x): 0.9572    D(G(z)): 0.2779 / 0.0111
[0/5][700/1583] Loss_D: 1.3140  Loss_G: 11.5786 D(x): 0.9664    D(G(z)): 0.5969 / 0.0011
[0/5][750/1583] Loss_D: 0.3630  Loss_G: 4.9583  D(x): 0.8572    D(G(z)): 0.1514 / 0.0132
[0/5][800/1583] Loss_D: 0.3699  Loss_G: 5.2018  D(x): 0.8577    D(G(z)): 0.1481 / 0.0122
[0/5][850/1583] Loss_D: 0.1652  Loss_G: 5.4559  D(x): 0.9043    D(G(z)): 0.0371 / 0.0093
[0/5][900/1583] Loss_D: 1.1085  Loss_G: 1.9395  D(x): 0.4598    D(G(z)): 0.0200 / 0.2342
[0/5][950/1583] Loss_D: 0.3559  Loss_G: 3.6596  D(x): 0.8194    D(G(z)): 0.0927 / 0.0498
[0/5][1000/1583]        Loss_D: 0.4936  Loss_G: 5.2182  D(x): 0.8310    D(G(z)): 0.2068 / 0.0106
[0/5][1050/1583]        Loss_D: 0.8654  Loss_G: 9.1896  D(x): 0.9585    D(G(z)): 0.4688 / 0.0004
[0/5][1100/1583]        Loss_D: 0.5183  Loss_G: 2.6914  D(x): 0.7115    D(G(z)): 0.0774 / 0.1028
[0/5][1150/1583]        Loss_D: 0.2161  Loss_G: 5.3966  D(x): 0.9458    D(G(z)): 0.1339 / 0.0083
[0/5][1200/1583]        Loss_D: 0.4653  Loss_G: 4.3271  D(x): 0.7217    D(G(z)): 0.0179 / 0.0272
[0/5][1250/1583]        Loss_D: 1.3157  Loss_G: 10.5584 D(x): 0.9759    D(G(z)): 0.6598 / 0.0001
[0/5][1300/1583]        Loss_D: 0.3837  Loss_G: 4.9804  D(x): 0.9150    D(G(z)): 0.2083 / 0.0143
[0/5][1350/1583]        Loss_D: 0.6663  Loss_G: 3.3218  D(x): 0.6151    D(G(z)): 0.0332 / 0.0713
[0/5][1400/1583]        Loss_D: 1.2427  Loss_G: 6.6631  D(x): 0.9398    D(G(z)): 0.6179 / 0.0035
[0/5][1450/1583]        Loss_D: 0.6161  Loss_G: 5.5180  D(x): 0.8229    D(G(z)): 0.2724 / 0.0091
[0/5][1500/1583]        Loss_D: 0.7386  Loss_G: 4.1078  D(x): 0.8265    D(G(z)): 0.3423 / 0.0295
[0/5][1550/1583]        Loss_D: 0.7012  Loss_G: 6.4715  D(x): 0.9362    D(G(z)): 0.3971 / 0.0029
[1/5][0/1583]   Loss_D: 2.9760  Loss_G: 10.0054 D(x): 0.9849    D(G(z)): 0.8227 / 0.0006
[1/5][50/1583]  Loss_D: 0.5924  Loss_G: 3.0524  D(x): 0.7396    D(G(z)): 0.1524 / 0.0741
[1/5][100/1583] Loss_D: 0.7677  Loss_G: 5.8320  D(x): 0.9481    D(G(z)): 0.4500 / 0.0057
[1/5][150/1583] Loss_D: 0.6516  Loss_G: 2.2812  D(x): 0.6712    D(G(z)): 0.1371 / 0.1652
[1/5][200/1583] Loss_D: 0.4927  Loss_G: 3.9598  D(x): 0.9099    D(G(z)): 0.2689 / 0.0345
[1/5][250/1583] Loss_D: 0.3345  Loss_G: 4.8404  D(x): 0.9267    D(G(z)): 0.2030 / 0.0125
[1/5][300/1583] Loss_D: 0.4107  Loss_G: 2.9924  D(x): 0.8283    D(G(z)): 0.1565 / 0.0777
[1/5][350/1583] Loss_D: 1.2549  Loss_G: 1.0486  D(x): 0.3975    D(G(z)): 0.0398 / 0.4263
[1/5][400/1583] Loss_D: 0.2304  Loss_G: 4.1551  D(x): 0.8473    D(G(z)): 0.0469 / 0.0241
[1/5][450/1583] Loss_D: 0.5795  Loss_G: 2.9901  D(x): 0.6567    D(G(z)): 0.0567 / 0.0857
[1/5][500/1583] Loss_D: 0.4109  Loss_G: 3.2403  D(x): 0.8181    D(G(z)): 0.1448 / 0.0647
[1/5][550/1583] Loss_D: 1.0792  Loss_G: 8.3332  D(x): 0.9176    D(G(z)): 0.5636 / 0.0005
[1/5][600/1583] Loss_D: 0.3855  Loss_G: 3.1238  D(x): 0.8530    D(G(z)): 0.1665 / 0.0658
[1/5][650/1583] Loss_D: 0.4945  Loss_G: 2.8641  D(x): 0.8035    D(G(z)): 0.1759 / 0.0847
[1/5][700/1583] Loss_D: 1.5954  Loss_G: 1.5269  D(x): 0.3392    D(G(z)): 0.0072 / 0.3193
[1/5][750/1583] Loss_D: 0.4016  Loss_G: 3.6690  D(x): 0.8822    D(G(z)): 0.2126 / 0.0363
[1/5][800/1583] Loss_D: 0.4503  Loss_G: 2.9058  D(x): 0.7179    D(G(z)): 0.0467 / 0.0806
[1/5][850/1583] Loss_D: 0.3217  Loss_G: 3.7841  D(x): 0.8585    D(G(z)): 0.1249 / 0.0369
[1/5][900/1583] Loss_D: 0.7214  Loss_G: 2.4609  D(x): 0.5798    D(G(z)): 0.0452 / 0.1310
[1/5][950/1583] Loss_D: 0.7223  Loss_G: 5.1093  D(x): 0.9782    D(G(z)): 0.4242 / 0.0146
[1/5][1000/1583]        Loss_D: 0.5028  Loss_G: 2.0050  D(x): 0.7228    D(G(z)): 0.1010 / 0.1813
[1/5][1050/1583]        Loss_D: 0.3377  Loss_G: 4.4794  D(x): 0.9293    D(G(z)): 0.2026 / 0.0213
[1/5][1100/1583]        Loss_D: 1.8651  Loss_G: 8.1047  D(x): 0.9523    D(G(z)): 0.7664 / 0.0009
[1/5][1150/1583]        Loss_D: 1.0144  Loss_G: 7.3735  D(x): 0.9708    D(G(z)): 0.5438 / 0.0013
[1/5][1200/1583]        Loss_D: 0.7163  Loss_G: 1.6915  D(x): 0.5856    D(G(z)): 0.0409 / 0.2580
[1/5][1250/1583]        Loss_D: 0.4433  Loss_G: 4.2207  D(x): 0.8766    D(G(z)): 0.2377 / 0.0234
[1/5][1300/1583]        Loss_D: 0.8081  Loss_G: 5.1962  D(x): 0.9104    D(G(z)): 0.4510 / 0.0088
[1/5][1350/1583]        Loss_D: 0.9024  Loss_G: 5.4644  D(x): 0.9482    D(G(z)): 0.5014 / 0.0084
[1/5][1400/1583]        Loss_D: 0.6234  Loss_G: 0.9321  D(x): 0.6405    D(G(z)): 0.0920 / 0.4512
[1/5][1450/1583]        Loss_D: 0.5227  Loss_G: 3.8051  D(x): 0.8997    D(G(z)): 0.3084 / 0.0306
[1/5][1500/1583]        Loss_D: 0.4734  Loss_G: 2.5366  D(x): 0.7280    D(G(z)): 0.0874 / 0.1156
[1/5][1550/1583]        Loss_D: 0.6074  Loss_G: 3.7703  D(x): 0.8886    D(G(z)): 0.3458 / 0.0320
[2/5][0/1583]   Loss_D: 0.4978  Loss_G: 2.1119  D(x): 0.7535    D(G(z)): 0.1516 / 0.1538
[2/5][50/1583]  Loss_D: 0.8479  Loss_G: 1.0940  D(x): 0.5147    D(G(z)): 0.0348 / 0.3839
[2/5][100/1583] Loss_D: 0.5531  Loss_G: 4.7267  D(x): 0.9178    D(G(z)): 0.3359 / 0.0145
[2/5][150/1583] Loss_D: 0.6142  Loss_G: 2.4585  D(x): 0.6621    D(G(z)): 0.1180 / 0.1228
[2/5][200/1583] Loss_D: 0.6200  Loss_G: 3.1059  D(x): 0.6215    D(G(z)): 0.0527 / 0.0682
[2/5][250/1583] Loss_D: 0.4501  Loss_G: 2.9026  D(x): 0.8017    D(G(z)): 0.1746 / 0.0761
[2/5][300/1583] Loss_D: 0.9992  Loss_G: 0.9579  D(x): 0.4551    D(G(z)): 0.0633 / 0.4272
[2/5][350/1583] Loss_D: 0.4531  Loss_G: 3.2955  D(x): 0.8345    D(G(z)): 0.2097 / 0.0569
[2/5][400/1583] Loss_D: 1.4651  Loss_G: 0.8333  D(x): 0.3276    D(G(z)): 0.0168 / 0.5268
[2/5][450/1583] Loss_D: 0.4721  Loss_G: 3.1386  D(x): 0.8820    D(G(z)): 0.2601 / 0.0592
[2/5][500/1583] Loss_D: 1.0095  Loss_G: 1.3625  D(x): 0.4675    D(G(z)): 0.0382 / 0.2987
[2/5][550/1583] Loss_D: 0.4755  Loss_G: 3.1508  D(x): 0.8616    D(G(z)): 0.2491 / 0.0577
[2/5][600/1583] Loss_D: 0.6649  Loss_G: 3.6077  D(x): 0.7807    D(G(z)): 0.3009 / 0.0424
[2/5][650/1583] Loss_D: 1.2826  Loss_G: 4.5301  D(x): 0.9652    D(G(z)): 0.6521 / 0.0197
[2/5][700/1583] Loss_D: 0.5502  Loss_G: 3.7380  D(x): 0.9252    D(G(z)): 0.3314 / 0.0350
[2/5][750/1583] Loss_D: 1.4062  Loss_G: 5.1900  D(x): 0.9003    D(G(z)): 0.6633 / 0.0097
[2/5][800/1583] Loss_D: 0.9581  Loss_G: 4.0939  D(x): 0.8554    D(G(z)): 0.5062 / 0.0245
[2/5][850/1583] Loss_D: 0.5956  Loss_G: 2.6712  D(x): 0.9178    D(G(z)): 0.3573 / 0.0977
[2/5][900/1583] Loss_D: 1.7876  Loss_G: 1.4199  D(x): 0.2265    D(G(z)): 0.0166 / 0.3049
[2/5][950/1583] Loss_D: 0.5993  Loss_G: 3.2266  D(x): 0.8728    D(G(z)): 0.3331 / 0.0527
[2/5][1000/1583]        Loss_D: 0.6875  Loss_G: 1.6206  D(x): 0.5945    D(G(z)): 0.0617 / 0.2494
[2/5][1050/1583]        Loss_D: 0.5171  Loss_G: 2.8921  D(x): 0.8456    D(G(z)): 0.2624 / 0.0747
[2/5][1100/1583]        Loss_D: 0.5507  Loss_G: 2.5814  D(x): 0.6882    D(G(z)): 0.1103 / 0.1036
[2/5][1150/1583]        Loss_D: 0.4891  Loss_G: 3.0394  D(x): 0.7822    D(G(z)): 0.1805 / 0.0715
[2/5][1200/1583]        Loss_D: 0.9422  Loss_G: 1.0834  D(x): 0.4981    D(G(z)): 0.1049 / 0.3744
[2/5][1250/1583]        Loss_D: 0.9472  Loss_G: 0.7164  D(x): 0.4634    D(G(z)): 0.0561 / 0.5296
[2/5][1300/1583]        Loss_D: 0.4667  Loss_G: 2.9211  D(x): 0.7725    D(G(z)): 0.1541 / 0.0748
[2/5][1350/1583]        Loss_D: 0.6084  Loss_G: 4.2239  D(x): 0.9073    D(G(z)): 0.3590 / 0.0216
[2/5][1400/1583]        Loss_D: 0.6736  Loss_G: 2.0391  D(x): 0.6200    D(G(z)): 0.1129 / 0.1762
[2/5][1450/1583]        Loss_D: 0.5695  Loss_G: 2.0845  D(x): 0.6678    D(G(z)): 0.1006 / 0.1636
[2/5][1500/1583]        Loss_D: 0.7781  Loss_G: 3.9123  D(x): 0.9292    D(G(z)): 0.4600 / 0.0306
[2/5][1550/1583]        Loss_D: 0.5800  Loss_G: 2.8121  D(x): 0.7136    D(G(z)): 0.1664 / 0.0841
[3/5][0/1583]   Loss_D: 0.6866  Loss_G: 1.2368  D(x): 0.5888    D(G(z)): 0.0629 / 0.3474
[3/5][50/1583]  Loss_D: 0.4827  Loss_G: 2.7606  D(x): 0.7790    D(G(z)): 0.1755 / 0.0903
[3/5][100/1583] Loss_D: 1.1030  Loss_G: 4.3139  D(x): 0.8421    D(G(z)): 0.5342 / 0.0220
[3/5][150/1583] Loss_D: 1.0086  Loss_G: 4.4525  D(x): 0.9477    D(G(z)): 0.5511 / 0.0192
[3/5][200/1583] Loss_D: 0.6956  Loss_G: 2.5286  D(x): 0.7302    D(G(z)): 0.2661 / 0.1020
[3/5][250/1583] Loss_D: 1.0426  Loss_G: 3.9627  D(x): 0.8986    D(G(z)): 0.5501 / 0.0298
[3/5][300/1583] Loss_D: 0.6225  Loss_G: 2.6796  D(x): 0.7152    D(G(z)): 0.1908 / 0.0887
[3/5][350/1583] Loss_D: 1.3732  Loss_G: 5.0278  D(x): 0.9563    D(G(z)): 0.6832 / 0.0094
[3/5][400/1583] Loss_D: 0.4851  Loss_G: 2.8528  D(x): 0.8215    D(G(z)): 0.2247 / 0.0756
[3/5][450/1583] Loss_D: 0.6482  Loss_G: 3.6800  D(x): 0.8935    D(G(z)): 0.3771 / 0.0379
[3/5][500/1583] Loss_D: 1.2615  Loss_G: 4.2520  D(x): 0.8589    D(G(z)): 0.6115 / 0.0250
[3/5][550/1583] Loss_D: 0.6365  Loss_G: 1.8921  D(x): 0.6713    D(G(z)): 0.1606 / 0.1847
[3/5][600/1583] Loss_D: 0.5330  Loss_G: 2.7128  D(x): 0.8131    D(G(z)): 0.2437 / 0.0856
[3/5][650/1583] Loss_D: 1.9794  Loss_G: 0.5431  D(x): 0.2068    D(G(z)): 0.0397 / 0.6388
[3/5][700/1583] Loss_D: 0.6074  Loss_G: 2.2992  D(x): 0.7487    D(G(z)): 0.2274 / 0.1330
[3/5][750/1583] Loss_D: 0.5139  Loss_G: 2.1525  D(x): 0.7426    D(G(z)): 0.1422 / 0.1570
[3/5][800/1583] Loss_D: 0.8958  Loss_G: 3.7582  D(x): 0.9187    D(G(z)): 0.5037 / 0.0336
[3/5][850/1583] Loss_D: 0.7184  Loss_G: 2.6059  D(x): 0.7339    D(G(z)): 0.2741 / 0.1005
[3/5][900/1583] Loss_D: 0.5275  Loss_G: 3.8261  D(x): 0.8821    D(G(z)): 0.2905 / 0.0327
[3/5][950/1583] Loss_D: 0.6070  Loss_G: 2.4334  D(x): 0.7912    D(G(z)): 0.2790 / 0.1101
[3/5][1000/1583]        Loss_D: 0.7844  Loss_G: 2.7326  D(x): 0.7999    D(G(z)): 0.3798 / 0.0929
[3/5][1050/1583]        Loss_D: 0.8756  Loss_G: 2.1391  D(x): 0.6474    D(G(z)): 0.2944 / 0.1412
[3/5][1100/1583]        Loss_D: 1.0252  Loss_G: 4.8418  D(x): 0.9227    D(G(z)): 0.5351 / 0.0146
[3/5][1150/1583]        Loss_D: 0.9527  Loss_G: 1.2363  D(x): 0.4865    D(G(z)): 0.1076 / 0.3313
[3/5][1200/1583]        Loss_D: 0.8031  Loss_G: 3.0929  D(x): 0.7701    D(G(z)): 0.3550 / 0.0674
[3/5][1250/1583]        Loss_D: 0.6219  Loss_G: 1.7979  D(x): 0.7600    D(G(z)): 0.2595 / 0.2005
[3/5][1300/1583]        Loss_D: 0.7066  Loss_G: 2.9271  D(x): 0.8065    D(G(z)): 0.3317 / 0.0760
[3/5][1350/1583]        Loss_D: 0.5467  Loss_G: 1.4356  D(x): 0.6813    D(G(z)): 0.1080 / 0.2827
[3/5][1400/1583]        Loss_D: 0.5424  Loss_G: 3.0990  D(x): 0.9155    D(G(z)): 0.3330 / 0.0591
[3/5][1450/1583]        Loss_D: 1.1482  Loss_G: 0.7654  D(x): 0.4039    D(G(z)): 0.0690 / 0.5224
[3/5][1500/1583]        Loss_D: 0.6861  Loss_G: 2.4627  D(x): 0.7815    D(G(z)): 0.3135 / 0.1112
[3/5][1550/1583]        Loss_D: 0.5983  Loss_G: 3.3769  D(x): 0.8517    D(G(z)): 0.3215 / 0.0449
[4/5][0/1583]   Loss_D: 0.6654  Loss_G: 2.2644  D(x): 0.7580    D(G(z)): 0.2801 / 0.1315
[4/5][50/1583]  Loss_D: 0.5557  Loss_G: 2.2323  D(x): 0.7087    D(G(z)): 0.1515 / 0.1395
[4/5][100/1583] Loss_D: 0.6370  Loss_G: 1.6936  D(x): 0.6772    D(G(z)): 0.1802 / 0.2180
[4/5][150/1583] Loss_D: 0.6787  Loss_G: 1.0779  D(x): 0.6064    D(G(z)): 0.1114 / 0.3910
[4/5][200/1583] Loss_D: 0.6140  Loss_G: 2.5970  D(x): 0.7473    D(G(z)): 0.2257 / 0.1048
[4/5][250/1583] Loss_D: 0.5363  Loss_G: 3.5347  D(x): 0.9215    D(G(z)): 0.3422 / 0.0362
[4/5][300/1583] Loss_D: 1.5164  Loss_G: 1.3299  D(x): 0.3114    D(G(z)): 0.0342 / 0.3438
[4/5][350/1583] Loss_D: 0.5139  Loss_G: 2.2716  D(x): 0.7593    D(G(z)): 0.1762 / 0.1264
[4/5][400/1583] Loss_D: 1.2510  Loss_G: 4.8018  D(x): 0.9065    D(G(z)): 0.6168 / 0.0124
[4/5][450/1583] Loss_D: 0.7356  Loss_G: 2.0707  D(x): 0.7513    D(G(z)): 0.3120 / 0.1564
[4/5][500/1583] Loss_D: 0.6553  Loss_G: 3.6002  D(x): 0.9199    D(G(z)): 0.3942 / 0.0383
[4/5][550/1583] Loss_D: 0.7996  Loss_G: 1.5098  D(x): 0.5130    D(G(z)): 0.0391 / 0.2819
[4/5][600/1583] Loss_D: 0.5633  Loss_G: 1.5441  D(x): 0.6907    D(G(z)): 0.1265 / 0.2745
[4/5][650/1583] Loss_D: 0.6551  Loss_G: 3.2228  D(x): 0.8717    D(G(z)): 0.3562 / 0.0558
[4/5][700/1583] Loss_D: 0.6168  Loss_G: 3.7609  D(x): 0.8887    D(G(z)): 0.3581 / 0.0336
[4/5][750/1583] Loss_D: 1.0102  Loss_G: 1.7060  D(x): 0.5355    D(G(z)): 0.2248 / 0.2284
[4/5][800/1583] Loss_D: 0.7609  Loss_G: 1.4617  D(x): 0.5551    D(G(z)): 0.0695 / 0.2792
[4/5][850/1583] Loss_D: 0.6161  Loss_G: 1.9848  D(x): 0.6841    D(G(z)): 0.1563 / 0.2069
[4/5][900/1583] Loss_D: 0.5277  Loss_G: 1.6981  D(x): 0.7130    D(G(z)): 0.1415 / 0.2142
[4/5][950/1583] Loss_D: 0.5122  Loss_G: 2.2607  D(x): 0.7418    D(G(z)): 0.1627 / 0.1325
[4/5][1000/1583]        Loss_D: 0.6613  Loss_G: 4.3808  D(x): 0.9238    D(G(z)): 0.4027 / 0.0185
[4/5][1050/1583]        Loss_D: 0.6141  Loss_G: 1.8883  D(x): 0.6740    D(G(z)): 0.1380 / 0.1894
[4/5][1100/1583]        Loss_D: 0.8745  Loss_G: 4.8708  D(x): 0.9070    D(G(z)): 0.4908 / 0.0119
[4/5][1150/1583]        Loss_D: 0.9964  Loss_G: 0.4272  D(x): 0.4654    D(G(z)): 0.0773 / 0.6834
[4/5][1200/1583]        Loss_D: 0.6035  Loss_G: 2.7608  D(x): 0.8706    D(G(z)): 0.3290 / 0.0830
[4/5][1250/1583]        Loss_D: 0.8481  Loss_G: 2.4196  D(x): 0.6765    D(G(z)): 0.2971 / 0.1181
[4/5][1300/1583]        Loss_D: 0.4499  Loss_G: 2.4512  D(x): 0.7474    D(G(z)): 0.1212 / 0.1163
[4/5][1350/1583]        Loss_D: 0.5246  Loss_G: 3.6333  D(x): 0.9294    D(G(z)): 0.3325 / 0.0342
[4/5][1400/1583]        Loss_D: 0.5456  Loss_G: 2.5568  D(x): 0.8039    D(G(z)): 0.2386 / 0.1028
[4/5][1450/1583]        Loss_D: 0.6094  Loss_G: 2.5911  D(x): 0.8320    D(G(z)): 0.3132 / 0.0932
[4/5][1500/1583]        Loss_D: 0.4767  Loss_G: 2.7012  D(x): 0.8140    D(G(z)): 0.2076 / 0.0863
[4/5][1550/1583]        Loss_D: 0.4742  Loss_G: 2.5811  D(x): 0.7891    D(G(z)): 0.1840 / 0.0966

 

結果

最後に、私達がどのように行なったかを確認しましょう。ここでは、3 つの異なる結果を見ます。最初に、D と G の損失が訓練の間にどのように変わったかを見ます。2 番目に、総てのエポックについて fixed_noise バッチ上の G の出力を可視化します。そして 3 番目に、G からの fake データのバッチの次に real データのバッチを見ます。

 
損失 versus 訓練反復

下は D & G の損失 vs 訓練反復のプロットです。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

 
G の進捗の可視化

訓練の総てのエポック後に fixed_noise バッチ上の generator の出力をどのようにセーブしたかを思い出してください。今、G の訓練進捗をアニメーションで可視化できます。アニメーションを開始するには play ボタンを押してください。

#%%capture
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

 
Real 画像 vs. Fake 画像

最後に、幾つかの real 画像と fake 画像を並べて見てみましょう。

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
plt.show()

 

WHERE TO GO NEXT

旅の最後に到達しましたが、ここから行ける幾つかの場所があります。貴方は以下ができるでしょう :

  • より長く訓練して結果がどのように良くなるかを見ます。
  • 異なるデータセットを取るためにモデルを変更します、多分画像のサイズとモデル・アーキテクチャを変更します。
  • ここ で幾つかの他のクールな GAN プロジェクトを調べます。
  • 音楽 を生成する GAN を作成します。
 

以上