PyTorch Lightning 1.1 : research: CIFAR10 (ShuffleNet)

PyTorch Lightning 1.1: research : CIFAR10 (ShuffleNet)
作成 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/23/2021 (1.1.x)

* 本ページは、以下のリソースを参考にして遂行した実験結果のレポートです:

* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

research: CIFAR10 (ShuffleNet)

仕様

  • Total params: 352,042 (352K)
  • Trainable params: 352,042
  • Non-trainable params: 0

 

結果

  • ShuffleNetV2
  • {‘test_acc’: 0.8831999897956848, ‘test_loss’: 0.3897647559642792}
  • 100 エポック ; Wall time: 1h 6min 42s
  • ‘Tesla M60’ x 2
  • ReduceLROnPlateau

 

コード

import torch
import torch.nn as nn
import torch.nn.functional as F


class ShuffleBlock(nn.Module):
    def __init__(self, groups=2):
        super(ShuffleBlock, self).__init__()
        self.groups = groups

    def forward(self, x):
        '''Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]'''
        N, C, H, W = x.size()
        g = self.groups
        return x.view(N, g, C//g, H, W).permute(0, 2, 1, 3, 4).reshape(N, C, H, W)


class SplitBlock(nn.Module):
    def __init__(self, ratio):
        super(SplitBlock, self).__init__()
        self.ratio = ratio

    def forward(self, x):
        c = int(x.size(1) * self.ratio)
        return x[:, :c, :, :], x[:, c:, :, :]


class BasicBlock(nn.Module):
    def __init__(self, in_channels, split_ratio=0.5):
        super(BasicBlock, self).__init__()
        self.split = SplitBlock(split_ratio)
        in_channels = int(in_channels * split_ratio)
        self.conv1 = nn.Conv2d(in_channels, in_channels,
                               kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, in_channels,
                               kernel_size=3, stride=1, padding=1, groups=in_channels, bias=False)
        self.bn2 = nn.BatchNorm2d(in_channels)
        self.conv3 = nn.Conv2d(in_channels, in_channels,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(in_channels)
        self.shuffle = ShuffleBlock()

    def forward(self, x):
        x1, x2 = self.split(x)
        out = F.relu(self.bn1(self.conv1(x2)))
        out = self.bn2(self.conv2(out))
        out = F.relu(self.bn3(self.conv3(out)))
        out = torch.cat([x1, out], 1)
        out = self.shuffle(out)
        return out


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        mid_channels = out_channels // 2
        # left
        self.conv1 = nn.Conv2d(in_channels, in_channels,
                               kernel_size=3, stride=2, padding=1, groups=in_channels, bias=False)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv2 = nn.Conv2d(in_channels, mid_channels,
                               kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(mid_channels)
        # right
        self.conv3 = nn.Conv2d(in_channels, mid_channels,
                               kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(mid_channels)
        self.conv4 = nn.Conv2d(mid_channels, mid_channels,
                               kernel_size=3, stride=2, padding=1, groups=mid_channels, bias=False)
        self.bn4 = nn.BatchNorm2d(mid_channels)
        self.conv5 = nn.Conv2d(mid_channels, mid_channels,
                               kernel_size=1, bias=False)
        self.bn5 = nn.BatchNorm2d(mid_channels)

        self.shuffle = ShuffleBlock()

    def forward(self, x):
        # left
        out1 = self.bn1(self.conv1(x))
        out1 = F.relu(self.bn2(self.conv2(out1)))
        # right
        out2 = F.relu(self.bn3(self.conv3(x)))
        out2 = self.bn4(self.conv4(out2))
        out2 = F.relu(self.bn5(self.conv5(out2)))
        # concat
        out = torch.cat([out1, out2], 1)
        out = self.shuffle(out)
        return out


class ShuffleNetV2(nn.Module):
    def __init__(self, net_size):
        super(ShuffleNetV2, self).__init__()
        out_channels = configs[net_size]['out_channels']
        num_blocks = configs[net_size]['num_blocks']

        self.conv1 = nn.Conv2d(3, 24, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(24)
        self.in_channels = 24
        self.layer1 = self._make_layer(out_channels[0], num_blocks[0])
        self.layer2 = self._make_layer(out_channels[1], num_blocks[1])
        self.layer3 = self._make_layer(out_channels[2], num_blocks[2])
        self.conv2 = nn.Conv2d(out_channels[2], out_channels[3],
                               kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels[3])
        self.linear = nn.Linear(out_channels[3], 10)

    def _make_layer(self, out_channels, num_blocks):
        layers = [DownBlock(self.in_channels, out_channels)]
        for i in range(num_blocks):
            layers.append(BasicBlock(out_channels))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        # out = F.max_pool2d(out, 3, stride=2, padding=1)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.relu(self.bn2(self.conv2(out)))
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


configs = {
    0.5: {
        'out_channels': (48, 96, 192, 1024),
        'num_blocks': (3, 7, 3)
    },

    1: {
        'out_channels': (116, 232, 464, 1024),
        'num_blocks': (3, 7, 3)
    },
    1.5: {
        'out_channels': (176, 352, 704, 1024),
        'num_blocks': (3, 7, 3)
    },
    2: {
        'out_channels': (224, 488, 976, 2048),
        'num_blocks': (3, 7, 3)
    }
}
net = ShuffleNetV2(net_size=0.5)
print(net)
x = torch.randn(3, 3, 32, 32)
y = net(x)
print(y.shape)
ShuffleNetV2(
  (conv1): Conv2d(3, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): DownBlock(
      (conv1): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
      (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv4): Conv2d(24, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=24, bias=False)
      (bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv5): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn5): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (1): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24, bias=False)
      (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (2): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24, bias=False)
      (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (3): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24, bias=False)
      (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
  )
  (layer2): Sequential(
    (0): DownBlock(
      (conv1): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv4): Conv2d(48, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=48, bias=False)
      (bn4): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv5): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn5): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (1): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (2): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (3): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (4): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (5): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (6): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (7): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(48, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=48, bias=False)
      (bn2): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
  )
  (layer3): Sequential(
    (0): DownBlock(
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv4): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
      (bn4): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv5): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn5): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (1): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)
      (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (2): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)
      (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
    (3): BasicBlock(
      (split): SplitBlock()
      (conv1): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=96, bias=False)
      (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(96, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shuffle): ShuffleBlock()
    )
  )
  (conv2): Conv2d(192, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear): Linear(in_features=1024, out_features=10, bias=True)
)
torch.Size([3, 10])
from torchsummary import summary

