PyTorch Lightning 1.1 : Tutorials : ウォークスルー (MNIST からオートエンコーダ)

PyTorch Lightning 1.1: Tutorials : ウォークスルー (MNIST からオートエンコーダ) (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/06/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

 

Tutorials : ウォークスルー (MNIST からオートエンコーダ)

このガイドは PyTorch Lightning の中心的なピースを貴方にガイドします。

以下を達成します :

  • MNIST 分類器を実装します。
  • AutoEncoder を実装するために継承を利用します。

NOTE : 任意の DL/ML PyTorch プロジェクトは Lightning 構造に適合します。ここでは例示するために単に 3 タイプの研究にフォーカスします。

 

MNIST からオートエンコーダ

Lightning をインストールする

Lightning のインストールは自明です。 conda 環境の使用を勧めます :

conda activate my_env
pip install pytorch-lightning

あるいは conda 環境なしに、pip を使用します。

pip install pytorch-lightning

Or conda.

conda install pytorch-lightning -c conda-forge

 

研究

モデル

LightningModule は総ての中心的な研究構成要素を保持します :

  • モデル
  • optimizer
  • train/ val/ test ステップ。

最初はモデルから始めましょう。この場合、3-層ニューラルネットワークを設計します。

import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule

class LitMNIST(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)
    x = F.relu(x)
    x = self.layer_3(x)

    x = F.log_softmax(x, dim=1)
    return x

これは torch.nn.Module の代わりに LightningModule であることに気付くでしょう。LightningModule はそれが追加された機能を持つことを除いて純粋な PyTorch Model に同値です。けれども、貴方はそれを PyToch Module を利用するのと 正確に 同じに利用できます。

net = LitMNIST()
x = torch.randn(1, 1, 28, 28)
out = net(x)
tensor([[-2.2484, -2.2880, -2.3452, -2.2685, -2.2250, -2.5208, -2.2343, -2.3637,
         -2.1987, -2.3739]], grad_fn=<LogSoftmaxBackward>)
out.size()
torch.Size([1, 10])

今は training_step を追加します、これは総ての訓練ループロジックを持ちます。

class LitMNIST(LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

 

データ

Lightning は純粋なデータローダ上で動作します。ここに MNIST をロードするための PyTorch コードがあります。

from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
import os
from torchvision import datasets, transforms

# transforms
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.1307,), (0.3081,))])

# data
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train = DataLoader(mnist_train, batch_size=64)

DataLoader を 3 つの方法で利用できます :

 

1. DataLoaders を .fit() に渡す

データローダを .fit() 関数に渡します。

model = LitMNIST()
trainer = Trainer()
trainer.fit(model, mnist_train)

 

2. LightningModule DataLoader

高速な研究プロトタイピングのためには、モデルをデータローダとリンクすることが容易かもしれません。

class LitMNIST(pl.LightningModule):

    def train_dataloader(self):
        # transforms
        # prepare transforms standard to MNIST
        transform=transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.1307,), (0.3081,))])
        # data
        mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
        return DataLoader(mnist_train, batch_size=64)

    def val_dataloader(self):
        transforms = ...
        mnist_val = ...
        return DataLoader(mnist_val, batch_size=64)

    def test_dataloader(self):
        transforms = ...
        mnist_test = ...
        return DataLoader(mnist_test, batch_size=64)

DataLoader は既にモデル内にあります、.fit() 上で指定する必要はありません。

model = LitMNIST()
trainer = Trainer()
trainer.fit(model)

 

3. DataModule (推奨)

自由に動く (= free-floating) データローダ、分割、ダウンロード手順を定義すると、そのようなものは乱雑になり得ます。この場合、データセットの完全な定義を DataModule にグループ分けするのがより良いです、これは以下を含みます :

  • ダウンロード手順
  • (前) 処理手順
  • 分割手順
  • 訓練データローダ
  • Val データローダ
  • テストデータローダ
class MyDataModule(LightningDataModule):

    def __init__(self):
        super().__init__()
        self.train_dims = None
        self.vocab_size = 0

    def prepare_data(self):
        # called only on 1 GPU
        download_dataset()
        tokenize()
        build_vocab()

    def setup(self):
        # called on every GPU
        vocab = load_vocab()
        self.vocab_size = len(vocab)

        self.train, self.val, self.test = load_datasets()
        self.train_dims = self.train.next_batch.size()

    def train_dataloader(self):
        transforms = ...
        return DataLoader(self.train, batch_size=64)

    def val_dataloader(self):
        transforms = ...
        return DataLoader(self.val, batch_size=64)

    def test_dataloader(self):
        transforms = ...
        return DataLoader(self.test, batch_size=64)

