PyTorch 1.5 レシピ : 基本 : PyTorch でデータをロードする (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/08/2020 (1.5.0)
* 本ページは、PyTorch 1.5 Recipes の以下のページを翻訳した上で適宜、補足説明したものです:
- Basic : Loading Data in PyTorch
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
基本 : PyTorch でデータをロードする
イントロダクション
PyTorch データローディング・ユティリティの中心は torch.utils.data.DataLoader です。
それはデータセットに渡る Python iterable を表します。PyTorch のライブラリは torch.utils.data.Dataset で使用する組込みの高位データセットを貴方に提供します。これらのデータセットは現在以下で利用可能です :
with more to come. torchaudio.datasets.YESNO からの Yesno データセットをしようして、PyTorch Dataset からのデータをどのように PyTorch DataLoader を効果的にそして効率的にロードするかを実演します。
セットアップ
始める前に、データセットにアクセスするため torchaudio をインストールする必要があります。
pip install torchaudio
ステップ
- データをロードするために総ての必要なライブラリをインポートする
- データセットのデータにアクセスする
- データをロードする
- データに渡り反復する
- [オプション] データを可視化する
1. データをロードするために必要なライブラリをインポートする
このレシピのため、torch と torchaudio を使用します。どの組込みデータセットを利用するかに依拠して、torchvision or torchtext をインストールしてインポートすることもできます。
import torch import torchaudio
データセットのデータにアクセスする
torchaudio の Yesno データセットは個々がヘブライ語で yes or no を言っている 60 の録音をフィーチャーしています ; 各録音は 8 単語長です (更に ここ を読んでください)。
torchaudio.datasets.YESNO は YesNo のためのデータセットを作成します。
torchaudio.datasets.YESNO( root, url='http://www.openslr.org/resources/1/waves_yesno.tar.gz', folder_in_archive='waves_yesno', download=False, transform=None, target_transform=None)
データセットの各項目はタプルの形式です : (waveform, sample_rate, labels)。
Yesno データセットのためのルートを設定しなければなりません、これは訓練とテストデータセットが存在するところです。他のパラメータはオプションで、デフォルト値が示されています。ここに他のパラメータについての幾つかの追加の有用な情報があります :
# * ``download``: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. # * ``transform``: Using transforms on your data allows you to take it from its source state and transform it into data that’s joined together, de-normalized, and ready for training. Each library in PyTorch supports a growing list of transformations. # * ``target_transform``: A function/transform that takes in the target and transforms it. # # Let’s access our Yesno data: # # A data point in Yesno is a tuple (waveform, sample_rate, labels) where labels # is a list of integers with 1 for yes and 0 for no. yesno_data_trainset = torchaudio.datasets.YESNO('./', download=True) # Pick data point number 3 to see an example of the the yesno_data: n = 3 waveform, sample_rate, labels = yesno_data[n] print("Waveform: {}\nSample rate: {}\nLabels: {}".format(waveform, sample_rate, labels))
このデータを実際に使用するとき、データを「訓練」データセットと「テスト」データセットに供給するのは最善の実践です。これはモデルのパフォーマンスをテストするとき out-of-sample データを持つことを確かなものにします。
3. データをロードする
データセットへのアクセスを持つ今、それを torch.utils.data.DataLoader を通して渡さなければなりません。DataLoader はデータセットとサンプラーを結合し、データセットに渡る iterable を返します。
data_loader = torch.utils.data.DataLoader(yesno_data, batch_size=1, shuffle=True)
4. データに渡り反復する
私達のデータは data_loader を使用して iterable です。モデルの訓練を始めるときこれは必要になります!今、data_loader オブジェクトの各データエントリは waveform, サンプリングレートとラベルを表す tensor を含む tensor に変換されることに気付くでしょう。
for data in data_loader: print("Data: ", data) print("Waveform: {}\nSample rate: {}\nLabels: {}".format(data[0], data[1], data[2])) break
5. [オプション] データを可視化する
DataLoader からの出力を更に理解するためデータをオプションで可視化できます。
import matplotlib.pyplot as plt print(data[0][0].numpy()) plt.figure() plt.plot(waveform.t().numpy())
Congratulations! データを PyTorch に成功的にロードしました。
以上