PyTorch Lightning 1.1: Lightning API : LightningModule (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 02/25/2021 (1.1.x)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- Lightning API : LightningModule
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
| 人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
| PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) | ||
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |

Lightning API : LightningModule
LightningModule は貴方の PyTorch コードを 5 つのセクションに体系化します :
- 計算 (init)
- 訓練ループ (training_step)
- 検証ループ (validation_step)
- テストループ (test_step)
- Optimizers (configure_optimizers)
幾つかのことに注意してください。
- それは 同じ コードです。
- PyTorch コードは 抽象化されません – 単に体系化されます。
- LightningModule にない総ての他のコードは trainer により自動化されます。
net = Net() trainer = Trainer() trainer.fit(net)
- .cuda() や .to() 呼び出しはありません… Lightning はこれらを貴方のために行ないます。
# don't do in lightning x = torch.Tensor(2, 3) x = x.cuda() x = x.to(device) # do this instead x = x # leave it alone! # or to init a new tensor new_x = torch.Tensor(2, 3) new_x = new_x.type_as(x)
- 分散のためのサンプラーはありません、Lightning はこれも貴方のために行ないます。
# Don't do in Lightning... data = MNIST(...) sampler = DistributedSampler(data) DataLoader(data, sampler=sampler) # do this instead data = MNIST(...) DataLoader(data)
- LightningModule は torch.nn.Module ですが追加の機能を伴います。Use it as such!
net = Net.load_from_checkpoint(PATH) net.freeze() out = net(x)
このように、Lightning を利用するには、貴方のコードを単に体系化する必要があります、これは約 30 分かかります (そして本音で、貴方は多分どのようにしても行なうはずです)。
最小限のサンプル
ここに必要なメソッドだけがあります :
>>> import pytorch_lightning as pl >>> class LitModel(pl.LightningModule): ... ... def __init__(self): ... super().__init__() ... self.l1 = nn.Linear(28 * 28, 10) ... ... def forward(self, x): ... return torch.relu(self.l1(x.view(x.size(0), -1))) ... ... def training_step(self, batch, batch_idx): ... x, y = batch ... y_hat = self(x) ... loss = F.cross_entropy(y_hat, y) ... return loss ... ... def configure_optimizers(self): ... return torch.optim.Adam(self.parameters(), lr=0.02)
これは次を行なうことにより訓練できます :
train_loader = DataLoader(MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())) trainer = pl.Trainer() model = LitModel() trainer.fit(model, train_loader)
LightningModule は多くの便利なメソッドを持ちますが、それについて知る必要がある中心的なものは以下です :
init
- ここで計算を定義します
forward
- 推論のためのみに使用します (training_step から分離)
training_step
- 完全な訓練ループ
validation_step
- 完全な検証ループ
test_step
- 完全なテストループ
configure_optimizers
- optimzer と LR スケジューラを定義します。
訓練
訓練ループ
訓練ループを追加するには training_step メソッドを使用します :
class LitClassifier(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
内部的には、Lightning は以下を行ないます (擬似コード) :
# put model in train mode
model.train()
torch.set_grad_enabled(True)
losses = []
for batch in train_dataloader:
# forward
loss = training_step(batch)
losses.append(loss.detach())
# backward
loss.backward()
# apply and clear grads
optimizer.step()
optimizer.zero_grad()
訓練エポック-level メトリクス
エポック-level メトリクスを計算してそれらをログ記録することを望む場合、.log メソッドを使用します。
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
# logs metrics for each training_step,
# and the average across the epoch, to the progress bar and logger
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
.log オブジェクトは要求されたメトリクスを full エポックに渡り自動的に reduce します。それが内部的に何を行なうかの疑似コードがここにあります。
outs = []
for batch in train_dataloader:
# forward
out = training_step(val_batch)
# backward
loss.backward()
# apply and clear grads
optimizer.step()
optimizer.zero_grad()
epoch_metric = torch.mean(torch.stack([x['train_loss'] for x in outs]))
エポック-level 演算を訓練する
各 training_step の総ての出力で何かを行なう必要がある場合、training_epoch_end を貴方自身で override します。
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
preds = ...
return {'loss': loss, 'other_stuff': preds}
def training_epoch_end(self, training_step_outputs):
for pred in training_step_outputs:
# do something
一致する疑似コードは :
outs = []
for batch in train_dataloader:
# forward
out = training_step(val_batch)
# backward
loss.backward()
# apply and clear grads
optimizer.step()
optimizer.zero_grad()
training_epoch_end(outs)
DataParallel で訓練する
GPU に渡り各バッチからデータを分割する accelerator を使用して訓練するとき、時にそれらを処理のためにマスター GPU 上で集計する必要があるかもしれません (dp, or ddp2)。
この場合、training_step_end メソッドを実装します :
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return {'loss': loss, 'pred': pred}
def training_step_end(self, batch_parts):
gpu_0_prediction = batch_parts[0]['pred']
gpu_1_prediction = batch_parts[1]['pred']
# do something with both outputs
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
def training_epoch_end(self, training_step_outputs):
for out in training_step_outputs:
# do something with preds
lightning が内部的に行なう完全な疑似コードは :
outs = []
for train_batch in train_dataloader:
batches = split_batch(train_batch)
dp_outs = []
for sub_batch in batches:
# 1
dp_out = training_step(sub_batch)
dp_outs.append(dp_out)
# 2
out = training_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
training_epoch_end(outs)
検証ループ
検証ループを追加するには、LightningModule の validation_step メソッドを override します :
class LitModel(pl.LightningModule):
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
self.log('val_loss', loss)
内部的には、Lightning は以下を行ないます :
# ...
for batch in train_dataloader:
loss = model.training_step()
loss.backward()
# ...
if validate_at_some_point:
# disable grads + batchnorm + dropout
torch.set_grad_enabled(False)
model.eval()
# ----------------- VAL LOOP ---------------
for val_batch in model.val_dataloader:
val_out = model.validation_step(val_batch)
# ----------------- VAL LOOP ---------------
# enable grads + batchnorm + dropout
torch.set_grad_enabled(True)
model.train()
検証エポック-level メトリクス
各 validation_step の総ての出力で何かを行なう必要がある場合、validation_epoch_end を override します。
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return pred
def validation_epoch_end(self, validation_step_outputs):
for pred in validation_step_outputs:
# do something with a pred
DataParallel で検証する
GPU に渡り各バッチからデータを分割する accelerator を使用して検証するとき、時にそれらを処理のためにマスター GPU 上で集計する必要があるかもしれません (dp, or ddp2)。
この場合、validation_step_end メソッドを実装します :
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
pred = ...
return {'loss': loss, 'pred': pred}
def validation_step_end(self, batch_parts):
gpu_0_prediction = batch_parts.pred[0]['pred']
gpu_1_prediction = batch_parts.pred[1]['pred']
# do something with both outputs
return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2
def validation_epoch_end(self, validation_step_outputs):
for out in validation_step_outputs:
# do something with preds
内部的に lightning が行なう完全な疑似コードは :
outs = []
for batch in dataloader:
batches = split_batch(batch)
dp_outs = []
for sub_batch in batches:
# 1
dp_out = validation_step(sub_batch)
dp_outs.append(dp_out)
# 2
out = validation_step_end(dp_outs)
outs.append(out)
# do something with the outputs for all batches
# 3
validation_epoch_end(outs)
テストループ
テストループを追加するためのプロセスは検証ループを追加するためのプロセスと同じです。詳細については上のセクションを参照してください。
唯一の違いはテストループは .test() が使用されるときに呼び出されるだけであることです :
model = Model() trainer = Trainer() trainer.fit() # automatically loads the best weights for you trainer.test(model)
test() を呼び出す 2 つの方法があります :
# call after training trainer = Trainer() trainer.fit(model) # automatically auto-loads the best weights trainer.test(test_dataloaders=test_dataloader) # or call with pretrained model model = MyLightningModule.load_from_checkpoint(PATH) trainer = Trainer() trainer.test(model, test_dataloaders=test_dataloader)
推論
研究のために、LightningModule はシステムとして最善に構造化されます。
import pytorch_lightning as pl
import torch
from torch import nn
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
def training_step(self, batch, batch_idx):
x, _ = batch
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# decode
recons = self.decoder(z)
# reconstruction
reconstruction_loss = nn.functional.mse_loss(recons, x)
return reconstruction_loss
def validation_step(self, batch, batch_idx):
x, _ = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
recons = self.decoder(z)
reconstruction_loss = nn.functional.mse_loss(recons, x)
self.log('val_reconstruction', reconstruction_loss)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
これはこのように訓練できます :
autoencoder = Autoencoder() trainer = pl.Trainer(gpus=1) trainer.fit(autoencoder, train_dataloader, val_dataloader)
この単純なモデルはこのように見えるサンプルを生成します (エンコーダとデコーダは非常に弱いです)

上のメソッドは lightning インターフェイスの一部です :
- training_step
- validation_step
- test_step
- configure_optimizers
この場合、訓練ループと val ループは正確に同じであることに注意してください。このコードをもちろん再利用できます。
class Autoencoder(pl.LightningModule):
def __init__(self, latent_dim=2):
super().__init__()
self.encoder = nn.Sequential(nn.Linear(28 * 28, 256), nn.ReLU(), nn.Linear(256, latent_dim))
self.decoder = nn.Sequential(nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 28 * 28))
def training_step(self, batch, batch_idx):
loss = self.shared_step(batch)
return loss
def validation_step(self, batch, batch_idx):
loss = self.shared_step(batch)
self.log('val_loss', loss)
def shared_step(self, batch):
x, _ = batch
# encode
x = x.view(x.size(0), -1)
z = self.encoder(x)
# decode
recons = self.decoder(z)
# loss
return nn.functional.mse_loss(recons, x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
総てのループが利用できる shared_step と呼ばれる新しいメソッドを作成します。このメソッド名は任意で 予約されていません。
研究内の推論
システムで推論を遂行することを望むところのケースでは LightningModule に forward メソッドを追加できます。
class Autoencoder(pl.LightningModule):
def forward(self, x):
return self.decoder(x)
forward を追加する優位点は複雑なシステムで、テキスト生成のような、遥かに多くの関係する推論手続きを行なうことができることです :
class Seq2Seq(pl.LightningModule):
def forward(self, x):
embeddings = self(x)
hidden_states = self.encoder(embeddings)
for h in hidden_states:
# decode
...
return decoded
プロダクション内の推論
プロダクションのようなケースについては、LightningModule 内で異なるモデルを iterate することを望むかもしれません。
import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
class ClassificationTask(pl.LightningModule):
def __init__(self, model):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.cross_entropy(y_hat, y)
acc = FM.accuracy(y_hat, y)
# loss is tensor. The Checkpoint Callback is monitoring 'checkpoint_on'
metrics = {'val_acc': acc, 'val_loss': loss}
self.log_dict(metrics)
return metrics
def test_step(self, batch, batch_idx):
metrics = self.validation_step(batch, batch_idx)
metrics = {'test_acc': metrics['val_acc'], 'test_loss': metrics['val_loss']}
self.log_dict(metrics)
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)
それから任意のモデルをこのタスクで fit されるために渡します
for model in [resnet50(), vgg16(), BidirectionalRNN()]:
task = ClassificationTask(model)
trainer = Trainer(gpus=2)
trainer.fit(task, train_dataloader, val_dataloader)
タスクは GAN 訓練、自己教師ありあるいは RL さえ実装するように任意に複雑であり得ます。
class GANTask(pl.LightningModule):
def __init__(self, generator, discriminator):
super().__init__()
self.generator = generator
self.discriminator = discriminator
...
このように使用されるとき、モデルはタスクから分離できてそれを LightningModule 内に保持する必要なくプロダクションで利用できます。
- onnx にエクスポートできます。
- あるいは jit を使用してトレースできます。
- あるいは python ランタイムで実行できます。
task = ClassificationTask(model) trainer = Trainer(gpus=2) trainer.fit(task, train_dataloader, val_dataloader) # use model after training or load weights and drop into the production system model.eval() y_hat = model(x)
以上