PyTorch Ignite 0.4.8 : ガイド : タイムプロファイリングの方法

PyTorch Ignite 0.4.8 : How To ガイド : タイムプロファイリングの方法 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/12/2022 (0.4.8)

* 本ページは、Pytorch Ignite AI の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

ガイド : タイムプロファイリングの方法

このサンプルは以下のための時間の分析を取得する方法を実演します :

  • 訓練の間の個々のエポック
  • 合計の訓練時間
  • 個々のイベント
  • イベントに対応する総てのハンドラ
  • 個々のハンドラ
  • データロードとデータ処理

このサンプルでは、MNIST データセット上で ResNet18 モデルを使用していきます。ベースコードは Getting Start ガイド で使用されものと同じです。

 

基本的なセットアップ

import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Timer, BasicTimeProfiler, HandlersTimeProfiler
torch.cuda.is_available()
True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        return self.model(x)


model = Net().to(device)

data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=True),
    batch_size=128,
    shuffle=True,
)

val_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=False),
    batch_size=256,
    shuffle=False,
)

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()

総てのエポックの最後に訓練と検証データセットのメトリクス ( AccuracyLoss) を表示するために 2 つのハンドラを trainer に装着します。

trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
evaluator = create_supervised_evaluator(
    model, metrics={"accuracy": Accuracy(), "loss": Loss(criterion)}, device=device
)


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print(
        f"Training Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"
    )


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print(
        f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['accuracy']:.2f} Avg loss: {metrics['loss']:.2f}"
    )

 

イベントの状態の使用

総てのエポックにかかった時間と訓練のための合計時間をプリントすることを望むだけなら、単純に trainer の State を利用できます。trainer.state.times により返される時間をログ記録するため、エポックが完了した時と訓練が完了した時にトリガーされる 2 つの別個のハンドラを装着します。

@trainer.on(Events.EPOCH_COMPLETED)
def log_epoch_time():
    print(
        f"Epoch {trainer.state.epoch}, Time Taken : {trainer.state.times['EPOCH_COMPLETED']}"
    )


@trainer.on(Events.COMPLETED)
def log_total_time():
    print(f"Total Time: {trainer.state.times['COMPLETED']}")
trainer.run(train_loader, max_epochs=2)
Training Results - Epoch[1] Avg accuracy: 0.97 Avg loss: 0.11
Validation Results - Epoch[1] Avg accuracy: 0.97 Avg loss: 0.10
Epoch 1, Time Taken : 31.281248569488525
Training Results - Epoch[2] Avg accuracy: 0.99 Avg loss: 0.05
Validation Results - Epoch[2] Avg accuracy: 0.98 Avg loss: 0.05
Epoch 2, Time Taken : 30.54600954055786
Total Time: 107.31757092475891

State:
	iteration: 938
	epoch: 2
	epoch_length: 469
	max_epochs: 2
	output: 0.013461492024362087
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

 

BasicTimeProfiler を使用したイベントベースのプロファイリング

データ処理, データロードと総ての事前定義されたイベントにかかる時間のような情報を更に望む場合、BasicTimeProfiler() を使用できます。

# Attach basic profiler
basic_profiler = BasicTimeProfiler()
basic_profiler.attach(trainer)

trainer.run(train_loader, max_epochs=2)
Training Results - Epoch[1] Avg accuracy: 0.99 Avg loss: 0.04
Validation Results - Epoch[1] Avg accuracy: 0.99 Avg loss: 0.04
Epoch 1, Time Taken : 30.6413791179657
Training Results - Epoch[2] Avg accuracy: 0.97 Avg loss: 0.10
Validation Results - Epoch[2] Avg accuracy: 0.97 Avg loss: 0.11
Epoch 2, Time Taken : 30.38310170173645
Total Time: 106.3881447315216

State:
	iteration: 938
	epoch: 2
	epoch_length: 469
	max_epochs: 2
	output: 0.0808301642537117
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

そして get_results() を通して結果辞書を得てそれを print_results() に渡してきれいにフォーマットされた結果を得ることができます、これはかかった時間の合計, 最少, 最大, 平均と標準偏差を含みます。

