PyTorch Ignite 0.4.2 : コンセプト (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/10/2021 (0.4.2)
* 本ページは、PyTorch Ignite ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
PyTorch Ignite 0.4.2 : コンセプト
エンジン
フレームワークのエッセンスはクラス Engine で、これは提供されたデータに渡り与えられた数の回数をループし、処理関数を実行して結果を返す抽象です :
while epoch < max_epochs: # run an epoch on data data_iter = iter(data) while True: try: batch = next(data_iter) output = process_function(batch) iter_counter += 1 except StopIteration: data_iter = iter(data) if iter_counter == epoch_length: break
従って、モデル trainer は単純に訓練データセットに渡り複数回ループしてモデル・パラメータを更新するエンジンです。同様に、モデル評価は検証データセットに渡り単一回数実行してメトリクスを計算するエンジンにより成されます。
例えば、教師ありタスクのためのモデル trainer は :
def train_step(trainer, batch): model.train() optimizer.zero_grad() x, y = prepare_batch(batch) y_pred = model(x) loss = loss_fn(y_pred, y) loss.backward() optimizer.step() return loss.item() trainer = Engine(train_step) trainer.run(data, max_epochs=100)
訓練ステップの出力のタイプ (i.e. 上のサンプルでは loss.item()) は制限されません。訓練ステップ関数はユーザが望む総てを返すことができます。出力は trainer.state.output に設定されて更に任意のタイプの処理のために利用できます。
NOTE :
trainer.run(data, max_epochs=100, epoch_length=200)
データが (ユーザにとって) 未知の長さを持つ有限のデータ iterator である場合、引数 epoch_length は省略できてそれはデータ iterator が使い尽くされたときに自動的に決定されます。
大抵は任意の複雑性訓練ロジック は train_step メソッドとこのメソッドを使用して構築された trainer としてコーディングされます。train_step 関数の引数 batch はユーザ定義で単一の反復のために必要な任意のデータを含むことができます。
model_1 = ... model_2 = ... # ... optimizer_1 = ... optimizer_2 = ... # ... criterion_1 = ... criterion_2 = ... # ... def train_step(trainer, batch): data_1 = batch["data_1"] data_2 = batch["data_2"] # ... model_1.train() optimizer_1.zero_grad() loss_1 = forward_pass(data_1, model_1, criterion_1) loss_1.backward() optimizer_1.step() # ... model_2.train() optimizer_2.zero_grad() loss_2 = forward_pass(data_2, model_2, criterion_2) loss_2.backward() optimizer_2.step() # ... # User can return any type of structure. return { "loss_1": loss_1, "loss_2": loss_2, # ... } trainer = Engine(train_step) trainer.run(data, max_epochs=100)
GAN のようなマルチモデル訓練サンプルについては、サンプル を見てください。
イベントとハンドラ
Engine の柔軟性を改良するために、イベントシステムが導入されました、これは以下の実行の各ステップ上の相互作用を容易にします :
- エンジンが開始 (= started) /完了 (= completed) する
- エポックが開始/完了する
- バッチ反復が開始/完了する
イベントの完全なリストは Events で見つけられます。
このように、ユーザはカスタムコードをイベントハンドラとして実行できます。ハンドラな任意の関数であり得ます : e.g. lambda, 単純な関数, クラスメソッド等。最初の引数はオプションで engine であり得ますが、必要ではありません。
run() が呼び出されたときに何が起きるかより詳細に考えましょう :
fire_event(Events.STARTED) while epoch < max_epochs: fire_event(Events.EPOCH_STARTED) # run once on data for batch in data: fire_event(Events.ITERATION_STARTED) output = process_function(batch) fire_event(Events.ITERATION_COMPLETED) fire_event(Events.EPOCH_COMPLETED) fire_event(Events.COMPLETED)
最初に「エンジンが開始される」イベントが発火されてそしてその総てのイベントハンドラが実行されます (イベントハンドラをどのように追加されるかを次のパラグラフで見ます)。次に、"while" ループが開始されて「エポックが開始される」イベントが派生します、等。イベントが「発火される (= fired)」たびに、装着されたハンドラが実行されます。
イベントを装着するには単純にメソッド add_event_handler() or on() デコレータを使用します :
trainer = Engine(update_model) trainer.add_event_handler(Events.STARTED, lambda _: print("Start training")) # or @trainer.on(Events.STARTED) def on_training_started(engine): print("Another message of start training") # or even simpler, use only what you need ! @trainer.on(Events.STARTED) def on_training_started(): print("Another message of start training") # attach handler with args, kwargs mydata = [1, 2, 3, 4] def on_training_ended(data): print("Training is ended. mydata={}".format(data)) trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
イベントハンドラは remove_event_handler() or (add_event_handler() により返される) RemovableEventHandle リファレンスを通してデタッチされます。これはマルチループのために configured エンジンを再利用するために使用できます :
model = ... train_loader, validation_loader, test_loader = ... trainer = create_supervised_trainer(model, optimizer, loss) evaluator = create_supervised_evaluator(model, metrics={"acc": Accuracy()}) def log_metrics(engine, title): print("Epoch: {} - {} accuracy: {:.2f}" .format(trainer.state.epoch, title, engine.state.metrics["acc"])) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(trainer): with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "train"): evaluator.run(train_loader) with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "validation"): evaluator.run(validation_loader) with evaluator.add_event_handler(Events.COMPLETED, log_metrics, "test"): evaluator.run(test_loader) trainer.run(train_loader, max_epochs=100)
イベントハンドラはまたユーザパターンにより呼び出されるように configure することもできます : 総ての n-th イベント、一度だけあるいはカスタムイベント・フィルタリング関数を使用して :
model = ... train_loader, validation_loader, test_loader = ... trainer = create_supervised_trainer(model, optimizer, loss) @trainer.on(Events.ITERATION_COMPLETED(every=50)) def log_training_loss_every_50_iterations(): print("{} / {} : {} - loss: {:.2f}" .format(trainer.state.epoch, trainer.state.max_epochs, trainer.state.iteration, trainer.state.output)) @trainer.on(Events.EPOCH_STARTED(once=25)) def do_something_once_on_25_epoch(): # do something def custom_event_filter(engine, event): if event in [1, 2, 5, 10, 50, 100]: return True return False @engine.on(Events.ITERATION_STARTED(event_filter=custom_event_filter)) def call_on_special_event(engine): # do something on 1, 2, 5, 10, 50, 100 iterations trainer.run(train_loader, max_epochs=100)
ユーザはまたカスタムイベントを定義できます。ユーザにより定義されるイベントは EventEnum を継承して engine 内に register_events() で登録されるべきです。
class CustomEvents(EventEnum): """ Custom events defined by user """ CUSTOM_STARTED = 'custom_started' CUSTOM_COMPLETED = 'custom_completed' engine.register_events(*CustomEvents)
これらのイベントは任意のハンドラを装着して使用できるでしょう、そして fire_event() を使用して発火されます。
@engine.on(CustomEvents.CUSTOM_STARTED) def call_on_custom_event(engine): # do something @engine.on(Events.STARTED) def fire_custom_events(engine): engine.fire_event(CustomEvents.CUSTOM_STARTED)
NOTE : カスタムイベントの使用方法のサンプルについては create_supervised_tbptt_trainer のソースコードを見てください。
ハンドラ
ライブラリは、訓練パイプラインをチェックポイントし、ベストモデルをセーブし、改善がなければ訓練を停止し、実験的な追跡システムを使用する等のために組込みハンドラのセットを提供します。それらは以下の 2 つのモジュールで見つかります :
幾つかのクラスは callable 関数として単純に Engine に追加できます。例えば :
from ignite.handlers import TerminateOnNan trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
他は幾つかのハンドラを Engine に内部的に追加するために attach メソッドを提供します :
from ignite.contrib.handlers.tensorboard_logger import * # Create a logger tb_logger = TensorboardLogger(log_dir="experiments/tb_logs") # Attach the logger to the trainer to log model's weights as a histogram after each epoch tb_logger.attach( trainer, event_name=Events.EPOCH_COMPLETED, log_handler=WeightsHistHandler(model) )
タイムラインとイベント
下でイベントと幾つかの典型的なハンドラが (総てのエポック後に評価を伴う) 訓練ループのためのタイムライン上で表示されています :
State
state は process_function の出力、現在のエポック、反復そして他の役立つ情報をストアするために Engine に導入されます。各エンジンは State を含み、これは以下を含みます :
- engine.state.seed : Seed to set at each data “epoch”.
- engine.state.epoch : エンジンが完了したエポック数。0 として初期化されて最初のエポックは 1 です。
- engine.state.iteration: エンジンが完了した反復数。0 として初期化されて最初の反復は 1 です。
- engine.state.max_epochs: 実行するエポック数。1 として初期化されます。
- engine.state.output: エンジンのために定義された process_function の出力。下を見てください。
- 等々
他の属性は State の docs で見つかります。
下のコードでは、engine.state.output はバッチ損失をストアします。出力は総ての反復で損失をプリントするために使用されます。
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() def on_iteration_completed(engine): iteration = engine.state.iteration epoch = engine.state.epoch loss = engine.state.output print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, iteration, loss)) trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)
process_function の出力上の制限はありませんので、Ignite はそのメトリクスとハンドラのために output_transform を提供します。output_transform は engine.state.output を目的のために変換するために利用される関数です。下で様々なタイプの engine.state.output とそれらをどのように変換するかを見ます。
下のコードでは、engine.state.output は処理されたバッチのための損失、y_pred, y のリストです。Accuracy をエンジンに装着することを望む場合、output_transform は engine.state.output から y_pred と y を得るために必要とされます。それがどのように成されるか見ましょう :
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item(), y_pred, y trainer = Engine(update) @trainer.on(Events.EPOCH_COMPLETED) def print_loss(engine): epoch = engine.state.epoch loss = engine.state.output[0] print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss)) accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]]) accuracy.attach(trainer, 'acc') trainer.run(data, max_epochs=10)
上に類似していますが、今回は process_function の出力は処理されたバッチに対して loss, y_pred, y の辞書で、これは engine.state.output から y_pred と y を得るためにユーザは output_transform をどのように使用できるかです。下を見てください :
def update(engine, batch): x, y = batch y_pred = model(inputs) loss = loss_fn(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return {'loss': loss.item(), 'y_pred': y_pred, 'y': y} trainer = Engine(update) @trainer.on(Events.EPOCH_COMPLETED) def print_loss(engine): epoch = engine.state.epoch loss = engine.state.output['loss'] print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss)) accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']]) accuracy.attach(trainer, 'acc') trainer.run(data, max_epochs=10)
NOTE :
def user_handler_function(engine): engine.state.new_attribute = 12345
メトリクス
ライブラリは様々な機械学習タスクのために out-of-the-box なメトリクスのリストを提供します。メトリクスを計算する 2 つの方法がサポートされます : 1) オンライン、そして 2) 出力履歴全体をストアする。
メトリクスはエンジンに装着できます :
from ignite.metrics import Accuracy accuracy = Accuracy() accuracy.attach(evaluator, "accuracy") state = evaluator.run(validation_data) print("Result:", state.metrics) # > {"accuracy": 0.12345}
あるいはスタンドアロン・オブジェクトとして利用できます :
from ignite.metrics import Accuracy accuracy = Accuracy() accuracy.reset() for y_pred, y in get_prediction_target(): accuracy.update((y_pred, y)) print("Result:", accuracy.compute())
メトリクスの完全なリストと API は ignite.metrics モジュールで見つけられます。
Where to go next? サンプル とチュートリアル・ノートブックを確認してください。
以上