Kornia 0.6 : Tutorials (中級) : SOLD2 によるライン検出とマッチング

Kornia 0.6 : Tutorials (中級) : SOLD2 によるライン検出とマッチング:自己教師ありオクルージョン-aware なライン記述と検出 (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 10/28/2022 (v0.6.8)

* 本ページは、Kornia Tutorials の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

Kornia 0.6 : Tutorials (中級) : SOLD2 によるライン検出とマッチング:自己教師ありオクルージョン-aware なライン記述と検出

このチュートリアルでは、kornia.feature.sold2 API を使用してライン検出、そしてマッチングを素早く実行できる方法を示します。

 

セットアップ

ライブラリをインストールします :

%%capture
!pip install git+https://github.com/kornia/kornia
!pip install opencv-python --upgrade # Just for windows
!pip install matplotlib

次に画像をダウンロードします :

%%capture
!wget https://github.com/cvg/SOLD2/raw/main/assets/images/terrace0.JPG
!wget https://github.com/cvg/SOLD2/raw/main/assets/images/terrace1.JPG

そして、ライブラリをロードします :

import kornia as K
import kornia.feature as KF
import torch
/home/docs/checkouts/readthedocs.org/user_builds/kornia-tutorials/envs/latest/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

画像をロードして torch テンソルに変換します。

def load_img(img_path):
    try:
        # not ready on Windows machine
        img = K.io.load_image(img_path, K.io.ImageLoadType.RGB32)
    except:
        import cv2

        img = cv2.imread(img_path, cv2.IMREAD_COLOR)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = K.image_to_tensor(img).float() / 255.0
    return img
fname1 = "terrace0.JPG"
fname2 = "terrace1.JPG"

torch_img1 = load_img(fname1)
torch_img2 = load_img(fname2)

torch_img1.shape, torch_img2.shape
(torch.Size([3, 496, 744]), torch.Size([3, 496, 744]))

モデル用にデータを準備します、これはグレースケール (shape: (Batch size, 1, Height, Width)) の画像のバッチが想定されます。

SOLD2 モデルは config=None を使用するとき範囲 400~800px の画像に対して調整されました。

# First, convert the images to gray scale
torch_img1_gray = K.color.rgb_to_grayscale(torch_img1)
torch_img2_gray = K.color.rgb_to_grayscale(torch_img2)
torch_img1_gray.shape, torch_img2_gray.shape
(torch.Size([1, 496, 744]), torch.Size([1, 496, 744]))
# then, stack the images to create/simulate a batch
imgs = torch.stack(
    [torch_img1_gray, torch_img2_gray],
)

imgs.shape
torch.Size([2, 1, 496, 744])

 

ライン検出とマッチングの実行

sold2 モデルを pre-trained=True でロードします、これは事前訓練済み重みをダウンロードしてモデルに設定します。

%%capture
sold2 = KF.SOLD2(pretrained=True, config=None)

 

モデル予測の実行

%%capture
with torch.inference_mode():
    outputs = sold2(imgs)

デモ用の出力を体系化します。

Attention : 検出された線分は慣習で ij 座標にあります。

outputs.keys()
dict_keys(['junction_heatmap', 'line_heatmap', 'dense_desc', 'line_segments'])
line_seg1 = outputs["line_segments"][0]
line_seg2 = outputs["line_segments"][1]
desc1 = outputs["dense_desc"][0]
desc2 = outputs["dense_desc"][1]

 

ライン・マッチングの実行

with torch.inference_mode():
    matches = sold2.match(line_seg1, line_seg2, desc1[None], desc2[None])
valid_matches = matches != -1
match_indices = matches[valid_matches]

matched_lines1 = line_seg1[valid_matches]
matched_lines2 = line_seg2[match_indices]

 

検出されたラインとマッチングのプロット

元のコード から適応されたプロット関数 :

import copy

import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np


def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=6, pad=0.5):
    """Plot a set of images horizontally.
    Args:
        imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
        titles: a list of strings, as titles for each image.
        cmaps: colormaps for monochrome images.
    """
    n = len(imgs)
    if not isinstance(cmaps, (list, tuple)):
        cmaps = [cmaps] * n
    figsize = (size * n, size * 3 / 4) if size is not None else None
    fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
    if n == 1:
        ax = [ax]
    for i in range(n):
        ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
        ax[i].get_yaxis().set_ticks([])
        ax[i].get_xaxis().set_ticks([])
        ax[i].set_axis_off()
        for spine in ax[i].spines.values():  # remove frame
            spine.set_visible(False)
        if titles:
            ax[i].set_title(titles[i])
    fig.tight_layout(pad=pad)


def plot_lines(
    lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
):
    """Plot lines and endpoints for existing images.
    Args:
        lines: list of ndarrays of size (N, 2, 2).
        colors: string, or list of list of tuples (one for each keypoints).
        ps: size of the keypoints as float pixels.
        lw: line width as float pixels.
        indices: indices of the images to draw the matches on.
    """
    if not isinstance(line_colors, list):
        line_colors = [line_colors] * len(lines)
    if not isinstance(point_colors, list):
        point_colors = [point_colors] * len(lines)

    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    axes = [ax[i] for i in indices]
    fig.canvas.draw()

    # Plot the lines and junctions
    for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
        for i in range(len(l)):
            line = matplotlib.lines.Line2D(
                (l[i, 1, 1], l[i, 0, 1]),
                (l[i, 1, 0], l[i, 0, 0]),
                zorder=1,
                c=lc,
                linewidth=lw,
            )
            a.add_line(line)
        pts = l.reshape(-1, 2)
        a.scatter(pts[:, 1], pts[:, 0], c=pc, s=ps, linewidths=0, zorder=2)


def plot_color_line_matches(lines, lw=2, indices=(0, 1)):
    """Plot line matches for existing images with multiple colors.
    Args:
        lines: list of ndarrays of size (N, 2, 2).
        lw: line width as float pixels.
        indices: indices of the images to draw the matches on.
    """
    n_lines = len(lines[0])

    cmap = plt.get_cmap("nipy_spectral", lut=n_lines)
    colors = np.array([mcolors.rgb2hex(cmap(i)) for i in range(cmap.N)])

    np.random.shuffle(colors)

    fig = plt.gcf()
    ax = fig.axes
    assert len(ax) > max(indices)
    axes = [ax[i] for i in indices]
    fig.canvas.draw()

    # Plot the lines
    for a, l in zip(axes, lines):
        for i in range(len(l)):
            line = matplotlib.lines.Line2D(
                (l[i, 1, 1], l[i, 0, 1]),
                (l[i, 1, 0], l[i, 0, 0]),
                zorder=1,
                c=colors[i],
                linewidth=lw,
            )
            a.add_line(line)
imgs_to_plot = [K.tensor_to_image(torch_img1), K.tensor_to_image(torch_img2)]
lines_to_plot = [line_seg1.numpy(), line_seg2.numpy()]

plot_images(imgs_to_plot, ["Image 1 - detected lines", "Image 2 - detected lines"])
plot_lines(lines_to_plot, ps=3, lw=2, indices={0, 1})

plot_images(imgs_to_plot, ["Image 1 - matched lines", "Image 2 - matched lines"])
plot_color_line_matches([matched_lines1, matched_lines2], lw=2)

 

以上