Pyro 1.4 : Examples : 生成モデルによる高速情景理解

Pyro 1.4 : Examples : 生成モデルによる高速情景理解 (翻訳)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/06/2020 (1.4.0)

* 本ページは、Pyro の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

生成モデルによる高速情景理解

このチュートリアルでは “Attend, Infer, Repeat: Fast Scene Understanding with Generative Models” (AIR) [1] で説明されるモデルと推論ストラテジーを実装してそれをマルチ-mnist データセットに適用します。

スタンドアロン実装 も利用可能です。

%pylab inline
import os
from collections import namedtuple
import pyro
import pyro.optim as optim
from pyro.infer import SVI, TraceGraph_ELBO
import pyro.distributions as dist
import pyro.poutine as poutine
import pyro.contrib.examples.multi_mnist as multi_mnist
import torch
import torch.nn as nn
from torch.nn.functional import relu, sigmoid, softplus, grid_sample, affine_grid
import numpy as np

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.4.0')
pyro.enable_validation(True)

 

イントロダクション

[1] で説明されているモデルは情景の生成モデルです。このチュートリアルでは [1] のマルチ-mnist データセットに類似したデータセットから画像をモデル化するためにそれを使用します。ここにこのデータセットからの幾つかのデータポイントがあります :

inpath = '../../examples/air/.data'
X_np, _ = multi_mnist.load(inpath)
X_np = X_np.astype(np.float32)
X_np /= 255.0
mnist = torch.from_numpy(X_np)
def show_images(imgs):
    figure(figsize=(8, 2))
    for i, img in enumerate(imgs):
        subplot(1, len(imgs), i + 1)
        axis('off')
        imshow(img.data.numpy(), cmap='gray')
show_images(mnist[9:14])

どこに進んでいるかのアイデアを得るため、最初にモデルと推論するために取るアプローチの簡潔な概要を与えます。できる限り密接に [1] で使用された命名規則に従います。

AIR は画像を生成する過程を離散ステップに分解します、それらの各々は画像の一部だけを生成します。より具体的には、各ステップでニューラルネットワークを通して潜在「コード」変数 (z_what) を渡すことによりモデルは小さい画像 (y_att) を生成します。これらの小さい画像を「オブジェクト」として参照します。マルチ- mnist データセットに適用される AIR のケースではこれらのオブジェクトの各々が単一の数字を表すことを想定します。モデルはまた各オブジェクトの位置とサイズについての不確かさも含みます。オブジェクトの位置とサイズをその「ポーズ」(z_where) として記述します。最終的な画像を生成するため、各オブジェクトはポーズ情報 z_where を使用して最初はより大きい画像 (y) 内に位置します。最後に、総ての時間ステップからの y 群は最終的な画像 x を生成するために追加的に (= additively) 結合されます。

ここにこの過程の 2 ステップを示す ([1] から再生成された) 図があります :


Figure 1: 生成過程の 2 ステップ

推論はこのモデルで amortized 確率的変分推論 (SVI) を使用して遂行されます。ニューラルネットワークのパラメータはまた推論の間に最適化されます。そのようなリッチなモデルで推論を遂行することは常に困難ですが、離散的な選択 (このケースではステップ数) の存在はこのモデルの推論を特にトリッキーにします。この理由のために著者は良いパフォーマンスを得るためにデータ依存ベースラインと呼ばれるテクニックを使用しています。このテクニックは Pyro で実装できて、そしてチュートリアルで後でどのようにかを見ます。

 

モデル

単一オブジェクトを生成する

モデルをより密接に見ましょう。モデルの中心部には単一オブジェクトの生成過程があります。以下を思い出してください :

  • 各ステップで単一オブジェクトが生成される。
  • 各オブジェクトはその潜在コードをニューラルネットワークに渡すことにより生成される。
  • 各オブジェクトを生成されるために使用される潜在コードについて不確かさを維持します、そのポーズとともに。

これは Pyro で次のように表すことができます :

# Create the neural network. This takes a latent code, z_what, to pixel intensities.
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(50, 200)
        self.l2 = nn.Linear(200, 400)

    def forward(self, z_what):
        h = relu(self.l1(z_what))
        return sigmoid(self.l2(h))

decode = Decoder()

z_where_prior_loc = torch.tensor([3., 0., 0.])
z_where_prior_scale = torch.tensor([0.1, 1., 1.])
z_what_prior_loc = torch.zeros(50)
z_what_prior_scale = torch.ones(50)

