Lightly 1.2 : Getting Started : 自己教師あり学習 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/12/2022 (v1.2.25)
* 本ページは、Lightly の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Getting Started : Self-supervised learning
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Lightly 1.2 : Getting Started : 自己教師あり学習
Lightly は自己教師あり学習を使用して深層学習モデルを訓練するためのコンピュータビジョン・フレームワークです。フレームワークは、最近傍を見つける、類似性検索、転移学習やデータ分析のような幅広い有用なアプリケーションのために利用できます。
更に、lightly プラットフォーム と直接相互作用するために lightly フレームワークを使用できます。詳細は The Lightly Platform のセクションを確認してください。
Lightly がどのように動作するか
Lightly の柔軟なデザインは貴方の Python コードに統合することを容易にします。Lightly は PyTorch フレームワーク周りに完全に構築されていて、様々なピースは要件に合わせて組み立てることができます。
データと変換
SimCLR のような自己教師あり手法の基本的なビルディングブロックは画像変換です。各画像はランダムに増強が適用されて 2 つの新しい画像に変換されます。そして自己教師ありモデルのタスクはネガティブサンプルのセットの中で同じオリジナルに由来する画像を識別することです。
Lightly はこれらの変換を dataloader の collate 関数で torchvision transforms として実装しています。例えば、下の collate 関数は各画像に 2 つの異なる、ランダム化された transforms を適用します : ランダムなリサイズクロップとランダムなカラー jitter です。
import lightly.data as data
# the collate function applies random transforms to the input images
collate_fn = data.ImageCollateFunction(input_size=32, cj_prob=0.5)
そして画像データセットをロードして PyTorch dataloader を上からの collate 関数で作成しましょう。
import torch
# create a dataset from your image folder
dataset = data.LightlyDataset(input_dir='./my/cute/cats/dataset/')
# build a PyTorch dataloader
dataloader = torch.utils.data.DataLoader(
dataset, # pass the dataset to the dataloader
batch_size=128, # a large batch size helps with the learning
shuffle=True, # shuffling is important!
collate_fn=collate_fn) # apply transformations to the input images
Note : LightlyDataset の代わりにカスタム PyTorch Dataset を使用することもできます。モデルを訓練するための基本的な関数をサポートするために Dataset 実装が (sample, target, filename) のタプルを返すことを確実にするだけです。詳細は lightly.data.dataset を見てください。
次のセクションに向かい、ちょうど準備したデータで ResNet を訓練できる方法を見ます。
モデル、損失と訓練
次に、埋め込みモデル、optimizer と損失関数が必要です。正規化 temperature-scaled 交差エントロピー損失と単純な確率的勾配降下と共に ResNet を使用します。
import torchvision
from lightly.loss import NTXentLoss
from lightly.models.modules.heads import SimCLRProjectionHead
# use a resnet backbone
resnet = torchvision.models.resnet18()
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
# build a SimCLR model
class SimCLR(torch.nn.Module):
def __init__(self, backbone, hidden_dim, out_dim):
super().__init__()
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, out_dim)
def forward(self, x):
h = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(h)
return z
model = SimCLR(resnet, hidden_dim=512, out_dim=128)
# use a criterion for self-supervised learning
# (normalized temperature-scaled cross entropy loss)
criterion = NTXentLoss(temperature=0.5)
# get a PyTorch optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=1e-0, weight_decay=1e-5)
Note : カスタム・バックボーン使用し、自己教師あり学習を使用してそれらを訓練するために lightly を使用することもできます。colab playground でカスタム・バックボーンを使用する方法について更に学習してください。
10 エポックの間、モデルを訓練します。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
max_epochs = 10
for epoch in range(max_epochs):
for (x0, x1), _, _ in dataloader:
x0 = x0.to(device)
x1 = x1.to(device)
z0 = model(x0)
z1 = model(x1)
loss = criterion(z0, z1)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Congrats, 自己教師あり学習を使用して最初のモデルをちょうど訓練しました!
モデルを実装して訓練するために PyTorch Lightning を使用することももちろんできます。
import pytorch_lightning as pl
class SimCLR(pl.LightningModule):
def __init__(self, backbone, hidden_dim, out_dim):
super().__init__()
self.backbone = backbone
self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, out_dim)
self.criterion = NTXentLoss(temperature=0.5)
def forward(self, x):
h = self.backbone(x).flatten(start_dim=1)
z = self.projection_head(h)
return z
def training_step(self, batch, batch_idx):
(x0, x1), _, _ = batch
z0 = self.forward(x0)
z1 = self.forward(x1)
loss = self.criterion(z0, z1)
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=1e-0)
return optimizer
model = SimCLR(resnet, hidden_dim=512, out_dim=128)
gpus = 1 if torch.cuda.is_available() else None
trainer = pl.Trainer(max_epochs=max_epochs, gpus=gpus)
trainer.fit(
model,
dataloader
)
マルチ GPU を装備するマシン上で訓練するためには分散データ並列バックエンドの使用を勧めます。
# if we have a machine with 4 GPUs we set gpus=4
trainer = pl.Trainer(
max_epochs=max_epochs,
gpus=4,
distributed_backend='ddp'
)
trainer.fit(
model,
dataloader
)
埋め込み
画像を埋め込んだり、埋め込みモデルに直接アクセスするために訓練済みモデルを利用できます。
# make a new dataloader without the transformations
# The only transformation needed is to make a torch tensor out of the PIL image
dataset.transform = torchvision.transforms.ToTensor()
dataloader = torch.utils.data.DataLoader(
dataset, # use the same dataset as before
batch_size=1, # we can use batch size 1 for inference
shuffle=False, # don't shuffle your data during inference
)
# embed your image dataset
embeddings = []
model.eval()
with torch.no_grad():
for img, label, fnames in dataloader:
img = img.to(model.device)
emb = model.backbone(img).flatten(start_dim=1)
embeddings.append(emb)
embeddings = torch.cat(embeddings, 0)
Done! 最近傍を見つけたり類似検索を行なうために埋め込みを続けて利用できます。更に、転移や few-shot 学習のために ResNet バックボーンが利用できます。
# access the ResNet backbone
resnet = model.backbone
Note : 自己教師あり学習はモデルが訓練されるラベルを必要としません。けれども Lightly は追加のラベルの使用をサポートします。例えば、サブフォルダ ‘Maine Coon’, ‘Bengal’ と ‘British Shorthair’ を含むフォルダ ‘cats’ でモデルを訓練する場合、Lightly は列挙されたラベルをリストとして自動的に返します。
Lightly in Three Lines
Lightly はまた easy-to-use なインターフェースも提供します。以下の行は自己教師あり学習でモデルを訓練して埋め込みを作成するために 3 行のコードだけでパッケージを使用する方法を示しています。
from lightly import train_embedding_model, embed_images
# first we train our model for 10 epochs
checkpoint = train_embedding_model(input_dir='./my/cute/cats/dataset/', trainer={'max_epochs': 10})
# let's embed our 'cats' dataset using our trained model
embeddings, labels, filenames = embed_images(input_dir='./my/cute/cats/dataset/', checkpoint=checkpoint)
# now, let's inspect the shape of our embeddings
print(embeddings.shape)
以上