PyTorch Lightning 1.1 : notebooks : PyTorch Lightning DataModules

PyTorch Lightning 1.1: notebooks : PyTorch Lightning DataModules (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/13/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

notebooks : PyTorch Lightning DataModules

  • Datamodules – DataModule について学習して MNIST と CIFAR10 上でデータセット不可知なモデルを訓練します。

pytorch-lightning version 0.9.0 のリリースより、LightningModule からデータ関連フックを切り離すのに役立つ LightningDataModule と呼ばれる新しいクラスを含めました。

このノートブックは Datamodule を使用してどのように開始するかを貴方にガイドします。

データモジュールの最も最新のドキュメントは ここ で見つけられます。

 

セットアップ

Lightning は容易にインストールできます。単純に pip install pytorch-lightning

! pip install pytorch-lightning --quiet

 

イントロダクション

最初に、LightningDataModule の使用なく通常の LightningModule 実装を調べます。

import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST, CIFAR10
from torchvision import transforms

 

LitMNISTModel を定義する

下で、MNIST 手書き数字を分類する hello world チュートリアルからの LightningModule を再利用します。

不幸なことに、データセット固有の項目をモデル内にハードコードしましたので、それを MNIST データで動作するように永遠に制限します。😢

モデルを異なるデータセット上で訓練/評価する計画を持たないのであればこれは構いません。けれども、多くの場合、貴方のアーキテクチャを異なるデータセット上で試すことを望むときこれは厄介になる可能性があります。

class LitMNIST(pl.LightningModule):
    
    def __init__(self, data_dir='./', hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # We hardcode dataset specific stuff here.
        self.data_dir = data_dir
        self.num_classes = 10
        self.dims = (1, 28, 28)
        channels, width, height = self.dims
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Build model
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, self.num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    ####################
    # DATA RELATED HOOKS
    ####################

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

 

LitMNIST モデルを訓練する

model = LitMNIST()
trainer = pl.Trainer(max_epochs=2, gpus=1, progress_bar_refresh_rate=20)
trainer.fit(model)

 

DataModule を使用する

DataModule は LightningModule からデータ関連フックを切り離す方法ですので、データセット不可知なモデルを開発できます。

 

MNISTDataModule を定義する

下でクラスの各関数について調べて、そしてそれらが何をしているか話しましょう :

  1. __init__
    • data_dir arg を取ります、これは MNIST データセットをダウンロードしたかダウンロードすることを望むところを指し示します。
    • 変換を定義します、これは train, val と test データセット分割に渡り適用されます。
    • self.dims を定義します、これは datamodule.size() から返されるタプルでそれはモデルを初期化する手助けができます。

  2. prepare_data
    • これはデータセットをダウンロードできるところです。望まれるデータセットを指し示してデータセットがそこで見つからない場合には torchvision の MNIST データセット・クラスにダウンロードすることを依頼します。
    • この関数内でどのような状態割当ても行なわないことに注意してください (i.e. self.something = …)。

  3. setup
    • データをファイルからロードして各分割 (train, val, test) のために PyTorch tensor データセットを準備します。
    • setup は ‘stage’ arg を想定します、これはロジックを ‘fit’ と ‘test’ のために分離するために使用されます。
    • 総てのデータセットを一度にロードすることが嫌でないならば、stage に None が渡されたときにいつでも ‘fit’ 関連のセットアップと ‘test’ 関連のセットアップの両者を実行することを許す条件をセットアップできます。
    • これは総ての GPU に渡り実行されそしてここで状態割当てを行なうことは安全であることに注意してください

  4. x_dataloader
    • train_dataloader(), val_dataloader() と test_dataloader() は総て setup() で準備したそれぞれのデータセットをラップすることにより作成された PyTorch DataLoader インスタンスを返します。
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

 

データセット不可知な LitModel を定義する

下で、先に作成した LitMNIST モデルと同じモデルを定義します。

けれども、今回はモデルは使用したい任意の入力データを使用する自由を持ちます。

class LitModel(pl.LightningModule):
    
    def __init__(self, channels, width, height, num_classes, hidden_size=64, learning_rate=2e-4):

        super().__init__()

        # We take in input dimensions as parameters and use those to dynamically build model.
        self.channels = channels
        self.width = width
        self.height = height
        self.num_classes = num_classes
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(channels * width * height, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):

        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

 

MNISTDataModule を使用して LitModel を訓練する

今は、MNISTDataModule の configuration 設定とデーたローダを使用して LitModel を初期化して訓練します。

# Init DataModule
dm = MNISTDataModule()
# Init model from datamodule's attributes
model = LitModel(*dm.size(), dm.num_classes)
# Init trainer
trainer = pl.Trainer(max_epochs=3, progress_bar_refresh_rate=20, gpus=1)
# Pass the datamodule as arg to trainer.fit to override model hooks :)
trainer.fit(model, dm)

 

CIFAR10 DataModule を定義する

CIFAR10 データセットのための新しいデータモジュールを定義することにより先に作成した LitModel がデータセット不可知であることを検証しましょう。

class CIFAR10DataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=32)

 

CIFAR10DataModule を使用して LitModel を訓練する

モデルは非常に良くはないので、それは CIFAR10 データセット上ではかなり悪く遂行します。

ここでポイントは LitModel が入力データとして異なるデータモジュールを使用して問題がないことを見れることです。

dm = CIFAR10DataModule()
model = LitModel(*dm.size(), dm.num_classes, hidden_size=256)
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, gpus=1)
trainer.fit(model, dm)
 

以上