PyTorch Lightning 1.1: notebooks : CIFAR10 ~94% ベースライン・チュートリアル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/18/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- notebooks : PyTorch Lightning CIFAR10 ~94% Baseline Tutorial
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
notebooks : CIFAR10 ~94% ベースライン・チュートリアル
- 94% Baseline CIFAR10 – Lightning で Resnet を使用して CIFAR 10 上 ~94% 精度の素早いベースラインを達成します。
Resnet を Cifar10 上で 94% 精度にまで訓練します!
主要な重要点は :
- pl.LightningModule の configure_optimizers メソッドの異なる学習率スケジュールと頻度を持つ実験。
- Lightning による直接的な変更を伴う既存の Resnet アーキテクチャを使用します。
セットアップ
Lightning は容易にインストールできます。単純に pip install pytorch-lightning です。事前のデータモジュールとモデルについては bolts も確認してください。
1 | ! pip install pytorch - lightning pytorch - lightning - bolts - qU |
1 2 3 | # Run this if you intend to use TPUs # !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py # !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev |
1 2 3 4 5 6 7 8 9 10 11 12 | import torch import torch.nn as nn import torch.nn.functional as F from torch.optim.lr_scheduler import OneCycleLR from torch.optim.swa_utils import AveragedModel, update_bn import torchvision import pytorch_lightning as pl from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.metrics.functional import accuracy from pl_bolts.datamodules import CIFAR10DataModule from pl_bolts.transforms.dataset_normalizations import cifar10_normalization |
1 | pl.seed_everything( 7 ); |
CIFAR10 データモジュール
bolts から既存のデータモジュールをインポートして訓練とテスト変換を変更します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | batch_size = 32 train_transforms = torchvision.transforms.Compose([ torchvision.transforms.RandomCrop( 32 , padding = 4 ), torchvision.transforms.RandomHorizontalFlip(), torchvision.transforms.ToTensor(), cifar10_normalization(), ]) test_transforms = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), cifar10_normalization(), ]) cifar10_dm = CIFAR10DataModule( batch_size = batch_size, train_transforms = train_transforms, test_transforms = test_transforms, val_transforms = test_transforms, ) |
Resnet
TorchVision からの事前の Resnet アーキテクチャを変更します。事前のアーキテクチャは入力として ImageNet 画像 (224×224) に基づいています。そのためそれを CIFAR10 images (32×32) のために修正する必要があります。
1 2 3 4 5 | def create_model(): model = torchvision.models.resnet18(pretrained = False , num_classes = 10 ) model.conv1 = nn.Conv2d( 3 , 64 , kernel_size = ( 3 , 3 ), stride = ( 1 , 1 ), padding = ( 1 , 1 ), bias = False ) model.maxpool = nn.Identity() return model |
Lightning モジュール
カスタム学習率スケジューラを使用するためには configure_optimizers メソッドを確認してください。SGD を伴う OneCycleLR は 20-30 エポックでおよそ 92-93% 精度にそして 40-50 エポックで 93-94 % 精度に到達させます。https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate から様々な LR スケジュールで自由に実験してください。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | class LitResnet(pl.LightningModule): def __init__( self , lr = 0.05 ): super ().__init__() self .save_hyperparameters() self .model = create_model() def forward( self , x): out = self .model(x) return F.log_softmax(out, dim = 1 ) def training_step( self , batch, batch_idx): x, y = batch logits = F.log_softmax( self .model(x), dim = 1 ) loss = F.nll_loss(logits, y) self .log( 'train_loss' , loss) return loss def evaluate( self , batch, stage = None ): x, y = batch logits = self (x) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim = 1 ) acc = accuracy(preds, y) if stage: self .log(f '{stage}_loss' , loss, prog_bar = True ) self .log(f '{stage}_acc' , acc, prog_bar = True ) def validation_step( self , batch, batch_idx): self .evaluate(batch, 'val' ) def test_step( self , batch, batch_idx): self .evaluate(batch, 'test' ) def configure_optimizers( self ): optimizer = torch.optim.SGD( self .parameters(), lr = self .hparams.lr, momentum = 0.9 , weight_decay = 5e - 4 ) steps_per_epoch = 45000 / / batch_size scheduler_dict = { 'scheduler' : OneCycleLR(optimizer, 0.1 , epochs = self .trainer.max_epochs, steps_per_epoch = steps_per_epoch), 'interval' : 'step' , } return { 'optimizer' : optimizer, 'lr_scheduler' : scheduler_dict} |
1 2 3 4 5 6 7 8 9 10 11 12 13 | model = LitResnet(lr = 0.05 ) model.datamodule = cifar10_dm trainer = pl.Trainer( progress_bar_refresh_rate = 20 , max_epochs = 40 , gpus = 1 , logger = pl.loggers.TensorBoardLogger( 'lightning_logs/' , name = 'resnet' ), callbacks = [LearningRateMonitor(logging_interval = 'step' )], ) trainer.fit(model, cifar10_dm) trainer.test(model, datamodule = cifar10_dm); |
ボーナス : パフォーマンス上のブーストを得るために確率的重み付け平均法を使用する
素早いパフォーマンス・ブーストを得るために torch.optim から SWA を利用します。また Lightning から 2, 3 のクールな特徴を示します :
- 総てのエポックの終わりの後にコードを実行するために training_epoch_end を使用します。
- SWA のためのこのラッパーで事前訓練モデルを直接使用します。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | class SWAResnet(LitResnet): def __init__( self , trained_model, lr = 0.01 ): super ().__init__() self .save_hyperparameters( 'lr' ) self .model = trained_model self .swa_model = AveragedModel( self .model) def forward( self , x): out = self .swa_model(x) return F.log_softmax(out, dim = 1 ) def training_epoch_end( self , training_step_outputs): self .swa_model.update_parameters( self .model) def validation_step( self , batch, batch_idx, stage = None ): x, y = batch logits = F.log_softmax( self .model(x), dim = 1 ) loss = F.nll_loss(logits, y) preds = torch.argmax(logits, dim = 1 ) acc = accuracy(preds, y) self .log(f 'val_loss' , loss, prog_bar = True ) self .log(f 'val_acc' , acc, prog_bar = True ) def configure_optimizers( self ): optimizer = torch.optim.SGD( self .model.parameters(), lr = self .hparams.lr, momentum = 0.9 , weight_decay = 5e - 4 ) return optimizer def on_train_end( self ): update_bn( self .datamodule.train_dataloader(), self .swa_model, device = self .device) |
1 2 3 4 5 6 7 8 9 10 11 12 | swa_model = SWAResnet(model.model, lr = 0.01 ) swa_model.datamodule = cifar10_dm swa_trainer = pl.Trainer( progress_bar_refresh_rate = 20 , max_epochs = 20 , gpus = 1 , logger = pl.loggers.TensorBoardLogger( 'lightning_logs/' , name = 'swa_resnet' ), ) swa_trainer.fit(swa_model, cifar10_dm) swa_trainer.test(swa_model, datamodule = cifar10_dm); |
1 2 3 | # Start tensorboard. % reload_ext tensorboard % tensorboard - - logdir lightning_logs / |
以上