PyTorch 1.8 チュートリアル : PyTorch の学習 : 基本 – モデルのセーブ & ロード (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/18/2021 (1.8.0)
* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:
- Learning PyTorch : Learn the Basics : Save & Load the Model
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
PyTorch の学習 : 基本 – モデルのセーブ & ロード
このセクションではセーブ、ロードそしてモデル予測を実行することでモデル状態をどのように永続化するかを見ます。
import torch
import torch.onnx as onnx
import torchvision.models as models
モデル重みをセーブしてロードする
PyTorch モデルは学習されたパラメータを state_dict と呼ばれる、内部状態辞書にストアします。これらは torch.save メソッドを通して永続化できます :
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
モデル重みをロードするには、最初に同じモデルのインスタンスを作成してから load_state_dict() メソッドを使用してパラメータをロードする必要があります。
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
Note
dropout とバッチ正規化層を評価モードに設定するために推論の前に model.eval() メソッドを確実に呼び出してください。これを行なうことを失敗すれば一貫性のない推論結果を生成します。
形状とともにモデルをセーブしてロードする
モデル重みをロードするとき、モデルクラスを最初にインスタンス化することが必要でした、何故ならばクラスがネットワークの構造を定義するからです。このクラスの構造をモデルと一緒にセーブすることを望むかもしれません、その場合モデル (そして model.state_dict() ではない) をセービング関数に渡すことができます :
torch.save(model, 'model.pth')
それからこのようにモデルをロードできます :
model = torch.load('model.pth')
Note
このアプローチはモデルをシリアライズするとき Python pickle モジュールを使用しますので、モデルをロードするときそれは実際のクラス定義が利用可能であることに依拠します。
モデルを ONNX にエクスポートする
PyTorch はまたネイティブの ONNX エクスポート・サポートも持ちます。PyTorch 実行グラフの動的性質が与えられたとき、けれども、エクスポート・プロセスは永続化された ONNX モデルを生成するには実行グラフを辿らなければなりません。このため、適切なサイズの test 変数が export ルーチンに渡されるべきです (私達のケースでは、正しいサイズのダミー・ゼロ tensor を作成します) :
input_image = torch.zeros((1,3,224,224))
onnx.export(model, input_image, 'model.onnx')
異なるプラットフォーム上と異なるプログラミング言語で推論を実行することを含む、ONNX モデルでできるたくさんのことがあります。より詳細については、ONNX チュートリアル を訪ねることを勧めます。
以上