PyTorch : Pyro イントロ (2) Pyro の推論 – 確率関数から周辺分布

PyTorch : Pyro イントロ (2) Pyro の推論 – 確率関数から周辺分布 (翻訳)

翻訳 : (株)クラスキャット セールスインフォメーション
更新日時 : 11/20, 11/07/2018 (v0.2.1)
作成日時 : 10/10/2018 (v0.2.1)

* 本ページは、Pyro のドキュメント Introduction : Models in Pyro: Inference in Pyro: From Stochastic Functions to Marginal Distributions を動作確認・翻訳した上で適宜、補足説明したものです:

* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

Pyro の推論: 確率関数から周辺分布

確率関数はそれら (同時確率分布) の潜在変数 $z$ に渡る同時確率分布 $p(y, z \; \vert \; x)$ を導いて (= induce) 値 y を返します、そしてこの同時分布は関数の返り値に渡る周辺分布を導きます。けれども、非プリミティブ確率関数については、出力 $p(y \; \vert \; x)$ の周辺確率は決して明白に計算したりあるいは戻り値 $y \sim p (y \; \vert \; x)$ に渡る周辺分布からサンプルをドローしたりはできません。

その最も一般的な定式化では、Pyro のような universal な確率プログラミング言語の推論は (これらの計算を遂行できるように) 任意の boolean 制約が与えられたときにこの周辺分布を構築する問題です。制約は返り値、内部的なランダムネス、あるいは両者の決定論的な関数であり得ます。

ベイズ推論 または 事後推論 は (扱いやすい近似を許容する) このより一般的な定式化の重要な特別なケースです。ベイズ推論では、返り値は常にあるサブセットの内部的なサンプル・ステートメントの値で、制約は他の内部的なサンプル・ステートメント上の等式制約です。かなりの現代的な機械学習は近似ベイズ推論として位置づけられて Pyro のような言語で簡潔に表わされます。

このチュートリアルの残りを動機づけるために、最初に単純な物理問題のための生成モデルを構築しましょう。結果としてそれを解くために Pyro の推論機構が利用できます。

 

単純な例

あるものがどのくらいの重さかを見出そうとしていますが、使用している測りが当てにならずに同じ物体を測るたびにわずかに異なる答えを与えるものと仮定します。ノイズを持つ測定情報を (その密度や材料特性のような) 物体についての何某かの事前知識をもとにした推測と統合することによりこの変化性を補完しようとすることができるでしょう。次のモデルはこの過程をエンコードします :

# import some dependencies
import numpy as np
import matplotlib.pyplot as plt

import torch

import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist

torch.manual_seed(101);
def scale(guess):
    # 重みに渡る事前分布は私達の推測についての不確かさをエンコードします。
    weight = pyro.sample("weight", dist.Normal(guess, 1.0))

    # これはスケールのノイズ性についての私達の信念をエンコードします :
    # 測定は真の重みまわりで変動します
    return pyro.sample("measurement", dist.Normal(weight, 0.75))

 

周辺分布を表わす

物体の重さを推定するために実際にモデルを使用してみる前にモデルの挙動を解析してみましょう。特に、与えられた guess に対してアプリオリに見ることを期待する測定値の周辺分布をシミュレートするために 重点サンプリング を使用することができます。

pyro.infer.EmpiricalMarginal による Pyro の周辺化は 2 つのステップに分割されます。まず、モデルの重み付けられた実行トレースを数多く集めます。それから、それらのトレースを (引数の特定のセットが与えられたとき可能性のある返り値に渡る) ヒストグラムに折り畳むことができます。

実行トレースを集めることはサンプリングか、離散な潜在変数だけを持つモデルについては正確な列挙を通して遂行されます。実行トレース (提案分布として事前分布を使用して) に渡る基本的な重点サンプラーを作成するために、次を書くことができます :

posterior = pyro.infer.Importance(scale, num_samples=100)

posterior はそれ単独では特に有用ではありません。代わりに、posterior の出力 (posterior.run で計算され、これは単一の入力値に対して推論を実行します) は (scale と同じ出力型を持つプリミティブ確率関数を作成する) pyro.infer.EmpiricalMarginal によって消費されることを意図されています。