results = basic_profiler.get_results()
basic_profiler.print_results(results);
 ----------------------------------------------------
| Time profiling stats (in seconds):                 |
 ----------------------------------------------------
total  |  min/index  |  max/index  |  mean  |  std

Processing function:
28.62366 | 0.02439/937 | 0.05147/0 | 0.03052 | 0.00191

Dataflow:
32.23854 | 0.02618/936 | 0.15481/937 | 0.03437 | 0.00455

Event handlers:
45.38009

- Events.STARTED: []
0.00001

- Events.EPOCH_STARTED: []
0.00001 | 0.00000/0 | 0.00000/1 | 0.00000 | 0.00000

- Events.ITERATION_STARTED: []
0.00246 | 0.00000/351 | 0.00003/609 | 0.00000 | 0.00000

- Events.ITERATION_COMPLETED: []
0.00556 | 0.00000/12 | 0.00002/653 | 0.00001 | 0.00000

- Events.EPOCH_COMPLETED: ['log_training_results', 'log_validation_results', 'log_epoch_time']
45.36316 | 22.66037/1 | 22.70279/0 | 22.68158 | 0.02999

- Events.COMPLETED: ['log_total_time']
0.00004

Note : このアプローチは個々のハンドラによりかかった時間ではなく、事前定義されたイベントに対応する総てのハンドラによりかかった時間の合計を取得します。

 

HandlersTimeProfiler を使用したハンドラベースのプロファイリング

HandlersTimeProfiler を使用して上記の問題を解決できます、これは必要な情報だけを与えます。これを通して、前は可能ではなかった、Custom Event に装着されたハンドラにかかった時間を計算することもできます。

# Attach handlers profiler
handlers_profiler = HandlersTimeProfiler()
handlers_profiler.attach(trainer)
trainer.run(train_loader, max_epochs=2)
Training Results - Epoch[1] Avg accuracy: 0.99 Avg loss: 0.02
Validation Results - Epoch[1] Avg accuracy: 0.99 Avg loss: 0.03
Epoch 1, Time Taken : 30.685564279556274
Training Results - Epoch[2] Avg accuracy: 1.00 Avg loss: 0.01
Validation Results - Epoch[2] Avg accuracy: 0.99 Avg loss: 0.03
Epoch 2, Time Taken : 30.860342502593994
Total Time: 107.25911617279053

State:
	iteration: 938
	epoch: 2
	epoch_length: 469
	max_epochs: 2
	output: 0.005279005039483309
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

上記と同様にプロファイラの結果をプリントできます。出力は装着された書くハンドラのための実行時間の合計, 平均とその他の詳細を示します。それはまたデータ処理とデータロード時間も表示します。

results = handlers_profiler.get_results()
handlers_profiler.print_results(results)
---------------------------------------  -------------------  --------------  --------------  --------------  --------------  --------------  
Handler                                  Event Name                 Total(s)      Min(s)/IDX      Max(s)/IDX         Mean(s)          Std(s)  
---------------------------------------  -------------------  --------------  --------------  --------------  --------------  --------------  
log_training_results                     EPOCH_COMPLETED            39.35234      19.31905/0      20.03329/1        19.67617         0.50504  
log_validation_results                   EPOCH_COMPLETED             6.35954       3.16563/0       3.19391/1         3.17977            0.02  
log_epoch_time                           EPOCH_COMPLETED               7e-05         3e-05/1         3e-05/0           3e-05             0.0  
BasicTimeProfiler._as_first_started      STARTED                     0.00034       0.00034/0       0.00034/0         0.00034            None  
log_total_time                           COMPLETED                     4e-05         4e-05/0         4e-05/0           4e-05            None  
---------------------------------------  -------------------  --------------  --------------  --------------  --------------  --------------  
Total                                                               45.71233                                                                  
---------------------------------------  -------------------  --------------  --------------  --------------  --------------  --------------  
Processing took total 29.2974s [min/index: 0.0238s/468, max/index: 0.06095s/726, mean: 0.03123s, std: 0.00228s]
Dataflow took total 32.09461s [min/index: 0.02433s/468, max/index: 0.06684s/1, mean: 0.03422s, std: 0.00291s]

