PyTorch Lightning 1.1 : research: CIFAR100 (ShuffleNet)

PyTorch Lightning 1.1: research : CIFAR100 (ShuffleNet)
作成 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/24/2021 (1.1.x)

* 本ページは以下の CIFAR10 用リソースを参考に CIFAR100 で遂行した実験結果のレポートです:

* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

research: CIFAR100 (ShuffleNet)

仕様

  • Total params: 1,360,896 (1.4M)
  • Trainable params: 1,360,896
  • Non-trainable params: 0

 

結果

100 エポック

  • {‘test_acc’: 0.6672999858856201, ‘test_loss’: 1.4519164562225342}
  • Wall time: 1h 26min 15s
  • Tesla T4
  • ReduceLROnPlateau

 

CIFAR 100 DM

from typing import Any, Callable, Optional, Sequence, Union
 
from pl_bolts.datamodules.vision_datamodule import VisionDataModule
#from pl_bolts.datasets import TrialCIFAR10
#from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg
 
if _TORCHVISION_AVAILABLE:
    from torchvision import transforms
    #from torchvision import transforms as transform_lib
    from torchvision.datasets import CIFAR100
else:  # pragma: no cover
    warn_missing_pkg('torchvision')
    CIFAR100 = None
def cifar100_normalization():
    if not _TORCHVISION_AVAILABLE:  # pragma: no cover
        raise ModuleNotFoundError(
            'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.'
        )

    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [129.3, 124.1, 112.4]],
        std=[x / 255.0 for x in [68.2, 65.4, 70.4]],
        # cifar10
        #mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        #std=[x / 255.0 for x in [63.0, 62.1, 66.7]],
    )
    return normalize
class CIFAR100DataModule(VisionDataModule):
    """
    .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/
        Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png
        :width: 400
        :alt: CIFAR-10
    Specs:
        - 10 classes (1 per class)
        - Each image is (3 x 32 x 32)
    Standard CIFAR10, train, val, test splits and transforms
    Transforms::
        mnist_transforms = transform_lib.Compose([
            transform_lib.ToTensor(),
            transforms.Normalize(
                mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
                std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
            )
        ])
    Example::
        from pl_bolts.datamodules import CIFAR10DataModule
        dm = CIFAR10DataModule(PATH)
        model = LitModel()
        Trainer().fit(model, datamodule=dm)
    Or you can set your own transforms
    Example::
        dm.train_transforms = ...
        dm.test_transforms = ...
        dm.val_transforms  = ...
    """
    name = "cifar100"
    dataset_cls = CIFAR100
    dims = (3, 32, 32)

    def __init__(
        self,
        data_dir: Optional[str] = None,
        val_split: Union[int, float] = 0.2,
        num_workers: int = 16,
        normalize: bool = False,
        batch_size: int = 32,
        seed: int = 42,
        shuffle: bool = False,
        pin_memory: bool = False,
        drop_last: bool = False,
        *args: Any,
        **kwargs: Any,
    ) -> None:
        """
        Args:
            data_dir: Where to save/load the data
            val_split: Percent (float) or number (int) of samples to use for the validation split
            num_workers: How many workers to use for loading data
            normalize: If true applies image normalize
            batch_size: How many samples per batch to load
            seed: Random seed to be used for train/val/test splits
            shuffle: If true shuffles the train data every epoch
            pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
                        returning them
            drop_last: If true drops the last incomplete batch
        """
        super().__init__(  # type: ignore[misc]
            data_dir=data_dir,
            val_split=val_split,
            num_workers=num_workers,
            normalize=normalize,
            batch_size=batch_size,
            seed=seed,
            shuffle=shuffle,
            pin_memory=pin_memory,
            drop_last=drop_last,
            *args,
            **kwargs,
        )

    @property
    def num_samples(self) -> int:
        train_len, _ = self._get_splits(len_dataset=50_000)
        return train_len

    @property
    def num_classes(self) -> int:
        """
        Return:
            10
        """
        return 100

    def default_transforms(self) -> Callable:
        if self.normalize:
            cf100_transforms = transforms.Compose([transform_lib.ToTensor(), cifar100_normalization()])
        else:
            cf100_transforms = transforms.Compose([transform_lib.ToTensor()])

        return cf100_transforms

 

