PyTorch 0.4.1 examples (コード解説) : テキスト分類 – TorchText IMDB (LSTM, GRU)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 08/14/2018 (0.4.1)
* 本ページは、github 上の以下の pytorch/examples と keras/examples レポジトリのサンプル・コードを参考にしています:
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
LSTM モデル
import torch.nn as nn class Net(nn.Module): def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim): super().__init__() self.embedding = nn.Embedding(input_dim, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5) self.fc = nn.Linear(hidden_dim*2, output_dim) def forward(self, x): embedded = self.embedding(x) output, (hidden, cell) = self.lstm(embedded) hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1) return self.fc(hidden)
テスト精度は 76.86% まで到達。基本的な RNN では 62.56% でしたので、かなりの改善が見られました。
LSTM モデル with Dropout 層
検証精度のグラフが不自然なので、Dropout を追加してみました :
import torch.nn as nn class Net(nn.Module): def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim): super().__init__() self.embedding = nn.Embedding(input_dim, embedding_dim) self.dropout1 = nn.Dropout() self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5) self.dropout2 = nn.Dropout() self.fc = nn.Linear(hidden_dim*2, output_dim) def forward(self, x): embedded = self.dropout1(self.embedding(x)) output, (hidden, cell) = self.lstm(embedded) hidden = self.dropout2(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)) return self.fc(hidden)
グラフは滑らかになり、テスト精度は 77.22% が得られました。
GRU モデル with Dropout 層
比較のために LSTM の代わりに GRU を利用してみました。
class Net(nn.Module): def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim): super().__init__() self.embedding = nn.Embedding(input_dim, embedding_dim) self.dropout1 = nn.Dropout() self.gru = nn.GRU(embedding_dim, hidden_dim, batch_first=True, num_layers=2, bidirectional=True, dropout=0.5) self.dropout2 = nn.Dropout() self.fc = nn.Linear(hidden_dim*2, output_dim) def forward(self, x): embedded = self.dropout1(self.embedding(x)) output, hidden = self.gru(embedded) hidden = self.dropout2(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)) return self.fc(hidden)
テスト精度は 78.78% でした。
以上