PyTorch Lightning 1.1 : Lightning API : Trainer

PyTorch Lightning 1.1: Lightning API : Trainer (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/26/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

Lightning API : Trainer

貴方の PyTorch を LightningModule にひとたび体系化したならば、Trainer は他の総てを自動化します。

この抽象は以下を達成します :

  1. 追加の抽象なしに PyTorch コードを通して総ての局面に渡り制御を維持します。
  2. Facebook AI Research, NYU, MIT, Stanford 等… のようなトップ AI ラボからの contributors とユーザにより埋め込まれたベストプラクティスを trainer は利用します。
  3. trainer は自動化されることを望まない任意の主要パートを override することを可能にします。

 

基本的な使用方法

これは trainer の基本的な使用法です :

model = MyLightningModule()

trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)

 

Under the hood

内部的には、Lightning Trainer は訓練ループの詳細を処理します、幾つかのサンプルは以下を含みます :

  • 自動的に勾配を有効/無効にします。
  • 訓練、検証とテストデータローダを実行します。
  • 適切な時にコールバックを呼び出します。
  • 正しいデバイスにバッチと計算を配置します。

ここに trainer が内部的に行なうもののための疑似コードがあります (訓練ループのみ示します)

# put model in train mode
model.train()
torch.set_grad_enabled(True)

losses = []
for batch in train_dataloader:
    # calls hooks like this one
    on_train_batch_start()

    # train step
    loss = training_step(batch)

    # backward
    loss.backward()

    # apply and clear grads
    optimizer.step()
    optimizer.zero_grad()

    losses.append(loss)

 

Python スクリプトの Trainer

Python スクリプトでは、Trainer を呼ぶために main 関数を使用することが推奨されます。

from argparse import ArgumentParser

def main(hparams):
    model = LightningModule()
    trainer = Trainer(gpus=hparams.gpus)
    trainer.fit(model)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--gpus', default=None)
    args = parser.parse_args()

    main(args)

そしてそれをこのように実行できます :

python main.py --gpus 2

NOTE

Pro-tip: 総てのフラグを手動で定義する必要はありません。Lightning はそれらを自動的に追加できます

from argparse import ArgumentParser

def main(args):
    model = LightningModule()
    trainer = Trainer.from_argparse_args(args)
    trainer.fit(model)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser = Trainer.add_argparse_args(
        # group the Trainer arguments together
        parser.add_argument_group(title="pl.Trainer args")
    )
    args = parser.parse_args()

    main(args)

そしてそれをこのように実行できます :

python main.py --gpus 2 --max_steps 10 --limit_train_batches 10 --any_trainer_arg x

Note

訓練実行を早期に停止することを望む場合、キーボードで “Ctrl + C” を押すことができます。trainer は KeyboardInterrupt を捕捉して on_train_end のようなコールバックの実行を含む、graceful シャットダウンを試みます。trainer オブジェクトはまたそのような場合に属性 interrupted を True に設定します。計算リソースをシャットダウンするコールバックを持つ場合、例えば、uninterrupted 実行のためだけにシャットダウン・ロジックを条件付きで実行できます。

 

テスト

訓練をひとたび行えば、自由にテストセットを実行してください!(Only right before publishing your paper or pushing to production)

trainer.test(test_dataloaders=test_dataloader)

 

配備 / 予測

単に LightningModule を訓練しました、これはまた単に torch.nn.Module です。Use it to do whatever!

# load model
pretrained_model = LightningModule.load_from_checkpoint(PATH)
pretrained_model.freeze()

# use it for finetuning
def forward(self, x):
    features = pretrained_model(x)
    classes = classifier(features)

# or for prediction
out = pretrained_model(x)
api_write({'response': out}

様々なデバイス上でモデルを実行したいかもしれません。データを手動で正しいデバイスに移動する代わりに、forward メソッド (あるいは推論のために使用する任意の他のメソッド) を auto_move_data() でデコレートします、そして Lightning は残りを処理します。

 

再現性

実行から実行への完全な再現性を確実にするには疑似ランダム generator のためにシードを設定し、そして Trainer で deterministic フラグを設定する必要があります。

サンプル :

from pytorch_lightning import Trainer, seed_everything

seed_everything(42)
# sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
model = Model()
trainer = Trainer(deterministic=True)
 

以上