einops 0.4 : PyTorch サンプル : pytorch と einops でより良いコードを書く

einops 0.4 : PyTorch サンプル : pytorch と einops でより良いコードを書く (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 03/27/2022 (0.4.1)

* 本ページは、einops の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

einops 0.4 : PyTorch サンプル : pytorch と einops でより良いコードを書く

深層学習のビルディングブロックを書き直す

では現実世界からのサンプルに進みましょう。これらのコード断片は公式チュートリアルとポピュラーなレポジトリから抜粋されました。

コードを改良する方法と einops がどのように役立てるかを学習します。

# start from importing some stuff
import torch
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
import math

from einops import rearrange, reduce, asnumpy, parse_shape
from einops.layers.torch import Rearrange, Reduce

 

単純な ConvNet

オリジナル

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

conv_net_old = Net()

einops 改良版

conv_net_new = nn.Sequential(
    nn.Conv2d(1, 10, kernel_size=5),
    nn.MaxPool2d(kernel_size=2),
    nn.ReLU(),
    nn.Conv2d(10, 20, kernel_size=5),
    nn.MaxPool2d(kernel_size=2),
    nn.ReLU(),
    nn.Dropout2d(),
    Rearrange('b c h w -> b (c h w)'),
    nn.Linear(320, 50),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(50, 10),
    nn.LogSoftmax(dim=1)
)

新しい実装が好ましい理由 :

  • 元のコードでは、入力サイズが変更されてバッチサイズが 16 で割り切れる場合 (通常はそうです)、reshape の後で無意味な何かを得ます。
    • 新しいコードはこの場合エラーを明示的に上げます。

  • 新しいバージョンではフラグ self.training を有する dropout を使用することを忘れないようにします。

  • コードは読んで解析することが容易です。

  • sequential は出力 / セーブ / 受け渡しを自明にします。そしてコードでモデルをロードする必要がありません (これはまた多くの利点があります)。

  • logsoftmax は必要ありませんか?それなら conv_net_new[:-1] を使用することができます。nn.Sequential を好むもう一つの理由です。

  • … そしてまた ReLU のために inplace を追加できます。

 

超解像度

オリジナル

class SuperResolutionNetOld(nn.Module):
    def __init__(self, upscale_factor):
        super(SuperResolutionNetOld, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

einops 改良版

def SuperResolutionNetNew(upscale_factor):
    return nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=5, padding=2),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 64, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(64, 32, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(32, upscale_factor ** 2, kernel_size=3, padding=1),
        Rearrange('b (h2 w2) h w -> b (h h2) (w w2)', h2=upscale_factor, w2=upscale_factor),
    )

違いは以下です :

  • 特殊な命令 pixel_shuffle は必要ありません (そして結果はフレームワーク間で転送可能です)

  • 出力は fake 軸を含みません (そして入力に対して同じことができるでしょう)

  • inplace ReLU が今は使用されています、高解像度画像のために重要で多くのメモリを節約します。

  • そして再度 nn.Sequential の総ての利点。

 

スタイル変換の Gram 行列のリスタイル (= Restyling)

元のコードは既に良いです – 最初の行はどのような種類の入力が想定されているか示しています :

  • einsum 演算は以下のように読まれるべきです :

  • 各バッチと各チャネルのペアに対して、h と w に渡り合計します。

  • 正規化も変更しました、それは Gram 行列が定義される方法であるからです、そうでないならばそれを正規化された Gram 行列のようなものと呼ぶべきです。

オリジナル

def gram_matrix_old(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

einops 改良版

def gram_matrix_new(y):
    b, ch, h, w = y.shape
    return torch.einsum('bchw,bdhw->bcd', [y, y]) / (h * w)

単に ‘b c1 h w,b c2 h w->b c1 c2’ を使用できれば素晴らしいですが、einsum は 1 文字の軸だけをサポートします

 

リカレントモデル

ここで行なったことの総ては解読をスキップするために単に shape の情報を明示的にしたことです :

オリジナル

class RNNModelOld(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""
    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

einops 改良版

class RNNModelNew(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""
    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(p=dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

    def forward(self, input, hidden):
        t, b = input.shape
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = rearrange(self.drop(output), 't b nhid -> (t b) nhid')
        decoded = rearrange(self.decoder(output), '(t b) token -> t b token', t=t, b=b)
        return decoded, hidden

 

チャネル・シャッフル (from shufflenet)

オリジナル

def channel_shuffle_old(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups
    
    # reshape
    x = x.view(batchsize, groups, 
        channels_per_group, height, width)

    # transpose
    # - contiguous() required if transpose() is used before view().
    #   See https://github.com/pytorch/pytorch/issues/764
    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

einops 改良版

def channel_shuffle_new(x, groups):
    return rearrange(x, 'b (c1 c2) h w -> b (c2 c1) h w', c1=groups)

向上は明らかですが、これが限界ではありません。下で見るように、これらの数行を書く必要さえありません。

 

Shufflenet

オリジナル

from collections import OrderedDict

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()

    channels_per_group = num_channels // groups
    
    # reshape
    x = x.view(batchsize, groups, 
        channels_per_group, height, width)

    # transpose
    # - contiguous() required if transpose() is used before view().
    #   See https://github.com/pytorch/pytorch/issues/764
    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

class ShuffleUnitOld(nn.Module):
    def __init__(self, in_channels, out_channels, groups=3,
                 grouped_conv=True, combine='add'):
        
        super(ShuffleUnitOld, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.grouped_conv = grouped_conv
        self.combine = combine
        self.groups = groups
        self.bottleneck_channels = self.out_channels // 4

        # define the type of ShuffleUnit
        if self.combine == 'add':
            # ShuffleUnit Figure 2b
            self.depthwise_stride = 1
            self._combine_func = self._add
        elif self.combine == 'concat':
            # ShuffleUnit Figure 2c
            self.depthwise_stride = 2
            self._combine_func = self._concat
            
            # ensure output of concat has the same channels as 
            # original output channels.
            self.out_channels -= self.in_channels
        else:
            raise ValueError("Cannot combine tensors with \"{}\"" \
                             "Only \"add\" and \"concat\" are" \
                             "supported".format(self.combine))

        # Use a 1x1 grouped or non-grouped convolution to reduce input channels
        # to bottleneck channels, as in a ResNet bottleneck module.
        # NOTE: Do not use group convolution for the first conv1x1 in Stage 2.
        self.first_1x1_groups = self.groups if grouped_conv else 1

        self.g_conv_1x1_compress = self._make_grouped_conv1x1(
            self.in_channels,
            self.bottleneck_channels,
            self.first_1x1_groups,
            batch_norm=True,
            relu=True
            )

        # 3x3 depthwise convolution followed by batch normalization
        self.depthwise_conv3x3 = conv3x3(
            self.bottleneck_channels, self.bottleneck_channels,
            stride=self.depthwise_stride, groups=self.bottleneck_channels)
        self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)

        # Use 1x1 grouped convolution to expand from 
        # bottleneck_channels to out_channels
        self.g_conv_1x1_expand = self._make_grouped_conv1x1(
            self.bottleneck_channels,
            self.out_channels,
            self.groups,
            batch_norm=True,
            relu=False
            )


    @staticmethod
    def _add(x, out):
        # residual connection
        return x + out


    @staticmethod
    def _concat(x, out):
        # concatenate along channel axis
        return torch.cat((x, out), 1)


    def _make_grouped_conv1x1(self, in_channels, out_channels, groups,
        batch_norm=True, relu=False):

        modules = OrderedDict()
        conv = conv1x1(in_channels, out_channels, groups=groups)
        modules['conv1x1'] = conv

        if batch_norm:
            modules['batch_norm'] = nn.BatchNorm2d(out_channels)
        if relu:
            modules['relu'] = nn.ReLU()
        if len(modules) > 1:
            return nn.Sequential(modules)
        else:
            return conv


    def forward(self, x):
        # save for combining later with output
        residual = x
        if self.combine == 'concat':
            residual = F.avg_pool2d(residual, kernel_size=3, 
                stride=2, padding=1)

        out = self.g_conv_1x1_compress(x)
        out = channel_shuffle(out, self.groups)
        out = self.depthwise_conv3x3(out)
        out = self.bn_after_depthwise(out)
        out = self.g_conv_1x1_expand(out)
        
        out = self._combine_func(residual, out)
        return F.relu(out)

einops 改良版

class ShuffleUnitNew(nn.Module):
    def __init__(self, in_channels, out_channels, groups=3, 
                 grouped_conv=True, combine='add'):
        super().__init__()
        first_1x1_groups = groups if grouped_conv else 1
        bottleneck_channels = out_channels // 4
        self.combine = combine
        if combine == 'add':
            # ShuffleUnit Figure 2b
            self.left = Rearrange('...->...') # identity
            depthwise_stride = 1
        else:
            # ShuffleUnit Figure 2c
            self.left = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
            depthwise_stride = 2
            # ensure output of concat has the same channels as original output channels.
            out_channels -= in_channels
            assert out_channels > 0

        self.right = nn.Sequential(
            # Use a 1x1 grouped or non-grouped convolution to reduce input channels
            # to bottleneck channels, as in a ResNet bottleneck module.
            conv1x1(in_channels, bottleneck_channels, groups=first_1x1_groups),
            nn.BatchNorm2d(bottleneck_channels),
            nn.ReLU(inplace=True),
            # channel shuffle
            Rearrange('b (c1 c2) h w -> b (c2 c1) h w', c1=groups),
            # 3x3 depthwise convolution followed by batch 
            conv3x3(bottleneck_channels, bottleneck_channels,
                    stride=depthwise_stride, groups=bottleneck_channels),
            nn.BatchNorm2d(bottleneck_channels),
            # Use 1x1 grouped convolution to expand from 
            # bottleneck_channels to out_channels
            conv1x1(bottleneck_channels, out_channels, groups=groups),
            nn.BatchNorm2d(out_channels),
        )        
        
    def forward(self, x):
        if self.combine == 'add':
            combined = self.left(x) + self.right(x)
        else:
            combined = torch.cat([self.left(x), self.right(x)], dim=1)
        return F.relu(combined, inplace=True)

コードの書き換えは以下を識別するのに役立ちました :

  • 再シャッフルして最初の畳込みでグループを使用しないことは意味がないです (実際、論文でそのようにはなっていません)。けれども、同等のモデルになります。

  • 最初の畳込みはグループ化されないかもしれない一方で、最後の畳込みが常にグループ化されることは奇妙です (そしてそれは論文と異なります)。

他のコメント :

  • ここで導入された pytorch のための identity 層があります。

  • 残された最後のことはコードの conv1x1 と conv3x3 を取り除くことです – それらは標準より良くないです。

 

ResNet の単純化

オリジナル

class ResNetOld(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        self.inplanes = 64
        super(ResNetOld, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AvgPool2d(7, stride=1)

        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

einops 改良版

def make_layer(inplanes, planes, block, n_blocks, stride=1):
    downsample = None
    if stride != 1 or inplanes != planes * block.expansion:
        # output size won't match input, so adjust residual
        downsample = nn.Sequential(
            nn.Conv2d(inplanes, planes * block.expansion,
                      kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(planes * block.expansion),
        )
    return nn.Sequential(
        block(inplanes, planes, stride, downsample),
        *[block(planes * block.expansion, planes) for _ in range(1, n_blocks)]
    )


def ResNetNew(block, layers, num_classes=1000):    
    e = block.expansion
    
    resnet = nn.Sequential(
        Rearrange('b c h w -> b c h w', c=3, h=224, w=224),
        nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
        nn.BatchNorm2d(64),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        make_layer(64,      64,  block, layers[0], stride=1),
        make_layer(64 * e,  128, block, layers[1], stride=2),
        make_layer(128 * e, 256, block, layers[2], stride=2),
        make_layer(256 * e, 512, block, layers[3], stride=2),
        # combined AvgPool and view in one averaging operation
        Reduce('b c h w -> b c', 'mean'),
        nn.Linear(512 * e, num_classes),
    )
    
    # initialization
    for m in resnet.modules():
        if isinstance(m, nn.Conv2d):
            n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            m.weight.data.normal_(0, math.sqrt(2. / n))
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
    return resnet

変更点 :

  • 入力 shape に対する明示的なチェック

  • view はなく単純なシーケンシャル構造で、出力は単なる nn.Sequential ですので、常にセーブ/受け渡し/等が可能です。

  • AvgPool と追加の view は必要ありません、この箇所が今では遥かに明らかです。

  • make_layer は内部状態を使用しません (それは非常に欠陥がある箇所です)。

 

RNN 言語モデリングの改良

オリジナル

class RNNOld(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, 
                           bidirectional=bidirectional, dropout=dropout)
        self.fc = nn.Linear(hidden_dim*2, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        #x = [sent len, batch size]
        
        embedded = self.dropout(self.embedding(x))
        
        #embedded = [sent len, batch size, emb dim]
        
        output, (hidden, cell) = self.rnn(embedded)
        
        #output = [sent len, batch size, hid dim * num directions]
        #hidden = [num layers * num directions, batch size, hid dim]
        #cell = [num layers * num directions, batch size, hid dim]
        
        #concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
        #and apply dropout
        
        hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
                
        #hidden = [batch size, hid dim * num directions]
            
        return self.fc(hidden.squeeze(0))

einops 改良版

class RNNNew(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional, dropout):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=n_layers, 
                           bidirectional=bidirectional, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.directions = 2 if bidirectional else 1
        self.fc = nn.Linear(hidden_dim * self.directions, output_dim)
        
    def forward(self, x):
        #x = [sent len, batch size]        
        embedded = self.dropout(self.embedding(x))
        
        #embedded = [sent len, batch size, emb dim]
        output, (hidden, cell) = self.rnn(embedded)
        
        hidden = rearrange(hidden, '(layer dir) b c -> layer b (dir c)', 
                           dir=self.directions)
        # take the final layer's hidden
        return self.fc(self.dropout(hidden[-1]))
  • 元のコードは非双方向モデルに対しては正しく動作しません。

  • … そして bidirectional = False のときは失敗し、1 層だけしかありません。

  • コードの改変は hidden が構成される方法とそれが変更される方法の両者を示しています。

 

FastText をより高速に書く

オリジナル

class FastTextOld(nn.Module):
    def __init__(self, vocab_size, embedding_dim, output_dim):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.fc = nn.Linear(embedding_dim, output_dim)
        
    def forward(self, x):
        
        #x = [sent len, batch size]
        
        embedded = self.embedding(x)
                
        #embedded = [sent len, batch size, emb dim]
        
        embedded = embedded.permute(1, 0, 2)
        
        #embedded = [batch size, sent len, emb dim]
        
        pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) 
        
        #pooled = [batch size, embedding_dim]
                
        return self.fc(pooled)

einops 改良版

def FastTextNew(vocab_size, embedding_dim, output_dim):
    return nn.Sequential(
        Rearrange('t b -> t b'),
        nn.Embedding(vocab_size, embedding_dim),
        Reduce('t b c -> b c', 'mean'),
        nn.Linear(embedding_dim, output_dim),
        Rearrange('b c -> b c'),
    )

新しいコードについての幾つかのコメント :

  • 最初と最後の演算は何もしませんので削除できます。
    • しかし想定される入力と出力を明示的に示すために追加されました。

  • これはまた単一行の編集でインターフェイスを変更する柔軟性を与えます。入力を (batch, time) として受け取る必要があるならば、単に最初の行を Rearrange(‘b t -> t b’) に変更します。

 

テキスト分類のための CNN

オリジナル

class CNNOld(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.conv_0 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[0],embedding_dim))
        self.conv_1 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[1],embedding_dim))
        self.conv_2 = nn.Conv2d(in_channels=1, out_channels=n_filters, kernel_size=(filter_sizes[2],embedding_dim))
        self.fc = nn.Linear(len(filter_sizes)*n_filters, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        #x = [sent len, batch size]
        
        x = x.permute(1, 0)
                
        #x = [batch size, sent len]
        
        embedded = self.embedding(x)
                
        #embedded = [batch size, sent len, emb dim]
        
        embedded = embedded.unsqueeze(1)
        
        #embedded = [batch size, 1, sent len, emb dim]
        
        conved_0 = F.relu(self.conv_0(embedded).squeeze(3))
        conved_1 = F.relu(self.conv_1(embedded).squeeze(3))
        conved_2 = F.relu(self.conv_2(embedded).squeeze(3))
            
        #conv_n = [batch size, n_filters, sent len - filter_sizes[n]]
        
        pooled_0 = F.max_pool1d(conved_0, conved_0.shape[2]).squeeze(2)
        pooled_1 = F.max_pool1d(conved_1, conved_1.shape[2]).squeeze(2)
        pooled_2 = F.max_pool1d(conved_2, conved_2.shape[2]).squeeze(2)
        
        #pooled_n = [batch size, n_filters]
        
        cat = self.dropout(torch.cat((pooled_0, pooled_1, pooled_2), dim=1))

        #cat = [batch size, n_filters * len(filter_sizes)]
            
        return self.fc(cat)

einops 改良版

class CNNNew(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_filters, filter_sizes, output_dim, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.convs = nn.ModuleList([
            nn.Conv1d(embedding_dim, n_filters, kernel_size=size) for size in filter_sizes
        ])
        self.fc = nn.Linear(len(filter_sizes) * n_filters, output_dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = rearrange(x, 't b -> t b')
        emb = rearrange(self.embedding(x), 't b c -> b c t')
        pooled = [reduce(conv(emb), 'b c t -> b c', 'max') for conv in self.convs]
        concatenated = rearrange(pooled, 'filter b c -> b (filter c)')
        return self.fc(self.dropout(F.relu(concatenated)))
  • 元のコードは Conv2D を誤用していて、Conv1d が正しい選択です。

  • 修正コードは任意の数の filter_sizes で動作できます (そして失敗しません)。

  • 新しいコードの最初の行は何もしませんが、単純化のために追加されました。

 

ハイウェイ畳込み

  • ハイウェイ畳込みは TTS システムでは一般的です。下のコードは分割を少し明示的にしています。

  • 入力がチャネル軸に対して前にグループ化されていた場合 (グループ (化) 畳込み or 双方向 LSTMs/GRU)、結局は分割ポリシーが重要であることが判明するかもしれません。

  • 同じことが GLU と gated ユニットにも一般に当てはまります。

オリジナル

class HighwayConv1dOld(nn.Conv1d):
    def forward(self, inputs):
        L = super(HighwayConv1dOld, self).forward(inputs)
        H1, H2 = torch.chunk(L, 2, 1)  # chunk at the feature dim
        torch.sigmoid_(H1)
        return H1 * H2 + (1.0 - H1) * inputs

einops 改良版

class HighwayConv1dNew(nn.Conv1d):
    def forward(self, inputs):
        L = super().forward(inputs)
        H1, H2 = rearrange(L, 'b (split c) t -> split b c t', split=2)
        torch.sigmoid_(H1)
        return H1 * H2 + (1.0 - H1) * inputs

 

Tacotron の CBHG モジュール

オリジナル

class CBHG_Old(nn.Module):
    """CBHG module: a recurrent neural network composed of:
        - 1-d convolution banks
        - Highway networks + residual connections
        - Bidirectional gated recurrent units
    """

    def __init__(self, in_dim, K=16, projections=[128, 128]):
        super(CBHG, self).__init__()
        self.in_dim = in_dim
        self.relu = nn.ReLU()
        self.conv1d_banks = nn.ModuleList(
            [BatchNormConv1d(in_dim, in_dim, kernel_size=k, stride=1,
                             padding=k // 2, activation=self.relu)
             for k in range(1, K + 1)])
        self.max_pool1d = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)

        in_sizes = [K * in_dim] + projections[:-1]
        activations = [self.relu] * (len(projections) - 1) + [None]
        self.conv1d_projections = nn.ModuleList(
            [BatchNormConv1d(in_size, out_size, kernel_size=3, stride=1,
                             padding=1, activation=ac)
             for (in_size, out_size, ac) in zip(
                 in_sizes, projections, activations)])

        self.pre_highway = nn.Linear(projections[-1], in_dim, bias=False)
        self.highways = nn.ModuleList(
            [Highway(in_dim, in_dim) for _ in range(4)])

        self.gru = nn.GRU(
            in_dim, in_dim, 1, batch_first=True, bidirectional=True)

オリジナル

def forward_old(self, inputs):
    # (B, T_in, in_dim)
    x = inputs

    # Needed to perform conv1d on time-axis
    # (B, in_dim, T_in)
    if x.size(-1) == self.in_dim:
        x = x.transpose(1, 2)

    T = x.size(-1)

    # (B, in_dim*K, T_in)
    # Concat conv1d bank outputs
    x = torch.cat([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], dim=1)
    assert x.size(1) == self.in_dim * len(self.conv1d_banks)
    x = self.max_pool1d(x)[:, :, :T]

    for conv1d in self.conv1d_projections:
        x = conv1d(x)

    # (B, T_in, in_dim)
    # Back to the original shape
    x = x.transpose(1, 2)

    if x.size(-1) != self.in_dim:
        x = self.pre_highway(x)

    # Residual connection
    x += inputs
    for highway in self.highways:
        x = highway(x)

    # (B, T_in, in_dim*2)
    outputs, _ = self.gru(x)

    return outputs

einops 改良版

def forward_new(self, inputs, input_lengths=None):
    x = rearrange(inputs, 'b t c -> b c t')
    _, _, T = x.shape
    # Concat conv1d bank outputs
    x = rearrange([conv1d(x)[:, :, :T] for conv1d in self.conv1d_banks], 
                 'bank b c t -> b (bank c) t', c=self.in_dim)
    x = self.max_pool1d(x)[:, :, :T]

    for conv1d in self.conv1d_projections:
        x = conv1d(x)
    x = rearrange(x, 'b c t -> b t c')
    if x.size(-1) != self.in_dim:
        x = self.pre_highway(x)

    # Residual connection
    x += inputs
    for highway in self.highways:
        x = highway(x)

    # (B, T_in, in_dim*2)
    outputs, _ = self.gru(self.highways(x))

    return outputs    

依然として大きな改良の余地がありますが、このサンプルでは forward 関数だけが変更されました。

 

単純な attention

朗報 : 次元の順序を推測する必要はもうありません。入力に対しても出力に対しても (必要ありません)

オリジナル

class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
    
    def forward(self, K, V, Q):
        A = torch.bmm(K.transpose(1,2), Q) / np.sqrt(Q.shape[1])
        A = F.softmax(A, 1)
        R = torch.bmm(V, A)
        return torch.cat((R, Q), dim=1)

einops 改良版

def attention(K, V, Q):
    _, n_channels, _ = K.shape
    A = torch.einsum('bct,bcl->btl', [K, Q])
    A = F.softmax(A * n_channels ** (-0.5), 1)
    R = torch.einsum('bct,btl->bcl', [V, A])
    return torch.cat((R, Q), dim=1)

 

Transformer の attention は更なる注意が必要です

オリジナル

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn



class MultiHeadAttentionOld(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model)

        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)

        self.dropout = nn.Dropout(dropout)


    def forward(self, q, k, v, mask=None):
        
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
        
        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()
        
        residual = q
        
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
        
        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
        
        mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)
        
        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)
        
        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)
        
        return output, attn

einops 改良版

class MultiHeadAttentionNew(nn.Module):
    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        super().__init__()
        self.n_head = n_head
        
        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v)))
        
        self.fc = nn.Linear(n_head * d_v, d_model)
        nn.init.xavier_normal_(self.fc.weight)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, q, k, v, mask=None):
        residual = q
        q = rearrange(self.w_qs(q), 'b l (head k) -> head b l k', head=self.n_head)
        k = rearrange(self.w_ks(k), 'b t (head k) -> head b t k', head=self.n_head)
        v = rearrange(self.w_vs(v), 'b t (head v) -> head b t v', head=self.n_head)
        attn = torch.einsum('hblk,hbtk->hblt', [q, k]) / np.sqrt(q.shape[-1])
        if mask is not None:
            attn = attn.masked_fill(mask[None], -np.inf)
        attn = torch.softmax(attn, dim=3)
        output = torch.einsum('hblt,hbtv->hblv', [attn, v])
        output = rearrange(output, 'head b l v -> b l (head v)')
        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)
        return output, attn

