PyTorch Lightning 1.1: ユースケース : 転移学習 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/01/2021 (1.2.1)
* 本ページは、PyTorch Lightning ドキュメントの以下のページを翻訳した上で適宜、補足説明したものです:
- Common Use Cases : Transfer Learning
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
ユースケース : 転移学習
事前訓練モデルを利用する
時に LightningModule を事前訓練モデルとして利用することを望みます。LightningModule は単なる torch.nn.Module ですからこれは問題ないありません!
Note
LightningModule は 正確に torch.nn.Module であることを忘れないでください、しかしより多くの機能を伴います。
オートエンコーダを別個のモデルの特徴抽出器として使用しましょう。
class Encoder(torch.nn.Module):
...
class AutoEncoder(LightningModule):
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
class CIFAR10Classifier(LightningModule):
def __init__(self):
# init the pretrained LightningModule
self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH)
self.feature_extractor.freeze()
# the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
self.classifier = nn.Linear(100, 10)
def forward(self, x):
representations = self.feature_extractor(x)
x = self.classifier(representations)
...
転移学習のために事前訓練オートエンコーダ (LightningModule) を使用しました!
サンプル: Imagenet (コンピュータヴィジョン)
import torchvision.models as models
class ImagenetTransferLearning(LightningModule):
def __init__(self):
super().__init__()
# init a pretrained resnet
backbone = models.resnet50(pretrained=True)
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*layers)
# use the pretrained model to classify cifar-10 (10 image classes)
num_target_classes = 10
self.classifier = nn.Linear(num_filters, num_target_classes)
def forward(self, x):
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
x = self.classifier(representations)
...
再調整
model = ImagenetTransferLearning()
trainer = Trainer()
trainer.fit(model)
そして関心のあるデータを予測するためにそれを使用します。
model = ImagenetTransferLearning.load_from_checkpoint(PATH)
model.freeze()
x = some_images_from_cifar10()
predictions = model(x)
imagenet 上の事前訓練モデルを使用し、CIFAR-10 上で予測するために CIFAR-10 上で再調整しました。非学術的世界では貴方が持つ tiny データセット上で再調整して貴方のデータセット上で予測するでしょう。
サンプル: BERT (NLP)
Lightning は転移学習のために何が使用されるかについて、それが torch.nn.Module サブクラスである限り完全に不可知です。
ここに Huggingface transformers を使用するモデルがあります。
class BertMNLIFinetuner(LightningModule):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)
self.W = nn.Linear(bert.config.hidden_size, 3)
self.num_classes = 3
def forward(self, input_ids, attention_mask, token_type_ids):
h, _, attn = self.bert(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids)
h_cls = h[:, 0]
logits = self.W(h_cls)
return logits, attn
以上