DataModule の使用は完全なデータセット定義の容易な共有を可能にします。

# use an MNIST dataset
mnist_dm = MNISTDatamodule()
model = LitModel(num_classes=mnist_dm.num_classes)
trainer.fit(model, mnist_dm)

# or other datasets with the same model
imagenet_dm = ImagenetDatamodule()
model = LitModel(num_classes=imagenet_dm.num_classes)
trainer.fit(model, imagenet_dm)

NOTE : prepare_data() は分散訓練で一つの GPU 上だけで (自動的に) 呼び出されます。

NOTE : setup() は総ての GPU 上で (自動的に) 呼び出されます。

 

データにより定義されるモデル

モデルがデータについて知る必要があるとき、データをモデルに渡す前に渡すのがベストです。

# init dm AND call the processing manually
dm = ImagenetDataModule()
dm.prepare_data()
dm.setup()

model = LitModel(out_features=dm.num_classes, img_width=dm.img_width, img_height=dm.img_height)
trainer.fit(model, dm)
  1. データセットをダウンロードして処理するために prepare_data() を使用する。
  2. 分割を行ない、モデル内部をビルドするために setup() を使用する。

DataModule を使用することへの代替は次のようにモデルモジュールの初期化を LightningModule の setup メソッドまで遅延することです :

class LitMNIST(LightningModule):

    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

    def setup(self, step):
        # step is either 'fit' or 'test' 90% of the time not relevant
        data = load_data()
        num_classes = data.classes
        self.l1 = nn.Linear(..., num_classes)

 

Optimizer

次にシステムを訓練するために何の optimizer を使用するか選択します。PyTorch ではそれは次のように行ないます :

from torch.optim import Adam
optimizer = Adam(LitMNIST().parameters(), lr=1e-3)

Lightning では同じことを行ないますがそれを configure_optimizers() メソッド下に体系化します。

class LitMNIST(LightningModule):

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

NOTE : LightningModule 自身がパラメータを持ちますので、self.parameters() を渡します。

けれども、複数の optimizer に適合するパラメータを使用させる場合には :

class LitMNIST(LightningModule):

    def configure_optimizers(self):
        return Adam(self.generator(), lr=1e-3), Adam(self.discriminator(), lr=1e-3)

 

訓練ステップ

訓練ステップは訓練ループの内部で発生するものです。

for epoch in epochs:
    for batch in data:
        # TRAINING STEP
        # ....
        # TRAINING STEP
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

MNIST の場合、以下を行ないます :

for epoch in epochs:
    for batch in data:
        # ------ TRAINING STEP START ------
        x, y = batch
        logits = model(x)
        loss = F.nll_loss(logits, y)
        # ------ TRAINING STEP END ------

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Lightning では、訓練ステップにある総ては LightningModule の training_step() 関数下に体系化されます。

class LitMNIST(LightningModule):

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

再度、これは同じ PyTorch コードです、それが LightningModule により体系化されたことを除いてです。このコードは制限されません、これはそれは完全な seq-2-seq, RL ループ, GAN 等… のように複雑であり得ることを意味します。

 

エンジニアリング

訓練

ここまで純粋な PyTorch で 4 つの主要な構成要素を定義しましたが LightningModule でコードを体系化しました。

  1. モデル。
  2. 訓練データ。
  3. Optimizer.
  4. 訓練ループで発生するもの。

明確にするため、完全な LightningModule が今ではこのように見えることを思い出しましょう。

