PyTorch Ignite 0.4.8 : Tutorials : Fashion-MNIST の分類のための CNN (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/03/2022 (0.4.8)
* 本ページは、Pytorch Ignite の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Tutorials : Fashion-MNIST の分類のための CNN
これはニューラルネットワーク・モデルを訓練し、実験を設定してモデルを検証するために Ignite を使用するチュートリアルです。
このノートブックでは、畳み込みニューラルネットワークを使用して画像の分類を行なっていきます。
Fashion-MNIST データセット を使用していきます、Fashion-MNIST は衣服の 28×28 グレースケール画像のセットです。
Lets get started!
必要な依存性
torch と ignite は既にインストールされていることを仮定しています。pip を使用してそれをインストールできます :
!pip install pytorch-ignite
ライブラリのインポート
numpy, matplotlib と seaborn のような一般的なデータサイエンス・ライブラリ
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
モデルを作成するために torch, nn と functional モジュールをインポートします。
データセットをロードしてデータセットの画像に変換を適用するために torchvision から datasets と transforms もインポートします。
データをモデルにロードする訓練と検証ローダを作成するために Dataloader をインポートします。
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
Ignite は PyTorch でニューラルネットワークを訓練するのに役立つ高位ライブラリです。それは訓練ループ, 様々なメトリクス, ハンドラと有用な contrib セクションをセットアップするためのエンジンを装備しています!
下で、以下をインポートします :
- Engine : データセットの各バッチに対して与えられた process_function を実行し、それと共にイベントを発生させます。
- Events: 特定のイベントで関数を起動するようにエンジンに関数を装着することをユーザに可能にします。Eg: EPOCH_COMPLETED, ITERATION_STARTED, etc.
- Accuracy : 二値, マルチクラス, マルチラベル 等のためのデータセットに対して精度を計算するメトリック。
- Loss : パラメータとして損失関数を受け取る一般的なメトリックで、データセットに対する損失を計算します。
- RunningAverage : 訓練の間にエンジンに装着する一般的なメトリック。
- ModelCheckpoint : モデルをチェックポイントするためのハンドラ。
- EarlyStopping : スコア関数に基づいて訓練を停止するハンドラ。
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, RunningAverage, ConfusionMatrix
from ignite.handlers import ModelCheckpoint, EarlyStopping
下のコードは最初に画像を pytorch テンソルに変換して画像を正規化するために torhvision transfroms を使用して transform をセットアップします。
次に、fashion mnist データセットをダウンロードして、上で定義した transforms を適用するために torchvision datasets を使用します。
- trainset は訓練データを含みます。
- validationset は検証データを含みます。
次に、訓練と検証セットからデータローダを作成するために pytorch dataloader を使用します。
# transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('./data', download=True, train=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True)
# Download and load the test data
validationset = datasets.FashionMNIST('./data', download=True, train=False, transform=transform)
val_loader = DataLoader(validationset, batch_size=64, shuffle=True)
CNN モデル
モデルアーキテクチャの説明
- Convolutional layers, 畳み込み層は出力のテンソルを生成するために層入力と共に畳み込まれる畳み込みカーネルを作成するために使用されます。
- Maxpooling layers, Maxpooling 層は前の層からの最もアクティブなピクセルを保持する入力表現をダウンサンプリングするために使用されます。
- 通常の Linear + Dropout 層は過剰適合を回避して 10-次元出力を生成します。
- モデルのために Relu 非線形を、NLLLOSS 損失 を使用していきますので最後の層で logsoftmax を使用しました。
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.convlayer1 = nn.Sequential(
nn.Conv2d(1, 32, 3,padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.convlayer2 = nn.Sequential(
nn.Conv2d(32,64,3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2)
)
self.fc1 = nn.Linear(64*6*6,600)
self.drop = nn.Dropout2d(0.25)
self.fc2 = nn.Linear(600, 120)
self.fc3 = nn.Linear(120, 10)
def forward(self, x):
x = self.convlayer1(x)
x = self.convlayer2(x)
x = x.view(-1,64*6*6)
x = self.fc1(x)
x = self.drop(x)
x = self.fc2(x)
x = self.fc3(x)
return F.log_softmax(x,dim=1)
モデル, Optimizer と損失の作成
下で CNN モデルのインスタンスを作成します。モデルはデバイスに置かれて負の対数尤度損失 (negative log likelihood loss) の損失関数と 0.001 の学習率の Adam optimizer がセットアップされます。
# creating model,and defining optimizer and loss
model = CNN()
# moving model to gpu if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.NLLLoss()
Ignite を使用した訓練と評価
訓練と評価エンジンのインスタンス化
下で create_supervised_trainer と create_supervised_evaluator を使用して必要な引数を渡すことにより 3 つのエンジン – trainer、訓練セットに対する evaluator と検証セットに対する evaluator を作成します。
ignite.metrics からメトリクスをインポートします、これをモデルに対して計算することを望みます。Accuracy, ConfusionMatrix と Loss のように、それらを evaluator エンジンに渡します、各イテレーションのためにこれらのメトリクスを計算します。
- training_history: 訓練損失と精度をストアします。
- validation_history: 検証損失と精度をストアします。
- last_epoch: モデルが訓練されるまで最後のエポックをストアします。
# defining the number of epochs
epochs = 12
# creating trainer,evaluator
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
metrics = {
'accuracy':Accuracy(),
'nll':Loss(criterion),
'cm':ConfusionMatrix(num_classes=10)
}
train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)
training_history = {'accuracy':[],'loss':[]}
validation_history = {'accuracy':[],'loss':[]}
last_epoch = []
メトリクス – RunningAverage
最初に、各バッチに対するスカラー損失の移動平均を追跡するために RunningAverage のメトリックを装着します。
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')
EarlyStopping – 検証損失の追跡
次にこの訓練プロセスに対して EarlyStopping ハンドラをセットアップします。EarlyStopping は、訓練を停止する基準が何であれユーザが定義することを可能にする score_function を必要とします。この場合、検証セットの損失が 10 エポック内に減少しなければ、訓練プロセスは早期に停止します。EarlyStopping ハンドラは検証損失に依存しますので、val_evaluator に装着されています。
def score_function(engine):
val_loss = engine.state.metrics['nll']
return -val_loss
handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
val_evaluator.add_event_handler(Events.COMPLETED, handler)
カスタム関数を特定のイベントでエンジンに装着
以下では、独自のカスタム関数を定義してそれらを訓練プロセスの様々な Event に装着する方法を見ます。
下の関数は両者とも同様のタスクを実現し、それらはデータセットで動作する evaluator の結果をプリントします。一つの関数は訓練用 evaluator とデータセット上でそれを行ない、他方は検証用で行ないます。もう一つの違いはこれらの関数が trainer エンジンで装着される方法です。
最初の方法はデコレータを使用し、そのシンタクスは単純です – @ trainer.on(Events.EPOCH_COMPLETED)、これは修飾された関数がトレーナーに装着されて各エポックで呼び出されることを意味します。
2 つ目の方法はトレーナーの add_event_handler を使用します – trainer.add_event_handler(Events.EPOCH_COMPLETED, custom_function)。これは上と同じ結果を得ます。
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
train_evaluator.run(train_loader)
metrics = train_evaluator.state.metrics
accuracy = metrics['accuracy']*100
loss = metrics['nll']
last_epoch.append(0)
training_history['accuracy'].append(accuracy)
training_history['loss'].append(loss)
print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, accuracy, loss))
def log_validation_results(trainer):
val_evaluator.run(val_loader)
metrics = val_evaluator.state.metrics
accuracy = metrics['accuracy']*100
loss = metrics['nll']
validation_history['accuracy'].append(accuracy)
validation_history['loss'].append(loss)
print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(trainer.state.epoch, accuracy, loss))
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)
混同行列
混同行列は、分類モデルが何を正しく理解してどのようなタイプのエラーを起こすかについて良い考えを与えてくれます。
seaborn ライブラリからの seaborn.heatmap を使用して混同行列を可視化します。
@trainer.on(Events.COMPLETED)
def log_confusion_matrix(trainer):
val_evaluator.run(val_loader)
metrics = val_evaluator.state.metrics
cm = metrics['cm']
cm = cm.numpy()
cm = cm.astype(int)
classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']
fig, ax = plt.subplots(figsize=(10,10))
ax= plt.subplot()
sns.heatmap(cm, annot=True, ax = ax,fmt="d")
# labels, title and ticks
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
ax.xaxis.set_ticklabels(classes,rotation=90)
ax.yaxis.set_ticklabels(classes,rotation=0)
モデルチェックポイント
最後に、このモデルをチェックポイントすることを望みます。それを行なうことは重要です、訓練プロセスは時間がかかる可能性があり、訓練の間に何かの理由で問題が発生した場合、失敗したポイントから訓練を再開するためにモデルチェックポイントは有用であり得るからです。
下では各エポックの最後にモデルをチェックポイントするために Ignite の ModelCheckpoint ハンドラを使用しています。
checkpointer = ModelCheckpoint('./saved_models', 'fashionMNIST', n_saved=2, create_dir=True, save_as_state_dict=True, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'fashionMNIST': model})
エンジンの実行
次に、12 エポックの間トレーナーを実行して結果をモニタします。上で定義されたカスタム関数がエポック毎に損失と精度をプリントするのに役立つことが下で分かります。
trainer.run(train_loader, max_epochs=epochs)
Training Results - Epoch: 1 Avg accuracy: 90.05 Avg loss: 0.27 Validation Results - Epoch: 1 Avg accuracy: 88.68 Avg loss: 0.32 Training Results - Epoch: 2 Avg accuracy: 90.74 Avg loss: 0.26 Validation Results - Epoch: 2 Avg accuracy: 89.11 Avg loss: 0.31 Training Results - Epoch: 3 Avg accuracy: 92.62 Avg loss: 0.21 Validation Results - Epoch: 3 Avg accuracy: 90.48 Avg loss: 0.27 Training Results - Epoch: 4 Avg accuracy: 93.07 Avg loss: 0.19 Validation Results - Epoch: 4 Avg accuracy: 90.56 Avg loss: 0.27 Training Results - Epoch: 5 Avg accuracy: 93.72 Avg loss: 0.18 Validation Results - Epoch: 5 Avg accuracy: 90.81 Avg loss: 0.26 Training Results - Epoch: 6 Avg accuracy: 94.07 Avg loss: 0.17 Validation Results - Epoch: 6 Avg accuracy: 90.85 Avg loss: 0.26 Training Results - Epoch: 7 Avg accuracy: 94.21 Avg loss: 0.15 Validation Results - Epoch: 7 Avg accuracy: 90.50 Avg loss: 0.28 Training Results - Epoch: 8 Avg accuracy: 94.94 Avg loss: 0.14 Validation Results - Epoch: 8 Avg accuracy: 90.78 Avg loss: 0.28 Training Results - Epoch: 9 Avg accuracy: 94.19 Avg loss: 0.16 Validation Results - Epoch: 9 Avg accuracy: 90.20 Avg loss: 0.30 Training Results - Epoch: 10 Avg accuracy: 95.80 Avg loss: 0.12 Validation Results - Epoch: 10 Avg accuracy: 91.37 Avg loss: 0.29 Training Results - Epoch: 11 Avg accuracy: 96.08 Avg loss: 0.11 Validation Results - Epoch: 11 Avg accuracy: 90.64 Avg loss: 0.29 Training Results - Epoch: 12 Avg accuracy: 96.77 Avg loss: 0.09 Validation Results - Epoch: 12 Avg accuracy: 90.98 Avg loss: 0.29 State: iteration: 11256 epoch: 12 epoch_length: 938 max_epochs: 12 output: 0.0867251381278038 batch: <class 'list'> metrics: <class 'dict'> dataloader: <class 'torch.utils.data.dataloader.DataLoader'> seed: <class 'NoneType'> times: <class 'dict'>
損失と精度のプロット
次に、損失と精度をプロットします、これらは各エポックで損失と精度がどのように変化していくかを見るために training_history と validation_history dictionary 辞書にストアされています。
plt.plot(training_history['accuracy'],label="Training Accuracy")
plt.plot(validation_history['accuracy'],label="Validation Accuracy")
plt.xlabel('No. of Epochs')
plt.ylabel('Accuracy')
plt.legend(frameon=False)
plt.show()
plt.plot(training_history['loss'],label="Training Loss")
plt.plot(validation_history['loss'],label="Validation Loss")
plt.xlabel('No. of Epochs')
plt.ylabel('Loss')
plt.legend(frameon=False)
plt.show()
ディスクからセーブされたモデルをロードする
推論のためにディスクからセーブされた pytorch モデルをロードします。
# loading the saved model
def fetch_last_checkpoint_model_filename(model_save_path):
import os
from pathlib import Path
checkpoint_files = os.listdir(model_save_path)
checkpoint_files = [f for f in checkpoint_files if '.pt' in f]
checkpoint_iter = [
int(x.split('_')[2].split('.')[0])
for x in checkpoint_files]
last_idx = np.array(checkpoint_iter).argmax()
return Path(model_save_path) / checkpoint_files[last_idx]
model.load_state_dict(torch.load(fetch_last_checkpoint_model_filename('./saved_models')))
print("Model Loaded")
Model Loaded
モデルの推論
下のコードはモデルからの推論と結果の可視化のために使用されます。
ここでは val_loader からイテレーションを行なって最高確率を持つクラスを選択してからそれを実際のクラスと比較します。
# classes of fashion mnist dataset
classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']
# creating iterator for iterating the dataset
dataiter = iter(val_loader)
images, labels = dataiter.next()
images_arr = []
labels_arr = []
pred_arr = []
# moving model to cpu for inference
model.to("cpu")
# iterating on the dataset to predict the output
for i in range(0,10):
images_arr.append(images[i].unsqueeze(0))
labels_arr.append(labels[i].item())
ps = torch.exp(model(images_arr[i]))
ps = ps.data.numpy().squeeze()
pred_arr.append(np.argmax(ps))
# plotting the results
fig = plt.figure(figsize=(25,4))
for i in range(10):
ax = fig.add_subplot(2, 20/2, i+1, xticks=[], yticks=[])
ax.imshow(images_arr[i].resize_(1, 28, 28).numpy().squeeze())
ax.set_title("{} ({})".format(classes[pred_arr[i]], classes[labels_arr[i]]),
color=("green" if pred_arr[i]==labels_arr[i] else "red"))
Refrences
以上