PyTorch 1.8 チュートリアル : PyTorch の学習 : 基本 – Transform (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/13/2021 (1.8.0)
* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:
- Learning PyTorch : Learn the Basics : Transforms
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
PyTorch の学習 : 基本 – Transform
データは機械学習アルゴリズムを訓練するために必要な最終的な処理形式で常に与えられるわけではありません。データのある操作を遂行するために transform を利用してそれを訓練に適合するようにします。
総ての TorchVision データセットは 2 つのパラメータを持ちます – 特徴を変更するための transform とラベルを変更するための target_transform – これらは変換ロジックを含む callable を受け取ります。torchvision.transforms モジュールはすぐに使える幾つかの一般的に利用される変換を提供します。
FashionMNIST 特徴は PIL 画像形式にあり、そしてラベルは整数です。訓練のためには、正規化された tensor としての特徴、そして one-hot エンコードされた tensor としての特徴を必要とします。これらの変換を行なうため、ToTensor と Lambda を使用します。
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw Processing... Done!
ToTensor()
ToTensor は PIL 画像や NumPy ndarray を FloatTensor に変換して画像のピクセル強度値を範囲 [0., 1.] にスケールします。
Lambda 変換
Lambda 変換は任意のユーザ定義 lambda 関数を適用します。ここでは、整数を one-hot エンコード tensor に変えます。それは最初にサイズ 10 のゼロ tensor を作成して scatter_ を呼び出します、これはラベル y により与えられたインデックス上 value=1 を割当てます。
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
以上