PyTorch 2.0 チュートリアル : 入門 : モデルパラメータの最適化 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/19/2023 (2.0.0)
* 本ページは、PyTorch 2.0 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:
- Introduction to PyTorch : Learn the Basics : Optimizing Model Parameters
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Website: www.classcat.com ; ClassCatJP
PyTorch 2.0 チュートリアル : 入門 : モデルパラメータの最適化
モデルとデータを持つ今、データ上でパラメータを最適化することによりモデルを訓練、検証そしてテストするときです。モデルの訓練は反復的なプロセスです ; 各反復でモデルは出力について推測を行ない、その推測内で誤差を計算し (損失)、(前のセクション で見たように) そのパラメータに関する誤差の導関数を集めて、そして勾配降下を使用してこれらのパラメータを 最適化します。このプロセスのより詳細なウォークスルーについては、backpropagation from 3Blue1Brown の動画を確認してください。
前提となるコード (Prerequisite Code)
Dataset & DataLoader と モデルの構築 の前のセクションからコードをロードします。
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
train_dataloader = DataLoader(training_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
model = NeuralNetwork()
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz 0%| | 0/26421880 [00:00<?, ?it/s] 0%| | 32768/26421880 [00:00<01:25, 308533.34it/s] 0%| | 65536/26421880 [00:00<01:26, 303989.64it/s] 0%| | 131072/26421880 [00:00<00:59, 440126.28it/s] 1%| | 229376/26421880 [00:00<00:41, 623857.36it/s] 2%|1 | 491520/26421880 [00:00<00:20, 1266466.02it/s] 4%|3 | 950272/26421880 [00:00<00:11, 2269763.66it/s] 7%|7 | 1933312/26421880 [00:00<00:05, 4471185.06it/s] 15%|#4 | 3833856/26421880 [00:00<00:02, 8607269.59it/s] 26%|##6 | 6946816/26421880 [00:00<00:01, 14841015.69it/s] 37%|###6 | 9732096/26421880 [00:01<00:00, 18064398.08it/s] 49%|####8 | 12877824/26421880 [00:01<00:00, 21318997.08it/s] 59%|#####9 | 15695872/26421880 [00:01<00:00, 22636137.26it/s] 71%|#######1 | 18808832/26421880 [00:01<00:00, 24427652.64it/s] 83%|########2 | 21889024/26421880 [00:01<00:00, 25578163.49it/s] 95%|#########4| 25001984/26421880 [00:01<00:00, 26420546.75it/s] 100%|##########| 26421880/26421880 [00:01<00:00, 16087653.96it/s] Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz 0%| | 0/29515 [00:00<?, ?it/s] 100%|##########| 29515/29515 [00:00<00:00, 266477.70it/s] 100%|##########| 29515/29515 [00:00<00:00, 265043.97it/s] Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz 0%| | 0/4422102 [00:00<?, ?it/s] 1%| | 32768/4422102 [00:00<00:14, 301419.72it/s] 1%|1 | 65536/4422102 [00:00<00:14, 299886.54it/s] 3%|2 | 131072/4422102 [00:00<00:09, 435724.08it/s] 5%|5 | 229376/4422102 [00:00<00:06, 617548.80it/s] 11%|#1 | 491520/4422102 [00:00<00:03, 1257879.50it/s] 21%|##1 | 950272/4422102 [00:00<00:01, 2253560.03it/s] 44%|####3 | 1933312/4422102 [00:00<00:00, 4446841.74it/s] 87%|########6 | 3833856/4422102 [00:00<00:00, 8563990.63it/s] 100%|##########| 4422102/4422102 [00:00<00:00, 5030366.42it/s] Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz 0%| | 0/5148 [00:00<?, ?it/s] 100%|##########| 5148/5148 [00:00<00:00, 24508827.46it/s] Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ハイパーパラメータ
ハイパーパラメータはモデル最適化プロセスを貴方に制御させる調整可能なパラメータです。異なるハイパーパラメータ値はモデル訓練と収束率に影響を与えることができます (ハイパーパラメータ調整について 更に読んでください)。
訓練のために以下のハイパーパラメータを定義します :
- エポック数 – データセットに対して iterate する回数
- バッチサイズ – パラメータが更新される前にネットワークを通して伝搬されるデータサンプルの数
- 学習率 – 各バッチ/エポックでモデル・パラメータをどのくらい更新するか。より小さい値はスローな学習スピードを生成し、その一方で大きな値は訓練の間に予想できない動作という結果になるかもしれません。
learning_rate = 1e-3
batch_size = 64
epochs = 5
最適化ループ
ひとたびハイパーパラメータを設定すれば、最適化ループでモデルを訓練して最適化することができます。最適化ループの各反復は エポック と呼ばれます。
各エポックは 2 つの主要パートから構成されます :
- 訓練ループ – 訓練データセットに渡り反復して最適なパラメータに収束することを試みます。
- 検証/テストループ – モデルパフォーマンスが改良されているか確認するためにテストデータセットに対して反復します。
訓練ループで使用される概念の幾つかに私達自身を簡潔に慣れさせましょう。最適化ループの 完全な実装 を見るためには前にジャンプします。
損失関数
ある訓練データが提示されたとき、未訓練ネットワークは正しい答えを与えない傾向にあります。損失関数 は得られた結果のターゲット値に対する相違の程度を測定し、そして訓練の間に最小化することを望むものが損失関数です。損失を計算するため、与えられたデータサンプルの入力を使用して予測を行ない、そしてそれを真のデータラベル値に対して比較します。
一般的な損失関数は、回帰タスクのための nn.MSELoss (平均二乗誤差)、そして分類のための nn.NLLLoss (負対数尤度) を含みます。nn.CrossEntropyLoss は nn.LogSoftmax と nn.NLLLoss を組み合わせています。
モデルの出力ロジットを nn.CrossEntropyLoss に渡します、これはロジットを正規化して予測誤差を計算します。
# Initialize the loss function
loss_fn = nn.CrossEntropyLoss()
Optimizer
最適化は各訓練ステップでモデル誤差を減少させるためにモデルパラメータを調整するプロセスです。最適化アルゴリズム はこのプロセスがどのように実行されるかを定義します (この例では確率的勾配降下を使用します)。総ての最適化ロジックは optimizer オブジェクトでカプセル化されます。ここでは、SGD optimizer を利用します ; 更に、PyTorch には ADAM と RMSProp のような利用可能な多くの 様々な optimizer があります、これらは異なる種類のモデルとデータに対してより良く動作します。
訓練される必要があるモデルのパラメータを登録して学習率ハイパーパラメータを渡すことにより optimizer を初期化します。
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
訓練ループの内側で、最適化は 3 つのステップで発生します :
- モデルパラメータの勾配をリセットするために optimizer.zero_grad() を呼び出します。勾配はデフォルトでは合計されます ; 二重カウントを防ぐため、各反復でそれを明示的にゼロに設定します。
- loss.backwards() の呼び出しで予測損失を逆伝播します。PyTorch は各パラメータに関する損失の勾配を deposit (正確に配置) します。
- ひとたび勾配を持てば、backward パスで集められた勾配によりパラメータを調整するために optimizer.step() を呼び出します。
完全な実装
私達は最適化コードについてループする train_loop、そしてテストデータに対してモデルのパフォーマンスを評価する test_loop を定義します。
def train_loop(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
# Compute prediction and loss
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), (batch + 1) * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test_loop(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
損失関数と optimizer を初期化し、そしてそれを train_loop と test_loop に渡します。モデルのパフォーマンスの向上を追跡するためにエポック数を自由に増やしてください。
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train_loop(train_dataloader, model, loss_fn, optimizer)
test_loop(test_dataloader, model, loss_fn)
print("Done!")
Epoch 1 ------------------------------- loss: 2.306274 [ 64/60000] loss: 2.285838 [ 6464/60000] loss: 2.269943 [12864/60000] loss: 2.261921 [19264/60000] loss: 2.245953 [25664/60000] loss: 2.216744 [32064/60000] loss: 2.223908 [38464/60000] loss: 2.186029 [44864/60000] loss: 2.199185 [51264/60000] loss: 2.156675 [57664/60000] Test Error: Accuracy: 43.2%, Avg loss: 2.147765 Epoch 2 ------------------------------- loss: 2.169220 [ 64/60000] loss: 2.154036 [ 6464/60000] loss: 2.094223 [12864/60000] loss: 2.105725 [19264/60000] loss: 2.057751 [25664/60000] loss: 1.993958 [32064/60000] loss: 2.032755 [38464/60000] loss: 1.943888 [44864/60000] loss: 1.964005 [51264/60000] loss: 1.880819 [57664/60000] Test Error: Accuracy: 49.6%, Avg loss: 1.877261 Epoch 3 ------------------------------- loss: 1.923636 [ 64/60000] loss: 1.892730 [ 6464/60000] loss: 1.766222 [12864/60000] loss: 1.803750 [19264/60000] loss: 1.695832 [25664/60000] loss: 1.644898 [32064/60000] loss: 1.681978 [38464/60000] loss: 1.567808 [44864/60000] loss: 1.608559 [51264/60000] loss: 1.498805 [57664/60000] Test Error: Accuracy: 58.4%, Avg loss: 1.512764 Epoch 4 ------------------------------- loss: 1.588199 [ 64/60000] loss: 1.555499 [ 6464/60000] loss: 1.396250 [12864/60000] loss: 1.469044 [19264/60000] loss: 1.354479 [25664/60000] loss: 1.347111 [32064/60000] loss: 1.376084 [38464/60000] loss: 1.281405 [44864/60000] loss: 1.330694 [51264/60000] loss: 1.232292 [57664/60000] Test Error: Accuracy: 62.6%, Avg loss: 1.253344 Epoch 5 ------------------------------- loss: 1.334092 [ 64/60000] loss: 1.319183 [ 6464/60000] loss: 1.146310 [12864/60000] loss: 1.253771 [19264/60000] loss: 1.136652 [25664/60000] loss: 1.153831 [32064/60000] loss: 1.188763 [38464/60000] loss: 1.104697 [44864/60000] loss: 1.157200 [51264/60000] loss: 1.075848 [57664/60000] Test Error: Accuracy: 64.3%, Avg loss: 1.092215 Epoch 6 ------------------------------- loss: 1.164995 [ 64/60000] loss: 1.170451 [ 6464/60000] loss: 0.982307 [12864/60000] loss: 1.118346 [19264/60000] loss: 1.000636 [25664/60000] loss: 1.022203 [32064/60000] loss: 1.070253 [38464/60000] loss: 0.991670 [44864/60000] loss: 1.042103 [51264/60000] loss: 0.976573 [57664/60000] Test Error: Accuracy: 65.3%, Avg loss: 0.986656 Epoch 7 ------------------------------- loss: 1.046831 [ 64/60000] loss: 1.072952 [ 6464/60000] loss: 0.869186 [12864/60000] loss: 1.026576 [19264/60000] loss: 0.912367 [25664/60000] loss: 0.928342 [32064/60000] loss: 0.991693 [38464/60000] loss: 0.918040 [44864/60000] loss: 0.961682 [51264/60000] loss: 0.910086 [57664/60000] Test Error: Accuracy: 66.7%, Avg loss: 0.914412 Epoch 8 ------------------------------- loss: 0.960576 [ 64/60000] loss: 1.005463 [ 6464/60000] loss: 0.788259 [12864/60000] loss: 0.960789 [19264/60000] loss: 0.852388 [25664/60000] loss: 0.859124 [32064/60000] loss: 0.936759 [38464/60000] loss: 0.868817 [44864/60000] loss: 0.903424 [51264/60000] loss: 0.862730 [57664/60000] Test Error: Accuracy: 68.1%, Avg loss: 0.862345 Epoch 9 ------------------------------- loss: 0.894723 [ 64/60000] loss: 0.954783 [ 6464/60000] loss: 0.727722 [12864/60000] loss: 0.911476 [19264/60000] loss: 0.809471 [25664/60000] loss: 0.806541 [32064/60000] loss: 0.895423 [38464/60000] loss: 0.834586 [44864/60000] loss: 0.859842 [51264/60000] loss: 0.826616 [57664/60000] Test Error: Accuracy: 69.3%, Avg loss: 0.822960 Epoch 10 ------------------------------- loss: 0.842345 [ 64/60000] loss: 0.914199 [ 6464/60000] loss: 0.680164 [12864/60000] loss: 0.873395 [19264/60000] loss: 0.776440 [25664/60000] loss: 0.765965 [32064/60000] loss: 0.861726 [38464/60000] loss: 0.809308 [44864/60000] loss: 0.825955 [51264/60000] loss: 0.797490 [57664/60000] Test Error: Accuracy: 70.6%, Avg loss: 0.791684 Done!
以上