PyTorch 0.4.1 examples : テキスト分類 – TorchText IMDB (RNN)

PyTorch 0.4.1 examples (コード解説) : テキスト分類 – TorchText IMDB (RNN)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/14/2018 (0.4.1)

* 本ページは、github 上の以下の pytorch/examples と keras/examples レポジトリのサンプル・コードを参考にしています:

* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

TorchText IMDB

◇ 最初に data.Field と data.LabelField のインスタンスを作成します。LabelField は Field の派生クラスです :

from torchtext import data

TEXT = data.Field(sequential=True, fix_length=80, tensor_type=torch.LongTensor, batch_first=True, lower=True)
LABEL = data.LabelField(sequential=False, tensor_type=torch.FloatTensor)

◇ 次に、作成した Field インスタンスを引数として datasets.IMDB の .splits() クラスメソッドを呼び出すと torchtext.data.Dataset オブジェクトが返されます :

from torchtext import datasets

ds_train, ds_test = datasets.IMDB.splits(TEXT, LABEL)

要素数それぞれ 25,000 の訓練用データセットとテスト用データセットが生成されます。
.fields プロパティはキー ‘text’ と ‘label’ を持つ辞書で上で作成した Field インスタンスを保持しています :

print('train : ', len(ds_train))
print('test : ', len(ds_test))
print('train.fields :', ds_train.fields)
train : 25000
test : 25000
train.fields : {'text' : <torchtext.data.field.Field object at 0x7f27636c9080>, 
'label': <torchtext.data.field.Field object at 0x7f27636c90b8>}

◇ 更に訓練データセット i.e. torchtext.data.Dataset の .split() メソッドで検証データセットも生成できます。split_ratio のデフォルトは 0.7 です :

import random

ds_train, ds_valid = ds_train.split(random_state=random.seed(SEED))
print('train : ', len(ds_train))
print('valid : ', len(ds_valid))
train : 17500
valid : 7500

◇ 続いて、Field インスタンスの .build_vocab() メソッドでそのフィールドについて Vocab オブジェクトを構築します :

TEXT.build_vocab(ds_train, max_size=25000)
LABEL.build_vocab(ds_train)

幾つか属性を確認しましょう :

print('TEXT.vocab size : %d ; LABEL size : %d' % (len(TEXT.vocab), len(LABEL.vocab)))

print(TEXT.vocab.freqs.most_common(20))
print(LABEL.vocab.freqs.most_common(20))

# print(TEXT.vocab.__dict__)
print(LABEL.vocab.__dict__)

print(TEXT.vocab.itos[:10])

print(LABEL.vocab.stoi)
TEXT.vocab size : 25002 ; LABEL size : 2

[('the', 225398), ('a', 111814), ('and', 110519), ('of', 101005), ('to', 93624), ('is', 72546), ('in', 63126), ('i', 49081), ('this', 48813), ('that', 46443), ('it', 45681),
('/><br', 35589), ('was', 32813), ('as', 31380), ('for', 29868), ('with', 29781), ('but', 27740), ('on', 22248), ('movie', 21611), ('his', 20327)]
[('neg', 8810), ('pos', 8690)]

{'freqs': Counter({'neg': 8810, 'pos': 8690}), 'itos': ['neg', 'pos'], 
'stoi': defaultdict(<function _default_unk_index at 0x7fd5346de2f0>, {'neg': 0, 'pos': 1}), 'vectors': None}

['<unk>', '<pad>', 'the', 'a', 'and', 'of', 'to', 'is', 'in', 'i']

defaultdict(<function _default_unk_index at 0x7ff8f2a4f2f0>, {'neg': 0, 'pos': 1})

◇ 最後に、BucketIterator でラップします :torchtext.data.Iterator は Dataset からバッチをロードする iterator です。つまり、DataLoader と同じ役割りを果たします :

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (ds_train, ds_valid, ds_test), batch_size=BATCH_SIZE, sort_key=lambda x: len(x.text), repeat=False
)

 

RNN モデル

import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)


    def forward(self, x):
        embedded = self.embedding(x)
        output, hidden = self.rnn(embedded)
        hidden = hidden.squeeze(0)

        return self.fc(hidden)

 

訓練と評価

損失関数は nn.BCEWithLogitsLoss() です :

検証データセットを作成したので、エポック毎に精度を確認します :

最終的なテスト精度は 62.56% でした。

 

以上