新しい実装の利点

  • 2 つではなく、1 つのモジュールを持ちます。

  • 新しいコードは None マスクに対して失敗しません。

  • 削除した元のコードの注意書き (= caveats) の総量は膨大です。コメントを消去してそこで何が起きているか解読してみてください。

 

自己 attention GAN

SAGAN は現在、画像生成のための SotA で、同じトリックを使用して単純化できます。

オリジナル

class Self_Attn_Old(nn.Module):
    """ Self attention Layer"""
    def __init__(self,in_dim,activation):
        super(Self_Attn_Old,self).__init__()
        self.chanel_in = in_dim
        self.activation = activation
        
        self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
        self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
        self.softmax  = nn.Softmax(dim=-1) #

    def forward(self, x):
        """
            inputs :
                x : input feature maps( B X C X W X H)
            returns :
                out : self attention value + input feature 
                attention: B X N X N (N is Width*Height)
        """
        
        m_batchsize,C,width ,height = x.size()
        proj_query  = self.query_conv(x).view(m_batchsize,-1,width*height).permute(0,2,1) # B X CX(N)
        proj_key =  self.key_conv(x).view(m_batchsize,-1,width*height) # B X C x (*W*H)
        energy =  torch.bmm(proj_query,proj_key) # transpose check
        attention = self.softmax(energy) # BX (N) X (N) 
        proj_value = self.value_conv(x).view(m_batchsize,-1,width*height) # B X C X N

        out = torch.bmm(proj_value,attention.permute(0,2,1) )
        out = out.view(m_batchsize,C,width,height)
        
        out = self.gamma*out + x
        return out,attention

