PyTorch 0.4.1 examples (コード解説) : 画像分類 – CIFAR-10 (Simple Network)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/03/2018 (0.4.1)
* 本ページは、github 上の以下の pytorch/examples と keras/examples レポジトリのサンプル・コードを参考にしています:
- https://github.com/pytorch/examples/tree/master/mnist
- https://github.com/keras-team/keras/tree/master/examples
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
Simple Network
“Simple Network” というのはモデル名です。
モデル定義
class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv11 = nn.Conv2d(3, 64, 3, stride=1, padding=1) self.bn11 = nn.BatchNorm2d(64) self.conv12 = nn.Conv2d(64, 128, 3, stride=1, padding=1) self.bn12 = nn.BatchNorm2d(128) self.conv13 = nn.Conv2d(128, 128, 3, stride=1, padding=1) self.bn13 = nn.BatchNorm2d(128) self.conv14 = nn.Conv2d(128, 128, 3, stride=1, padding=1) self.bn14 = nn.BatchNorm2d(128) self.dropout1 = nn.Dropout2d(0.1) self.conv21 = nn.Conv2d(128, 128, 3, stride=1, padding=1) self.bn21 = nn.BatchNorm2d(128) self.conv22 = nn.Conv2d(128, 128, 3, stride=1, padding=1) self.bn22 = nn.BatchNorm2d(128) self.conv23 = nn.Conv2d(128, 256, 3, stride=1, padding=1) self.bn23 = nn.BatchNorm2d(256) self.dropout2 = nn.Dropout2d(0.1) self.conv31 = nn.Conv2d(256, 256, 3, stride=1, padding=1) self.bn31 = nn.BatchNorm2d(256) self.conv32 = nn.Conv2d(256, 256, 3, stride=1, padding=1) self.bn32 = nn.BatchNorm2d(256) self.dropout3 = nn.Dropout2d(0.1) self.conv41 = nn.Conv2d(256, 512, 3, stride=1, padding=1) self.bn41 = nn.BatchNorm2d(512) self.dropout4 = nn.Dropout2d(0.1) self.conv51 = nn.Conv2d(512, 2048, 1, stride=1, padding=0) self.bn51 = nn.BatchNorm2d(2048) self.conv52 = nn.Conv2d(2048, 256, 1, stride=1, padding=0) self.bn52 = nn.BatchNorm2d(256) self.dropout5 = nn.Dropout2d(0.1) self.conv61 = nn.Conv2d(256, 256, 3, stride=1, padding=1) self.bn61 = nn.BatchNorm2d(256) self.dropout6 = nn.Dropout2d(0.1) self.fc = nn.Linear(256, num_classes) def forward(self, x): x = self.conv11(x) x = F.relu(self.bn11(x)) x = self.conv12(x) x = F.relu(self.bn12(x)) x = self.conv13(x) x = F.relu(self.bn13(x)) x = self.conv14(x) x = F.relu(self.bn14(x)) x = F.max_pool2d(x, 2, stride=2) x = self.dropout1(x) x = self.conv21(x) x = F.relu(self.bn21(x)) x = self.conv22(x) x = F.relu(self.bn22(x)) x = self.conv23(x) x = F.relu(self.bn23(x)) x = F.max_pool2d(x, 2, stride=2) x = self.dropout2(x) x = self.conv31(x) x = F.relu(self.bn31(x)) x = self.conv32(x) x = F.relu(self.bn32(x)) x = F.max_pool2d(x, 2, stride=2) x = self.dropout3(x) x = self.conv41(x) x = F.relu(self.bn41(x)) x = F.max_pool2d(x, 2, stride=2) x = self.dropout4(x) x = self.conv51(x) x = F.relu(self.bn51(x)) x = self.conv52(x) x = F.relu(self.bn52(x)) x = F.max_pool2d(x, 2, stride=2) x = self.dropout5(x) x = self.conv61(x) x = F.relu(self.bn61(x)) #Global Max Pooling x = F.max_pool2d(x, kernel_size=x.size()[2:]) x = self.dropout6(x) x = x.view(x.size(0), -1) return self.fc(x)
訓練
訓練時の出力です :
--- Epoch : 1 --- lr : 0.001500 Epoch [1/100], Step [100/500], Loss: 1.613496 Epoch [1/100], Step [200/500], Loss: 1.379638 Epoch [1/100], Step [300/500], Loss: 1.345135 Epoch [1/100], Step [400/500], Loss: 1.097351 Epoch [1/100], Step [500/500], Loss: 1.076568 Val Acc : 62.2600 >> Max Val Acc : 62.2600 (index: 1) ... --- Epoch : 75 --- lr : 0.000717 Epoch [75/100], Step [100/500], Loss: 0.091090 Epoch [75/100], Step [200/500], Loss: 0.001435 Epoch [75/100], Step [300/500], Loss: 0.001089 Epoch [75/100], Step [400/500], Loss: 0.002576 Epoch [75/100], Step [500/500], Loss: 0.000330 Val Acc : 88.7900 >> Max Val Acc : 88.7900 (index: 75) ... --- Epoch : 100 --- lr : 0.000581 Epoch [100/100], Step [100/500], Loss: 0.000132 Epoch [100/100], Step [200/500], Loss: 0.000755 Epoch [100/100], Step [300/500], Loss: 0.000058 Epoch [100/100], Step [400/500], Loss: 0.000587 Epoch [100/100], Step [500/500], Loss: 0.001706 Val Acc : 88.4000 >> Max Val Acc : 88.7900 (index: 75)
損失グラフはきれいに下がります :
シンプルな ConvNet で 78.37 %、Network in Network で 87.76 % の精度が出ていましたが、Simple Network では 88.79 % まで到達しました :
明示的な初期化
畳み込み層の重みを明示的に初期化して訓練した場合です :
精度です :
- 初期化なし – 88.79 %
- init.kaiming_uniform_ – 88.47 %
- init.orthogonal_ – 89.01 %
0.00025
>> Max Val Acc : 84.2520 (index: 67)
以上