PyTorch 2.0 チュートリアル : 入門 : Transforms

PyTorch 2.0 チュートリアル : 入門 : Transforms (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/18/2023 (2.0.0)

* 本ページは、PyTorch 2.0 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Website: www.classcat.com  ;   ClassCatJP

 

PyTorch 2.0 チュートリアル : 入門 : Transforms

データは機械学習アルゴリズムを訓練するために必要な最終的な処理形式で常に与えられるわけではありません。データのある操作を実行するために transforms を利用してそれを訓練に適合するようにします。

総ての TorchVision データセットは 2 つのパラメータを持ちます – 特徴量を変更するための transform とラベルを変更するための target_transform – これらは変換ロジックを含む callable を受け取ります。torchvision.transforms モジュールはすぐに使える幾つかの一般的に使用される変換を提供します。

FashionMNIST 特徴量は PIL 画像形式にあり、そしてラベルは整数値です。訓練のためには、正規化された tensor としての特徴量、そして one-hot エンコードされたテンソルとしてのラベルを必要とします。これらの変換を行なうため、ToTensor と Lambda を使用します。

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<01:27, 299897.12it/s]
  0%|          | 65536/26421880 [00:00<01:28, 298398.41it/s]
  0%|          | 131072/26421880 [00:00<01:00, 434083.08it/s]
  1%|          | 229376/26421880 [00:00<00:42, 615419.41it/s]
  2%|1         | 491520/26421880 [00:00<00:20, 1252568.03it/s]
  4%|3         | 950272/26421880 [00:00<00:11, 2243208.64it/s]
  7%|7         | 1933312/26421880 [00:00<00:05, 4429509.68it/s]
 15%|#4        | 3833856/26421880 [00:00<00:02, 8512958.26it/s]
 26%|##6       | 6979584/26421880 [00:00<00:01, 14757242.33it/s]
 37%|###7      | 9895936/26421880 [00:01<00:00, 18313952.68it/s]
 49%|####9     | 13008896/26421880 [00:01<00:00, 21352645.90it/s]
 60%|######    | 15958016/26421880 [00:01<00:00, 22891504.70it/s]
 72%|#######2  | 19070976/26421880 [00:01<00:00, 24507338.07it/s]
 84%|########3 | 22183936/26421880 [00:01<00:00, 25637858.09it/s]
 95%|#########4| 24969216/26421880 [00:01<00:00, 25459500.65it/s]
100%|##########| 26421880/26421880 [00:01<00:00, 15916497.18it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/29515 [00:00<?, ?it/s]
100%|##########| 29515/29515 [00:00<00:00, 269231.10it/s]
100%|##########| 29515/29515 [00:00<00:00, 267946.75it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/4422102 [00:00<?, ?it/s]
  1%|          | 32768/4422102 [00:00<00:14, 304153.27it/s]
  1%|1         | 65536/4422102 [00:00<00:14, 302904.48it/s]
  3%|2         | 131072/4422102 [00:00<00:09, 440532.50it/s]
  5%|5         | 229376/4422102 [00:00<00:06, 624684.22it/s]
 11%|#1        | 491520/4422102 [00:00<00:03, 1271647.45it/s]
 21%|##1       | 950272/4422102 [00:00<00:01, 2278198.61it/s]
 44%|####3     | 1933312/4422102 [00:00<00:00, 4493081.88it/s]
 87%|########6 | 3833856/4422102 [00:00<00:00, 8640067.62it/s]
100%|##########| 4422102/4422102 [00:00<00:00, 5087879.30it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/5148 [00:00<?, ?it/s]
100%|##########| 5148/5148 [00:00<00:00, 26493591.40it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

 

ToTensor()

ToTensor は PIL 画像や NumPy ndarray を FloatTensor に変換して画像のピクセル強度値を範囲 [0., 1.] にスケールします。

 

Lambda 変換

Lambda 変換は任意のユーザ定義 lambda 関数を適用します。ここでは、整数を one-hot エンコード・テンソルに変えます。それは最初にサイズ 10 (データセットのラベル数) のゼロ・テンソルを作成して scatter_ を呼び出します、これはラベル y により与えられたインデックス上 value=1 を割当てます。

target_transform = Lambda(lambda y: torch.zeros(
    10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

 

以上