einops 改良版

class Self_Attn_New(nn.Module):
    """ Self attention Layer"""
    def __init__(self, in_dim):
        super().__init__()
        self.query_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, out_channels=in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros([1]))

    def forward(self, x):
        proj_query = rearrange(self.query_conv(x), 'b c h w -> b (h w) c')
        proj_key = rearrange(self.key_conv(x), 'b c h w -> b c (h w)')
        proj_value = rearrange(self.value_conv(x), 'b c h w -> b (h w) c')
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=2)
        out = torch.bmm(attention, proj_value)
        out = x + self.gamma * rearrange(out, 'b (h w) c -> b c h w',
                                         **parse_shape(x, 'b c h w'))
        return out, attention

 

時系列予測の改良

このサンプルは単純過ぎると思われましたが、どのような入力が想定されるかを理解するために周辺コードを解析する必要がありました。貴方自身で試すことができます。

更に今ではコードは double だけでなく、任意の dtype で動作します ; そして新しいコードは GPU の使用をサポートしています。

オリジナル

class SequencePredictionOld(nn.Module):
    def __init__(self):
        super(SequencePredictionOld, self).__init__()
        self.lstm1 = nn.LSTMCell(1, 51)
        self.lstm2 = nn.LSTMCell(51, 51)
        self.linear = nn.Linear(51, 1)

    def forward(self, input, future = 0):
        outputs = []
        h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
        c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
        h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
        c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)

        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
            
        for i in range(future):# if we should predict the future
            h_t, c_t = self.lstm1(output, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
        outputs = torch.stack(outputs, 1).squeeze(2)
        return outputs

einops 改良版

class SequencePredictionNew(nn.Module):
    def __init__(self):
        super(SequencePredictionNew, self).__init__()
        self.lstm1 = nn.LSTMCell(1, 51)
        self.lstm2 = nn.LSTMCell(51, 51)
        self.linear = nn.Linear(51, 1)

    def forward(self, input, future=0):
        b, t = input.shape
        h_t, c_t, h_t2, c_t2 = torch.zeros(4, b, 51, dtype=self.linear.weight.dtype, 
                                           device=self.linear.weight.device)

        outputs = []
        for input_t in rearrange(input, 'b t -> t b ()'):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
            
        for i in range(future): # if we should predict the future
            h_t, c_t = self.lstm1(output, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]
        return rearrange(outputs, 't b () -> b t')

 

spacial transformer ネットワーク (STN) の変形

オリジナル

class SpacialTransformOld(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )

        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        xs = self.localization(x)
        xs = xs.view(-1, 10 * 3 * 3)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)

        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)

        return x

