PyTorch Lightning 1.1: notebooks : Lightning へのイントロダクション (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/12/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- notebooks : Introduction to Pytorch Lightning
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
| 人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
| PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) | ||
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |

notebooks : Lightning へのイントロダクション
- MNIST Hello World – 古典的な MNIST 手書き数字データセット上で最初の Lightning モジュールを訓練します。
このノートブックでは、MNIST 手書き数字データセット 上で訓練するためのモデルを準備することにより lightning の基本を調べます。
セットアップ
Lightning は容易にインストールできます。単純に pip install pytorch-lightning
! pip install pytorch-lightning --quiet
import os import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST from torchvision import transforms import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy
最も単純なサンプル
ここに訓練ループだけを持つ (検証、テストはありません) 最も単純な最小限のサンプルがあります。
念頭に置いてください – LightningModule は PyTorch nn.Module です – それは幾つかのより役立つ特徴を持つだけです。
class MNISTModel(pl.LightningModule):
def __init__(self):
super(MNISTModel, self).__init__()
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, batch_nb):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
Trainer を使用することで以下を自動的に得ます :
- Tensorboard ロギング
- Model チェックポインティング
- 訓練と検証ループ
- early-stopping
# Init our model mnist_model = MNISTModel() # Init DataLoader from MNIST Dataset train_ds = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) train_loader = DataLoader(train_ds, batch_size=32) # Initialize a trainer trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20) # Train the model ⚡ trainer.fit(mnist_model, train_loader)
より完全な MNIST Lightning モジュール・サンプル
That wasn’t so hard was it?
私達は手がけ始めた今、少しだけより深く潜りそして MNIST のためより完全な LightningModule を書きましょう…
今回、総てのデータセット固有のピースを LightningModule で直接焼きます (= bake in)。このようにして、スクリプトの最初でそれを実行することを望むたびに extra コードを書くことを回避できます。
以下の組込み関数が行なっていることに注意する :
- prepare_data()
- これはデータセットをダウンロードできるところです。望まれるデータセットをポイントしてデータセットがそこで見つからない場合には torchvision の MNIST データセット・クラスにダウンロードすることを依頼します。
- この関数内でどのような状態割当ても行なわないことに注意してください (i.e. self.something = …)。
- setup(stage)
- データをファイルからロードして各分割 (train, val, test) のために PyTorch tensor データセットを準備します。
- setup は ‘stage’ arg を想定します、これはロジックを ‘fit’ と ‘test’ のために分離するために使用されます。
- 総てのデータセットを一度にロードすることが嫌でないならば、stage に None が渡されたときにいつでも ‘fit’ 関連のセットアップと ‘test’ 関連のセットアップの両者を実行することを許す条件をセットアップできます (あるいはそれを一緒に無視してどのような条件も除外する (条件))。
- これは総ての GPU に渡り実行されそしてここで状態割当てを行なうことは安全であることに注意してください。
- x_dataloader()
- train_dataloader(), val_dataloader() と test_dataloader() は総て setup() で準備したそれぞれのデータセットをラップすることにより作成された PyTorch DataLoader インスタンスを返します。
class LitMNIST(pl.LightningModule):
def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):
super().__init__()
# We hardcode dataset specific stuff here.
self.data_dir = data_dir
self.num_classes = 10
self.dims = (1, 28, 28)
channels, width, height = self.dims
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
self.hidden_size = hidden_size
self.learning_rate = learning_rate
# Build model
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, self.num_classes)
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
####################
# DATA RELATED HOOKS
####################
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
model = LitMNIST() trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20) trainer.fit(model)
テスト
モデルをテストするため、trainer.test(model) を呼び出します。
あるいは、モデルを丁度訓練したのであれば、trainer.test() を単に呼び出すこともできて Lightning は (val_loss で条件付けられた) 最善のセーブされたチェックポイントを使用して自動的にテストします。
trainer.test()
ボーナス Tip
訓練を続けたいだけの回数 trainer.fit(model) を呼び出し続けることができます。
trainer.fit(model)
Colab では、Lightning が作成したログを見るために TensorBoard マジック関数を利用できます!
# Start tensorboard. %load_ext tensorboard %tensorboard --logdir lightning_logs/

以上