PyTorch : Pyro examples : ガウス混合モデル

PyTorch : Pyro examples : ガウス混合モデル (翻訳)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 11/26/2018 (v0.2.1)

* 本ページは、Pyro のドキュメント Examples : Gaussian Mixture Model を翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ガウス混合モデル

このチュートリアルは混合モデルの動機づける例を通して Pyro で離散潜在変数をどのように周辺化するかを示します。tiny 5-ポイントのデータセット上で自明の 1-D ガウスモデルを訓練することでモデルを単純に維持して、並列列挙 (= parallel enumeration) のメカニクスに焦点を合わせます。

Pyro の `TraceEnum_ELBO <http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.traceenum_elbo.TraceEnum_ELBO>`__ はガイドとモデルの両者で変数を自動的に周辺化できます。ガイド変数を列挙するとき、Pyro はシーケンシャルに列挙するか (これは変数がダウンストリーム制御フローを決定する場合に有用です)、あるいは新しい tensor 次元を割り当てて変数のサンプル部位 (= site) で可能な値の tensor を作成するために非標準評価を使用して並列に列挙することができます。それからこれらの非標準値はモデルでリプレーされます。モデルの変数を列挙するとき、変数は並列に列挙されてそしてガイドに現れてはなりません。数学的には、モデル側列挙は変数を正確に周辺化することによりイェンセンの不等式の適用を回避するところで、ガイド側列挙は総ての値を列挙することにより単純に確率的 ELBO の分散を減少させます。

from __future__ import print_function
import os
from collections import defaultdict
import numpy as np
import scipy.stats
import torch
from torch.distributions import constraints
from matplotlib import pyplot
%matplotlib inline

import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.contrib.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate

smoke_test = ('CI' in os.environ)
pyro.enable_validation(True)

 

データセット

ここに私達の tiny データセットがあります。それは 5 つのポイントを持ちます。

data = torch.tensor([0., 1., 10., 11., 12.])

 

MAP 推定

事前分布とデータが与えられたとき、モデルパラメータ weights、locs と scale を学習することから始めましょう。`AutoDelta <http://docs.pyro.ai/en/dev/contrib.autoguide.html#autodelta>`__ guide (その delta 分布により命名されます) を使用してこれらの点推定を学習します。私達のモデルはグローバルな混合重みと各混合構成要素の位置と、両者の構成要素に共通の共有されたスケールを学習します。推論の間、`TraceEnum_ELBO <http://docs.pyro.ai/en/dev/inference_algos.html#pyro.infer.traceenum_elbo.TraceEnum_ELBO>`__ はデータポイントの割当てをクラスタに周辺化します。

K = 2  # 構成要素の固定数。

@config_enumerate(default='parallel')
@poutine.broadcast
def model(data):
    # グローバル変数。
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.iarange('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.iarange('data', len(data)):
        # ローカル変数。
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']))

この (model,guide) ペアで推論を実行するために、各反復で総ての割り当てに渡り列挙する Pyro の `config_enumerate() <http://docs.pyro.ai/en/dev/poutine.html#pyro.infer.enum.config_enumerate>`__ 関数を使用します。`pyro.iarange <http://docs.pyro.ai/en/dev/primitives.html#pyro.iarange>`__ indepencence コンテキストでバッチ化された Categorical 割当てをラップしますので、この列挙は並列に発生します: 私達は 2**len(data) = 32 ではなくて、2 つの可能性だけを列挙します。最後に、列挙の並列バージョンを使用するために、max_iarange_nesting=1 を通してシングル `iarange <http://docs.pyro.ai/en/dev/primitives.html#pyro.iarange>`__ だけを使用していることを Pyro に知らせます; これは最右端次元 `iarange <http://docs.pyro.ai/en/dev/primitives.html#pyro.iarange>`__ を使用していて Pyro は並列化のために任意の他の次元を使用できることを Pyro に知らせます。

optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_iarange_nesting=1)
svi = SVI(model, global_guide, optim, loss=elbo)

推論の前に推算値 (= plausible values) を初期化します。混合モデルはローカルモードに非常に敏感です。一般的なアプローチは多くのランダムな初期化の中で最善を選択することで、そこではクラスタ平均はデータのランダムなサブサンプリングから初期化されます。`AutoDelta <http://docs.pyro.ai/en/dev/contrib.autoguide.html#autodelta>`__ ガイドを使用していますので、各変数のために一つの param を初期化できます、そこでは名前は “auto_” により prefix されて制約は各分布に対して適切です (`Distribution.support <https://pytorch.org/docs/stable/distributions.html#torch.distributions.distribution.Distribution.support>`__ 属性から制約を見つけることができます)。

def initialize(seed):
    pyro.set_rng_seed(seed)
    pyro.clear_param_store()
    # Initialize weights to uniform.
    pyro.param('auto_weights', 0.5 * torch.ones(K), constraint=constraints.simplex)
    # Assume half of the data variance is due to intra-component noise.
    pyro.param('auto_scale', (data.var() / 2).sqrt(), constraint=constraints.positive)
    # Initialize means from a subsample of data.
    pyro.param('auto_locs', data[torch.multinomial(torch.ones(len(data)) / len(data), K)]);
    loss = svi.loss(model, global_guide, data)
    return loss

