PyTorch Ignite 0.4.8 : Tutorials : 変分オートエンコーダ (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/31/2022 (0.4.8)
* 本ページは、Pytorch Ignite の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Tutorials : 変分オートエンコーダ
このチュートリアルは、ニューラルネットワーク・モデルを訓練し、実験をセットアップしてモデルを検証するために Ignite を使用します。
この実験では、Auto-Encoding Variational Bayes by Kingma と Welling を再現していきます。この論文は画像をベクトルにエンコードしてから画像を再構築するためにエンコーダ・デコーダ・アーキテクチャを使用します。
MNIST 画像をエンコードして再構築できることを望みます。MNIST は古典的な機械学習データセットで、それは数字 0 から 9 の白黒画像を含みます。50000 訓練画像と 10000 テスト画像があります。データセットは画像とラベルのペアから構成されます。
モデルを作成するために PyTorch を、データをインポートするために torchvision を、そしてモデルを訓練してモニタするために Ignite を使用していきます!
このコードの多くは 公式 PyTorch example から借りたものであることに注意してください。それに類似して、sigmoid と adagrad の代わりに、ReLU と adam optimizer を使用しています。
Let’s get started!
必要な依存関係
このサンプルでは、torch と ignite は既にインストールされていると仮定し、torchvision パッケージだけを必要とします。pip を使用してそれをインストールできます :
!pip install pytorch-ignite torchvision
ライブラリのインポート
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
モデルを作成するために torch, nn と functional モジュールをインポートします!DataLoader はダウンロードされたデータセットに対するイテレータを作成します。
以下のコードはまたマシンで利用可能な GPU があるかもチェックして存在するならばデバイスを GPU に割当てます。
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
from torch.nn import functional as F
SEED = 1234
torch.manual_seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torchvision はコンピュータビジョン・タスクのための複数のデータセットを提供するライブラリです。下では以下をインポートしています :
- MNIST : MNIST データセットをダウンロードするモジュール。
- save_image : テンソルを画像としてセーブする。
- make_grid : テンソルの連結を受け取り画像のグリッドを作成する。
- ToTensor : 画像をテンソルに変換する。
- Compose : 変換を集める。
from torchvision.datasets import MNIST
from torchvision.utils import save_image, make_grid
from torchvision.transforms import Compose, ToTensor
Ignite は PyTorch でニューラルネットワークを訓練する手助けをする高位ライブラリです。それは訓練ループ、様々なメトリクス、ハンドラと有用な contrib セクションをセットアップする Engine を装備しています!
下で以下をインポートします :
- Engine : データセットのバッチに対して与えられた process_function を実行し、進むにつれてイベントを発行します。
- Events : 特定のイベントで関数を発火させるためにユーザがエンジンに関数を装着することを可能にします。Eg: EPOCH_COMPLETED, ITERATION_STARTED, 等。
- MeanSquaredError : 平均二乗誤差を計算するメトリック。
- Loss : パラメータとして損失関数を取り、データセットに対して損失を計算する一般的なメトリック。
- RunningAverage : 訓練の間にエンジンに装着する一般的なメトリック。
- ModelCheckpoint : モデルをチェックポイントするハンドラ。
from ignite.engine import Engine, Events
from ignite.metrics import MeanSquaredError, Loss, RunningAverage
データ処理
下で使用する唯一の変換は画像をテンソルに変換するもので、MNIST はデータセットを貴方のマシンにダウンロードします。
- train_data は画像テンソルとラベルのタプルのリストです。val_data も同じで、画像の数が異なるだけです。
- 画像は 1 チャネルを持つ 28 x 28 テンソル、つまり 28 x 28 グレースケール画像です。
- ラベルは単一の整数値で、画像が示しているものを表します。
data_transform = Compose([ToTensor()])
train_data = MNIST(download=True, root="/tmp/mnist/", transform=data_transform, train=True)
val_data = MNIST(download=True, root="/tmp/mnist/", transform=data_transform, train=False)
image = train_data[0][0]
label = train_data[0][1]
print ('len(train_data) : ', len(train_data))
print ('len(val_data) : ', len(val_data))
print ('image.shape : ', image.shape)
print ('label : ', label)
img = plt.imshow(image.squeeze().numpy(), cmap='gray')
次に訓練と検証データセットのイテレータをセットアップしましょう。PyTorch の DataLoader を活用できます、これはデータセット、バッチサイズ、ワーカー数、デバイス、そして他の役立つパラメータを指定することを可能にします。
イテレータの出力が何かを見てみましょう :
- 各バッチが 32 画像と対応するラベルから構成されていることがわかります。
- サンプルはシャッフルされています。
- データは利用可能であれば GPU に置かれ、そうでなければ CPU を使用します。
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, **kwargs)
val_loader = DataLoader(val_data, batch_size=32, shuffle=True, **kwargs)
for batch in train_loader:
x, y = batch
break
print ('x.shape : ', x.shape)
print ('y.shape : ', y.shape)
x.shape : torch.Size([32, 1, 28, 28]) y.shape : torch.Size([32])
モデルが画像をどのくらい上手く再構築しているか可視化するため、上の x の値を、モデルから生成された再構築と比較するために使用できる画像のセットとしてセーブしましょう。
fixed_images = x.to(device)
VAE モデル
VAE は完全結合層から成るモデルで、これは平坦化された画像を受け取り、それらを完全結合層に渡して画像を低次元ベクトルに reduce します。それからベクトルは、入力と同じサイズのベクトルを生成するために、エンコーディング・ステップからの完全結合層の重みのミラーリングされたセットに渡されます。
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
モデル、Optimizer と損失を作成する
以下で VAE モデルのインスタンスを作成します。モデルはデバイスに置かれ、そして Binary Cross Entropy + KL Divergence の損失関数が使用されて Adam optimizer がセットアップされます。
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
def kld_loss(x_pred, x, mu, logvar):
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
bce_loss = nn.BCELoss(reduction='sum')
Ignite を使用した訓練と評価
Trainer エンジン – process_function
Ignite の Engine は与えられたバッチを処理する process_function を定義することをユーザに可能にします、これはデータセットの総てのバッチに適用されます。これはモデルを訓練して検証するために適用できる一般的なクラスです!process_function は 2 つのパラメータ engine と batch を持ちます。
trainer の関数が何を行なうかを通り抜けましょう :
- モデルを train モードに設定する。
- optimizer の勾配をゼロに設定する。
- バッチから x を生成する。
- x を shape (-1, 784) に平坦化する。
- モデルと x を使用して x を x_pred として再構築するために forward パスを実行する。モデルはまた mu, logvar を返します。
- x_pred, x, logvar and mu を使用して損失を計算する。
- モデルパラメータのための勾配を計算するために損失を使用して backward パスを実行します。
- 勾配と optimizer を使用してモデルパラメータは最適化されます。
- スカラー損失を返します。
下は訓練プロセスの間の単一演算です。この process_function は訓練エンジンに装着されます。
def process_function(engine, batch):
model.train()
optimizer.zero_grad()
x, _ = batch
x = x.to(device)
x = x.view(-1, 784)
x_pred, mu, logvar = model(x)
BCE = bce_loss(x_pred, x)
KLD = kld_loss(x_pred, x, mu, logvar)
loss = BCE + KLD
loss.backward()
optimizer.step()
return loss.item(), BCE.item(), KLD.item()
評価エンジン – process_function
訓練プロセス関数と同様に、単一バッチを評価する関数をセットアップします。ここに eval_function が何を行なうかがあります :
- モデルを eval モードに設定する。
- バッチから x を生成する。
- torch.no_grad() で、どのような後続のステップに対しても勾配は計算されません。
- x を shape (-1, 784) に平坦化する。
- モデルと x を使用して x を x_pred として再構築するために forward パスを実行します。モデルはまた mu, logvar も返します。
- x_pred, x, mu と logvar を返します。
Ignite はメトリクスを trainer ではなく evaluator に装着することを勧めています、何故ならば訓練の間、モデルパラメータは常に変化するので安定モデル上でモデルを評価することが最善であるからです。この情報は重要です、訓練と評価のための関数には違いがあるからです。訓練は単一のスカラー損失を返します。評価は y_pred と y を返し、その出力はデータセット全体に対するバッチ毎のメトリクスを計算するために使用されます。
Ignite の総てのメトリクスはエンジンに装着された関数の出力として y_pred と y を必要とします。
def evaluate_function(engine, batch):
model.eval()
with torch.no_grad():
x, _ = batch
x = x.to(device)
x = x.view(-1, 784)
x_pred, mu, logvar = model(x)
kwargs = {'mu': mu, 'logvar': logvar}
return x_pred, x, kwargs
訓練と評価エンジンのインスタンス化
上で定義された関数を使用して下で 2 つのエンジン、trainer と evaluator を作成します。訓練と検証セット上でメトリクスの履歴を追跡し続けるために辞書も定義します。
trainer = Engine(process_function)
evaluator = Engine(evaluate_function)
training_history = {'bce': [], 'kld': [], 'mse': []}
validation_history = {'bce': [], 'kld': [], 'mse': []}
メトリクス – RunningAverage, MeanSquareError と Loss
最初に、各バッチに対するスカラー損失、バイナリ交差エントロピーと KL ダイバージェンス出力の移動平均を追跡するために RunningAverage のメトリックを装着します。
RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'bce')
RunningAverage(output_transform=lambda x: x[2]).attach(trainer, 'kld')
ここで評価のために使用したい 3 つのメトリクスがあります – 平均二乗誤差、バイナリ交差エントロピーと KL ダイバージェンスです。先に気付いた場合、eval_function は x_pred, x と 2,3 の他の値を返し、MeanSquaredError はバッチ毎に 2 つの値だけを必要とします。
各バッチに対して、engine.state.output は x_pred, x と kwargs です、これはメトリックの必要性に基づいて engine.state.output から値だけを抽出するために output_transform を使用する理由です。
Loss については、loss_function に必要な総てのパラメータは engine.state.output が出力するので、定義した loss_function を渡してそれを単純に evaluator に装着します。
MeanSquaredError(output_transform=lambda x: [x[0], x[1]]).attach(evaluator, 'mse')
Loss(bce_loss, output_transform=lambda x: [x[0], x[1]]).attach(evaluator, 'bce')
Loss(kld_loss).attach(evaluator, 'kld')
カスタム関数を特定のイベントでエンジンに装着する
下で独自のカスタム関数を定義して訓練プロセスの様々な Event にそれらを装着する方法を見ます。
最初の方法はデコレータを使用します、シンタクスは単純です – @ trainer.on(Events.EPOCH_COMPLETED)、つまりデコレートされた関数はトレーナーに装着されて各エポックの最後に呼び出されます。
2 番目の方法は trainer の add_event_handler ハンドラを使用します – trainer.add_event_handler(Events.EPOCH_COMPLETED, custom_function)。これは上と同じ結果を得ます。
下の関数は各エポックの最後に訓練の間の損失をプリントします。
@trainer.on(Events.EPOCH_COMPLETED)
def print_trainer_logs(engine):
avg_loss = engine.state.metrics['loss']
avg_bce = engine.state.metrics['bce']
avg_kld = engine.state.metrics['kld']
print("Trainer Results - Epoch {} - Avg loss: {:.2f} Avg bce: {:.2f} Avg kld: {:.2f}"
.format(engine.state.epoch, avg_loss, avg_bce, avg_kld))
下の関数は evaluator のログをプリントして訓練と検証データセットに対するメトリクスの履歴を更新し、パラメータ DataLoader と mode を受け取ることが分かります。このように関数を再目的化して trainer に 2 度装着します、一度は訓練データセット上の評価で他方は検証データセット上です。
def print_logs(engine, dataloader, mode, history_dict):
evaluator.run(dataloader, max_epochs=1)
metrics = evaluator.state.metrics
avg_mse = metrics['mse']
avg_bce = metrics['bce']
avg_kld = metrics['kld']
avg_loss = avg_bce + avg_kld
print(
mode + " Results - Epoch {} - Avg mse: {:.2f} Avg loss: {:.2f} Avg bce: {:.2f} Avg kld: {:.2f}"
.format(engine.state.epoch, avg_mse, avg_loss, avg_bce, avg_kld))
for key in evaluator.state.metrics.keys():
history_dict[key].append(evaluator.state.metrics[key])
trainer.add_event_handler(Events.EPOCH_COMPLETED, print_logs, train_loader, 'Training', training_history)
trainer.add_event_handler(Events.EPOCH_COMPLETED, print_logs, val_loader, 'Validation', validation_history)
下の関数は再構築された画像を生成するために画像のセット (fixed_images) と VAE モデルを使用します、それから画像はグリッド内に形成され、ローカルマシンにセーブされて下でノートブックで表示されます。この関数を訓練プロセスの開始とエポックの最後に装着します、このようにしてモデルが画像の再構築においてどのくらい良くなるか可視化することができます。
def compare_images(engine, save_img=False):
epoch = engine.state.epoch
reconstructed_images = model(fixed_images.view(-1, 784))[0].view(-1, 1, 28, 28)
comparison = torch.cat([fixed_images, reconstructed_images])
if save_img:
save_image(comparison.detach().cpu(), 'reconstructed_epoch_' + str(epoch) + '.png', nrow=8)
comparison_image = make_grid(comparison.detach().cpu(), nrow=8)
fig = plt.figure(figsize=(5, 5));
output = plt.imshow(comparison_image.permute(1, 2, 0));
plt.title('Epoch ' + str(epoch));
plt.show();
trainer.add_event_handler(Events.STARTED, compare_images, save_img=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=5), compare_images, save_img=False)
エンジンの実行
次に、20 エポック trainer を実行して結果を監視します。
e = trainer.run(train_loader, max_epochs=20)
Trainer Results - Epoch 1 - Avg loss: 3726.80 Avg bce: 2979.88 Avg kld: 746.93 Training Results - Epoch 1 - Avg mse: 14.30 Avg loss: 3707.20 Avg bce: 2951.70 Avg kld: 755.49 Validation Results - Epoch 1 - Avg mse: 14.04 Avg loss: 3676.85 Avg bce: 2918.95 Avg kld: 757.90 Trainer Results - Epoch 2 - Avg loss: 3560.70 Avg bce: 2778.40 Avg kld: 782.31 Training Results - Epoch 2 - Avg mse: 12.50 Avg loss: 3537.77 Avg bce: 2768.78 Avg kld: 768.99 Validation Results - Epoch 2 - Avg mse: 12.28 Avg loss: 3513.11 Avg bce: 2741.00 Avg kld: 772.11 Trainer Results - Epoch 3 - Avg loss: 3525.03 Avg bce: 2727.09 Avg kld: 797.93 Training Results - Epoch 3 - Avg mse: 11.53 Avg loss: 3485.33 Avg bce: 2676.06 Avg kld: 809.27 Validation Results - Epoch 3 - Avg mse: 11.39 Avg loss: 3466.74 Avg bce: 2656.27 Avg kld: 810.47 Trainer Results - Epoch 4 - Avg loss: 3475.28 Avg bce: 2674.27 Avg kld: 801.01 Training Results - Epoch 4 - Avg mse: 11.33 Avg loss: 3443.58 Avg bce: 2655.76 Avg kld: 787.82 Validation Results - Epoch 4 - Avg mse: 11.23 Avg loss: 3428.17 Avg bce: 2640.41 Avg kld: 787.76 Trainer Results - Epoch 5 - Avg loss: 3426.41 Avg bce: 2626.99 Avg kld: 799.42 Training Results - Epoch 5 - Avg mse: 11.39 Avg loss: 3429.65 Avg bce: 2659.27 Avg kld: 770.38 Validation Results - Epoch 5 - Avg mse: 11.36 Avg loss: 3420.94 Avg bce: 2650.27 Avg kld: 770.67 Trainer Results - Epoch 6 - Avg loss: 3431.80 Avg bce: 2629.85 Avg kld: 801.94 Training Results - Epoch 6 - Avg mse: 10.91 Avg loss: 3398.44 Avg bce: 2611.61 Avg kld: 786.83 Validation Results - Epoch 6 - Avg mse: 10.86 Avg loss: 3387.81 Avg bce: 2601.49 Avg kld: 786.32 Trainer Results - Epoch 7 - Avg loss: 3408.46 Avg bce: 2604.93 Avg kld: 803.54 Training Results - Epoch 7 - Avg mse: 10.79 Avg loss: 3389.06 Avg bce: 2600.44 Avg kld: 788.61 Validation Results - Epoch 7 - Avg mse: 10.75 Avg loss: 3379.07 Avg bce: 2591.14 Avg kld: 787.93 Trainer Results - Epoch 8 - Avg loss: 3378.76 Avg bce: 2574.27 Avg kld: 804.48 Training Results - Epoch 8 - Avg mse: 10.63 Avg loss: 3371.22 Avg bce: 2584.10 Avg kld: 787.12 Validation Results - Epoch 8 - Avg mse: 10.64 Avg loss: 3364.51 Avg bce: 2578.62 Avg kld: 785.89 Trainer Results - Epoch 9 - Avg loss: 3369.00 Avg bce: 2570.99 Avg kld: 798.01 Training Results - Epoch 9 - Avg mse: 10.47 Avg loss: 3359.88 Avg bce: 2569.40 Avg kld: 790.48 Validation Results - Epoch 9 - Avg mse: 10.48 Avg loss: 3354.08 Avg bce: 2563.92 Avg kld: 790.16 Trainer Results - Epoch 10 - Avg loss: 3384.09 Avg bce: 2581.98 Avg kld: 802.12 Training Results - Epoch 10 - Avg mse: 10.29 Avg loss: 3360.48 Avg bce: 2550.36 Avg kld: 810.12 Validation Results - Epoch 10 - Avg mse: 10.31 Avg loss: 3354.99 Avg bce: 2545.82 Avg kld: 809.17 Trainer Results - Epoch 11 - Avg loss: 3381.03 Avg bce: 2572.36 Avg kld: 808.67 Training Results - Epoch 11 - Avg mse: 10.02 Avg loss: 3343.59 Avg bce: 2525.49 Avg kld: 818.11 Validation Results - Epoch 11 - Avg mse: 10.07 Avg loss: 3341.57 Avg bce: 2524.22 Avg kld: 817.35 Trainer Results - Epoch 12 - Avg loss: 3334.09 Avg bce: 2533.24 Avg kld: 800.85 Training Results - Epoch 12 - Avg mse: 10.21 Avg loss: 3343.52 Avg bce: 2543.29 Avg kld: 800.23 Validation Results - Epoch 12 - Avg mse: 10.25 Avg loss: 3340.80 Avg bce: 2541.98 Avg kld: 798.82 Trainer Results - Epoch 13 - Avg loss: 3349.17 Avg bce: 2546.56 Avg kld: 802.61 Training Results - Epoch 13 - Avg mse: 10.09 Avg loss: 3341.05 Avg bce: 2530.32 Avg kld: 810.73 Validation Results - Epoch 13 - Avg mse: 10.16 Avg loss: 3339.41 Avg bce: 2530.93 Avg kld: 808.48 Trainer Results - Epoch 14 - Avg loss: 3341.74 Avg bce: 2531.26 Avg kld: 810.48 Training Results - Epoch 14 - Avg mse: 10.07 Avg loss: 3337.68 Avg bce: 2528.95 Avg kld: 808.73 Validation Results - Epoch 14 - Avg mse: 10.11 Avg loss: 3335.04 Avg bce: 2528.32 Avg kld: 806.72 Trainer Results - Epoch 15 - Avg loss: 3355.13 Avg bce: 2546.82 Avg kld: 808.31 Training Results - Epoch 15 - Avg mse: 9.88 Avg loss: 3321.04 Avg bce: 2509.58 Avg kld: 811.46 Validation Results - Epoch 15 - Avg mse: 9.95 Avg loss: 3321.06 Avg bce: 2511.42 Avg kld: 809.64 Trainer Results - Epoch 16 - Avg loss: 3328.52 Avg bce: 2519.90 Avg kld: 808.61 Training Results - Epoch 16 - Avg mse: 9.82 Avg loss: 3323.66 Avg bce: 2504.24 Avg kld: 819.41 Validation Results - Epoch 16 - Avg mse: 9.88 Avg loss: 3323.40 Avg bce: 2504.71 Avg kld: 818.69 Trainer Results - Epoch 17 - Avg loss: 3326.10 Avg bce: 2519.43 Avg kld: 806.68 Training Results - Epoch 17 - Avg mse: 10.00 Avg loss: 3316.19 Avg bce: 2522.44 Avg kld: 793.75 Validation Results - Epoch 17 - Avg mse: 10.08 Avg loss: 3318.46 Avg bce: 2526.19 Avg kld: 792.27 Trainer Results - Epoch 18 - Avg loss: 3330.66 Avg bce: 2520.28 Avg kld: 810.38 Training Results - Epoch 18 - Avg mse: 9.78 Avg loss: 3308.34 Avg bce: 2499.15 Avg kld: 809.19 Validation Results - Epoch 18 - Avg mse: 9.86 Avg loss: 3309.42 Avg bce: 2501.67 Avg kld: 807.75 Trainer Results - Epoch 19 - Avg loss: 3305.99 Avg bce: 2500.76 Avg kld: 805.23 Training Results - Epoch 19 - Avg mse: 9.83 Avg loss: 3308.12 Avg bce: 2503.78 Avg kld: 804.34 Validation Results - Epoch 19 - Avg mse: 9.91 Avg loss: 3309.38 Avg bce: 2507.52 Avg kld: 801.86 Trainer Results - Epoch 20 - Avg loss: 3305.14 Avg bce: 2503.61 Avg kld: 801.53 Training Results - Epoch 20 - Avg mse: 9.93 Avg loss: 3311.43 Avg bce: 2511.87 Avg kld: 799.55 Validation Results - Epoch 20 - Avg mse: 10.05 Avg loss: 3316.48 Avg bce: 2519.03 Avg kld: 797.44
結果のプロット
下で訓練と検証セット上で収集されたメトリクスをプロットするのを見ます。バイナリ交差エントロピー、平均二乗誤差と KL ダイバージェンスの履歴をプロットします。
plt.plot(range(20), training_history['bce'], 'dodgerblue', label='training')
plt.plot(range(20), validation_history['bce'], 'orange', label='validation')
plt.xlim(0, 20);
plt.xlabel('Epoch')
plt.ylabel('BCE')
plt.title('Binary Cross Entropy on Training/Validation Set')
plt.legend();
plt.plot(range(20), training_history['kld'], 'dodgerblue', label='training')
plt.plot(range(20), validation_history['kld'], 'orange', label='validation')
plt.xlim(0, 20);
plt.xlabel('Epoch')
plt.ylabel('KLD')
plt.title('KL Divergence on Training/Validation Set')
plt.legend();
plt.plot(range(20), training_history['mse'], 'dodgerblue', label='training')
plt.plot(range(20), validation_history['mse'], 'orange', label='validation')
plt.xlim(0, 20);
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.title('Mean Squared Error on Training/Validation Set')
plt.legend();
以上