PyTorch 1.5 レシピ : TorchScript : Flask で配備する (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/20/2020 (1.5.0)
* 本ページは、PyTorch 1.5 Recipes の以下のページを翻訳した上で適宜、補足説明したものです:
- Production, TorchScript : Deploying with Flask
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
- Windows PC のブラウザからご参加が可能です。スマートデバイスもご利用可能です。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション |
E-Mail:sales-info@classcat.com ; WebSite: https://www.classcat.com/ |
Facebook: https://www.facebook.com/ClassCatJP/ |
プロダクション, TorchScript : Flask で配備する
このレシピでは、以下を学習します :
- 訓練された PyTorch モデルを web API を通して公開するために Flask コンテナでどのようにラップするか
- incoming web リクエストをモデルのために PyTorch tensor にどのように変換するか
- HTTP レスポンスのためにモデルの出力をどのようにパッケージ化するか
要件
以下のパッケージ (と依存性) がインストールされた Python 3 環境が必要です :
- PyTorch 1.5
- TorchVision 0.6.0
- Flask 1.1
オプションで、補助的ファイルの幾つかを得るために、git が必要です。
PyTorch と TorchVision をインストールする手順は pytorch.org で利用可能です。Flask をインストールするための手順は Flask サイト で利用可能です。
Flask とは何でしょう?
Flask は Python で書かれた軽量 web サーバです。それは直接的な利用のため、または巨大なシステム内の web サービスとして、訓練された PyTorch モデルからの予測のための web API を素早くセットアップする便利な方法を提供します。
セットアップと補助的なファイル
画像を取り、それらを ImageNet データセットの 1000 クラスの一つにマップする web サービスを作成します。
これを行なうため、テストのための画像ファイルが必要です。オプションで、モデルによるクラスインデックス出力を可読なクラス名にマップするファイルも得ることができます。
オプション 1: 両者のファイルを素早く取得する
TorchServe レポジトリをチェックアウトしてそれらを作業フォルダにコピーすることで補助的なファイルの両者を素早く引っ張ることができます。(NB: このチュートリアルのために TorchServe 上の依存性はありません – それは単にファイルを得るための素早い方法です。) シェルプロンプトから以下のコマンドを発行します :
git clone https://github.com/pytorch/serve cp serve/examples/image_classifier/kitten.jpg . cp serve/examples/image_classifier/index_to_name.json .
And you’ve got them!
オプション 2: 貴方自身の画像を持ち込む
下の Flask サービスで index_to_name.json ファイルはオプションです。貴方自身の画像でサービスをテストすることができます – 単にそれが 3-色 JPEG であることを確実にしてください。
Flask サービスを構築する
Flask サービスのための完全な Python スクリプトはこのレシピの最後で示されます ; それを貴方自身の app.py ファイルにコピーしてペーストできます。下でそれらの関数を明白にするために個々のセクションを見ます。
インポート
import torchvision.models as models import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request
In order:
- torchvision.models から事前訓練された DenseNet モデルを利用しています。
- torchvision.transforms は画像データを操作するためのツールを含みます。
- Pillow (PIL) は初期的に画像ファイルをロードするために使用するものです。
- そしてもちろん flask からのクラスが必要です。
前処理
def transform_image(infile): input_transforms = [transforms.Resize(255), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] my_transforms = transforms.Compose(input_transforms) image = Image.open(infile) timg = my_transforms(image) timg.unsqueeze_(0) return timg
web リクエストは画像ファイルを与えますが、モデルは shape (N, 3, 224, 224) の PyTorch tensor を想定しています、ここで N は入力バッチの項目数です (私達は単に 1 のバッチサイズを持ちます)。私達が行なう最初のことは TorchVision 変換のセットを構成することです、これは画像をリサイズしてクロップし、それを tensor に変換してから、tensor の値を正規化します。(この正規化のより多くの情報については、torchvision.models_ のためのドキュメント参照)
その後、ファイルを開いて変換を適用します。変換は shape (3, 224, 224) の tensor を返します – 224×224 画像の 3 色チャネルです。この単一画像をバッチにする必要がありますので、新しい最初の次元を追加するために tensor をその場で変更するために unsqueeze_(0) を使用します。tensor は同じデータを含みますが、今は shape (1, 3, 224, 224) を持ちます。
一般に、画像データで作業していない場合でさえ、入力を HTTP リクエストから PyTorch が消費できる tensor に変換する必要があります。
推論
def get_prediction(input_tensor): outputs = model.forward(input_tensor) _, y_hat = outputs.max(1) prediction = y_hat.item() return prediction
推論自身は最も単純なパートです : 入力 tensor をモデルに渡すとき、画像が特定のクラスに属する、モデルの推定尤度を表す値の tensor を代わりに得ます。max() 呼び出しは最大尤度値を持つクラスを見つけて ImageNet クラスインデックスを持つ値を返します。最後に、それを含む tensor から item() 呼び出しでクラスインデックスを抽出してそれを返します。
後処理
def render_prediction(prediction_idx): stridx = str(prediction_idx) class_name = 'Unknown' if img_class_map is not None: if stridx in img_class_map is not None: class_name = img_class_map[stridx][1] return prediction_idx, class_name
render_prediction メソッドは予測されたクラスインデックスを可読なクラスラベルにマップします。モデルから予測を得た後、可読な消費のためかソフトウェアの他のピースのために準備するために後処理を遂行することは典型的です。
完全な Flask App を実行する
以下を app.py という名前のファイルにペーストしてください :
import io import json import os import torchvision.models as models import torchvision.transforms as transforms from PIL import Image from flask import Flask, jsonify, request app = Flask(__name__) model = models.densenet121(pretrained=True) # Trained on 1000 classes from ImageNet model.eval() # Turns off autograd and img_class_map = None mapping_file_path = 'index_to_name.json' # Human-readable names for Imagenet classes if os.path.isfile(mapping_file_path): with open (mapping_file_path) as f: img_class_map = json.load(f) # Transform input into the form our model expects def transform_image(infile): input_transforms = [transforms.Resize(255), # We use multiple TorchVision transforms to ready the image transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], # Standard normalization for ImageNet model input [0.229, 0.224, 0.225])] my_transforms = transforms.Compose(input_transforms) image = Image.open(infile) # Open the image file timg = my_transforms(image) # Transform PIL image to appropriately-shaped PyTorch tensor timg.unsqueeze_(0) # PyTorch models expect batched input; create a batch of 1 return timg # Get a prediction def get_prediction(input_tensor): outputs = model.forward(input_tensor) # Get likelihoods for all ImageNet classes _, y_hat = outputs.max(1) # Extract the most likely class prediction = y_hat.item() # Extract the int value from the PyTorch tensor return prediction # Make the prediction human-readable def render_prediction(prediction_idx): stridx = str(prediction_idx) class_name = 'Unknown' if img_class_map is not None: if stridx in img_class_map is not None: class_name = img_class_map[stridx][1] return prediction_idx, class_name @app.route('/', methods=['GET']) def root(): return jsonify({'msg' : 'Try POSTing to the /predict endpoint with an RGB image attachment'}) @app.route('/predict', methods=['POST']) def predict(): if request.method == 'POST': file = request.files['file'] if file is not None: input_tensor = transform_image(file) prediction_idx = get_prediction(input_tensor) class_id, class_name = render_prediction(prediction_idx) return jsonify({'class_id': class_id, 'class_name': class_name}) if __name__ == '__main__': app.run()
シェルプロンプトからサーバを開始するため、次のコマンドを発行します :
FLASK_APP=app.py flask run
デフォルトでは、Flask サーバはポート 5000 でリスンします。ひとたびサーバが動作していれば、他の端末ウィンドウを開いて、新しい推論サーバをテストしてください :
curl -X POST -H "Content-Type: multipart/form-data" http://localhost:5000/predict -F "file=@kitten.jpg"
総てが正しくセットアップされれば、以下に類似のレスポンスを受け取るはずです :
{"class_id":285,"class_name":"Egyptian_cat"}
以上