PyTorch Lightning 1.1: Getting Started : 2 ステップで Lightning (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/02/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- Getting Started : Lightning in 2 steps
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
Getting Started : 2 ステップで Lightning
このガイドでは貴方の PyTorch コードを 2 ステップで Lightning にどのように体系化するかを示します。
PyTorch Lightning によるコードの体系化は貴方のコードを以下のようにします :
- 総ての柔軟性を保持しながら (これは依然として総て純粋な PyTorch です)、多くのボイラープレートは取り除きます。
- 研究コードをエンジニアリングから切り離すことにより可読性を高める。
- より容易に再現する
- 訓練ループとトリッキーなエンジニアリングの殆どを自動化してエラーを少ない傾向にします。
- モデルを変更することなく任意のハードウェアにスケーラブルです。
ステップ 0 : PyTorch Lightning をインストールする
pip を使用してインストールできます。
pip install pytorch-lightning
Or with conda (conda をどのようにインストールするかを ここ で見てください) :
conda install pytorch-lightning -c conda-forge
conda 環境を利用することもできるでしょう。
conda activate my_env pip install pytorch-lightning
以下をインポートします :
import os import torch from torch import nn import torch.nn.functional as F from torchvision import transforms from torchvision.datasets import MNIST from torch.utils.data import DataLoader, random_split import pytorch_lightning as pl
ステップ 1 : Lightning モジュールを定義する
class LitAutoEncoder(pl.LightningModule): def __init__(self): super().__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 64), nn.ReLU(), nn.Linear(64, 3) ) self.decoder = nn.Sequential( nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28*28) ) def forward(self, x): # in lightning, forward defines the prediction/inference actions embedding = self.encoder(x) return embedding def training_step(self, batch, batch_idx): # training_step defined the train loop. # It is independent of forward x, y = batch x = x.view(x.size(0), -1) z = self.encoder(x) x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) # Logging to TensorBoard by default self.log('train_loss', loss) return loss def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer
システム VS モデル
LightningModule はモデルではなくシステムを定義します。
システムのサンプルは :
内部的には LightningModule は依然として単なる torch.nn.Module で、これは総ての研究コードを自己充足的にするために単一のファイル内にグループ分けします :
- 訓練ループ
- 検証ループ
- テストループ
- モデル or モデル (群) のシステム
- Optimizer
利用可能なコールバック・フック で見つかる 20+ フックのいずれかを override して (backward パスのような) 訓練の任意の部分をカスタマイズできます。
class LitAutoEncoder(pl.LightningModule): def backward(self, loss, optimizer, optimizer_idx): loss.backward()
FORWARD vs TRAINING_STEP
Lightning では推論から訓練を分離しています。training_step は完全な訓練ステップを定義します。推論アクションを定義するために forward を利用することをユーザに奨励します。
例えば、この場合にはオートエンコーダを embedding 抽出器として動作するように定義できるでしょう :
def forward(self, x): embeddings = self.encoder(x) return embeddings
もちろん、training_step 内から forward を使用することから貴方を止めるものはありません。
def training_step(self, batch, batch_idx): ... z = self(x)
それは実際に貴方のアプリケーションへ下りていきます。けれども、両者の意図を分離することを勧めます。
- 推論 (予測) のために forward を使用する。
- 訓練のために training_step を使用する。
LightningModule docs により多くの詳細があります。
ステップ 2 : Lightning Trainer で Fit させる
最初に、貴方が望むようなやり方でデータを定義します。Lightning は train/val/test のために DataLoader を単に必要とします。
dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()) train_loader = DataLoader(dataset)
次に、LightningModule と PyTorch Lightning Trainer を初期化してから、データとモデルの両者とともに fit を呼び出します。
# init model autoencoder = LitAutoEncoder() # most basic trainer, uses good defaults (auto-tensorboard, checkpoints, logs, and more) # trainer = pl.Trainer(gpus=8) (if you have GPUs) trainer = pl.Trainer() trainer.fit(autoencoder, train_loader)
Trainer は以下を自動化します :
- エポックとバッチ iteration
- optimizer.step(), backward, zero_grad() の呼び出し
- .eval(), grads 有効/無効の呼び出し
- 重みのセーブとロード
- Tensorboard (Loggers オプション参照)
- マルチ-GPU 訓練 サポート
- TPU サポート
- 16-ビット訓練 サポート
TIP : optimizer を手動で管理することを好むのであれば Manual optimization モードを利用できます (ie: RL, GAN, 等…)。
That’s it!
これらが Lightning で知る必要がある主要な 2 つのコンセプトです。lightning の他の総ての特徴は Trainer か LightningModule のいずれかの特徴です。
以上