MXNet : Module API 概要 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
日時 : 02/24/2017
* 本ページは、MXNet 本家サイトの Module API の概要説明を翻訳した上で適宜、補足したものです:
http://mxnet.io/api/python/module.html
序
module API は、MXNet においてニューラルネットワークで計算を実行するために中位と高位-レベルな I/F を提供します。module は $BaseModule$ のサブクラスのインスタンスです。もっとも広く使われる module クラスは単純に $Module$ と呼ばれ、$Symbol$ と一つまたはそれ以上の $Executors$ をラップします。関数の完全なリストは $BaseModule$ を参照してください。modules の各サブクラスは幾つかの追加の I/F 関数を持つかもしれません。このトピックでは、一般的なユースケースの幾つかの例を提供します。全ての module APIs は $mxnet.module$ 名前空間にあり、単に $mxnet.mod$ と呼称されます。
計算のために Module を準備する
module を構築するためには特定の module クラスのためのコンストラクタを参照してください。例えば、$Module$ クラスは $Symbol$ を入力として受け取ります :
import mxnet as mx # construct a simple MLP data = mx.symbol.Variable('data') fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128) act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu") fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64) act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu") fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10) out = mx.symbol.SoftmaxOutput(fc3, name = 'softmax') # construct the module mod = mx.mod.Module(out)
$Symbol$ の $data\_names$ と $label\_names$ もまた指定します。それらのパラメータはスキップします、何故ならば私たちの $Symbol$ は命名規則に従うからです、従ってデフォルトの挙動 (data は $data$ として命名、ラベルは $softmax\_label$ として命名) で構いません。$context$、デフォルトは CPU、は他の重要なパラメータです。GPU context を指定できますし、データ並列化が必要な時には GPU context のリストさえ指定可能です。
module で計算を可能にするには、デバイスメモリを割り当てるために $bind()$ を呼ぶ必要がありそしてパラメータを初期化するために $init\_params()$ あるいは $set\_params()$ を呼び出します。
mod.bind(data_shapes=train_dataiter.provide_data, label_shapes=train_dataiter.provide_label) mod.init_params()
これで $foward()$ や $backword()$, etc. のような関数を使って module で計算可能です。単に module を fit させたいだけであれば $bind$ と $init\_parames()$ を明示的に呼ぶ必要はありません、何故ならば $fit()$ 関数が必要な場合には自動的にそれらを呼び出すからです。
トレーニングし、予測し、そして評価する
Modules provide high-level APIs for training, predicting, and evaluating. To fit a module, call the fit() function with some DataIters:
mod = mx.mod.Module(softmax) mod.fit(train_dataiter, eval_data=eval_dataiter, optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=n_epoch)
以上