def prior_step_sketch(t):
    # Sample object pose. This is a 3-dimensional vector representing x,y position and size.
    z_where = pyro.sample('z_where_{}'.format(t),
                          dist.Normal(z_where_prior_loc.expand(1, -1),
                                      z_where_prior_scale.expand(1, -1))
                              .to_event(1))

    # Sample object code. This is a 50-dimensional vector.
    z_what = pyro.sample('z_what_{}'.format(t),
                         dist.Normal(z_what_prior_loc.expand(1, -1),
                                     z_what_prior_scale.expand(1, -1))
                             .to_event(1))

    # Map code to pixel space using the neural network.
    y_att = decode(z_what)

    # Position/scale object within larger image.
    y = object_to_image(z_where, y_att)

    return y

望ましくはこの時点で pyro.sample の使用とモデル内の PyTorch ネットワークに馴染みがあるように見えることです。そうでないなら VAE チュートリアル を見直すことを望むかもしれません。注意すべき一つのことは名前がステップに渡り一意であることを確実にするために pyro.sample に渡される名前に現在のステップ t を含むことです。

object_to_image 関数はこのモデルに固有で更なる注意が必要です。ニューラルネットワーク (ここではデコード) は小さい画像を出力し、z_where で記述されるポーズ (位置とサイズ) を獲得するために必要な任意の変換とスケーリングを遂行した後これを出力画像に追加したいことを思い出してください。これをどのように行なうかは明らかではなく、特にこれが (SVI を遂行するために必要とする) モデルの微分可能性を保存する方法で実装できるか明白ではありません。けれども、spatial transformer ネットワーク (STN) [2] を使用してこれを行えることが判明しています。

私達にとって幸いなことに、PyTorch は grid_sampleaffine_grid 関数を使用して STN を実装することを容易にします。object_to_image は単純な関数で、これらを呼び出し、z_where を想定する形式にマッサージする少し余分の作業を行ないます。

def expand_z_where(z_where):
    # Takes 3-dimensional vectors, and massages them into 2x3 matrices with elements like so:
    # [s,x,y] -> [[s,0,x],
    #             [0,s,y]]
    n = z_where.size(0)
    expansion_indices = torch.LongTensor([1, 0, 2, 0, 1, 3])
    out = torch.cat((torch.zeros([1, 1]).expand(n, 1), z_where), 1)
    return torch.index_select(out, 1, expansion_indices).view(n, 2, 3)

def object_to_image(z_where, obj):
    n = obj.size(0)
    theta = expand_z_where(z_where)
    grid = affine_grid(theta, torch.Size((n, 1, 50, 50)))
    out = grid_sample(obj.view(n, 1, 20, 20), grid)
    return out.view(n, 50, 50)

STN の詳細の議論はこのチュートリアルの範囲を越えています。けれども私達の目的のためには、object_to_image はニューラルネットワークにより生成された小さい画像を取りそしてそれを望まれるポーズを持つより大きな画像内に配置することに留意すれば十分です。

これを明らかにするために prior_step_sketch を何回か呼び出した結果を可視化しましょう :

pyro.set_rng_seed(0)
samples = [prior_step_sketch(0)[0] for _ in range(5)]
show_images(samples)

 

画像全体を生成する

単一ステップの実装を完了し、次に画像全体を生成するためにこれをどのように使用できるか考えます。各データポイントを生成するために使用されるステップ数に渡り不確かさを維持したいことを思い出してください。ステップ数に渡る事前分布のために行える一つの選択は幾何分布で、これは次のように表せます :

pyro.set_rng_seed(0)
def geom(num_trials=0):
    p = torch.tensor([0.5])
    x = pyro.sample('x{}'.format(num_trials), dist.Bernoulli(p))
    if x[0] == 1:
        return num_trials
    else:
        return geom(num_trials + 1)

# Generate some samples.
for _ in range(5):
    print('sampled {}'.format(geom()))
sampled 2
sampled 3
sampled 0
sampled 1
sampled 0

これはベルヌーイ試行のシリーズの成功の前の失敗の数としての幾何分布の定義の直接的な翻訳です。ここでこれを再帰関数として表します、これは行なわれる試行数 num_trials を表すカウンターまわりを渡します。この関数はベルヌーイからサンプリングして x==1 (これは成功を表します) ならば num_trials を返し、そうでなければ再帰呼出しをしてカウンターを増加させます。

