PyTorch Metric Learning (距離学習) 0.99 : 概要

PyTorch Metric Learning (距離学習) 0.99 : 概要 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/16/2021 (v0.9.98)

* 本ページは、PyTorch Metric Learning の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料セミナー実施中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマにウェビナー (WEB セミナー) を定期的に開催しています。スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

 

PyTorch Metric Learning 0.99 (距離学習) : 概要

概要

このライブラリは 9 モジュールを含み、その各々は貴方の既存のコードベース内で独立に利用できたり、完全な訓練/テスト・ワークフローのために一緒に組み合わせることができます。

 

損失関数がどのように動作するか

訓練ループで損失と miner を使用する

通常の TripletMarginLoss を初期化しましょう :

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss()

訓練ループで損失を計算するには、貴方のモデルで計算された embeddings、そして対応するラベルを渡します。embeddings はサイズ (N, embedding_size) を持ち、そしてラベルはサイズ (N) を持つべきです、ここで N はバッチサイズです。

# your training loop
for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    loss = loss_func(embeddings, labels)
    loss.backward()
    optimizer.step()

TripletMarginLoss はそれに渡したラベルに基づいて、バッチ内の総ての可能なトリプレットを計算します。アンカー positive ペアは同じラベルを共有する埋め込みにより形成され、そしてアンカー negative ペアは異なるラベルを持つ埋め込みにより形成されます。

時にそれは mining 関数を追加するのに役立つことができます :

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner()
loss_func = losses.TripletMarginLoss()

# your training loop
for i, (data, labels) in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = model(data)
    hard_pairs = miner(embeddings, labels)
    loss = loss_func(embeddings, labels, hard_pairs)
    loss.backward()
    optimizer.step()

上のコードでは、miner はそれが特に難しいと考える positive と negative ペアを見つけます。TripletMarginLoss はトリプレット上で作用しますが、依然としてペアを渡すことも可能です。これはライブラリが必要なときには、ペアをトリプレットにそしてトリプレットをペアに自動的に変換するからです。

 

損失関数をカスタマイズする

損失関数は distance, reducerregularizer を使用してカスタマイズできます。下の図では、miner がバッチ内のハード・ペアのインデックスを見つけます。これらは distance オブジェクトにより計算される、距離行列内にインデックスするために使用されます。この図については、損失関数はペア・ベースなので、それは損失をペア毎に計算します。加えて、regularizer が供給されていますので、バッチの各埋め込みに対して正則化損失が計算されます。ペア毎と要素毎損失は reducer に渡されます、これは (この図では) 高値を持つ損失を保持するだけです。高値のペアと要素損失のために平均が計算されて、それから最終的な損失を得るために一緒に加算されます。

今はここにカスタマイズされた TripletMarginLoss のサンプルがあります :

from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.reducers import ThresholdReducer
from pytorch_metric_learning.regularizers import LpRegularizer
from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(distance = CosineSimilarity(), 
                                    reducer = ThresholdReducer(high=0.3), 
                                    embedding_regularizer = LpRegularizer())

このカスタマイズされたトリプレット損失は以下の特性を持ちます :

  • 損失はユークリッド距離の代わりにコサイン類似度を使用して計算されます。
  • 0.3 より高い総てのトリプレット損失は破棄されます。
  • 埋め込みは L2 正則化されます。

 

教師なし / 自己教師あり学習のために損失関数を使用する

TripletMarginLoss は埋め込みベースまたはタプルベースの損失です。これは内部的には、「クラス」のことは考えていないことを意味します。タプル (トリプレットのペア) は各反復で、それが受け取るラベルに基づいて、形成されます。ラベルはクラスを表さなくても構いません。それらは単純に埋め込み間の positive と negative 関係を示す必要があります。そして、教師なし or 自己教師あり学習のためにこれらの損失関数を使用することは容易です。

例えば、下のコードは自己教師 (= self-supervision) で一般に使用される増強ストラテジーの単純化されたバージョンです。データセットはどのようなラベルも装備していません。代わりに、どの埋め込みが positive ペアであるかを単に示すため、ラベルは訓練ループで作成されます。

# your training for-loop
for i, data in enumerate(dataloader):
    optimizer.zero_grad()
    embeddings = your_model(data)
    augmented = your_model(your_augmentation(data))
    labels = torch.arange(embeddings.size(0))

    embeddings = torch.cat([embeddings, augmented], dim=0)
    labels = torch.cat([labels, labels], dim=0)

    loss = loss_func(embeddings, labels)
    loss.backward()
    optimizer.step()

MoCo-スタイルの自己教師に関心があるならば、CIFAR10 上の MoCo ノートブックを見てください。それはモメンタム encoder キューを実装するために CrossBatchMemory を使用します、これはキューからハードサンプルを抽出するために任意のタプル損失とタプル miner を利用できることを意味します。

 

ライブラリの残りのハイライト

  • 貴方のモデルを訓練するための便利な方法については、trainer を見てください。
  • データセット上で貴方のモデルの精度をテストすることを望みますか?tester を試してください。
  • 埋め込み空間の精度を直接的に計算するには、AccuracyCalculator を使用してください。
 

以上