PyTorch Lightning 1.1 : Lightning API : LightningModule

PyTorch Lightning 1.1: Lightning API : LightningModule (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/25/2021 (1.1.x)

* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/
Facebook: https://www.facebook.com/ClassCatJP/

 

 

Lightning API : LightningModule

LightningModule は貴方の PyTorch コードを 5 つのセクションに体系化します :

  • 計算 (init)
  • 訓練ループ (training_step)
  • 検証ループ (validation_step)
  • テストループ (test_step)
  • Optimizers (configure_optimizers)

幾つかのことに注意してください。

  1. それは 同じ コードです。
  2. PyTorch コードは 抽象化されません – 単に体系化されます。
  3. LightningModule にない総ての他のコードは trainer により自動化されます。
    net = Net()
    trainer = Trainer()
    trainer.fit(net)
    
  4. .cuda() や .to() 呼び出しはありません… Lightning はこれらを貴方のために行ないます。
    # don't do in lightning
    x = torch.Tensor(2, 3)
    x = x.cuda()
    x = x.to(device)
    
    # do this instead
    x = x  # leave it alone!
    
    # or to init a new tensor
    new_x = torch.Tensor(2, 3)
    new_x = new_x.type_as(x)
    
  5. 分散のためのサンプラーはありません、Lightning はこれも貴方のために行ないます。
    # Don't do in Lightning...
    data = MNIST(...)
    sampler = DistributedSampler(data)
    DataLoader(data, sampler=sampler)
    
    # do this instead
    data = MNIST(...)
    DataLoader(data)
    
  6. LightningModule は torch.nn.Module ですが追加の機能を伴います。Use it as such!
    net = Net.load_from_checkpoint(PATH)
    net.freeze()
    out = net(x)
    

このように、Lightning を利用するには、貴方のコードを単に体系化する必要があります、これは約 30 分かかります (そして本音で、貴方は多分どのようにしても行なうはずです)。

 

最小限のサンプル

ここに必要なメソッドだけがあります :

>>> import pytorch_lightning as pl
>>> class LitModel(pl.LightningModule):
...
...     def __init__(self):
...         super().__init__()
...         self.l1 = nn.Linear(28 * 28, 10)
...
...     def forward(self, x):
...         return torch.relu(self.l1(x.view(x.size(0), -1)))
...
...     def training_step(self, batch, batch_idx):
...         x, y = batch
...         y_hat = self(x)
...         loss = F.cross_entropy(y_hat, y)
...         return loss
...
...     def configure_optimizers(self):
...         return torch.optim.Adam(self.parameters(), lr=0.02)

これは次を行なうことにより訓練できます :

train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor()))
trainer = pl.Trainer()
model = LitModel()

trainer.fit(model, train_loader)

LightningModule は多くの便利なメソッドを持ちますが、それについて知る必要がある中心的なものは以下です :

init

  • ここで計算を定義します

forward

  • 推論のためのみに使用します (training_step から分離)

training_step

  • 完全な訓練ループ

validation_step

  • 完全な検証ループ

test_step

  • 完全なテストループ

configure_optimizers

  • optimzer と LR スケジューラを定義します。

 

訓練

訓練ループ

訓練ループを追加するには training_step メソッドを使用します :

class LitClassifier(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self.model(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

内部的には、Lightning は以下を行ないます (擬似コード) :

# put model in train mode
model.train()
torch.set_grad_enabled(True)

losses = []
for batch in train_dataloader:
    # forward
    loss = training_step(batch)
    losses.append(loss.detach())

    # backward
    loss.backward()

    # apply and clear grads
    optimizer.step()
    optimizer.zero_grad()

 

訓練エポック-level メトリクス

エポック-level メトリクスを計算してそれらをログ記録することを望む場合、.log メソッドを使用します。

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)

    # logs metrics for each training_step,
    # and the average across the epoch, to the progress bar and logger
    self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
    return loss

.log オブジェクトは要求されたメトリクスを full エポックに渡り自動的に reduce します。それが内部的に何を行なうかの疑似コードがここにあります。

outs = []
for batch in train_dataloader:
    # forward
    out = training_step(val_batch)

    # backward
    loss.backward()

    # apply and clear grads
    optimizer.step()
    optimizer.zero_grad()

epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs]))

 

エポック-level 演算を訓練する