幾何事前分布の使用は好ましいです、何故ならばそれモデルが使用できるステップ数を先験的には束縛しないからです。それはまた便利です、何故ならば各再帰呼出しの前にオブジェクトを生成するために幾何分布を拡張することにより、これをカウントに渡る幾何分布から幾何的に分布したステップを持つ画像に渡る分布に変えるからです。

def geom_prior(x, step=0):
    p = torch.tensor([0.5])
    i = pyro.sample('i{}'.format(step), dist.Bernoulli(p))
    if i[0] == 1:
        return x
    else:
        x = x + prior_step_sketch(step)
        return geom_prior(x, step + 1)

この分布から幾つかサンプルを可視化しましょう :

pyro.set_rng_seed(4)
x_empty = torch.zeros(1, 50, 50)
samples = [geom_prior(x_empty)[0] for _ in range(5)]
show_images(samples)

 

尤度を指定する

モデルの使用を完結させるために必要な最後のものは尤度関数です。[1] に従って 0.3 の固定標準偏差を持つガウシアン尤度を使用します。これは obs 引数を使用して pyro.sample で実装することは簡単です。

後で推論を遂行するようになるとき事前 (確率) と尤度を一つの関数にパッケージ化することが便利であることを見出すでしょう。これはまた plate を導入するに便利な場所で、これをデータ・サブサンプリングを実装するために使用します、そしてネットワークを登録するために pyro.module を最適化したいです。

def model(data):
    # Register network for optimization.
    pyro.module("decode", decode)
    with pyro.plate('data', data.size(0)) as indices:
        batch = data[indices]
        x = prior(batch.size(0)).view(-1, 50 * 50)
        sd = (0.3 * torch.ones(1)).expand_as(x)
        pyro.sample('obs', dist.Normal(x, sd).to_event(1),
                    obs=batch)

 

ガイド

[1] に従ってこのモデルで amortized 確率的変分推論 を遂行します。Pyro はこの推論ストラテジーの殆どを実装する一般目的メカニズムを提供しますが、先のチュートリアルで見たようにモデル固有のガイドを提供することが要求されます。Pyro でガイドと呼称するものは正確に、ペーパーで「推論ネットワーク」と呼ばれる実在です。

真の事後 (確率) に在ることを想定する依存性 (の一部) を捕捉することをガイドに可能にするためにリカレント・ネットワーク回りにガイドを構築します。各ステップでリカレントネットワークはステップ内で行なわれる選択のためのパラメータを生成します。サンプリングされた値はリカレントネットワークに供給し戻されて、その結果この情報は次のステップのためのパラメータを計算するとき利用できます。深層マルコフモデル のためのガイドも同様の構造を共有します。

モデルでのように、ガイドの中心は単一ステップのためのロジックです。ここにこれの実装の概略があります :

def guide_step_basic(t, data, prev):

    # The RNN takes the images and choices from the previous step as input.
    rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)
    h, c = rnn(rnn_input, (prev.h, prev.c))

    # Compute parameters for all choices made this step, by passing
    # the RNN hidden start through another neural network.
    z_pres_p, z_where_loc, z_where_scale, z_what_loc, z_what_scale = predict_basic(h)

    z_pres = pyro.sample('z_pres_{}'.format(t),
                         dist.Bernoulli(z_pres_p * prev.z_pres))

    z_where = pyro.sample('z_where_{}'.format(t),
                          dist.Normal(z_where_loc, z_where_scale))

    z_what = pyro.sample('z_what_{}'.format(t),
                         dist.Normal(z_what_loc, z_what_scale))

    return # values for next step

これはこのモデルで利用するために合理的なガイドですが、ペーパーは上のコードに行える重要な改良を説明しています。ガイドは各ステップでオブジェクトのポーズとその潜在コードについての情報を出力することを思い出してください。私達が行なうことができる改良は次の観察に基づいています: ひとたびオブジェクトのポーズを推論したならば、入力画像からオブジェクトをクロップするためにポーズ情報を利用して、潜在コードのパラメータを計算するために追加のネットワークに結果を渡す場合、その潜在コードを推論するより良いジョブを行えます。このネットワークを下で「エンコーダ」と呼称します。

ここにこの改良されたガイドをどのように実装できるか、そして関係するネットワークの具体的な実装ががあります :

rnn = nn.LSTMCell(2554, 256)

# Takes pixel intensities of the attention window to parameters (mean,
# standard deviation) of the distribution over the latent code,
# z_what.
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(400, 200)
        self.l2 = nn.Linear(200, 100)

    def forward(self, data):
        h = relu(self.l1(data))
        a = self.l2(h)
        return a[:, 0:50], softplus(a[:, 50:])

