Sentence Transformers 2.2 : 画像検索 – 画像クラスタリング

Sentence Transformers 2.2 : ノートブック : 画像検索 – 画像クラスタリング (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/01/2022 (v2.2.2)

* 本ページは、UKPLab/sentence-transformers の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

Sentence Transformers 2.2 : ノートブック : 画像検索 – 画像クラスタリング

このサンプルは SentenceTransformer が画像クラスタリングのためにどのように使用できるかを示します。

モデルとしては OpenAI CLIP モデル を使用します、これは画像と画像の alt テキストの大規模なセットで訓練されました。

写真のソースとしては、Unsplash Dataset Lite を使用します、これは約 25k 画像を含みます。Unsplash 画像については ライセンス をご覧ください。

すべての画像をベクトル空間にエンコードしてからこのベクトル空間で高密度な領域、つまり画像がかなり類似している領域を見つけます。

from sentence_transformers import SentenceTransformer, util
from PIL import Image
import glob
import torch
import pickle
import zipfile
from IPython.display import display
from IPython.display import Image as IPImage
import os
from tqdm.autonotebook import tqdm

# 最初に、個別の CLIP モデルをロードします。
model = SentenceTransformer('clip-ViT-B-32')
# 次に Unsplash から約 25k 画像を取得します。
img_folder = 'photos/'
if not os.path.exists(img_folder) or len(os.listdir(img_folder)) == 0:
    os.makedirs(img_folder, exist_ok=True)
    
    photo_filename = 'unsplash-25k-photos.zip'
    if not os.path.exists(photo_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+photo_filename, photo_filename)
        
    #Extract all images
    with zipfile.ZipFile(photo_filename, 'r') as zf:
        for member in tqdm(zf.infolist(), desc='Extracting'):
            zf.extract(member, img_folder)
# 今は、埋め込みを計算する必要があります。
# 早めるために、事前計算された埋め込みを分配します。
# そうでないなら画像を貴方自身でエンコードすることもできます。
# 画像をエンコードするには、以下のコードを使用できます :
# from PIL import Image
# img_emb = model.encode(Image.open(filepath))

use_precomputed_embeddings = True

if use_precomputed_embeddings: 
    emb_filename = 'unsplash-25k-photos-embeddings.pkl'
    if not os.path.exists(emb_filename):   #Download dataset if does not exist
        util.http_get('http://sbert.net/datasets/'+emb_filename, emb_filename)
        
    with open(emb_filename, 'rb') as fIn:
        img_names, img_emb = pickle.load(fIn)  
    print("Images:", len(img_names))
else:
    img_names = list(glob.glob('unsplash/photos/*.jpg'))
    print("Images:", len(img_names))
    img_emb = model.encode([Image.open(filepath) for filepath in img_names], batch_size=128, convert_to_tensor=True, show_progress_bar=True)
Images: 24996
# ベクトル空間の高密度領域を見つけるために独自の効率的なメソッドを実装しました。
def community_detection(embeddings, threshold, min_community_size=10, init_max_size=1000):
    """
    高速なコミュニティ検出のための関数

    埋め込みですべてのコミュニティ、つまり近い (閾値よりも近い) 埋め込みを見つけます。

    min_community_size よりも大きいコミュニティだけを返します。コミュニティは降順で返されます。
    各リストの最初の要素はコミュニティの中心点です。
    """

    # コサイン類以度スコアを計算します。
    cos_scores = util.cos_sim(embeddings, embeddings)

    # コミュニティに対する最小サイズ
    top_k_values, _ = cos_scores.topk(k=min_community_size, largest=True)

    # Filter for rows >= min_threshold
    extracted_communities = []
    for i in range(len(top_k_values)):
        if top_k_values[i][-1] >= threshold:
            new_cluster = []

            # Only check top k most similar entries
            top_val_large, top_idx_large = cos_scores[i].topk(k=init_max_size, largest=True)
            top_idx_large = top_idx_large.tolist()
            top_val_large = top_val_large.tolist()

            if top_val_large[-1] < threshold:
                for idx, val in zip(top_idx_large, top_val_large):
                    if val < threshold:
                        break

                    new_cluster.append(idx)
            else:
                # Iterate over all entries (slow)
                for idx, val in enumerate(cos_scores[i].tolist()):
                    if val >= threshold:
                        new_cluster.append(idx)

            extracted_communities.append(new_cluster)

    # Largest cluster first
    extracted_communities = sorted(extracted_communities, key=lambda x: len(x), reverse=True)

    # Step 2) Remove overlapping communities
    unique_communities = []
    extracted_ids = set()

    for community in extracted_communities:
        add_cluster = True
        for idx in community:
            if idx in extracted_ids:
                add_cluster = False
                break

        if add_cluster:
            unique_communities.append(community)
            for idx in community:
                extracted_ids.add(idx)

    return unique_communities
# Now we run the clustering algorithm
# With the threshold parameter, we define at which threshold we identify
# two images as similar. Set the threshold lower, and you will get larger clusters which have 
# less similar images in it (e.g. black cat images vs. cat images vs. animal images).
# With min_community_size, we define that we only want to have clusters of a certain minimal size
clusters = community_detection(img_emb, threshold=0.9, min_community_size=10)
print("Total number of clusters:", len(clusters))
Total number of clusters: 147
# Now we output the first 10 (largest) clusters
for cluster in clusters[0:10]:
    print("\n\nCluster size:", len(cluster))
    
    #Output 3 images
    for idx in cluster[0:3]:
        display(IPImage(os.path.join(img_folder, img_names[idx]), width=200))
Cluster size: 482






Cluster size: 401






Cluster size: 360






Cluster size: 258






Cluster size: 194






Cluster size: 145






 

以上