PyTorch Lightning 1.1 : Getting Started : PyTorch を Lightning に整理する

PyTorch Lightning 1.1: Getting Started : PyTorch を Lightning に整理する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/05/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

Getting Started : PyTorch を Lightning に整理する

1. 計算コードを移す

モデル・アーキテクチャと forward パスを貴方の LightningModule に移します。

class LitModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        return x

 

2. optimizer(s) とスケジューラを移す

optimizer と configure_optimizers() フックを移します。

class LitModel(LightningModule):

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

 

3. 訓練ループ “meat (本質)” を見つける

Lightning は訓練の殆どを貴方のために自動化します、エポックとバッチ iteration、貴方が保持する必要があるのは訓練ステップのロジックです。これは training_step() フックに入るべきです (フック・パラメータを使用することを確実にしてください、この場合は batch と batch_idx です) :

class LitModel(LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

 

4. val ループ “meat” を見つける

(オプションの) 検証ループを追加するにはロジックを validation_step() フックに追加します (フック・パラメータを使用することを確実にしてください、この場合は batch と batch_idx です) :

class LitModel(LightningModule):

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = F.cross_entropy(y_hat, y)
        return val_loss

NOTE : model.eval() と torch.no_grad() は検証のために自動的に呼び出されます。

 

5. テストループ “meat” を見つける

(オプションの) テストループを追加するにはロジックを test_step() フックに追加します (フック・パラメータを使用することを確実にしてください、この場合は batch と batch_idx です) :

class LitModel(LightningModule):

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

NOTE : model.eval() と torch.no_grad() はテストのために自動的に呼び出されます。

テストループは貴方が呼び出すまで使用されません。

trainer.test()

TIP : .test() はベスト・チェックポイントを自動的にロードします。

 

6. 任意の .cuda() や to.device() 呼び出しを除去する

LightningModule は任意のハードウェア上で自動的に実行できます!

 

以上