PyTorch Lightning 1.1: research : 転移学習 (CIFAR10, VGG)
作成 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/01/2021 (1.1.x)
* 本ページは PyTorch Lighting の以下のドキュメントを参考に実装した転移学習のサンプルコードです:
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
★ 無料セミナー実施中 ★ クラスキャット主催 人工知能 & ビジネス Web セミナー
人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
research: 転移学習 (CIFAR10, VGG)
! pip install pytorch-lightning pytorch-lightning-bolts -qU
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR, CyclicLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.swa_utils import AveragedModel, update_bn
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, GPUStatsMonitor, EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
pl.seed_everything(7);
batch_size = 32
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
test_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
cifar10_dm = CIFAR10DataModule(
batch_size=batch_size,
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms,
)
import torchvision.models as models
model_to_check = models.vgg11(pretrained=False)
from torchsummary import summary
summary(model_to_check.to('cuda:0'), (3, 224, 224))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 224, 224] 1,792 ReLU-2 [-1, 64, 224, 224] 0 MaxPool2d-3 [-1, 64, 112, 112] 0 Conv2d-4 [-1, 128, 112, 112] 73,856 ReLU-5 [-1, 128, 112, 112] 0 MaxPool2d-6 [-1, 128, 56, 56] 0 Conv2d-7 [-1, 256, 56, 56] 295,168 ReLU-8 [-1, 256, 56, 56] 0 Conv2d-9 [-1, 256, 56, 56] 590,080 ReLU-10 [-1, 256, 56, 56] 0 MaxPool2d-11 [-1, 256, 28, 28] 0 Conv2d-12 [-1, 512, 28, 28] 1,180,160 ReLU-13 [-1, 512, 28, 28] 0 Conv2d-14 [-1, 512, 28, 28] 2,359,808 ReLU-15 [-1, 512, 28, 28] 0 MaxPool2d-16 [-1, 512, 14, 14] 0 Conv2d-17 [-1, 512, 14, 14] 2,359,808 ReLU-18 [-1, 512, 14, 14] 0 Conv2d-19 [-1, 512, 14, 14] 2,359,808 ReLU-20 [-1, 512, 14, 14] 0 MaxPool2d-21 [-1, 512, 7, 7] 0 AdaptiveAvgPool2d-22 [-1, 512, 7, 7] 0 Linear-23 [-1, 4096] 102,764,544 ReLU-24 [-1, 4096] 0 Dropout-25 [-1, 4096] 0 Linear-26 [-1, 4096] 16,781,312 ReLU-27 [-1, 4096] 0 Dropout-28 [-1, 4096] 0 Linear-29 [-1, 1000] 4,097,000 ================================================================ Total params: 132,863,336 Trainable params: 132,863,336 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.57 Forward/backward pass size (MB): 125.37 Params size (MB): 506.83 Estimated Total Size (MB): 632.78 ----------------------------------------------------------------
summary(model_to_check.to('cuda:0'), (3, 32, 32))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 32, 32] 1,792 ReLU-2 [-1, 64, 32, 32] 0 MaxPool2d-3 [-1, 64, 16, 16] 0 Conv2d-4 [-1, 128, 16, 16] 73,856 ReLU-5 [-1, 128, 16, 16] 0 MaxPool2d-6 [-1, 128, 8, 8] 0 Conv2d-7 [-1, 256, 8, 8] 295,168 ReLU-8 [-1, 256, 8, 8] 0 Conv2d-9 [-1, 256, 8, 8] 590,080 ReLU-10 [-1, 256, 8, 8] 0 MaxPool2d-11 [-1, 256, 4, 4] 0 Conv2d-12 [-1, 512, 4, 4] 1,180,160 ReLU-13 [-1, 512, 4, 4] 0 Conv2d-14 [-1, 512, 4, 4] 2,359,808 ReLU-15 [-1, 512, 4, 4] 0 MaxPool2d-16 [-1, 512, 2, 2] 0 Conv2d-17 [-1, 512, 2, 2] 2,359,808 ReLU-18 [-1, 512, 2, 2] 0 Conv2d-19 [-1, 512, 2, 2] 2,359,808 ReLU-20 [-1, 512, 2, 2] 0 MaxPool2d-21 [-1, 512, 1, 1] 0 AdaptiveAvgPool2d-22 [-1, 512, 7, 7] 0 Linear-23 [-1, 4096] 102,764,544 ReLU-24 [-1, 4096] 0 Dropout-25 [-1, 4096] 0 Linear-26 [-1, 4096] 16,781,312 ReLU-27 [-1, 4096] 0 Dropout-28 [-1, 4096] 0 Linear-29 [-1, 1000] 4,097,000 ================================================================ Total params: 132,863,336 Trainable params: 132,863,336 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 2.94 Params size (MB): 506.83 Estimated Total Size (MB): 509.78 ----------------------------------------------------------------
layers = list(model_to_check.children())[:-1]
feature_extractor = nn.Sequential(*layers)
summary(feature_extractor, (3, 32, 32))
---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 32, 32] 1,792 ReLU-2 [-1, 64, 32, 32] 0 MaxPool2d-3 [-1, 64, 16, 16] 0 Conv2d-4 [-1, 128, 16, 16] 73,856 ReLU-5 [-1, 128, 16, 16] 0 MaxPool2d-6 [-1, 128, 8, 8] 0 Conv2d-7 [-1, 256, 8, 8] 295,168 ReLU-8 [-1, 256, 8, 8] 0 Conv2d-9 [-1, 256, 8, 8] 590,080 ReLU-10 [-1, 256, 8, 8] 0 MaxPool2d-11 [-1, 256, 4, 4] 0 Conv2d-12 [-1, 512, 4, 4] 1,180,160 ReLU-13 [-1, 512, 4, 4] 0 Conv2d-14 [-1, 512, 4, 4] 2,359,808 ReLU-15 [-1, 512, 4, 4] 0 MaxPool2d-16 [-1, 512, 2, 2] 0 Conv2d-17 [-1, 512, 2, 2] 2,359,808 ReLU-18 [-1, 512, 2, 2] 0 Conv2d-19 [-1, 512, 2, 2] 2,359,808 ReLU-20 [-1, 512, 2, 2] 0 MaxPool2d-21 [-1, 512, 1, 1] 0 AdaptiveAvgPool2d-22 [-1, 512, 7, 7] 0 ================================================================ Total params: 9,220,480 Trainable params: 9,220,480 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 2.74 Params size (MB): 35.17 Estimated Total Size (MB): 37.93 ----------------------------------------------------------------
#layers = list(model_to_check.features.children())[:-1]
layers = list(model_to_check.features.children())
print(len(layers))
layers[2] = nn.Identity()
layers[5] = nn.Identity()
layers[10] = nn.Identity()
feature_extractor2 = nn.Sequential(*layers)
#avgpool = nn.AdaptiveAvgPool2d((7, 8))
summary(feature_extractor2, (3, 32, 32))
21 ---------------------------------------------------------------- Layer (type) Output Shape Param # ================================================================ Conv2d-1 [-1, 64, 32, 32] 1,792 ReLU-2 [-1, 64, 32, 32] 0 Identity-3 [-1, 64, 32, 32] 0 Conv2d-4 [-1, 128, 32, 32] 73,856 ReLU-5 [-1, 128, 32, 32] 0 Identity-6 [-1, 128, 32, 32] 0 Conv2d-7 [-1, 256, 32, 32] 295,168 ReLU-8 [-1, 256, 32, 32] 0 Conv2d-9 [-1, 256, 32, 32] 590,080 ReLU-10 [-1, 256, 32, 32] 0 Identity-11 [-1, 256, 32, 32] 0 Conv2d-12 [-1, 512, 32, 32] 1,180,160 ReLU-13 [-1, 512, 32, 32] 0 Conv2d-14 [-1, 512, 32, 32] 2,359,808 ReLU-15 [-1, 512, 32, 32] 0 MaxPool2d-16 [-1, 512, 16, 16] 0 Conv2d-17 [-1, 512, 16, 16] 2,359,808 ReLU-18 [-1, 512, 16, 16] 0 Conv2d-19 [-1, 512, 16, 16] 2,359,808 ReLU-20 [-1, 512, 16, 16] 0 MaxPool2d-21 [-1, 512, 8, 8] 0 ================================================================ Total params: 9,220,480 Trainable params: 9,220,480 Non-trainable params: 0 ---------------------------------------------------------------- Input size (MB): 0.01 Forward/backward pass size (MB): 35.75 Params size (MB): 35.17 Estimated Total Size (MB): 70.94 ----------------------------------------------------------------
class ImagenetTransferLearning(pl.LightningModule):
def __init__(self, lr=0.05, factor=0.5):
super().__init__()
self.save_hyperparameters()
backbone = models.vgg11(pretrained=True)
layers = list(backbone.features.children())[:-1]
layers[2] = nn.Identity()
layers[5] = nn.Identity()
layers[10] = nn.Identity()
self.feature_extractor = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d((8, 8))
self.classifier = nn.Sequential(
nn.Linear(512 * 8 * 8, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 10),
)
"""
self.classifier = nn.Sequential(
nn.Linear(512 * 8 * 8, 2048),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(2048, 512),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(512, 10),
)
"""
def forward(self, x):
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x)
#representations = self.feature_extractor(x).flatten(1)
out = self.avgpool(representations).flatten(1)
out = self.classifier(out)
logits = F.log_softmax(out, dim=1)
return logits
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log('train_loss', loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f'{stage}_loss', loss, prog_bar=True)
self.log(f'{stage}_acc', acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, 'val')
def test_step(self, batch, batch_idx):
self.evaluate(batch, 'test')
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
return {
'optimizer': optimizer,
'lr_scheduler': ReduceLROnPlateau(optimizer, 'max', patience=5, factor=self.hparams.factor, verbose=True, threshold=0.0001, threshold_mode='abs', cooldown=1, min_lr=1e-5),
'monitor': 'val_acc'
}
%%time
model = ImagenetTransferLearning(lr=0.001, factor=0.9)
model.datamodule = cifar10_dm
trainer = pl.Trainer(
gpus=1,
max_epochs=20,
progress_bar_refresh_rate=50,
logger=pl.loggers.TensorBoardLogger('tblogs/', name='vgg11'),
callbacks=[LearningRateMonitor(logging_interval='step')],
)
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm);
以上