encode = Encoder()

# Takes the guide RNN hidden state to parameters of
# the guide distributions over z_where and z_pres.
class Predict(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.l = nn.Linear(256, 7)

    def forward(self, h):
        a = self.l(h)
        z_pres_p = sigmoid(a[:, 0:1]) # Squish to [0,1]
        z_where_loc = a[:, 1:4]
        z_where_scale = softplus(a[:, 4:]) # Squish to >0
        return z_pres_p, z_where_loc, z_where_scale

predict = Predict()

def guide_step_improved(t, data, prev):

    rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)
    h, c = rnn(rnn_input, (prev.h, prev.c))
    z_pres_p, z_where_loc, z_where_scale = predict(h)

    z_pres = pyro.sample('z_pres_{}'.format(t),
                         dist.Bernoulli(z_pres_p * prev.z_pres)
                             .to_event(1))

    z_where = pyro.sample('z_where_{}'.format(t),
                          dist.Normal(z_where_loc, z_where_scale)
                              .to_event(1))

    # New. Crop a small window from the input.
    x_att = image_to_object(z_where, data)

    # Compute the parameter of the distribution over z_what
    # by passing the window through the encoder network.
    z_what_loc, z_what_scale = encode(x_att)

    z_what = pyro.sample('z_what_{}'.format(t),
                         dist.Normal(z_what_loc, z_what_scale)
                             .to_event(1))

    return # values for next step

ガイドの微分可能性を維持したいので、必要な「クロッピング」を遂行するために再度 STN を利用します。image_to_object 関数はガイドで使用される object_to_image 関数とは反対の変換を遂行します。つまり、前者は小さい画像を取りそれをより大きい画像上に配置し、後者はより大きい画像から小さい画像をクロップします

def z_where_inv(z_where):
    # Take a batch of z_where vectors, and compute their "inverse".
    # That is, for each row compute:
    # [s,x,y] -> [1/s,-x/s,-y/s]
    # These are the parameters required to perform the inverse of the
    # spatial transform performed in the generative model.
    n = z_where.size(0)
    out = torch.cat((torch.ones([1, 1]).type_as(z_where).expand(n, 1), -z_where[:, 1:]), 1)
    out = out / z_where[:, 0:1]
    return out

def image_to_object(z_where, image):
    n = image.size(0)
    theta_inv = expand_z_where(z_where_inv(z_where))
    grid = affine_grid(theta_inv, torch.Size((n, 1, 20, 20)))
    out = grid_sample(image.view(n, 1, 50, 50), grid)
    return out.view(n, -1)

 

もう一つの視点

ここまでモデルとガイドを個別に考えましたが、ズームアウトしてモデルとガイド計算を全体として見る場合興味深い視点を得ます。それを行えば、各ステップで AIR は 変分オートエンコーダ (VAE) と同じ構造を持つ部分計算を含むことを見ます。これを見るため、ガイドは潜在コードに渡る分布のパラメータを生成するためにニューラルネットワーク (エンコーダ) に window を渡し、そしてモデルはこの潜在コード分布からサンプルをもう一つのニューラルネットワーク (デコーダ) に渡すことに注意してください。この構造は次の図でハイライトされています、[1] から再生成されました :


Figure 2 : 各ステップのガイドとモデル間の相互作用

この視点から AIR は vae の一連の亜種として見られます。入力画像から小さいウィンドウをクロッピングする行為は VAE の注意を入力画像の小さい領域に制限するために役立ちます ; hence “Attend, Infer, Repeat”.

 

推論

イントロダクションで述べたように、このモデルで推論を成功的に遂行することは挑戦です。特に、モデルでの離散的選択の存在は総ての選択が再パラメータ化できるモデルよりも推論を技巧的にします。直面する基礎的な問題は変分推論で遂行される最適化で使用する勾配推定が 非再パラメータ化可能な選択 の存在で遥かに高い分散を持つことです。

この分散を制御下に置くため、ペーパーはモデルの離散的選択に「データ依存なベースライン」(AKA「ニューラル・ベースライン」) と呼ばれるテクニックを適用しています。

 

データ依存ベースライン

私達に幸いなことに、Pyro はデータ依存なベースラインのためのサポートを含みます。貴方がこのアイデアにまだ馴染みがないのであれば、続ける前に イントロダクション を読むことを望むかもしれません。model author としてニューラルネットワークを実装して、それに入力としてデータを渡し、そしてその出力を pyro.sample に供給します。Pyro の推論バックエンドは、ベースラインが推論のために使用される勾配推定器に含まれそしてネットワークパラメータが適切に更新されることを確かなものにします。

