Pyro 1.4 : Examples : ベイジアン回帰 – イントロダクション (Part 1)

Pyro 1.4 : Examples : ベイジアン回帰 – イントロダクション (Part 1) (翻訳)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/06/2020 (1.4.0)

* 本ページは、Pyro の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

ベイジアン回帰 – イントロダクション (Part 1)

回帰は機械学習における最も一般的で基本的な教師あり学習タスクの一つです。次の形式のデータセット $\mathcal{D}$ が与えられたと仮定します :

\[
\mathcal{D} = \{ (X_i, y_i) \} \qquad \text{for}\qquad i=1,2,…,N
\]

線形回帰のゴールは次の形式の関数をデータに fit させることです :

\[
y = w X + b + \epsilon
\]

ここで $w$ と $b$ は学習可能なパラメータで $\epsilon$ は観測ノイズを表します。特に $w$ は重み行列で $b$ はバイアス・ベクトルです。

このチュートリアルでは、最初に PyTorch で線形回帰を実装してパラメータ $w$ と $b$ のための点推定を学習しましょう。それからベイジアン回帰を実装するために Pyro を使用することによりどのように不確かさを推定に組み込むかを見ましょう。更に、予測を行なうために Pyro のユティリティ関数をどのように使用するかそして TorchScript を使用してモデルをどのようにサーブするかを学習します。

 

セットアップ

必要なモジュールをインポートすることから始めましょう :

%reset -s -f
import os
from functools import partial
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn

import matplotlib.pyplot as plt

import pyro
from pyro.distributions import Normal, Uniform, Delta
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.distributions.util import logsumexp
from pyro.infer import EmpiricalMarginal, SVI, Trace_ELBO, TracePredictive
from pyro.infer.mcmc import MCMC, NUTS
import pyro.optim as optim
import pyro.poutine as poutine

# for CI testing
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('0.3.0')
pyro.enable_validation(True)
pyro.set_rng_seed(1)
pyro.enable_validation(True)

 

データセット

次のサンプルは [1] から改作しました。Terrain Ruggedness Index (直訳: 地形起伏指標) (データセットで起伏の多い変数) により計測された国の topographic heterogeneity (直訳: 位置特異的な不均質) と一人あたりの GDP の関係性を探究したいです。特に、地形起伏 (= terrain ruggedness) あるいは悪い地形 (= bad geography) はアフリカの外ではより貧しい経済的なパフォーマンスに関係しますが、起伏の激しい地形はアフリカ諸国のためには収入上は逆の効果を持っていたことが [1] で著者により記されています。データを見てこの関係性を調査しましょう。データセットから 3 つの特徴に注目します :

  • rugged: Terrain Ruggedness Index を定量化する
  • cont_africa: 与えられた国がアフリカにあるか否か
  • rgdppc_2000: 2000 年についての一人あたりの Real GDP

応答変数 GDP はかなり斜め (= skewed) ですので、それを対数変換します。

DATA_URL = "https://d2hg8soec8ck9v.cloudfront.net/datasets/rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = df[df["cont_africa"] == 1]
non_african_nations = df[df["cont_africa"] == 0]
sns.scatterplot(non_african_nations["rugged"],
            non_african_nations["rgdppc_2000"],
            ax=ax[0])
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
sns.scatterplot(african_nations["rugged"],
                african_nations["rgdppc_2000"],
                ax=ax[1])
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");

 

線形回帰

