PyTorch Lightning 1.1: Getting Started : PyTorch を Lightning に整理する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/05/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- Getting Started : How to organize PyTorch into Lightning
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
Getting Started : PyTorch を Lightning に整理する
1. 計算コードを移す
モデル・アーキテクチャと forward パスを貴方の LightningModule に移します。
class LitModel(LightningModule): def __init__(self): super().__init__() self.layer_1 = torch.nn.Linear(28 * 28, 128) self.layer_2 = torch.nn.Linear(128, 10) def forward(self, x): x = x.view(x.size(0), -1) x = self.layer_1(x) x = F.relu(x) x = self.layer_2(x) return x
2. optimizer(s) とスケジューラを移す
optimizer と configure_optimizers() フックを移します。
class LitModel(LightningModule): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer
3. 訓練ループ “meat (本質)” を見つける
Lightning は訓練の殆どを貴方のために自動化します、エポックとバッチ iteration、貴方が保持する必要があるのは訓練ステップのロジックです。これは training_step() フックに入るべきです (フック・パラメータを使用することを確実にしてください、この場合は batch と batch_idx です) :
class LitModel(LightningModule): def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss
4. val ループ “meat” を見つける
(オプションの) 検証ループを追加するにはロジックを validation_step() フックに追加します (フック・パラメータを使用することを確実にしてください、この場合は batch と batch_idx です) :
class LitModel(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) val_loss = F.cross_entropy(y_hat, y) return val_loss
NOTE : model.eval() と torch.no_grad() は検証のために自動的に呼び出されます。
5. テストループ “meat” を見つける
(オプションの) テストループを追加するにはロジックを test_step() フックに追加します (フック・パラメータを使用することを確実にしてください、この場合は batch と batch_idx です) :
class LitModel(LightningModule): def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) return loss
NOTE : model.eval() と torch.no_grad() はテストのために自動的に呼び出されます。
テストループは貴方が呼び出すまで使用されません。
trainer.test()
TIP : .test() はベスト・チェックポイントを自動的にロードします。
6. 任意の .cuda() や to.device() 呼び出しを除去する
LightningModule は任意のハードウェア上で自動的に実行できます!
以上