Pyro 0.3.0 : Pyro の推論へのイントロダクション (翻訳)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/15/2018 (v0.3.0)
* 本ページは、Pyro のドキュメント An Introduction to Inference in Pyro を
翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
Pyro の推論へのイントロダクション
かなりの現代的な機械学習は近似推論として位置づけられて Pyro のような言語で簡潔に表わされます。このチュートリアルの残りを動機づけるために、最初に単純な物理問題のための生成モデルを構築しましょう。結果としてそれを解くために Pyro の推論機構が利用できます。けれども、最初にこのチュートリアルのために必要なモジュールをインポートします :
import matplotlib.pyplot as plt import numpy as np import torch import pyro import pyro.infer import pyro.optim import pyro.distributions as dist pyro.set_rng_seed(101)
単純な例
あるものがどのくらいの重さかを見出そうとしていますが、使用している測りが当てにならずに同じ物体を測るたびにわずかに異なる答えを与えるものと仮定します。ノイズを持つ測定情報を (その密度や材料特性のような) 物体についての何某かの事前知識をもとにした推測と統合することによりこの変化性を補完しようとすることができるでしょう。次のモデルはこの過程をエンコードします :
\[
{\sf weight} \, | \, {\sf guess} \sim \cal {\sf Normal}({\sf guess}, 1)\\
{\sf measurement} \, | \, {\sf guess}, {\sf weight} \sim {\sf Normal}({\sf weight}, 0.75)
\]
これは重みに渡る私達の信念のためだけでなく、それの測定を取る結果のためのモデルでもあることに注意してください。モデルは次の確率関数に対応します :
def scale(guess): weight = pyro.sample("weight", dist.Normal(guess, 1.0)) return pyro.sample("measurement", dist.Normal(weight, 0.75))
条件付け (= Conditioning)
確率プログラミングの実際の有用性は生成モデルを観測データ上で条件付けしてデータを生成したかもしれない潜在要因を推論する能力にあります。Pyro では、条件付ける式を推論を通したその評価から分離し、モデルを一度書いて多くの異なる観測上でそれを条件付けることを可能にします。Pyro は与えられた観測のセットと同じになるようにモデルの内部の sample ステートメントを制約することをサポートします。
scale を再度考えます。入力 guess = 8.5 が与えられたとき重さの分布からサンプリングすることを望みますが、しかし今 measurement == 9.5 を観測したと仮定します。つまり、次の分布を推論することを望みます :
\[
({\sf weight} \, | \, {\sf guess}, {\sf measurement} = 9.5) \sim \, ?
\]
Pyro は sample ステートメントの値を制約することを可能にする関数 pyro.condition を提供します。pyro.condition は高階関数でモデルと観測の辞書を取り新しいモデルを返します。このモデルは同じ入力と出力シグネチャを持ちますが観測された sample ステートメントで与えられた値を常に使用します :
conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})
それはちょうど通常の Python 関数のように振る舞いますので、条件付けは延ばされるか Python の lambda か def でパラメータ化されます :
def deferred_conditioned_scale(measurement, *args, **kwargs): return pyro.condition(scale, data={"measurement": measurement})(*args, **kwargs)
ある場合には pyro.condition を使用する代わりに観測を個々の pyro.sample ステートメントに直接渡す方がより便利かもしれません。オプションの obs キーワード引数はその目的で pyro.sample により予約されています :
def scale_obs(guess): # equivalent to conditioned_scale above weight = pyro.sample("weight", dist.Normal(guess, 1.)) # here we condition on measurement == 9.5 return pyro.sample("measurement", dist.Normal(weight, 1.), obs=9.5)
最後に、包含的な観測のための pyro.condition に加えて、Pyro はまた、pyro.condition への同一のインターフェイスを持つ因果推論のために使用される pyro.do、Pearl の do-演算子の実装を含みます。condition と do は自由に mix されて組み合わせることができて、Pyro をモデル・ベースの因果推論のためのパワフルなツールにします。
ガイド関数を持つ柔軟な近似推論
conditioned_scale に戻りましょう。measurement (測定) の観測上で条件付けた今、guess と measurement == data が与えられたときの重さに渡る分布を推定するために Pyro の近似推論アルゴリズムを使用できます。
pyro.infer.SVI のような、Pyro の推論アルゴリズムは (ガイド関数またはガイドと呼ぶ) 任意の確率関数を近似事後分布として使用することを可能にします。ガイド関数は特定のモデルのために正当な近似であるためにこれら 2 つの基準を満たさなければなりません:
- モデルに現れる総ての未観測の i.e. 条件付けられていない (= unobserved i.e. not conditioned) sample ステートメントはガイドに出現します。
- ガイドはモデルと同じ入力シグネチャを持ちます (i.e. 同じ引数を取ります)。
ガイド関数は重点サンプリング、棄却サンプリング、逐次モンテカルロ、MCMC、そして独立 (型) Metropolis-Hastings のためのプログラム可能な、データ依存の提案分布として、そして確率的変分推論のための変分分布または推論ネットワークとしてサーブできます。現在、重点サンプリング、MCMC、そして確率的変分推論が Pyro で実装されていますが、将来的には他のアルゴリズムを追加する計画です。
ガイドの正確な意味は異なる推論アルゴリズムで異なりますが、ガイド関数は一般に、原理的には、モデルの総ての未観測 sample ステートメントに渡る分布を密接に近似するために十分に柔軟であるように選択されるべきです。
scale のケースでは、guess と measurement が与えられたとき重さに渡る真の事後分布は実際には Normal(9.14,0.6) であることが判明しています。モデルが非常に単純なので、関心のある事後分布を解析的に決定することができます (導出については、例えば http://www.stat.cmu.edu/~brian/463-663/week09/Chapter%2003.pdf の Section 3.4 参照)。
def perfect_guide(guess): loc =(0.75**2 * guess + 9.5) / (1 + 0.75**2) # 9.14 scale = np.sqrt(0.75**2/(1 + 0.75**2)) # 0.6 return pyro.sample("weight", dist.Normal(loc, scale))
パラメータ化された確率関数と変分推論
scale のための正確な事後分布を書き出すことはできるでしょうが、一般には任意の条件付けられた確率関数の事後分布への良い近似となるガイドを指定することは扱いにくいです。実際に、(確率関数のために) 正確に真の事後分布を決定できる確率関数は通例よりも寧ろ例外です。例えば、真ん中に非線形関数を持つ scale サンプル版でさえも扱いにくいかもしれません :
def intractable_scale(guess): weight = pyro.sample("weight", dist.Normal(guess, 1.0)) return pyro.sample("measurement", dist.Normal(some_nonlinear_function(weight), 0.75))
代わりにできることは名前付けられたパラメータでインデックスされたガイドの族を指定するためにトップレベル関数 pyro.param を使用して、そしてある損失関数による最善の近似であるその族のメンバーを探し求めることです。事後推論を近似するためのこのアプローチは 変分推論 と呼ばれます。
pyro.param は Pyro のキーバリュー・パラメータ・ストアのためのフロントエンドで、これはドキュメントでより詳細に記述されます。pyro.sample のように、pyro.param はその最初の引数として常に名前とともに呼び出されます。最初に pyro.param が特定の名前で呼ばれたとき、それはその引数をパラメータ・ストアにストアしてからその値を返します。その後、それがその名前で呼び出されたとき、それはどのような他の引数でもそれはパラメータ・ストアから値を返します。それはここでは simple_param_store.setdefault に類似していますが、幾つかの追加の追跡と管理機能を伴います。
simple_param_store = {} a = simple_param_store.setdefault("a", torch.randn(1))
例えば、scale_posterior_guide で a と b を手動で指定する代わりにパラメータ化できます :
def scale_parametrized_guide(guess): a = pyro.param("a", torch.tensor(guess)) b = pyro.param("b", torch.tensor(1.)) return pyro.sample("weight", dist.Normal(a, torch.abs(b)))
余談ですが、scale_parametrized_guide では、パラメータ b に torch.abs を適用しなければならないことに注意してください、何故ならば正規分布の標準偏差は正でなけれればならないからです ; 類似の制限はまた多くの他の分布のパラメータにも適用されます。PyTorch distributions ライブラリ (その上で Pyro は構築されています) はそのような制限を強制するための constraints モジュールを含み、そして constraints を Pyro パラメータに適用することは関係する constraint オブジェクトを pyro.param に渡すように容易です :
from torch.distributions import constraints def scale_parametrized_guide_constrained(guess): a = pyro.param("a", torch.tensor(guess)) b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive) return pyro.sample("weight", dist.Normal(a, b)) # no more torch.abs
Pyro は確率的変分推論を可能にするために構築されています、次の 3 つの主要な特色を持つパワフルで広く適用可能な変分推論アルゴリズムのクラスです :
- パラメータは常に実数値の tensor です。
- モデルとガイドの実行履歴のサンプルから損失関数のモンテカルロ推定を計算します。
- 最適なパラメータを探し求めるために確率的勾配降下を使用します。
確率的勾配降下を PyTorch の GPU で高速化された tensor math 及び自動微分と結合すると変分推論を非常に高次元なパラメータ空間と大規模なデータセットにスケールすることを可能にします。
Pyro の SVI 機能は SVI チュートリアル で詳細に記述されます。ここにそれを scale に適用した非常に単純な例があります :
guess = torch.tensor(8.5) pyro.clear_param_store() svi = pyro.infer.SVI(model=conditioned_scale, guide=scale_parametrized_guide, optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}), loss=pyro.infer.Trace_ELBO()) losses, a,b = [], [], [] num_steps = 2500 for t in range(num_steps): losses.append(svi.step(guess)) a.append(pyro.param("a").item()) b.append(pyro.param("b").item()) plt.plot(losses) plt.title("ELBO") plt.xlabel("step") plt.ylabel("loss"); print('a = ',pyro.param("b").item()) print('b = ', pyro.param("a").item())
a = 0.6285385489463806 b = 9.107474327087402
plt.subplot(1,2,1) plt.plot([0,num_steps],[9.14,9.14], 'k:') plt.plot(a) plt.ylabel('a') plt.subplot(1,2,2) plt.ylabel('b') plt.plot([0,num_steps],[0.6,0.6], 'k:') plt.plot(b) plt.tight_layout()
SVI は望まれる条件付き分布の真のパラメータに非常に近いパラメータを得ることに注意してください。 This is to be expected as our guide is from the same family.
最適化はパラメータストアのガイド・パラメータの値を更新することに注意してください、その結果、ひとたび良いパラメータ値を見い出せばガイドからのサンプルをダウンストリーム・タスクからの事後サンプルとして使用できます。
Next Steps
変分オートエンコーダ・チュートリアル では、scale のようなモデルを深層ニューラルネットワークでどのように増強するかを見てそして生成モデルを構築するために確率的変分推論を使用します。
以上