PyTorch Lightning 1.1: notebooks : CIFAR10 ~94% ベースライン・チュートリアル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/18/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- notebooks : PyTorch Lightning CIFAR10 ~94% Baseline Tutorial
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
| 人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
| PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) | ||
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |
![]()
notebooks : CIFAR10 ~94% ベースライン・チュートリアル
- 94% Baseline CIFAR10 – Lightning で Resnet を使用して CIFAR 10 上 ~94% 精度の素早いベースラインを達成します。
Resnet を Cifar10 上で 94% 精度にまで訓練します!
主要な重要点は :
- pl.LightningModule の configure_optimizers メソッドの異なる学習率スケジュールと頻度を持つ実験。
- Lightning による直接的な変更を伴う既存の Resnet アーキテクチャを使用します。
セットアップ
Lightning は容易にインストールできます。単純に pip install pytorch-lightning です。事前のデータモジュールとモデルについては bolts も確認してください。
! pip install pytorch-lightning pytorch-lightning-bolts -qU
# Run this if you intend to use TPUs # !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py # !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
import torch import torch.nn as nn import torch.nn.functional as F from torch.optim.lr_scheduler import OneCycleLR from torch.optim.swa_utils import AveragedModel, update_bn import torchvision import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.metrics.functional import accuracy from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
pl.seed_everything(7);
CIFAR10 データモジュール
bolts から既存のデータモジュールをインポートして訓練とテスト変換を変更します。
batch_size = 32
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
test_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
cifar10_dm = CIFAR10DataModule(
batch_size=batch_size,
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms,
)
Resnet
TorchVision からの事前の Resnet アーキテクチャを変更します。事前のアーキテクチャは入力として ImageNet 画像 (224×224) に基づいています。そのためそれを CIFAR10 images (32×32) のために修正する必要があります。
def create_model():
model = torchvision.models.resnet18(pretrained=False, num_classes=10)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
return model
Lightning モジュール
カスタム学習率スケジューラを使用するためには configure_optimizers メソッドを確認してください。SGD を伴う OneCycleLR は 20-30 エポックでおよそ 92-93% 精度にそして 40-50 エポックで 93-94 % 精度に到達させます。https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate から様々な LR スケジュールで自由に実験してください。
class LitResnet(pl.LightningModule):
def __init__(self, lr=0.05):
super().__init__()
self.save_hyperparameters()
self.model = create_model()
def forward(self, x):
out = self.model(x)
return F.log_softmax(out, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = F.log_softmax(self.model(x), dim=1)
loss = F.nll_loss(logits, y)
self.log('train_loss', loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f'{stage}_loss', loss, prog_bar=True)
self.log(f'{stage}_acc', acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, 'val')
def test_step(self, batch, batch_idx):
self.evaluate(batch, 'test')
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
steps_per_epoch = 45000 // batch_size
scheduler_dict = {
'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),
'interval': 'step',
}
return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}
model = LitResnet(lr=0.05)
model.datamodule = cifar10_dm
trainer = pl.Trainer(
progress_bar_refresh_rate=20,
max_epochs=40,
gpus=1,
logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),
callbacks=[LearningRateMonitor(logging_interval='step')],
)
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm);
ボーナス : パフォーマンス上のブーストを得るために確率的重み付け平均法を使用する
素早いパフォーマンス・ブーストを得るために torch.optim から SWA を利用します。また Lightning から 2, 3 のクールな特徴を示します :
- 総てのエポックの終わりの後にコードを実行するために training_epoch_end を使用します。
- SWA のためのこのラッパーで事前訓練モデルを直接使用します。
class SWAResnet(LitResnet):
def __init__(self, trained_model, lr=0.01):
super().__init__()
self.save_hyperparameters('lr')
self.model = trained_model
self.swa_model = AveragedModel(self.model)
def forward(self, x):
out = self.swa_model(x)
return F.log_softmax(out, dim=1)
def training_epoch_end(self, training_step_outputs):
self.swa_model.update_parameters(self.model)
def validation_step(self, batch, batch_idx, stage=None):
x, y = batch
logits = F.log_softmax(self.model(x), dim=1)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log(f'val_loss', loss, prog_bar=True)
self.log(f'val_acc', acc, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
return optimizer
def on_train_end(self):
update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)
swa_model = SWAResnet(model.model, lr=0.01)
swa_model.datamodule = cifar10_dm
swa_trainer = pl.Trainer(
progress_bar_refresh_rate=20,
max_epochs=20,
gpus=1,
logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),
)
swa_trainer.fit(swa_model, cifar10_dm)
swa_trainer.test(swa_model, datamodule=cifar10_dm);
# Start tensorboard. %reload_ext tensorboard %tensorboard --logdir lightning_logs/
以上