PyTorch Lightning 1.1: notebooks : CIFAR10 ~94% ベースライン・チュートリアル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/18/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- notebooks : PyTorch Lightning CIFAR10 ~94% Baseline Tutorial
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
notebooks : CIFAR10 ~94% ベースライン・チュートリアル
- 94% Baseline CIFAR10 – Lightning で Resnet を使用して CIFAR 10 上 ~94% 精度の素早いベースラインを達成します。
Resnet を Cifar10 上で 94% 精度にまで訓練します!
主要な重要点は :
- pl.LightningModule の configure_optimizers メソッドの異なる学習率スケジュールと頻度を持つ実験。
- Lightning による直接的な変更を伴う既存の Resnet アーキテクチャを使用します。
セットアップ
Lightning は容易にインストールできます。単純に pip install pytorch-lightning です。事前のデータモジュールとモデルについては bolts も確認してください。
! pip install pytorch-lightning pytorch-lightning-bolts -qU
# Run this if you intend to use TPUs # !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py # !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
import torch import torch.nn as nn import torch.nn.functional as F from torch.optim.lr_scheduler import OneCycleLR from torch.optim.swa_utils import AveragedModel, update_bn import torchvision import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.metrics.functional import accuracy from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
pl.seed_everything(7);
CIFAR10 データモジュール
bolts から既存のデータモジュールをインポートして訓練とテスト変換を変更します。
batch_size = 32 train_transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomCrop(32, padding=4), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), cifar10_normalization(), ]) test_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), cifar10_normalization(), ]) cifar10_dm = CIFAR10DataModule( batch_size=batch_size, train_transforms=train_transforms, test_transforms=test_transforms, val_transforms=test_transforms, )
Resnet
TorchVision からの事前の Resnet アーキテクチャを変更します。事前のアーキテクチャは入力として ImageNet 画像 (224×224) に基づいています。そのためそれを CIFAR10 images (32×32) のために修正する必要があります。
def create_model(): model = torchvision.models.resnet18(pretrained=False, num_classes=10) model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) model.maxpool = nn.Identity() return model
Lightning モジュール
カスタム学習率スケジューラを使用するためには configure_optimizers メソッドを確認してください。SGD を伴う OneCycleLR は 20-30 エポックでおよそ 92-93% 精度にそして 40-50 エポックで 93-94 % 精度に到達させます。https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate から様々な LR スケジュールで自由に実験してください。
class LitResnet(pl.LightningModule): def __init__(self, lr=0.05): super().__init__() self.save_hyperparameters() self.model = create_model() def forward(self, x): out = self.model(x) return F.log_softmax(out, dim=1) def training_step(self, batch, batch_idx): x, y = batch logits = F.log_softmax(self.model(x), dim=1) loss = F.nll_loss(logits, y) self.log('train_loss', loss) return loss def evaluate(self, batch, stage=None): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) if stage: self.log(f'{stage}_loss', loss, prog_bar=True) self.log(f'{stage}_acc', acc, prog_bar=True) def validation_step(self, batch, batch_idx): self.evaluate(batch, 'val') def test_step(self, batch, batch_idx): self.evaluate(batch, 'test') def configure_optimizers(self): optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4) steps_per_epoch = 45000 // batch_size scheduler_dict = { 'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch), 'interval': 'step', } return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}
model = LitResnet(lr=0.05) model.datamodule = cifar10_dm trainer = pl.Trainer( progress_bar_refresh_rate=20, max_epochs=40, gpus=1, logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'), callbacks=[LearningRateMonitor(logging_interval='step')], ) trainer.fit(model, cifar10_dm) trainer.test(model, datamodule=cifar10_dm);
ボーナス : パフォーマンス上のブーストを得るために確率的重み付け平均法を使用する
素早いパフォーマンス・ブーストを得るために torch.optim から SWA を利用します。また Lightning から 2, 3 のクールな特徴を示します :
- 総てのエポックの終わりの後にコードを実行するために training_epoch_end を使用します。
- SWA のためのこのラッパーで事前訓練モデルを直接使用します。
class SWAResnet(LitResnet): def __init__(self, trained_model, lr=0.01): super().__init__() self.save_hyperparameters('lr') self.model = trained_model self.swa_model = AveragedModel(self.model) def forward(self, x): out = self.swa_model(x) return F.log_softmax(out, dim=1) def training_epoch_end(self, training_step_outputs): self.swa_model.update_parameters(self.model) def validation_step(self, batch, batch_idx, stage=None): x, y = batch logits = F.log_softmax(self.model(x), dim=1) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim=1) acc = accuracy(preds, y) self.log(f'val_loss', loss, prog_bar=True) self.log(f'val_acc', acc, prog_bar=True) def configure_optimizers(self): optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4) return optimizer def on_train_end(self): update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)
swa_model = SWAResnet(model.model, lr=0.01) swa_model.datamodule = cifar10_dm swa_trainer = pl.Trainer( progress_bar_refresh_rate=20, max_epochs=20, gpus=1, logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'), ) swa_trainer.fit(swa_model, cifar10_dm) swa_trainer.test(swa_model, datamodule=cifar10_dm);
# Start tensorboard. %reload_ext tensorboard %tensorboard --logdir lightning_logs/
以上