PyTorch Lightning 1.1: notebooks : PyTorch Lightning DataModules (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/13/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- notebooks : PyTorch Lightning DataModules
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
| 人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
| PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) | ||
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |
notebooks : PyTorch Lightning DataModules
- Datamodules – DataModule について学習して MNIST と CIFAR10 上でデータセット不可知なモデルを訓練します。
pytorch-lightning version 0.9.0 のリリースより、LightningModule からデータ関連フックを切り離すのに役立つ LightningDataModule と呼ばれる新しいクラスを含めました。
このノートブックは Datamodule を使用してどのように開始するかを貴方にガイドします。
データモジュールの最も最新のドキュメントは ここ で見つけられます。
セットアップ
Lightning は容易にインストールできます。単純に pip install pytorch-lightning
! pip install pytorch-lightning --quiet
イントロダクション
最初に、LightningDataModule の使用なく通常の LightningModule 実装を調べます。
import pytorch_lightning as pl from pytorch_lightning.metrics.functional import accuracy import torch from torch import nn import torch.nn.functional as F from torch.utils.data import random_split, DataLoader # Note - you must have torchvision installed for this example from torchvision.datasets import MNIST, CIFAR10 from torchvision import transforms
LitMNISTModel を定義する
下で、MNIST 手書き数字を分類する hello world チュートリアルからの LightningModule を再利用します。
不幸なことに、データセット固有の項目をモデル内にハードコードしましたので、それを MNIST データで動作するように永遠に制限します。😢
モデルを異なるデータセット上で訓練/評価する計画を持たないのであればこれは構いません。けれども、多くの場合、貴方のアーキテクチャを異なるデータセット上で試すことを望むときこれは厄介になる可能性があります。
class LitMNIST(pl.LightningModule):
def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):
super().__init__()
# We hardcode dataset specific stuff here.
self.data_dir = data_dir
self.num_classes = 10
self.dims = (1, 28, 28)
channels, width, height = self.dims
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
self.hidden_size = hidden_size
self.learning_rate = learning_rate
# Build model
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, self.num_classes)
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
####################
# DATA RELATED HOOKS
####################
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
LitMNIST モデルを訓練する
model = LitMNIST() trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20) trainer.fit(model)
DataModule を使用する
DataModule は LightningModule からデータ関連フックを切り離す方法ですので、データセット不可知なモデルを開発できます。
MNISTDataModule を定義する
下でクラスの各関数について調べて、そしてそれらが何をしているか話しましょう :
- __init__
- data_dir arg を取ります、これは MNIST データセットをダウンロードしたかダウンロードすることを望むところを指し示します。
- 変換を定義します、これは train, val と test データセット分割に渡り適用されます。
- self.dims を定義します、これは datamodule.size() から返されるタプルでそれはモデルを初期化する手助けができます。
- prepare_data
- これはデータセットをダウンロードできるところです。望まれるデータセットを指し示してデータセットがそこで見つからない場合には torchvision の MNIST データセット・クラスにダウンロードすることを依頼します。
- この関数内でどのような状態割当ても行なわないことに注意してください (i.e. self.something = …)。
- setup
- データをファイルからロードして各分割 (train, val, test) のために PyTorch tensor データセットを準備します。
- setup は ‘stage’ arg を想定します、これはロジックを ‘fit’ と ‘test’ のために分離するために使用されます。
- 総てのデータセットを一度にロードすることが嫌でないならば、stage に None が渡されたときにいつでも ‘fit’ 関連のセットアップと ‘test’ 関連のセットアップの両者を実行することを許す条件をセットアップできます。
- これは総ての GPU に渡り実行されそしてここで状態割当てを行なうことは安全であることに注意してください。
- x_dataloader
- train_dataloader(), val_dataloader() と test_dataloader() は総て setup() で準備したそれぞれのデータセットをラップすることにより作成された PyTorch DataLoader インスタンスを返します。
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
self.num_classes = 10
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
データセット不可知な LitModel を定義する
下で、先に作成した LitMNIST モデルと同じモデルを定義します。
けれども、今回はモデルは使用したい任意の入力データを使用する自由を持ちます。
class LitModel(pl.LightningModule):
def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):
super().__init__()
# We take in input dimensions as parameters and use those to dynamically build model.
self.channels = channels
self.width = width
self.height = height
self.num_classes = num_classes
self.hidden_size = hidden_size
self.learning_rate = learning_rate
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, num_classes)
)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
self.log('val_loss', loss, prog_bar=True)
self.log('val_acc', acc, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
MNISTDataModule を使用して LitModel を訓練する
今は、MNISTDataModule の configuration 設定とデーたローダを使用して LitModel を初期化して訓練します。
# Init DataModule dm = MNISTDataModule() # Init model from datamodule's attributes model = LitModel(*dm.size(), dm.num_classes) # Init trainer trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1) # Pass the datamodule as arg to trainer.fit to override model hooks :) trainer.fit(model, dm)
CIFAR10 DataModule を定義する
CIFAR10 データセットのための新しいデータモジュールを定義することにより先に作成した LitModel がデータセット不可知であることを検証しましょう。
class CIFAR10DataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.dims = (3, 32, 32)
self.num_classes = 10
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.cifar_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.cifar_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.cifar_test, batch_size=32)
CIFAR10DataModule を使用して LitModel を訓練する
モデルは非常に良くはないので、それは CIFAR10 データセット上ではかなり悪く遂行します。
ここでポイントは LitModel が入力データとして異なるデータモジュールを使用して問題がないことを見れることです。
dm = CIFAR10DataModule() model = LitModel(*dm.size(), dm.num_classes, hidden_size=256) trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1) trainer.fit(model, dm)
以上