PyTorch 0.4.1 examples (コード解説) : 画像分類 – MNIST (Network in Network)

PyTorch 0.4.1 examples (コード解説) : 画像分類 – MNIST (Network in Network)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 07/29/2018 (0.4.1)

* 本ページは、github 上の以下の pytorch/examples と keras/examples レポジトリのサンプル・コードを参考にしています:

* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 
 

Network In Network

マイクロ・ネットワークで知られる Network in Network (NiN) はより小さいモデルでより高速に訓練することを可能とし、overfitting も比較的起きにくいとされています。2014 年に公開されており今となっては古いのですが、後のモデルにも影響を与えたモデルです。

基本的なアイデアは MLP 畳込み層 (mlpconv) – マイクロ・ネットワーク – の導入とグローバル平均プーリングによる完全結合層の代用です。特に前者については広く利用されて類似のアイデアは Inception や ResNet にも見られます。後者は特徴マップの解釈を容易にして overfitting を起きにくくするそうです。

Network In Network 全体の構成としては mlpconv 層のスタックになり、そのトップにグローバル平均プーリングが続くだけのシンプルなモデルです。

 
◆ Network In Network の原論文は以下です :

  • Network In Network
    Min Lin, Qiang Chen, Shuicheng Yan
    (Submitted on 16 Dec 2013 (v1), last revised 4 Mar 2014 (this version, v3))

Abstract のみ翻訳しておきますが、基本的にはここに書かれていることが全てです :

受容野内のローカルパッチのモデル識別性を強化するために “Network in Network” (NIN) と呼ばれる、新しい深層ネットワーク構造を提案します。従来の畳込み層は入力をスキャンするために、非線形活性化関数が続く線形フィルタを使用します。代わりに、受容野内のデータを抽象化するためにより複雑な構造でマイクロ・ニューラルネットワークを構築します。そのマイクロ・ニューラルネットワークは潜在的な関数近似器である多層パーセプトロンでインスタンス化します。特徴マップは CNN と同様の流儀で入力に渡ってマイクロ・ネットワークをスライドさせることで得られます; そしてそれらは次の層に供給されます。深層 NIN は上で記述された構造を多層化することで実装されます。マイクロ・ネットワークを通したローカルモデリングの強化により、分類層の特徴マップに渡るグローバル平均プーリングを利用することが可能で、これは伝統的な完全結合層よりも解釈がより容易で過剰適合しにくいです。私たちは CIFAR-10 と CIFAR-100 上の NIN で state-of-the-art な分類性能を示し、SVHN と MNIST データセットで合理的な性能を示しました。

 
◆ 実装は以下の Caffe 実装が参考になります :

 

Network in Network for MNIST

PyTorch 0.4.x の自作のサンプルをコードの簡単な解説とともに提供しています。
初級チュートリアル程度の知識は仮定しています。

先に MNIST 画像分類タスクのために MLP / ConvNet モデルを実装しましたが、
同じタスクに対して Network in Network モデルを実装してみます。

 

モデル定義

基本的には上掲の Caffe 実装を参考にしてそのまま PyTorch 実装にしただけですが、バッチ正規化層を追加しています :

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv11 = nn.Conv2d(1, 192, 5, stride=1, padding=2)
        self.conv12 = nn.Conv2d(192, 160, 1, stride=1, padding=0)
        self.conv13 = nn.Conv2d(160, 96, 1, stride=1, padding=0)

        self.bn1 = nn.BatchNorm2d(96)
        self.dropout1 = nn.Dropout2d(0.5)

        self.conv21 = nn.Conv2d(96, 192, 5, stride=1, padding=2)
        self.conv22 = nn.Conv2d(192, 192, 1, stride=1, padding=0)
        self.conv23 = nn.Conv2d(192, 192, 1, stride=1, padding=0)

        self.bn2 = nn.BatchNorm2d(192)
        self.dropout2 = nn.Dropout2d(0.5)

        self.conv31 = nn.Conv2d(192, 192, 3, stride=1, padding=1)
        self.conv32 = nn.Conv2d(192, 192, 1, stride=1, padding=0)
        self.conv33 = nn.Conv2d(192, num_classes, 1, stride=1, padding=0)

        self.bn3 = nn.BatchNorm2d(num_classes)


    def forward(self, x):
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        x = F.relu(self.conv13(x))

        x = F.max_pool2d(self.bn1(x), 3, stride=2, padding=1)

        x = self.dropout1(x)

        x = F.relu(self.conv21(x))
        x = F.relu(self.conv22(x))
        x = F.relu(self.conv23(x))
 
        x = F.avg_pool2d(self.bn2(x), 3, stride=2, padding=1)

        x = self.dropout2(x)

        x = F.relu(self.conv31(x))
        x = F.relu(self.conv32(x))
        x = F.relu(self.conv33(x))

        x = F.avg_pool2d(self.bn3(x), 7, stride=1, padding=0)
 
        return x.view(x.size(0), num_classes)

インスタンスを直接プリントすると、含まれる層の情報が得られます :

print(model)

Out:

Net(
  (conv11): Conv2d(1, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv12): Conv2d(192, 160, kernel_size=(1, 1), stride=(1, 1))
  (conv13): Conv2d(160, 96, kernel_size=(1, 1), stride=(1, 1))
  (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout1): Dropout2d(p=0.5)
  (conv21): Conv2d(96, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv22): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
  (conv23): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
  (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout2): Dropout2d(p=0.5)
  (conv31): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv32): Conv2d(192, 192, kernel_size=(1, 1), stride=(1, 1))
  (conv33): Conv2d(192, 10, kernel_size=(1, 1), stride=(1, 1))
  (bn3): BatchNorm2d(10, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

 

損失関数と optimizer

損失関数としては nn.CrossEntropyLoss() を使用します :

criterion = nn.CrossEntropyLoss()

optimizer は先に SGD を使用しましたので、ここでは Adam を使用してみましょう :

optimizer = optim.Adam(model.parameters(), lr=0.002)

ついでに学習率のスケジューラも試してみます (設定数は直感で決めています) :

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.90)

 

