PyTorch Lightning 1.1: 拡張 (オプション) : LightningDataModule (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/27/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- Optional extensions : LightningDataModule
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
| 人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
| PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) | ||
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |
拡張 (オプション) : LightningDataModule
datamodule は共有可能な、再利用可能なクラスでデータを処理するために必要な総てのステップをカプセル化します :
datamodule は PyTorch のデータ処理に伴う 5 つのステップをカプセル化します :
- ダウンロード / トークン化 / (前) 処理。
- クリーンアップしてそして (おそらく) ディスクにセーブします。
- データセット内にロードします。
- 変換 (回転、トークン化等…) を適用します。
- DataLoader 内にラップします。
そしてこのクラスはどこでもシェアして利用できます :
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule model = LitClassifier() trainer = Trainer() imagenet = ImagenetDataModule() trainer.fit(model, imagenet) cifar10 = CIFAR10DataModule() trainer.fit(model, cifar10)
何故 DataModule が必要なのでしょう?
通常の PyTorch コードでは、データ・クリーニング/準備は通常は多くのファイルに渡り散乱しています。これはプロジェクトに渡り正確な分割と変換を共有して再利用することを不可能にしています。
かつて質問をしたことがあるならば datamodule は貴方のためのものです :
# regular PyTorch test_data = MNIST(my_path, train=False, download=True) train_data = MNIST(my_path, train=True, download=True) train_data, val_data = random_split(train_data, [55000, 5000]) train_loader = DataLoader(train_data, batch_size=32) val_loader = DataLoader(val_data, batch_size=32) test_loader = DataLoader(test_data, batch_size=32)
同値な DataModule は同じ正確なコードを単に体系化しますが、プロジェクトに渡り再利用可能にします。
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage=None):
self.mnist_test = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
しかし今、処理の複雑さが増大するにつれて (変換、マルチ GPU 訓練) 、これらのデータセットを再利用可能にしながら Lightning にそれらの詳細を処理させることができますので、同僚と共有したり異なるプロジェクトで利用することができます。
mnist = MNISTDataModule(my_path) model = LitClassifier() trainer = Trainer() trainer.fit(model, mnist)
ここにより現実的で、複雑な DataModule があります、これは datamodule がどれほどより再利用可能であるかを示します。
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = './'):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == 'fit' or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Optionally...
# self.dims = tuple(self.mnist_train[0][0].shape)
# Assign test dataset for use in dataloader(s)
if stage == 'test' or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
# Optionally...
# self.dims = tuple(self.mnist_test[0][0].shape)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
NOTE: setup は文字列 arg stage を想定します。それは trainer.fit と trainer.test のための setup ロジックを分離するために使用されます。
以上