# Choose the best among 100 random initializations.
loss, seed = min((initialize(seed), seed) for seed in range(100))
initialize(seed)
print('seed = {}, initial_loss = {}'.format(seed, loss))
seed = 7, initial_loss = 25.6655845642

訓練の間、収束を監視するために損失と勾配 norm を収集します。PyTorch の .register_hook() メソッドを使用してこれを行なうことができます。

# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

losses = []
for i in range(200 if not smoke_test else 2):
    loss = svi.step(data)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')
...................................................................................................
...................................................................................................
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');

pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');

ここに学習されたパラメータがあります :

map_estimates = global_guide(data)
weights = map_estimates['weights']
locs = map_estimates['locs']
scale = map_estimates['scale']
print('weights = {}'.format(weights.data.numpy()))
print('locs = {}'.format(locs.data.numpy()))
print('scale = {}'.format(scale.data.numpy()))
weights = [0.375      0.62500006]
locs = [ 0.49898112 10.984461  ]
scale = 0.651433706284

モデルの重みは予測されたように、データのおよそ 2/5 が最初の構成要素で 3/5 が 2 番目の構成要素にあります。次に混合モデルを可視化しましょう。

X = np.arange(-3,15,0.1)
Y1 = weights[0].item() * scipy.stats.norm.pdf((X - locs[0].item()) / scale.item())
Y2 = weights[1].item() * scipy.stats.norm.pdf((X - locs[1].item()) / scale.item())

pyplot.figure(figsize=(10, 4), dpi=100).set_facecolor('white')
pyplot.plot(X, Y1, 'r-')
pyplot.plot(X, Y2, 'b-')
pyplot.plot(X, Y1 + Y2, 'k--')
pyplot.plot(data.data.numpy(), np.zeros(len(data)), 'k*')
pyplot.title('Density of two-component mixture model')
pyplot.ylabel('probability density');

最後に混合モデルの最適化は非凸 (= non-convex) でしばしばローカル最適化条件で行き詰まる可能性があることに注意してください。例えばこのチュートリアルで、scale が大きすぎるように初期化された場合混合モデルが everthing-in-one-cluster 仮説で行き詰まることを観測しました。

 

ガイド内で列挙する: メンバーシップを予測する

ここまでモデルの割当変数を周辺化しました。これが高速な収束を提供する一方で、それはガイドからクラスタ割当を読むことを妨げます。

ガイドからクラスタ割当を読むために、(上のような) グローバルパラメータと (前に周辺化された) ローカルパラメータの両者に fit する新しい full_guide を定義します。global のための良い値は既に学習したので、`poutine.block <http://docs.pyro.ai/en/dev/poutine.html#pyro.poutine.block>`__ を使用してそれらを SVI が更新することを防ぎます。

@config_enumerate(default="parallel")
@poutine.broadcast
def full_guide(data):
    # Global variables.
    with poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.
        global_guide(data)

    # Local variables.
    with pyro.iarange('data', len(data)):
        assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
                                      constraint=constraints.unit_interval)
        pyro.sample('assignment', dist.Categorical(assignment_probs))
optim = pyro.optim.Adam({'lr': 0.2, 'betas': [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_iarange_nesting=1)
svi = SVI(model, full_guide, optim, loss=elbo)

# Register hooks to monitor gradient norms.
gradient_norms = defaultdict(list)
svi.loss(model, full_guide, data)  # Initializes param store.
for name, value in pyro.get_param_store().named_parameters():
    value.register_hook(lambda g, name=name: gradient_norms[name].append(g.norm().item()))

losses = []
for i in range(200 if not smoke_test else 2):
    loss = svi.step(data)
    losses.append(loss)
    print('.' if i % 100 else '\n', end='')
...................................................................................................
...................................................................................................
pyplot.figure(figsize=(10,3), dpi=100).set_facecolor('white')
pyplot.plot(losses)
pyplot.xlabel('iters')
pyplot.ylabel('loss')
pyplot.yscale('log')
pyplot.title('Convergence of SVI');

pyplot.figure(figsize=(10,4), dpi=100).set_facecolor('white')
for name, grad_norms in gradient_norms.items():
    pyplot.plot(grad_norms, label=name)
pyplot.xlabel('iters')
pyplot.ylabel('gradient norm')
pyplot.yscale('log')
pyplot.legend(loc='best')
pyplot.title('Gradient norms during SVI');

今ではガイドのローカル assignment_probs 変数を検査できます。

assignment_probs = pyro.param('assignment_probs')
pyplot.figure(figsize=(8, 4), dpi=100).set_facecolor('white')
pyplot.plot(data.data.numpy(), assignment_probs.data.numpy()[:, 0], 'ro',
            label='component with mean {:0.2g}'.format(locs[0]))
pyplot.plot(data.data.numpy(), assignment_probs.data.numpy()[:, 1], 'bo',
            label='component with mean {:0.2g}'.format(locs[1]))
pyplot.title('Mixture assignment probabilities')
pyplot.xlabel('data value')
pyplot.ylabel('assignment probability')
pyplot.legend(loc='center');

 

以上