Lightly 1.2 : Tutorials : 3. 衣類データ上の SimCLR の訓練 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/20/2022 (v1.2.27)
* 本ページは、Lightly の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Tutorials : Tutorial 3: Train SimCLR on Clothing
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Lightly 1.2 : Tutorials : 3. 衣類データ上の SimCLR の訓練
このチュートリアルでは、lightly を使用して SimCLR を訓練します。モデル、増強と訓練手順は A Simple Framework for Contrastive Learning of Visual Representations からです。
論文は対照学習のための非常に単純な訓練手順を探求しています。NCE に基づく通常の対照学習損失を使用しますので、手法はより大きいバッチサイズから非常に恩恵を受けます。この例では、256 のバッチサイズを使用し、64×64 ピクセルの画像毎の入力解像度と resnet-18 モデルの組み合わせで、このサンプルは 16GB の GPU メモリを必要とします。
このチュートリアルのために Alex Grigorev の 衣類データセット を使用します。
このチュートリアルで以下を学習します :
- SimCLR モデルの作成方法
- 画像表現を生成する方法
- 異なる増強が学習済みの表現にどのように影響を与えるか
インポート
このチュートリアルに必要な Python フレームワークをインポートします。
import os
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import lightly
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize
from PIL import Image
import numpy as np
Configuration
実験のために幾つかの設定パラメータを設定します。それらを自由に変更して効果を分析してください。
256 のバッチサイズと 128 の入力解像度を持つデフォルト設定は 6GB の GPU メモリを必要とします。
num_workers = 8
batch_size = 256
seed = 1
max_epochs = 20
input_size = 128
num_ftrs = 32
実験のためにシードを設定しましょう。
pl.seed_everything(seed)
1
path_to_data がダウンロードされた衣類データセットを指すことを確認してください。git clone https://github.com/alexeygrigorev/clothing-dataset.git を使用してそれをダウンロードできます。
path_to_data = '/datasets/clothing-dataset/images'
データ増強とローダのセットアップ
データセットからの画像は、衣服がテーブル、ベッドや床にあるとき上から取られていています。そのため、垂直反転やランダム回転 (90 度) のような追加の増強を利用できます。これらの増強を追加することで衣服のピースの方向についてモデルの不変性を学習します。例えば、シャツが逆さまかどうかは気にしませんが、シャツである構造についてはケアします。
別の増強と学習された不変性についてここで更に学習できます : 自己教師あり学習の高度なコンセプト
collate_fn = lightly.data.SimCLRCollateFunction(
input_size=input_size,
vf_prob=0.5,
rr_prob=0.5
)
# We create a torchvision transformation for embedding the dataset after
# training
test_transforms = torchvision.transforms.Compose([
torchvision.transforms.Resize((input_size, input_size)),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=lightly.data.collate.imagenet_normalize['mean'],
std=lightly.data.collate.imagenet_normalize['std'],
)
])
dataset_train_simclr = lightly.data.LightlyDataset(
input_dir=path_to_data
)
dataset_test = lightly.data.LightlyDataset(
input_dir=path_to_data,
transform=test_transforms
)
dataloader_train_simclr = torch.utils.data.DataLoader(
dataset_train_simclr,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn,
drop_last=True,
num_workers=num_workers
)
dataloader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=batch_size,
shuffle=False,
drop_last=False,
num_workers=num_workers
)
SimCLR モデルの作成
次に SimCLR モデルを作成します。それを PyTorch Lightning モジュールとして実装して Torchvision からの ResNet-18 バックボーンを使用します。Lightly は SimCLR 射影ヘッドと損失関数の実装を SimCLRProjectionHead と NTXentLoss クラスで提供しています。それらを単純にインポートしてモジュールでビルディングブロックを組み合わせることができます。
from lightly.models.modules.heads import SimCLRProjectionHead
from lightly.loss import NTXentLoss
class SimCLRModel(pl.LightningModule):
def __init__(self):
super().__init__()
# create a ResNet backbone and remove the classification head
resnet = torchvision.models.resnet18()
self.backbone = nn.Sequential(*list(resnet.children())[:-1])
hidden_dim = resnet.fc.in_features
self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, 128)
self.criterion = NTXentLoss()
def forward(self, x):
h = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(h)
return z
def training_step(self, batch, batch_idx):
(x0, x1), _, _ = batch
z0 = self.forward(x0)
z1 = self.forward(x1)
loss = self.criterion(z0, z1)
self.log("train_loss_ssl", loss)
return loss
def configure_optimizers(self):
optim = torch.optim.SGD(
self.parameters(), lr=6e-2, momentum=0.9, weight_decay=5e-4
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, max_epochs
)
return [optim], [scheduler]
最初に GPU が利用可能かを確認してから PyTorch Lightning トレーナーを使用してモジュールを訓練します。
gpus = 1 if torch.cuda.is_available() else 0
model = SimCLRModel()
trainer = pl.Trainer(
max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100
)
trainer.fit(model, dataloader_train_simclr)
/opt/hostedtoolcache/Python/3.10.2/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:90: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=100)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer. rank_zero_deprecation( /opt/hostedtoolcache/Python/3.10.2/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (22) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn( Training: 0it [00:00, ?it/s] Training: 0%| | 0/22 [00:00<?, ?it/s] Epoch 0: 0%| | 0/22 [00:00<?, ?it/s] Epoch 0: 100%|##########| 22/22 [00:06<00:00, 3.40it/s] Epoch 0: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.77, v_num=0] Epoch 0: 0%| | 0/22 [00:00<?, ?it/s, loss=5.77, v_num=0] Epoch 1: 0%| | 0/22 [00:00<?, ?it/s, loss=5.77, v_num=0] Epoch 1: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.77, v_num=0] Epoch 1: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.55, v_num=0] Epoch 1: 0%| | 0/22 [00:00<?, ?it/s, loss=5.55, v_num=0] Epoch 2: 0%| | 0/22 [00:00<?, ?it/s, loss=5.55, v_num=0] Epoch 2: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.55, v_num=0] Epoch 2: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.5, v_num=0] Epoch 2: 0%| | 0/22 [00:00<?, ?it/s, loss=5.5, v_num=0] Epoch 3: 0%| | 0/22 [00:00<?, ?it/s, loss=5.5, v_num=0] Epoch 3: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.5, v_num=0] Epoch 3: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.45, v_num=0] Epoch 3: 0%| | 0/22 [00:00<?, ?it/s, loss=5.45, v_num=0] Epoch 4: 0%| | 0/22 [00:00<?, ?it/s, loss=5.45, v_num=0] Epoch 4: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.45, v_num=0] Epoch 4: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.4, v_num=0] Epoch 4: 0%| | 0/22 [00:00<?, ?it/s, loss=5.4, v_num=0] Epoch 5: 0%| | 0/22 [00:00<?, ?it/s, loss=5.4, v_num=0] Epoch 5: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.4, v_num=0] Epoch 5: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.38, v_num=0] Epoch 5: 0%| | 0/22 [00:00<?, ?it/s, loss=5.38, v_num=0] Epoch 6: 0%| | 0/22 [00:00<?, ?it/s, loss=5.38, v_num=0] Epoch 6: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.38, v_num=0] Epoch 6: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.35, v_num=0] Epoch 6: 0%| | 0/22 [00:00<?, ?it/s, loss=5.35, v_num=0] Epoch 7: 0%| | 0/22 [00:00<?, ?it/s, loss=5.35, v_num=0] Epoch 7: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.35, v_num=0] Epoch 7: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.34, v_num=0] Epoch 7: 0%| | 0/22 [00:00<?, ?it/s, loss=5.34, v_num=0] Epoch 8: 0%| | 0/22 [00:00<?, ?it/s, loss=5.34, v_num=0] Epoch 8: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.34, v_num=0] Epoch 8: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.3, v_num=0] Epoch 8: 0%| | 0/22 [00:00<?, ?it/s, loss=5.3, v_num=0] Epoch 9: 0%| | 0/22 [00:00<?, ?it/s, loss=5.3, v_num=0] Epoch 9: 100%|##########| 22/22 [00:06<00:00, 3.38it/s, loss=5.3, v_num=0] Epoch 9: 100%|##########| 22/22 [00:06<00:00, 3.38it/s, loss=5.28, v_num=0] Epoch 9: 0%| | 0/22 [00:00<?, ?it/s, loss=5.28, v_num=0] Epoch 10: 0%| | 0/22 [00:00<?, ?it/s, loss=5.28, v_num=0] Epoch 10: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.28, v_num=0] Epoch 10: 100%|##########| 22/22 [00:06<00:00, 3.40it/s, loss=5.27, v_num=0] Epoch 10: 0%| | 0/22 [00:00<?, ?it/s, loss=5.27, v_num=0] Epoch 11: 0%| | 0/22 [00:00<?, ?it/s, loss=5.27, v_num=0] Epoch 11: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.27, v_num=0] Epoch 11: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.23, v_num=0] Epoch 11: 0%| | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=0] Epoch 12: 0%| | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=0] Epoch 12: 100%|##########| 22/22 [00:06<00:00, 3.38it/s, loss=5.23, v_num=0] Epoch 12: 100%|##########| 22/22 [00:06<00:00, 3.38it/s, loss=5.23, v_num=0] Epoch 12: 0%| | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=0] Epoch 13: 0%| | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=0] Epoch 13: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.23, v_num=0] Epoch 13: 100%|##########| 22/22 [00:06<00:00, 3.41it/s, loss=5.23, v_num=0] Epoch 13: 0%| | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=0] Epoch 14: 0%| | 0/22 [00:00<?, ?it/s, loss=5.23, v_num=0] Epoch 14: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.23, v_num=0] Epoch 14: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.21, v_num=0] Epoch 14: 0%| | 0/22 [00:00<?, ?it/s, loss=5.21, v_num=0] Epoch 15: 0%| | 0/22 [00:00<?, ?it/s, loss=5.21, v_num=0] Epoch 15: 100%|##########| 22/22 [00:06<00:00, 3.42it/s, loss=5.21, v_num=0] Epoch 15: 100%|##########| 22/22 [00:06<00:00, 3.42it/s, loss=5.2, v_num=0] Epoch 15: 0%| | 0/22 [00:00<?, ?it/s, loss=5.2, v_num=0] Epoch 16: 0%| | 0/22 [00:00<?, ?it/s, loss=5.2, v_num=0] Epoch 16: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.2, v_num=0] Epoch 16: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.19, v_num=0] Epoch 16: 0%| | 0/22 [00:00<?, ?it/s, loss=5.19, v_num=0] Epoch 17: 0%| | 0/22 [00:00<?, ?it/s, loss=5.19, v_num=0] Epoch 17: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.19, v_num=0] Epoch 17: 100%|##########| 22/22 [00:06<00:00, 3.39it/s, loss=5.18, v_num=0] Epoch 17: 0%| | 0/22 [00:00<?, ?it/s, loss=5.18, v_num=0] Epoch 18: 0%| | 0/22 [00:00<?, ?it/s, loss=5.18, v_num=0] Epoch 18: 100%|##########| 22/22 [00:06<00:00, 3.37it/s, loss=5.18, v_num=0] Epoch 18: 100%|##########| 22/22 [00:06<00:00, 3.37it/s, loss=5.18, v_num=0] Epoch 18: 0%| | 0/22 [00:00<?, ?it/s, loss=5.18, v_num=0] Epoch 19: 0%| | 0/22 [00:00<?, ?it/s, loss=5.18, v_num=0] Epoch 19: 100%|##########| 22/22 [00:06<00:00, 3.37it/s, loss=5.18, v_num=0] Epoch 19: 100%|##########| 22/22 [00:06<00:00, 3.37it/s, loss=5.17, v_num=0] Epoch 19: 100%|##########| 22/22 [00:06<00:00, 3.29it/s, loss=5.17, v_num=0]
次にちょうど訓練したモデルを使用してテスト画像から埋め込みを生成するヘルパー関数を作成します。埋め込みを生成するにはバックボーンだけが必要であることに注意してください、射影ヘッドは訓練のために必要なだけです。この部分のためにはモデルを eval モードに確実にしてください。
def generate_embeddings(model, dataloader):
"""Generates representations for all images in the dataloader with
the given model
"""
embeddings = []
filenames = []
with torch.no_grad():
for img, label, fnames in dataloader:
img = img.to(model.device)
emb = model.backbone(img).flatten(start_dim=1)
embeddings.append(emb)
filenames.extend(fnames)
embeddings = torch.cat(embeddings, 0)
embeddings = normalize(embeddings)
return embeddings, filenames
model.eval()
embeddings, filenames = generate_embeddings(model, dataloader_test)
最近傍の可視化
訓練済み埋め込みを見て幾つかのランダムサンプルに対する最近傍を可視化しましょう。
作業を単純化する幾つかのヘルパー関数を作成します。
def get_image_as_np_array(filename: str):
"""Returns an image as an numpy array
"""
img = Image.open(filename)
return np.asarray(img)
def plot_knn_examples(embeddings, filenames, n_neighbors=3, num_examples=6):
"""Plots multiple rows of random images with their nearest neighbors
"""
# lets look at the nearest neighbors for some samples
# we use the sklearn library
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(embeddings)
distances, indices = nbrs.kneighbors(embeddings)
# get 5 random samples
samples_idx = np.random.choice(len(indices), size=num_examples, replace=False)
# loop through our randomly picked samples
for idx in samples_idx:
fig = plt.figure()
# loop through their nearest neighbors
for plot_x_offset, neighbor_idx in enumerate(indices[idx]):
# add the subplot
ax = fig.add_subplot(1, len(indices[idx]), plot_x_offset + 1)
# get the correponding filename for the current index
fname = os.path.join(path_to_data, filenames[neighbor_idx])
# plot the image
plt.imshow(get_image_as_np_array(fname))
# set the title to the distance of the neighbor
ax.set_title(f'd={distances[idx][plot_x_offset]:.3f}')
# let's disable the axis
plt.axis('off')
画像のプロットを行いましょう。一番左の画像は問い合わせ画像で、同じ行のその隣のものは最近傍です。タイトル内に近傍の距離が表示されています。
plot_knn_examples(embeddings, filenames)
カラー不変性
カラー増強なしに再度訓練しましょう。これはモデルが画像のカラーを尊重するように強制します。
# Set color jitter and gray scale probability to 0
new_collate_fn = lightly.data.SimCLRCollateFunction(
input_size=input_size,
vf_prob=0.5,
rr_prob=0.5,
cj_prob=0.0,
random_gray_scale=0.0
)
# let's update our collate method and reuse our dataloader
dataloader_train_simclr.collate_fn = new_collate_fn
# then train a new model
model = SimCLRModel()
trainer = pl.Trainer(
max_epochs=max_epochs, gpus=gpus, progress_bar_refresh_rate=100
)
trainer.fit(model, dataloader_train_simclr)
# and generate again embeddings from the test set
model.eval()
embeddings, filenames = generate_embeddings(model, dataloader_test)
/opt/hostedtoolcache/Python/3.10.2/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:90: LightningDeprecationWarning: Setting `Trainer(progress_bar_refresh_rate=100)` is deprecated in v1.5 and will be removed in v1.7. Please pass `pytorch_lightning.callbacks.progress.TQDMProgressBar` with `refresh_rate` directly to the Trainer's `callbacks` argument instead. Or, to disable the progress bar pass `enable_progress_bar = False` to the Trainer. rank_zero_deprecation( /opt/hostedtoolcache/Python/3.10.2/x64/lib/python3.10/site-packages/pytorch_lightning/trainer/data_loading.py:432: UserWarning: The number of training samples (22) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch. rank_zero_warn( Training: 0it [00:00, ?it/s] Training: 0%| | 0/22 [00:00<?, ?it/s] Epoch 0: 0%| | 0/22 [00:00<?, ?it/s] Epoch 0: 100%|##########| 22/22 [00:06<00:00, 3.64it/s] Epoch 0: 100%|##########| 22/22 [00:06<00:00, 3.64it/s, loss=5.13, v_num=1] Epoch 0: 0%| | 0/22 [00:00<?, ?it/s, loss=5.13, v_num=1] Epoch 1: 0%| | 0/22 [00:00<?, ?it/s, loss=5.13, v_num=1] Epoch 1: 100%|##########| 22/22 [00:06<00:00, 3.59it/s, loss=5.13, v_num=1] Epoch 1: 100%|##########| 22/22 [00:06<00:00, 3.59it/s, loss=4.81, v_num=1] Epoch 1: 0%| | 0/22 [00:00<?, ?it/s, loss=4.81, v_num=1] Epoch 2: 0%| | 0/22 [00:00<?, ?it/s, loss=4.81, v_num=1] Epoch 2: 100%|##########| 22/22 [00:06<00:00, 3.63it/s, loss=4.81, v_num=1] Epoch 2: 100%|##########| 22/22 [00:06<00:00, 3.63it/s, loss=4.73, v_num=1] Epoch 2: 0%| | 0/22 [00:00<?, ?it/s, loss=4.73, v_num=1] Epoch 3: 0%| | 0/22 [00:00<?, ?it/s, loss=4.73, v_num=1] Epoch 3: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.73, v_num=1] Epoch 3: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.69, v_num=1] Epoch 3: 0%| | 0/22 [00:00<?, ?it/s, loss=4.69, v_num=1] Epoch 4: 0%| | 0/22 [00:00<?, ?it/s, loss=4.69, v_num=1] Epoch 4: 100%|##########| 22/22 [00:06<00:00, 3.64it/s, loss=4.69, v_num=1] Epoch 4: 100%|##########| 22/22 [00:06<00:00, 3.64it/s, loss=4.66, v_num=1] Epoch 4: 0%| | 0/22 [00:00<?, ?it/s, loss=4.66, v_num=1] Epoch 5: 0%| | 0/22 [00:00<?, ?it/s, loss=4.66, v_num=1] Epoch 5: 100%|##########| 22/22 [00:06<00:00, 3.62it/s, loss=4.66, v_num=1] Epoch 5: 100%|##########| 22/22 [00:06<00:00, 3.62it/s, loss=4.64, v_num=1] Epoch 5: 0%| | 0/22 [00:00<?, ?it/s, loss=4.64, v_num=1] Epoch 6: 0%| | 0/22 [00:00<?, ?it/s, loss=4.64, v_num=1] Epoch 6: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.64, v_num=1] Epoch 6: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.63, v_num=1] Epoch 6: 0%| | 0/22 [00:00<?, ?it/s, loss=4.63, v_num=1] Epoch 7: 0%| | 0/22 [00:00<?, ?it/s, loss=4.63, v_num=1] Epoch 7: 100%|##########| 22/22 [00:06<00:00, 3.60it/s, loss=4.63, v_num=1] Epoch 7: 100%|##########| 22/22 [00:06<00:00, 3.60it/s, loss=4.61, v_num=1] Epoch 7: 0%| | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=1] Epoch 8: 0%| | 0/22 [00:00<?, ?it/s, loss=4.61, v_num=1] Epoch 8: 100%|##########| 22/22 [00:06<00:00, 3.60it/s, loss=4.61, v_num=1] Epoch 8: 100%|##########| 22/22 [00:06<00:00, 3.60it/s, loss=4.6, v_num=1] Epoch 8: 0%| | 0/22 [00:00<?, ?it/s, loss=4.6, v_num=1] Epoch 9: 0%| | 0/22 [00:00<?, ?it/s, loss=4.6, v_num=1] Epoch 9: 100%|##########| 22/22 [00:06<00:00, 3.57it/s, loss=4.6, v_num=1] Epoch 9: 100%|##########| 22/22 [00:06<00:00, 3.57it/s, loss=4.59, v_num=1] Epoch 9: 0%| | 0/22 [00:00<?, ?it/s, loss=4.59, v_num=1] Epoch 10: 0%| | 0/22 [00:00<?, ?it/s, loss=4.59, v_num=1] Epoch 10: 100%|##########| 22/22 [00:06<00:00, 3.64it/s, loss=4.59, v_num=1] Epoch 10: 100%|##########| 22/22 [00:06<00:00, 3.64it/s, loss=4.58, v_num=1] Epoch 10: 0%| | 0/22 [00:00<?, ?it/s, loss=4.58, v_num=1] Epoch 11: 0%| | 0/22 [00:00<?, ?it/s, loss=4.58, v_num=1] Epoch 11: 100%|##########| 22/22 [00:06<00:00, 3.62it/s, loss=4.58, v_num=1] Epoch 11: 100%|##########| 22/22 [00:06<00:00, 3.62it/s, loss=4.57, v_num=1] Epoch 11: 0%| | 0/22 [00:00<?, ?it/s, loss=4.57, v_num=1] Epoch 12: 0%| | 0/22 [00:00<?, ?it/s, loss=4.57, v_num=1] Epoch 12: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.57, v_num=1] Epoch 12: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.56, v_num=1] Epoch 12: 0%| | 0/22 [00:00<?, ?it/s, loss=4.56, v_num=1] Epoch 13: 0%| | 0/22 [00:00<?, ?it/s, loss=4.56, v_num=1] Epoch 13: 100%|##########| 22/22 [00:06<00:00, 3.54it/s, loss=4.56, v_num=1] Epoch 13: 100%|##########| 22/22 [00:06<00:00, 3.54it/s, loss=4.56, v_num=1] Epoch 13: 0%| | 0/22 [00:00<?, ?it/s, loss=4.56, v_num=1] Epoch 14: 0%| | 0/22 [00:00<?, ?it/s, loss=4.56, v_num=1] Epoch 14: 100%|##########| 22/22 [00:06<00:00, 3.63it/s, loss=4.56, v_num=1] Epoch 14: 100%|##########| 22/22 [00:06<00:00, 3.63it/s, loss=4.55, v_num=1] Epoch 14: 0%| | 0/22 [00:00<?, ?it/s, loss=4.55, v_num=1] Epoch 15: 0%| | 0/22 [00:00<?, ?it/s, loss=4.55, v_num=1] Epoch 15: 100%|##########| 22/22 [00:06<00:00, 3.60it/s, loss=4.55, v_num=1] Epoch 15: 100%|##########| 22/22 [00:06<00:00, 3.60it/s, loss=4.56, v_num=1] Epoch 15: 0%| | 0/22 [00:00<?, ?it/s, loss=4.56, v_num=1] Epoch 16: 0%| | 0/22 [00:00<?, ?it/s, loss=4.56, v_num=1] Epoch 16: 100%|##########| 22/22 [00:06<00:00, 3.57it/s, loss=4.56, v_num=1] Epoch 16: 100%|##########| 22/22 [00:06<00:00, 3.57it/s, loss=4.55, v_num=1] Epoch 16: 0%| | 0/22 [00:00<?, ?it/s, loss=4.55, v_num=1] Epoch 17: 0%| | 0/22 [00:00<?, ?it/s, loss=4.55, v_num=1] Epoch 17: 100%|##########| 22/22 [00:06<00:00, 3.63it/s, loss=4.55, v_num=1] Epoch 17: 100%|##########| 22/22 [00:06<00:00, 3.63it/s, loss=4.54, v_num=1] Epoch 17: 0%| | 0/22 [00:00<?, ?it/s, loss=4.54, v_num=1] Epoch 18: 0%| | 0/22 [00:00<?, ?it/s, loss=4.54, v_num=1] Epoch 18: 100%|##########| 22/22 [00:06<00:00, 3.55it/s, loss=4.54, v_num=1] Epoch 18: 100%|##########| 22/22 [00:06<00:00, 3.55it/s, loss=4.54, v_num=1] Epoch 18: 0%| | 0/22 [00:00<?, ?it/s, loss=4.54, v_num=1] Epoch 19: 0%| | 0/22 [00:00<?, ?it/s, loss=4.54, v_num=1] Epoch 19: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.54, v_num=1] Epoch 19: 100%|##########| 22/22 [00:06<00:00, 3.61it/s, loss=4.55, v_num=1] Epoch 19: 100%|##########| 22/22 [00:06<00:00, 3.52it/s, loss=4.55, v_num=1]
他のサンプル
plot_knn_examples(embeddings, filenames)
What’s next?
# You could use the pre-trained model and train a classifier on top.
pretrained_resnet_backbone = model.backbone
# you can also store the backbone and use it in another code
state_dict = {
'resnet18_parameters': pretrained_resnet_backbone.state_dict()
}
torch.save(state_dict, 'model.pth')
THIS COULD BE IN A NEW FILE (e.g. inference.py)
Make sure you place the model.pth file in the same folder as this code
# load the model in a new file for inference
resnet18_new = torchvision.models.resnet18()
# note that we need to create exactly the same backbone in order to load the weights
backbone_new = nn.Sequential(*list(resnet18_new.children())[:-1])
ckpt = torch.load('model.pth')
backbone_new.load_state_dict(ckpt['resnet18_parameters'])
<All keys matched successfully>
以上