PyTorch Lightning 1.1: notebooks : PyTorch Lightning で TPU 訓練 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/16/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- notebooks : TPU training with PyTorch Lightning
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
| 人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
| PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) | ||
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |

notebooks : PyTorch Lightning で TPU 訓練
- TPU 訓練 – Lightning で TPU を使用して MNIST 上でモデルを訓練します。
このノートブックでは、TPU 上でモデルを訓練します。コードの一行の変更がそのために必要なことの総てです。
TPU 訓練に関連する最も最新のドキュメントはここで見つけられます。
セットアップ
Lightning は容易にインストールできます。単純に pip install pytorch-lightning
! pip install pytorch-lightning -qU
Colab TPU 互換 PyTorch/TPU wheels と依存性をインストールする
! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
import torch from torch import nn import torch.nn.functional as F from torch.utils.data import random_split, DataLoader # Note - you must have torchvision installed for this example from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy
MNISTDataModule を定義する
下で MNISTDataModule を定義します。データモジュールについて更に docs と datamodule ノートブック で学習できます。
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
LitModel を定義する
下で、モデル LitMNIST を定義します。
class LitModel(pl.LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
super().__init__()
self.save_hyperparameters()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_classes)
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log('train_loss', loss, prog_bar=False)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
TPU 訓練
Lightning は単一 TPU コアあるいは 8 TPU コア上の訓練をサポートします。
Trainer パラメータ tpu_cores は幾つの TPU コア上 (1 or 8) で訓練するか / 単一 TPU コア上 [1] で訓練するかを定義します。
単一 TPU 訓練のために、リストで TPU コア ID [1-8] を単に渡します。tpu_cores=[5] を設定すると TPU コア ID 5 上で訓練します。
tpu_cores=[5] 上で TPU コア ID 5 上で訓練します。
# Init DataModule dm = MNISTDataModule() # Init model from datamodule's attributes model = LitModel(*dm.size(), dm.num_classes) # Init trainer trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5]) # Train trainer.fit(model, dm)
tpu_cores=1 で単一 TPU コア上で訓練します。
# Init DataModule dm = MNISTDataModule() # Init model from datamodule's attributes model = LitModel(*dm.size(), dm.num_classes) # Init trainer trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1) # Train trainer.fit(model, dm)
GPU available: False, used: False TPU available: True, using: 1 TPU cores training on 1 TPU cores INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None | Name | Type | Params ------------------------------------- 0 | model | Sequential | 55.1 K ------------------------------------- 55.1 K Trainable params 0 Non-trainable params 55.1 K Total params
tpu_cores=8 で 8 TPU コア上で訓練します。単一 TPU コア上で訓練した後では 8 TPU コア上でそれを訓練するにはノートブックをリスタートしなければならないかもしれません。
# Init DataModule dm = MNISTDataModule() # Init model from datamodule's attributes model = LitModel(*dm.size(), dm.num_classes) # Init trainer trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8) # Train trainer.fit(model, dm)
GPU available: False, used: False TPU available: True, using: 8 TPU cores training on 8 TPU cores INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None | Name | Type | Params ------------------------------------- 0 | model | Sequential | 55.1 K ------------------------------------- 55.1 K Trainable params 0 Non-trainable params 55.1 K Total params

以上