guess = 8.5

marginal = pyro.infer.EmpiricalMarginal(posterior.run(guess))
print(marginal())

Out:

tensor(8.0281)

入力 guess とともに呼び出されたとき、最初に marginal は guess が与えられたとき重み付けられた実行トレースのシークエンスを生成するために posterior を使用し、それからトレースからの返り値に渡るヒストグラムを構築して、そして最後にヒストグラムからドローされたサンプルを返します。同じ引数で一回より多く marginal を呼び出すことは同じヒストグラムからサンプリングするでしょう。

plt.hist([marginal().item() for _ in range(100)], range=(5.0, 12.0))
plt.title("P(measurement | guess)")
plt.xlabel("weight")
plt.ylabel("#");

pyro.infer.EmpiricalMarginal はまた、潜在変数の名前を提供する、オプションのキーワード引数 sites=name を受け取ります。sites が指定されたとき、marginal は返り値のものではなく、site の周辺分布を計算します。これは有用です、何故ならば同じ posterior オブジェクトから多くの異なる周辺を計算することを望むかも知れないからです。

 

データ上で条件付けるモデル

確率プログラミングの実際の有用性は生成モデルを観測されたデータ上で条件付けてデータを生成したかもしれない潜在要因を推論する能力にあります。Pyro では、条件付ける式を推論を通したその評価から分離し、モデルを一度書いて多くの異なる観測上でそれを条件付けることを可能にします。Pyro は与えられた観測のセットと同じになるようにモデルの内部の sample ステートメントを制約することをサポートします。

scale を再度考えます。入力 guess = 8.5 が与えられたとき重さの周辺分布からサンプリングすることを望みますが、しかし今 measurement == 9.5 を観測したと仮定します。Pyro は sample ステートメントの値を制約することを可能にする関数 pyro.condition を提供します。pyro.condition は高階関数でモデルとデータの辞書を取り新しいモデルを返します。このモデルは同じ入力と出力シグネチャを持ちますが観測された sample ステートメントで与えられた値を常に使用します :

conditioned_scale = pyro.condition(
    scale, data={"measurement": 9.5})

それはちょうど通常の Python 関数のように振る舞いますので、条件付けは伸ばされるか Python の lambda か def でパラメータ化されます :

def deferred_conditioned_scale(measurement, *args, **kwargs):
    return pyro.condition(scale, data={"measurement": measurement})(*args, **kwargs)

ある場合には pyro.condition を使用する代わりに観測を個々の pyro.sample ステートメントに直接渡す方がより便利かもしれません。オプションの obs キーワード引数はその目的で pyro.sample により予約されています :

# equivalent to pyro.condition(scale, data={"measurement": torch.tensor([9.5])})
def scale_obs(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.))
     # here we attach an observation measurement == 9.5
    return pyro.sample("measurement", dist.Normal(weight, 1.),
                       obs=9.5)

けれども、その侵略的な非合成的な性質によりハードコーディングは通常は推奨されません。それに反して、pyro.condition を使用すれば、条件付けは、基礎的なモデルを変更することなく確率モデル上複数の複雑な問い合わせを形成するために自由に構成されます。唯一の制限は単一の site は一度だ制約されるだけかもしれないことです。

def scale2(guess):
    weight = pyro.sample("weight", dist.Normal(guess, 1.))
    tolerance = torch.abs(pyro.sample("tolerance", dist.Normal(0., 1.)))
    return pyro.sample("measurement", dist.Normal(weight, tolerance))

# conditioning composes:
# the following are all equivalent and do not interfere with each other
conditioned_scale2_1 = pyro.condition(
    pyro.condition(scale2, data={"weight": 9.2}),
    data={"measurement": 9.5})

conditioned_scale2_2 = pyro.condition(
    pyro.condition(scale2, data={"measurement": 9.5}),
    data={"weight": 9.2})

