Kornia 0.6 : Tutorials (データ増強) : データ増強セマンティックセグメンテーション

Kornia 0.6 : Tutorials (データ増強) : データ増強セマンティックセグメンテーション (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/26/2022 (v0.6.8)

* 本ページは、Kornia Tutorials の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

Kornia 0.6 : Tutorials (データ増強) : データ増強セマンティックセグメンテーション

このチュートリアルでは、kornia.augmentation API を使用して、セマンティックセグメンテーションに対してデータ増強を素早く実行できる方法を示します。

 

インストールとデータの習得

Kornia と幾つかの依存関係をインストールして、簡単なデータサンプルをダウンロードします :

%%capture
!pip install kornia opencv-python matplotlib
%%capture
!wget http://www.zemris.fer.hr/~ssegvic/multiclod/images/causevic16semseg3.png
# import the libraries
%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import numpy as np

import torch
import torch.nn as nn
import kornia as K
/home/docs/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

 

増強パイプラインの定義

nn.Module を使用して増強 API を定義するためにクラスを定義します :

class MyAugmentation(nn.Module):
  def __init__(self):
    super(MyAugmentation, self).__init__()
    # we define and cache our operators as class members
    self.k1 = K.augmentation.ColorJitter(0.15, 0.25, 0.25, 0.25)
    self.k2 = K.augmentation.RandomAffine([-45., 45.], [0., 0.15], [0.5, 1.5], [0., 0.15])
  
  def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    # 1. apply color only in image
    # 2. apply geometric tranform
    img_out = self.k2(self.k1(img))

    # 3. infer geometry params to mask
    # TODO: this will change in future so that no need to infer params
    mask_out = self.k2(mask, self.k2._params)

    return img_out, mask_out

データをロードして変換を適用します :

def load_data(data_path: str) -> torch.Tensor:
  data: np.ndarray = cv2.imread(data_path, cv2.IMREAD_COLOR)
  data_t: torch.Tensor = K.image_to_tensor(data, keepdim=False)
  data_t = K.color.bgr_to_rgb(data_t)
  data_t = K.enhance.normalize(data_t, torch.tensor(0.), torch.tensor(255.))
  img, labels = data_t[..., :571], data_t[..., 572:]
  return img, labels

# load data (B, C, H, W)
img, labels = load_data("causevic16semseg3.png")

# create augmentation instance
aug = MyAugmentation()

# apply the augmenation pipelone to our batch of data
img_aug, labels_aug = aug(img, labels)

# visualize
img_out = torch.cat([img, labels], dim=-1)
plt.imshow(K.tensor_to_image(img_out))
plt.axis('off')

# generate several samples
num_samples: int = 10

for img_id in range(num_samples):
  # generate data
  img_aug, labels_aug = aug(img, labels)
  img_out = torch.cat([img_aug, labels_aug], dim=-1)

  # save data
  plt.figure()
  plt.imshow(K.tensor_to_image(img_out))
  plt.axis('off')
  plt.savefig(f"img_{img_id}.png", bbox_inches='tight')

 

以上