PyTorch Lightning 1.1 : research: 転移学習 (CIFAR10, ResNet50)

PyTorch Lightning 1.1: research : 転移学習 (CIFAR10, ResNet50)
作成 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/01/2021 (1.1.x)

* 本ページは PyTorch Lighting の以下のドキュメントを参考に実装した転移学習のサンプルコードです:

* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

research: 転移学習 (CIFAR10, ResNet50)


! pip install pytorch-lightning pytorch-lightning-bolts -qU

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR, CyclicLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.swa_utils import AveragedModel, update_bn
import torchvision
 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, GPUStatsMonitor, EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

pl.seed_everything(7);

batch_size = 32
 
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])
 
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])
 
cifar10_dm = CIFAR10DataModule(
    batch_size=batch_size,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    val_transforms=test_transforms,
)

import torchvision.models as models

class ImagenetTransferLearning(pl.LightningModule):
    def __init__(self, lr=0.05, factor=0.5):
        super().__init__()
        self.save_hyperparameters()

        # init a pretrained resnet
        backbone = models.resnet50(pretrained=True)
        num_filters = backbone.fc.in_features # 2048
        layers = list(backbone.children())[:-1] # 分類層は取り除く
        layers[3] = nn.Identity()
        self.feature_extractor = nn.Sequential(*layers)

        # use the pretrained model to classify cifar-10 (10 image classes)
        num_target_classes = 10

        layers = [
                  nn.Linear(num_filters, 512),
                  nn.ReLU(inplace=True),
                  nn.Dropout(p=0.4),
                  nn.Linear(512, num_target_classes)
        ]
        self.classifier = nn.Sequential(*layers)

    def forward(self, x):
        self.feature_extractor.eval()
        with torch.no_grad():
            representations = self.feature_extractor(x).flatten(1)
        out = self.classifier(representations)
        logits = F.log_softmax(out, dim=1)
        return logits

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        return loss
 
    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
 
        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)
 
    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, 'val')
 
    def test_step(self, batch, batch_idx):
        self.evaluate(batch, 'test')
 
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
 
        return {
          'optimizer': optimizer,
          'lr_scheduler': ReduceLROnPlateau(optimizer, 'max', patience=5, factor=self.hparams.factor, verbose=True, threshold=0.0001, threshold_mode='abs', cooldown=1, min_lr=1e-5),
          'monitor': 'val_acc'
        }

    def xconfigure_optimizers(self):
        #print("###")
        #print(self.hparams)
        optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
        steps_per_epoch = 45000 // batch_size
        scheduler_dict = {
            #'scheduler': ExponentialLR(optimizer, gamma=0.1),
            #'interval': 'epoch',
            'scheduler': OneCycleLR(optimizer, max_lr=0.1, pct_start=0.25, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),
            #'scheduler': CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, step_size_up=steps_per_epoch*2, mode="triangular2"),
            #'scheduler': CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, step_size_up=steps_per_epoch, mode="exp_range", gamma=0.85),
            #'scheduler': CosineAnnealingLR(optimizer, T_max=200),
            'interval': 'step',
        }
        return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}

model = ImagenetTransferLearning(lr=0.001, factor=0.9)
model.datamodule = cifar10_dm

from torchsummary import summary

