PyTorch デザインノート : Multiprocessing ベストプラクティス

PyTorch デザインノート : Multiprocessing ベストプラクティス (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 05/28/2018 (0.4.0)

* 本ページは、PyTorch Doc Notes の – Multiprocessing best practices を動作確認・翻訳した上で
適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、適宜、追加改変している場合もあります。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

本文

torch.multiprocessing は Python の multiprocessing モジュールの代替品です。それは正確に同じ演算をサポートしますが、それを拡張し、結果的に multiprocessing.Queue を通して送られた総ての tensor は共有メモリに移されたデータを持ち他のプロセスにハンドルを送るだけです。

Note: Tensor が他のプロセスに送られるとき、Tensor データと torch.Tensor.grad の両者が共有されます。

これは、Hogwild, A3C あるいは非同期演算を要求する任意の他の様々な訓練メソッドを実装することを可能にします。

 

CUDA tensor を共有する

プロセス間での CUDA tensor の共有は spawn や forkserver start メソッドを使用して Python 3 でのみサポートされます。Python 2 の multiprocessing は fork を使用してサブプロセスを作成できるだけで、それは CUDA ランタイムではサポートされません。

警告:

CUDA API は、他のプロセスにエクスポートされた割当てが (それらに使用される限り) 有効であり続けることを要求します。共有した CUDA tensor が (それが必要である限り) スコープの外に出ないことに注意して確実にするべきです。これはモデルパラメータの共有のためには問題ではありませんが、他の種類のデータを渡すことはケアされるべきです。この制限は共有 CPU メモリには適用されないことに注意してください。

マルチプロセス処理の代わりに nn.DataParallel を使用する もまた見てください :

 

ベストプラクティスとティップス

デッドロックを回避して戦う

新しいプロセスが spawn されるとき上手くいかない多くのものがあり、デッドロックの最も一般的な原因はバックグラウンド・スレッドです。ロックを保持するかモジュールをインポートする任意のスレッドがある場合、サブプロセスが壊れた状態になり異なる方法でデッドロックや失敗することは非常に可能性があります。もし貴方がそうしない場合でさえ、Python 組込みライブラリがそうすることに注意してください – multiprocessing より先を見る必要はありません。multiprocessing.Queue は実際には非常に複雑なクラスで、これはオブジェクトのシリアライズ、送信と受信のために使用されるマルチスレッドを spawn し、そしてそれらは前述の問題をまた引き起こすことができます。そのような状況にあることを見出した場合には、multiprocessing.queues.SimpleQueue を使用してみてください、これはどのような追加スレッドも使用しません。

私達はそれが貴方にとって容易になるようにベストを尽くしこれらのデッドロックが起きないことを確実にしようとしていますが、幾つかのことは制御外です。

 

Queue を通して渡されたバッファを再利用する

貴方が Tensormultiprocessing.Queue 内へと置くたびに、それは共有メモリ内に移されなければならないことを忘れないでください。もしそれが既に共有されているのであれば、それは no-op です、そうでなければそれは追加のメモリコピーを招きこれはプロセス全体をスローダウンする可能性があります。データを単一の一つに送るようなプロセスのプールを持つ場合でさえ、それにバッファを送り返させてください – これは殆ど (コスト) フリーで次のバッチを送るとき貴方にコピーを回避させるでしょう。

 

非同期マルチプロセス訓練 (e.g. Hogwild)

torch.multiprocessing を使用して、モデルを非同期に訓練することが可能です、パラメータはずっと共有するか定期的に同期されます。最初のケースでは、モデル・オブジェクト全体に渡り送ることを勧めます、その一方で後者では、state_dict() だけを送るように勧めます。

プロセス間で総ての種類の PyTorch オブジェクトを渡すために multiprocessing.Queue を使用することを勧めます。例えば、fork start メソッドを使用するとき、既に共有メモリにある tensor とストレージから継承することは可能ですが、それは非常にバグでありがちで、進んだユーザによってのみ、注意深く使用されるべきです。Queue は時に洗練されていない解法ですけれども、総てのケースで適切に動作するでしょう。

警告:

if __name__ == ‘__main__’ でガードされていないグローバル・ステートメントを持つことについて注意深くあるべきです。fork とは異なる start メソッドが使用される場合、それらは総てのサブプロセスで実行されるでしょう。

 

Hogwild

具体的な Hogwild 実装は examples レポジトリ で見つかりますが、コードの全体的な構造を示すためには、下の最小限のサンプルもまたあります :

import torch.multiprocessing as mp
from model import MyModel

def train(model):
    # Construct data_loader, optimizer, etc.
    for data, labels in data_loader:
        optimizer.zero_grad()
        loss_fn(model(data), labels).backward()
        optimizer.step()  # This will update the shared parameters

if __name__ == '__main__':
    num_processes = 4
    model = MyModel()
    # NOTE: this is required for the ``fork`` method to work
    model.share_memory()
    processes = []
    for rank in range(num_processes):
        p = mp.Process(target=train, args=(model,))
        p.start()
        processes.append(p)
    for p in processes:
      p.join()
 

以上