PyTorch Ignite 0.4.8 : コンセプト (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/02/2022 (0.4.8)
* 本ページは、PyTorch Ignite の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
PyTorch Ignite 0.4.8 : コンセプト
エンジン
フレームワークの本質はクラス Engine で、これは提供されたデータに渡り与えられた数の回数をループし、処理関数を実行して結果を返す抽象です :
while epoch < max_epochs:
# run an epoch on data
data_iter = iter(data)
while True:
try:
batch = next(data_iter)
output = process_function(batch)
iter_counter += 1
except StopIteration:
data_iter = iter(data)
if iter_counter == epoch_length:
break
そして、モデル trainer は単純に訓練データセットに渡り複数回ループしてモデルパラメータを更新するエンジンです。同様に、モデル評価は検証データセットに渡り単一回実行してメトリクスを計算するエンジンにより成されます。
例えば、教師ありタスクのためのモデル trainer は :
def train_step(trainer, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
訓練ステップの出力のタイプ (i.e. 上のサンプルでは loss.item()) は制限されません。訓練ステップ関数はユーザが望む総てを返すことができます。出力は trainer.state.output に設定されて更に任意のタイプの処理のために利用できます。
NOTE :
trainer.run(data, max_epochs=100, epoch_length=200)
データが (ユーザに対して) 未知の長さを持つ有限のデータ iterator である場合、引数 epoch_length は省略できてそれはデータ iterator が使い尽くされたときに自動的に決定されます。
任意の複雑さの訓練ロジック は train_step メソッドでコーディングできて trainer はこのメソッドを使用して構築できます。train_step 関数の引数 batch はユーザ定義で単一の反復のために必要な任意のデータを含むことができます。
model_1 = ...
model_2 = ...
# ...
optimizer_1 = ...
optimizer_2 = ...
# ...
criterion_1 = ...
criterion_2 = ...
# ...
def train_step(trainer, batch):
data_1 = batch["data_1"]
data_2 = batch["data_2"]
# ...
model_1.train()
optimizer_1.zero_grad()
loss_1 = forward_pass(data_1, model_1, criterion_1)
loss_1.backward()
optimizer_1.step()
# ...
model_2.train()
optimizer_2.zero_grad()
loss_2 = forward_pass(data_2, model_2, criterion_2)
loss_2.backward()
optimizer_2.step()
# ...
# User can return any type of structure.
return {
"loss_1": loss_1,
"loss_2": loss_2,
# ...
}
trainer = Engine(train_step)
trainer.run(data, max_epochs=100)
GAN のようなマルチモデル訓練サンプルについては、サンプル を見てください。
イベントとハンドラ
Engine の柔軟性を改良するために、イベントシステムが導入されました、これは以下の実行の各ステップでの相互作用を容易にします :
- エンジンが開始 (= started) /終了 (= completed) する
- エポックが開始/終了する
- バッチ反復が開始/終了する
イベントの完全なリストは Events で見つかります。
そして、ユーザはカスタムコードをイベントハンドラとして実行できます。ハンドラな任意の関数であり得ます : e.g. lambda, 単純な関数, クラスメソッド等。最初の引数はオプションで engine であり得ますが、必ずしも必要ではありません。
run() が呼び出されたときに何が起きるかより詳細に考えましょう :
fire_event(Events.STARTED)
while epoch < max_epochs:
fire_event(Events.EPOCH_STARTED)
# run once on data
for batch in data:
fire_event(Events.ITERATION_STARTED)
output = process_function(batch)
fire_event(Events.ITERATION_COMPLETED)
fire_event(Events.EPOCH_COMPLETED)
fire_event(Events.COMPLETED)
最初に「エンジンが開始される」イベントが発火されてそしてその総てのイベントハンドラが実行されます (イベントハンドラを追加する方法は次のパラグラフで見ます)。次に、"while" ループが開始されて「エポックが開始される」イベントが派生します、等。イベントが「発火される (= fired)」たびに、装着されたハンドラが実行されます。
イベントを装着するには単純にメソッド add_event_handler() か on() デコレータを使用します :
trainer = Engine(update_model)
trainer.add_event_handler(Events.STARTED, lambda _: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):
print("Another message of start training")
# or even simpler, use only what you need !
@trainer.on(Events.STARTED)
def on_training_started():
print("Another message of start training")
# attach handler with args, kwargs
mydata = [1, 2, 3, 4]
def on_training_ended(data):
print(f"Training is ended. mydata={data}")
trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
イベントハンドラは remove_event_handler() か (add_event_handler() により返される) RemovableEventHandle 参照を通してデタッチされます。これはマルチループのために configured エンジンを再利用するために使用できます :
model = ...
train_loader, validation_loader, test_loader = ...
trainer = create_supervised_trainer(model, optimizer, loss)
evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()})
def log_metrics(engine, title):
print(f"Epoch: {trainer.state.epoch} - {title} accuracy: {engine.state.metrics['acc']:.2f}")
@trainer.on(Events.EPOCH_COMPLETED)
def evaluate(trainer):
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"):
evaluator.run(train_loader)
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"):
evaluator.run(validation_loader)
with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"):
evaluator.run(test_loader)
trainer.run(train_loader, max_epochs=100)
イベントハンドラはまたユーザパターンにより呼び出されるように configure することもできます : 総ての n-th イベント、一度だけあるいはカスタムイベント・フィルタリング関数を使用して :
model = ...
train_loader, validation_loader, test_loader = ...
trainer = create_supervised_trainer(model, optimizer, loss)
@trainer.on(Events.ITERATION_COMPLETED(every=50))
def log_training_loss_every_50_iterations():
print(f"{trainer.state.epoch} / {trainer.state.max_epochs} : {trainer.state.iteration} - loss: {trainer.state.output:.2f}")
@trainer.on(Events.EPOCH_STARTED(once=25))
def do_something_once_on_25_epoch():
# do something
def custom_event_filter(engine, event):
if event in [1, 2, 5, 10, 50, 100]:
return True
return False
@engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter))
def call_on_special_event(engine):
# do something on 1, 2, 5, 10, 50, 100 iterations
trainer.run(train_loader, max_epochs=100)
カスタム・イベント
ユーザはまたカスタムイベントを定義することもできます。ユーザにより定義されるイベントは EventEnum を継承して engine 内に register_events() で登録されるべきです。
from ignite.engine import EventEnum
class CustomEvents(EventEnum):
"""
Custom events defined by user
"""
CUSTOM_STARTED = 'custom_started'
CUSTOM_COMPLETED = 'custom_completed'
engine.register_events(*CustomEvents)
これらのイベントは任意のハンドラを装着して使用できるでしょう、そして fire_event() を使用して発火されます。
@engine.on(CustomEvents.CUSTOM_STARTED)
def call_on_custom_event(engine):
# do something
@engine.on(Events.STARTED)
def fire_custom_events(engine):
engine.fire_event(CustomEvents.CUSTOM_STARTED)
NOTE : カスタムイベントの使用方法のサンプルについては create_supervised_tbptt_trainer のソースコードを見てください。
ハンドラ
ライブラリは、訓練パイプラインをチェックポイントし、ベストモデルをセーブし、改善がなければ訓練を停止し、実験的な追跡システムを使用する等のために組込みハンドラのセットを提供しています。それらは以下の 2 つのモジュールで見つかります :
幾つかのクラスは callable 関数として単純に Engine に追加できます。例えば :
from ignite.handlers import TerminateOnNan
trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
他は幾つかのハンドラを Engine に内部的に追加するために attach メソッドを提供します :
from ignite.contrib.handlers.tensorboard_logger import *
# Create a logger
tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")
# Attach the logger to the trainer to log model's weights as a histogram after each epoch
tb_logger.attach(
trainer,
event_name=Events.EPOCH_COMPLETED,
log_handler=WeightsHistHandler(model)
)
タイムラインとイベント
下でイベントと幾つかの典型的なハンドラが (総てのエポック後に評価を伴う) 訓練ループのためのタイムライン上で表示されています :
State
state は process_function の出力、現在のエポック、反復そして他の役立つ情報をストアするために Engine に導入されます。各エンジンは State を含み、これは以下を含みます :
- engine.state.seed : Seed to set at each data “epoch”.
- engine.state.epoch : エンジンが完了したエポック数。0 として初期化されて最初のエポックは 1 です。
- engine.state.iteration: エンジンが完了した反復数。0 として初期化されて最初の反復は 1 です。
- engine.state.max_epochs: 実行するエポック数。1 として初期化されます。
- engine.state.output: エンジンのために定義された process_function の出力。下を見てください。
- 等々
他の属性は State の docs で見つかります。
下のコードでは、engine.state.output はバッチ損失をストアします。この出力は総ての反復で損失をプリントするために使用されます。
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def on_iteration_completed(engine):
iteration = engine.state.iteration
epoch = engine.state.epoch
loss = engine.state.output
print(f"Epoch: {epoch}, Iteration: {iteration}, Loss: {loss}")
trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)
process_function の出力上の制限はありませんので、Ignite はその ignite.metrics と ignite.handlers のために output_transform 引数を提供します。引数 output_transform は目的のために engine.state.output を変換するために利用される関数です。下で様々なタイプの engine.state.output とそれらをどのように変換するかを見ます。
下のコードでは、engine.state.output は処理されたバッチのための損失、y_pred, y のリストです。Accuracy をエンジンに装着することを望む場合、output_transform は engine.state.output から y_pred と y を得るために必要とされます。それがどのように成されるか見ましょう :
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item(), y_pred, y
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
epoch = engine.state.epoch
loss = engine.state.output[0]
print (f'Epoch {epoch}: train_loss = {loss}')
accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
上に類似していますが、今回は process_function の出力は処理されたバッチに対して loss, y_pred, y の辞書で、これは engine.state.output から y_pred と y を得るためにユーザは output_transform をどのように使用できるかです。下を見てください :
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return {'loss': loss.item(),
'y_pred': y_pred,
'y': y}
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
epoch = engine.state.epoch
loss = engine.state.output['loss']
print (f'Epoch {epoch}: train_loss = {loss}')
accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
NOTE :
def user_handler_function(engine):
engine.state.new_attribute = 12345
メトリクス
ライブラリは様々な機械学習タスクのために out-of-the-box なメトリクスのリストを提供します。メトリクスを計算する 2 つの方法がサポートされます : 1) オンライン、そして 2) 出力履歴全体をストアする。
メトリクスはエンジンに装着できます :
from ignite.metrics import Accuracy
accuracy = Accuracy()
accuracy.attach(evaluator, "accuracy")
state = evaluator.run(validation_data)
print("Result:", state.metrics)
# > {"accuracy": 0.12345}
あるいはスタンドアロン・オブジェクトとして利用できます :
from ignite.metrics import Accuracy
accuracy = Accuracy()
accuracy.reset()
for y_pred, y in get_prediction_target():
accuracy.update((y_pred, y))
print("Result:", accuracy.compute())
メトリクスと API の完全なリストは ignite.metrics モジュールで見つけられます。
Where to go next? サンプル とチュートリアル・ノートブックを確認してください。
以上