Pyro 1.3 : Pyro のテンソル shape (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 07/15/2020 (1.3.1)
* 本ページは、Pyro の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |
Pyro のテンソル shape
このチュートリアルは tensor 次元の Pyro の体系を紹介します。始まる前に、PyTorch のブロードキャスト・セマンティクス に精通するべきです。
要約
- 学習かデバッグしているとき、pyro.enable_validation(True) を設定します。
- 右側にアラインされる tensor ブロードキャスト: torch.ones(3,4,5) + torch.ones(5).
>>> torch.ones(3,4,5)+torch.ones(5) tensor([[[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.]], [[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.]], [[2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.], [2., 2., 2., 2., 2.]]]) - Distribution .sample().shape == batch_shape + event_shape.
- Distribution .log_prob(x).shape == batch_shape (but not event_shape!).
- サンプルのバッチをドローするために .expand() を使用するか、自動的に expand するために plate に依拠します。
- 次元が従属であることを宣言するために my_dist.to_event(1) を使用します。
- pyro.plate(‘name’, size) と共に使用する: 次元を条件的に独立であると宣言するため。
- 総ての次元は従属か条件的に独立であるかが宣言されなければなりません。
- 左側のバッチ処理をサポートしようとしています。これは Pyro を自動並列化させます。
- x.sum(2) よりも x.sum(-1) のような負のインデックスを使用します。
- pixel = image[…, i, j] のような ellipsis 表記を使用します。
- i, j が列挙されれば、Vindex を使用します、pixel = Vindex(image)[…, i, j]
- デバッグ時、Trace.format_shapes() を使用して trace の総ての shape を検査してください。
import os
import torch
import pyro
from torch.distributions import constraints
from pyro.distributions import Bernoulli, Categorical, MultivariateNormal, Normal
from pyro.distributions.util import broadcast_shape
from pyro.infer import Trace_ELBO, TraceEnum_ELBO, config_enumerate
import pyro.poutine as poutine
from pyro.optim import Adam
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.3.0')
pyro.enable_validation(True) # <---- This is always a good idea!
# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
pyro.clear_param_store()
loss.loss(model, guide)
Distributions shapes: batch_shape と event_shape
PyTorch Tensor は単一の .shape 属性を持ちますが、Distributions は特別な意味を持つ 2 つの shape 属性を持ちます : .batch_shape と .event_shape です。これら 2 つは結び付いてサンプルのトータル shape を定義します :
x = d.sample() assert x.shape == d.batch_shape + d.event_shape
.batch_shape に渡るインデックスは条件的に独立な確率変数を示す一方で、.event_shape に渡るインデックスは従属確率変数を示します (i.e. 分布からの 1 つのドロー)。従属確率変数は一緒に確率を定義しますので、.log_prob() メソッドは shape .event_shape の各事象に対して単一の数を生成するだけです。こうして .log_prob() のトータル shape は .batch_shape です :
assert d.log_prob(x).shape == d.batch_shape
Distribution.sample() メソッドはまた sample_shape パラメータを取ることに注意してください、これは独立同分布 (iid) 確率変数に渡りインデックスします、その結果 :
x2 = d.sample(sample_shape) assert x2.shape == sample_shape + batch_shape + event_shape
In summary
| iid | independent | dependent
------+--------------+-------------+------------
shape = sample_shape + batch_shape + event_shape
例えば単変量分布は empty event shape を持ちます (何故ならば各数字は独立事象であるからです)。MultivariateNormal のようなベクトルに渡る分布は len(event_shape) == 1 を持ちます。InverseWishart のような行列に渡る分布は len(event_shape) == 2 を持ちます。
Examples
最も単純な分布 shape は単一の単変量分布です。
d = Bernoulli(0.5) assert d.batch_shape == () assert d.event_shape == () x = d.sample() assert x.shape == () assert d.log_prob(x).shape == ()
分布は batched パラメータを渡すことによりバッチ処理できます。
d = Bernoulli(0.5 * torch.ones(3,4)) assert d.batch_shape == (3, 4) assert d.event_shape == () x = d.sample() assert x.shape == (3, 4) assert d.log_prob(x).shape == (3, 4)
分布をバッチ処理するもう一つの方法は .expand() メソッドを通すことです。これはパラメータが左端次元に沿って同一である場合に動作するだけです。
d = Bernoulli(torch.tensor([0.1, 0.2, 0.3, 0.4])).expand([3, 4]) assert d.batch_shape == (3, 4) assert d.event_shape == () x = d.sample() assert x.shape == (3, 4) assert d.log_prob(x).shape == (3, 4)
多変量分布は空でない (= nonempty) .event_shape を持ちます。これらの分布については、.sample() と .log_prob(x) は異なります :
d = MultivariateNormal(torch.zeros(3), torch.eye(3, 3)) assert d.batch_shape == () assert d.event_shape == (3,) x = d.sample() assert x.shape == (3,) # == batch_shape + event_shape assert d.log_prob(x).shape == () # == batch_shape
分布を reshape する
Pyro では .to_event(n) プロパティを呼び出すことにより単変量分布を多変量分布として扱うことができます、ここで n は従属と宣言するための (右から) バッチ次元の数です。
d = Bernoulli(0.5 * torch.ones(3,4)).to_event(1) assert d.batch_shape == (3,) assert d.event_shape == (4,) x = d.sample() assert x.shape == (3, 4) assert d.log_prob(x).shape == (3,)
Pyro プログラムで作業する間、サンプルは shape batch_shape + event_shape を持つ一方で、.log_prob(x) 値は shape batch_shape を持つことに留意してください。batch_shape が、それを .to_event(n) でトリムダウンするか pyro.plate を通して次元を独立として宣言することにより、注意深く制御されることを確かなものにする必要があるでしょう。
従属性を仮定することは常に安全です
Pyro ではしばしば幾つかの次元を従属として宣言します、それらが実際には独立であるとしても、e.g.
x = pyro.sample("x", dist.Normal(0, 1).expand([10]).to_event(1))
assert x.shape == (10,)
これは 2 つの理由で有用です : 第一にそれは後で MultivariateNormal 分布でスワップインすることを容易に可能にします。2 番目にそれはコードを少し単純化します、何故ならば次のように plate を必要としないからです (下参照) :
with pyro.plate("x_plate", 10):
x = pyro.sample("x", dist.Normal(0, 1)) # .expand([10]) is automatic
assert x.shape == (10,)
これら 2 つのバージョンの違いは、plate を伴う 2 番目のバージョンでは (Pyro が) 勾配を見積もるとき条件付き独立の情報を利用できることを Pyro に知らせる一方で、最初のバージョンでは Pyro はそれらが従属であることを仮定しなければなりません (正規分布が実際には条件的に独立であるとしても)。これはグラフィカルモデルでの有向分離 (= d-separation) への類推です : エッジを追加して変数が従属であるかもしれないことを (i.e. モデルクラスを広くするために) 仮定することは常に安全ですが、変数が実際に従属であるとき独立性を仮定することは安全ではありません (i.e. モデルクラスを狭めますので真のモデルがクラスの外側にあります、as in mean field)。実際には Pyro の SVI 推論アルゴリズムは正規分布のために再パラメータ化された勾配推定器を使用しますので、両者の勾配推定器は同じパフォーマンスを持ちます。
plate で独立 dims を宣言する
Pyro モデルはあるバッチ次元が独立であることを宣言するためにコンテキスト・マネージャ pyro.plate を使用できます。そして推論アルゴリズムは e.g. lower variance 勾配推定器を構築したり指数空間よりも線形空間で列挙するためにこの独立性を利用できます。独立次元の例はミニバッチのデータに渡るインデックスです : 各データは他の総てから独立であるはずです。
次元を独立として宣言する最も単純な方法は単純な次を通して右端バッチ次元を独立として宣言することです :
with pyro.plate("my_plate"):
# within this context, batch dimension -1 is independent
shape のデバッグに役立てるためにオプションの size 引数を常に提供することを勧めます :
with pyro.plate("my_plate", len(my_data)):
# within this context, batch dimension -1 is independent
Pyro 0.2 から加えて plate をネストできます e.g. ピクセル毎に独立性を持つ場合 :
with pyro.plate("x_axis", 320):
# within this context, batch dimension -1 is independent
with pyro.plate("y_axis", 200):
# within this context, batch dimensions -2 and -1 are independent
-2, -1 のような負のインデックスを使用して常に右からカウントすることに注意してください。
最後に e.g. x だけに依拠するノイズ、y だけに依拠するあるノイズ、そして両者に依拠するあるノイズのために plate を混在させて適合させることを望む場合、複数の plate を宣言してそれらを再利用可能なコンテキストマネージャとして利用することができます。この場合 Pyro は自動的には次元を割り当てられませんので、dim 引数を提供する必要があります (再度、右から数えます) :
x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
# within this context, batch dimension -2 is independent
with y_axis:
# within this context, batch dimension -3 is independent
with x_axis, y_axis:
# within this context, batch dimensions -3 and -2 are independent
plate 内の batch サイズを良く見ましょう。
def model1():
a = pyro.sample("a", Normal(0, 1))
b = pyro.sample("b", Normal(torch.zeros(2), 1).to_event(1))
with pyro.plate("c_plate", 2):
c = pyro.sample("c", Normal(torch.zeros(2), 1))
with pyro.plate("d_plate", 3):
d = pyro.sample("d", Normal(torch.zeros(3,4,5), 1).to_event(2))
assert a.shape == () # batch_shape == () event_shape == ()
assert b.shape == (2,) # batch_shape == () event_shape == (2,)
assert c.shape == (2,) # batch_shape == (2,) event_shape == ()
assert d.shape == (3,4,5) # batch_shape == (3,) event_shape == (4,5)
x_axis = pyro.plate("x_axis", 3, dim=-2)
y_axis = pyro.plate("y_axis", 2, dim=-3)
with x_axis:
x = pyro.sample("x", Normal(0, 1))
with y_axis:
y = pyro.sample("y", Normal(0, 1))
with x_axis, y_axis:
xy = pyro.sample("xy", Normal(0, 1))
z = pyro.sample("z", Normal(0, 1).expand([5]).to_event(1))
assert x.shape == (3, 1) # batch_shape == (3,1) event_shape == ()
assert y.shape == (2, 1, 1) # batch_shape == (2,1,1) event_shape == ()
assert xy.shape == (2, 3, 1) # batch_shape == (2,3,1) event_shape == ()
assert z.shape == (2, 3, 1, 5) # batch_shape == (2,3,1) event_shape == (5,)
test_model(model1, model1, Trace_ELBO())
各 sample site の .shapes を batch_shape と event_shape の間の境界でそれらをアラインすることにより可視化することは役立ちます : 右側の次元は .log_prob() で総計されて左側の次元はそのままです。
batch dims | event dims
-----------+-----------
| a = sample("a", Normal(0, 1))
|2 b = sample("b", Normal(zeros(2), 1)
| .to_event(1))
| with plate("c", 2):
2| c = sample("c", Normal(zeros(2), 1))
| with plate("d", 3):
3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| .to_event(2))
|
| x_axis = plate("x", 3, dim=-2)
| y_axis = plate("y", 2, dim=-3)
| with x_axis:
3 1| x = sample("x", Normal(0, 1))
| with y_axis:
2 1 1| y = sample("y", Normal(0, 1))
| with x_axis, y_axis:
2 3 1| xy = sample("xy", Normal(0, 1))
2 3 1|5 z = sample("z", Normal(0, 1).expand([5])
| .to_event(1))
プログラムの sample site の shape を自動的に調べるため、プログラムを追跡して Trace.format_shapes() メソッドを使用することができます、これは各 sample site のための 3 つの shape をプリントします : distribution shape (site["fn"].batch_shape と site["fn"].event_shape の両者)、value shape (site["value"].shape) そして対数確率が計算されれば log_prob shape (site["log_prob"].shape) もです :
trace = poutine.trace(model1).get_trace() trace.compute_log_prob() # optional, but allows printing of log_prob shapes print(trace.format_shapes())
Trace Shapes:
Param Sites:
Sample Sites:
a dist |
value |
log_prob |
b dist | 2
value | 2
log_prob |
c_plate dist |
value 2 |
log_prob |
c dist 2 |
value 2 |
log_prob 2 |
d_plate dist |
value 3 |
log_prob |
d dist 3 | 4 5
value 3 | 4 5
log_prob 3 |
x_axis dist |
value 3 |
log_prob |
y_axis dist |
value 2 |
log_prob |
x dist 3 1 |
value 3 1 |
log_prob 3 1 |
y dist 2 1 1 |
value 2 1 1 |
log_prob 2 1 1 |
xy dist 2 3 1 |
value 2 3 1 |
log_prob 2 3 1 |
z dist 2 3 1 | 5
value 2 3 1 | 5
log_prob 2 3 1 |
plate 内で tensor をサブサンプリングする
plate の主要な利用の一つはデータをサブサンプリングすることです。これは plate 内で可能です、何故ならばデータは条件的に独立であるので、半分のデータの上の損失の期待値は、そうですね、完全なデータの半分の期待損失になるはずです。
データをサブサンプリングするには、Pyro に元のデータサイズと subsample サイズを知らせる必要があります ; それから Pyro はデータのランダムサブセットを選択してインデックスのセットを生成します。
data = torch.arange(100.)
def model2():
mean = pyro.param("mean", torch.zeros(len(data)))
with pyro.plate("data", len(data), subsample_size=10) as ind:
assert len(ind) == 10 # ind is a LongTensor that indexes the subsample.
batch = data[ind] # Select a minibatch of data.
mean_batch = mean[ind] # Take care to select the relevant per-datum parameters.
# Do stuff with batch:
x = pyro.sample("x", Normal(mean_batch, 1), obs=batch)
assert len(x) == 10
test_model(model2, guide=lambda: None, loss=Trace_ELBO())
並列列挙 (= enumeration) を可能にするブロードキャスト
Pyro 0.2 は離散潜在変数を並列に列挙する機能を導入します。これは SVI を通して事後分布を学習するとき勾配推定器の分散を本質的に減じることができます。
並列列挙を使用するため、Pyro は (それが) 列挙のために利用可能な tensor 次元を割り当てる必要があります。plate のために使用することを望む他の次元との衝突を避けるため、使用する tensor 次元の最大数のバジェットを宣言する必要があります。このバジェットは max_plate_nesting と呼称されて SVI への引数です (引数は単純に TraceEnum_ELBO 経由で渡されます)。通常は Pyro はそれ自身の上でこのバジェットを決定できます (それは (model,guide) ペアを一度実行して何が起きるかを記録します) が、動的モデル構造の場合には max_plate_nesting を手動で宣言する必要があるかもしれません。
max_plate_nesting と Pyro が列挙のためにどのように次元を割り当てるかを理解するためには、上から model1() に立ち戻りましょう。今回は 3 つのタイプの次元を精密に示します : 左側の列挙次元 (Pyro はこれらを制御します)、中央のバッチ次元、そして右側の事象次元です。
max_plate_nesting = 3
|<--->|
enumeration|batch|event
-----------+-----+-----
|. . .| a = sample("a", Normal(0, 1))
|. . .|2 b = sample("b", Normal(zeros(2), 1)
| | .to_event(1))
| | with plate("c", 2):
|. . 2| c = sample("c", Normal(zeros(2), 1))
| | with plate("d", 3):
|. . 3|4 5 d = sample("d", Normal(zeros(3,4,5), 1)
| | .to_event(2))
| |
| | x_axis = plate("x", 3, dim=-2)
| | y_axis = plate("y", 2, dim=-3)
| | with x_axis:
|. 3 1| x = sample("x", Normal(0, 1))
| | with y_axis:
|2 1 1| y = sample("y", Normal(0, 1))
| | with x_axis, y_axis:
|2 3 1| xy = sample("xy", Normal(0, 1))
|2 3 1|5 z = sample("z", Normal(0, 1).expand([5]))
| | .to_event(1))
供給過多 max_plate_nesting=4 は安全ですが、供給不足 max_plate_nesting=2 はできない (or Pyro will error) ことに注意してください。これが実際にどのように動作するか見ましょう。
@config_enumerate
def model3():
p = pyro.param("p", torch.arange(6.) / 6)
locs = pyro.param("locs", torch.tensor([-1., 1.]))
a = pyro.sample("a", Categorical(torch.ones(6) / 6))
b = pyro.sample("b", Bernoulli(p[a])) # Note this depends on a.
with pyro.plate("c_plate", 4):
c = pyro.sample("c", Bernoulli(0.3))
with pyro.plate("d_plate", 5):
d = pyro.sample("d", Bernoulli(0.4))
e_loc = locs[d.long()].unsqueeze(-1)
e_scale = torch.arange(1., 8.)
e = pyro.sample("e", Normal(e_loc, e_scale)
.to_event(1)) # Note this depends on d.
# enumerated|batch|event dims
assert a.shape == ( 6, 1, 1 ) # Six enumerated values of the Categorical.
assert b.shape == ( 2, 1, 1, 1 ) # Two enumerated Bernoullis, unexpanded.
assert c.shape == ( 2, 1, 1, 1, 1 ) # Only two Bernoullis, unexpanded.
assert d.shape == (2, 1, 1, 1, 1, 1 ) # Only two Bernoullis, unexpanded.
assert e.shape == (2, 1, 1, 1, 5, 4, 7) # This is sampled and depends on d.
assert e_loc.shape == (2, 1, 1, 1, 1, 1, 1,)
assert e_scale.shape == ( 7,)
test_model(model3, model3, TraceEnum_ELBO(max_plate_nesting=2))
それらの次元を良く見ましょう。最初に Pyro は max_plate_nesting の右側から始めて列挙 dims を割り当てることに注意してください : Pyro は enumerate a に dim -3 を、それから enumerate b に dim -4 を、それから enumerate c に dim -5 を、そして最後に enumerate d に dim -6 を割り当てます。次に samples は新しい列挙次元で extent (size > 1) を持つだけであることに注意してください。これは tensor を小さくそして計算を安価に維持するのに役立ちます。(log_prob shape は enumeratin shape と batch shape の両者を含むまでブロードキャストされます、従って e.g. trace.nodes['d']['log_prob'].shape == (2, 1, 1, 1, 5, 4) です。)
tensor 次元の類似のマップを描くことができます :
max_plate_nesting = 2
|<->|
enumeration batch event
------------|---|-----
6|1 1| a = pyro.sample("a", Categorical(torch.ones(6) / 6))
2 1|1 1| b = pyro.sample("b", Bernoulli(p[a]))
| | with pyro.plate("c_plate", 4):
2 1 1|1 1| c = pyro.sample("c", Bernoulli(0.3))
| | with pyro.plate("d_plate", 5):
2 1 1 1|1 1| d = pyro.sample("d", Bernoulli(0.4))
2 1 1 1|1 1|1 e_loc = locs[d.long()].unsqueeze(-1)
| |7 e_scale = torch.arange(1., 8.)
2 1 1 1|5 4|7 e = pyro.sample("e", Normal(e_loc, e_scale)
| | .to_event(1))
enumeration セマンティクスを持つこのモデルを自動的に検査するため、enumerated trace を作成してから Trace.format_shapes() を使用することができます :
trace = poutine.trace(poutine.enum(model3, first_available_dim=-3)).get_trace() trace.compute_log_prob() # optional, but allows printing of log_prob shapes print(trace.format_shapes())
Trace Shapes:
Param Sites:
p 6
locs 2
Sample Sites:
a dist |
value 6 1 1 |
log_prob 6 1 1 |
b dist 6 1 1 |
value 2 1 1 1 |
log_prob 2 6 1 1 |
c_plate dist |
value 4 |
log_prob |
c dist 4 |
value 2 1 1 1 1 |
log_prob 2 1 1 1 4 |
d_plate dist |
value 5 |
log_prob |
d dist 5 4 |
value 2 1 1 1 1 1 |
log_prob 2 1 1 1 5 4 |
e dist 2 1 1 1 5 4 | 7
value 2 1 1 1 5 4 | 7
log_prob 2 1 1 1 5 4 |
並列化可能なコードを書く
並列化された sample sites を正しく処理する Pyro モデルを書くことは技巧的であり得ます。2 つのトリックが役立ちます : ブロードキャスト と Ellipsis スライシング です。これらが実際にどのように動作するかを見るために考案されたモデルを見ましょう。私達の目標は列挙ありとなしの両者で動作するモデルを書くことです。
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None # set to either True or False below
def fun(observe):
p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
# Note that the shapes of these sites depend on whether Pyro is enumerating.
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y))
if enumerated:
assert x_active.shape == (2, 1, 1)
assert y_active.shape == (2, 1, 1, 1)
else:
assert x_active.shape == (width, 1)
assert y_active.shape == (height,)
# The first trick is to broadcast. This works with or without enumeration.
p = 0.1 + 0.5 * x_active * y_active
if enumerated:
assert p.shape == (2, 2, 1, 1)
else:
assert p.shape == (width, height)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
# The second trick is to index using ellipsis slicing.
# This allows Pyro to add arbitrary dimensions on the left.
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
if enumerated:
assert dense_pixels.shape == (2, 2, width, height)
else:
assert dense_pixels.shape == (width, height)
with x_axis, y_axis:
if observe:
pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)
def model4():
fun(observe=True)
def guide4():
fun(observe=False)
# Test without enumeration.
enumerated = False
test_model(model4, guide4, Trace_ELBO())
# Test with enumeration.
enumerated = True
test_model(model4, config_enumerate(guide4, "parallel"),
TraceEnum_ELBO(max_plate_nesting=2))
pyro.plate 内の自動ブロードキャスト
私達の総てのモデル/ガイド仕様では、pyro.sample ステートメントにより強要されるバッチ shape 上の制約を満たすために sample shape を自動的に拡張 (= expand) するためには pyro.plate に依拠したことに注意してください。けれどもこのブロードキャストは手動アノテートされた .expand() ステートメントと同値です。
前のセクション からの model4 を使用してこれを実演します。前のものからコードへの以下の変更に注意してください :
- この例の目的のため、「並列な」列挙だけを考慮しますが、ブロードキャストは列挙がなくてもあるいは「シーケンシャルな」列挙でも期待通りに動作するはずです。
- サンプリング関数を分離しました、これは acive ピクセルに対応する tensor を返します。モデルコードのコンポーネントへのもジュール化は一般的な実戦で、巨大なモデルのメンテナンス性に役立ちます。
- num_particles に渡る ELBO 推定器を並列化するため pyro.plate 構成部品もまた利用したいです。これは最も外側の pyro.plate コンテキスト内でモデル/ガイドの内容をラップすることで成されます。
num_particles = 100 # Number of samples for the ELBO estimator
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
def sample_pixel_locations_no_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x).expand([num_particles, width, 1]))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y).expand([num_particles, 1, height]))
return x_active, y_active
def sample_pixel_locations_full_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y))
return x_active, y_active
def sample_pixel_locations_partial_broadcasting(p_x, p_y, x_axis, y_axis):
with x_axis:
x_active = pyro.sample("x_active", Bernoulli(p_x).expand([width, 1]))
with y_axis:
y_active = pyro.sample("y_active", Bernoulli(p_y).expand([height]))
return x_active, y_active
def fun(observe, sample_fn):
p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.plate('x_axis', width, dim=-2)
y_axis = pyro.plate('y_axis', height, dim=-1)
with pyro.plate("num_particles", 100, dim=-3):
x_active, y_active = sample_fn(p_x, p_y, x_axis, y_axis)
# Indices corresponding to "parallel" enumeration are appended
# to the left of the "num_particles" plate dim.
assert x_active.shape == (2, 1, 1, 1)
assert y_active.shape == (2, 1, 1, 1, 1)
p = 0.1 + 0.5 * x_active * y_active
assert p.shape == (2, 2, 1, 1, 1)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
assert dense_pixels.shape == (2, 2, 1, width, height)
with x_axis, y_axis:
if observe:
pyro.sample("pixels", Bernoulli(p), obs=dense_pixels)
def test_model_with_sample_fn(sample_fn):
def model():
fun(observe=True, sample_fn=sample_fn)
@config_enumerate
def guide():
fun(observe=False, sample_fn=sample_fn)
test_model(model, guide, TraceEnum_ELBO(max_plate_nesting=3))
test_model_with_sample_fn(sample_pixel_locations_no_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_full_broadcasting)
test_model_with_sample_fn(sample_pixel_locations_partial_broadcasting)
最初のサンプリング関数では、pyro.plate コンテキストにより追加された条件的に独立な次元を構成するために、ある手動の簿記をして Bernoulli 分布のバッチ shape を拡張しなければなりませんでした。特に、sample_pixel_locations は num_particles, width and height の知識をどのように必要としてグローバルスコープからこれらの変数にアクセスしているかに注意してください、これは理想的ではありません。
- pyro.plate への 2 番目の引数、i.e. オプションの size 引数は暗黙的なブロードキャストのために提供される必要があります、その結果それは sample sites の各々のためにバッチ shape 要件を推論できます。
- sample saite の既存の batch_shape は pyro.plate コンテキストのサイズでブロードキャスト可能でなければなりません。私達の特定の例では、Bernoulli(p_x) は empty バッチ shape を持ち、これは普遍的にブロードキャスト可能です。
pyro.plate を使用して tensor 化された演算を通して並列化を獲得することがどれほど単純であるかに注意してください!pyro.plate はまたコードのモジュール化にも役立ちます、何故ならばモデルコンポーネントは plate コンテキストの不可知論として書かれるからです、そこではそれらはその後に埋め込まれるかもしれません。
以上