私達の AIR 実装にどのようにデータ依存ベースラインを追加できるかを見ましょう。ニューラルネットワークが必要で、これはガイドの各離散的選択で (スカラー) ベースライン値を出力できて、入力としてそこまでにガイドによりサンプリングされたマルチ-mnist 画像と値を受け取ります。
これはガイドネットワークの構造に非常に類似していることに気付いてください、そして実際にリカレントネットワークを再度利用します。

これを実装するため、ちょうど説明した RNN の単一ステップを実装する短いヘルパー関数を最初に書きます :

bl_rnn = nn.LSTMCell(2554, 256)
bl_predict = nn.Linear(256, 1)

# Use an RNN to compute the baseline value. This network takes the
# input images and the values samples so far as input.
def baseline_step(x, prev):
    rnn_input = torch.cat((x,
                           prev.z_where.detach(),
                           prev.z_what.detach(),
                           prev.z_pres.detach()), 1)
    bl_h, bl_c = bl_rnn(rnn_input, (prev.bl_h, prev.bl_c))
    bl_value = bl_predict(bl_h) * prev.z_pres
    return bl_value, bl_h, bl_c

ここでハイライトすべき 2 つの重要な詳細があります :

最初に、ガイドによりサンプリングされた値をそれらをベースラインネットワークに渡す前にデタッチします。これは重要です、何故ならばベースラインネットワークとガイドネットワークは異なる目的関数で最適化される全く分離したネットワークであるからです。これなしでは、勾配はベースラインネットワークからガイドネットワークに流れてしまうでしょう。データ依存ベースラインを使用するときガイドからサンプリングされた値をベースラインネットワークに供給するときはいつでもこれを行なわなければなりません。(そうしない場合には PyTorch 実行時エラーを引き起こします。)

2 番目に、ベースラインネットワークの出力を前のステップからの z_pres の値により乗算します。これは完全な (= complete) サンプルのための正確な予測を出力しなければならないという重荷からベースラインネットワークを解放します。(完全なサンプルのための出力はゼロで乗算されますので、これらの出力のためのベースライン損失の導関数はゼロになります。) これは OK です、何故ならば実際には推論目的関数からの完了したサンプルのためのランダム選択は既に除去しているので、それらに任意の分散リダクションを適用する必要がないからです。

今ではガイドの実装を完了するために必要な総てを持っています。最終的な guide_step 関数は上で導入された guide_step_improved に非常に類似しています。唯一の変更は :

  1. 今は baseline_step ヘルパーを呼び出してそれが返すベースライン値を pyro.sample に渡します。
  2. 今は完全なサンプルのために the z_where と z_what 選択をマスクします。
    これはモデルにマスクが追加されたのと正確に同じ目的で役立ちます。
    (この変更の裏の動機のためには先の議論を見てください。)

モデル全体のためのガイドを提供するために guide_step を反復するガイド関数も書きます。

GuideState = namedtuple('GuideState', ['h', 'c', 'bl_h', 'bl_c', 'z_pres', 'z_where', 'z_what'])
def initial_guide_state(n):
    return GuideState(h=torch.zeros(n, 256),
                      c=torch.zeros(n, 256),
                      bl_h=torch.zeros(n, 256),
                      bl_c=torch.zeros(n, 256),
                      z_pres=torch.ones(n, 1),
                      z_where=torch.zeros(n, 3),
                      z_what=torch.zeros(n, 50))

def guide_step(t, data, prev):

    rnn_input = torch.cat((data, prev.z_where, prev.z_what, prev.z_pres), 1)
    h, c = rnn(rnn_input, (prev.h, prev.c))
    z_pres_p, z_where_loc, z_where_scale = predict(h)

    # Here we compute the baseline value, and pass it to sample.
    baseline_value, bl_h, bl_c = baseline_step(data, prev)
    z_pres = pyro.sample('z_pres_{}'.format(t),
                         dist.Bernoulli(z_pres_p * prev.z_pres)
                             .to_event(1),
                         infer=dict(baseline=dict(baseline_value=baseline_value.squeeze(-1))))

    z_where = pyro.sample('z_where_{}'.format(t),
                          dist.Normal(z_where_loc, z_where_scale)
                              .mask(z_pres)
                              .to_event(1))

    x_att = image_to_object(z_where, data)

    z_what_loc, z_what_scale = encode(x_att)

    z_what = pyro.sample('z_what_{}'.format(t),
                         dist.Normal(z_what_loc, z_what_scale)
                             .mask(z_pres)
                             .to_event(1))

    return GuideState(h=h, c=c, bl_h=bl_h, bl_c=bl_c, z_pres=z_pres, z_where=z_where, z_what=z_what)