class LitMNIST(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

再度、これは同じ PyTorch コードです、それが LightningModule により体系化されたことを除いてです。

 

ロギング

TensorBoard, お気に入りのロガー、and/or 進捗バーにロギングするためには、log() メソッドを使用します、これは LightningModule の任意のメソッドから呼び出せます。

def training_step(self, batch, batch_idx):
    self.log('my_metric', x)

log() メソッドは幾つかオプションを持ちます :

  • on_step (訓練のそのステップでメトリックをロギングする)
  • on_epoch (自動的に累算してエポックの最後にロギングします)
  • prog_bar (進捗バーにロギングします)
  • ロガー (Tensorboard のようなロガーにロギングします)

log がどこから呼び出されるかに依拠して、Lightning は正しいモードを自動決定します。しかしもちろんフラグを手動で設定してデフォルト動作を override できます。

NOTE : on_epoch=True の設定は full 訓練エポックに渡りログ値を累算します。

def training_step(self, batch, batch_idx):
    self.log('my_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

ロガーの任意のメソッドを直接利用することもできます :

def training_step(self, batch, batch_idx):
    tensorboard = self.logger.experiment
    tensorboard.any_summary_writer_method_you_want())

ひとたび訓練が始まれば、お気に入りのロガーを使用するか TensorBoard ログをブートアップしてログを見ることができます :

tensorboard --logdir ./lightning_logs

それは自動的に tensorboard ログを生成します (or 選択したロガーにより)。

しかしサポートされる 任意の数の他のロガー を利用することもできます。

 

CPU 上で訓練する
from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer()
trainer.fit(model, train_loader)

次の重み要約と進捗バーを見るはずです :

 

GPU 上で訓練する

しかし美しさは trainer フラグで行えるマジックです。例えば、GPU 上でこのモデルを実行するには :

model = LitMNIST()
trainer = Trainer(gpus=1)
trainer.fit(model, train_loader)

 

マルチ GPU 上で訓練する

あるいはマルチ GPU 上でも訓練できます。

model = LitMNIST()
trainer = Trainer(gpus=8)
trainer.fit(model, train_loader)

あるいはマルチノード :

# (32 GPUs)
model = LitMNIST()
trainer = Trainer(gpus=8, num_nodes=4, accelerator='ddp')
trainer.fit(model, train_loader)

より詳細については 分散コンピューティングガイド を参照してください。

 

TPU 上で訓練する

PyTorch を TPU 上で利用できることを知っていますか?行なうことは非常に困難ですが、これを out of the box に動作させるため私達は xla チームと素晴らしいライブラリを利用するために作業しました!

Colab 上で訓練しましょう (ここで完全なデモが利用可能です)

最初に、ランタイムを TPU に変更します (そして lightning を再インストールします)。

次に、必要な xla ライブラリをインストールします (TPU 上の PyTorch のためのサポートを追加します) :

!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py

!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

分散訓練では (マルチ GPU とマルチ TPU コア) 各 GPU や TPU コアはこのプログラムのコピーを実行します。これは、どのようなケアもなければデータセットを N 回ダウンロードすることを意味し、これは総ての種類の問題を引き起こします。

この問題を解くには、貴方のダウンロードコードが DataModule の prepare_data メソッドにあることを確実にしてください。このメソッドでは (総ての GPU 上の代わりに) 一度実行する必要がある総ての準備を行ないます。

prepare_data は 2 つの方法で呼び出されます、ノード毎に 1 度かルートノード上でだけ Trainer(prepare_data_per_node=False) です。

class MNISTDataModule(LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        # download only
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

    def setup(self, stage):
        # transform
        transform=transforms.Compose([transforms.ToTensor()])
        mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
        mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

prepare_data メソッドは一度だけ行なわれる必要がある任意のデータ処理を行なうに良い場所でもあります (ie: ダウンロード or トークン化等 …)。

NOTE : Lightning は分散訓練のために正しい DistributedSampler を挿入します。貴方自身で追加する必要はありません!

今では他のどのようなことをすることもなく TPU 上で LightningModule を訓練することができます!

dm = MNISTDataModule()
model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model, dm)

今では TPU コアがブートアップするのを見るでしょう。

 

ハイパーパラメータ

Lightning はコマンドライン ArgumentParser とシームレスに相互作用するユティリティを持ち貴方の選択するハイパーパラメータ最適化フレームワークと上手くプレイできます。

 

ArgumentParser

Lightning は組込み Python ArgumentParser の多くの機能を増強するように設計されています。

from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--layer_1_dim', type=int, default=128)
args = parser.parse_args()

これは貴方のプログラムをこのように呼び出すことを可能にします :

python trainer.py --layer_1_dim 64

 

Argparser ベストプラクティス

引数を 3 つのセクションに層化することがベストプラクティスです。

  1. Trainer args (gpus, num_nodes, etc…)
  2. モデル固有引数 (layer_dim, num_layers, learning_rate 等…)
  3. プログラム引数 (data_path, cluster_email, 等…)

これを以下のように行えます。最初に、LightningModule で、そのモジュールに固有の引数を定義します。データ分割やデータパスもまたモジュール固有であるかもしれないことを忘れないでください (i.e.: プロジェクトが CIFAR-10 上 Imagenet と他の上で訓練するモデルを持つ場合)。

class LitModel(LightningModule):

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--encoder_layers', type=int, default=12)
        parser.add_argument('--data_path', type=str, default='/some/path')
        return parser

今は主要な trainer ファイルで、Trainer args, プログラム args を追加し、そしてモデル args を追加します。

# ----------------
# trainer_main.py
# ----------------
from argparse import ArgumentParser
parser = ArgumentParser()

# add PROGRAM level args
parser.add_argument('--conda_env', type=str, default='some_name')
parser.add_argument('--notification_email', type=str, default='will@email.com')

# add model specific args
parser = LitModel.add_model_specific_args(parser)

# add all the available trainer options to argparse
# ie: now --gpus --num_nodes ... --fast_dev_run all work in the cli
parser = Trainer.add_argparse_args(parser)

args = parser.parse_args()

今ではこのようにプログラムを実行するために呼び出すことができます :

python trainer_main.py --gpus 2 --num_nodes 2 --conda_env 'my_env' --encoder_layers 12

最後にこのように訓練を開始することを確実にしてください :

# init the trainer like this
trainer = Trainer.from_argparse_args(args, early_stopping_callback=...)

# NOT like this
trainer = Trainer(gpus=hparams.gpus, ...)

# init the model with Namespace directly
model = LitModel(args)

# or init the model with all the key-value pairs
dict_args = vars(args)
model = LitModel(**dict_args)

 

LightningModule ハイパーパラメータ

しばしば何回もモデルの多くのバージョンを訓練します。そのモデルを共有するかそのモデルがどのように訓練されたかを知るために非常に有用なポイントで数カ月後にそれに戻るかもしれません (i.e.: どのような学習率、ニューラルネットワーク等…)。

Lightning はその情報をチェックポイントと yaml ファイルにセーブする幾つかの方法を持ちます。ここでの目標は可読性と再現性を改良することです。

  1. 最初の方法は lightning に __init__ 内の任意のものの値をチェックポイントにセーブすることです。これはまたそれらの値を self.hparams を通して利用可能にもします。
    class LitMNIST(LightningModule):
    
        def __init__(self, layer_1_dim=128, learning_rate=1e-2, **kwargs):
            super().__init__()
            # call this to save (layer_1_dim=128, learning_rate=1e-4) to the checkpoint
            self.save_hyperparameters()
    
            # equivalent
            self.save_hyperparameters('layer_1_dim', 'learning_rate')
    
            # Now possible to access layer_1_dim from hparams
            self.hparams.layer_1_dim
    
  2. 時に init はセーブすることを望まないかもしれないオブジェクトや他のパラメータを持つかもしれません。その場合、幾つかだけを選択します :
    class LitMNIST(LightningModule):
    
        def __init__(self, loss_fx, generator_network, layer_1_dim=128 **kwargs):
            super().__init__()
            self.layer_1_dim = layer_1_dim
            self.loss_fx = loss_fx
    
            # call this to save (layer_1_dim=128) to the checkpoint
            self.save_hyperparameters('layer_1_dim')
    
    # to load specify the other args
    model = LitMNIST.load_from_checkpoint(PATH, loss_fx=torch.nn.SomeOtherLoss, generator_network=MyGenerator())
    
  3. self.hparams に割当てます。self.hparams に割当てられた任意のものもまた自動的にセーブされます。
    # using a argparse.Namespace
    class LitMNIST(LightningModule):
        def __init__(self, hparams, *args, **kwargs):
            super().__init__()
            self.hparams = hparams
            self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
            self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
            self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
        def train_dataloader(self):
            return DataLoader(mnist_train, batch_size=self.hparams.batch_size)
    

    WARNING: v1.1.0 から deprecated です。ハイパーパラメータを LightningModule に割当てるこの方法は v1.3.0 からは最早サポートされません。代わりに上からの self.save_hyperparameters() を使用してください。

  4. dict or Namespace のような full オブジェクトをチェックポイントにセーブすることもできます。
    # using a argparse.Namespace
    class LitMNIST(LightningModule):
    
        def __init__(self, conf, *args, **kwargs):
            super().__init__()
            self.save_hyperparameters(conf)
    
            self.layer_1 = torch.nn.Linear(28 * 28, self.hparams.layer_1_dim)
            self.layer_2 = torch.nn.Linear(self.hparams.layer_1_dim, self.hparams.layer_2_dim)
            self.layer_3 = torch.nn.Linear(self.hparams.layer_2_dim, 10)
    
    conf = OmegaConf.create(...)
    model = LitMNIST(conf)
    
    # Now possible to access any stored variables from hparams
    model.hparams.anything
    

 

Trainer args

まとめると、総ての可能な trainer フラグを argparse に追加して Trainer をこのように初期化します。

parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
hparams = parser.parse_args()

trainer = Trainer.from_argparse_args(hparams)

# or if you need to pass in callbacks
trainer = Trainer.from_argparse_args(hparams, checkpoint_callback=..., callbacks=[...])

 

複数の Lightning モジュール

しばしば複数の Lightning モジュールを持ちます、そこでは各々の一つは異なる引数を持ちます。main.py ファイルを汚す代わりに、LightningModule は各々の一つのための引数を貴方に定義させます。

class LitMNIST(LightningModule):

    def __init__(self, layer_1_dim, **kwargs):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, layer_1_dim)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--layer_1_dim', type=int, default=128)
        return parser