モデル

import torch
import torch.nn as nn
import torch.nn.functional as F
def channel_split(x, split):
    """split a tensor into two pieces along channel dimension
    Args:
        x: input tensor
        split:(int) channel size for each pieces
    """
    assert x.size(1) == split * 2
    return torch.split(x, split, dim=1)

def channel_shuffle(x, groups):
    """channel shuffle operation
    Args:
        x: input tensor
        groups: input branch number
    """

    batch_size, channels, height, width = x.size()
    channels_per_group = int(channels // groups)

    x = x.view(batch_size, groups, channels_per_group, height, width)
    x = x.transpose(1, 2).contiguous()
    x = x.view(batch_size, -1, height, width)

    return x

class ShuffleUnit(nn.Module):

    def __init__(self, in_channels, out_channels, stride):
        super().__init__()

        self.stride = stride
        self.in_channels = in_channels
        self.out_channels = out_channels

        if stride != 1 or in_channels != out_channels:
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, in_channels, 1),
                nn.BatchNorm2d(in_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels),
                nn.BatchNorm2d(in_channels),
                nn.Conv2d(in_channels, int(out_channels / 2), 1),
                nn.BatchNorm2d(int(out_channels / 2)),
                nn.ReLU(inplace=True)
            )

            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels),
                nn.BatchNorm2d(in_channels),
                nn.Conv2d(in_channels, int(out_channels / 2), 1),
                nn.BatchNorm2d(int(out_channels / 2)),
                nn.ReLU(inplace=True)
            )
        else:
            self.shortcut = nn.Sequential()

            in_channels = int(in_channels / 2)
            self.residual = nn.Sequential(
                nn.Conv2d(in_channels, in_channels, 1),
                nn.BatchNorm2d(in_channels),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels, in_channels, 3, stride=stride, padding=1, groups=in_channels),
                nn.BatchNorm2d(in_channels),
                nn.Conv2d(in_channels, in_channels, 1),
                nn.BatchNorm2d(in_channels),
                nn.ReLU(inplace=True)
            )


    def forward(self, x):

        if self.stride == 1 and self.out_channels == self.in_channels:
            shortcut, residual = channel_split(x, int(self.in_channels / 2))
        else:
            shortcut = x
            residual = x

        shortcut = self.shortcut(shortcut)
        residual = self.residual(residual)
        x = torch.cat([shortcut, residual], dim=1)
        x = channel_shuffle(x, 2)

        return x