einops 改良版

class SpacialTransformNew(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # Spatial transformer localization-network
        linear = nn.Linear(32, 3 * 2)
        # Initialize the weights/bias with identity transformation
        linear.weight.data.zero_()
        linear.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
        
        self.compute_theta = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            Rearrange('b c h w -> b (c h w)', h=3, w=3),
            nn.Linear(10 * 3 * 3, 32),
            nn.ReLU(True),
            linear,
            Rearrange('b (row col) -> b row col', row=2, col=3),
        )

    # Spatial transformer network forward function
    def stn(self, x):
        grid = F.affine_grid(self.compute_theta(x), x.size())
        return F.grid_sample(x, grid)
  • 新しいコードは、渡された画像サイズが想定とは異なるとき妥当なエラーを与えます。

  • バッチサイズが 18 で割り切れる場合、古いコードでどのような入力であれ、それは affine_grid で直ちに失敗します。

 

GLOW の改良

それは手動で書かれた古き良き depth-to-space です!

GLOW は revertible ですから、rearrange ライクな演算に頻繁に依存します。

オリジナル

def unsqueeze2d_old(input, factor=2):
    assert factor >= 1 and isinstance(factor, int)
    factor2 = factor ** 2
    if factor == 1:
        return input
    size = input.size()
    B = size[0]
    C = size[1]
    H = size[2]
    W = size[3]
    assert C % (factor2) == 0, "{}".format(C)
    x = input.view(B, C // factor2, factor, factor, H, W)
    x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
    x = x.view(B, C // (factor2), H * factor, W * factor)
    return x

def squeeze2d_old(input, factor=2):
    assert factor >= 1 and isinstance(factor, int)
    if factor == 1:
        return input
    size = input.size()
    B = size[0]
    C = size[1]
    H = size[2]
    W = size[3]
    assert H % factor == 0 and W % factor == 0, "{}".format((H, W))
    x = input.view(B, C, H // factor, factor, W // factor, factor)
    x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
    x = x.view(B, C * factor * factor, H // factor, W // factor)
    return x

einops 改良版

def unsqueeze2d_new(input, factor=2):
    return rearrange(input, 'b (c h2 w2) h w -> b c (h h2) (w w2)', h2=factor, w2=factor)

def squeeze2d_new(input, factor=2):
    return rearrange(input, 'b c (h h2) (w w2) -> b (c h2 w2) h w', h2=factor, w2=factor)
  • 用語 squeeze はそれほど役立ちません ; どの次元が squeeze されるのでしょう?torch.squeeze がありますが、それは全く異なります。

  • 実際、関数を作成することを完全にスキップできます – それはとにかく einops への単純な呼び出しです。

 

YOLO 検出の問題の検出

オリジナル

def YOLO_prediction_old(input, num_classes, num_anchors, anchors, stride_h, stride_w):
    bs = input.size(0)
    in_h = input.size(2)
    in_w = input.size(3)
    scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in anchors]

    prediction = input.view(bs, num_anchors,
                            5 + num_classes, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
    # Get outputs
    x = torch.sigmoid(prediction[..., 0])  # Center x
    y = torch.sigmoid(prediction[..., 1])  # Center y
    w = prediction[..., 2]  # Width
    h = prediction[..., 3]  # Height
    conf = torch.sigmoid(prediction[..., 4])  # Conf
    pred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.

    FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
    # Calculate offsets for each grid
    grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_w, 1).repeat(
        bs * num_anchors, 1, 1).view(x.shape).type(FloatTensor)
    grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_h, 1).t().repeat(
        bs * num_anchors, 1, 1).view(y.shape).type(FloatTensor)
    # Calculate anchor w, h
    anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
    anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
    anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
    anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
    # Add offset and scale with anchors
    pred_boxes = FloatTensor(prediction[..., :4].shape)
    pred_boxes[..., 0] = x.data + grid_x
    pred_boxes[..., 1] = y.data + grid_y
    pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
    pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
    # Results
    _scale = torch.Tensor([stride_w, stride_h] * 2).type(FloatTensor)
    output = torch.cat((pred_boxes.view(bs, -1, 4) * _scale,
                        conf.view(bs, -1, 1), pred_cls.view(bs, -1, num_classes)), -1)
    return output

einops 改良版

def YOLO_prediction_new(input, num_classes, num_anchors, anchors, stride_h, stride_w):
    raw_predictions = rearrange(input, 'b (anchor prediction) h w -> prediction b anchor h w', 
                                anchor=num_anchors, prediction=5 + num_classes)
    anchors = torch.FloatTensor(anchors).to(input.device)
    anchor_sizes = rearrange(anchors, 'anchor dim -> dim () anchor () ()')

    _, _, _, in_h, in_w = raw_predictions.shape
    grid_h = rearrange(torch.arange(in_h).float(), 'h -> () () h ()').to(input.device)
    grid_w = rearrange(torch.arange(in_w).float(), 'w -> () () () w').to(input.device)

    predicted_bboxes = torch.zeros_like(raw_predictions)
    predicted_bboxes[0] = (raw_predictions[0].sigmoid() + grid_w) * stride_w  # center x
    predicted_bboxes[1] = (raw_predictions[1].sigmoid() + grid_h) * stride_h  # center y
    predicted_bboxes[2:4] = (raw_predictions[2:4].exp()) * anchor_sizes  # bbox width and height
    predicted_bboxes[4] = raw_predictions[4].sigmoid()  # confidence
    predicted_bboxes[5:] = raw_predictions[5:].sigmoid()  # class predictions
    # merging all predicted bboxes for each image
    return rearrange(predicted_bboxes, 'prediction b anchor h w -> b (anchor h w) prediction')

多くを変更して修正しました :

  • 入力が最初の GPU になくても新しいコードは失敗しません。

  • 古いコードは非正方形画像に対して誤った grid_x と grid_y を持ちます。

  • 新しいコードはブロードキャストが十分であるとき、レプリケーションを使用しません。

  • 古いコードは時々奇妙に .data を取りますが、これは実際の効果はありません、幾つかの分岐は最後まで勾配を保持するからです。
    • 勾配が必要でない場合、torch.no_grad が使用されるべきですから、それは冗長です。

 

多くの絵 (= pictures) のための単純な出力

次に生成モデルの絵を出力する必要があるとき、このトリックが使用できます。

オリジナル

device = 'cpu'
plt.imshow(np.transpose(vutils.make_grid(fake_batch.to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

einops 改良版

padded = F.pad(fake_batch[:64], [1, 1, 1, 1])
plt.imshow(rearrange(padded, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=8).cpu())
 

以上