Pyro 1.4 : SVI (2) 条件付き独立、サブサンプリング及び Amortization (翻訳)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/05/2020 (1.4.0)
* 本ページは、Pyro の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
SVI (2) 条件付き独立、サブサンプリング及び Amortization
ゴール: SVI を巨大なデータセットにスケールする
$N$ 観測を持つモデルについて、モデルとガイドを実行して ELBO を構築することは、複雑さが $N$ で酷くスケールされる log pdf を評価することを伴います。これは巨大なデータセットにスケールすることを望む場合に問題です。幸い、モデル/ガイドが (活用することができる) ある条件付き独立構造を持つ条件下では ELBO 目的 (関数) はサブサンプリングを自然にサポートします。例えば、潜在 (変数) が与えられたとき観測が条件的に独立である場合、ELBO の対数尤度項は次で近似されます :
\[
\sum_{i=1}^N \log p({\bf x}_i | {\bf z}) \approx \frac{N}{M}
\sum_{i\in{\mathcal{I}_M}} \log p({\bf x}_i | {\bf z})
\]
ここで $\mathcal{I}_M$ は $M<N$ であるサイズ $M$ のインデックスのミニバッチです (議論についてはリファレンス [1,2] 参照)。Great, problem solved! しかしこれを Pyro でどのように行なうのでしょう ?
Pyro で条件付き独立をマーキングする
ユーザがこの類のことを Pyro で行なうことを望むのであれば、彼または彼女は最初にモデルとガイドが Pyro が関連する条件付き独立を活用できるような方法で書かれることを確かにする必要があります。これがどのように成されるのか見ましょう。Pyro は条件付き独立をマーキングするために 2 つの言語プリミティブを提供します : plate と markov です。2 つの単純なものから始めましょう。
sequential (逐次) plate
前のチュートリアル で使用した例に戻りましょう。便利のためにモデルの主要ロジックをここで複製しましょう :
def model(data): # sample f from the beta prior f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) # loop over the observed data using pyro.sample with the obs keyword argument for i in range(len(data)): # observe datapoint i using the bernoulli likelihood pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
このモデルについて潜在確率変数 latent_fairness が与えられたとき観測は条件的に独立です。これを Pyro で明示的にマーキングするためには基本的には Python 組込みの range を Pyro 構成物 plate で単に置き換える必要があります :
def model(data): # sample f from the beta prior f = pyro.sample("latent_fairness", dist.Beta(alpha0, beta0)) # loop over the observed data [WE ONLY CHANGE THE NEXT LINE] for i in pyro.plate("data_loop", len(data)): # observe datapoint i using the bernoulli likelihood pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
pyro.plate は一つの主要な違いとともに range に非常に類似していることを見ます : plate の各起動はユーザに一意な名前を提供することを要求します。2 番目の引数はちょうど range のためのように整数です。
今のところ順調です (= So far so good)。Pyro は今では潜在確率変数が与えられたとき観測の条件付き独立を活用できます。しかしこれは実際にはどのように動作するのでしょう?基本的には pyro.plate はコンテキスト・マネージャを使用して実装されています。for ループの本体の総ての実行において新しい (条件付き) 独立コンテキストに入り、それからこれは for ループ本体の最後に抜け出します。これについてはっきりと明示しましょう :
- 各観測された pyro.sample ステートメントは for ループ本体の異なる実行内部で発生しますので、Pyro は各観測を独立であるとマーキングします。
- latent_fairness が与えられたときこの独立性は正しく条件付き独立です、何故ならば latent_fairness は data_loop のコンテキストの外側でサンプリングされるからです。
先に進む前に、逐次 plate を使用するときに避けるべき幾つかの落とし穴に言及しましょう。上のコード・スニペットの次の変形を考えます :
# WARNING do not do this! my_reified_list = list(pyro.plate("data_loop", len(data))) for i in my_reified_list: pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
これは望まれる動作を達成しません、何故ならば単一の pyro.sample ステートメントが呼び出される前に list() は data_loop コンテキストに完全に入って抜けるからです。同様に、ユーザはコンテキスト・マネージャの境界に渡り可変な計算を漏らす (= leak) ことのないように気をつける必要があります、何故ならばこれは微妙なバグに繋がるかもしれないからです。例えば、pyro.plate は時間モデル (= temporal model) のためには適切ではありません、そこではループの各 iteration は前の iteration に依拠します ; この場合代わりに range か pyro.markov が使用されるべきです。
ベクトル化 plate
概念的にはベクトル化 plate はそれがベクトル化された演算であることを除いて逐次 plate と同じです (torch.arange が range に対するように)。そのようなものとしてそれは逐次 plate と共に出現する明示的な for ループに比較して潜在的に膨大なスピードアップを可能にします。私達の実行例に対してこれがどのように見えるか見てみましょう。最初に tensor の形式にあるデータが必要です :
data = torch.zeros(10) data[0:6] = torch.ones(6) # 6 heads and 4 tails
それから次を持ちます :
with plate('observe_data'): pyro.sample('obs', dist.Bernoulli(f), obs=data)
これを類似の逐次 plate 使用方法とポイント毎に比較してみましょう : – 両者のパターンはユーザに一意な名前を指定することを要求します。- このコードスニペットは単一の (観測される) 確率変数 (つまり obs) だけを導入することに注意してください、何故ならば全体の tensor が一度に考慮されるからです。- このケースでは iterator の必要性はありませんので、plate コンテキストに伴う tensor(s) の長さを指定する必要はありません。
逐次 plate のケースで言及した落とし穴はベクトル化 plate にもまた当てはまることに注意してください。
サブサンプリング
私達は今では Pyro で条件付き独立をどのようにマーキングするかを知っています。これはそれ自体有用ですが (SVI Part III の 依存性追跡セクション 参照)、巨大なデータセット上で SVI を行えるようにサブサンプリングもまた行ないたいです。モデルとガイドの構造に依拠して、Pyro はサブサンプリングを行なう幾つかの方法をサポートします。これらを一つずつ調べましょう。
plate で自動サブサンプリング
最初に最も単純なケースを見ましょう、そこでは plate への 1 つか 2 つの追加引数でサブサンプリングを代償なく得ます :
for i in pyro.plate("data_loop", len(data), subsample_size=5): pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])
これだけのことです: 単に引数 subsample_size を使用します。model() を実行するときはいつでも今ではデータの 5 つのランダムに選択されたデータポイントのために対数尤度を評価するだけです ; 更に、対数尤度は $\tfrac{10}{5} = 2$ の適切な因子で自動的にスケールされます。ベクトル化 plate についてはどうでしょう?呪文は全く類似しています :
with plate('observe_data', size=10, subsample_size=5) as ind: pyro.sample('obs', dist.Bernoulli(f), obs=data.index_select(0, ind))
重要なこととして、plate は今ではインデックスの tensor ind を返します、これはこの場合は長さ 5 になります。引数 subsample_size に加えて plate が tensor データのフルサイズを知るように引数 size も渡すことにも注意してください、その結果それは正しいスケーリング因子を計算できます。ちょうど逐次 plate のためのように、ユーザは plate により提供されるインデックスを使用して正しいデータポイントを選択する責任を負います。
最後に、もしデータが GPU 上にあるのであればユーザは device 引数を plate に渡さなければならないことに注意してください。
plate によるカスタム・サブサンプリング・ストラテジー
上の model() が実行されるたびに plate は新しいサブサンプル・インデックスをサンプリングします。このサブサンプリングはステートレスですから、これは幾つかの問題に繋がる可能性があります : 基本的に十分に巨大なデータセットに対して巨大な数の iteration の後でさえも、データポイントの一部が決して選択されないであろう無視できない可能性があります。これを回避するためにユーザは plate への subsample 引数を利用してサブサンプリングを制御できます。詳細は ドキュメント を見てください。
ローカル確率変数だけがあるときのサブサンプリング
次で与えられる同時確率密度を持つモデルを持つことを念頭におきます :
$$
p({\bf x}, {\bf z}) = \prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i)
$$
この依存構造を持つモデルについて、サブサンプリングで導入されるスケール因子は ELBO の総ての項を同じ総量でスケールします。これは例えば、vanilla VAE の場合です。これは VAE に対してサブサンプリングを完全に制御してミニバッチを直接モデルとガイドに渡すことを何故ユーザに許容するかを説明しています; plate は依然として使用されますが、subsample_size と subsample はそうではありません。これがどのように見えるかを詳細に見るためには、VAE チュートリアル を見てください。
グローバルとローカル確率変数の両者があるときのサブサンプリング
上のコインフリップの例では plate はモデルに現れましたがガイドには現れません、何故ならばサブサンプリングされる唯一のものは観測であったからです。より複雑な例を見てみましょう、そこではサブサンプリングはモデルとガイドの両者に現れます。簡潔にするために、議論を幾分抽象的なものにして完全なモデルとガイドを書くことは回避しましょう。
次の同時分布で指定されるモデルを考えます :
$$
p({\bf x}, {\bf z}, \beta) = p(\beta)
\prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i | \beta)
$$
N 観測 $\{ {\bf x}_i \}$ と $N$ ローカル潜在確率変数 $\{ {\bf z}_i \}$ があります。グローバル潜在確率変数 $\beta$ もまたあります。ガイドは次のように分解できます :
$$
q({\bf z}, \beta) = q(\beta) \prod_{i=1}^N q({\bf z}_i | \beta, \lambda_i)
$$
ここで $N$ ローカル変分パラメータ $\{\lambda_i \}$ を導入することについては明白でしたが、一方で他の変分パラメータは暗黙的なままです。モデルとガイドの両者は条件付き独立を持ちます。特に、モデル側では、$\{ {\bf z}_i \}$ が与えられたとき観測 $\{ {\bf x}_i \}$ は独立です。更に、$\beta$ が与えられたとき潜在確率変数 $\{\bf {z}_i \}$ は独立です。ガイド側では、変分パラメータ $\{\lambda_i \}$ と $\beta$ が与えられたとき潜在確率変数 $\{\bf {z}_i \}$ は独立です。Pyro でこれらの条件付き独立をマーキングしてサブサンプリングを行なうためにはモデルとガイドの両者で plate を利用する必要があります。逐次 plate を使用して基本的なロジックの概略を述べましょう (コードの完全なピースは pyro.param ステートメント, etc. を含むでしょう)。最初に、モデルです :
def model(data): beta = pyro.sample("beta", ...) # sample the global RV for i in pyro.plate("locals", len(data)): z_i = pyro.sample("z_{}".format(i), ...) # compute the parameter used to define the observation # likelihood using the local random variable theta_i = compute_something(z_i) pyro.sample("obs_{}".format(i), dist.MyDist(theta_i), obs=data[i])
コインフリップを実行する例と対称的に、ここでは plate ループの内側と外側の両者で pyro.sample ステートメントを持つことに注意してください。次にガイドです :
def guide(data): beta = pyro.sample("beta", ...) # sample the global RV for i in pyro.plate("locals", len(data), subsample_size=5): # sample the local RVs pyro.sample("z_{}".format(i), ..., lambda_i)
インデックスはガイドでは一度だけサブサンプリングされることに十分に注意してください ; Pyro バックエンドは、モデルの実行中インデックスの同じセットが使用されることを確かなものにします。この理由で、subsample_size はガイドでのみ指定される必要があります。
Amortization
グローバルとローカル潜在確率変数とローカル変分パラメータを持つモデルを再度考えましょう :
$$
p({\bf x}, {\bf z}, \beta) = p(\beta)
\prod_{i=1}^N p({\bf x}_i | {\bf z}_i) p({\bf z}_i | \beta) \qquad \qquad
q({\bf z}, \beta) = q(\beta) \prod_{i=1}^N q({\bf z}_i | \beta, \lambda_i)
$$
スモールからミディアムサイズの $N$ に対してこのようにローカル変分パラメータを使用することは良いアプローチであり得ます。けれども、$N$ が巨大であれば、それに渡り最適化している空間が $N$ で増大する事実は現実問題となる可能性があります。データセットのサイズによるこの嫌な増大を回避する一つの方法は amortization です。
これは次のように動作します。ローカル変分パラメータを導入する代わりに、単一パラメトリック関数 $f(\cdot)$ を学習して次の形式を持つ変分分布で作業していきます :
$$
q(\beta) \prod_{n=1}^N q({\bf z}_i | f({\bf x}_i))
$$
関数 $f(\cdot)$ — これは基本的には与えられた観測を (そのデータポイントに適合された) 変分パラメータのセットにマップします — は事後分布を正確に捕捉するために十分にリッチである必要がありますが、今では変分パラメータの非常識な数を導入しなければならないことなく巨大なデータセットを処理できます。このアプローチは他の恩恵もまたあります : 例えば、学習の間 $f(\cdot)$ は異なるデータポイント内の統計的パワーを共有することを効果的に可能にします。これは正確に VAE で使用されるアプローチであることに注意してください。
Tensor shape とベクトル化 plate
このチュートリアルの pyro.plate の使用方法は比較的単純なケースに制限されていました。例えば、どの plate も他の plate の内側にネストされませんでした。plate をフル活用するためには、ユーザは Pyro の tensor shape セマンティクスを使用することに注意しなければなりません。議論のためには tensor shape チュートリアル を見てください。
References
- Stochastic Variational Inference, Matthew D. Hoffman, David M. Blei, Chong Wang, John Paisley
- Auto-Encoding Variational Bayes, Diederik P Kingma, Max Welling
以上