class GoodGAN(LightningModule):

    def __init__(self, encoder_layers, **kwargs):
        super().__init__()
        self.encoder = Encoder(layers=encoder_layers)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--encoder_layers', type=int, default=12)
        return parser

今は各モデルが必要な引数を main.py に注入することを許容できます。

def main(args):
    dict_args = vars(args)

    # pick model
    if args.model_name == 'gan':
        model = GoodGAN(**dict_args)
    elif args.model_name == 'mnist':
        model = LitMNIST(**dict_args)

    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(parser)

    # figure out which model to use
    parser.add_argument('--model_name', type=str, default='gan', help='gan or mnist')

    # THIS LINE IS KEY TO PULL THE MODEL NAME
    temp_args, _ = parser.parse_known_args()

    # let the model add what it wants
    if temp_args.model_name == 'gan':
        parser = GoodGAN.add_model_specific_args(parser)
    elif temp_args.model_name == 'mnist':
        parser = LitMNIST.add_model_specific_args(parser)

    args = parser.parse_args()

    # train
    main(args)

そして今ではコマンドライン・インターフェイスを使用して MNIST や GAN を訓練できます!

$ python main.py --model_name gan --encoder_layers 24
$ python main.py --model_name mnist --layer_1_dim 128

 

検証する

殆どの場合、データの検証分割上の性能が最小に達するときモデルの訓練を止めます。

