PyTorch 0.4.1 examples (コード解説) : テキスト分類 – TorchText IMDB (LSTM, GRU)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/14/2018 (0.4.1)
* 本ページは、github 上の以下の pytorch/examples と keras/examples レポジトリのサンプル・コードを参考にしています:
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
LSTM モデル
import torch.nn as nn
class Net(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5)
self.fc = nn.Linear(hidden_dim*2, output_dim)
def forward(self, x):
embedded = self.embedding(x)
output, (hidden, cell) = self.lstm(embedded)
hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
return self.fc(hidden)


テスト精度は 76.86% まで到達。基本的な RNN では 62.56% でしたので、かなりの改善が見られました。
LSTM モデル with Dropout 層
検証精度のグラフが不自然なので、Dropout を追加してみました :
import torch.nn as nn
class Net(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.dropout1 = nn.Dropout()
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5)
self.dropout2 = nn.Dropout()
self.fc = nn.Linear(hidden_dim*2, output_dim)
def forward(self, x):
embedded = self.dropout1(self.embedding(x))
output, (hidden, cell) = self.lstm(embedded)
hidden = self.dropout2(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
return self.fc(hidden)


グラフは滑らかになり、テスト精度は 77.22% が得られました。
GRU モデル with Dropout 層
比較のために LSTM の代わりに GRU を利用してみました。
class Net(nn.Module):
def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):
super().__init__()
self.embedding = nn.Embedding(input_dim, embedding_dim)
self.dropout1 = nn.Dropout()
self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5)
self.dropout2 = nn.Dropout()
self.fc = nn.Linear(hidden_dim*2, output_dim)
def forward(self, x):
embedded = self.dropout1(self.embedding(x))
output, hidden = self.gru(embedded)
hidden = self.dropout2(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
return self.fc(hidden)



テスト精度は 78.78% でした。
以上