MXNet : Keras バックエンドとしての MXNet

MXNet : Keras バックエンドとしての MXNet

作成 : (株)クラスキャット セールスインフォメーション
日時 : 06/14/2017

 

深層学習の実装を Keras で行なわれている方も多いかと思いますが、そのバックエンドとして、Theano, TensorFlow に続いて MXNet も利用可能になりました。

以前から MXNet の Keras のバックエンド・サポートのための作業は DMLC (Distributed (Deep) Machine Learning Community) により行なわれていましたが、先日 Beta リリースとなりました。Keras の開発者である François Chollet (@fchollet) 氏もその旨ツイートされていました。

そして AWS の Principal Tech Evangelist の Julien Simon (@julsimon) 氏がその技術的な詳細について以下のブログ記事で紹介しています :

せっかくですので、簡単に試してみたいと思います。

なお、経緯の詳細は以下の Keras/MXNet それぞれの github issue を参照してください :

 

インストール手順

基本的には上記のブログ記事を参考にすれば良いです。
* MXNet はインストール済みとします。
* Ubuntu trusty 上 python 2 の virtualenv 環境を使用しています。
* 今回は Beta 版ということもあり CPU 環境でのみ動作検証しました。

(1) dmlc/keras を git clone します :

git clone https://github.com/dmlc/keras.git

(2) python setup.py install を実行します :

$ cd keras
$ python setup.py install

以下のように終了すれば成功です :

Installed .../venv/lib/python2.7/site-packages/scipy-0.19.0-py2.7-linux-x86_64.egg
Searching for numpy==1.13.0
Best match: numpy 1.13.0
Adding numpy 1.13.0 to easy-install.pth file

Using .../venv/lib/python2.7/site-packages
Finished processing dependencies for Keras==1.2.2

(3) 最後に ~/.keras/keras.json を以下のように編集します :

{
    "image_dim_ordering": "tf", 
    "epsilon": 1e-07, 
    "floatx": "float32", 
    "backend": "mxnet"
}

image_dim_ordering は “tf” のままで良く、backend として “mxnet” を設定すれば完了です。

 

サンプル実行

まずはバージョン確認です :

$ python
Python 2.7.6 (default, Oct 26 2016, 20:30:19) 
[GCC 4.8.4] on linux2
Type "help", "copyright", "credits" or "license" for more information.
>>> import mxnet
>>> mxnet.__version__
'0.10.1'
>>> import keras
Using MXNet backend.
>>> keras.__version__
'1.2.2'

続いて定番の MNIST の MLP サンプルを実行します :