ちょうど training_step のように、ケアするメトリクスが何であれ確認し、サンプルを生成したり、多くをログに追加するために validation_step を定義できます。

def validation_step(self, batch, batch_idx):
    loss = MSE_loss(...)
    self.log('val_loss', loss)

今は検証ループとともに訓練することもできます。

from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model, train_loader, val_loader)

ロギングされる単語 Validation sanity check (検証サニティ・チェック) に気付いたかもしれません。これは Lightning は訓練を始める前に 2 バッチの検証を実行するためです。これは検証ループにバグがないか確かめる一種のユニットテストで、それを見つけるために full エポックの間潜在的に待つ必要はありません。

NOTE : Lightning は勾配を無効にし、モデルを eval モードに置き、そして検証のために必要な総てを行ないます。

 

内部での Val ループ

内部的には、Lightning は以下を行ないます :

model = Model()
model.train()
torch.set_grad_enabled(True)

for epoch in epochs:
    for batch in data:
        # ...
        # train

    # validate
    model.eval()
    torch.set_grad_enabled(False)

    outputs = []
    for batch in val_data:
        x, y = batch                        # validation_step
        y_hat = model(x)                    # validation_step
        loss = loss(y_hat, x)               # validation_step
        outputs.append({'val_loss': loss})  # validation_step

    total_loss = outputs.mean()             # validation_epoch_end

 

オプションのメソッド

依然として更なる細かい制御さえ必要な場合には、ループのために他のオプションのメソッドを定義します。

def validation_step(self, batch, batch_idx):
    preds = ...
    return preds

