Kornia 0.6 : Tutorials (データ増強) : データ増強セマンティックセグメンテーション (翻訳/解説)
翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/26/2022 (v0.6.8)
* 本ページは、Kornia Tutorials の以下のドキュメントを翻訳した上で適宜、補足説明したものです:
- Data Augmentation : Data Augmentation Semantic Segmentation
* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。
クラスキャット 人工知能 研究開発支援サービス
◆ クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :
- 人工知能研究開発支援
- 人工知能研修サービス(経営者層向けオンサイト研修)
- テクニカルコンサルティングサービス
- 実証実験(プロトタイプ構築)
- アプリケーションへの実装
- 人工知能研修サービス
- PoC(概念実証)を失敗させないための支援
◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール。
- お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。
◆ お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。
- 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
- sales-info@classcat.com ; Web: www.classcat.com ; ClassCatJP
Kornia 0.6 : Tutorials (データ増強) : データ増強セマンティックセグメンテーション
このチュートリアルでは、kornia.augmentation API を使用して、セマンティックセグメンテーションに対してデータ増強を素早く実行できる方法を示します。
インストールとデータの習得
Kornia と幾つかの依存関係をインストールして、簡単なデータサンプルをダウンロードします :
%%capture
!pip install kornia opencv-python matplotlib
%%capture
!wget http://www.zemris.fer.hr/~ssegvic/multiclod/images/causevic16semseg3.png
# import the libraries
%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import numpy as np
import torch
import torch.nn as nn
import kornia as K
/home/docs/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
増強パイプラインの定義
nn.Module を使用して増強 API を定義するためにクラスを定義します :
class MyAugmentation(nn.Module):
def __init__(self):
super(MyAugmentation, self).__init__()
# we define and cache our operators as class members
self.k1 = K.augmentation.ColorJitter(0.15, 0.25, 0.25, 0.25)
self.k2 = K.augmentation.RandomAffine([-45., 45.], [0., 0.15], [0.5, 1.5], [0., 0.15])
def forward(self, img: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# 1. apply color only in image
# 2. apply geometric tranform
img_out = self.k2(self.k1(img))
# 3. infer geometry params to mask
# TODO: this will change in future so that no need to infer params
mask_out = self.k2(mask, self.k2._params)
return img_out, mask_out
データをロードして変換を適用します :
def load_data(data_path: str) -> torch.Tensor:
data: np.ndarray = cv2.imread(data_path, cv2.IMREAD_COLOR)
data_t: torch.Tensor = K.image_to_tensor(data, keepdim=False)
data_t = K.color.bgr_to_rgb(data_t)
data_t = K.enhance.normalize(data_t, torch.tensor(0.), torch.tensor(255.))
img, labels = data_t[..., :571], data_t[..., 572:]
return img, labels
# load data (B, C, H, W)
img, labels = load_data("causevic16semseg3.png")
# create augmentation instance
aug = MyAugmentation()
# apply the augmenation pipelone to our batch of data
img_aug, labels_aug = aug(img, labels)
# visualize
img_out = torch.cat([img, labels], dim=-1)
plt.imshow(K.tensor_to_image(img_out))
plt.axis('off')
# generate several samples
num_samples: int = 10
for img_id in range(num_samples):
# generate data
img_aug, labels_aug = aug(img, labels)
img_out = torch.cat([img_aug, labels_aug], dim=-1)
# save data
plt.figure()
plt.imshow(K.tensor_to_image(img_out))
plt.axis('off')
plt.savefig(f"img_{img_id}.png", bbox_inches='tight')
以上