class ShuffleNetV2(nn.Module):

    def __init__(self, ratio=1, class_num=100):
        super().__init__()
        if ratio == 0.5:
            out_channels = [48, 96, 192, 1024]
        elif ratio == 1:
            out_channels = [116, 232, 464, 1024]
        elif ratio == 1.5:
            out_channels = [176, 352, 704, 1024]
        elif ratio == 2:
            out_channels = [244, 488, 976, 2048]
        else:
            ValueError('unsupported ratio number')

        self.pre = nn.Sequential(
            nn.Conv2d(3, 24, 3, padding=1),
            nn.BatchNorm2d(24)
        )

        self.stage2 = self._make_stage(24, out_channels[0], 3)
        self.stage3 = self._make_stage(out_channels[0], out_channels[1], 7)
        self.stage4 = self._make_stage(out_channels[1], out_channels[2], 3)
        self.conv5 = nn.Sequential(
            nn.Conv2d(out_channels[2], out_channels[3], 1),
            nn.BatchNorm2d(out_channels[3]),
            nn.ReLU(inplace=True)
        )

        self.fc = nn.Linear(out_channels[3], class_num)

    def forward(self, x):
        x = self.pre(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

    def _make_stage(self, in_channels, out_channels, repeat):
        layers = []
        layers.append(ShuffleUnit(in_channels, out_channels, 2))

        while repeat:
            layers.append(ShuffleUnit(out_channels, out_channels, 1))
            repeat -= 1

        return nn.Sequential(*layers)

def shufflenetv2():
    return ShuffleNetV2()
net = squeezenet()
print(net)
y = net(torch.randn(1, 3, 32, 32))
print(y.size())
SqueezeNet(
  (stem): Sequential(
    (0): Conv2d(3, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fire2): Fire(
    (squeeze): Sequential(
      (0): Conv2d(96, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire3): Fire(
    (squeeze): Sequential(
      (0): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire4): Fire(
    (squeeze): Sequential(
      (0): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire5): Fire(
    (squeeze): Sequential(
      (0): Conv2d(256, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(32, 128, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire6): Fire(
    (squeeze): Sequential(
      (0): Conv2d(256, 48, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire7): Fire(
    (squeeze): Sequential(
      (0): Conv2d(384, 48, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(48, 192, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(48, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire8): Fire(
    (squeeze): Sequential(
      (0): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (fire9): Fire(
    (squeeze): Sequential(
      (0): Conv2d(512, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_1x1): Sequential(
      (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (expand_3x3): Sequential(
      (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (conv10): Conv2d(512, 100, kernel_size=(1, 1), stride=(1, 1))
  (avg): AdaptiveAvgPool2d(output_size=1)
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
torch.Size([1, 100])
from torchsummary import summary
 
summary(shufflenetv2().to('cuda'), (3, 32, 32))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 24, 32, 32]             672
       BatchNorm2d-2           [-1, 24, 32, 32]              48
            Conv2d-3           [-1, 24, 16, 16]             240
       BatchNorm2d-4           [-1, 24, 16, 16]              48
            Conv2d-5           [-1, 58, 16, 16]           1,450
       BatchNorm2d-6           [-1, 58, 16, 16]             116
              ReLU-7           [-1, 58, 16, 16]               0
            Conv2d-8           [-1, 24, 32, 32]             600
       BatchNorm2d-9           [-1, 24, 32, 32]              48
             ReLU-10           [-1, 24, 32, 32]               0
           Conv2d-11           [-1, 24, 16, 16]             240
      BatchNorm2d-12           [-1, 24, 16, 16]              48
           Conv2d-13           [-1, 58, 16, 16]           1,450
      BatchNorm2d-14           [-1, 58, 16, 16]             116
             ReLU-15           [-1, 58, 16, 16]               0
      ShuffleUnit-16          [-1, 116, 16, 16]               0
           Conv2d-17           [-1, 58, 16, 16]           3,422
      BatchNorm2d-18           [-1, 58, 16, 16]             116
             ReLU-19           [-1, 58, 16, 16]               0
           Conv2d-20           [-1, 58, 16, 16]             580
      BatchNorm2d-21           [-1, 58, 16, 16]             116
           Conv2d-22           [-1, 58, 16, 16]           3,422
      BatchNorm2d-23           [-1, 58, 16, 16]             116
             ReLU-24           [-1, 58, 16, 16]               0
      ShuffleUnit-25          [-1, 116, 16, 16]               0
           Conv2d-26           [-1, 58, 16, 16]           3,422
      BatchNorm2d-27           [-1, 58, 16, 16]             116
             ReLU-28           [-1, 58, 16, 16]               0
           Conv2d-29           [-1, 58, 16, 16]             580
      BatchNorm2d-30           [-1, 58, 16, 16]             116
           Conv2d-31           [-1, 58, 16, 16]           3,422
      BatchNorm2d-32           [-1, 58, 16, 16]             116
             ReLU-33           [-1, 58, 16, 16]               0
      ShuffleUnit-34          [-1, 116, 16, 16]               0
           Conv2d-35           [-1, 58, 16, 16]           3,422
      BatchNorm2d-36           [-1, 58, 16, 16]             116
             ReLU-37           [-1, 58, 16, 16]               0
           Conv2d-38           [-1, 58, 16, 16]             580
      BatchNorm2d-39           [-1, 58, 16, 16]             116
           Conv2d-40           [-1, 58, 16, 16]           3,422
      BatchNorm2d-41           [-1, 58, 16, 16]             116
             ReLU-42           [-1, 58, 16, 16]               0
      ShuffleUnit-43          [-1, 116, 16, 16]               0
           Conv2d-44            [-1, 116, 8, 8]           1,160
      BatchNorm2d-45            [-1, 116, 8, 8]             232
           Conv2d-46            [-1, 116, 8, 8]          13,572
      BatchNorm2d-47            [-1, 116, 8, 8]             232
             ReLU-48            [-1, 116, 8, 8]               0
           Conv2d-49          [-1, 116, 16, 16]          13,572
      BatchNorm2d-50          [-1, 116, 16, 16]             232
             ReLU-51          [-1, 116, 16, 16]               0
           Conv2d-52            [-1, 116, 8, 8]           1,160
      BatchNorm2d-53            [-1, 116, 8, 8]             232
           Conv2d-54            [-1, 116, 8, 8]          13,572
      BatchNorm2d-55            [-1, 116, 8, 8]             232
             ReLU-56            [-1, 116, 8, 8]               0
      ShuffleUnit-57            [-1, 232, 8, 8]               0
           Conv2d-58            [-1, 116, 8, 8]          13,572
      BatchNorm2d-59            [-1, 116, 8, 8]             232
             ReLU-60            [-1, 116, 8, 8]               0
           Conv2d-61            [-1, 116, 8, 8]           1,160
      BatchNorm2d-62            [-1, 116, 8, 8]             232
           Conv2d-63            [-1, 116, 8, 8]          13,572
      BatchNorm2d-64            [-1, 116, 8, 8]             232
             ReLU-65            [-1, 116, 8, 8]               0
      ShuffleUnit-66            [-1, 232, 8, 8]               0
           Conv2d-67            [-1, 116, 8, 8]          13,572
      BatchNorm2d-68            [-1, 116, 8, 8]             232
             ReLU-69            [-1, 116, 8, 8]               0
           Conv2d-70            [-1, 116, 8, 8]           1,160
      BatchNorm2d-71            [-1, 116, 8, 8]             232
           Conv2d-72            [-1, 116, 8, 8]          13,572
      BatchNorm2d-73            [-1, 116, 8, 8]             232
             ReLU-74            [-1, 116, 8, 8]               0
      ShuffleUnit-75            [-1, 232, 8, 8]               0
           Conv2d-76            [-1, 116, 8, 8]          13,572
      BatchNorm2d-77            [-1, 116, 8, 8]             232
             ReLU-78            [-1, 116, 8, 8]               0
           Conv2d-79            [-1, 116, 8, 8]           1,160
      BatchNorm2d-80            [-1, 116, 8, 8]             232
           Conv2d-81            [-1, 116, 8, 8]          13,572
      BatchNorm2d-82            [-1, 116, 8, 8]             232
             ReLU-83            [-1, 116, 8, 8]               0
      ShuffleUnit-84            [-1, 232, 8, 8]               0
           Conv2d-85            [-1, 116, 8, 8]          13,572
      BatchNorm2d-86            [-1, 116, 8, 8]             232
             ReLU-87            [-1, 116, 8, 8]               0
           Conv2d-88            [-1, 116, 8, 8]           1,160
      BatchNorm2d-89            [-1, 116, 8, 8]             232
           Conv2d-90            [-1, 116, 8, 8]          13,572
      BatchNorm2d-91            [-1, 116, 8, 8]             232
             ReLU-92            [-1, 116, 8, 8]               0
      ShuffleUnit-93            [-1, 232, 8, 8]               0
           Conv2d-94            [-1, 116, 8, 8]          13,572
      BatchNorm2d-95            [-1, 116, 8, 8]             232
             ReLU-96            [-1, 116, 8, 8]               0
           Conv2d-97            [-1, 116, 8, 8]           1,160
      BatchNorm2d-98            [-1, 116, 8, 8]             232
           Conv2d-99            [-1, 116, 8, 8]          13,572
     BatchNorm2d-100            [-1, 116, 8, 8]             232
            ReLU-101            [-1, 116, 8, 8]               0
     ShuffleUnit-102            [-1, 232, 8, 8]               0
          Conv2d-103            [-1, 116, 8, 8]          13,572
     BatchNorm2d-104            [-1, 116, 8, 8]             232
            ReLU-105            [-1, 116, 8, 8]               0
          Conv2d-106            [-1, 116, 8, 8]           1,160
     BatchNorm2d-107            [-1, 116, 8, 8]             232
          Conv2d-108            [-1, 116, 8, 8]          13,572
     BatchNorm2d-109            [-1, 116, 8, 8]             232
            ReLU-110            [-1, 116, 8, 8]               0
     ShuffleUnit-111            [-1, 232, 8, 8]               0
          Conv2d-112            [-1, 116, 8, 8]          13,572
     BatchNorm2d-113            [-1, 116, 8, 8]             232
            ReLU-114            [-1, 116, 8, 8]               0
          Conv2d-115            [-1, 116, 8, 8]           1,160
     BatchNorm2d-116            [-1, 116, 8, 8]             232
          Conv2d-117            [-1, 116, 8, 8]          13,572
     BatchNorm2d-118            [-1, 116, 8, 8]             232
            ReLU-119            [-1, 116, 8, 8]               0
     ShuffleUnit-120            [-1, 232, 8, 8]               0
          Conv2d-121            [-1, 232, 4, 4]           2,320
     BatchNorm2d-122            [-1, 232, 4, 4]             464
          Conv2d-123            [-1, 232, 4, 4]          54,056
     BatchNorm2d-124            [-1, 232, 4, 4]             464
            ReLU-125            [-1, 232, 4, 4]               0
          Conv2d-126            [-1, 232, 8, 8]          54,056
     BatchNorm2d-127            [-1, 232, 8, 8]             464
            ReLU-128            [-1, 232, 8, 8]               0
          Conv2d-129            [-1, 232, 4, 4]           2,320
     BatchNorm2d-130            [-1, 232, 4, 4]             464
          Conv2d-131            [-1, 232, 4, 4]          54,056
     BatchNorm2d-132            [-1, 232, 4, 4]             464
            ReLU-133            [-1, 232, 4, 4]               0
     ShuffleUnit-134            [-1, 464, 4, 4]               0
          Conv2d-135            [-1, 232, 4, 4]          54,056
     BatchNorm2d-136            [-1, 232, 4, 4]             464
            ReLU-137            [-1, 232, 4, 4]               0
          Conv2d-138            [-1, 232, 4, 4]           2,320
     BatchNorm2d-139            [-1, 232, 4, 4]             464
          Conv2d-140            [-1, 232, 4, 4]          54,056
     BatchNorm2d-141            [-1, 232, 4, 4]             464
            ReLU-142            [-1, 232, 4, 4]               0
     ShuffleUnit-143            [-1, 464, 4, 4]               0
          Conv2d-144            [-1, 232, 4, 4]          54,056
     BatchNorm2d-145            [-1, 232, 4, 4]             464
            ReLU-146            [-1, 232, 4, 4]               0
          Conv2d-147            [-1, 232, 4, 4]           2,320
     BatchNorm2d-148            [-1, 232, 4, 4]             464
          Conv2d-149            [-1, 232, 4, 4]          54,056
     BatchNorm2d-150            [-1, 232, 4, 4]             464
            ReLU-151            [-1, 232, 4, 4]               0
     ShuffleUnit-152            [-1, 464, 4, 4]               0
          Conv2d-153            [-1, 232, 4, 4]          54,056
     BatchNorm2d-154            [-1, 232, 4, 4]             464
            ReLU-155            [-1, 232, 4, 4]               0
          Conv2d-156            [-1, 232, 4, 4]           2,320
     BatchNorm2d-157            [-1, 232, 4, 4]             464
          Conv2d-158            [-1, 232, 4, 4]          54,056
     BatchNorm2d-159            [-1, 232, 4, 4]             464
            ReLU-160            [-1, 232, 4, 4]               0
     ShuffleUnit-161            [-1, 464, 4, 4]               0
          Conv2d-162           [-1, 1024, 4, 4]         476,160
     BatchNorm2d-163           [-1, 1024, 4, 4]           2,048
            ReLU-164           [-1, 1024, 4, 4]               0
          Linear-165                  [-1, 100]         102,500
================================================================
Total params: 1,360,896
Trainable params: 1,360,896
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 12.66
Params size (MB): 5.19
Estimated Total Size (MB): 17.86
----------------------------------------------------------------

 

Lightning モジュール

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR, CyclicLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.swa_utils import AveragedModel, update_bn
import torchvision
 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, GPUStatsMonitor, EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
#from pl_bolts.datamodules import CIFAR10DataModule
#from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
pl.seed_everything(7);
batch_size = 50
 
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    cifar100_normalization(),
])
 
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    cifar100_normalization(),
])
 
cifar100_dm = CIFAR100DataModule(
    batch_size=batch_size,
    num_workers=8,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    val_transforms=test_transforms,
)
class LitCifar100(pl.LightningModule):
    def __init__(self, lr=0.05, factor=0.8):
        super().__init__()
  
        self.save_hyperparameters()
        self.model = shufflenetv2()
 
    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)
  
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        return loss
  
    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
  
        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)
  
    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, 'val')
  
    def test_step(self, batch, batch_idx):
        self.evaluate(batch, 'test')
  
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
 
        return {
          'optimizer': optimizer,
          'lr_scheduler': ReduceLROnPlateau(optimizer, 'max', patience=5, factor=self.hparams.factor, verbose=True, threshold=0.0001, threshold_mode='abs', cooldown=1, min_lr=1e-5),
          'monitor': 'val_acc'
        }

 

訓練 / 評価

%%time
 
model = LitCifar100(lr=0.05, factor=0.5)
model.datamodule = cifar100_dm
  
trainer = pl.Trainer(
    gpus=1,
    max_epochs=100,
    progress_bar_refresh_rate=100,
    logger=pl.loggers.TensorBoardLogger('tblogs/', name='shufflenetv2'),
    callbacks=[LearningRateMonitor(logging_interval='step')],
)
  
trainer.fit(model, cifar100_dm)
trainer.test(model, datamodule=cifar100_dm);
  | Name  | Type         | Params
---------------------------------------
0 | model | ShuffleNetV2 | 1.4 M 
---------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.444     Total estimated model params size (MB)
(...)
Epoch    32: reducing learning rate of group 0 to 2.5000e-02.
Epoch    39: reducing learning rate of group 0 to 1.2500e-02.
Epoch    47: reducing learning rate of group 0 to 6.2500e-03.
Epoch    56: reducing learning rate of group 0 to 3.1250e-03.
Epoch    66: reducing learning rate of group 0 to 1.5625e-03.
Epoch    78: reducing learning rate of group 0 to 7.8125e-04.
Epoch    85: reducing learning rate of group 0 to 3.9063e-04.
Epoch    95: reducing learning rate of group 0 to 1.9531e-04.
(...)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.6672999858856201, 'test_loss': 1.4519164562225342}
--------------------------------------------------------------------------------
CPU times: user 1h 21min 24s, sys: 1min 37s, total: 1h 23min 1s
Wall time: 1h 26min 15s
 

以上