conditioned_scale2_3 = pyro.condition(
    scale2, data={"weight": 9.2, "measurement": 9.5})

包含的な観測のための pyro.condition に加えて、Pyro はまた、pyro.condition への同一のインターフェイスを持ち、因果推論のために使用される pyro.do、Pearl の do-演算子の実装を含みます。condition と do は自由に混合されて構成されて、Pyro をモデル・ベースの因果推論のためのパワフルなツールにします。

 

ガイド関数を伴う柔軟な近似推論

deferred_conditioned_scale に戻りましょう。何某かのデータに対して measurement を制約した今、guess と measurement == data が与えられたときの重さに渡る分布を推定するために Pyro の近似推論アルゴリズムを使用できます。scale のためにこれを行なうために重点サンプリングをどのように使用するか先に見ました ; 条件付けられたモデルで正確に同じ構成を使用できます :

guess = 8.5
measurement = 9.5

conditioned_scale = pyro.condition(scale, data={"measurement": measurement})

marginal = pyro.infer.EmpiricalMarginal(
    pyro.infer.Importance(conditioned_scale, num_samples=100).run(guess), sites="weight")

# The marginal distribution concentrates around the data
print(marginal())
plt.hist([marginal().item() for _ in range(100)], range=(5.0, 12.0))
plt.title("P(weight | measurement, guess)")
plt.xlabel("weight")
plt.ylabel("#");

けれども、このアプローチは計算的に極めて非効率です、何故ならば重さに渡る事前分布は重さに渡る真の分布から非常に遠いかもしれないからです、特に初期 guess がそれほど良くない場合は。

そのため、pyro.infer.Importance と pyro.infer.SVI のような、Pyro のある推論アルゴリズムは (ガイド関数またはガイドと呼ぶ) 任意の確率関数を近似事後分布として使用することを可能にします。ガイド関数は特定のモデルのために正当な近似であるためにこれら 2 つの基準を満たさなければなりません:

  1. モデルに出現する総ての未観測の (= unobserved) sample ステートメントはガイドに出現します。
  2. ガイドはモデルと同じ入力シグネチャを持ちます (i.e. 同じ引数を取ります)。

ガイド関数は、重点サンプリング、棄却サンプリング、逐次モンテカルロ、MCMC、そして独立 (型) Metropolis-Hastings のためのプログラム可能な、データ依存の提案分布として、そして確率的変分推論のための変分分布または推論ネットワークとしてサーブできます。現在、重点サンプリングと確率的変分推論だけが Pyro で実装されていますが、将来的には他のアルゴリズムを追加する計画です。

ガイドの正確な意味は異なる推論アルゴリズムで異なりますが、ガイド関数は一般に、モデルの総ての未観測の (= unobserved) sample ステートメントに渡る分布を密接に近似するようにように選択されるべきです。

deferred_conditioned_scale のための最も単純なガイドは重さに渡る事前分布に適合します :

def scale_prior_guide(guess):
    return pyro.sample("weight", dist.Normal(guess, 1.))

posterior = pyro.infer.Importance(conditioned_scale,
                                  guide=scale_prior_guide,
                                  num_samples=10)

marginal = pyro.infer.EmpiricalMarginal(posterior.run(guess), sites="weight")

prior よりも上手くやれるでしょうか?scale の場合、guess と measurement が与えられたとき重さに渡る真の事後分布は次のように直接書かれます :

def scale_posterior_guide(measurement, guess):
    # note that torch.size(measurement, 0) is the total number of measurements
    # that we're conditioning on
    a = (guess + torch.sum(measurement)) / (measurement.size(0) + 1.0)
    b = 1. / (measurement.size(0) + 1.0)
    return pyro.sample("weight", dist.Normal(a, b))

posterior = pyro.infer.Importance(deferred_conditioned_scale,
                                  guide=scale_posterior_guide,
                                  num_samples=20)

marginal = pyro.infer.EmpiricalMarginal(posterior.run(torch.tensor([measurement]), guess), sites="weight")
plt.hist([marginal().item() for _ in range(100)], range=(5.0, 12.0))
plt.title("P(weight | measurement, guess)")
plt.xlabel("weight")
plt.ylabel("#");

 

