PyTorch : Tutorial 上級 : numpy と scipy を使用してエクステンションを作成する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/18/2018 (0.4.0)
* 本ページは、PyTorch Intermidiate Tutorials の – Creating extensions using numpy and scipy
を
動作確認・翻訳した上で適宜、補足説明したものです:
* サンプルコードの動作確認はしておりますが、適宜、追加改変している場合もあります。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
本文
このチュートリアルでは、2 つのタスクを通り抜けます :
- パラメータを持たないニューラルネットワーク層を作成します。
- これはその実装の一部として numpy を呼び出します。
- 学習可能な重みを持つニューラルネットワーク層を作成します。
- これはその実装の一部として SciPy を呼び出します。
import torch from torch.autograd import Function
Parameter-less サンプル
この層は有用であったり数学的に正しいことは特に何もしません。
それはふさわしく BadFFTFunction と命名されます。
層実装
from numpy.fft import rfft2, irfft2 class BadFFTFunction(Function): def forward(self, input): numpy_input = input.detach().numpy() result = abs(rfft2(numpy_input)) return input.new(result) def backward(self, grad_output): numpy_go = grad_output.numpy() result = irfft2(numpy_go) return grad_output.new(result) # この層はどのようなパラメータも持ちませんので、 # これを nn.Module クラスとしてではなく、単に関数として宣言できます。 def incorrect_fft(input): return BadFFTFunction()(input)
作成された層の使用例:
input = torch.randn(8, 8, requires_grad=True) result = incorrect_fft(input) print(result) result.backward(torch.randn(result.size())) print(input)
Out:
tensor([[ 1.0175, 5.9401, 9.4193, 7.9451, 1.4826], [ 2.7564, 5.7768, 3.3575, 16.3103, 9.5395], [ 6.6113, 5.4757, 8.4662, 5.3982, 3.0258], [ 12.0516, 6.3465, 5.2812, 7.1785, 8.8942], [ 10.6671, 8.8123, 2.1455, 2.3794, 3.1276], [ 12.0516, 9.6364, 6.3934, 5.7216, 8.8942], [ 6.6113, 6.7071, 8.3808, 4.1126, 3.0258], [ 2.7564, 12.5046, 8.5980, 7.9262, 9.5395]]) tensor([[-0.4980, -0.8297, -1.3876, -1.4957, 0.9741, 0.1898, -0.4307, -0.5218], [-0.0814, 0.3802, -1.7602, -0.1787, 0.4242, -1.4913, 0.7898, 0.6024], [ 1.0963, 0.4286, 0.8050, -0.6570, -0.5148, -1.9187, 1.7108, 0.4385], [-0.2829, 0.6173, 1.0591, 1.1673, -1.4923, -0.7836, 0.3336, -0.5917], [ 1.4302, -0.8397, -0.1708, 1.4137, -1.5305, -0.4111, -0.2065, 0.6902], [ 1.6454, -0.8543, 0.2880, 0.6782, -0.3345, 0.3596, -0.8991, 0.2772], [-1.0208, -0.9278, -2.0031, 1.0153, -0.6754, 1.0043, 0.4205, -0.4027], [ 0.6563, 0.5535, 1.5821, 1.4493, 0.0675, 0.4383, -0.2272, 1.4501]])
パラメータ化されたサンプル
これは学習可能な重みを持つ層を実装します。
それは学習可能なカーネルを持つ相互相関を実装します。
深層学習の文献では、それは (混乱を引き起こしますが) 畳み込みとして参照されます。
backward は入力に関する勾配とフィルタに関する勾配を計算します。
実装:
実装は例として役立ちます、そしてそれが正確であることは検証しなかったことに注意してください。
from scipy.signal import convolve2d, correlate2d from torch.nn.modules.module import Module from torch.nn.parameter import Parameter class ScipyConv2dFunction(Function): @staticmethod def forward(ctx, input, filter): input, filter = input.detach(), filter.detach() # detach so we can cast to NumPy result = correlate2d(input.numpy(), filter.detach().numpy(), mode='valid') ctx.save_for_backward(input, filter) return input.new(result) @staticmethod def backward(ctx, grad_output): grad_output = grad_output.detach() input, filter = ctx.saved_tensors grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full') grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid') return grad_output.new_tensor(grad_input), grad_output.new_tensor(grad_filter) class ScipyConv2d(Module): def __init__(self, kh, kw): super(ScipyConv2d, self).__init__() self.filter = Parameter(torch.randn(kh, kw)) def forward(self, input): return ScipyConv2dFunction.apply(input, self.filter)
使用例:
module = ScipyConv2d(3, 3) print(list(module.parameters())) input = torch.randn(10, 10, requires_grad=True) output = module(input) print(output) output.backward(torch.randn(8, 8)) print(input.grad)
Out:
[Parameter containing: tensor([[ 0.4886, -0.4155, 0.5584], [ 2.2424, 0.4452, 0.3586], [-0.2167, 0.9405, 0.2488]])] tensor([[-2.9912, -0.5662, 3.7414, 3.6055, 3.8686, -2.1231, 0.3196, 5.1846], [ 0.5401, -1.8301, -2.9379, 2.0127, 5.4983, 6.4266, 3.1789, -2.1868], [-0.7909, 4.1522, -2.9417, -1.4666, -3.1675, 2.5134, 1.4663, 2.6334], [ 2.3376, 1.7667, -0.1113, 1.3432, -3.7815, 0.3749, -2.7867, 2.7936], [ 1.3951, 2.9989, 0.9232, -0.5367, -0.2340, -2.2505, -1.6442, -3.2447], [-3.3345, -1.6093, 0.0072, 1.4827, 0.4360, -2.7343, -3.7304, -1.4646], [-2.9499, -5.7625, -3.3628, -0.2274, 1.4422, 3.6559, -5.8038, -1.2585], [-3.0484, -5.3444, 0.0130, -2.6582, -0.9732, 3.1847, 0.8779, -6.1164]]) tensor([[ 0.4321, 1.8663, -0.2998, 2.0890, 0.1624, 0.1035, 0.8119, 3.0118, 1.4538, -0.1654], [-0.2396, 1.0648, -0.1204, -1.5785, 0.8905, -0.0230, -3.6403, -0.5714, 2.0351, 0.6735], [ 1.0622, 3.7028, 2.7420, 2.5183, -1.2862, -0.2742, 0.8661, 2.0489, -0.8798, 0.4524], [-1.3321, -4.0190, -0.5341, 0.1137, 0.5872, 1.9134, 0.2770, 0.6109, 3.5335, -0.5052], [ 2.1476, 2.5968, -2.7266, 2.6567, 2.3867, -1.0325, 2.9208, 1.3676, -1.8782, 1.2381], [-2.2506, -2.6762, 3.2016, -1.1409, -2.2895, 2.7274, 1.9179, 1.9295, 0.9343, -0.8167], [ 0.8443, -3.0849, 1.4744, 6.8210, -0.9215, -2.9612, 0.0753, -2.3565, 1.1515, -0.1838], [-0.3818, -0.2170, -0.3137, 2.7301, -4.2095, -1.7367, -2.2464, -4.8457, 1.1600, 0.3315], [-0.5465, -0.0006, 0.4709, 1.9905, 0.2919, -2.4889, -0.7931, -2.0214, -1.8382, 0.7788], [ 0.0480, 0.0116, 0.7040, -0.8114, -0.7938, -1.0880, -1.5437, -0.5219, -0.2780, 0.1757]])
以上