Pyro 1.4 : Pyro の推論へのイントロダクション

Pyro 1.4 : Pyro の推論へのイントロダクション (翻訳)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/03/2020 (1.4.0)

* 本ページは、Pyro の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

Pyro の推論へのイントロダクション

現代的な機械学習の大半は近似推論として位置づけられて Pyro のような言語で簡潔に表わされます。このチュートリアルの残りを動機づけるために、単純な物理問題のための生成モデルを構築しましょう。結果としてそれを解くために Pyro の推論機構が利用できます。けれども、最初にこのチュートリアルのために必要なモジュールをインポートします :

import matplotlib.pyplot as plt
import numpy as np
import torch

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

pyro.set_rng_seed(101)

 

単純な例

あるものがどのくらいの重さかを算定しようとしていますが、使用している測りが当てにならずに同じ物体を測るたびにわずかに異なる答えを与えるものと仮定します。このノイズを持つ測定情報を (その密度や材料特性のような) 物体についての何某かの事前知識をもとにした推測と統合することによりこの変化性を補完しようとすることができるでしょう。次のモデルはこの過程をエンコードします :

$$
{\sf weight} \, | \, {\sf guess} \sim \cal {\sf Normal}({\sf guess}, 1) \\
{\sf measurement} \, | \, {\sf guess}, {\sf weight} \sim {\sf Normal}({\sf weight}, 0.75)
$$

これは重みに渡る私達の信念のためだけでなく、それの測定を取る結果のためのモデルでもあることに注意してください。モデルは次の確率関数に対応します :

def scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

 

条件付け (= Conditioning)

確率プログラミングの実際の有用性は 観測データ上で生成モデルを条件付け してデータを生成したかもしれない潜在要因を推論する能力にあります。Pyro では、推論を通したその評価から条件付ける式を分離し、モデルを一度書いてそして多くの異なる観測上でそれを条件付けることを可能にします。Pyro は与えられた観測のセットと同じになるようにモデル内部の sample ステートメントを制約することをサポートします。

scale (関数) を再度考えます。入力 guess = 8.5 が与えられたとき weight の分布からサンプリングすることを望みますが、しかし今 measurement == 9.5 を観測したと仮定します。つまり、次の分布を推論することを望みます :

\[
({\sf weight} \, | \, {\sf guess}, {\sf measurement} = 9.5) \sim \, ?
\]

Pyro は sample ステートメントの値を制約することを可能にする関数 pyro.condition を提供します。pyro.condition は高階関数でモデルと観測の辞書を取り新しいモデルを返します、このモデルは同じ入力と出力シグネチャを持ちますが観測された sample ステートメントで与えられた値を常に利用します :

conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})

それはちょうど通常の Python 関数のように動作しますので、条件付けは遅延されるか Python の lambda か def でパラメータ化されます :

def deferred_conditioned_scale(measurement, guess):
    return pyro.condition(scale, data={"measurement": measurement})(guess)

ある場合には pyro.condition を使用する代わりに観測を個々の pyro.sample ステートメントに直接渡す方が便利かもしれません。オプションの obs キーワード引数はその目的で pyro.sample により予約されています :

def scale_obs(guess):  # equivalent to conditioned_scale above
    weight = pyro.sample("weight", dist.Normal(guess, 1.))
     # here we condition on measurement == 9.5
    return pyro.sample("measurement", dist.Normal(weight, 0.75), obs=9.5)

最後に、観測を組み込むための pyro.condition に加えて、Pyro はまた pyro.condition への同一のインターフェイスを持つ因果推論のために使用される pyro.do、Pearl の do-演算子の実装も含みます。condition と do は自由に混合されて組み合わせることができて、Pyro をモデル・ベースの因果推論のためのパワフルなツールにします。

 

ガイド関数を持つ柔軟な近似推論

conditioned_scale に戻りましょう。measurement の観測上で条件付けた今、guess と measurement == data が与えられたときの weight に渡る分布を推定するために Pyro の近似推論アルゴリズムを利用できます。

pyro.infer.SVI のような、Pyro の推論アルゴリズムは (ガイド関数またはガイドと呼ぶ) 任意の確率関数を近似事後分布として使用することを可能にします。ガイド関数は特定のモデルのための正当な近似であるためにはこれらの 2 つの基準を満たさなければなりません:

  1. モデルに現れる総ての未観測の (i.e. 条件付けられていない) sample ステートメントはガイドに出現します。
  2. ガイドはモデルと同じ入力シグネチャを持ちます (i.e. 同じ引数を取ります)。

ガイド関数は重点サンプリング、棄却サンプリング、逐次モンテカルロ、MCMC、そして独立 (型) Metropolis-Hastings のためのプログラム可能な、データ依存の提案分布として、そして確率的変分推論のための 変分分布推論ネットワーク としてサーブできます。現在、重点サンプリング、MCMC と確率的変分推論が Pyro で実装されていますが、将来的には他のアルゴリズムを追加する予定です。

ガイドの正確な意味は異なる推論アルゴリズムに渡り異なりますが、ガイド関数は一般に、原理的には、モデルの総ての未観測 sample ステートメントに渡る分布を密接に近似するために十分に柔軟であるように選択されるべきです。