各 training_step の総ての出力で何かを行なう必要がある場合、training_epoch_end を貴方自身で override します。

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    preds = ...
    return {'loss': loss, 'other_stuff': preds}

def training_epoch_end(self, training_step_outputs):
   for pred in training_step_outputs:
       # do something

一致する疑似コードは :

outs = []
for batch in train_dataloader:
    # forward
    out = training_step(val_batch)

    # backward
    loss.backward()

    # apply and clear grads
    optimizer.step()
    optimizer.zero_grad()

training_epoch_end(outs)

 

DataParallel で訓練する

GPU に渡り各バッチからデータを分割する accelerator を使用して訓練するとき、時にそれらを処理のためにマスター GPU 上で集計する必要があるかもしれません (dp, or ddp2)。

この場合、training_step_end メソッドを実装します :

def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {'loss': loss, 'pred': pred}

def training_step_end(self, batch_parts):
    gpu_0_prediction = batch_parts[0]['pred']
    gpu_1_prediction = batch_parts[1]['pred']

    # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def training_epoch_end(self, training_step_outputs):
   for out in training_step_outputs:
       # do something with preds

lightning が内部的に行なう完全な疑似コードは :

outs = []
for train_batch in train_dataloader:
    batches = split_batch(train_batch)
    dp_outs = []
    for sub_batch in batches:
        # 1
        dp_out = training_step(sub_batch)
        dp_outs.append(dp_out)

    # 2
    out = training_step_end(dp_outs)
    outs.append(out)

# do something with the outputs for all batches
# 3
training_epoch_end(outs)

 

検証ループ

検証ループを追加するには、LightningModule の validation_step メソッドを override します :

class LitModel(pl.LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        self.log('val_loss', loss)

内部的には、Lightning は以下を行ないます :

# ...
for batch in train_dataloader:
    loss = model.training_step()
    loss.backward()
    # ...

    if validate_at_some_point:
        # disable grads + batchnorm + dropout
        torch.set_grad_enabled(False)
        model.eval()

        # ----------------- VAL LOOP ---------------
        for val_batch in model.val_dataloader:
            val_out = model.validation_step(val_batch)
        # ----------------- VAL LOOP ---------------

        # enable grads + batchnorm + dropout
        torch.set_grad_enabled(True)
        model.train()

 

検証エポック-level メトリクス

各 validation_step の総ての出力で何かを行なう必要がある場合、validation_epoch_end を override します。

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred =  ...
    return pred

def validation_epoch_end(self, validation_step_outputs):
   for pred in validation_step_outputs:
       # do something with a pred

 

DataParallel で検証する

GPU に渡り各バッチからデータを分割する accelerator を使用して検証するとき、時にそれらを処理のためにマスター GPU 上で集計する必要があるかもしれません (dp, or ddp2)。

この場合、validation_step_end メソッドを実装します :

def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self.model(x)
    loss = F.cross_entropy(y_hat, y)
    pred = ...
    return {'loss': loss, 'pred': pred}

def validation_step_end(self, batch_parts):
    gpu_0_prediction = batch_parts.pred[0]['pred']
    gpu_1_prediction = batch_parts.pred[1]['pred']

    # do something with both outputs
    return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2

def validation_epoch_end(self, validation_step_outputs):
   for out in validation_step_outputs:
       # do something with preds

内部的に lightning が行なう完全な疑似コードは :

outs = []
for batch in dataloader:
    batches = split_batch(batch)
    dp_outs = []
    for sub_batch in batches:
        # 1
        dp_out = validation_step(sub_batch)
        dp_outs.append(dp_out)

    # 2
    out = validation_step_end(dp_outs)
    outs.append(out)

# do something with the outputs for all batches
# 3
validation_epoch_end(outs)

 

テストループ

テストループを追加するためのプロセスは検証ループを追加するためのプロセスと同じです。詳細については上のセクションを参照してください。

唯一の違いはテストループは .test() が使用されるときに呼び出されるだけであることです :

model = Model()
trainer = Trainer()
trainer.fit()

# automatically loads the best weights for you
trainer.test(model)

test() を呼び出す 2 つの方法があります :

# call after training
trainer = Trainer()
trainer.fit(model)

# automatically auto-loads the best weights
trainer.test(test_dataloaders=test_dataloader)

# or call with pretrained model
model = MyLightningModule.load_from_checkpoint(PATH)
trainer = Trainer()
trainer.test(model, test_dataloaders=test_dataloader)

 

推論

研究のために、LightningModule はシステムとして最善に構造化されます。

import pytorch_lightning as pl
import torch
from torch import nn

class Autoencoder(pl.LightningModule):

     def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
        self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))

     def training_step(self, batch, batch_idx):
        x, _ = batch

        # encode
        x = x.view(x.size(0), -1)
        z = self.encoder(x)

        # decode
        recons = self.decoder(z)

        # reconstruction
        reconstruction_loss = nn.functional.mse_loss(recons, x)
        return reconstruction_loss

     def validation_step(self, batch, batch_idx):
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        recons = self.decoder(z)
        reconstruction_loss = nn.functional.mse_loss(recons, x)
        self.log('val_reconstruction', reconstruction_loss)

     def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0002)