国の一人あたりの log GDP をデータセットからの 2 つの特徴の関数として予測したいです – 国がアフリカ内であるか、そしてその Terrain Ruggedness Index です。PyroModule[nn.Linear] と呼ばれる自明なクラスを作成します、これは PyroModule と torch.nn.Linear をサブクラス化しています。PyroModule は PyTorch の nn.Module に非常に類似していますが、追加で Pyro プリミティブ を属性としてサポートします、これらは Pyro の effect handlers で変更可能です (pyro.sample プリミティブであるもジュール属性をどのように持てるかについては 次のセクション 参照)。以下は幾つかの一般的な注意 :

  • PyTorch モジュールの学習可能パラメータは nn.Parameter のインスタンスで、この場合 nn.Linear クラスの重みとバイアス・パラメータです。PyroModule の内部で属性として宣言されるとき、これらは自動的に Pyro の param ストアに登録されます。このモデルは最適化の間これらのパラメータの値を制約することを要求しませんが、これは PyroParam ステートメントを使用して PyroModule で容易に獲得できます。
  • PyroModule[nn.Linear] の forward メソッドが nn.Liner から継承する一方で、それは容易に override されることもできることに注意してください。e.g. ロジスティック回帰の場合、私達は線形予測器に sigmoid を適用します。
from torch import nn
from pyro.nn import PyroModule

assert issubclass(PyroModule[nn.Linear], nn.Linear)
assert issubclass(PyroModule[nn.Linear], PyroModule)

 

PyTorch Optimizer で訓練する

2 つの特徴 rugged と cont_africa に加えて、モデルに交互作用項も含めることに注意してください、これはアフリカの内部と外部の国について GDP 上の起伏の効果を個別にモデル化させます。

損失として平均二乗誤差 (MSE) をそして optimizer として torch.optim モジュールから Adam を使用します。モデルのパラメータ、つまりネットワークの重みとバイアスパラメータを最適化したいです、これは回帰係数と切片に対応します。

# Dataset: Add a feature to capture the interaction between "cont_africa" and "rugged"
df["cont_africa_x_rugged"] = df["cont_africa"] * df["rugged"]
data = torch.tensor(df[["cont_africa", "rugged", "cont_africa_x_rugged", "rgdppc_2000"]].values,
                        dtype=torch.float)
x_data, y_data = data[:, :-1], data[:, -1]

# Regression model
linear_reg_model = PyroModule[nn.Linear](3, 1)

# Define loss and optimize
loss_fn = torch.nn.MSELoss(reduction='sum')
optim = torch.optim.Adam(linear_reg_model.parameters(), lr=0.05)
num_iterations = 1500 if not smoke_test else 2

def train():
    # run the model forward on the data
    y_pred = linear_reg_model(x_data).squeeze(-1)
    # calculate the mse loss
    loss = loss_fn(y_pred, y_data)
    # initialize gradients to zero
    optim.zero_grad()
    # backpropagate
    loss.backward()
    # take a gradient step
    optim.step()
    return loss

for j in range(num_iterations):
    loss = train()
    if (j + 1) % 50 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss.item()))


# Inspect learned parameters
print("Learned parameters:")
for name, param in linear_reg_model.named_parameters():
    print(name, param.data.numpy())
[iteration 0050] loss: 3179.7852
[iteration 0100] loss: 1616.1371
[iteration 0150] loss: 1109.4117
[iteration 0200] loss: 833.7545
[iteration 0250] loss: 637.5822
[iteration 0300] loss: 488.2652
[iteration 0350] loss: 376.4650
[iteration 0400] loss: 296.0483
[iteration 0450] loss: 240.6140
[iteration 0500] loss: 203.9386
[iteration 0550] loss: 180.6171
[iteration 0600] loss: 166.3493
[iteration 0650] loss: 157.9457
[iteration 0700] loss: 153.1786
[iteration 0750] loss: 150.5735
[iteration 0800] loss: 149.2020
[iteration 0850] loss: 148.5065
[iteration 0900] loss: 148.1668
[iteration 0950] loss: 148.0070
[iteration 1000] loss: 147.9347
[iteration 1050] loss: 147.9032
[iteration 1100] loss: 147.8900
[iteration 1150] loss: 147.8847
[iteration 1200] loss: 147.8827
[iteration 1250] loss: 147.8819
[iteration 1300] loss: 147.8817
[iteration 1350] loss: 147.8816
[iteration 1400] loss: 147.8815
[iteration 1450] loss: 147.8815
[iteration 1500] loss: 147.8815
Learned parameters:
weight [[-1.9478593  -0.20278624  0.39330274]]
bias [9.22308]

 

