PyTorch 1.8 チュートリアル : レシピ : 基本 :- 推論のためにモデルをセーブ/ロードする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/17/2021 (1.8.1+cu102)
* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:
- PyTorch Recipes : Basics : Saving and Loading Models for Inference in PyTorch
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
レシピ : 基本 :- 推論のためにモデルをセーブ/ロードする
PyTorch で推論のためにモデルをセーブしてロードするために 2 つのアプローチがあります。最初のものは state_dict をセーブしてロードするもので、そして 2 番目はモデル全体をセーブしてロードします。
イントロダクション
torch.save() 関数でモデルの state_dict をセーブすることは後でモデルを復元するために最大限の柔軟性を与えます。これはモデルをセーブするための推奨方法です、何故ならばそれは訓練モデルの学習されたパラメータをセーブするために実際に必要であるだけからです。モデル全体をセーブしてロードするとき、Python の pickle モジュールを使用してモジュール全体をセーブします。このアプローチの利用は最も直感的なシンタックスをもたらしコードの最小の総量を伴います。このアプローチの不利な点はシリアライズされたデータが、モデルがセーブされたときに使用された、特定のクラスと正確なディレクトリ構造に束縛されることです。この理由は pickle はモデルクラス自身はセーブしないからです。むしろ、それはクラスを含むファイルへのパスをセーブします、これはロード時の際に使用されます。このため、貴方のコードは他のプロジェクトやリファクタリング後に使用されるとき様々な方法で壊れる可能性があります。このレシピでは、推論のためにモデルをどのようにセーブしてロードするかについて両者の方法を調べます。
セットアップ
始める前に、torch をそれがまだ利用可能でないならばインストールする必要があります。
pip install torch
ステップ
- データをロードするために総ての必要なライブラリをインポートします。
- ニューラルネットワークを定義して初期化します。
- optimizer を初期化します。
- state_dict を通してモデルをセーブしてロードします。
- モデル全体をセーブしてロードする。
1. データをロードするために必要なライブラリをインポートする
このレシピのため、torch とその補助 torch.nn と torch.optim を使用します。
import torch
import torch.nn as nn
import torch.optim as optim
2. ニューラルネットワークを定義して初期化する
サンプルのために、訓練画像のためのニューラルネットワークを作成します。更に学習するためには ニューラルネットワークを定義する レシピを見てください。
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
print(net)
3. optimizer を初期化する
SGD with モメンタムを利用します。
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4. state_dict を通してモデルをセーブしてロードする
単に state_dict を使用してモデルをセーブしてロードしましょう。
# Specify a path
PATH = "state_dict_model.pt"
# Save
torch.save(net.state_dict(), PATH)
# Load
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()
一般的な PyToch の慣習は .pt か .pth ファイル拡張子を使用してモデルをセーブします。
load_state_dict() 関数はセーブされたオブジェクトへの パスではなく、辞書オブジェクトを取ることに注意してください。これはセーブされた stated_dict を load_state_dict() 関数に渡す前にデシリアライズしなければならないことを意味します。例えば、model.load_state_dict(PATH) を使用してはロード できません。
推論を実行する前に dropout とバッチ正規化層を評価モードに設定するために model.eval() を呼び出さなければならないことも忘れないでください。これをし損なうと一貫性のない推論結果をもたらします。
5. モデル全体をセーブしてロードする
今はモデル全体で同じことを試しましょう。
# Specify a path
PATH = "entire_model.pt"
# Save
torch.save(net, PATH)
# Load
model = torch.load(PATH)
model.eval()
再度ここで、推論を実行する前に dropout とバッチ正規化層を評価モードに設定するために model.eval() を呼び出さなければならないことを忘れないでください。
Congratulations! 貴方は PyTorch で推論のためにモデルを成功的にセーブしてロードしました。
以上