PyTorch 1.8 レシピ : 基本 : 一般的なチェックポイントをセーブ/ロードする

PyTorch 1.8 チュートリアル : レシピ : 基本 :- 一般的なチェックポイントをセーブ/ロードする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/17/2021 (1.8.1+cu102)

* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

レシピ : 基本 :- 一般的なチェックポイントをセーブ/ロードする

推論や訓練を再開するために一般的なチェックポイント・モデルをセーブしてロードすることは貴方が最後に離れたところからピックアップするために有用であり得ます。一般的なチェックポイントをセーブするとき、モデルの state_dict だけではなくそれ以上をセーブしなければなりません。optimizer の state_dict もセーブすることは重要です、何故ならばこれはモデルが訓練されるとき更新されるバッファとパラメータを含むからです。セーブすることを望むかもしれない他の項目は貴方自身のアルゴリズムに基づいて、離れたエポック、最新の記録された訓練損失、外部 torch.nn.Embedding 層、等々。

 

イントロダクション

複数のチェックポイントをセーブするには、それらを辞書で体系化して辞書をシリアライズするために torch.save() を使用しなければなりません。一般的な PyTorch 慣習はこれらのチェックポイントを .tar ファイル拡張子を使用してセーブします。項目をロードするには、最初にモデルと optimizer を初期化してから、torch.load() を使用して辞書をローカルでロードします。ここから、想定するような辞書に単純に問い合せることによりセーブされた項目に容易にアクセス可能です。

このレシピでは、複数のチェックポイントをどのようにセーブしてロードするかを調べましょう。

 

セットアップ

始める前に、torch をそれがまだ利用可能でないならばインストールする必要があります。

pip install torch

 

ステップ

  1. データをロードするために総ての必要なライブラリをインポートする
  2. ニューラルネットワークを定義して初期化する
  3. optimizer を初期化する
  4. 一般的なチェックポイントをセーブする
  5. 一般的なチェックポイントをロードする

 

1. データをロードするために必要なライブラリをインポートする

このレシピのために、torch とその補助 torch.nn と torch.optim を使用します。

import torch
import torch.nn as nn
import torch.optim as optim

 

2. ニューラルネットワークを定義して初期化する

例のために、訓練画像のためのニューラルネットワークを作成します。更に学習するためには ニューラルネットワークを定義する レシピを見てください。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

 

3. optimizer を初期化する

SGD with モメンタムを使用します。

optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

 

4. 一般的なチェックポイントをセーブする

総ての関連情報を集めて貴方の辞書を構築します。

# Additional information
EPOCH = 5
PATH = "model.pt"
LOSS = 0.4

torch.save({
            'epoch': EPOCH,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': LOSS,
            }, PATH)

 

5. 一般的なチェックポイントをロードする

最初にモデルと optimizer を初期化してから、辞書をローカルでロードすることを忘れないでください。

model = Net()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

推論を実行する前に dropout とバッチ正規化層を評価モードに設定するために model.eval() を呼び出さなければなりません。これをし損なうと一貫性のない推論結果をもたらします。

訓練を再開することを望む場合には、これらの層が訓練モードにあることを確かなものにするために model.train() を呼び出してください。

Congratulations! PyTorch で推論 and/or 訓練を再開するために一般的なチェックポイントを成功的にセーブしてロードしました。

 

以上