これはこのように訓練できます :

autoencoder = Autoencoder()
trainer = pl.Trainer(gpus=1)
trainer.fit(autoencoder, train_dataloader, val_dataloader)

この単純なモデルはこのように見えるサンプルを生成します (エンコーダとデコーダは非常に弱いです)

上のメソッドは lightning インターフェイスの一部です :

  • training_step
  • validation_step
  • test_step
  • configure_optimizers

この場合、訓練ループと val ループは正確に同じであることに注意してください。このコードをもちろん再利用できます。

class Autoencoder(pl.LightningModule):

     def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
        self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))

     def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)

        return loss

     def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log('val_loss', loss)

     def shared_step(self, batch):
        x, _ = batch

        # encode
        x = x.view(x.size(0), -1)
        z = self.encoder(x)

        # decode
        recons = self.decoder(z)

        # loss
        return nn.functional.mse_loss(recons, x)

     def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0002)

総てのループが利用できる shared_step と呼ばれる新しいメソッドを作成します。このメソッド名は任意で 予約されていません

 

研究内の推論

システムで推論を遂行することを望むところのケースでは LightningModule に forward メソッドを追加できます。

class Autoencoder(pl.LightningModule):
    def forward(self, x):
        return self.decoder(x)

forward を追加する優位点は複雑なシステムで、テキスト生成のような、遥かに多くの関係する推論手続きを行なうことができることです :

class Seq2Seq(pl.LightningModule):

    def forward(self, x):
        embeddings = self(x)
        hidden_states = self.encoder(embeddings)
        for h in hidden_states:
            # decode
            ...
        return decoded

 

プロダクション内の推論

プロダクションのようなケースについては、LightningModule 内で異なるモデルを iterate することを望むかもしれません。

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM

class ClassificationTask(pl.LightningModule):

     def __init__(self, model):
         super().__init__()
         self.model = model

     def training_step(self, batch, batch_idx):
         x, y = batch
         y_hat = self.model(x)
         loss = F.cross_entropy(y_hat, y)
         return loss

     def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = F.cross_entropy(y_hat, y)
        acc = FM.accuracy(y_hat, y)

        # loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
        metrics = {'val_acc': acc, 'val_loss': loss}
        self.log_dict(metrics)
        return metrics

     def test_step(self, batch, batch_idx):
        metrics = self.validation_step(batch, batch_idx)
        metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
        self.log_dict(metrics)

     def configure_optimizers(self):
         return torch.optim.Adam(self.model.parameters(), lr=0.02)

それから任意のモデルをこのタスクで fit されるために渡します

for model in [resnet50(), vgg16(), BidirectionalRNN()]:
    task = ClassificationTask(model)

    trainer = Trainer(gpus=2)
    trainer.fit(task, train_dataloader, val_dataloader)

タスクは GAN 訓練、自己教師ありあるいは RL さえ実装するように任意に複雑であり得ます。

class GANTask(pl.LightningModule):

     def __init__(self, generator, discriminator):
         super().__init__()
         self.generator = generator
         self.discriminator = discriminator
     ...

このように使用されるとき、モデルはタスクから分離できてそれを LightningModule 内に保持する必要なくプロダクションで利用できます。

  • onnx にエクスポートできます。
  • あるいは jit を使用してトレースできます。
  • あるいは python ランタイムで実行できます。
task = ClassificationTask(model)

trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)

# use model after training or load weights and drop into the production system
model.eval()
y_hat = model(x)
 

以上