PyTorch 1.4 Tutorials : PyTorch モデル配備 : (オプション) PyTorch から ONNX へモデルをエクスポートして ONNX ランタイムを使用してそれを実行する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 01/21/2020 (1.4.0)
* 本ページは、PyTorch 1.4 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:
- Deploying PyTorch Models in Production : (optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
PyTorch モデル配備 : (オプション) PyTorch から ONNX へモデルをエクスポートして ONNX ランタイムを使用してそれを実行する
このチュートリアルでは、PyTorch で定義されたモデルを ONNX フォーマットにどのように変換してそれを ONNX ランタイムで実行するかを説明します。
ONNX ランタイムは ONNX モデルのためのパフォーマンスにフォーカスしたエンジンで、これは複数のプラットフォームとハードウェアに渡り効率的に推論します (Windows, Linux と Mac そして CPU と GPU の両者の上で)。ONNX ランタイムは ここ で説明されているように複数のモデルに渡りパフォーマンスをかなり増加することが証明されています。
このチュートリアルのために、ONNX と ONNX ランタイム をインストールする必要があります。”pip install onnx onnxruntime” で ONNX と ONNX ランタイムのバイナリビルドを得ることができます。ONNX ランタイムは Python バージョン 3.5 から 3.7 と互換であることに注意してください。
NOTE: このチュートリアルは PyTorch マスター・ブランチが必要です、それは ここ の手順をフォローしてインストールできます。
# Some standard imports import io import numpy as np from torch import nn import torch.utils.model_zoo as model_zoo import torch.onnx
超解像 (= super-resolution) は画像、動画の解像度を増大させる方法で画像処理や動画編集で広く使用されます。このチュートリアルのために、小さい超解像モデルを使用します。
最初に、PyTorch で SuperResolution モデルを作成しましょう。このモデルは画像の解像度をアップスケール因子で増大させるために “Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network” – Shi et al で説明されている efficient sub-pixel 畳込み層を使用します。モデルは入力として画像の YCbCr の Y 成分を想定し、そして超解像度でアップスケールされた Y 成分を出力します。
このモデル は PyTorch の examples から変更なしに直接的に由来しています :
# Super Resolution model definition in PyTorch import torch.nn as nn import torch.nn.init as init class SuperResolutionNet(nn.Module): def __init__(self, upscale_factor, inplace=False): super(SuperResolutionNet, self).__init__() self.relu = nn.ReLU(inplace=inplace) self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self._initialize_weights() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pixel_shuffle(self.conv4(x)) return x def _initialize_weights(self): init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv4.weight) # Create the super-resolution model by using the above model definition. torch_model = SuperResolutionNet(upscale_factor=3)
通常は、このモデルを今訓練するでしょう ; けれども、このチュートリアルのために、代わりにある事前訓練された重みをダウンロードします。このモデルは良い精度のために完全には訓練されていません、そしてここではデモ目的のみのために使用されることに注意してください。
モデルを推論モードに変えるためモデルをエクスポートする前に torch_model.eval() または torch_model.train(False) を呼び出すことは重要です。これは必要です、何故ならば dropout や batchnorm のような演算子は推論と訓練モードでは異なる動作をするからです。
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth' batch_size = 1 # just a random number # Initialize model with the pretrained weights map_location = lambda storage, loc: storage if torch.cuda.is_available(): map_location = None torch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location)) # set the model to inference mode torch_model.eval()
PyTorch のモデルのエクスポートは tracing かスクリプティングを通して動作します。このチュートリアルはサンプルとして tracing によりエクスポートされたモデルを使用します。モデルをエクスポートするためには、torch.onnx.export() 関数を呼び出します。これはモデルを実行し、出力を計算するためにどの演算子が使用されるかの trace を記録します。export はモデルを実行しますので、入力 tensor x を提供する必要があります。これの値はランダムであり得ます、それが正しい型とサイズでありさえすれば。総ての入力次元についてエクスポートされた ONNX グラフでは入力サイズは固定されることに注意してください、動的軸 (= dynamic axes) として指定されない限りは。このサンプルではモデルを batch_size 1 の入力でエクスポートしていますが、torch.onnx.export() の dynamic_axes パラメータで最初の次元を動的として指定します。そしてエクスポートされたモデルはサイズ [batch_size, 1, 224, 224] の入力を受け取ります、そこでは batch_size は可変であり得ます。
PyTorch の export インターフェイスについての更なる詳細を学習するには、torch.onnx ドキュメント を調べてください。
# Input to the model x = torch.randn(batch_size, 1, 224, 224, requires_grad=True) torch_out = torch_model(x) # Export the model torch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "super_resolution.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}})
また torch_out、モデルの出力も計算します、これをエクスポートしたモデルが ONNX ランタイムで実行されるときに同じ値を計算するかを検証するために使用します。
しかし ONNX ランタイムによるモデルの出力を検証する前に、ONNX API により ONNX モデルを確認します。最初に、onnx.load(“super_resolution.onnx”) はセーブされたモデルをロードして onnx.ModelProto 構造を出力します(ML モデルをバンドルするためのトップレベルのファイル/コンテナ形式です。より多くの情報については onnx.proto ドキュメント)。それから、onnx.checker.check_model(onnx_model) はモデルの構造を検証してモデルが正当なスキーマを持つことを確かめます。ONNX の妥当性はモデルのバージョン、グラフ構造、加えてノードとそれらの入力と出力をチェックすることにより検証されます。
import onnx onnx_model = onnx.load("super_resolution.onnx") onnx.checker.check_model(onnx_model)
今は ONNX ランタイムの Python API を使用して出力を計算しましょう。このパートは通常は別のプロセスかもう一つのマシン上で行なうことができますが、同じプロセスで続けます、その結果 ONNX ランタイムと PyTorch がネットワークのために同じ値を計算していることを検証できます。
モデルを ONNX ランタイムで実行するため、選択された configuration パラメータでモデルのための推論セッションを作成する必要があります (ここではデフォルト config を使用します)。ひとたびセッションが作成されれば、run() api を使用してモデルを評価します。この呼び出しの出力は ONNX ランタイムにより計算されたモデルの出力を含むリストです。
import onnxruntime ort_session = onnxruntime.InferenceSession("super_resolution.onnx") def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() # compute ONNX Runtime output prediction ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)} ort_outs = ort_session.run(None, ort_inputs) # compare ONNX Runtime and PyTorch results np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) print("Exported model has been tested with ONNXRuntime, and the result looks good!")
PyTorch と ONNX ランタイム実行の出力が数値的に与えられた精度 (rtol=1e-03 と atol=1e-05) で一致することを見ずはずです。ついでに言うと、それらが一致しないのであれば ONNX exporter に問題がありますのでその場合には私達にコンタクトしてください。
ONNX ランタイムを使用して画像上でモデルを実行する
ここまで PyTorch からモデルをエクスポートしてそれを入力としてダミー tensor を伴い ONNX ランタイムでどのようにロードしてそれを実行するかを示しました。
このチュートリアルのために、下のように見える広く使われている有名なネコ画像を使用します。
最初に、画像をロードして、それを標準的な PIL python ライブラリを使用して前処理しましょう。この前処理はニューラルネットワークを訓練/テストするためのデータ処理の標準的な実践であることに注意してください。
最初にモデルの入力 (224×224) のサイズに fit させるために画像をリサイズします。それから画像をその Y, Cb と Cr 成分に分割します。これらの成分はグレースケール画像 (Y)、と青色差 (Cb) と赤色差 (Cr) クロマ成分です。Y 成分は人間の目により敏感ですので、私達はこの成分に関心があります、これを変換していきます。Y 成分を抽出した後、それをモデルの入力である tensor に変換します。
from PIL import Image import torchvision.transforms as transforms img = Image.open("./_static/img/cat.jpg") resize = transforms.Resize([224, 224]) img = resize(img) img_ycbcr = img.convert('YCbCr') img_y, img_cb, img_cr = img_ycbcr.split() to_tensor = transforms.ToTensor() img_y = to_tensor(img_y) img_y.unsqueeze_(0)
今は、次のステップとして、グレースケールのりサイズされたネコ画像を表す tensor を取りそして前に説明したように ONNX ランタイムで超解像モデルを実行しましょう。
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img_y)} ort_outs = ort_session.run(None, ort_inputs) img_out_y = ort_outs[0]
この時点で、モデルの出力は tensor です。今は、出力 tensor から最後の出力画像を構築し戻すためにモデルの出力を処理して、そして画像をセーブします。後処理ステップは ここ の超解像モデルの PyTorch 実装から採用されています。
img_out_y = Image.fromarray(np.uint8((img_out_y[0] * 255.0).clip(0, 255)[0]), mode='L') # get the output image follow post-processing step from PyTorch implementation final_img = Image.merge( "YCbCr", [ img_out_y, img_cb.resize(img_out_y.size, Image.BICUBIC), img_cr.resize(img_out_y.size, Image.BICUBIC), ]).convert("RGB") # Save the image, we will compare this with the output image from mobile device final_img.save("./_static/img/cat_superres_with_ort.jpg")
ONNX ランタイムはクロスプラットフォーム・エンジンで、複数のプラットフォームに渡り CPU と GPU の両者の上でそれを実行できます。
ONNX ランタイムはまた Azure Machine Learning サービスを使用してモデル推論のためにクラウドに配備できます。より多くの情報は こちら です。
ONNX ランタイムのパフォーマンスについてのより多くの情報は こちら です。
ONNX ランタイムについてのより多くの情報は こちら です。
以上