PyTorch 1.8 レシピ : 基本 : デバイスに渡りモデルをセーブ/ロードする

PyTorch 1.8 チュートリアル : レシピ : 基本 :- デバイスに渡りモデルをセーブ/ロードする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/29/2021 (1.8.1+cu102)

* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料 Web セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。
スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

レシピ : 基本 :- デバイスに渡りモデルをセーブ/ロードする

異なるデバイスに渡りニューラルネットワークをセーブしてロードすることを望む事例があるかもしれません。

 

イントロダクション

デバイスに渡りモデルをセーブしてロードすることは PyTorch を使用すれば比較的簡単です。このレシピでは、CPU と GPU に渡りモデルをセーブしてロードする実験をします。

 

セットアップ

総てのコードブロックがこのレシピで正しく動作するために、最初にランタイムを “GPU” またはそれ以上に変更しなければなりません。ひとたびそれを行えば、torch をそれがまだ利用可能でないならばインストールする必要があります。

pip install torch

 

ステップ

  1. データをロードするために総ての必要なライブラリをインポートする。
  2. ニューラルネットワークを定義して初期化する。
  3. GPU 上にセーブして、CPU 上にロードする。
  4. GPU 上にセーブして、GPU 上にロードする。
  5. CPU 上にセーブして、GPU 上にロードする
  6. DataParallel モデルをセーブしてロードする。

 

1. データをロードするために必要なライブラリをインポートする

このレシピのために、torch とその補助 torch.nn と torch.optim を使用します。

import torch
import torch.nn as nn
import torch.optim as optim

 

2. ニューラルネットワークを定義して初期化する

例のために、訓練画像のためのニューラルネットワークを作成します。更に学習するためには ニューラルネットワークを定義する レシピを見てください。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

 

3. GPU 上でセーブして、CPU 上にロードする

GPU で訓練されたモデルを CPU 上でロードするとき、torch.load() 関数の map_location 引数に torch.device(‘cpu’) を渡します。

# Specify a path to save to
PATH = "model.pt"

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device('cpu')
model = Net()
model.load_state_dict(torch.load(PATH, map_location=device))

この場合、tensor の礎となるストレージは map_location 引数を使用して CPU デバイスに動的に再マップされます。

 

4. GPU 上でセーブして、GPU 上にロードする

GPU で訓練されてセーブされたモデルを GPU 上でロードするとき、初期化されたモデルを model.to(torch.device(‘cuda’)) を使用して CUDA 最適化モデルに単純に変換します。

モデルのためのデータを準備するために総てのモデル入力上で .to(torch.device(‘cuda’)) 関数を使用することを確実にしてください。

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device("cuda")
model = Net()
model.load_state_dict(torch.load(PATH))
model.to(device)

my_tensor.to(device) の呼び出しは GPU 上で my_tensor の新しいコピーを返すことに注意してください。それは my_tensor を上書き しません。従って、tensor を手動で上書きするためには忘れないでください : my_tensor = my_tensor.to(torch.device(‘cuda’))

 

5. CPU 上でセーブして、GPU 上にロードする

CPU 上で訓練されてセーブされたモデルを GPU 上でロードするとき、torch.load() 関数の map_location 引数を cuda:device_id に設定します。これはモデルを与えられた GPU デバイスにロードします。

モデルのパラメータ tensor を CUDA tensor に変換するために model.to(torch.device(‘cuda’)) を呼び出すことを確実にしてください。

最後に、CUDA 最適化モデルのためのデータを準備するために総てのモデル入力上で .to(torch.device(‘cuda’)) 関数を使用することもまた確実にしてください。

# Save
torch.save(net.state_dict(), PATH)

# Load
device = torch.device("cuda")
model = Net()
# Choose whatever GPU device number you want
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
model.to(device)

 

6. torch.nn.DataParallel モデルをセーブする

torch.nn.DataParallel は並列 GPU 利用を可能にするモデル・ラッパーです。

DataParallel モデルを一般的にセーブするには、model.module.state_dict() をセーブします。このように、望む任意のデバイスに望む任意の方法でモデルをロードする柔軟性を持ちます。

# Save
torch.save(net.module.state_dict(), PATH)

# Load to whatever device you want

Congratulations! PyTorch でデバイスに渡りモデルを成功的にセーブしてロードしました。

 

以上