PyTorch 1.8 チュートリアル : テキスト : TorchText でテキスト分類 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 04/24/2021 (1.8.1+cu102)
* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
スケジュールは弊社 公式 Web サイト でご確認頂けます。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
人工知能研究開発支援 | 人工知能研修サービス | テレワーク & オンライン授業を支援 |
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。) |
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ ; Facebook |
テキスト : TorchText でテキスト分類
このチュートリアルでは、テキスト分類分析のためのデータセットを構築するためにどのように torchtext を利用するかを示します。ユーザは以下を行なうための柔軟性を持ちます :
- iterator としての raw データにアクセスする
- raw テキスト文字列を (モデルを訓練するために使用できる) torch.Tensor に変換するためにデータ処理パイプラインを構築する
- torch.utils.data.DataLoader でデータをシャッフルして iterate する
raw データセット iterator にアクセスする
torchtext は幾つかの raw データセット iterator を提供します、これは raw テキスト文字列を yield します。例えば、AG_NEWS データセット iterator はラベルとテキストのタプルとして raw データを yield します。
import torch
from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')
next(train_iter)
>>> (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters -
Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green
again.")
next(train_iter)
>>> (3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private
investment firm Carlyle Group,\\which has a reputation for making well-timed
and occasionally\\controversial plays in the defense industry, has quietly
placed\\its bets on another part of the market.')
next(train_iter)
>>> (3, "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring
crude prices plus worries\\about the economy and the outlook for earnings are
expected to\\hang over the stock market next week during the depth of
the\\summer doldrums.")
データ処理パイプラインを準備する
語彙、単語ベクトル、tokenizer を含む、torchtext ライブラリの非常に基本的なコンポーネントを再検討しました。それらは raw テキスト文字列のための基本的なデータ処理ビルディングブロックです。
ここに tonenizer と語彙を伴う典型的な NLP データ処理のためのサンプルがあります。最初のステップは raw 訓練データセットで語彙を構築します。ユーザは Vocab クラスのコンストラクタの引数をセットアップすることによりカスタマイズされた語彙を持つことができます。例えば、トークンのための最小頻度 min_freq を含めます。
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab
tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')
counter = Counter()
for (label, line) in train_iter:
counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=1)
語彙ブロックはトークンのリストを整数に変換します。
[vocab[token] for token in ['here', 'is', 'an', 'example']]
>>> [476, 22, 31, 5298]
tokenizer と語彙でテキスト処理パイプラインを準備します。テキストとラベルパイプラインはデータセット iterator からの raw データ文字列を処理するために使用されます。
text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: int(x) - 1
テキストパイプラインは語彙で定義された検索テーブルに基づいてテキスト文字列を整数のリストに変換します。ラベルパイプラインはラベルを整数に変換します。例えば、
text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5286]
label_pipeline('10')
>>> 9
データバッチと iterator を生成する
torch.utils.data.DataLoader は PyTorch ユーザのために推奨されます (チュートリアルは こちら です)。それはマップ-style データセットとともに動作します、これは getitem() と len() プロトコルを実装し、そしてインデックス/キーからデータサンプルへのマップを表します。それはまた False の shuffle 引数を持つ iterable なデータセットとともに動作もします。
モデルに送る前に、collate_fn 関数は DataLoader から生成されたサンプルのバッチ上で動作します。collate_fn の入力は DataLoader の batch サイズを持つデータのバッチで、そして collate_fn は前に宣言されたデータ処理パイプラインに従ってそれらを処理します。ここで注意してください、collate_fn はトップレベルの def として宣言されることを確実にしてください。これは関数が各ワーカーで利用可能であることを保証します。
この例では、元のデータバッチ入力のテキストエントリはリストにパックされてそして nn.EmbeddingBag の入力のための単一 tensor として結合されます。オフセットはテキスト tensor の個々のシークエンスの開始インデックスを表すデリミタの tensor です。ラベルは個々のテキストエントリのラベルを保存する tensor です。
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def collate_batch(batch):
label_list, text_list, offsets = [], [], [0]
for (_label, _text) in batch:
label_list.append(label_pipeline(_label))
processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
text_list.append(processed_text)
offsets.append(processed_text.size(0))
label_list = torch.tensor(label_list, dtype=torch.int64)
offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
text_list = torch.cat(text_list)
return label_list.to(device), text_list.to(device), offsets.to(device)
train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)
モデルを定義する
モデルは nn.EmbeddingBag 層と分類目的のための線形層から成ります。“mean” のデフォルト・モードを持つ nn.EmbeddingBag は埋め込みの「バッグ」の平均値を計算します。ここでのテキストエントリは異なる長さを持ちます。ここでは nn.EmbeddingBag モジュールはパディングを必要としません、何故ならばテキスト長はオフセットにセーブされているからです。
更に、nn.EmbeddingBag は埋め込みに渡り平均を on the fly に累積しますので、nn.EmbeddingBag は tensor のシークエンスを処理するパフォーマンスとメモリ効率を強化できます。
from torch import nn
class TextClassificationModel(nn.Module):
def __init__(self, vocab_size, embed_dim, num_class):
super(TextClassificationModel, self).__init__()
self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
self.fc = nn.Linear(embed_dim, num_class)
self.init_weights()
def init_weights(self):
initrange = 0.5
self.embedding.weight.data.uniform_(-initrange, initrange)
self.fc.weight.data.uniform_(-initrange, initrange)
self.fc.bias.data.zero_()
def forward(self, text, offsets):
embedded = self.embedding(text, offsets)
return self.fc(embedded)
インスタンスを初期化する
AG_NEWS データセットは 4 つのラベルを持ち従ってクラス数は 4 です。
1 : World 2 : Sports 3 : Business 4 : Sci/Tec
64 の埋め込み次元を持つモデルを構築します。語彙サイズは語彙インスタンスの長さに等しいです。クラス数はラベルの数に等しいです。
train_iter = AG_NEWS(split='train')
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)
モデルを訓練して結果を評価する関数を定義する
import time
def train(dataloader):
model.train()
total_acc, total_count = 0, 0
log_interval = 500
start_time = time.time()
for idx, (label, text, offsets) in enumerate(dataloader):
optimizer.zero_grad()
predited_label = model(text, offsets)
loss = criterion(predited_label, label)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
optimizer.step()
total_acc += (predited_label.argmax(1) == label).sum().item()
total_count += label.size(0)
if idx % log_interval == 0 and idx > 0:
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches '
'| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
total_acc/total_count))
total_acc, total_count = 0, 0
start_time = time.time()
def evaluate(dataloader):
model.eval()
total_acc, total_count = 0, 0
with torch.no_grad():
for idx, (label, text, offsets) in enumerate(dataloader):
predited_label = model(text, offsets)
loss = criterion(predited_label, label)
total_acc += (predited_label.argmax(1) == label).sum().item()
total_count += label.size(0)
return total_acc/total_count
データセットを分割してモデルを実行する
元の AG_NEWS は検証データセットを持ちませんので、訓練データセットを 0.95 (訓練) と 0.05 (検証) の分割比率で訓練/検証セットに分割します。ここでは PyTorch コアライブラリの torch.utils.data.dataset.random_split 関数を利用します。
CrossEntropyLoss criterion は nn.LogSoftmax() と nn.NLLLoss() を単一クラスに結合しています。それは C クラスで分類問題を訓練するときに有用です。SGD は optimizer として確率的勾配降下法を実装しています。初期学習率は 5.0 に設定されます。エポックを通して学習率を調整するためにここでは StepLR が使用されます。
from torch.utils.data.dataset import random_split
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5 # learning rate
BATCH_SIZE = 64 # batch size for training
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = list(train_iter)
test_dataset = list(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
random_split(train_dataset, [num_train, len(train_dataset) - num_train])
train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
shuffle=True, collate_fn=collate_batch)
for epoch in range(1, EPOCHS + 1):
epoch_start_time = time.time()
train(train_dataloader)
accu_val = evaluate(valid_dataloader)
if total_accu is not None and total_accu > accu_val:
scheduler.step()
else:
total_accu = accu_val
print('-' * 59)
print('| end of epoch {:3d} | time: {:5.2f}s | '
'valid accuracy {:8.3f} '.format(epoch,
time.time() - epoch_start_time,
accu_val))
print('-' * 59)
| epoch 1 | 500/ 1782 batches | accuracy 0.690 | epoch 1 | 1000/ 1782 batches | accuracy 0.855 | epoch 1 | 1500/ 1782 batches | accuracy 0.875 ----------------------------------------------------------- | end of epoch 1 | time: 11.69s | valid accuracy 0.887 ----------------------------------------------------------- | epoch 2 | 500/ 1782 batches | accuracy 0.895 | epoch 2 | 1000/ 1782 batches | accuracy 0.902 | epoch 2 | 1500/ 1782 batches | accuracy 0.904 ----------------------------------------------------------- | end of epoch 2 | time: 11.68s | valid accuracy 0.902 ----------------------------------------------------------- | epoch 3 | 500/ 1782 batches | accuracy 0.915 | epoch 3 | 1000/ 1782 batches | accuracy 0.915 | epoch 3 | 1500/ 1782 batches | accuracy 0.915 ----------------------------------------------------------- | end of epoch 3 | time: 11.68s | valid accuracy 0.903 ----------------------------------------------------------- | epoch 4 | 500/ 1782 batches | accuracy 0.927 | epoch 4 | 1000/ 1782 batches | accuracy 0.924 | epoch 4 | 1500/ 1782 batches | accuracy 0.921 ----------------------------------------------------------- | end of epoch 4 | time: 11.67s | valid accuracy 0.904 ----------------------------------------------------------- | epoch 5 | 500/ 1782 batches | accuracy 0.933 | epoch 5 | 1000/ 1782 batches | accuracy 0.931 | epoch 5 | 1500/ 1782 batches | accuracy 0.929 ----------------------------------------------------------- | end of epoch 5 | time: 11.67s | valid accuracy 0.904 ----------------------------------------------------------- | epoch 6 | 500/ 1782 batches | accuracy 0.938 | epoch 6 | 1000/ 1782 batches | accuracy 0.933 | epoch 6 | 1500/ 1782 batches | accuracy 0.934 ----------------------------------------------------------- | end of epoch 6 | time: 11.69s | valid accuracy 0.910 ----------------------------------------------------------- | epoch 7 | 500/ 1782 batches | accuracy 0.942 | epoch 7 | 1000/ 1782 batches | accuracy 0.939 | epoch 7 | 1500/ 1782 batches | accuracy 0.937 ----------------------------------------------------------- | end of epoch 7 | time: 11.68s | valid accuracy 0.908 ----------------------------------------------------------- | epoch 8 | 500/ 1782 batches | accuracy 0.950 | epoch 8 | 1000/ 1782 batches | accuracy 0.952 | epoch 8 | 1500/ 1782 batches | accuracy 0.952 ----------------------------------------------------------- | end of epoch 8 | time: 11.67s | valid accuracy 0.910 ----------------------------------------------------------- | epoch 9 | 500/ 1782 batches | accuracy 0.952 | epoch 9 | 1000/ 1782 batches | accuracy 0.952 | epoch 9 | 1500/ 1782 batches | accuracy 0.954 ----------------------------------------------------------- | end of epoch 9 | time: 11.65s | valid accuracy 0.911 ----------------------------------------------------------- | epoch 10 | 500/ 1782 batches | accuracy 0.953 | epoch 10 | 1000/ 1782 batches | accuracy 0.955 | epoch 10 | 1500/ 1782 batches | accuracy 0.954 ----------------------------------------------------------- | end of epoch 10 | time: 11.64s | valid accuracy 0.912 -----------------------------------------------------------
以下のプリントアウトとともに GPU 上でモデルを実行します :
| epoch 1 | 500/ 1782 batches | accuracy 0.684 | epoch 1 | 1000/ 1782 batches | accuracy 0.852 | epoch 1 | 1500/ 1782 batches | accuracy 0.877 ----------------------------------------------------------- | end of epoch 1 | time: 8.33s | valid accuracy 0.867 ----------------------------------------------------------- | epoch 2 | 500/ 1782 batches | accuracy 0.895 | epoch 2 | 1000/ 1782 batches | accuracy 0.900 | epoch 2 | 1500/ 1782 batches | accuracy 0.903 ----------------------------------------------------------- | end of epoch 2 | time: 8.18s | valid accuracy 0.890 ----------------------------------------------------------- | epoch 3 | 500/ 1782 batches | accuracy 0.914 | epoch 3 | 1000/ 1782 batches | accuracy 0.914 | epoch 3 | 1500/ 1782 batches | accuracy 0.916 ----------------------------------------------------------- | end of epoch 3 | time: 8.20s | valid accuracy 0.897 ----------------------------------------------------------- | epoch 4 | 500/ 1782 batches | accuracy 0.926 | epoch 4 | 1000/ 1782 batches | accuracy 0.924 | epoch 4 | 1500/ 1782 batches | accuracy 0.921 ----------------------------------------------------------- | end of epoch 4 | time: 8.18s | valid accuracy 0.895 ----------------------------------------------------------- | epoch 5 | 500/ 1782 batches | accuracy 0.938 | epoch 5 | 1000/ 1782 batches | accuracy 0.935 | epoch 5 | 1500/ 1782 batches | accuracy 0.937 ----------------------------------------------------------- | end of epoch 5 | time: 8.16s | valid accuracy 0.902 ----------------------------------------------------------- | epoch 6 | 500/ 1782 batches | accuracy 0.939 | epoch 6 | 1000/ 1782 batches | accuracy 0.939 | epoch 6 | 1500/ 1782 batches | accuracy 0.938 ----------------------------------------------------------- | end of epoch 6 | time: 8.16s | valid accuracy 0.906 ----------------------------------------------------------- | epoch 7 | 500/ 1782 batches | accuracy 0.941 | epoch 7 | 1000/ 1782 batches | accuracy 0.939 | epoch 7 | 1500/ 1782 batches | accuracy 0.939 ----------------------------------------------------------- | end of epoch 7 | time: 8.19s | valid accuracy 0.903 ----------------------------------------------------------- | epoch 8 | 500/ 1782 batches | accuracy 0.942 | epoch 8 | 1000/ 1782 batches | accuracy 0.941 | epoch 8 | 1500/ 1782 batches | accuracy 0.942 ----------------------------------------------------------- | end of epoch 8 | time: 8.16s | valid accuracy 0.904 ----------------------------------------------------------- | epoch 9 | 500/ 1782 batches | accuracy 0.942 | epoch 9 | 1000/ 1782 batches | accuracy 0.941 | epoch 9 | 1500/ 1782 batches | accuracy 0.942 ----------------------------------------------------------- end of epoch 9 | time: 8.16s | valid accuracy 0.904 ----------------------------------------------------------- | epoch 10 | 500/ 1782 batches | accuracy 0.940 | epoch 10 | 1000/ 1782 batches | accuracy 0.942 | epoch 10 | 1500/ 1782 batches | accuracy 0.942 ----------------------------------------------------------- | end of epoch 10 | time: 8.15s | valid accuracy 0.904 -----------------------------------------------------------
テストデータセットでモデルを評価する
テストデータセットの結果をチェックします…
print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))
Checking the results of test dataset. test accuracy 0.910
test accuracy 0.906
ランダムなニュース上のテスト
ここまでの最善のモデルを使用してゴルフのニュースをテストします。
ag_news_label = {1: "World",
2: "Sports",
3: "Business",
4: "Sci/Tec"}
def predict(text, text_pipeline):
with torch.no_grad():
text = torch.tensor(text_pipeline(text))
output = model(text, torch.tensor([0]))
return output.argmax(1).item() + 1
ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
enduring the season’s worst weather conditions on Sunday at The \
Open on his way to a closing 75 at Royal Portrush, which \
considering the wind and the rain was a respectable showing. \
Thursday’s first round at the WGC-FedEx St. Jude Invitational \
was another story. With temperatures in the mid-80s and hardly any \
wind, the Spaniard was 13 strokes better in a flawless round. \
Thanks to his best putting performance on the PGA Tour, Rahm \
finished with an 8-under 62 for a three-stroke lead, which \
was even more impressive considering he’d never played the \
front nine at TPC Southwind."
model = model.to("cpu")
print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])
This is a Sports news
以上