summary(model.feature_extractor.to('cuda'), (3, 32, 32))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 64, 16, 16]           9,408
       BatchNorm2d-2           [-1, 64, 16, 16]             128
              ReLU-3           [-1, 64, 16, 16]               0
          Identity-4           [-1, 64, 16, 16]               0
            Conv2d-5           [-1, 64, 16, 16]           4,096
       BatchNorm2d-6           [-1, 64, 16, 16]             128
              ReLU-7           [-1, 64, 16, 16]               0
            Conv2d-8           [-1, 64, 16, 16]          36,864
       BatchNorm2d-9           [-1, 64, 16, 16]             128
             ReLU-10           [-1, 64, 16, 16]               0
           Conv2d-11          [-1, 256, 16, 16]          16,384
      BatchNorm2d-12          [-1, 256, 16, 16]             512
           Conv2d-13          [-1, 256, 16, 16]          16,384
      BatchNorm2d-14          [-1, 256, 16, 16]             512
             ReLU-15          [-1, 256, 16, 16]               0
       Bottleneck-16          [-1, 256, 16, 16]               0
           Conv2d-17           [-1, 64, 16, 16]          16,384
      BatchNorm2d-18           [-1, 64, 16, 16]             128
             ReLU-19           [-1, 64, 16, 16]               0
           Conv2d-20           [-1, 64, 16, 16]          36,864
      BatchNorm2d-21           [-1, 64, 16, 16]             128
             ReLU-22           [-1, 64, 16, 16]               0
           Conv2d-23          [-1, 256, 16, 16]          16,384
      BatchNorm2d-24          [-1, 256, 16, 16]             512
             ReLU-25          [-1, 256, 16, 16]               0
       Bottleneck-26          [-1, 256, 16, 16]               0
           Conv2d-27           [-1, 64, 16, 16]          16,384
      BatchNorm2d-28           [-1, 64, 16, 16]             128
             ReLU-29           [-1, 64, 16, 16]               0
           Conv2d-30           [-1, 64, 16, 16]          36,864
      BatchNorm2d-31           [-1, 64, 16, 16]             128
             ReLU-32           [-1, 64, 16, 16]               0
           Conv2d-33          [-1, 256, 16, 16]          16,384
      BatchNorm2d-34          [-1, 256, 16, 16]             512
             ReLU-35          [-1, 256, 16, 16]               0
       Bottleneck-36          [-1, 256, 16, 16]               0
           Conv2d-37          [-1, 128, 16, 16]          32,768
      BatchNorm2d-38          [-1, 128, 16, 16]             256
             ReLU-39          [-1, 128, 16, 16]               0
           Conv2d-40            [-1, 128, 8, 8]         147,456
      BatchNorm2d-41            [-1, 128, 8, 8]             256
             ReLU-42            [-1, 128, 8, 8]               0
           Conv2d-43            [-1, 512, 8, 8]          65,536
      BatchNorm2d-44            [-1, 512, 8, 8]           1,024
           Conv2d-45            [-1, 512, 8, 8]         131,072
      BatchNorm2d-46            [-1, 512, 8, 8]           1,024
             ReLU-47            [-1, 512, 8, 8]               0
       Bottleneck-48            [-1, 512, 8, 8]               0
           Conv2d-49            [-1, 128, 8, 8]          65,536
      BatchNorm2d-50            [-1, 128, 8, 8]             256
             ReLU-51            [-1, 128, 8, 8]               0
           Conv2d-52            [-1, 128, 8, 8]         147,456
      BatchNorm2d-53            [-1, 128, 8, 8]             256
             ReLU-54            [-1, 128, 8, 8]               0
           Conv2d-55            [-1, 512, 8, 8]          65,536
      BatchNorm2d-56            [-1, 512, 8, 8]           1,024
             ReLU-57            [-1, 512, 8, 8]               0
       Bottleneck-58            [-1, 512, 8, 8]               0
           Conv2d-59            [-1, 128, 8, 8]          65,536
      BatchNorm2d-60            [-1, 128, 8, 8]             256
             ReLU-61            [-1, 128, 8, 8]               0
           Conv2d-62            [-1, 128, 8, 8]         147,456
      BatchNorm2d-63            [-1, 128, 8, 8]             256
             ReLU-64            [-1, 128, 8, 8]               0
           Conv2d-65            [-1, 512, 8, 8]          65,536
      BatchNorm2d-66            [-1, 512, 8, 8]           1,024
             ReLU-67            [-1, 512, 8, 8]               0
       Bottleneck-68            [-1, 512, 8, 8]               0
           Conv2d-69            [-1, 128, 8, 8]          65,536
      BatchNorm2d-70            [-1, 128, 8, 8]             256
             ReLU-71            [-1, 128, 8, 8]               0
           Conv2d-72            [-1, 128, 8, 8]         147,456
      BatchNorm2d-73            [-1, 128, 8, 8]             256
             ReLU-74            [-1, 128, 8, 8]               0
           Conv2d-75            [-1, 512, 8, 8]          65,536
      BatchNorm2d-76            [-1, 512, 8, 8]           1,024
             ReLU-77            [-1, 512, 8, 8]               0
       Bottleneck-78            [-1, 512, 8, 8]               0
           Conv2d-79            [-1, 256, 8, 8]         131,072
      BatchNorm2d-80            [-1, 256, 8, 8]             512
             ReLU-81            [-1, 256, 8, 8]               0
           Conv2d-82            [-1, 256, 4, 4]         589,824
      BatchNorm2d-83            [-1, 256, 4, 4]             512
             ReLU-84            [-1, 256, 4, 4]               0
           Conv2d-85           [-1, 1024, 4, 4]         262,144
      BatchNorm2d-86           [-1, 1024, 4, 4]           2,048
           Conv2d-87           [-1, 1024, 4, 4]         524,288
      BatchNorm2d-88           [-1, 1024, 4, 4]           2,048
             ReLU-89           [-1, 1024, 4, 4]               0
       Bottleneck-90           [-1, 1024, 4, 4]               0
           Conv2d-91            [-1, 256, 4, 4]         262,144
      BatchNorm2d-92            [-1, 256, 4, 4]             512
             ReLU-93            [-1, 256, 4, 4]               0
           Conv2d-94            [-1, 256, 4, 4]         589,824
      BatchNorm2d-95            [-1, 256, 4, 4]             512
             ReLU-96            [-1, 256, 4, 4]               0
           Conv2d-97           [-1, 1024, 4, 4]         262,144
      BatchNorm2d-98           [-1, 1024, 4, 4]           2,048
             ReLU-99           [-1, 1024, 4, 4]               0
      Bottleneck-100           [-1, 1024, 4, 4]               0
          Conv2d-101            [-1, 256, 4, 4]         262,144
     BatchNorm2d-102            [-1, 256, 4, 4]             512
            ReLU-103            [-1, 256, 4, 4]               0
          Conv2d-104            [-1, 256, 4, 4]         589,824
     BatchNorm2d-105            [-1, 256, 4, 4]             512
            ReLU-106            [-1, 256, 4, 4]               0
          Conv2d-107           [-1, 1024, 4, 4]         262,144
     BatchNorm2d-108           [-1, 1024, 4, 4]           2,048
            ReLU-109           [-1, 1024, 4, 4]               0
      Bottleneck-110           [-1, 1024, 4, 4]               0
          Conv2d-111            [-1, 256, 4, 4]         262,144
     BatchNorm2d-112            [-1, 256, 4, 4]             512
            ReLU-113            [-1, 256, 4, 4]               0
          Conv2d-114            [-1, 256, 4, 4]         589,824
     BatchNorm2d-115            [-1, 256, 4, 4]             512
            ReLU-116            [-1, 256, 4, 4]               0
          Conv2d-117           [-1, 1024, 4, 4]         262,144
     BatchNorm2d-118           [-1, 1024, 4, 4]           2,048
            ReLU-119           [-1, 1024, 4, 4]               0
      Bottleneck-120           [-1, 1024, 4, 4]               0
          Conv2d-121            [-1, 256, 4, 4]         262,144
     BatchNorm2d-122            [-1, 256, 4, 4]             512
            ReLU-123            [-1, 256, 4, 4]               0
          Conv2d-124            [-1, 256, 4, 4]         589,824
     BatchNorm2d-125            [-1, 256, 4, 4]             512
            ReLU-126            [-1, 256, 4, 4]               0
          Conv2d-127           [-1, 1024, 4, 4]         262,144
     BatchNorm2d-128           [-1, 1024, 4, 4]           2,048
            ReLU-129           [-1, 1024, 4, 4]               0
      Bottleneck-130           [-1, 1024, 4, 4]               0
          Conv2d-131            [-1, 256, 4, 4]         262,144
     BatchNorm2d-132            [-1, 256, 4, 4]             512
            ReLU-133            [-1, 256, 4, 4]               0
          Conv2d-134            [-1, 256, 4, 4]         589,824
     BatchNorm2d-135            [-1, 256, 4, 4]             512
            ReLU-136            [-1, 256, 4, 4]               0
          Conv2d-137           [-1, 1024, 4, 4]         262,144
     BatchNorm2d-138           [-1, 1024, 4, 4]           2,048
            ReLU-139           [-1, 1024, 4, 4]               0
      Bottleneck-140           [-1, 1024, 4, 4]               0
          Conv2d-141            [-1, 512, 4, 4]         524,288
     BatchNorm2d-142            [-1, 512, 4, 4]           1,024
            ReLU-143            [-1, 512, 4, 4]               0
          Conv2d-144            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-145            [-1, 512, 2, 2]           1,024
            ReLU-146            [-1, 512, 2, 2]               0
          Conv2d-147           [-1, 2048, 2, 2]       1,048,576
     BatchNorm2d-148           [-1, 2048, 2, 2]           4,096
          Conv2d-149           [-1, 2048, 2, 2]       2,097,152
     BatchNorm2d-150           [-1, 2048, 2, 2]           4,096
            ReLU-151           [-1, 2048, 2, 2]               0
      Bottleneck-152           [-1, 2048, 2, 2]               0
          Conv2d-153            [-1, 512, 2, 2]       1,048,576
     BatchNorm2d-154            [-1, 512, 2, 2]           1,024
            ReLU-155            [-1, 512, 2, 2]               0
          Conv2d-156            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-157            [-1, 512, 2, 2]           1,024
            ReLU-158            [-1, 512, 2, 2]               0
          Conv2d-159           [-1, 2048, 2, 2]       1,048,576
     BatchNorm2d-160           [-1, 2048, 2, 2]           4,096
            ReLU-161           [-1, 2048, 2, 2]               0
      Bottleneck-162           [-1, 2048, 2, 2]               0
          Conv2d-163            [-1, 512, 2, 2]       1,048,576
     BatchNorm2d-164            [-1, 512, 2, 2]           1,024
            ReLU-165            [-1, 512, 2, 2]               0
          Conv2d-166            [-1, 512, 2, 2]       2,359,296
     BatchNorm2d-167            [-1, 512, 2, 2]           1,024
            ReLU-168            [-1, 512, 2, 2]               0
          Conv2d-169           [-1, 2048, 2, 2]       1,048,576
     BatchNorm2d-170           [-1, 2048, 2, 2]           4,096
            ReLU-171           [-1, 2048, 2, 2]               0
      Bottleneck-172           [-1, 2048, 2, 2]               0
AdaptiveAvgPool2d-173           [-1, 2048, 1, 1]               0
================================================================
Total params: 23,508,032
Trainable params: 23,508,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 22.28
Params size (MB): 89.68
Estimated Total Size (MB): 111.97
----------------------------------------------------------------

%%time
 
model = ImagenetTransferLearning(lr=0.001, factor=0.9)
model.datamodule = cifar10_dm

trainer = pl.Trainer(
    gpus=1,
    max_epochs=50,
    progress_bar_refresh_rate=50,
    logger=pl.loggers.TensorBoardLogger('tblogs/', name='resnet50'),
    callbacks=[LearningRateMonitor(logging_interval='step')],
)
 
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm);
 

以上