PyTorch Ignite 0.4.8 : Examples : MNIST (解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/02/2022 (0.4.8)
* 本ページは、PyTorch Ignite の以下のサンプルを適宜書き換えた上で解説を加えたものです :
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
PyTorch Ignite 0.4.8 : Examples : MNIST
このサンプルは MNIST を題材にして Ignite の基本的な使い方を説明しています。
インポート
以下をインポートします。
ignite の他、torch, torchvision 更には tqdm をインポートします :
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm
from ignite.engine import create_supervised_evaluator, create_supervised_trainer, Events
from ignite.metrics import Accuracy, Loss
from ignite.utils import setup_logger
ハイパーパラメータ
ここでは optimizer として SGD momentum を利用します。
log_interval はイベント ITERATION_COMPLETED のハンドラの設定として 10 バッチ毎を指定するために使用します。
train_batch_size = 64
val_batch_size = 1000
epochs = 10
lr = 0.01
momentum = 0.5
log_interval = 10
MNIST データセットとデータローダ
訓練と検証データセットを torch.utils.data.DataLoader として定義してそれぞれ train_loader と val_loader にストアします。MNIST データセットのダウンロードには torchvision.datasets を利用しています :
def get_data_loaders(train_batch_size, val_batch_size):
data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(
MNIST(download=True, root=".", transform=data_transform, train=True), batch_size=train_batch_size, shuffle=True
)
val_loader = DataLoader(
MNIST(download=False, root=".", transform=data_transform, train=False), batch_size=val_batch_size, shuffle=False
)
return train_loader, val_loader
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
デバイス
利用可能な場合には device を cuda に、そうでない場合には cpu を設定して高速化します。
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
モデル
CNN (畳み込みニューラルネットワーク) モデルのクラスを定義します :
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=-1)
モデルをインスタンス化してデバイスに転送します :
model = Net()
model.to(device) # optimizer を作成する前にモデルを転送します。
Optimizer と損失関数
optimizer と損失関数を指定します :
optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
criterion = nn.NLLLoss()
そしてプロジェクトの重要なパートのセットアップは終了です。PyTorch-Ignite は以下で見るように総ての他のボイラープレートを処理します。
エンジンの作成
次に trainer と evaluator エンジンを定義します。このサンプルではヘルパー・メソッド create_supervised_trainer() と create_supervised_evaluator() を使用してそれぞれをインスタンス化しています。
オブジェクト trainer と evaluator は Engine のインスタンスで Ignite の主要コンポーネントです。Engine は訓練/検証ループに渡る抽象化です。
trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
trainer.logger = setup_logger("trainer")
create_supervised_trainer にモデル、optimizer と損失関数を渡すことにより trainer エンジンを、そして create_supervised_evaluator にモデルと out-of-the-box な メトリックス を渡すことにより evaluator エンジンを定義しなければなりません。
val_metrics = {"accuracy": Accuracy(), "nll": Loss(criterion)}
evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)
evaluator.logger = setup_logger("evaluator")
evaluator を作成するヘルパー関数 create_supervised_evaluator は引数 metrics を受け取ることに注意してください。
ここでは 2 つのメトリクスを定義しています : 検証データセット上で計算するための accuracy (精度) と loss (損失) です。メトリクスの詳細は ignite.metrics で見つかります。
※ 訓練と検証ループに渡りより多くの制御を必要とする場合、Engine のステップロジックをラップすることにより カスタム trainer と evaluator オブジェクトを作成することもできます。
イベントハンドラの追加
コードスニペットの最も興味深いパートはイベントハンドラの追加です。あらゆる種類のイベントハンドラを追加することによりコードを更にカスタマイズできます。Engine は実行の間にトリガーされる様々なイベント上のハンドラを追加装着することが可能です。イベントがトリガーされたとき、装着されたハンドラ (関数) が実行されます。
以下はロギング目的で総ての log_interval -th 反復毎に最後に実行される関数を追加しています :
pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=f"ITERATION - loss: {0:.2f}")
trainer エンジンへのハンドラの追加
ITERATION_COMPLETED (バッチ終了時)
(log_interval 毎の) 損失 :
@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
pbar.desc = f"ITERATION - loss: {engine.state.output:.2f}"
pbar.update(log_interval)
EPOCH_COMPLETED (エポック終了時)
エポックが終了するとき訓練と検証メトリクスを計算したいです。その目的で train_loader と val_loader 上で前に定義した evaluator を実行できます。そのため epoch complete イベント上で trainer に 3 つの追加のハンドラを装着できます :
訓練メトリクス :
@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
pbar.refresh()
evaluator.run(train_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics["accuracy"]
avg_nll = metrics["nll"]
tqdm.write(
f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
)
検証メトリクス :
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics["accuracy"]
avg_nll = metrics["nll"]
tqdm.write(
f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
)
pbar.n = pbar.last_print_n = 0
EPOCH_COMPLETED (エポック終了時) | COMPLETED (訓練終了時)
実行時間 :
@trainer.on(Events.EPOCH_COMPLETED | Events.COMPLETED)
def log_time(engine):
tqdm.write(f"{trainer.last_event_name.name} took { trainer.state.times[trainer.last_event_name.name]} seconds")
訓練の実行
最後に、訓練データセット上でエンジンをスタートさせて 10 エポックの間それを実行します :
trainer.run(train_loader, max_epochs=epochs)
pbar.close()
2022-03-02 16:46:41,231 trainer INFO: Engine run starting with max_epochs=10. ITERATION - loss: 0.63: 99%|█████████▉| 930/938 [01:08<00:00, 89.66it/s]2022-03-02 16:46:52,553 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:47:01,313 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:47:01,315 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.63: 99%|█████████▉| 930/938 [01:17<00:00, 89.66it/s]2022-03-02 16:47:01,326 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 1 Avg accuracy: 0.94 Avg loss: 0.19 2022-03-02 16:47:02,647 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:47:02,648 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.63: 0%| | 0/938 [01:18<00:10, 89.66it/s]2022-03-02 16:47:02,669 trainer INFO: Epoch[1] Complete. Time taken: 00:00:21 ITERATION - loss: 0.26: 2%|▏ | 20/938 [01:18<03:25, 4.47it/s]Validation Results - Epoch: 1 Avg accuracy: 0.95 Avg loss: 0.18 EPOCH_COMPLETED took 11.310486555099487 seconds ITERATION - loss: 0.17: 940it [01:28, 93.41it/s]2022-03-02 16:47:12,944 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:47:21,974 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:47:21,975 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.17: 940it [01:37, 93.41it/s]2022-03-02 16:47:21,988 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 2 Avg accuracy: 0.96 Avg loss: 0.12 2022-03-02 16:47:23,261 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:47:23,262 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.17: 0%| | 0/938 [01:39<00:10, 93.41it/s]2022-03-02 16:47:23,288 trainer INFO: Epoch[2] Complete. Time taken: 00:00:21 ITERATION - loss: 0.27: 2%|▏ | 20/938 [01:39<03:29, 4.38it/s]Validation Results - Epoch: 2 Avg accuracy: 0.96 Avg loss: 0.11 EPOCH_COMPLETED took 10.273233652114868 seconds ITERATION - loss: 0.38: 940it [01:49, 89.07it/s]2022-03-02 16:47:33,696 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:47:42,290 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:47:42,292 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.38: 940it [01:58, 89.07it/s]2022-03-02 16:47:42,306 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 3 Avg accuracy: 0.97 Avg loss: 0.10 2022-03-02 16:47:43,587 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:47:43,588 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.38: 0%| | 0/938 [01:59<00:10, 89.07it/s]2022-03-02 16:47:43,613 trainer INFO: Epoch[3] Complete. Time taken: 00:00:20 ITERATION - loss: 0.25: 2%|▏ | 20/938 [01:59<03:21, 4.55it/s]Validation Results - Epoch: 3 Avg accuracy: 0.97 Avg loss: 0.09 EPOCH_COMPLETED took 10.404369115829468 seconds ITERATION - loss: 0.19: 940it [02:09, 91.47it/s]2022-03-02 16:47:54,072 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:48:02,718 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:48:02,719 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.19: 940it [02:18, 91.47it/s]2022-03-02 16:48:02,729 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 4 Avg accuracy: 0.97 Avg loss: 0.09 2022-03-02 16:48:04,009 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:48:04,011 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.19: 0%| | 0/938 [02:19<00:10, 91.47it/s]2022-03-02 16:48:04,032 trainer INFO: Epoch[4] Complete. Time taken: 00:00:20 ITERATION - loss: 0.26: 1%| | 10/938 [02:19<04:47, 3.22it/s]Validation Results - Epoch: 4 Avg accuracy: 0.98 Avg loss: 0.08 EPOCH_COMPLETED took 10.454880475997925 seconds ITERATION - loss: 0.31: 940it [02:30, 91.48it/s]2022-03-02 16:48:14,255 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:48:22,980 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:48:22,981 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.31: 940it [02:38, 91.48it/s]2022-03-02 16:48:22,994 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 5 Avg accuracy: 0.98 Avg loss: 0.07 2022-03-02 16:48:24,325 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:48:24,326 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.31: 0%| | 0/938 [02:40<00:10, 91.48it/s]2022-03-02 16:48:24,355 trainer INFO: Epoch[5] Complete. Time taken: 00:00:20 ITERATION - loss: 0.18: 1%| | 10/938 [02:40<04:51, 3.18it/s]Validation Results - Epoch: 5 Avg accuracy: 0.98 Avg loss: 0.07 EPOCH_COMPLETED took 10.218773365020752 seconds ITERATION - loss: 0.12: 99%|█████████▉| 930/938 [02:50<00:00, 91.54it/s]2022-03-02 16:48:34,791 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:48:43,298 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:48:43,299 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.12: 99%|█████████▉| 930/938 [02:59<00:00, 91.54it/s]2022-03-02 16:48:43,310 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 6 Avg accuracy: 0.98 Avg loss: 0.07 2022-03-02 16:48:44,563 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:48:44,564 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.12: 0%| | 0/938 [03:00<00:10, 91.54it/s]2022-03-02 16:48:44,584 trainer INFO: Epoch[6] Complete. Time taken: 00:00:20 ITERATION - loss: 0.17: 2%|▏ | 20/938 [03:00<03:19, 4.61it/s]Validation Results - Epoch: 6 Avg accuracy: 0.98 Avg loss: 0.06 EPOCH_COMPLETED took 10.432302474975586 seconds ITERATION - loss: 0.18: 940it [03:10, 90.58it/s]2022-03-02 16:48:55,125 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:49:04,211 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:49:04,212 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.18: 940it [03:20, 90.58it/s]2022-03-02 16:49:04,224 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 7 Avg accuracy: 0.98 Avg loss: 0.06 ITERATION - loss: 0.18: 940it [03:21, 90.58it/s]2022-03-02 16:49:05,553 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:49:05,554 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.18: 0%| | 0/938 [03:21<00:10, 90.58it/s]2022-03-02 16:49:05,581 trainer INFO: Epoch[7] Complete. Time taken: 00:00:21 ITERATION - loss: 0.18: 2%|▏ | 20/938 [03:21<03:32, 4.33it/s]Validation Results - Epoch: 7 Avg accuracy: 0.98 Avg loss: 0.06 EPOCH_COMPLETED took 10.538563013076782 seconds ITERATION - loss: 0.32: 940it [03:32, 84.84it/s]2022-03-02 16:49:16,426 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:49:25,375 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:49:25,376 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.32: 940it [03:41, 84.84it/s]2022-03-02 16:49:25,391 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 8 Avg accuracy: 0.98 Avg loss: 0.06 2022-03-02 16:49:26,754 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:49:26,755 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.32: 0%| | 0/938 [03:42<00:11, 84.84it/s]2022-03-02 16:49:26,778 trainer INFO: Epoch[8] Complete. Time taken: 00:00:21 ITERATION - loss: 0.16: 1%| | 10/938 [03:42<04:59, 3.10it/s]Validation Results - Epoch: 8 Avg accuracy: 0.98 Avg loss: 0.06 EPOCH_COMPLETED took 10.840869188308716 seconds ITERATION - loss: 0.14: 940it [03:53, 88.55it/s]2022-03-02 16:49:37,510 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:49:46,327 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:49:46,328 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.14: 940it [04:02, 88.55it/s]2022-03-02 16:49:46,340 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 9 Avg accuracy: 0.98 Avg loss: 0.06 2022-03-02 16:49:47,699 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:49:47,701 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.14: 0%| | 0/938 [04:03<00:10, 88.55it/s]2022-03-02 16:49:47,727 trainer INFO: Epoch[9] Complete. Time taken: 00:00:21 ITERATION - loss: 0.25: 1%| | 10/938 [04:03<04:55, 3.14it/s]Validation Results - Epoch: 9 Avg accuracy: 0.98 Avg loss: 0.05 EPOCH_COMPLETED took 10.728539943695068 seconds ITERATION - loss: 0.23: 940it [04:14, 91.41it/s]2022-03-02 16:49:58,335 evaluator INFO: Engine run starting with max_epochs=1. 2022-03-02 16:50:07,454 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:09 2022-03-02 16:50:07,455 evaluator INFO: Engine run complete. Time taken: 00:00:09 ITERATION - loss: 0.23: 940it [04:23, 91.41it/s]2022-03-02 16:50:07,465 evaluator INFO: Engine run starting with max_epochs=1. Training Results - Epoch: 10 Avg accuracy: 0.98 Avg loss: 0.05 2022-03-02 16:50:08,820 evaluator INFO: Epoch[1] Complete. Time taken: 00:00:01 2022-03-02 16:50:08,822 evaluator INFO: Engine run complete. Time taken: 00:00:01 ITERATION - loss: 0.23: 0%| | 0/938 [04:24<00:10, 91.41it/s]2022-03-02 16:50:08,838 trainer INFO: Epoch[10] Complete. Time taken: 00:00:21 ITERATION - loss: 0.23: 0%| | 0/938 [04:24<00:10, 91.41it/s]2022-03-02 16:50:08,848 trainer INFO: Engine run complete. Time taken: 00:03:28 Validation Results - Epoch: 10 Avg accuracy: 0.99 Avg loss: 0.05 EPOCH_COMPLETED took 10.59480333328247 seconds COMPLETED took 207.60126090049744 seconds CPU times: user 3min 25s, sys: 1.49 s, total: 3min 26s Wall time: 3min 27s
以上