summary(ShuffleNetV2(net_size=0.5).to('cuda:0'), (3, 32, 32))
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 24, 32, 32]             648
       BatchNorm2d-2           [-1, 24, 32, 32]              48
            Conv2d-3           [-1, 24, 16, 16]             216
       BatchNorm2d-4           [-1, 24, 16, 16]              48
            Conv2d-5           [-1, 24, 16, 16]             576
       BatchNorm2d-6           [-1, 24, 16, 16]              48
            Conv2d-7           [-1, 24, 32, 32]             576
       BatchNorm2d-8           [-1, 24, 32, 32]              48
            Conv2d-9           [-1, 24, 16, 16]             216
      BatchNorm2d-10           [-1, 24, 16, 16]              48
           Conv2d-11           [-1, 24, 16, 16]             576
      BatchNorm2d-12           [-1, 24, 16, 16]              48
     ShuffleBlock-13           [-1, 48, 16, 16]               0
        DownBlock-14           [-1, 48, 16, 16]               0
       SplitBlock-15  [[-1, 24, 16, 16], [-1, 24, 16, 16]]               0
           Conv2d-16           [-1, 24, 16, 16]             576
      BatchNorm2d-17           [-1, 24, 16, 16]              48
           Conv2d-18           [-1, 24, 16, 16]             216
      BatchNorm2d-19           [-1, 24, 16, 16]              48
           Conv2d-20           [-1, 24, 16, 16]             576
      BatchNorm2d-21           [-1, 24, 16, 16]              48
     ShuffleBlock-22           [-1, 48, 16, 16]               0
       BasicBlock-23           [-1, 48, 16, 16]               0
       SplitBlock-24  [[-1, 24, 16, 16], [-1, 24, 16, 16]]               0
           Conv2d-25           [-1, 24, 16, 16]             576
      BatchNorm2d-26           [-1, 24, 16, 16]              48
           Conv2d-27           [-1, 24, 16, 16]             216
      BatchNorm2d-28           [-1, 24, 16, 16]              48
           Conv2d-29           [-1, 24, 16, 16]             576
      BatchNorm2d-30           [-1, 24, 16, 16]              48
     ShuffleBlock-31           [-1, 48, 16, 16]               0
       BasicBlock-32           [-1, 48, 16, 16]               0
       SplitBlock-33  [[-1, 24, 16, 16], [-1, 24, 16, 16]]               0
           Conv2d-34           [-1, 24, 16, 16]             576
      BatchNorm2d-35           [-1, 24, 16, 16]              48
           Conv2d-36           [-1, 24, 16, 16]             216
      BatchNorm2d-37           [-1, 24, 16, 16]              48
           Conv2d-38           [-1, 24, 16, 16]             576
      BatchNorm2d-39           [-1, 24, 16, 16]              48
     ShuffleBlock-40           [-1, 48, 16, 16]               0
       BasicBlock-41           [-1, 48, 16, 16]               0
           Conv2d-42             [-1, 48, 8, 8]             432
      BatchNorm2d-43             [-1, 48, 8, 8]              96
           Conv2d-44             [-1, 48, 8, 8]           2,304
      BatchNorm2d-45             [-1, 48, 8, 8]              96
           Conv2d-46           [-1, 48, 16, 16]           2,304
      BatchNorm2d-47           [-1, 48, 16, 16]              96
           Conv2d-48             [-1, 48, 8, 8]             432
      BatchNorm2d-49             [-1, 48, 8, 8]              96
           Conv2d-50             [-1, 48, 8, 8]           2,304
      BatchNorm2d-51             [-1, 48, 8, 8]              96
     ShuffleBlock-52             [-1, 96, 8, 8]               0
        DownBlock-53             [-1, 96, 8, 8]               0
       SplitBlock-54  [[-1, 48, 8, 8], [-1, 48, 8, 8]]               0
           Conv2d-55             [-1, 48, 8, 8]           2,304
      BatchNorm2d-56             [-1, 48, 8, 8]              96
           Conv2d-57             [-1, 48, 8, 8]             432
      BatchNorm2d-58             [-1, 48, 8, 8]              96
           Conv2d-59             [-1, 48, 8, 8]           2,304
      BatchNorm2d-60             [-1, 48, 8, 8]              96
     ShuffleBlock-61             [-1, 96, 8, 8]               0
       BasicBlock-62             [-1, 96, 8, 8]               0
       SplitBlock-63  [[-1, 48, 8, 8], [-1, 48, 8, 8]]               0
           Conv2d-64             [-1, 48, 8, 8]           2,304
      BatchNorm2d-65             [-1, 48, 8, 8]              96
           Conv2d-66             [-1, 48, 8, 8]             432
      BatchNorm2d-67             [-1, 48, 8, 8]              96
           Conv2d-68             [-1, 48, 8, 8]           2,304
      BatchNorm2d-69             [-1, 48, 8, 8]              96
     ShuffleBlock-70             [-1, 96, 8, 8]               0
       BasicBlock-71             [-1, 96, 8, 8]               0
       SplitBlock-72  [[-1, 48, 8, 8], [-1, 48, 8, 8]]               0
           Conv2d-73             [-1, 48, 8, 8]           2,304
      BatchNorm2d-74             [-1, 48, 8, 8]              96
           Conv2d-75             [-1, 48, 8, 8]             432
      BatchNorm2d-76             [-1, 48, 8, 8]              96
           Conv2d-77             [-1, 48, 8, 8]           2,304
      BatchNorm2d-78             [-1, 48, 8, 8]              96
     ShuffleBlock-79             [-1, 96, 8, 8]               0
       BasicBlock-80             [-1, 96, 8, 8]               0
       SplitBlock-81  [[-1, 48, 8, 8], [-1, 48, 8, 8]]               0
           Conv2d-82             [-1, 48, 8, 8]           2,304
      BatchNorm2d-83             [-1, 48, 8, 8]              96
           Conv2d-84             [-1, 48, 8, 8]             432
      BatchNorm2d-85             [-1, 48, 8, 8]              96
           Conv2d-86             [-1, 48, 8, 8]           2,304
      BatchNorm2d-87             [-1, 48, 8, 8]              96
     ShuffleBlock-88             [-1, 96, 8, 8]               0
       BasicBlock-89             [-1, 96, 8, 8]               0
       SplitBlock-90  [[-1, 48, 8, 8], [-1, 48, 8, 8]]               0
           Conv2d-91             [-1, 48, 8, 8]           2,304
      BatchNorm2d-92             [-1, 48, 8, 8]              96
           Conv2d-93             [-1, 48, 8, 8]             432
      BatchNorm2d-94             [-1, 48, 8, 8]              96
           Conv2d-95             [-1, 48, 8, 8]           2,304
      BatchNorm2d-96             [-1, 48, 8, 8]              96
     ShuffleBlock-97             [-1, 96, 8, 8]               0
       BasicBlock-98             [-1, 96, 8, 8]               0
       SplitBlock-99  [[-1, 48, 8, 8], [-1, 48, 8, 8]]               0
          Conv2d-100             [-1, 48, 8, 8]           2,304
     BatchNorm2d-101             [-1, 48, 8, 8]              96
          Conv2d-102             [-1, 48, 8, 8]             432
     BatchNorm2d-103             [-1, 48, 8, 8]              96
          Conv2d-104             [-1, 48, 8, 8]           2,304
     BatchNorm2d-105             [-1, 48, 8, 8]              96
    ShuffleBlock-106             [-1, 96, 8, 8]               0
      BasicBlock-107             [-1, 96, 8, 8]               0
      SplitBlock-108  [[-1, 48, 8, 8], [-1, 48, 8, 8]]               0
          Conv2d-109             [-1, 48, 8, 8]           2,304
     BatchNorm2d-110             [-1, 48, 8, 8]              96
          Conv2d-111             [-1, 48, 8, 8]             432
     BatchNorm2d-112             [-1, 48, 8, 8]              96
          Conv2d-113             [-1, 48, 8, 8]           2,304
     BatchNorm2d-114             [-1, 48, 8, 8]              96
    ShuffleBlock-115             [-1, 96, 8, 8]               0
      BasicBlock-116             [-1, 96, 8, 8]               0
          Conv2d-117             [-1, 96, 4, 4]             864
     BatchNorm2d-118             [-1, 96, 4, 4]             192
          Conv2d-119             [-1, 96, 4, 4]           9,216
     BatchNorm2d-120             [-1, 96, 4, 4]             192
          Conv2d-121             [-1, 96, 8, 8]           9,216
     BatchNorm2d-122             [-1, 96, 8, 8]             192
          Conv2d-123             [-1, 96, 4, 4]             864
     BatchNorm2d-124             [-1, 96, 4, 4]             192
          Conv2d-125             [-1, 96, 4, 4]           9,216
     BatchNorm2d-126             [-1, 96, 4, 4]             192
    ShuffleBlock-127            [-1, 192, 4, 4]               0
       DownBlock-128            [-1, 192, 4, 4]               0
      SplitBlock-129  [[-1, 96, 4, 4], [-1, 96, 4, 4]]               0
          Conv2d-130             [-1, 96, 4, 4]           9,216
     BatchNorm2d-131             [-1, 96, 4, 4]             192
          Conv2d-132             [-1, 96, 4, 4]             864
     BatchNorm2d-133             [-1, 96, 4, 4]             192
          Conv2d-134             [-1, 96, 4, 4]           9,216
     BatchNorm2d-135             [-1, 96, 4, 4]             192
    ShuffleBlock-136            [-1, 192, 4, 4]               0
      BasicBlock-137            [-1, 192, 4, 4]               0
      SplitBlock-138  [[-1, 96, 4, 4], [-1, 96, 4, 4]]               0
          Conv2d-139             [-1, 96, 4, 4]           9,216
     BatchNorm2d-140             [-1, 96, 4, 4]             192
          Conv2d-141             [-1, 96, 4, 4]             864
     BatchNorm2d-142             [-1, 96, 4, 4]             192
          Conv2d-143             [-1, 96, 4, 4]           9,216
     BatchNorm2d-144             [-1, 96, 4, 4]             192
    ShuffleBlock-145            [-1, 192, 4, 4]               0
      BasicBlock-146            [-1, 192, 4, 4]               0
      SplitBlock-147  [[-1, 96, 4, 4], [-1, 96, 4, 4]]               0
          Conv2d-148             [-1, 96, 4, 4]           9,216
     BatchNorm2d-149             [-1, 96, 4, 4]             192
          Conv2d-150             [-1, 96, 4, 4]             864
     BatchNorm2d-151             [-1, 96, 4, 4]             192
          Conv2d-152             [-1, 96, 4, 4]           9,216
     BatchNorm2d-153             [-1, 96, 4, 4]             192
    ShuffleBlock-154            [-1, 192, 4, 4]               0
      BasicBlock-155            [-1, 192, 4, 4]               0
          Conv2d-156           [-1, 1024, 4, 4]         196,608
     BatchNorm2d-157           [-1, 1024, 4, 4]           2,048
          Linear-158                   [-1, 10]          10,250
