PyTorch Lightning 1.1: Lightning API : Trainer (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/26/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- Lightning API : Trainer
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
Lightning API : Trainer
貴方の PyTorch を LightningModule にひとたび体系化したならば、Trainer は他の総てを自動化します。
この抽象は以下を達成します :
- 追加の抽象なしに PyTorch コードを通して総ての局面に渡り制御を維持します。
- Facebook AI Research, NYU, MIT, Stanford 等… のようなトップ AI ラボからの contributors とユーザにより埋め込まれたベストプラクティスを trainer は利用します。
- trainer は自動化されることを望まない任意の主要パートを override することを可能にします。
基本的な使用方法
これは trainer の基本的な使用法です :
model = MyLightningModule() trainer = Trainer() trainer.fit(model, train_dataloader, val_dataloader)
Under the hood
内部的には、Lightning Trainer は訓練ループの詳細を処理します、幾つかのサンプルは以下を含みます :
- 自動的に勾配を有効/無効にします。
- 訓練、検証とテストデータローダを実行します。
- 適切な時にコールバックを呼び出します。
- 正しいデバイスにバッチと計算を配置します。
ここに trainer が内部的に行なうもののための疑似コードがあります (訓練ループのみ示します)
# put model in train mode model.train() torch.set_grad_enabled(True) losses = [] for batch in train_dataloader: # calls hooks like this one on_train_batch_start() # train step loss = training_step(batch) # backward loss.backward() # apply and clear grads optimizer.step() optimizer.zero_grad() losses.append(loss)
Python スクリプトの Trainer
Python スクリプトでは、Trainer を呼ぶために main 関数を使用することが推奨されます。
from argparse import ArgumentParser def main(hparams): model = LightningModule() trainer = Trainer(gpus=hparams.gpus) trainer.fit(model) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--gpus', default=None) args = parser.parse_args() main(args)
そしてそれをこのように実行できます :
python main.py --gpus 2
NOTE
Pro-tip: 総てのフラグを手動で定義する必要はありません。Lightning はそれらを自動的に追加できます
from argparse import ArgumentParser def main(args): model = LightningModule() trainer = Trainer.from_argparse_args(args) trainer.fit(model) if __name__ == '__main__': parser = ArgumentParser() parser = Trainer.add_argparse_args( # group the Trainer arguments together parser.add_argument_group(title="pl.Trainer args") ) args = parser.parse_args() main(args)
そしてそれをこのように実行できます :
python main.py --gpus 2 --max_steps 10 --limit_train_batches 10 --any_trainer_arg x
Note
訓練実行を早期に停止することを望む場合、キーボードで “Ctrl + C” を押すことができます。trainer は KeyboardInterrupt を捕捉して on_train_end のようなコールバックの実行を含む、graceful シャットダウンを試みます。trainer オブジェクトはまたそのような場合に属性 interrupted を True に設定します。計算リソースをシャットダウンするコールバックを持つ場合、例えば、uninterrupted 実行のためだけにシャットダウン・ロジックを条件付きで実行できます。
テスト
訓練をひとたび行えば、自由にテストセットを実行してください!(Only right before publishing your paper or pushing to production)
trainer.test(test_dataloaders=test_dataloader)
配備 / 予測
単に LightningModule を訓練しました、これはまた単に torch.nn.Module です。Use it to do whatever!
# load model pretrained_model = LightningModule.load_from_checkpoint(PATH) pretrained_model.freeze() # use it for finetuning def forward(self, x): features = pretrained_model(x) classes = classifier(features) # or for prediction out = pretrained_model(x) api_write({'response': out}
様々なデバイス上でモデルを実行したいかもしれません。データを手動で正しいデバイスに移動する代わりに、forward メソッド (あるいは推論のために使用する任意の他のメソッド) を auto_move_data() でデコレートします、そして Lightning は残りを処理します。
再現性
実行から実行への完全な再現性を確実にするには疑似ランダム generator のためにシードを設定し、そして Trainer で deterministic フラグを設定する必要があります。
サンプル :
from pytorch_lightning import Trainer, seed_everything seed_everything(42) # sets seeds for numpy, torch, python.random and PYTHONHASHSEED. model = Model() trainer = Trainer(deterministic=True)
以上