PyTorch Lightning 1.1: research : 転移学習 (CIFAR10, ResNet50)
作成 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/01/2021 (1.1.x)
* 本ページは PyTorch Lighting の以下のドキュメントを参考に実装した転移学習のサンプルコードです:
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
★ 無料セミナー実施中 ★ クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :
| 人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
| PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) | ||
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
| 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
| E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
| Facebook: https://www.facebook.com/ClassCatJP/ |
research: 転移学習 (CIFAR10, ResNet50)
! pip install pytorch-lightning pytorch-lightning-bolts -qU
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR, CyclicLR, ExponentialLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.optim.swa_utils import AveragedModel, update_bn
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, GPUStatsMonitor, EarlyStopping
from pytorch_lightning.metrics.functional import accuracy
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
pl.seed_everything(7);
batch_size = 32
train_transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomCrop(32, padding=4),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
test_transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
cifar10_normalization(),
])
cifar10_dm = CIFAR10DataModule(
batch_size=batch_size,
train_transforms=train_transforms,
test_transforms=test_transforms,
val_transforms=test_transforms,
)
import torchvision.models as models
class ImagenetTransferLearning(pl.LightningModule):
def __init__(self, lr=0.05, factor=0.5):
super().__init__()
self.save_hyperparameters()
# init a pretrained resnet
backbone = models.resnet50(pretrained=True)
num_filters = backbone.fc.in_features # 2048
layers = list(backbone.children())[:-1] # 分類層は取り除く
layers[3] = nn.Identity()
self.feature_extractor = nn.Sequential(*layers)
# use the pretrained model to classify cifar-10 (10 image classes)
num_target_classes = 10
layers = [
nn.Linear(num_filters, 512),
nn.ReLU(inplace=True),
nn.Dropout(p=0.4),
nn.Linear(512, num_target_classes)
]
self.classifier = nn.Sequential(*layers)
def forward(self, x):
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
out = self.classifier(representations)
logits = F.log_softmax(out, dim=1)
return logits
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
self.log('train_loss', loss)
return loss
def evaluate(self, batch, stage=None):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
acc = accuracy(preds, y)
if stage:
self.log(f'{stage}_loss', loss, prog_bar=True)
self.log(f'{stage}_acc', acc, prog_bar=True)
def validation_step(self, batch, batch_idx):
self.evaluate(batch, 'val')
def test_step(self, batch, batch_idx):
self.evaluate(batch, 'test')
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
return {
'optimizer': optimizer,
'lr_scheduler': ReduceLROnPlateau(optimizer, 'max', patience=5, factor=self.hparams.factor, verbose=True, threshold=0.0001, threshold_mode='abs', cooldown=1, min_lr=1e-5),
'monitor': 'val_acc'
}
def xconfigure_optimizers(self):
#print("###")
#print(self.hparams)
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.lr, momentum=0.9, weight_decay=5e-4)
steps_per_epoch = 45000 // batch_size
scheduler_dict = {
#'scheduler': ExponentialLR(optimizer, gamma=0.1),
#'interval': 'epoch',
'scheduler': OneCycleLR(optimizer, max_lr=0.1, pct_start=0.25, epochs=self.trainer.max_epochs, steps_per_epoch=steps_per_epoch),
#'scheduler': CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, step_size_up=steps_per_epoch*2, mode="triangular2"),
#'scheduler': CyclicLR(optimizer, base_lr=0.001, max_lr=0.1, step_size_up=steps_per_epoch, mode="exp_range", gamma=0.85),
#'scheduler': CosineAnnealingLR(optimizer, T_max=200),
'interval': 'step',
}
return {'optimizer': optimizer, 'lr_scheduler': scheduler_dict}
model = ImagenetTransferLearning(lr=0.001, factor=0.9)
model.datamodule = cifar10_dm
from torchsummary import summary
summary(model.feature_extractor.to('cuda'), (3, 32, 32))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [-1, 64, 16, 16] 9,408
BatchNorm2d-2 [-1, 64, 16, 16] 128
ReLU-3 [-1, 64, 16, 16] 0
Identity-4 [-1, 64, 16, 16] 0
Conv2d-5 [-1, 64, 16, 16] 4,096
BatchNorm2d-6 [-1, 64, 16, 16] 128
ReLU-7 [-1, 64, 16, 16] 0
Conv2d-8 [-1, 64, 16, 16] 36,864
BatchNorm2d-9 [-1, 64, 16, 16] 128
ReLU-10 [-1, 64, 16, 16] 0
Conv2d-11 [-1, 256, 16, 16] 16,384
BatchNorm2d-12 [-1, 256, 16, 16] 512
Conv2d-13 [-1, 256, 16, 16] 16,384
BatchNorm2d-14 [-1, 256, 16, 16] 512
ReLU-15 [-1, 256, 16, 16] 0
Bottleneck-16 [-1, 256, 16, 16] 0
Conv2d-17 [-1, 64, 16, 16] 16,384
BatchNorm2d-18 [-1, 64, 16, 16] 128
ReLU-19 [-1, 64, 16, 16] 0
Conv2d-20 [-1, 64, 16, 16] 36,864
BatchNorm2d-21 [-1, 64, 16, 16] 128
ReLU-22 [-1, 64, 16, 16] 0
Conv2d-23 [-1, 256, 16, 16] 16,384
BatchNorm2d-24 [-1, 256, 16, 16] 512
ReLU-25 [-1, 256, 16, 16] 0
Bottleneck-26 [-1, 256, 16, 16] 0
Conv2d-27 [-1, 64, 16, 16] 16,384
BatchNorm2d-28 [-1, 64, 16, 16] 128
ReLU-29 [-1, 64, 16, 16] 0
Conv2d-30 [-1, 64, 16, 16] 36,864
BatchNorm2d-31 [-1, 64, 16, 16] 128
ReLU-32 [-1, 64, 16, 16] 0
Conv2d-33 [-1, 256, 16, 16] 16,384
BatchNorm2d-34 [-1, 256, 16, 16] 512
ReLU-35 [-1, 256, 16, 16] 0
Bottleneck-36 [-1, 256, 16, 16] 0
Conv2d-37 [-1, 128, 16, 16] 32,768
BatchNorm2d-38 [-1, 128, 16, 16] 256
ReLU-39 [-1, 128, 16, 16] 0
Conv2d-40 [-1, 128, 8, 8] 147,456
BatchNorm2d-41 [-1, 128, 8, 8] 256
ReLU-42 [-1, 128, 8, 8] 0
Conv2d-43 [-1, 512, 8, 8] 65,536
BatchNorm2d-44 [-1, 512, 8, 8] 1,024
Conv2d-45 [-1, 512, 8, 8] 131,072
BatchNorm2d-46 [-1, 512, 8, 8] 1,024
ReLU-47 [-1, 512, 8, 8] 0
Bottleneck-48 [-1, 512, 8, 8] 0
Conv2d-49 [-1, 128, 8, 8] 65,536
BatchNorm2d-50 [-1, 128, 8, 8] 256
ReLU-51 [-1, 128, 8, 8] 0
Conv2d-52 [-1, 128, 8, 8] 147,456
BatchNorm2d-53 [-1, 128, 8, 8] 256
ReLU-54 [-1, 128, 8, 8] 0
Conv2d-55 [-1, 512, 8, 8] 65,536
BatchNorm2d-56 [-1, 512, 8, 8] 1,024
ReLU-57 [-1, 512, 8, 8] 0
Bottleneck-58 [-1, 512, 8, 8] 0
Conv2d-59 [-1, 128, 8, 8] 65,536
BatchNorm2d-60 [-1, 128, 8, 8] 256
ReLU-61 [-1, 128, 8, 8] 0
Conv2d-62 [-1, 128, 8, 8] 147,456
BatchNorm2d-63 [-1, 128, 8, 8] 256
ReLU-64 [-1, 128, 8, 8] 0
Conv2d-65 [-1, 512, 8, 8] 65,536
BatchNorm2d-66 [-1, 512, 8, 8] 1,024
ReLU-67 [-1, 512, 8, 8] 0
Bottleneck-68 [-1, 512, 8, 8] 0
Conv2d-69 [-1, 128, 8, 8] 65,536
BatchNorm2d-70 [-1, 128, 8, 8] 256
ReLU-71 [-1, 128, 8, 8] 0
Conv2d-72 [-1, 128, 8, 8] 147,456
BatchNorm2d-73 [-1, 128, 8, 8] 256
ReLU-74 [-1, 128, 8, 8] 0
Conv2d-75 [-1, 512, 8, 8] 65,536
BatchNorm2d-76 [-1, 512, 8, 8] 1,024
ReLU-77 [-1, 512, 8, 8] 0
Bottleneck-78 [-1, 512, 8, 8] 0
Conv2d-79 [-1, 256, 8, 8] 131,072
BatchNorm2d-80 [-1, 256, 8, 8] 512
ReLU-81 [-1, 256, 8, 8] 0
Conv2d-82 [-1, 256, 4, 4] 589,824
BatchNorm2d-83 [-1, 256, 4, 4] 512
ReLU-84 [-1, 256, 4, 4] 0
Conv2d-85 [-1, 1024, 4, 4] 262,144
BatchNorm2d-86 [-1, 1024, 4, 4] 2,048
Conv2d-87 [-1, 1024, 4, 4] 524,288
BatchNorm2d-88 [-1, 1024, 4, 4] 2,048
ReLU-89 [-1, 1024, 4, 4] 0
Bottleneck-90 [-1, 1024, 4, 4] 0
Conv2d-91 [-1, 256, 4, 4] 262,144
BatchNorm2d-92 [-1, 256, 4, 4] 512
ReLU-93 [-1, 256, 4, 4] 0
Conv2d-94 [-1, 256, 4, 4] 589,824
BatchNorm2d-95 [-1, 256, 4, 4] 512
ReLU-96 [-1, 256, 4, 4] 0
Conv2d-97 [-1, 1024, 4, 4] 262,144
BatchNorm2d-98 [-1, 1024, 4, 4] 2,048
ReLU-99 [-1, 1024, 4, 4] 0
Bottleneck-100 [-1, 1024, 4, 4] 0
Conv2d-101 [-1, 256, 4, 4] 262,144
BatchNorm2d-102 [-1, 256, 4, 4] 512
ReLU-103 [-1, 256, 4, 4] 0
Conv2d-104 [-1, 256, 4, 4] 589,824
BatchNorm2d-105 [-1, 256, 4, 4] 512
ReLU-106 [-1, 256, 4, 4] 0
Conv2d-107 [-1, 1024, 4, 4] 262,144
BatchNorm2d-108 [-1, 1024, 4, 4] 2,048
ReLU-109 [-1, 1024, 4, 4] 0
Bottleneck-110 [-1, 1024, 4, 4] 0
Conv2d-111 [-1, 256, 4, 4] 262,144
BatchNorm2d-112 [-1, 256, 4, 4] 512
ReLU-113 [-1, 256, 4, 4] 0
Conv2d-114 [-1, 256, 4, 4] 589,824
BatchNorm2d-115 [-1, 256, 4, 4] 512
ReLU-116 [-1, 256, 4, 4] 0
Conv2d-117 [-1, 1024, 4, 4] 262,144
BatchNorm2d-118 [-1, 1024, 4, 4] 2,048
ReLU-119 [-1, 1024, 4, 4] 0
Bottleneck-120 [-1, 1024, 4, 4] 0
Conv2d-121 [-1, 256, 4, 4] 262,144
BatchNorm2d-122 [-1, 256, 4, 4] 512
ReLU-123 [-1, 256, 4, 4] 0
Conv2d-124 [-1, 256, 4, 4] 589,824
BatchNorm2d-125 [-1, 256, 4, 4] 512
ReLU-126 [-1, 256, 4, 4] 0
Conv2d-127 [-1, 1024, 4, 4] 262,144
BatchNorm2d-128 [-1, 1024, 4, 4] 2,048
ReLU-129 [-1, 1024, 4, 4] 0
Bottleneck-130 [-1, 1024, 4, 4] 0
Conv2d-131 [-1, 256, 4, 4] 262,144
BatchNorm2d-132 [-1, 256, 4, 4] 512
ReLU-133 [-1, 256, 4, 4] 0
Conv2d-134 [-1, 256, 4, 4] 589,824
BatchNorm2d-135 [-1, 256, 4, 4] 512
ReLU-136 [-1, 256, 4, 4] 0
Conv2d-137 [-1, 1024, 4, 4] 262,144
BatchNorm2d-138 [-1, 1024, 4, 4] 2,048
ReLU-139 [-1, 1024, 4, 4] 0
Bottleneck-140 [-1, 1024, 4, 4] 0
Conv2d-141 [-1, 512, 4, 4] 524,288
BatchNorm2d-142 [-1, 512, 4, 4] 1,024
ReLU-143 [-1, 512, 4, 4] 0
Conv2d-144 [-1, 512, 2, 2] 2,359,296
BatchNorm2d-145 [-1, 512, 2, 2] 1,024
ReLU-146 [-1, 512, 2, 2] 0
Conv2d-147 [-1, 2048, 2, 2] 1,048,576
BatchNorm2d-148 [-1, 2048, 2, 2] 4,096
Conv2d-149 [-1, 2048, 2, 2] 2,097,152
BatchNorm2d-150 [-1, 2048, 2, 2] 4,096
ReLU-151 [-1, 2048, 2, 2] 0
Bottleneck-152 [-1, 2048, 2, 2] 0
Conv2d-153 [-1, 512, 2, 2] 1,048,576
BatchNorm2d-154 [-1, 512, 2, 2] 1,024
ReLU-155 [-1, 512, 2, 2] 0
Conv2d-156 [-1, 512, 2, 2] 2,359,296
BatchNorm2d-157 [-1, 512, 2, 2] 1,024
ReLU-158 [-1, 512, 2, 2] 0
Conv2d-159 [-1, 2048, 2, 2] 1,048,576
BatchNorm2d-160 [-1, 2048, 2, 2] 4,096
ReLU-161 [-1, 2048, 2, 2] 0
Bottleneck-162 [-1, 2048, 2, 2] 0
Conv2d-163 [-1, 512, 2, 2] 1,048,576
BatchNorm2d-164 [-1, 512, 2, 2] 1,024
ReLU-165 [-1, 512, 2, 2] 0
Conv2d-166 [-1, 512, 2, 2] 2,359,296
BatchNorm2d-167 [-1, 512, 2, 2] 1,024
ReLU-168 [-1, 512, 2, 2] 0
Conv2d-169 [-1, 2048, 2, 2] 1,048,576
BatchNorm2d-170 [-1, 2048, 2, 2] 4,096
ReLU-171 [-1, 2048, 2, 2] 0
Bottleneck-172 [-1, 2048, 2, 2] 0
AdaptiveAvgPool2d-173 [-1, 2048, 1, 1] 0
================================================================
Total params: 23,508,032
Trainable params: 23,508,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 22.28
Params size (MB): 89.68
Estimated Total Size (MB): 111.97
----------------------------------------------------------------
%%time
model = ImagenetTransferLearning(lr=0.001, factor=0.9)
model.datamodule = cifar10_dm
trainer = pl.Trainer(
gpus=1,
max_epochs=50,
progress_bar_refresh_rate=50,
logger=pl.loggers.TensorBoardLogger('tblogs/', name='resnet50'),
callbacks=[LearningRateMonitor(logging_interval='step')],
)
trainer.fit(model, cifar10_dm)
trainer.test(model, datamodule=cifar10_dm);
以上