PyTorch 1.8 : PyTorch の学習 : マルチ GPU サンプル

PyTorch 1.8 チュートリアル : PyTorch の学習 : マルチ GPU サンプル (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 06/06/2021 (1.8.1+cu102)

* 本ページは、PyTorch 1.8 Tutorials の以下のページを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

無料 Web セミナー開催中 クラスキャット主催 人工知能 & ビジネス Web セミナー

人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。
スケジュールは弊社 公式 Web サイト でご確認頂けます。
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
  • ウェビナー運用には弊社製品「ClassCat® Webinar」を利用しています。
クラスキャットは人工知能・テレワークに関する各種サービスを提供しております :

人工知能研究開発支援 人工知能研修サービス テレワーク & オンライン授業を支援
PoC(概念実証)を失敗させないための支援 (本支援はセミナーに参加しアンケートに回答した方を対象としています。)

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
E-Mail:sales-info@classcat.com  ;  WebSite: https://www.classcat.com/  ;  Facebook

 

PyTorch の学習 : マルチ GPU サンプル

データ並列性はサンプルのミニバッチを複数のより小さいミニバッチに分割して小さいミニバッチの各々のための計算を並列に実行することです。

データ並列性は torch.nn.DataParallel を使用して実装されます。Module を DataParallel 内にラップできてそれはマルチ GPU に渡りバッチ次元で並列化されます。

 

DataParallel

import torch
import torch.nn as nn


class DataParallelModel(nn.Module):

    def __init__(self):
        super().__init__()
        self.block1 = nn.Linear(10, 20)

        # wrap block2 in DataParallel
        self.block2 = nn.Linear(20, 20)
        self.block2 = nn.DataParallel(self.block2)

        self.block3 = nn.Linear(20, 20)

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        return x

コードは CPU-モードで変更される必要はありません。

DataParallel のためのドキュメントは ここ で見つかります。

 

ラップされたモジュールの属性

Module を DataParallel でラップした後、モジュールの属性 (e.g. カスタムメソッド) はアクセスできなくなります。これは DataParallel が幾つかの新しいメンバーを定義し、他の属性の許可はそれらの名前の衝突に繋がるかもしれないためです。依然として属性にアクセスすることを望む人たちのためには、回避方法は下のように DataParallel のサブクラスを使用することです。

class MyDataParallel(nn.DataParallel):
    def __getattr__(self, name):
        return getattr(self.module, name)

 

(その上で) DataParallel が実装されているプリミティブ

一般に、pytorch の nn.parallel プリミティブは独立的に利用できます。私達は単純な MPI-ライクなプリミティブを実装しました :

  • replicate : マルチデバイス上で Module を複製します。
  • scatter : 入力を最初の次元で分散します。
  • gather : 入力を最初の次元で集めて結合します。
  • parallel_apply : 既に分散された入力のセットを既に分散されたモデルのセットに適用します。

より明瞭にするため、ここでは関数 data_parallel はこれらの集合体を使用して組み合わされています。

def data_parallel(module, input, device_ids, output_device=None):
    if not device_ids:
        return module(input)

    if output_device is None:
        output_device = device_ids[0]

    replicas = nn.parallel.replicate(module, device_ids)
    inputs = nn.parallel.scatter(input, device_ids)
    replicas = replicas[:len(inputs)]
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    return nn.parallel.gather(outputs, output_device)

 

CPU 上のモデルの一部と GPU 上の一部

ネットワークを実装する小さいサンプルを見ましょう、そこではそれの一部は CPU 上にそして一部は GPU 上にあります。

device = torch.device("cuda:0")

class DistributedModel(nn.Module):

    def __init__(self):
        super().__init__(
            embedding=nn.Embedding(1000, 10),
            rnn=nn.Linear(10, 10).to(device),
        )

    def forward(self, x):
        # Compute embedding on CPU
        x = self.embedding(x)

        # Transfer to GPU
        x = x.to(device)

        # Compute RNN on GPU
        x = self.rnn(x)
        return x
 

以上