>>> execfile("examples/mnist_mlp.py")
60000 train samples
10000 test samples
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
====================================================================================================
dense_1 (Dense)                  (None, 512)           401920      dense_input_1[0][0]              
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 512)           0           dense_1[0][0]                    
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 512)           0           activation_1[0][0]               
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 512)           262656      dropout_1[0][0]                  
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 512)           0           dense_2[0][0]                    
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 512)           0           activation_2[0][0]               
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 10)            5130        dropout_2[0][0]                  
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 10)            0           dense_3[0][0]                    
====================================================================================================
Total params: 669,706
Trainable params: 669,706
Non-trainable params: 0
____________________________________________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/20
60000/60000 [==============================] - 18s - loss: 1.2175 - acc: 0.6823 - val_loss: 0.5459 - val_acc: 0.8675
Epoch 2/20                                                                                                                                                   
60000/60000 [==============================] - 18s - loss: 0.5365 - acc: 0.8498 - val_loss: 0.3785 - v                                                       al_acc: 0.9003
Epoch 3/20
60000/60000 [==============================] - 18s - loss: 0.4252 - acc: 0.8784 - val_loss: 0.3234 - val_acc: 0.9097
Epoch 4/20
60000/60000 [==============================] - 18s - loss: 0.3770 - acc: 0.8906 - val_loss: 0.2918 - val_acc: 0.9189
Epoch 5/20
60000/60000 [==============================] - 17s - loss: 0.3424 - acc: 0.9015 - val_loss: 0.2701 - val_acc: 0.9241
Epoch 6/20
60000/60000 [==============================] - 17s - loss: 0.3157 - acc: 0.9093 - val_loss: 0.2537 - val_acc: 0.9302
Epoch 7/20
60000/60000 [==============================] - 18s - loss: 0.2971 - acc: 0.9141 - val_loss: 0.2378 - val_acc: 0.9333
Epoch 8/20
60000/60000 [==============================] - 17s - loss: 0.2795 - acc: 0.9200 - val_loss: 0.2250 - val_acc: 0.9356
Epoch 9/20
60000/60000 [==============================] - 18s - loss: 0.2638 - acc: 0.9232 - val_loss: 0.2157 - val_acc: 0.9379
Epoch 10/20
60000/60000 [==============================] - 18s - loss: 0.2525 - acc: 0.9274 - val_loss: 0.2057 - val_acc: 0.9404
Epoch 11/20
60000/60000 [==============================] - 18s - loss: 0.2398 - acc: 0.9302 - val_loss: 0.1966 - val_acc: 0.9427
Epoch 12/20
60000/60000 [==============================] - 21s - loss: 0.2301 - acc: 0.9336 - val_loss: 0.1886 - val_acc: 0.9457
Epoch 13/20
60000/60000 [==============================] - 18s - loss: 0.2199 - acc: 0.9368 - val_loss: 0.1807 - val_acc: 0.9481
Epoch 14/20
60000/60000 [==============================] - 18s - loss: 0.2130 - acc: 0.9388 - val_loss: 0.1746 - val_acc: 0.9496
Epoch 15/20
60000/60000 [==============================] - 19s - loss: 0.2053 - acc: 0.9406 - val_loss: 0.1683 - val_acc: 0.9525
Epoch 16/20
60000/60000 [==============================] - 18s - loss: 0.1958 - acc: 0.9436 - val_loss: 0.1624 - val_acc: 0.9537
Epoch 17/20
60000/60000 [==============================] - 21s - loss: 0.1903 - acc: 0.9453 - val_loss: 0.1569 - val_acc: 0.9542
Epoch 18/20
60000/60000 [==============================] - 18s - loss: 0.1815 - acc: 0.9479 - val_loss: 0.1535 - val_acc: 0.9553
Epoch 19/20
60000/60000 [==============================] - 18s - loss: 0.1787 - acc: 0.9481 - val_loss: 0.1480 - val_acc: 0.9577
Epoch 20/20
60000/60000 [==============================] - 18s - loss: 0.1713 - acc: 0.9503 - val_loss: 0.1433 - val_acc: 0.9589
Test score: 0.143255090081
Test accuracy: 0.9589

 
精度 95.9 % です。特に問題がないので、MNIST CNN のサンプルも実行してみます :

$ python mnist_cnn.py 
Using MXNet backend.
X_train shape: (60000, 1, 28, 28)
60000 train samples
10000 test samples
Train on 60000 samples, validate on 10000 samples
Epoch 1/12
60000/60000 [==============================] - 321s - loss: 0.3929 - acc: 0.8780 - val_loss: 0.0970 - val_acc: 0.9708
Epoch 2/12
60000/60000 [==============================] - 320s - loss: 0.1429 - acc: 0.9577 - val_loss: 0.0648 - val_acc: 0.9796
Epoch 3/12
60000/60000 [==============================] - 321s - loss: 0.1080 - acc: 0.9677 - val_loss: 0.0518 - val_acc: 0.9840
Epoch 4/12
60000/60000 [==============================] - 320s - loss: 0.0917 - acc: 0.9728 - val_loss: 0.0464 - val_acc: 0.9846
Epoch 5/12
60000/60000 [==============================] - 321s - loss: 0.0754 - acc: 0.9770 - val_loss: 0.0418 - val_acc: 0.9860
Epoch 6/12
60000/60000 [==============================] - 321s - loss: 0.0717 - acc: 0.9787 - val_loss: 0.0391 - val_acc: 0.9863
Epoch 7/12
60000/60000 [==============================] - 321s - loss: 0.0651 - acc: 0.9809 - val_loss: 0.0363 - val_acc: 0.9869
Epoch 8/12
60000/60000 [==============================] - 321s - loss: 0.0612 - acc: 0.9816 - val_loss: 0.0337 - val_acc: 0.9880
Epoch 9/12
60000/60000 [==============================] - 321s - loss: 0.0569 - acc: 0.9830 - val_loss: 0.0338 - val_acc: 0.9876
Epoch 10/12
60000/60000 [==============================] - 329s - loss: 0.0535 - acc: 0.9840 - val_loss: 0.0331 - val_acc: 0.9890
Epoch 11/12
60000/60000 [==============================] - 331s - loss: 0.0515 - acc: 0.9843 - val_loss: 0.0312 - val_acc: 0.9886
Epoch 12/12
60000/60000 [==============================] - 320s - loss: 0.0466 - acc: 0.9862 - val_loss: 0.0311 - val_acc: 0.9890
Test score: 0.0311143597026
Test accuracy: 0.989

 
精度 98.9 % に改善されました!

 

以上