PyTorch Lightning 1.1 : Getting Started : 2 ステップで Lightning

PyTorch Lightning 1.1: Getting Started : 2 ステップで Lightning (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/02/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

 

Getting Started : 2 ステップで Lightning

このガイドでは貴方の PyTorch コードを 2 ステップで Lightning にどのように体系化するかを示します

PyTorch Lightning によるコードの体系化は貴方のコードを以下のようにします :

  • 総ての柔軟性を保持しながら (これは依然として総て純粋な PyTorch です)、多くのボイラープレートは取り除きます。
  • 研究コードをエンジニアリングから切り離すことにより可読性を高める。
  • より容易に再現する
  • 訓練ループとトリッキーなエンジニアリングの殆どを自動化してエラーを少ない傾向にします。
  • モデルを変更することなく任意のハードウェアにスケーラブルです。

 

ステップ 0 : PyTorch Lightning をインストールする

pip を使用してインストールできます。

pip install pytorch-lightning

Or with conda (conda をどのようにインストールするかを ここ で見てください) :

conda install pytorch-lightning -c conda-forge

conda 環境を利用することもできるでしょう。

conda activate my_env
pip install pytorch-lightning

以下をインポートします :

import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

 

ステップ 1 : Lightning モジュールを定義する

class LitAutoEncoder(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28*28, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )
        self.decoder = nn.Sequential(
            nn.Linear(3, 64),
            nn.ReLU(),
            nn.Linear(64, 28*28)
        )

    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard by default
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

 
システム VS モデル

LightningModule はモデルではなくシステムを定義します。


システムのサンプルは :

内部的には LightningModule は依然として単なる torch.nn.Module で、これは総ての研究コードを自己充足的にするために単一のファイル内にグループ分けします :

  • 訓練ループ
  • 検証ループ
  • テストループ
  • モデル or モデル (群) のシステム
  • Optimizer

利用可能なコールバック・フック で見つかる 20+ フックのいずれかを override して (backward パスのような) 訓練の任意の部分をカスタマイズできます。

class LitAutoEncoder(pl.LightningModule):

    def backward(self, loss, optimizer, optimizer_idx):
        loss.backward()

 
FORWARD vs TRAINING_STEP

Lightning では推論から訓練を分離しています。training_step は完全な訓練ステップを定義します。推論アクションを定義するために forward を利用することをユーザに奨励します。

例えば、この場合にはオートエンコーダを embedding 抽出器として動作するように定義できるでしょう :

def forward(self, x):
    embeddings = self.encoder(x)
    return embeddings

もちろん、training_step 内から forward を使用することから貴方を止めるものはありません。

def training_step(self, batch, batch_idx):
    ...
    z = self(x)

それは実際に貴方のアプリケーションへ下りていきます。けれども、両者の意図を分離することを勧めます。

  • 推論 (予測) のために forward を使用する。
  • 訓練のために training_step を使用する。

LightningModule docs により多くの詳細があります。

 

ステップ 2 : Lightning Trainer で Fit させる

最初に、貴方が望むようなやり方でデータを定義します。Lightning は train/val/test のために DataLoader を単に必要とします。

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())
train_loader = DataLoader(dataset)

次に、LightningModule と PyTorch Lightning Trainer を初期化してから、データとモデルの両者とともに fit を呼び出します。

# init model
autoencoder = LitAutoEncoder()

# most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more)
# trainer = pl.Trainer(gpus=8) (if you have GPUs)
trainer = pl.Trainer()
trainer.fit(autoencoder, train_loader)

Trainer は以下を自動化します :

TIP : optimizer を手動で管理することを好むのであれば Manual optimization モードを利用できます (ie: RL, GAN, 等…)。

 
That’s it!

これらが Lightning で知る必要がある主要な 2 つのコンセプトです。lightning の他の総ての特徴は Trainer か LightningModule のいずれかの特徴です。

 

以上