回帰 fit をプロットする

アフリカの外側と内部の諸国のために個別に、モデルのための回帰 fit をプロットしましょう。

fit = df.copy()
fit["mean"] = linear_reg_model(x_data).detach().cpu().numpy()

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = fit[fit["cont_africa"] == 1]
non_african_nations = fit[fit["cont_africa"] == 0]
fig.suptitle("Regression Fit", fontsize=16)
ax[0].plot(non_african_nations["rugged"], non_african_nations["rgdppc_2000"], "o")
ax[0].plot(non_african_nations["rugged"], non_african_nations["mean"], linewidth=2)
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
ax[1].plot(african_nations["rugged"], african_nations["rgdppc_2000"], "o")
ax[1].plot(african_nations["rugged"], african_nations["mean"], linewidth=2)
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");

地形起伏 (= terrain ruggednes) 間の関係は非アフリカ諸国については GDP と逆相関を持ちますが、それはアフリカ諸国については GDP にプラスに影響します。けれどもこの傾向がどれほど強固であるかは不明確です。特に、回帰 fit がパラメータの不確かさによりどれほど変化するかを理解したいです。これに対処するため、線形回帰のための単純な bayesian モデルを構築します。bayesian モデリング はモデルの不確かさを推論するためのシステムマチックなフレームワークを提供します。単に点推定を学習する代わりに、観測データを調和したパラメータに渡る分布を学習していきます。

 

Pyro の確率的変分推論 (SVI) による Bayesian 回帰

モデル

私達の線形回帰を Bayesian にするためには、パラメータ $w$ と $b$ に 事前分布 を置く必要があります。これらは (任意のデータを観測する前に) $w$ と $b$ のための合理的な値についての私達の事前信念を表す分布です

前のように PyroModule を利用して線形回帰のための Bayesian モデルを作成することは非常に直感的です。以下に注意してください :

  • BayesianRegression モジュールは内部的に同じ PyroModule[nn.Linear] モジュールを使用します。けれども、このもジュールの重みとバイアスを PyroSample ステートメントで置き換えることに注意してください。これらのステートメントは重みとバイアスパラメータに渡る事前分布を置くことを可能にします、それらを固定された学習可能なパラメータとして扱う代わりに。バイアス成分については、適度に広い事前分布を設定します、何故ならばそれは大体は 0 より上でありがちだからです。
  • BayesianRegression.forward は生成プロセスを指定します。線形モジュールを呼び出すことにより応答の平均値を生成します (これは貴方が見たように、事前分布から重みとバイアスパラメータをサンプリングして平均応答のための値を返します)。最後に、学習された観測ノイズ sigma を持つ観測データ y_data 上で条件付けするために pyro.sample ステートメントへの obs 引数を使用します。モデルは変数 mean により与えられた回帰ラインを返します。
from pyro.nn import PyroSample


class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

 

AutoGuide を使用する

推論を行なう、i.e. 無観測パラメータに渡る事後分布を学習するため、確率的変分推論 (SVI) を利用します。ガイドは分布の族を決定します、そして SVI はこの族から近似事後分布を見つける目標を設定します、これは真の事後分布から最低の KL ダイバージェンスを持つものです。

ユーザは Pyro で恣意的に柔軟なカスタムガイドを書けますが、このチュートリアルでは Pyro の autoguide ライブラリ に制限します。次の チュートリアル で、ガイドを手動でどのように書くかを探求します。

