PyTorch Lightning 1.1 : notebooks : PyTorch Lightning で TPU 訓練

PyTorch Lightning 1.1: notebooks : PyTorch Lightning で TPU 訓練 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/16/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

 

notebooks : PyTorch Lightning で TPU 訓練

  • TPU 訓練 – Lightning で TPU を使用して MNIST 上でモデルを訓練します。

このノートブックでは、TPU 上でモデルを訓練します。コードの一行の変更がそのために必要なことの総てです。

TPU 訓練に関連する最も最新のドキュメントはここで見つけられます。

 

セットアップ

Lightning は容易にインストールできます。単純に pip install pytorch-lightning

! pip install pytorch-lightning -qU

 

Colab TPU 互換 PyTorch/TPU wheels と依存性をインストールする

! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

 

MNISTDataModule を定義する

下で MNISTDataModule を定義します。データモジュールについて更に docsdatamodule ノートブック で学習できます。

class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

 

LitModel を定義する

下で、モデル LitMNIST を定義します。

class LitModel(pl.LightningModule):
    
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        self.save_hyperparameters()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss, prog_bar=False)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

 

TPU 訓練

Lightning は単一 TPU コアあるいは 8 TPU コア上の訓練をサポートします。

Trainer パラメータ tpu_cores は幾つの TPU コア上 (1 or 8) で訓練するか / 単一 TPU コア上 [1] で訓練するかを定義します。

単一 TPU 訓練のために、リストで TPU コア ID [1-8] を単に渡します。tpu_cores=[5] を設定すると TPU コア ID 5 上で訓練します。

tpu_cores=[5] 上で TPU コア ID 5 上で訓練します。

# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5])
# Train
trainer.fit(model, dm)

tpu_cores=1 で単一 TPU コア上で訓練します。

# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1)
# Train
trainer.fit(model, dm)
GPU available: False, used: False
TPU available: True, using: 1 TPU cores
training on 1 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params

tpu_cores=8 で 8 TPU コア上で訓練します。単一 TPU コア上で訓練した後では 8 TPU コア上でそれを訓練するにはノートブックをリスタートしなければならないかもしれません。

# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8)
# Train
trainer.fit(model, dm)
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params

 

以上