PyTorch Lightning 1.1: notebooks : PyTorch Lightning で TPU 訓練 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/16/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- notebooks : TPU training with PyTorch Lightning
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
notebooks : PyTorch Lightning で TPU 訓練
- TPU 訓練 – Lightning で TPU を使用して MNIST 上でモデルを訓練します。
このノートブックでは、TPU 上でモデルを訓練します。コードの一行の変更がそのために必要なことの総てです。
TPU 訓練に関連する最も最新のドキュメントはここで見つけられます。
セットアップ
Lightning は容易にインストールできます。単純に pip install pytorch-lightning
! pip install pytorch-lightning -qU
Colab TPU 互換 PyTorch/TPU wheels と依存性をインストールする
! pip install cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.7-cp36-cp36m-linux_x86_64.whl
import torch from torch import nn import torch.nn.functional as F from torch.utils.data import random_split, DataLoader # Note - you must have torchvision installed for this example from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy
MNISTDataModule を定義する
下で MNISTDataModule を定義します。データモジュールについて更に docs と datamodule ノートブック で学習できます。
class MNISTDataModule(pl.LightningDataModule): def __init__(self, data_dir: str = './'): super().__init__() self.data_dir = data_dir self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) # self.dims is returned when you call dm.size() # Setting default dims here because we know them. # Could optionally be assigned dynamically in dm.setup() self.dims = (1, 28, 28) self.num_classes = 10 def prepare_data(self): # download MNIST(self.data_dir, train=True, download=True) MNIST(self.data_dir, train=False, download=True) def setup(self, stage=None): # Assign train/val datasets for use in dataloaders if stage == 'fit' or stage is None: mnist_full = MNIST(self.data_dir, train=True, transform=self.transform) self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000]) # Assign test dataset for use in dataloader(s) if stage == 'test' or stage is None: self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform) def train_dataloader(self): return DataLoader(self.mnist_train, batch_size=32) def val_dataloader(self): return DataLoader(self.mnist_val, batch_size=32) def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=32)
LitModel を定義する
下で、モデル LitMNIST を定義します。
class LitModel(pl.LightningModule): def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4): super().__init__() self.save_hyperparameters() self.model = nn.Sequential( nn.Flatten(), nn.Linear(channels * width * height, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, hidden_size), nn.ReLU(), nn.Dropout(0.1), nn.Linear(hidden_size, num_classes) ) def forward(self, x): x = self.model(x) return F.log_softmax(x, dim=1) def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) self.log('train_loss', loss, prog_bar=False) return loss def validation_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) self.log('val_loss', loss, prog_bar=True) self.log('val_acc', acc, prog_bar=True) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
TPU 訓練
Lightning は単一 TPU コアあるいは 8 TPU コア上の訓練をサポートします。
Trainer パラメータ tpu_cores は幾つの TPU コア上 (1 or 8) で訓練するか / 単一 TPU コア上 [1] で訓練するかを定義します。
単一 TPU 訓練のために、リストで TPU コア ID [1-8] を単に渡します。tpu_cores=[5] を設定すると TPU コア ID 5 上で訓練します。
tpu_cores=[5] 上で TPU コア ID 5 上で訓練します。
# Init DataModule dm = MNISTDataModule() # Init model from datamodule's attributes model = LitModel(*dm.size(), dm.num_classes) # Init trainer trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=[5]) # Train trainer.fit(model, dm)
tpu_cores=1 で単一 TPU コア上で訓練します。
# Init DataModule dm = MNISTDataModule() # Init model from datamodule's attributes model = LitModel(*dm.size(), dm.num_classes) # Init trainer trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=1) # Train trainer.fit(model, dm)
GPU available: False, used: False TPU available: True, using: 1 TPU cores training on 1 TPU cores INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None | Name | Type | Params ------------------------------------- 0 | model | Sequential | 55.1 K ------------------------------------- 55.1 K Trainable params 0 Non-trainable params 55.1 K Total params
tpu_cores=8 で 8 TPU コア上で訓練します。単一 TPU コア上で訓練した後では 8 TPU コア上でそれを訓練するにはノートブックをリスタートしなければならないかもしれません。
# Init DataModule dm = MNISTDataModule() # Init model from datamodule's attributes model = LitModel(*dm.size(), dm.num_classes) # Init trainer trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, tpu_cores=8) # Train trainer.fit(model, dm)
GPU available: False, used: False TPU available: True, using: 8 TPU cores training on 8 TPU cores INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None | Name | Type | Params ------------------------------------- 0 | model | Sequential | 55.1 K ------------------------------------- 55.1 K Trainable params 0 Non-trainable params 55.1 K Total params
以上