PyTorch 1.8 チュートリアル : PyTorch の学習 : 基本 – モデル・パラメータを最適化する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/17/2021 (1.8.0)
* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:
- Learning PyTorch : Learn the Basics : Optimizing Model Parameters
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
PyTorch の学習 : 基本 – モデル・パラメータを最適化する
モデルとデータを持つ今、データ上でパラメータを最適化することによりモデルを訓練、検証そしてテストするときです。モデルの訓練は反復的なプロセスです ; (エポックと呼ばれる) 各反復でモデルは出力について推測を行ない、その推測内で誤差を計算し (損失)、(前のセクション で見たように) そのパラメータに関する誤差の導関数を集めて、そして勾配降下を使用してこれらのパラメータを最適化します。このプロセスのより詳細なウォークスルーについては、backpropagation from 3Blue1Brown の動画を確認してください。
必要なコード
Dataset & DataLoader と モデルの構築 の前のセクションからコードをロードします。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw Processing... Done!
ハイパーパラメータ
ハイパーパラメータはモデル最適化プロセスを貴方に制御させる調整可能なパラメータです。異なるハイパーパラメータ値はモデル訓練と収束レートに影響を与えることができます (ハイパーパラメータ調整について 更に読んでください)。
訓練のために以下のハイパーパラメータを定義します :
- エポック数 – データセットに渡り iterate する回数
- バッチサイズ – 各エポックでモデルにより見られるデータサンプルの数
- 学習率 – 各バッチ/エポックでモデル・パラメータをどのくらい更新するかです。より小さい値はスローな学習スピードを yield し、その一方で大きな値は訓練の間に予想できない動作という結果になるかもしれません。
learning_rate = 1e-3
batch_size = 64
epochs = 5
最適化ループ
ひとたびハイパーパラメータを設定すれば、最適化ループでモデルを訓練して最適化することができます。最適化ループの各反復は エポック と呼ばれます。
各エポックは 2 つの主要パートから成ります :
- 訓練ループ – 訓練データセットに渡り反復して最適なパラメータに収束することを試みます。
- 検証/テストループ – モデルパフォーマンスが改良されているか確認するためにテストデータセットに渡り反復します。
訓練ループで使用される概念の幾つかに私達自身を簡潔に慣れさせましょう。最適化ループの Full 実装 を見るためには前にジャンプします。
損失関数
ある訓練データが提示されたとき、未訓練ネットワークは正しい答えを与えない傾向にあります。損失関数は得られた結果のターゲット値への相違点の程度を測定し、そして訓練の間に最小化することを望むものが損失関数です。損失を計算するため、与えられたデータサンプルの入力を使用して予測を行ない、そしてそれを真のデータラベル値に対して比較します。
一般的な損失関数は回帰タスクのための nn.MSELoss (平均二乗誤差)、そして分類のための nn.NLLLoss (負対数尤度) を含みます。nn.CrossEntropyLoss は nn.LogSoftmax と nn.NLLLoss を連結しています。
モデルの出力ロジットを nn.CrossEntropyLoss に渡します、これはロジットを正規化して予測誤差を計算します。
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
Optimizer
最適化は各訓練ステップでモデル誤差を減じるためにモデルパラメータを調整するプロセスです。最適化アルゴリズムはこのプロセスがどのように遂行されるかを定義します (この例では確率的勾配降下を使用します)。総ての最適化ロジックは optimizer オブジェクトでカプセル化されます。ここでは、SGD optimizer を利用します ; 更に、PyTorch では ADAM と RMSProp のような利用可能な多くの 様々な optimizer があります、これらは異なる種類のモデルとデータのためにより良く動作します。
訓練される必要があるモデルのパラメータを登録して学習率ハイパーパラメータを渡すことにより optimizer を初期化します。
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
訓練ループの内側で、最適化は 3 つのステップで発生します :
- モデルパラメータの勾配をリセットするために optimizer.zero_grad() を呼び出します。勾配はデフォルトでは合計されます ; 二重カウントを防ぐため、各反復でそれを明示的にゼロに設定します。
- loss.backwards() への呼び出しで予測損失を逆伝播します。PyTorch は各パラメータに関する損失の勾配を deposit します。
- ひとたび勾配を持てば、backward パスで集められた勾配によりパラメータを調整するために optimizer.step() を呼び出します。
完全な実装
私達は最適化コードに渡りループする train_loop、そしてテストデータに対してモデルのパフォーマンスを評価する test_loop を定義します。
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= size
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
損失関数と optimizer を初期化し、そしてそれを train_loop と test_loop に渡します。モデルがパフォーマンスを改善していることを追跡するためにエポック数を自由に増やしてください。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1 ------------------------------- loss: 2.311547 [ 0/60000] loss: 2.301839 [ 6400/60000] loss: 2.297934 [12800/60000] loss: 2.284441 [19200/60000] loss: 2.287610 [25600/60000] loss: 2.293657 [32000/60000] loss: 2.268280 [38400/60000] loss: 2.273554 [44800/60000] loss: 2.276760 [51200/60000] loss: 2.244509 [57600/60000] Test Error: Accuracy: 35.5%, Avg loss: 0.035344 Epoch 2 ------------------------------- loss: 2.280794 [ 0/60000] loss: 2.257797 [ 6400/60000] loss: 2.243957 [12800/60000] loss: 2.216885 [19200/60000] loss: 2.205212 [25600/60000] loss: 2.234457 [32000/60000] loss: 2.186525 [38400/60000] loss: 2.198901 [44800/60000] loss: 2.228637 [51200/60000] loss: 2.146443 [57600/60000] Test Error: Accuracy: 37.2%, Avg loss: 0.034046 Epoch 3 ------------------------------- loss: 2.237602 [ 0/60000] loss: 2.185120 [ 6400/60000] loss: 2.159259 [12800/60000] loss: 2.110712 [19200/60000] loss: 2.081194 [25600/60000] loss: 2.145947 [32000/60000] loss: 2.059634 [38400/60000] loss: 2.083288 [44800/60000] loss: 2.157574 [51200/60000] loss: 1.995380 [57600/60000] Test Error: Accuracy: 36.6%, Avg loss: 0.032120 Epoch 4 ------------------------------- loss: 2.173856 [ 0/60000] loss: 2.080033 [ 6400/60000] loss: 2.043522 [12800/60000] loss: 1.964996 [19200/60000] loss: 1.923586 [25600/60000] loss: 2.032249 [32000/60000] loss: 1.899945 [38400/60000] loss: 1.948169 [44800/60000] loss: 2.077470 [51200/60000] loss: 1.823172 [57600/60000] Test Error: Accuracy: 36.9%, Avg loss: 0.030078 Epoch 5 ------------------------------- loss: 2.106833 [ 0/60000] loss: 1.977626 [ 6400/60000] loss: 1.939335 [12800/60000] loss: 1.823382 [19200/60000] loss: 1.794451 [25600/60000] loss: 1.936425 [32000/60000] loss: 1.759290 [38400/60000] loss: 1.839000 [44800/60000] loss: 2.009614 [51200/60000] loss: 1.686389 [57600/60000] Test Error: Accuracy: 38.6%, Avg loss: 0.028461 Epoch 6 ------------------------------- loss: 2.049446 [ 0/60000] loss: 1.899526 [ 6400/60000] loss: 1.857706 [12800/60000] loss: 1.713123 [19200/60000] loss: 1.704621 [25600/60000] loss: 1.865643 [32000/60000] loss: 1.658126 [38400/60000] loss: 1.759566 [44800/60000] loss: 1.960299 [51200/60000] loss: 1.595586 [57600/60000] Test Error: Accuracy: 41.5%, Avg loss: 0.027330 Epoch 7 ------------------------------- loss: 2.003137 [ 0/60000] loss: 1.845121 [ 6400/60000] loss: 1.795158 [12800/60000] loss: 1.636422 [19200/60000] loss: 1.645869 [25600/60000] loss: 1.814831 [32000/60000] loss: 1.591536 [38400/60000] loss: 1.705516 [44800/60000] loss: 1.923961 [51200/60000] loss: 1.532001 [57600/60000] Test Error: Accuracy: 40.8%, Avg loss: 0.026488 Epoch 8 ------------------------------- loss: 1.958604 [ 0/60000] loss: 1.796570 [ 6400/60000] loss: 1.729746 [12800/60000] loss: 1.569006 [19200/60000] loss: 1.540844 [25600/60000] loss: 1.759719 [32000/60000] loss: 1.510318 [38400/60000] loss: 1.647093 [44800/60000] loss: 1.817549 [51200/60000] loss: 1.458255 [57600/60000] Test Error: Accuracy: 43.5%, Avg loss: 0.025153 Epoch 9 ------------------------------- loss: 1.861018 [ 0/60000] loss: 1.717396 [ 6400/60000] loss: 1.649649 [12800/60000] loss: 1.495671 [19200/60000] loss: 1.409449 [25600/60000] loss: 1.711848 [32000/60000] loss: 1.442910 [38400/60000] loss: 1.608273 [44800/60000] loss: 1.740068 [51200/60000] loss: 1.412943 [57600/60000] Test Error: Accuracy: 44.4%, Avg loss: 0.024262 Epoch 10 ------------------------------- loss: 1.792431 [ 0/60000] loss: 1.664259 [ 6400/60000] loss: 1.596819 [12800/60000] loss: 1.445964 [19200/60000] loss: 1.330631 [25600/60000] loss: 1.679749 [32000/60000] loss: 1.400718 [38400/60000] loss: 1.583857 [44800/60000] loss: 1.695087 [51200/60000] loss: 1.384648 [57600/60000] Test Error: Accuracy: 45.4%, Avg loss: 0.023702 Done!
以上