CLIP : ノートブック : CLIP との相互作用 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/09/2022 (No releases published)
* 本ページは、CLIP の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
CLIP : ノートブック : CLIP との相互作用
これは自己完結型のノートブックで、CLIP モデルをダウンロードして実行する方法、任意の画像とテキスト入力の間の類以度を計算する方法、そしてゼロショット画像分類を実行する方法を実演します。
Colab のための準備
GPU ランタイムで実行していることを確認してください : そうでないなら、メニューの「ランタイム > ランタイムのタイプを変更」でハードウェアアクセラレータとして “GPU” を選択します。次のセルは clip パッケージとその依存関係をインストールして、PyTorch 1.7.1 またはそれ以上がインストールされているかを確認します。
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
import numpy as np
import torch
from pkg_resources import packaging
print("Torch version:", torch.__version__)
Torch version: 1.13.0+cu116
モデルのロード
clip.available_models() は利用可能な CLIP モデルの名前を列挙します。
import clip
clip.available_models()
['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64', 'ViT-B/32', 'ViT-B/16', 'ViT-L/14', 'ViT-L/14@336px']
model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
Model parameters: 151,277,313 Input resolution: 224 Context length: 77 Vocab size: 49408
画像前処理
入力画像をリサイズしてモデルが想定する画像解像度と一致するようにセンタークロップします。それを行なう前に、データセット平均と標準偏差を使用してピクセル強度を正規化します。
clip.load() からの 2 番目の戻り値はこの前処理を行なう torchvision Transform を含みます。
preprocess
Compose( Resize(size=224, interpolation=bicubic, max_size=None, antialias=None) CenterCrop(size=(224, 224)) <function _convert_image_to_rgb at 0x7f38ce9704c0> ToTensor() Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) )
テキスト前処理
case-insensitive なトークナイザーを使用します、これは clip.tokenize() を使用して起動できます。デフォルトでは、出力は 77 トークン長になるようにパディングされます、これは CLIP モデルが想定するものです。
clip.tokenize("Hello World!")
tensor([[49406, 3306, 1002, 256, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int32)
入力画像とテキストのセットアップ
8 個のサンプル画像とそれらのテキスト説明をモデルに供給し、対応する特徴量の間の類以度を比較していきます。
トークナイザーは case-insensitive で、適切なテキスト説明を自由に与えることができます。
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# images in skimage to use and their textual descriptions
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}
skimage.data_dir
/root/.cache/scikit-image/0.18.3/data
original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}\n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
plt.tight_layout()
特徴量の構築
画像を正規化し、各テキスト入力をトークン化し、そして画像とテキストの特徴量を取得するためにモデルの forward パスを実行します。
image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
image_input.shape, text_tokens.shape
(torch.Size([8, 3, 224, 224]), torch.Size([8, 77]))
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
print(image_features.shape)
print(text_features.shape)
torch.Size([8, 512]) torch.Size([8, 512])
コサイン類以度の計算
特徴量を正規化して各ペアのドット積を計算します。
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
image_features.shape, text_features.shape
(torch.Size([8, 512]), torch.Size([8, 512]))
similarity.shape
(8, 8)
count = len(descriptions)
plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])
plt.title("Cosine similarity between text and image features", size=20)
Text(0.5, 1.0, 'Cosine similarity between text and image features')
ゼロショット画像分類
コサイン類以度 (100 倍) を softmax 演算へのロジットとして使用して画像を分類することができます。
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /root/.cache/cifar-100-python.tar.gz HBox(children=(FloatProgress(value=0.0, max=169001437.0), HTML(value=''))) Extracting /root/.cache/cifar-100-python.tar.gz to /root/.cache
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
len(text_descriptions), text_tokens.shape
(100, torch.Size([100, 77]))
text_descriptions[0]
This is a photo of a apple
text_tokens[0]
tensor([49406, 589, 533, 320, 1125, 539, 320, 3055, 49407, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0', dtype=torch.int32)
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
print(text_features.shape)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
print(text_probs.shape)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
print(top_probs)
print(top_labels)
torch.Size([100, 512]) torch.Size([8, 100]) tensor([[3.6263e-01, 1.7735e-01, 1.5459e-01, 4.6055e-02, 2.9277e-02], [9.8798e-01, 6.8940e-03, 1.3679e-03, 5.7813e-04, 3.1579e-04], [9.9349e-01, 1.2864e-03, 5.2795e-04, 5.0379e-04, 4.7412e-04], [1.4034e-01, 9.3087e-02, 9.2334e-02, 8.8002e-02, 8.2898e-02], [8.1203e-01, 6.5702e-02, 2.5351e-02, 2.2453e-02, 1.0816e-02], [5.9194e-01, 4.6337e-02, 3.5940e-02, 1.8315e-02, 1.6976e-02], [4.0153e-01, 1.5970e-01, 8.0615e-02, 6.8911e-02, 3.6448e-02], [9.5844e-01, 1.0107e-02, 8.5635e-03, 4.6983e-03, 3.2219e-03]]) tensor([[19, 15, 60, 29, 38], [48, 8, 41, 58, 83], [69, 76, 13, 22, 12], [ 5, 58, 12, 19, 84], [98, 35, 11, 69, 60], [83, 88, 11, 42, 80], [46, 11, 86, 60, 98], [28, 10, 61, 84, 60]])
plt.figure(figsize=(16, 16))
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace=0.5)
plt.show()
以上