手始めに、AutoDiagonalNormal ガイドを使用します、これはモデルの未観測パラメータの分布を対角共分散を持つ正規分布としてモデル化します、i.e. それは潜在変数の中に相関がないことを仮定します (Part II で見るように非常に強いモデリング仮定です)。内部的には、これは、モデルの各 sample ステートメントに対応する学習可能なパラメータを持つ正規分布を利用するガイドを定義します。e.g. 私達のケースでは、この分布は項の各々のための 3 つの回帰係数、そしてモデルの切片項と sigma により与えられる 1 成分に対応する (5,) のサイズを持つはずです。

Autoguide はまた AutoDelta で MAP 推定を学習することや AutoGuideList でガイドを構成することもサポートします (より多くの情報については docs 参照)。

from pyro.infer.autoguide import AutoDiagonalNormal

model = BayesianRegression(3, 1)
guide = AutoDiagonalNormal(model)

 

エビデンス下限を最適化する

推論を行なうために確率的変分推論 (SVI) を使用します (SVI のイントロダクションのためには、SVI パート I 参照)。ちょうど非ベイジアン線形回帰におけるように、訓練ループの各反復は勾配ステップを取ります、違いとしてこの場合、MSE 損失の代わりに (SVI に渡す Trace_ELBO オブジェクトを構築することにより) エビデンス下限 (ELBO) 目的関数を使用します。

from pyro.infer import SVI, Trace_ELBO


adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO())

前のように torch.optim モジュールではなく Pyro の optim モジュールからの Adam optimizer を使用することに注意してください。ここで Adam は torch.optim.Adam 回りの薄いラッパーです (議論のためには ここ を参照)。pyro.optim の optimizer は Pyro のパラメータ・ストアのパラメータ値を最適化して更新するために使用されます。特に、学習可能なパラメータを optimizer に渡す必要がないことに気付くでしょう、何故ならばそれはガイドコードにより決定されて SVI クラス内で内部的に自動的に発生するからです。ELBO 勾配ステップを取るためには SVI の step メソッドを単純に呼び出します。SVI.step に渡す data 引数は model() と guide() の両者に渡されます。完全な訓練ループは次のようなものです :

pyro.clear_param_store()
for j in range(num_iterations):
    # calculate the loss and take a gradient step
    loss = svi.step(x_data, y_data)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))
[iteration 0001] loss: 6.2310
[iteration 0101] loss: 3.5253
[iteration 0201] loss: 3.2347
[iteration 0301] loss: 3.0890
[iteration 0401] loss: 2.6377
[iteration 0501] loss: 2.0626
[iteration 0601] loss: 1.4852
[iteration 0701] loss: 1.4631
[iteration 0801] loss: 1.4632
[iteration 0901] loss: 1.4592
[iteration 1001] loss: 1.4940
[iteration 1101] loss: 1.4988
[iteration 1201] loss: 1.4938
[iteration 1301] loss: 1.4679
[iteration 1401] loss: 1.4581

Pyro の param ストアから取得することにより最適化されたパラメータ値を調べることができます。

guide.requires_grad_(False)

for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name))
AutoDiagonalNormal.loc Parameter containing:
tensor([-2.2371, -1.8097, -0.1691,  0.3791,  9.1823])
AutoDiagonalNormal.scale tensor([0.0551, 0.1142, 0.0387, 0.0769, 0.0702])

見れるように、単なる点推定の代わりに、今では学習されたパラメータのための不確かな推定 (AutoDiagonalNormal.scale) を持ちます。Autoguide は潜在変数を単一の tensor にパックすることに注意してください、この場合、モデルでサンプリングされた変数毎に一つのエントリです。loc と scale パラメータの両者はサイズ (5,) を持ち、先に述べたように、モデルの潜在変数の各々のために一つです。

潜在パラメータの分布をより明瞭に見るため、AutoDiagonalNormal.quantiles を利用することができます、これは autoguide からの潜在サンプルをアンパックして、それらを site のサポートに自動的に制約します (e.g. 変数 sigma は (0, 10) に在らなければならない)。パラメータのための median 値が最初のモデルから得た最尤点推定に非常に近いことを見ます。

