PyTorch 1.5 レシピ : 基本 : 異なるモデルからのパラメータを使用してモデルをウォームスタートする

PyTorch 1.5 レシピ : 基本 : 異なるモデルからのパラメータを使用してモデルをウォームスタートする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/13/2020 (1.5.0)

* 本ページは、PyTorch 1.5 Recipes の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

基本 : 異なるモデルからのパラメータを使用してモデルをウォームスタートする

モデルを部分的にロードする、あるいは一部のモデルをロードすることは転移学習や新しい複雑なモデルを訓練するとき一般的なシナリオです。訓練パラメータを活用することは、少々だけが利用可能である場合でさえも、訓練プロセスをウォームスタートする助けとなりそして望ましくは貴方のモデルがスクラッチからの訓練よりも高速に収束する助けとなるでしょう。

 

イントロダクション

幾つかのキーが欠落している部分的な state_dict からロードしていようが、あるいは (その中にロードしようとしてる) モデルよりも多いキーを持つ state_dict をロードしていようが、マッチしないキーを無視するために load_state_dict() 関数で strict 引数を False に設定できます。このレシピでは、異なるモデルのパラメータを使用してモデルをウォームスタートする実験を行ないます。

 

セットアップ

始める前に、torch をそれがまだ利用可能でないならばインストールする必要があります。

pip install torch

 

ステップ

  1. データをロードするために総ての必要なライブラリをインポートする。
  2. ニューラルネットワーク A と B を定義して初期化する。
  3. モデル A をセーブする。
  4. モデル B 内にロードする。

 

1. データをロードするために必要なライブラリをインポートする

このレシピのために、torch とその補助 torch.nn と torch.optim を使用します。

import torch
import torch.nn as nn
import torch.optim as optim

 

2. ニューラルネットワーク A と B を定義して初期化する

例のために、訓練画像のためのニューラルネットワークを作成します。更に学習するためには ニューラルネットワークを定義する レシピを見てください。タイプ A の一つのパラメータをタイプ B 内にロードするために 2 つのニューラルネットワークを作成します。

class NetA(nn.Module):
    def __init__(self):
        super(NetA, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netA = NetA()

class NetB(nn.Module):
    def __init__(self):
        super(NetB, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

netB = NetB()

 

3. モデル A をセーブする

# Specify a path to save to
PATH = "model.pt"

torch.save(netA.state_dict(), PATH)

 

4. モデル B にロードする

一つの層から他の一つにパラメータをロードすることを望み、しかし幾つかのキーが一致しない場合、(それの内にロードしている) モデルのキーに一致するようにロードしている state_dict のパラメータ・キーの名前を単純に変更してください。

netB.load_state_dict(torch.load(PATH), strict=False)

総てのキーが成功的に一致したことを見ることができます!

Congratulations! PyTorch で異なるモデルからのパラメータを使用してモデルを成功的にウォームスタートしました。

 
以上