basic_profiler と handler_profiler により得られたプロファイリング結果は write_results メソッドを使用して CSV ファイルにエクスポートできます。

basic_profiler.write_results("./basic_profile.csv")
handlers_profiler.write_results("./handlers_profile.csv")

basic_profiler の CSV ファイルを調べれば、イテレーション毎にストアされている情報の深さが分かります。

basic_profile = pd.read_csv("./basic_profile.csv")
basic_profile.head()
epoch	iteration	processing_stats	dataflow_stats	Event_STARTED	Event_COMPLETED	Event_EPOCH_STARTED	Event_EPOCH_COMPLETED	Event_ITERATION_STARTED	Event_ITERATION_COMPLETED	Event_GET_BATCH_STARTED	Event_GET_BATCH_COMPLETED
0	1.0	1.0	0.037031	0.066874	0.000017	0.000084	0.000003	22.484756	0.000005	0.000010	0.000006	0.000013
1	1.0	2.0	0.034586	0.039192	0.000017	0.000084	0.000003	22.484756	0.000005	0.000011	0.000006	0.000009
2	1.0	3.0	0.033999	0.034169	0.000017	0.000084	0.000003	22.484756	0.000005	0.000009	0.000012	0.000008
3	1.0	4.0	0.033792	0.034108	0.000017	0.000084	0.000003	22.484756	0.000004	0.000009	0.000005	0.000009
4	1.0	5.0	0.033714	0.034156	0.000017	0.000084	0.000003	22.484756	0.000004	0.000011	0.000006	0.000008

handlers_profile CSV は (行番号に対応している) ハンドラが起動された時ごとのための詳細をストアしています。

handlers_profile = pd.read_csv("./handlers_profile.csv")
handlers_profile.head()
#	processing_stats	dataflow_stats	log_training_results (EPOCH_COMPLETED)	log_validation_results (EPOCH_COMPLETED)	log_epoch_time (EPOCH_COMPLETED)	BasicTimeProfiler._as_first_started (STARTED)	log_total_time (COMPLETED)
0	1.0	0.037088	0.054261	19.319054	3.165631	0.000034	0.000342	0.000036
1	2.0	0.034641	0.066836	20.033289	3.193913	0.000032	0.000000	0.000000
2	3.0	0.034053	0.039158	0.000000	0.000000	0.000000	0.000000	0.000000
3	4.0	0.033844	0.034130	0.000000	0.000000	0.000000	0.000000	0.000000
4	5.0	0.033771	0.034076	0.000000	0.000000	0.000000	0.000000	0.000000

 

Timer を使用したカスタム・プロファイリング

抽象化の最低位では、任意のイベントのセットの間の時間を計算するために Timer() を提供しています。詳細はその docstring を見てください。

 

訓練経過時間

Timer() は例えば、訓練の間の経過した訓練時間を計算するために使用できます。

elapsed_time = Timer()

elapsed_time.attach(
    trainer,
    start=Events.STARTED,         # Start timer at the beginning of training
    resume=Events.EPOCH_STARTED,  # Resume timer at the beginning of each epoch
    pause=Events.EPOCH_COMPLETED, # Pause timer at the end of each epoch
    step=Events.EPOCH_COMPLETED,  # Step (update) timer at the end of each epoch
)

@trainer.on(Events.EPOCH_COMPLETED)
def log_elapsed_time(trainer):
    print(f"   Elapsed time: {elapsed_time.value()}")

trainer.run(train_loader, max_epochs=2)
Training Results - Epoch[1] Avg accuracy: 0.99 Avg loss: 0.02
Validation Results - Epoch[1] Avg accuracy: 0.99 Avg loss: 0.04
Epoch 1, Time Taken : 30.887796878814697
   Elapsed time: 53.353810481959954
Training Results - Epoch[2] Avg accuracy: 1.00 Avg loss: 0.01
Validation Results - Epoch[2] Avg accuracy: 0.99 Avg loss: 0.03
Epoch 2, Time Taken : 31.164958238601685
   Elapsed time: 107.81696200894658
Total Time: 107.8185646533966

State:
	iteration: 938
	epoch: 2
	epoch_length: 469
	max_epochs: 2
	output: 0.00048420054372400045
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>
 

以上