def guide(data):
    # Register networks for optimization.
    pyro.module('rnn', rnn),
    pyro.module('predict', predict),
    pyro.module('encode', encode),
    pyro.module('bl_rnn', bl_rnn)
    pyro.module('bl_predict', bl_predict)

    with pyro.plate('data', data.size(0), subsample_size=64) as indices:
        batch = data[indices]
        state = initial_guide_state(batch.size(0))
        steps = []
        for t in range(3):
            state = guide_step(t, batch, state)
            steps.append(state)
        return steps

 

総てを上手くまとめる

今ではモデルとガイドの実装を完了しました。前のほうのチュートリアルで見たように、推論を遂行し始めるために数行多いコードだけを書く必要があります :

data = mnist.view(-1, 50 * 50)

svi = SVI(model,
          guide,
          optim.Adam({'lr': 1e-4}),
          loss=TraceGraph_ELBO())

for i in range(5):
    loss = svi.step(data)
    print('i={}, elbo={:.2f}'.format(i, loss / data.size(0)))
i=0, elbo=2806.79
i=1, elbo=3656.81
i=2, elbo=3222.37
i=3, elbo=3872.77
i=4, elbo=2818.27

ここでの一つの主要な細部は単純な Trace_ELBO ではなくTraceGraph_ELBO 損失を使用することです。これはデータ依存ベースラインをサポートする勾配推定器を利用することを望むことを示します。この推定器はまたモデルに含まれる独立性情報を利用して勾配推定の 分散を削減 もします。類似のものは [1] で使用される暗黙性で、このモデル上で良い結果を獲得するために必要です。

 

結果

私達の実装をサニティチェックするため スタンドアロン実装 を使用して推論を実行してそのパフォーマンスを [1] でレポートされた結果の幾つかに対して比較しました。

ここで最適化の間の ELBO と訓練セット・カウント精度の進捗を示します :


Figure 3: 左: 最適化の間のエビデンス下界 (ELBO) の進捗。; 右: 最適化の間の訓練セット・カウント精度の進捗

カウント精度は 98.7 % あたりまで達しました、これは [1] でレポートされているカウント精度と同じ範囲 (= ballpark) にあります。ELBO 上で到達した値は [1] でレポートされたものとは少し異なります、これは使用された事前 (確率) に小さい違いがあるためかもしれません。

次の図はトップ行はテストセットからの 10 データポイントを示しています。ボトム行はこれらの入力の各々のためのガイドから単一サンプルの可視化です、それは z_pres と z_where のためにサンプリングされた値を示します。[1] に従って、1 番目、2 番目そして 3 番目のステップはそれぞれ赤、緑と青色の境界を使用して表示されます。(ガイドはこれらのサンプルの任意に対して 3 ステップを使用しないので青い境界は示されません)それはまた出力を生成するため、モデルを通してガイドからサンプリングされた潜在変数を渡すことで得られた入力の再構築を示します。


Figure 4: トップ行: マルチ-mnist テストセットからのデータポイント ; ボトム行: ガイドからのサンプルとモデルの入力の再構築の可視化

これらの結果は次のパラメータを使用して収集されました :

python main.py -n 200000 -blr 0.1 --z-pres-prior 0.01 --scale-prior-sd 0.2 --predict-net 200 --bl-predict-net 200 --decoder-output-use-sigmoid --decoder-output-bias -2 --seed 287710

PyTorch 0.2.0.post4 による Pyro commit c0b38ad を利用しました。推論は NVIDIA K80 GPU 上でおよそ 4 時間実行しました。(ランダムシードを設定してさえも、CUDA を使用するときこれは推論を決定論的にすうには十分ではありません。)

 

References

  1. Attend, Infer, Repeat: Fast Scene Understanding with Generative Models S. M. Ali Eslami and Nicolas Heess and Theophane Weber and Yuval Tassa and Koray Kavukcuoglu and Geoffrey E. Hinton
  2. Spatial Transformer Networks Max Jaderberg and Karen Simonyan and Andrew Zisserman


















 

以上