guide.quantiles([0.25, 0.5, 0.75])
{'sigma': [tensor(0.9328), tensor(0.9647), tensor(0.9976)],
 'linear.weight': [tensor([[-1.8868, -0.1952,  0.3272]]),
  tensor([[-1.8097, -0.1691,  0.3791]]),
  tensor([[-1.7327, -0.1429,  0.4309]])],
 'linear.bias': [tensor([9.1350]), tensor([9.1823]), tensor([9.2297])]}

 

モデル評価

モデルを評価するために、幾つかの予測サンプルを生成して事後分布を見ます。このため Predictive ユティリティクラスを利用します。

  • 訓練モデルから 800 サンプルを生成します。内部的にはこれは、最初にガイドの未観測サイトのためのサンプルを生成して、それからサイトをガイドからサンプリングされた値に条件付けることによりモデルを順方向に実行することにより成されます。Predictive クラスがどのように動作するかについての洞察のためには Model Serving セクションを参照してください。
  • return_sites では、outcome (“obs” saite) と (回帰ラインを捕捉する) モデルの戻り値 (“_RETURN”) の両者を指定することに注意してください。追加で、更なる解析のために (“linear.weight” で与えられる) 回帰係数を捕捉することもしたいです。
  • 残りのコードはモデルからの 2 つの変数のための 90% CI をプロットするために単純に使用されます。
from pyro.infer import Predictive


def summary(samples):
    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0),
            "std": torch.std(v, 0),
            "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
            "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
        }
    return site_stats