訓練コード (per epoch)

各エポックの訓練コードです。100 ステップ毎に損失値を表示してやります。
学習率スケジューラを使用していますので scheduler.step() を忘れずにエポック毎に呼び出します :

global_step = 0

def train(epoch, writer):
    model.train()
    scheduler.step()

    print("\n--- Epoch : %2d ---" % epoch)
    print("lr : %f" % optimizer.param_groups[0]['lr'])

    steps = len(ds_train)//batch_size
    for step, (images, labels) in enumerate(dataloader_train, 1):
        global global_step
        global_step += 1
        #print(labels.numpy().shape)
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if step % 100 == 0:
            print ('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' % (epoch, epochs, step, steps, loss.item()))
            writer.add_scalar('train/train_loss', loss.item() , global_step)

TensorBoard を利用するためには、グローバルステップをカウントする global_step を追加して、SummaryWriter の .add_scalar() メソッドで損失を書き込みます。

 

評価コード (per epoch)

各エポックの最後にテスト・データセットで評価します。単純に正解数をカウントしています。
acc_trans は各エポックの精度をストアするために使用しています :

acc_trans = []

def eval(epoch, writer):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for (images, labels) in dataloader_test:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    val_acc = correct*100./total
    acc_trans.append(val_acc)
    print("Val Acc : %.4f" % val_acc)
    writer.add_scalar('eval/val_acc', val_acc, epoch)

TensorBoard を利用するために、(評価ルーチンはエポック毎ですので、) epoch カウントを利用して検証精度を書き込みます。

 

訓練の実行

次のループで訓練が実行できます :

from tensorboardX import SummaryWriter
writer = SummaryWriter()

for epoch in range(1, epochs+1):
    train(epoch, writer)
    eval(epoch, writer)

acc_trans = torch.Tensor(acc_trans)
print("* Max Val Acc : %.4f (index: %d)" % (torch.max(acc_trans), torch.argmax(acc_trans)))

writer.close()

tensorboardX モジュールを使用すれば、TensorBoard が利用できます。
SummaryWriter をインスタンス化して訓練/評価ルーチンに渡してやります。

モデルは pickle でセーブできます :

torch.save(model.state_dict(), 'model_mnist_nin.pkl')

実行時出力です :

--- Epoch :  1 ---
lr : 0.001500
Epoch [1/100], Step [100/600], Loss: 1.2303
Epoch [1/100], Step [200/600], Loss: 0.8310
Epoch [1/100], Step [300/600], Loss: 0.6912
Epoch [1/100], Step [400/600], Loss: 0.4581
Epoch [1/100], Step [500/600], Loss: 0.3829
Epoch [1/100], Step [600/600], Loss: 0.2541
Val Acc : 98.4400

...

--- Epoch : 87 ---
lr : 0.000646
Epoch [87/100], Step [100/600], Loss: 0.0005
Epoch [87/100], Step [200/600], Loss: 0.0006
Epoch [87/100], Step [300/600], Loss: 0.0006
Epoch [87/100], Step [400/600], Loss: 0.0002
Epoch [87/100], Step [500/600], Loss: 0.0209
Epoch [87/100], Step [600/600], Loss: 0.0002
Val Acc : 99.7200

...

--- Epoch : 100 ---
lr : 0.000581
Epoch [100/100], Step [100/600], Loss: 0.0001
Epoch [100/100], Step [200/600], Loss: 0.0003
Epoch [100/100], Step [300/600], Loss: 0.0007
Epoch [100/100], Step [400/600], Loss: 0.0005
Epoch [100/100], Step [500/600], Loss: 0.0002
Epoch [100/100], Step [600/600], Loss: 0.0001
Val Acc : 99.6400

精度 99.72 % (エラー率 0.28 %) が出ています。
MLP モデルでは精度 96.99 % (エラー率 3.01 %)、単純な ConvNet では精度 98.13 % (エラー率 1.87%) そして Xavier 一様分布による初期化を伴う同じ ConvNet で精度 99.24 % (エラー率 0.76%)でしたので、かなり改善されています。

補記: 明示的な初期化も試してみました。
ここまで試した結果の精度 (とエラー率) をまとめると以下のようになります :

  • MLP
    • 精度 96.99 % (エラー率 3.01 %)
  • ConvNet
    • 初期化なし : 精度 98.13 % (エラー率 1.87 %)
    • init.xavier_uniform_ : 精度 99.24 % (エラー率 0.76%)
    • init.kaiming_uniform_ : 精度99.35 % (エラー率 0.65 %)
    • init.orthogonal_ : 精度 99.27 % (エラー率 0.73 %)
  • Network in Network
    • 初期化なし : 精度 99.75 % (エラー率 0.28 %)
    • init.xavier_uniform_ : 精度 99.74 % (エラー率 0.26 %)
    • init.kaiming_uniform_ : 精度 99.73 % (エラー率 0.27 %)
    • init.orthogonal_ : 精度 99.74 % (エラー率 0.26 %)

※ (公開されている) 記録上は Network in Network の改良モデルでエラー率 0.24 % が達成されています。

 

以上