PyTorch 1.5 レシピ : 基本 : 推論のためにモデルをセーブ/ロードする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/11/2020 (1.5.0)
* 本ページは、PyTorch 1.5 Recipes の以下のページを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
基本 : 推論のためにモデルをセーブ/ロードする
PyTorch で推論のためにモデルをセーブしてロードするために 2 つのアプローチがあります。最初のものは state_dict をセーブしてロードするもので、そして 2 番目はモデル全体をセーブしてロードします。
イントロダクション
torch.save() 関数で state_dict をセーブすることは後でモデルを復元するために最大限の柔軟さを与えます。これはモデルをセーブするための推奨方法です、何故ならばそれは訓練モデルの学習されたパラメータをセーブするために実際に必要であるだけからです。モデル全体をセーブしてロードするとき、Python の pickle モジュールを使用してモジュール全体をセーブします。このアプローチの利用は最も直感的なシンタックスをもたらしてコードの最小の総量を伴います。このアプローチの不利な点はシリアライズされたデータが、モデルがセーブされたときに使用された、特定のクラスと正確なディレクトリ構造に束縛されることです。この理由は pickle はモデルクラス自身はセーブしないからです。むしろ、それはクラスを含むファイルへのパスをセーブします、これはロードの間に使用されます。このため、貴方のコードは他のプロジェクトやリファクタリング後に使用されるとき様々な方法で壊れる可能性があります。このレシピでは、推論のためにモデルをどのようにセーブしてロードするかについて両者の方法を調べます。
セットアップ
始める前に、torch をそれがまだ利用可能でないならばインストールする必要があります。
pip install torch
ステップ
- データをロードするために総ての必要なライブラリをインポートします。
- ニューラルネットワークを定義して初期化します。
- optimizer を初期化します。
- state_dict を通してモデルをセーブしてロードします。
- モデル全体をセーブしてロードする。
1. データをロードするために必要なライブラリをインポートする
このレシピのため、torch とその補助 torch.nn と torch.optim を使用します。
import torch import torch.nn as nn import torch.optim as optim
2. ニューラルネットワークを定義して初期化する
サンプルのために、訓練画像のためのニューラルネットワークを作成します。更に学習するためには ニューラルネットワークを定義する レシピを見てください。
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net)
3. optimizer を初期化する
SGD with momentum を利用します。
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4. state_dict を通してモデルをセーブしてロードする
単に state_dict を使用してモデルをセーブしてロードしましょう。
# Specify a path PATH = "state_dict_model.pt" # Save torch.save(net.state_dict(), PATH) # Load model = Net() model.load_state_dict(torch.load(PATH)) model.eval()
一般的な PyToch の慣習は .pt か .pth ファイル拡張子を使用してモデルをセーブします。
load_state_dict() 関数はセーブされたオブジェクトへの パスではなく、辞書オブジェクトを取ることに注意してください。これはセーブされた stated_dict を load_state_dict() 関数に渡す前にデシリアライズしなければならないことを意味します。例えば、model.load_state_dict(PATH) を使用してはロード できません。
推論を実行する前に dropout とバッチ正規化層を評価モードに設定するために model.eval() を呼び出さなければならないことも忘れないでください。これをし損なうと一貫性のない推論結果をもたらします。
5. モデル全体をセーブしてロードする
今はモデル全体で同じことを試しましょう。
# Specify a path PATH = "entire_model.pt" # Save torch.save(net, PATH) # Load model = torch.load(PATH) model.eval()
再度ここで、推論を実行する前に dropout とバッチ正規化層を評価モードに設定するために model.eval() を呼び出さなければならないことを忘れないでください。
Congratulations! 貴方は PyTorch で推論のためにモデルを成功的にセーブしてロードしました。
以上