パラメータ化された確率関数と変分推論

scale のための正確な事後分布を書き出すことは出来るでしょうが、一般には任意の条件付けられた確率関数の事後分布への良い近似となるガイドを指定することは扱いにくいです。代わりにできることは名前付けられたパラメータでインデックスされたガイドのファミリを指定するためにトップレベル関数 pyro.param を使用して、最善の近似であるそのファミリのメンバーを探し求めることです。事後推論を近似するためのこのアプローチは 変分推論 と呼ばれます。

pyro.param は Pyro のキーバリュー・パラメータ・ストアのためのフロントエンドで、これはドキュメントでより詳細に記述されます。pyro.sample のように、pyro.params はその最初の引数として常に名前とともに呼び出されます。最初に pyro.param が特定の名前で呼ばれたとき、それはその引数をパラメータ・ストアにストアしてからその値を返します。その後、それがその名前で呼び出されたとき、どのような他の引数でもそれはパラメータ・ストアから値を返します。それはここの simple_param_store.setdefault に類似していますが、幾つかの追加の追跡と管理機能を伴います。

simple_param_store = {}
a = simple_param_store.setdefault("a", torch.randn(1))

例えば、scale_posterior_guide で a と b を手動で指定する代わりにパラメータ化できます :

def scale_parametrized_guide(guess):
    a = pyro.param("a", torch.tensor(torch.randn(1) + guess))
    b = pyro.param("b", torch.randn(1))
    return pyro.sample("weight", dist.Normal(a, torch.abs(b)))

Pyro は確率的変分推論を可能にするために構築されています、次の 3 つの主要な特色を持つパワフルで広く適用可能な変分推論アルゴリズムのクラスです :

  1. パラメータは常に実数値の tensor です。
  2. モデルとガイドの実行履歴のサンプルから損失関数のモンテカルロ推定を計算します。
  3. オプションのパラメータを探し求めるために確率的勾配降下を使用します。

確率的勾配降下を PyTorch の GPU で高速化された tensor math 及び自動微分と結合すると変分推論を非常に高次元なパラメータ空間と大規模なデータセットにスケールすることを可能にします。

Pyro の SVI 機能は SVI チュートリアル で詳細に記述されます。ここにそれを scale に適用した非常に単純な例があります :

pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_scale,
                     guide=scale_parametrized_guide,
                     optim=pyro.optim.SGD({"lr": 0.001}),
                     loss=pyro.infer.Trace_ELBO())

losses = []
for t in range(1000):
    losses.append(svi.step(guess))

plt.plot(losses)
plt.title("ELBO")
plt.xlabel("step")
plt.ylabel("loss");

最適化はガイド・パラメータを更新しますが、事後分布オブジェクト自身は生成しないことに注意してください。ひとたび良いパラメータ値を見い出すならば、単にガイドをダウンストリーム・タスクのためのモデルの近似の事後の表現として使用できます。

例えば、事前よりもずっと少ないサンプルを持つ重さに渡る周辺分布を推定するために、最適化されたガイドを重点分布として使用できます :

posterior = pyro.infer.Importance(conditioned_scale, scale_parametrized_guide, num_samples=10)
marginal = pyro.infer.EmpiricalMarginal(posterior.run(guess), sites="weight")

plt.hist([marginal().item() for _ in range(100)], range=(5.0, 12.0))
plt.title("P(weight | measurement, guess)")
plt.xlabel("weight")
plt.ylabel("#");

近似事後としてガイドから直接的にサンプリングすることもできます :

plt.hist([scale_parametrized_guide(guess).item() for _ in range(100)], range=(5.0, 12.0))
plt.title("P(weight | measurement, guess)")
plt.xlabel("weight")
plt.ylabel("#");

 

Next Steps

変分オートエンコーダ・チュートリアル では、scale のようなモデルを深層ニューラルネットワークでどのように増強するかを見てそして画像の生成モデルを構築するために確率的変分推論を使用します。

 

以上