HuggingFace ブログ : 画像分類用 ViT の微調整

HuggingFace ブログ : 画像分類用 ViT の微調整 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/16/2022

* 本ページは、HuggingFace Blog の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

HuggingFace ブログ : 画像分類用 ViT の微調整

ちょうど transformers ベースのモデルが NLP に変革をもたらしたように、今はそれらをあらゆる種類の他のドメインに適用した論文の爆発 (的な増加) を見ています。これらの最も画期的な一つが Vision Transformer (ViT) で、これは Google Brain の研究者のチームにより 2021年6月 に紹介されました。

この論文は、ちょうどセンテンスをトークン化するように、画像をトークン化して訓練のために transformer モデルに渡せる方法を探求しました。それは非常に単純なコンセプトで、実際には …

  1. 画像を部分画像パッチのグリッドに分割する。
  2. 各パッチを線形射影で埋め込む
  3. 各々の埋め込まれたパッチはトークンになり、そして結果としての埋め込まれたパッチのシークエンスはモデルに渡すシークエンスになります。

上記をひとたび行えば、NLP タスクで馴染みがあるように transformers を事前訓練して微調整できることが判っています。Pretty sweet 😎.

 
このブログ記事では、画像分類データセットをダウンロードして処理するために 🤗 datasets を利用し、🤗 transformers で事前訓練済み ViT を微調整するためにそれらを使用する方法をガイドします。

始めるには、最初にそれら両方のパッケージをインストールしましょう。

pip install datasets transformers

 

データセットのロード

小さい画像分類データセットをロードしてその構造を見ることから始めましょう。

beans データセットを使用します、これは健康な豆の葉と不健康な豆の葉の画像のコレクションです。🍃

from datasets import load_dataset

ds = load_dataset('beans')
ds

beans データセットの ‘train’ 分割から 400 番目のサンプルを見てみましょう。

データセットの各サンプルが 3 つの特徴量を持っていることに気づくでしょう。

  1. image: PIL 画像

  2. image_file_path: image としてロードされた画像ファイルへの str パス。

  3. labels: datasets.ClassLabel 特徴、これはラベルの整数表現です。(Later you’ll see how to get the string class names, don’t worry!)
ex = ds['train'][400]
ex
{
  'image': <PIL.JpegImagePlugin ...>,
  'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
  'labels': 1
}

Let’s take a look at the image 👀

image = ex['image']
image

That’s definitely a leaf! But what kind? 😅

このデータセットの ‘labels’ 特徴は datasets.features.ClassLabel ですから、このサンプルのラベル ID に対応する名前を検索するためにそれを使用できます。

最初に、’labels’ に対する特徴定義にアクセスしましょう。

labels = ds['train'].features['labels']
labels
ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

そして、サンプルに対するクラスラベルをプリントアウトしましょう。ClassLabel の int2str 関数を使用してそれを行なうことができます、これは名前が暗黙のうちに示しているように、文字列ラベルを検索するためにクラスの整数表現を渡すことができます。

labels.int2str(ex['labels'])
'bean_rust'

上で示された葉は豆さび病 (Bean Rust) に感染していることが判ります、豆の苗木の深刻な病気です。

貴方が取り組んでいるものをより良く理解するために各クラスからのサンプルのグリッドを表示する関数を書きましょう。

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)


A grid of a few examples from each class in the dataset

私がわかることからは、

  • Angular Leaf Spot (角葉スポット): 不規則な茶色の斑点を持つ。

  • Bean Rust (豆さび病) : 白っぽい黄色の輪に囲まれた丸い茶色の斑点を持つ。

  • Healthy: …健康に見えます。 🤷‍♂️

 

ViT 特徴抽出器のロード

今では画像がどのようなものか知り、解こうとしている問題をより良く理解しました。モデルのためにこれらの画像をどのように準備できるかを見ましょう!

ViT モデルを訓練するとき、それらに供給される画像に特定の変換が適用されます。画像に間違った変換を使用すると、モデルは見ているものを理解できません!🖼 ➡️ 🔢

正しい変換を適用していることを確実にするため、利用を計画している事前訓練済みモデルとともにセーブされた設定で初期化された ViTFeatureExtractor を使用します。私たちのケースでは、google/vit-base-patch16-224-in21k モデルを使用していきますので、Hugging Face ハブからその特徴抽出器をロードしましょう。

from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

You can see the feature extractor configuration by printing it.

ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

画像を処理するには、それを特徴抽出器の call 関数に渡すだけです。これはピクセル値を含む辞書を返します、これはモデルに渡される数値表現です。