================================================================
Total params: 352,042
Trainable params: 352,042
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 1416.34
Params size (MB): 1.34
Estimated Total Size (MB): 1417.69
----------------------------------------------------------------

 

ReduceLROnPlateau スケジューラ

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR, CyclicLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.swa_utils import AveragedModel, update_bn
import torchvision
 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, GPUStatsMonitor, EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
pl.seed_everything(7);
batch_size = 100
 
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomCrop(32, padding=4),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])
 
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    cifar10_normalization(),
])
 
cifar10_dm = CIFAR10DataModule(
    batch_size=batch_size,
    train_transforms=train_transforms,
    test_transforms=test_transforms,
    val_transforms=test_transforms,
)
class LitCifar10(pl.LightningModule):
    def __init__(self, optim, lr=0.05, factor=0.8):
        super().__init__()
 
        self.save_hyperparameters()
        self.model = ShuffleNetV2(net_size=0.5)

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)
 
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = F.log_softmax(self.model(x), dim=1)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        return loss
 
    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
 
        if stage:
            self.log(f'{stage}_loss', loss, prog_bar=True)
            self.log(f'{stage}_acc', acc, prog_bar=True)
 
    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, 'val')
 
    def test_step(self, batch, batch_idx):
        self.evaluate(batch, 'test')
 
    def configure_optimizers(self):
        optim = self.hparams.optim
        if optim == 'adam':
            optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.lr, weight_decay=0, eps=1e-3)
        else:
            optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)

        return {
          'optimizer': optimizer,
          'lr_scheduler': ReduceLROnPlateau(optimizer, 'max', patience=5, factor=self.hparams.factor, verbose=True, threshold=0.0001, threshold_mode='abs', cooldown=1, min_lr=1e-5),
          'monitor': 'val_acc'
        }