predictive = Predictive(model, guide=guide, num_samples=800,
                        return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(x_data)
pred_summary = summary(samples)
mu = pred_summary["_RETURN"]
y = pred_summary["obs"]
predictions = pd.DataFrame({
    "cont_africa": x_data[:, 0],
    "rugged": x_data[:, 1],
    "mu_mean": mu["mean"],
    "mu_perc_5": mu["5%"],
    "mu_perc_95": mu["95%"],
    "y_mean": y["mean"],
    "y_perc_5": y["5%"],
    "y_perc_95": y["95%"],
    "true_gdp": y_data,
})
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = predictions[predictions["cont_africa"] == 1]
non_african_nations = predictions[predictions["cont_africa"] == 0]
african_nations = african_nations.sort_values(by=["rugged"])
non_african_nations = non_african_nations.sort_values(by=["rugged"])
fig.suptitle("Regression line 90% CI", fontsize=16)
ax[0].plot(non_african_nations["rugged"],
           non_african_nations["mu_mean"])
ax[0].fill_between(non_african_nations["rugged"],
                   non_african_nations["mu_perc_5"],
                   non_african_nations["mu_perc_95"],
                   alpha=0.5)
ax[0].plot(non_african_nations["rugged"],
           non_african_nations["true_gdp"],
           "o")
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
idx = np.argsort(african_nations["rugged"])
ax[1].plot(african_nations["rugged"],
           african_nations["mu_mean"])
ax[1].fill_between(african_nations["rugged"],
                   african_nations["mu_perc_5"],
                   african_nations["mu_perc_95"],
                   alpha=0.5)
ax[1].plot(african_nations["rugged"],
           african_nations["true_gdp"],
           "o")
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");

上の図は回帰ラインの私達の推定における不確かさ、そして mean 回りの 90% CI を示します。データポイントの殆どが実際には 90% CI の外側にあることも見れます、そしてこれは想定されたものです、何故ならば sigma により影響される outcome 変数をプロットしていなからです!次にそのようにしましょう。

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
fig.suptitle("Posterior predictive distribution with 90% CI", fontsize=16)
ax[0].plot(non_african_nations["rugged"],
           non_african_nations["y_mean"])
ax[0].fill_between(non_african_nations["rugged"],
                   non_african_nations["y_perc_5"],
                   non_african_nations["y_perc_95"],
                   alpha=0.5)
ax[0].plot(non_african_nations["rugged"],
           non_african_nations["true_gdp"],
           "o")
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
idx = np.argsort(african_nations["rugged"])

ax[1].plot(african_nations["rugged"],
           african_nations["y_mean"])
ax[1].fill_between(african_nations["rugged"],
                   african_nations["y_perc_5"],
                   african_nations["y_perc_95"],
                   alpha=0.5)
ax[1].plot(african_nations["rugged"],
           african_nations["true_gdp"],
           "o")
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");

私達のモデルからの outcome と 90% CI は実際に観測するデータポイントの大多数を説明することを観察します。モデルが妥当な予測を与えるかを見るためにそのような事後予測チェックを行なうことは通常は良いアイデアです。

最後に、地形起伏と GDP の間の関係がモデルからのパラメータ推定における任意の不確かさに対してどの程度強固であるかという先の質問を再び訪ねましょう。このため、アフリカの内側と外側の諸国のための起伏が与えられたとき log GDP の傾斜の分布をプロットします。下で見れるように、アフリカ諸国のための確率質量は正の領域に大部分は集中して他の国のためには反対です、更なる信用を元の仮説に与えます。

weight = samples["linear.weight"]
weight = weight.reshape(weight.shape[0], 3)
gamma_within_africa = weight[:, 1] + weight[:, 2]
gamma_outside_africa = weight[:, 1]
fig = plt.figure(figsize=(10, 6))
sns.distplot(gamma_within_africa, kde_kws={"label": "African nations"},)
sns.distplot(gamma_outside_africa, kde_kws={"label": "Non-African nations"})
fig.suptitle("Density of Slope : log(GDP) vs. Terrain Ruggedness");

 

TorchScript を通したモデルサービング

最後に、モデル、ガイドと Predictive ユティリティクラスは総て torch.nn.Module インスタンスで、TorchScript としてシリアライズ可能であることに注意してください。

ここでは、Pyro モデルを torch.jit.ModuleScript としてどのようにサーブできるかを示します、これは Python ランタイムなしで C++ プログラムとして別個に実行できます。

それを行なうため、Pyro の effect ハンドリング・ライブラリ を使用して Predictive ユティリティクラスの私達自身の単純なバージョンを書き直します。これは以下を使用します :

  • trace poutine、モデル/ガイドコードの実行から実行トレースを捕捉する。
  • replay poutine、モデルの site をガイドトレースからサンプリングされた値に条件付けする。
from collections import defaultdict
from pyro import poutine
from pyro.poutine.util import prune_subsample_sites
import warnings


class Predict(torch.nn.Module):
    def __init__(self, model, guide):
        super().__init__()
        self.model = model
        self.guide = guide

    def forward(self, *args, **kwargs):
        samples = {}
        guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
        model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(*args, **kwargs)
        for site in prune_subsample_sites(model_trace).stochastic_nodes:
            samples[site] = model_trace.nodes[site]['value']
        return tuple(v for _, v in sorted(samples.items()))

predict_fn = Predict(model, guide)
predict_module = torch.jit.trace_module(predict_fn, {"forward": (x_data,)}, check_trace=False)

このモジュールの forward メソッドをトレースするために torch.jit.trace_module を使用し torch.jit.save を使用してそれをセーブします。このセーブされたモデル reg_predict.pt は PyTorch の C++ API で torch::jit::load(filename) を利用するか、下で行なうように Python API を利用してロードできます。

torch.jit.save(predict_module, '/tmp/reg_predict.pt')
pred_loaded = torch.jit.load('/tmp/reg_predict.pt')
pred_loaded(x_data)
(tensor([9.2165]),
 tensor([[-1.6612, -0.1498,  0.4282]]),
 tensor([ 7.5951,  8.2473,  9.3864,  9.2590,  9.0540,  9.3915,  8.6764,  9.3775,
          9.5473,  9.6144, 10.3521,  8.5452,  5.4008,  8.4601,  9.6219,  9.7774,
          7.1958,  7.2581,  8.9159,  9.0875,  8.3730,  8.7903,  9.3167,  8.8155,
          7.4433,  9.9981,  8.6909,  9.2915, 10.1376,  7.7618, 10.1916,  7.4754,
          6.3473,  7.7584,  9.1307,  6.0794,  8.5641,  7.8487,  9.2828,  9.0763,
          7.9250, 10.9226,  8.0005, 10.1799,  5.3611,  8.1174,  8.0585,  8.5098,
          6.8656,  8.6765,  7.8925,  9.5233, 10.1269, 10.2661,  7.8883,  8.9194,
         10.2866,  7.0821,  8.2370,  8.3087,  7.8408,  8.4891,  8.0107,  7.6815,
          8.7497,  9.3551,  9.9687, 10.4804,  8.5176,  7.1679, 10.8805,  7.4919,
          8.7088,  9.2417,  9.2360,  9.7907,  8.4934,  7.8897,  9.5338,  9.6572,
          9.6604,  9.9855,  6.7415,  8.1721, 10.0646, 10.0817,  8.4503,  9.2588,
          8.4489,  7.7516,  6.8496,  9.2208,  8.9852, 10.6585,  9.4218,  9.1290,
          9.5631,  9.7422, 10.2814,  7.2624,  9.6727,  8.9743,  6.9666,  9.5856,
          9.2518,  8.4207,  8.6988,  9.1914,  7.8161,  9.8446,  6.5528,  8.5518,
          6.7168,  7.0694,  8.9211,  8.5311,  8.4545, 10.8346,  7.8768,  9.2537,
          9.0776,  9.4698,  7.9611,  9.2177,  8.0880,  8.5090,  9.2262,  8.9242,
          9.3966,  7.5051,  9.1014,  8.9601,  7.7225,  8.7569,  8.5847,  8.8465,
          9.7494,  8.8587,  6.5624,  6.9372,  9.9806, 10.1259,  9.1864,  7.5758,
          9.8258,  8.6375,  7.6954,  8.9718,  7.0985,  8.6360,  8.5951,  8.9163,
          8.4661,  8.4551, 10.6844,  7.5948,  8.7568,  9.5296,  8.9530,  7.1214,
          9.1401,  8.4992,  8.9115, 10.9739,  8.1593, 10.1162,  9.7072,  7.8641,
          8.8606,  7.5935]),
 tensor(0.9631))

私達の Predict モジュールが実際に正しくシリアライズされたか、ロードされたもジュールからサンプルを生成して前のプロットを再生成することにより確認しましょう。

weight = []
for _ in range(800):
    # index = 1 corresponds to "linear.weight"
    weight.append(pred_loaded(x_data)[1])
weight = torch.stack(weight).detach()
weight = weight.reshape(weight.shape[0], 3)
gamma_within_africa = weight[:, 1] + weight[:, 2]
gamma_outside_africa = weight[:, 1]
fig = plt.figure(figsize=(10, 6))
sns.distplot(gamma_within_africa, kde_kws={"label": "African nations"},)
sns.distplot(gamma_outside_africa, kde_kws={"label": "Non-African nations"})
fig.suptitle("Loaded TorchScript Module : log(GDP) vs. Terrain Ruggedness");

次のセクションでは、変分推論のためのガイドをどのように書くかそして HMC を通した推論による結果と比較するかを見ます。

 

References

  1. McElreath, D., Statistical Rethinking, Chapter 7, 2016
  2. Nunn, N. & Puga, D., Ruggedness: The blessing of bad geography in Africa, Review of Economics and Statistics 94(1), Feb. 2012
 

以上