デフォルトでは NumPy 配列を得ますが、return_tensors=’pt’ 引数を追加すれば、代わりに torch テンソルが返されます。

feature_extractor(image, return_tensors='pt')

Should give you something like…

{
  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}

…where the shape of the tensor is (1, 3, 224, 224).

 

データセットの処理

画像を読みそれらを入力に変換する方法を知ったので、それら 2 つのことをまとめてデータセットの単一サンプルを処理する関数を書きましょう。

def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs
process_example(ds['train'][0])
{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': 0
}

ds.map を呼び出してこれをすべてのサンプルに一度に適用することも可能である一方で、これは特に大規模なデータセットの場合、非常に遅い可能性があります。代わりに、transform をデータセットに適用することができます。transform はそれらのインデックスを作成したときだけにサンプルに適用されます。

けれどもまずは、最後の関数をデータのバッチを受け取るようにアップデートする必要があります、それが ds.with_transform が想定するものだからです。

ds = load_dataset('beans')

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['labels'] = example_batch['labels']
    return inputs

You can directly apply this to the dataset using ds.with_transform(transform).

prepared_ds = ds.with_transform(transform)

これで、データセットからサンプルを得るときはいつでも、(下で示されるようにサンプルとスライスの両方に対して) transform がリアルタイムに適用されます。

prepared_ds['train'][0:2]

今回は、結果としての pixel_values テンソルは shape (2, 3, 224, 224) を持ちます。

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': [0, 0]
}

 

訓練と評価

データが処理されて、訓練パイプラインのセットアップを開始する準備ができました。このブログ記事は 🤗 の Trainer を使用しますが、最初に幾つかのことを行なう必要があります :

  • collate 関数を定義する。

  • 評価メトリックを定義する。訓練中、モデルはその予測精度で評価される必要があります。それに応じて compute_metrics 関数を定義する必要があります。

  • 事前訓練済みチェックポイントをロードします。事前訓練済みチェックポイントをロードして訓練用にそれを正しく設定する必要があります。

  • 訓練 configuration を定義する。

モデルを微調整した後、それを評価データの上で正しく評価して、画像を正しく分類することを実際に学習したことを検証します。

 

データ collator の定義

バッチは辞書のリストとして入ってきますので、単にそれらをバッチテンソルに unpack + stack することができます。

collate_fn はバッチの辞書を返しますので、後で入力をモデルに **unpack できます。 ✨

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

 

評価メトリックの定義

データセットからの 精度 メトリックは予測をラベルと比較するために簡単に使用することができます。

以下で、Trainer により使用される compute_metrics 関数でそれを使用する方法を見ることができます。

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

事前訓練済みモデルをロードしましょう。init で num_labels を追加しますので、モデルは正しいユニット数を持つ分類ヘッドを作成します。(push_to_hub を選択する場合) ハブ・ウイジェットで可読なラベルを持つ id2label と label2id マッピングも含みます。

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

Almost ready to train! 訓練前に必要な最後のものは TrainingArguments を定義して訓練設定をセットアップすることです。

これらの殆どは自明ですが、ここで非常に重要なものは remove_unused_columns=False です。これはモデルの call 関数により使用されない特徴をドロップします。デフォルトではそれは True です、通常は使用されない特徴カラムはドロップし、入力をモデルの call 関数にアンパックすることを簡単にするのが理想的だからです。しかし、私たちのケースでは、’pixel_values’ を作成するために未使用の特徴 (特に ‘image’) を必要とします。

言いたいことは、remove_unused_columns=False を設定するのを忘れた場合、ひどい目に合うということです。

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

Now, all instances can be passed to Trainer and we are ready to start training!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

 

訓練

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

 

評価 📊

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

ここに評価結果があります – Cool beans! Sorry, had to say it.

***** eval metrics *****
  epoch                   =        4.0
  eval_accuracy           =      0.985
  eval_loss               =     0.0637
  eval_runtime            = 0:00:02.13
  eval_samples_per_second =     62.356
  eval_steps_per_second   =       7.97

最後に、希望すれば、モデルをハブに push (アップロード) できます。ここでは、訓練設定で push_to_hub=True を指定したのであればそれをプッシュします。ハブにプッシュするためには、git-lfs をインストールして貴方の Hugging Face アカウントにログインする必要があることに注意してください (huggingface-cli login でなされます)。

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 cheers', **kwargs)
else:
    trainer.create_model_card(**kwargs)

The resulting model has been shared to nateraw/vit-base-beans. I’m assuming you don’t have pictures of bean leaves laying around, so I added some examples for you to give it a try! 🚀

 

以上