MXNet : KVStore API 概要

MXNet : KVStore API 概要 (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
日時 : 02/24/2017

* 本ページは、MXNet 本家サイトの KVStore API の概要説明を翻訳した上で適宜、補足したものです:
    http://mxnet.io/api/python/kvstore.html

 

トピック :

  • 基本的な Push と Pull
  • Key-Value ペアを List
  • API Reference

* 訳注 : API リファレンスについては原文参照のこと。

 

基本的な Push と Pull

シングル・デバイス上のマルチデバイス (GPUs) に渡る基本的な操作を提供します。

初期化

単純な例を考えましょう。(int, NDArray) ペアをストアに初期化して、そして値をプルアウトします。

>>> kv = mx.kv.create('local') # create a local kv store.
>>> shape = (2,3)
>>> kv.init(3, mx.nd.ones(shape)*2)
>>> a = mx.nd.zeros(shape)
>>> kv.pull(3, out = a)
>>> print a.asnumpy()
[[ 2.  2.  2.]
 [ 2.  2.  2.]]

Push, Aggregation, と Updater

初期化された任意のキーについて、次のように、キーに対して新しい値を同じ shape で push できます :

>>> kv.push(3, mx.nd.ones(shape)*8)
>>> kv.pull(3, out = a) # pull out the value
>>> print a.asnumpy()
[[ 8.  8.  8.]
 [ 8.  8.  8.]]

pus したいデータは任意のデバイス上にストアできます。更に、同じキーに複数の値を push できます、そこでは KVStore は最初にこれら全ての値を合計してそしてそれから集約された値を次のように push します :

>>> gpus = [mx.gpu(i) for i in range(4)]
>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.push(3, b)
>>> kv.pull(3, out = a)
>>> print a.asnumpy()
[[ 4.  4.  4.]
 [ 4.  4.  4.]]

各 push コマンドに対して、KVStore は push された値をストアされた値に updater で適用します。デフォルトの updater は ASSIGN です。どのようにデータが merge されるかを制御するためにデフォルトを置き換えることができます。

>>> def update(key, input, stored):
>>>     print "update on key: %d" % key
>>>     stored += input * 2
>>> kv._set_updater(update)
>>> kv.pull(3, out=a)
>>> print a.asnumpy()
[[ 4.  4.  4.]
 [ 4.  4.  4.]]
>>> kv.push(3, mx.nd.ones(shape))
update on key: 3
>>> kv.pull(3, out=a)
>>> print a.asnumpy()
[[ 6.  6.  6.]
 [ 6.  6.  6.]]

Pull

単一の key-value ペアをどのように pull するかを既に見ました。push コマンドを使用しるのと同様の方法で、値を幾つかのデバイスに単一の呼び出しで pull することができます。

>>> b = [mx.nd.ones(shape, gpu) for gpu in gpus]
>>> kv.pull(3, out = b)
>>> print b[1].asnumpy()
[[ 6.  6.  6.]
 [ 6.  6.  6.]]

 

Key-Value ペアを List

ここまで議論した操作の全ては単一のキー上で実行されました。KVStore はまた key-value ペアのリストを生成するための I/F もまた提供します。単一のデバイスに対しては、以下を使います :

>>> keys = [5, 7, 9]
>>> kv.init(keys, [mx.nd.ones(shape)]*len(keys))
>>> kv.push(keys, [mx.nd.ones(shape)]*len(keys))
update on key: 5
update on key: 7
update on key: 9
>>> b = [mx.nd.zeros(shape)]*len(keys)
>>> kv.pull(keys, out = b)
>>> print b[1].asnumpy()
[[ 3.  3.  3.]
 [ 3.  3.  3.]]

複数のデバイスに対しては :

>>> b = [[mx.nd.ones(shape, gpu) for gpu in gpus]] * len(keys)
>>> kv.push(keys, b)
update on key: 5
update on key: 7
update on key: 9
>>> kv.pull(keys, out = b)
>>> print b[1][1].asnumpy()
[[ 11.  11.  11.]
 [ 11.  11.  11.]]
 

以上