%%time

model = LitCifar10(optim='sgd', lr=0.05, factor=0.5)
model.datamodule = cifar10_dm
 
trainer = pl.Trainer(
    gpus=2,
    #num_nodes=1, 
    accelerator='dp',
    max_epochs=100,
    progress_bar_refresh_rate=100,
    logger=pl.loggers.TensorBoardLogger('tblogs/', name='shufflenet2'),
    callbacks=[LearningRateMonitor(logging_interval='step')],
)
 
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm);
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
Files already downloaded and verified
Files already downloaded and verified

  | Name  | Type         | Params
---------------------------------------
0 | model | ShuffleNetV2 | 352 K 
---------------------------------------
352 K     Trainable params
0         Non-trainable params
352 K     Total params
1.408     Total estimated model params size (MB)
(...)
Epoch    59: reducing learning rate of group 0 to 2.5000e-02.
Epoch    81: reducing learning rate of group 0 to 1.2500e-02.
Epoch    91: reducing learning rate of group 0 to 6.2500e-03.
(...)
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.8831999897956848, 'test_loss': 0.3897647559642792}
--------------------------------------------------------------------------------
CPU times: user 1h 28min 43s, sys: 4min 40s, total: 1h 33min 23s
Wall time: 1h 6min 42s
[{'test_loss': 0.3897647559642792, 'test_acc': 0.8831999897956848}]
 

以上