scale のケースでは、guess と measurement が与えられたとき weight に渡る真の事後分布は実際には Normal(9.14,0.6) であることが判明しています。モデルが非常に単純なので、関心のある事後分布を解析的に決定することができます (導出については、例えば http://www.stat.cmu.edu/~brian/463-663/week09/Chapter%2003.pdf の Section 3.4 参照)。

def perfect_guide(guess):
    loc =(0.75**2 * guess + 9.5) / (1 + 0.75**2) # 9.14
    scale = np.sqrt(0.75**2/(1 + 0.75**2)) # 0.6
    return pyro.sample("weight", dist.Normal(loc, scale))

 

パラメータ化された確率関数と変分推論

scale のための正確な事後分布を書き出すことはできるでしょうが、一般には任意の条件付けられた確率関数の事後分布への良い近似となるガイドを指定することは扱いにくいです。実際に、(確率関数のために) 真の事後分布を正確に決定できる確率関数は通例よりも寧ろ例外です。例えば、中に非線形関数を持つ scale 例のバージョンでさえも扱いにくいかもしれません :

def intractable_scale(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))
    return pyro.sample("measurement", dist.Normal(some_nonlinear_function(weight), 0.75))

代わりに私達ができることは、名前付けられたパラメータでインデックスされたガイドの族を指定するためにトップレベル関数 pyro.param を使用して、そしてある損失関数に従って最善の近似であるその族のメンバーを探し求めることです。事後推論を近似するためのこのアプローチは 変分推論 と呼ばれます。

pyro.param は Pyro の キーバリュー・パラメータストア のためのフロントエンドで、これはドキュメントで詳細に記述されます。pyro.sample のように、pyro.param はその最初の引数として名前とともに常に呼び出されます。最初に pyro.param が特定の名前で呼ばれたとき、それはその引数をパラメータストアにストアしてから値を返します。その後、それがその名前で呼び出されたとき、それはどのような他の引数でもそれはパラメータストアから値を返します。それはここでは simple_param_store.setdefault に類似していますが、幾つかの追加の追跡と管理機能を伴います。

simple_param_store = {}
a = simple_param_store.setdefault("a", torch.randn(1))

例えば、scale_posterior_guide で a と b を手動で指定する代わりにパラメータ化できます :

def scale_parametrized_guide(guess):
    a = pyro.param("a", torch.tensor(guess))
    b = pyro.param("b", torch.tensor(1.))
    return pyro.sample("weight", dist.Normal(a, torch.abs(b)))

余談ですが、scale_parametrized_guide では、パラメータ b に torch.abs を適用しなければならなかったことに注意してください、何故ならば正規分布の標準偏差は正でなければならないからです ; 類似の制限はまた多くの他の分布のパラメータにも適用されます。PyTorch distributions ライブラリ (その上で Pyro は構築されています) はそのような制限を強制するための constraints モジュール を含み、そして constraints を Pyro パラメータに適用することは関係する constraint オブジェクトを pyro.param に渡す程度に容易です :

from torch.distributions import constraints

def scale_parametrized_guide_constrained(guess):
    a = pyro.param("a", torch.tensor(guess))
    b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
    return pyro.sample("weight", dist.Normal(a, b))  # no more torch.abs

Pyro は確率的変分推論を可能にするために構築されています、以下の 3 つの主要な特色を持つパワフルで広く適用可能な変分推論アルゴリズムのクラスです :

  1. パラメータは常に実数値 tensor です。
  2. モデルとガイドの実行履歴のサンプルから損失関数のモンテカルロ推定を計算します。
  3. 最適なパラメータを探し求めるために 確率的勾配降下 を使用します。

確率的勾配降下を PyTorch の GPU で高速化された tensor math 及び自動微分と結合することは、変分推論を非常に高次元なパラメータ空間と大規模なデータセットにスケールすることを可能にします。

Pyro の SVI 機能は SVI チュートリアル で詳細に記述されます。ここにそれを scale に適用した非常に単純な例があります :

guess = 8.5

pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_scale,
                     guide=scale_parametrized_guide,
                     optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
                     loss=pyro.infer.Trace_ELBO())


losses, a,b  = [], [], []
num_steps = 2500
for t in range(num_steps):
    losses.append(svi.step(guess))
    a.append(pyro.param("a").item())
    b.append(pyro.param("b").item())

plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss");
print('a = ',pyro.param("a").item())
print('b = ', pyro.param("b").item())
a =  9.107474327087402
b =  0.6285384893417358

plt.subplot(1,2,1)
plt.plot([0,num_steps],[9.14,9.14], 'k:')
plt.plot(a)
plt.ylabel('a')

plt.subplot(1,2,2)
plt.ylabel('b')
plt.plot([0,num_steps],[0.6,0.6], 'k:')
plt.plot(b)
plt.tight_layout()

SVI は望まれる条件付き分布の真のパラメータに非常に近いパラメータを得ることに注意してください。これは、ガイドが同じ族からですので当然です。

最適化はパラメータストアのガイド・パラメータの値を更新する ことに注意してください、その結果、ひとたび良いパラメータ値を見い出せば、ガイドからのサンプルをダウンストリーム・タスクのための事後サンプルとして利用できます。

 

Next Steps

変分オートエンコーダ・チュートリアル では、scale のようなモデルを深層ニューラルネットワークでどのように増強するかを見てそして画像の生成モデルを構築するために確率的変分推論を利用します。

 

以上