def validation_epoch_end(self, val_step_outputs):
    for pred in val_step_outputs:
        # do something with all the predictions from each validation_step

 

テストする

ひとたび研究が成されてモデルを公開するか配備しようとするとき、通常はそれが「現実世界」でどのように一般化されるかを見出すことを望みます。このため、テストのためにデータの取り置いた分割を利用します。

ちょうど検証ループのように、テストループを次のように定義します :

class LitMNIST(LightningModule):
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('test_loss', loss)

けれども、テストセットが不注意に使用されていないことを確実にするため、Lightning はテストを実行するために分離した API を持ちます。ひとたびモデルを訓練したのであれば単純に .test() を呼び出します。

from pytorch_lightning import Trainer

model = LitMNIST()
trainer = Trainer(tpu_cores=8)
trainer.fit(model)

# run test set
result = trainer.test()
print(result)
--------------------------------------------------------------
TEST RESULTS
{'test_loss': 1.1703}
--------------------------------------------------------------

セーブされた lightning モデルからテストを実行することもできます。

model = LitMNIST.load_from_checkpoint(PATH)
trainer = Trainer(tpu_cores=8)
trainer.test(model)

NOTE : Lightning は勾配を無効にし、モデルを eval モードに置き、そしてテストのために必要な総てを行ないます。

WARNING : .test() は TPU 上ではステーブルではありません。私達はマルチ処理の課題を回避するために作業しています。

 

予測する

再度、LightningModule は PyTorch モジュールを正確に同じです。これは予測のためにそれをロードして利用できることを意味します。

model = LitMNIST.load_from_checkpoint(PATH)
x = torch.randn(1, 1, 28, 28)
out = model(x)

表面上は、forward と training_step は同様に見えます。一般に、 私達はモデルに行なうことを望むものは forward で発生するものであることを確かにすることを望みます。一方で training_step はそれの内部から forward を呼び出しがちです。

class MNISTClassifier(LightningModule):

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss
model = MNISTClassifier()
x = mnist_image()
logits = model(x)

この場合、ロジットを予測するようにこの LightningModel を設定しました。しかしそれを特徴マップを予測させるようにもできたでしょう :

class MNISTRepresentator(LightningModule):

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x1 = F.relu(x)
        x = self.layer_2(x1)
        x2 = F.relu(x)
        x3 = self.layer_3(x2)
        return [x, x1, x2, x3]

    def training_step(self, batch, batch_idx):
        x, y = batch
        out, l1_feats, l2_feats, l3_feats = self(x)
        logits = F.log_softmax(out, dim=1)
        ce_loss = F.nll_loss(logits, y)
        loss = perceptual_loss(l1_feats, l2_feats, l3_feats) + ce_loss
        return loss
model = MNISTRepresentator.load_from_checkpoint(PATH)
x = mnist_image()
feature_maps = model(x)

あるいは生成を行なうために使用するモデルをおそらくは持ちます。

class LitMNISTDreamer(LightningModule):

    def forward(self, z):
        imgs = self.decoder(z)
        return imgs

    def training_step(self, batch, batch_idx):
        x, y = batch
        representation = self.encoder(x)
        imgs = self(representation)

        loss = perceptual_loss(imgs, x)
        return loss
model = LitMNISTDreamer.load_from_checkpoint(PATH)
z = sample_noise()
generated_imgs = model(z)

forward vs training_step で進むものをどのように分割するかはこのモデルを予測のためにどのように利用することを望むかに依拠します。

 

必須ではないもの (= nonessentials)

拡張性

lightning は総てを非常に単純にしますが、どのような柔軟性や制御も犠牲にしません。lightning は訓練状態を管理する複数の方法を提供します。

 

訓練 override

訓練、検証とテストループの任意の部分は変更できます。例えば、貴方自身の backward パスを望んだ場合、次のデフォルト実装を :

def backward(self, use_amp, loss, optimizer):
    loss.backward()

貴方自身のもので override するでしょう :

class LitMNIST(LightningModule):

    def backward(self, use_amp, loss, optimizer, optimizer_idx):
        # do a custom way of backward
        loss.backward(retain_graph=True)

訓練の総ての単一パートはこのように configurable です。完全なリストについては LightningModule を見てください。

 

コールバック

任意の機能を追加するもう一つの方法は配慮するかもしれないフックのためのカスタム・コールバックを追加することです :

