Sentence Transformers 2.2 : ノートブック : 画像検索 – 画像クラスタリング (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/01/2022 (v2.2.2)
* 本ページは、UKPLab/sentence-transformers の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
クラスキャット 人工知能 研究開発支援サービス
◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Sentence Transformers 2.2 : ノートブック : 画像検索 – 画像クラスタリング
このサンプルは SentenceTransformer が画像クラスタリングのためにどのように使用できるかを示します。
モデルとしては OpenAI CLIP モデル を使用します、これは画像と画像の alt テキストの大規模なセットで訓練されました。
写真のソースとしては、Unsplash Dataset Lite を使用します、これは約 25k 画像を含みます。Unsplash 画像については ライセンス をご覧ください。
すべての画像をベクトル空間にエンコードしてからこのベクトル空間で高密度な領域、つまり画像がかなり類似している領域を見つけます。
from sentence_transformers import SentenceTransformer, util
from PIL import Image
import glob
import torch
import pickle
import zipfile
from IPython.display import display
from IPython.display import Image as IPImage
import os
from tqdm.autonotebook import tqdm
# 最初に、個別の CLIP モデルをロードします。
model = SentenceTransformer('clip-ViT-B-32')
# 次に Unsplash から約 25k 画像を取得します。
img_folder = 'photos/'
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
os.makedirs(img_folder, exist_ok=True)
photo_filename = 'unsplash-25k-photos.zip'
if not os.path.exists(photo_filename): #Download dataset if does not exist
util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename)
#Extract all images
with zipfile.ZipFile(photo_filename, 'r') as zf:
for member in tqdm(zf.infolist(), desc='Extracting'):
zf.extract(member, img_folder)
# 今は、埋め込みを計算する必要があります。
# 早めるために、事前計算された埋め込みを分配します。
# そうでないなら画像を貴方自身でエンコードすることもできます。
# 画像をエンコードするには、以下のコードを使用できます :
# from PIL import Image
# img_emb = model.encode(Image.open(filepath))
use_precomputed_embeddings = True
if use_precomputed_embeddings:
emb_filename = 'unsplash-25k-photos-embeddings.pkl'
if not os.path.exists(emb_filename): #Download dataset if does not exist
util.http_get('http://sbert.net/datasets/'+emb_filename, emb_filename)
with open(emb_filename, 'rb') as fIn:
img_names, img_emb = pickle.load(fIn)
print("Images:", len(img_names))
else:
img_names = list(glob.glob('unsplash/photos/*.jpg'))
print("Images:", len(img_names))
img_emb = model.encode([Image.open(filepath) for filepath in img_names], batch_size=128, convert_to_tensor=True, show_progress_bar=True)
Images: 24996
# ベクトル空間の高密度領域を見つけるために独自の効率的なメソッドを実装しました。
def community_detection(embeddings, threshold, min_community_size=10, init_max_size=1000):
"""
高速なコミュニティ検出のための関数
埋め込みですべてのコミュニティ、つまり近い (閾値よりも近い) 埋め込みを見つけます。
min_community_size よりも大きいコミュニティだけを返します。コミュニティは降順で返されます。
各リストの最初の要素はコミュニティの中心点です。
"""
# コサイン類以度スコアを計算します。
cos_scores = util.cos_sim(embeddings, embeddings)
# コミュニティに対する最小サイズ
top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True)
# Filter for rows >= min_threshold
extracted_communities = []
for i in range(len(top_k_values)):
if top_k_values[i][-1] >= threshold:
new_cluster = []
# Only check top k most similar entries
top_val_large, top_idx_large = cos_scores[i].topk(k=init_max_size, largest=True)
top_idx_large = top_idx_large.tolist()
top_val_large = top_val_large.tolist()
if top_val_large[-1] < threshold:
for idx, val in zip(top_idx_large, top_val_large):
if val < threshold:
break
new_cluster.append(idx)
else:
# Iterate over all entries (slow)
for idx, val in enumerate(cos_scores[i].tolist()):
if val >= threshold:
new_cluster.append(idx)
extracted_communities.append(new_cluster)
# Largest cluster first
extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)
# Step 2) Remove overlapping communities
unique_communities = []
extracted_ids = set()
for community in extracted_communities:
add_cluster = True
for idx in community:
if idx in extracted_ids:
add_cluster = False
break
if add_cluster:
unique_communities.append(community)
for idx in community:
extracted_ids.add(idx)
return unique_communities
# Now we run the clustering algorithm
# With the threshold parameter, we define at which threshold we identify
# two images as similar. Set the threshold lower, and you will get larger clusters which have
# less similar images in it (e.g. black cat images vs. cat images vs. animal images).
# With min_community_size, we define that we only want to have clusters of a certain minimal size
clusters = community_detection(img_emb, threshold=0.9, min_community_size=10)
print("Total number of clusters:", len(clusters))
Total number of clusters: 147
# Now we output the first 10 (largest) clusters
for cluster in clusters[0:10]:
print("\n\nCluster size:", len(cluster))
#Output 3 images
for idx in cluster[0:3]:
display(IPImage(os.path.join(img_folder, img_names[idx]), width=200))
Cluster size: 482
Cluster size: 401
Cluster size: 360
Cluster size: 258
Cluster size: 194
Cluster size: 145
以上