PyTorch Lightning 1.1 : notebooks : CIFAR10 ~94% ベースライン・チュートリアル

PyTorch Lightning 1.1: notebooks : CIFAR10 ~94% ベースライン・チュートリアル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/18/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

 

notebooks : CIFAR10 ~94% ベースライン・チュートリアル

  • 94% Baseline CIFAR10 – Lightning で Resnet を使用して CIFAR 10 上 ~94% 精度の素早いベースラインを達成します。

Resnet を Cifar10 上で 94% 精度にまで訓練します!

主要な重要点は :

  1. pl.LightningModule の configure_optimizers メソッドの異なる学習率スケジュールと頻度を持つ実験。
  2. Lightning による直接的な変更を伴う既存の Resnet アーキテクチャを使用します。

 

セットアップ

Lightning は容易にインストールできます。単純に pip install pytorch-lightning です。事前のデータモジュールとモデルについては bolts も確認してください。

! pip install pytorch-lightning pytorch-lightning-bolts -qU
# Run this if you intend to use TPUs
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torch.optim.swa_utils import AveragedModel, update_bn
import torchvision

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.metrics.functional import accuracy
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
pl.seed_everything(7);

 

CIFAR10 データモジュール

bolts から既存のデータモジュールをインポートして訓練とテスト変換を変更します。

batch_size = 32

train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])

test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])

cifar10_dm = CIFAR10DataModule(
    batch_size=batch_size,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    val_transforms=test_transforms,
)

 

Resnet

TorchVision からの事前の Resnet アーキテクチャを変更します。事前のアーキテクチャは入力として ImageNet 画像 (224×224) に基づいています。そのためそれを CIFAR10 images (32×32) のために修正する必要があります。

def create_model():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    return model

 

Lightning モジュール

カスタム学習率スケジューラを使用するためには configure_optimizers メソッドを確認してください。SGD を伴う OneCycleLR は 20-30 エポックでおよそ 92-93% 精度にそして 40-50 エポックで 93-94 % 精度に到達させます。https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate から様々な LR スケジュールで自由に実験してください。

class LitResnet(pl.LightningModule):
    def __init__(self, lr=0.05):
        super().__init__()

        self.save_hyperparameters()
        self.model = create_model()

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, 'val')

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, 'test')

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
        steps_per_epoch = 45000 // batch_size
        scheduler_dict = {
            'scheduler': OneCycleLR(optimizer, 0.1, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),
            'interval': 'step',
        }
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}
model = LitResnet(lr=0.05)
model.datamodule = cifar10_dm

trainer = pl.Trainer(
    progress_bar_refresh_rate=20,
    max_epochs=40,
    gpus=1,
    logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='resnet'),
    callbacks=[LearningRateMonitor(logging_interval='step')],
)

trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm);

 

ボーナス : パフォーマンス上のブーストを得るために確率的重み付け平均法を使用する

素早いパフォーマンス・ブーストを得るために torch.optim から SWA を利用します。また Lightning から 2, 3 のクールな特徴を示します :

  • 総てのエポックの終わりの後にコードを実行するために training_epoch_end を使用します。
  • SWA のためのこのラッパーで事前訓練モデルを直接使用します。
class SWAResnet(LitResnet):
    def __init__(self, trained_model, lr=0.01):
        super().__init__()

        self.save_hyperparameters('lr')
        self.model = trained_model
        self.swa_model = AveragedModel(self.model)

    def forward(self, x):
        out = self.swa_model(x)
        return F.log_softmax(out, dim=1)

    def training_epoch_end(self, training_step_outputs):
        self.swa_model.update_parameters(self.model)

    def validation_step(self, batch, batch_idx, stage=None):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        self.log(f'val_loss', loss, prog_bar=True)
        self.log(f'val_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
        return optimizer

    def on_train_end(self):
        update_bn(self.datamodule.train_dataloader(), self.swa_model, device=self.device)
swa_model = SWAResnet(model.model, lr=0.01)
swa_model.datamodule = cifar10_dm

swa_trainer = pl.Trainer(
    progress_bar_refresh_rate=20,
    max_epochs=20,
    gpus=1,
    logger=pl.loggers.TensorBoardLogger('lightning_logs/', name='swa_resnet'),
)

swa_trainer.fit(swa_model, cifar10_dm)
swa_trainer.test(swa_model, datamodule=cifar10_dm);
# Start tensorboard.
%reload_ext tensorboard
%tensorboard --logdir lightning_logs/
 

以上