from pytorch_lightning.callbacks import Callback

class MyPrintingCallback(Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('Trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

そしてコールバックを trainer 内に渡します :

trainer = Trainer(callbacks=[MyPrintingCallback()])

TIP : コールバック の 12+ フックの完全なリストを見てください。

 

Child モジュール

研究プロジェクトは同じデータセットに異なるアプローチをテストする傾向にあります。これは継承により Lightning で行なうことは非常に容易です。

例えば、今 MNIST 画像のための特徴抽出器としてオートエンコーダを訓練することを望むと想像してください。私達は既に総てのデータローディングを定義している LitMNIST-モジュールからオートエンコーダを拡張しています。オートエンコーダ・モデルで変更されるものは init, forward, training, validation そして test ステップだけです。

class Encoder(torch.nn.Module):
    pass

class Decoder(torch.nn.Module):
    pass

class AutoEncoder(LitMNIST):

    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.metric = MSE()

    def forward(self, x):
        return self.encoder(x)

    def training_step(self, batch, batch_idx):
        x, _ = batch

        representation = self.encoder(x)
        x_hat = self.decoder(representation)

        loss = self.metric(x, x_hat)
        return loss

    def validation_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, 'val')

    def test_step(self, batch, batch_idx):
        self._shared_eval(batch, batch_idx, 'test')

    def _shared_eval(self, batch, batch_idx, prefix):
        x, _ = batch
        representation = self.encoder(x)
        x_hat = self.decoder(representation)

        loss = self.metric(x, x_hat)
        self.log(f'{prefix}_loss', loss)

そして同じ trainer を使用してこれを訓練することができます。

autoencoder = AutoEncoder()
trainer = Trainer()
trainer.fit(autoencoder)

そして forward メソッドは LightningModule の実際の使用を定義するべきであることを忘れないでください。この場合、AutoEncoder を画像表現を抽出するために使用することを望みます。

some_images = torch.Tensor(32, 1, 28, 28)
representations = autoencoder(some_images)

 

転移学習

事前訓練モデルを使用する

時に LightningModule を事前訓練モデルとして使用することを望みます。LightningModule は単に torch.nn.Module ですので、これは構いません!

NOTE : LightningModule はより多くの機能を持ちますが 正確に torch.nn.Module であることを思い出してください。

別のモデルで AutoEncoder を特徴抽出機として使用しましょう。

class Encoder(torch.nn.Module):
    ...

class AutoEncoder(LightningModule):
    def __init__(self):
        self.encoder = Encoder()
        self.decoder = Decoder()

class CIFAR10Classifier(LightningModule):
    def __init__(self):
        # init the pretrained LightningModule
        self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
        self.feature_extractor.freeze()

        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(100, 10)

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        ...

転移学習のために事前訓練 Autoencoder (LightningModule) を使用しました!

 

サンプル: Imagenet (コンピュータビジョン)
import torchvision.models as models

class ImagenetTransferLearning(LightningModule):
    def __init__(self):
        super().__init__()

        # init a pretrained resnet
        backbone = models.resnet50(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = torch.nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10
        self.classifier = nn.Linear(num_filters, num_target_classes)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        x = self.classifier(representations)
        ...

再調整

model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)

そして関心のあるデータを予測するためにそれを使用します。

model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()

x = some_images_from_cifar10()
predictions = model(x)

imagenet 上の事前訓練モデルを使用して、CIFAR-10 上で予測するために CIFAR-10 上で再調整しました。非-学会では貴方が持つ tiny データセット上で再調整してそのデータセット上で予測するでしょう。

 

サンプル: BERT (NLP)

lightning は転移学習のために何が使用されたかには完全に不可知です、それが torch.nn.Module サブクラスである限りは。

ここに Huggingface transformers を使用するモデルがあります。

class BertMNLIFinetuner(LightningModule):

    def __init__(self):
        super().__init__()

        self.bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
        self.W = nn.Linear(bert.config.hidden_size, 3)
        self.num_classes = 3


    def forward(self, input_ids, attention_mask, token_type_ids):

        h, _, attn = self.bert(input_ids=input_ids,
                         attention_mask=attention_mask,
                         token_type_ids=token_type_ids)

        h_cls = h[:, 0]
        logits = self.W(h_cls)
        return logits, attn
 

以上