PyTorch デザインノート : シリアライゼーション・セマンティクス

PyTorch デザインノート : シリアライゼーション・セマンティクス (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/28/2018 (0.4.0)

* 本ページは、PyTorch Doc Notes の – Serialization semantics を動作確認・翻訳した上で
適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、適宜、追加改変している場合もあります。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

ベストプラクティス

モデルをセーブするための推奨アプローチ

モデルをシリアライズしてリストアするためには 2 つの主要なアプローチがあります。

1 番目(推奨) はモデル・パラメータだけをセーブしてロードします :

torch.save(the_model.state_dict(), PATH)

それから後で :

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

2 番目はモデル全体をセーブしてロードします :

torch.save(the_model, PATH)

それから後で :

the_model = torch.load(PATH)

けれどもこの場合、シリアライズされたデータは特定のクラスと使用された正確なディレクトリ構造に結合していますので、他のプロジェクトで使用されたり、何某かの重大なリファクタリングの後ではそれは様々な方法で壊れるかもしれません。

 

以上