Paella : ノートブック : サンプリング

Paella : ノートブック : サンプリング (翻訳/解説)

翻訳 : (株)クラスキャット セールスインフォメーション
作成日時 : 12/11/2022 (v2.7.0)

* 本ページは、dome272/Paella の以下のドキュメントを翻訳した上で適宜、補足説明したものです:

* サンプルコードの動作確認はしておりますが、必要な場合には適宜、追加改変しています。
* ご自由にリンクを張って頂いてかまいませんが、sales-info@classcat.com までご一報いただけると嬉しいです。

 

クラスキャット 人工知能 研究開発支援サービス

クラスキャット は人工知能・テレワークに関する各種サービスを提供しています。お気軽にご相談ください :

◆ 人工知能とビジネスをテーマに WEB セミナーを定期的に開催しています。スケジュール
  • お住まいの地域に関係なく Web ブラウザからご参加頂けます。事前登録 が必要ですのでご注意ください。

お問合せ : 本件に関するお問い合わせ先は下記までお願いいたします。

  • 株式会社クラスキャット セールス・マーケティング本部 セールス・インフォメーション
  • sales-info@classcat.com  ;  Web: www.classcat.com  ;   ClassCatJP

 

 

Paella : ノートブック : サンプリング

Paella レポジトリとともに OpenCLIP も使用します。

!nvidia-smi
Thu Dec 15 03:40:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  A100-SXM4-40GB      Off  | 00000000:00:04.0 Off |                    0 |
| N/A   29C    P0    53W / 400W |      0MiB / 40536MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

※ 訳注 : GPU メモリは 15GB では不足します。

!pip install einops rudalle open_clip_torch git+https://github.com/pabloppp/pytorch-tools
!git clone https://github.com/dome272/Paella

 

モデルのダウンロード

最初のコマンドは標準モデルをダウンロードします、これはテキスト埋め込み上で訓練されました。

2 番目のコマンドは画像埋め込み上で微調整されたモデルを追加でダウンロードします、これは画像のバリエーションを生成するのに役立ちます。

!curl -L -s -o ./model_600000.pt 'https://drive.google.com/uc?id=1ACIb2nrZk_jQID8bP9qYKkNZVldM94zw&confirm=t'  # Text-Model
!curl -L -s -o ./model_50000_img.pt 'https://drive.google.com/uc?id=1VVWPXFLdI18TXhqlYmfMNT41Suz9pJr8&confirm=t'  # Finetuned Image-Model

 

セットアップコード

必要なモジュールをインポートします :

import os
import time
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
import requests
from io import BytesIO
from Paella.modules import DenoiseUNet
import open_clip
from open_clip import tokenizer
from rudalle import get_vae
from einops import rearrange

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

サンプリング関数 sample() とともに各種ユティリティ関数を定義します :

def showmask(mask):
    plt.axis("off")
    plt.imshow(torch.cat([
        torch.cat([i for i in mask[0:1].cpu()], dim=-1),
    ], dim=-2).cpu())
    plt.show()

def showimages(imgs, **kwargs):
    plt.figure(figsize=(kwargs.get("width", 32), kwargs.get("height", 32)))
    plt.axis("off")
    plt.imshow(torch.cat([
        torch.cat([i for i in imgs], dim=-1),
    ], dim=-2).permute(1, 2, 0).cpu())
    plt.show()
    
def saveimages(imgs, name, **kwargs):
    name = name.replace(" ", "_").replace(".", "")
    path = os.path.join("outputs", name + ".jpg")
    while os.path.exists(path):
        base, ext = path.split(".")
        num = base.split("_")[-1]
        if num.isdigit():
            num = int(num) + 1
            base = "_".join(base.split("_")[:-1])
        else:
            num = 0
        path = base + "_" + str(num) + "." + ext
    torchvision.utils.save_image(imgs, path, **kwargs)
def log(t, eps=1e-20):
    return torch.log(t + eps)

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(t, temperature=1., dim=-1):
    return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim)

def sample(model, c, x=None, mask=None, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0], typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=-1, renoise_steps=11, renoise_mode='start'):
    with torch.inference_mode():
        r_range = torch.linspace(0, 1, T+1)[:-1][:, None].expand(-1, c.size(0)).to(c.device)
        temperatures = torch.linspace(temp_range[0], temp_range[1], T)
        preds = []
        if x is None:
            x = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)
        elif mask is not None:
            noise = torch.randint(0, model.num_labels, size=(c.size(0), *size), device=c.device)
            x = noise * mask + (1-mask) * x
        init_x = x.clone()
        for i in range(starting_t, T):
            if renoise_mode == 'prev':
                prev_x = x.clone()
            r, temp = r_range[i], temperatures[i]
            logits = model(x, c, r)
            if classifier_free_scale >= 0:
                logits_uncond = model(x, torch.zeros_like(c), r)
                logits = torch.lerp(logits_uncond, logits, classifier_free_scale)
            x = logits
            x_flat = x.permute(0, 2, 3, 1).reshape(-1, x.size(1))
            if typical_filtering:
                x_flat_norm = torch.nn.functional.log_softmax(x_flat, dim=-1)
                x_flat_norm_p = torch.exp(x_flat_norm)
                entropy = -(x_flat_norm * x_flat_norm_p).nansum(-1, keepdim=True)

                c_flat_shifted = torch.abs((-x_flat_norm) - entropy)
                c_flat_sorted, x_flat_indices = torch.sort(c_flat_shifted, descending=False)
                x_flat_cumsum = x_flat.gather(-1, x_flat_indices).softmax(dim=-1).cumsum(dim=-1)

                last_ind = (x_flat_cumsum < typical_mass).sum(dim=-1)
                sorted_indices_to_remove = c_flat_sorted > c_flat_sorted.gather(1, last_ind.view(-1, 1))
                if typical_min_tokens > 1:
                    sorted_indices_to_remove[..., :typical_min_tokens] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(1, x_flat_indices, sorted_indices_to_remove)
                x_flat = x_flat.masked_fill(indices_to_remove, -float("Inf"))
            # x_flat = torch.multinomial(x_flat.div(temp).softmax(-1), num_samples=1)[:, 0]
            x_flat = gumbel_sample(x_flat, temperature=temp)
            x = x_flat.view(x.size(0), *x.shape[2:])
            if mask is not None:
                x = x * mask + (1-mask) * init_x
            if i < renoise_steps:
                if renoise_mode == 'start':
                    x, _ = model.add_noise(x, r_range[i+1], random_x=init_x)
                elif renoise_mode == 'prev':
                    x, _ = model.add_noise(x, r_range[i+1], random_x=prev_x)
                else: # 'rand'
                    x, _ = model.add_noise(x, r_range[i+1])
            preds.append(x.detach())
    return preds

DenoiseUNet をインスタンス作成して重みをロードします。OpenCLIP モデルも使用します :

os.makedirs("outputs", exist_ok=True)

vqmodel = get_vae().to(device)
vqmodel.eval().requires_grad_(False)

clip_model, _, _ = open_clip.create_model_and_transforms('ViT-g-14', pretrained='laion2b_s12b_b42k')
clip_model = clip_model.to(device).eval().requires_grad_(False)

clip_preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
])

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    # torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor(),
])

def encode(x):
    return vqmodel.model.encode((2 * x - 1))[-1][-1]
    
def decode(img_seq, shape=(32,32)):
        img_seq = img_seq.view(img_seq.shape[0], -1)
        b, n = img_seq.shape
        one_hot_indices = torch.nn.functional.one_hot(img_seq, num_classes=vqmodel.num_tokens).float()
        z = (one_hot_indices @ vqmodel.model.quantize.embed.weight)
        z = rearrange(z, 'b (h w) c -> b c h w', h=shape[0], w=shape[1])
        img = vqmodel.model.decode(z)
        img = (img.clamp(-1., 1.) + 1) * 0.5
        return img
    
state_dict = torch.load("./model_600000.pt", map_location=device)
# state_dict = torch.load("./models/f8_img_40000.pt", map_location=device)
model = DenoiseUNet(num_labels=8192).to(device)
model.load_state_dict(state_dict)
model.eval().requires_grad_()
print()
Working with z of shape (1, 256, 32, 32) = 262144 dimensions.
/usr/local/lib/python3.8/dist-packages/huggingface_hub/file_download.py:594: FutureWarning: `cached_download` is the legacy way to download files from the HF hub, please consider upgrading to `hf_hub_download`
  warnings.warn(
Downloading: 100%
346M/346M [00:04<00:00, 57.9MB/s]
vae --> ready
Downloading: 100%
5.47G/5.47G [04:59<00:00, 18.8MB/s]

 

テキスト条件付き

プロンプトとして "highly detailed photograph of darth vader. artstation" を使用します :

mode = "text"
batch_size = 6
text = "highly detailed photograph of darth vader. artstation"
latent_shape = (32, 32)
tokenized_text = tokenizer.tokenize([text] * batch_size).to(device)
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        clip_embeddings = clip_model.encode_text(tokenized_text)
        s = time.time()
        sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],
           typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,
           renoise_mode="start")
        print(time.time() - s)
    sampled = decode(sampled[-1], latent_shape)

showimages(sampled)
saveimages(sampled, mode + "_" + text, nrow=len(sampled))

 

潜在空間補間

mode = "interpolation"
text = "surreal painting of a yellow tulip. artstation"
text2 = "surreal painting of a red tulip. artstation"
text_encoded = tokenizer.tokenize([text]).to(device)
text2_encoded = tokenizer.tokenize([text2]).to(device)
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        clip_embeddings = clip_model.encode_text(text_encoded).float()
        clip_embeddings2 = clip_model.encode_text(text2_encoded).float()

        l = torch.linspace(0, 1, 10).to(device)
        embeddings = []
        for i in l:
            lerp = torch.lerp(clip_embeddings, clip_embeddings2, i)
            embeddings.append(lerp)
        embeddings = torch.cat(embeddings)
        
        s = time.time()
        sampled = sample(model, embeddings, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0],
               typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)
        print(time.time() - s)
    sampled = decode(sampled[-1])
showimages(sampled)
saveimages(sampled, mode + "_" + text + "_" + text2, nrow=len(sampled))

mode = "interpolation"
text = "High quality front portrait photo of a tiger."
text2 = "High quality front portrait photo of a dog."
text_encoded = tokenizer.tokenize([text]).to(device)
text2_encoded = tokenizer.tokenize([text2]).to(device)
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        clip_embeddings = clip_model.encode_text(text_encoded).float()
        clip_embeddings2 = clip_model.encode_text(text2_encoded).float()

        l = torch.linspace(0, 1, 10).to(device)
        s = time.time()
        outputs = []
        for i in l:
            # lerp = torch.lerp(clip_embeddings, clip_embeddings2, i)
            low, high = clip_embeddings, clip_embeddings2
            low_norm = low/torch.norm(low, dim=1, keepdim=True)
            high_norm = high/torch.norm(high, dim=1, keepdim=True)
            omega = torch.acos((low_norm*high_norm).sum(1)).unsqueeze(1)
            so = torch.sin(omega)
            lerp = (torch.sin((1.0-i)*omega)/so)*low + (torch.sin(i*omega)/so) * high
            with torch.random.fork_rng():
                torch.random.manual_seed(32)
                sampled = sample(model, lerp, T=12, size=(32, 32), starting_t=0, temp_range=[1.0, 1.0],
                       typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11)
                outputs.append(sampled[-1])
        print(time.time() - s)
    sampled = torch.cat(outputs)
    sampled = decode(sampled)
showimages(sampled)
saveimages(sampled, mode + "_" + text + "_" + text2, nrow=len(sampled))

 

マルチ条件付き

条件を調整してより多くのキャプションを追加することもできます。条件の構造は [[caption_1, end_token_1], [caption_2, end_token_2], ...] です。下のサンプルでは、最初のキャプションがトークン 20 まで潜在的画像を水平に満たし、それからトークン 60 まで次のキャプションが来ます、等々。

mode = "multiconditioning"
batch_size = 4
latent_shape = (32, 100)
conditions = [
    ["an oil painting of a lighthouse standing on a hill", 20],
    ["an oil painting of a majestic boat sailing on the water during a storm. front view", 60],
    ["an oil painting of a majestic castle standing by the water", 100],
]
clip_embedding = torch.zeros(batch_size, 1024, *latent_shape).to(device)
last_pos = 0
for text, pos in conditions:
    tokenized_text = tokenizer.tokenize([text] * batch_size).to(device)
    part_clip_embedding = clip_model.encode_text(tokenized_text).float()[:, :, None, None]
    print(f"{last_pos}:{pos}={text}")
    clip_embedding[:, :, :, last_pos:pos] = part_clip_embedding
    last_pos = pos
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        sampled = sample(model, clip_embedding, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],
           typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,
           renoise_mode="start")
    sampled = decode(sampled[-1], latent_shape)
    
showimages(sampled)
saveimages(sampled, mode + "_" + ":".join(list(map(lambda x: x[0], conditions))), nrow=batch_size)

batch_size = 4
latent_shape = (32, 32)
text_a = "a cute portrait of a dog"
text_b = "a cute portrait of a cat"
mode = "vertical"
# mode = "horizontal"
text = tokenizer.tokenize([text_a, text_b] * batch_size).to(device)

with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        clip_embeddings = clip_model.encode_text(text).float()[:, :, None, None].expand(-1, -1, latent_shape[0], latent_shape[1])
        if mode == 'vertical':
            interp_mask = torch.linspace(0, 1, latent_shape[0], device=device)[None, None, :, None].expand(batch_size, 1, -1, latent_shape[1])
        else: 
            interp_mask = torch.linspace(0, 1, latent_shape[1], device=device)[None, None, None, :].expand(batch_size, 1, latent_shape[0], -1)
        # LERP
        clip_embeddings = clip_embeddings[0::2] * (1-interp_mask) + clip_embeddings[1::2] * interp_mask
        # # SLERP
        # low, high = clip_embeddings[0::2], clip_embeddings[1::2]
        # low_norm = low/torch.norm(low, dim=1, keepdim=True)
        # high_norm = high/torch.norm(high, dim=1, keepdim=True)
        # omega = torch.acos((low_norm*high_norm).sum(1)).unsqueeze(1)
        # so = torch.sin(omega)
        # clip_embeddings = (torch.sin((1.0-interp_mask)*omega)/so)*low + (torch.sin(interp_mask*omega)/so) * high
    
        sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],
           typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11,
           renoise_mode="start")
    sampled = decode(sampled[-1], latent_shape)

showimages(sampled)

 

画像のロード : ディスク or Web

images = preprocess(Image.open("path_to_image")).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]
showimages(images)
url = "https://media.istockphoto.com/id/1193591781/photo/obedient-dog-breed-welsh-corgi-pembroke-sitting-and-smiles-on-a-white-background-not-isolate.jpg?s=612x612&w=0&k=20&c=ZDKTgSFQFG9QvuDziGsnt55kvQoqJtIhrmVRkpYqxtQ="
# url = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1200px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg"
response = requests.get(url)
img = Image.open(BytesIO(response.content)).convert("RGB")
images = preprocess(img).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]
showimages(images)

 

インペインティング

mode = "inpainting"
text = "a delicious spanish paella"
tokenized_text = tokenizer.tokenize([text] * images.shape[0]).to(device)
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        # clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float() # clip_embeddings = clip_model.encode_text(text).float()
        clip_embeddings = clip_model.encode_text(tokenized_text).float()
        encoded_tokens = encode(images)
        latent_shape = encoded_tokens.shape[1:]
        mask = torch.zeros_like(encoded_tokens)
        mask[:,5:28,5:28] = 1
        sampled = sample(model, clip_embeddings, x=encoded_tokens, mask=mask, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],
               typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=6, renoise_steps=11)
    sampled = decode(sampled[-1], latent_shape)

showimages(images[0:1], height=10, width=10)
showmask(mask[0:1])
showimages(sampled, height=16, width=16)
saveimages(torch.cat([images[0:1], sampled]), mode + "_" + text, nrow=images.shape[0]+1)

 

アウトペインティング

mode = "outpainting"
size = (40, 64)
top_left = (0, 16)
text = "black & white photograph of a rocket from the bottom."
tokenized_text = tokenizer.tokenize([text] * images.shape[0]).to(device)
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        # clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float()
        clip_embeddings = clip_model.encode_text(tokenized_text).float()
        encoded_tokens = encode(images)
        canvas = torch.zeros((images.shape[0], *size), dtype=torch.long).to(device)
        canvas[:, top_left[0]:top_left[0]+encoded_tokens.shape[1], top_left[1]:top_left[1]+encoded_tokens.shape[2]] = encoded_tokens
        mask = torch.ones_like(canvas)
        mask[:, top_left[0]:top_left[0]+encoded_tokens.shape[1], top_left[1]:top_left[1]+encoded_tokens.shape[2]] = 0
        sampled = sample(model, clip_embeddings, x=canvas, mask=mask, T=12, size=size, starting_t=0, temp_range=[1.0, 1.0],
               typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)
    sampled = decode(sampled[-1], size)

showimages(images[0:1], height=10, width=10)
showmask(mask[0:1])
showimages(sampled, height=16, width=16)
saveimages(sampled, mode + "_" + text, nrow=images.shape[0])

 

構造的モーフィング

mode = "morphing"
max_steps = 24
init_step = 8

text = "A fox posing for a photo. stock photo. highly detailed. 4k"

with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        # images = preprocess(Image.open("data/city sketch.png")).unsqueeze(0).expand(4, -1, -1, -1).to(device)[:, :3]
        latent_image = encode(images)
        latent_shape = latent_image.shape[-2:]
        r = torch.ones(latent_image.size(0), device=device) * (init_step/max_steps)
        noised_latent_image, _ = model.add_noise(latent_image, r)
        
        tokenized_text = tokenizer.tokenize([text] * images.size(0)).to(device)
        clip_embeddings = clip_model.encode_text(tokenized_text).float()
        
        sampled = sample(model, clip_embeddings, x=noised_latent_image, T=max_steps, size=latent_shape, starting_t=init_step, temp_range=[1.0, 1.0],
                   typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=6, renoise_steps=max_steps-1,
                    renoise_mode="prev")
    sampled = decode(sampled[-1], latent_shape)
showimages(sampled)
showimages(images)
saveimages(torch.cat([images[0:1], sampled]), mode + "_" + text, nrow=images.shape[0]+1)

 

画像バリエーション

Note : 画像バリエーションが正しく動作するためには、画像埋め込み上で微調整されたチェックポイントをダウンロードする必要があります。

clip_preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
    torchvision.transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)),
])
latent_shape = (32, 32)
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        clip_embeddings = clip_model.encode_image(clip_preprocess(images)).float()       # clip_embeddings = clip_model.encode_text(text).float() 
        sampled = sample(model, clip_embeddings, T=12, size=latent_shape, starting_t=0, temp_range=[1.0, 1.0],
               typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=5, renoise_steps=11)
    sampled = decode(sampled[-1], latent_shape)

showimages(images)
showimages(sampled)

 

実験的: コンセプト学習

def text_encode(x, clip_model, insertion_index):
    # x = x.type(clip_model.dtype)
    x = x + clip_model.positional_embedding
    x = x.permute(1, 0, 2)  # NLD -> LND
    x = clip_model.transformer(x)
    x = x.permute(1, 0, 2)  # LND -> NLD
    x = clip_model.ln_final(x)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), insertion_index] @ clip_model.text_projection

    return x
from torch.optim import AdamW
batch_size = 1
asteriks_emb = clip_model.token_embedding(tokenizer.tokenize(["*"]).to(device))[0][1]
context_word = torch.randn(batch_size, 1, asteriks_emb.shape[-1]).to(device)
context_word.requires_grad_(True)
optim = AdamW(params=[context_word], lr=0.1)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
import requests
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

_preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize(256),
    torchvision.transforms.CenterCrop(256),
    torchvision.transforms.ToTensor(),
])

urls = [
    "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcStVHtFcMqIP4xuDYn8n_FzPDKjPtP_iTSbOQ&usqp=CAU",
    "https://i.insider.com/58d919eaf2d0331b008b4bbd?width=700",
    "https://media.cntraveler.com/photos/5539216cab60aad20f3f3aaa/16:9/w_2560%2Cc_limit/eiffel-tower-paris-secret-apartment.jpg",
    "https://static.independent.co.uk/s3fs-public/thumbnails/image/2014/03/25/12/eiffel.jpg?width=1200"
]
images = []
for url in urls:
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    images.append(_preprocess(img))

data = torch.stack(images)
dataset = DataLoader(TensorDataset(data), batch_size=1, shuffle=True)
loader = iter(dataset)
steps = 100
total_loss = 0
total_acc = 0
pbar = tqdm(range(steps))
for i in pbar:
    try:
        images = next(loader)[0]
    except StopIteration:
        loader = iter(dataset)
        images = next(loader)[0]
    images = images.to(device)
    text = "a photo of *"
    tokenized_text = tokenizer.tokenize([text]).to(device)
    insertion_index = tokenized_text.argmax(dim=-1)
    neutral_text_encoded = clip_model.token_embedding(tokenized_text)
    insertion_idx = torch.where(neutral_text_encoded == asteriks_emb)[1].unique()
    neutral_text_encoded[:, insertion_idx, :] = context_word
    clip_embeddings = text_encode(neutral_text_encoded, clip_model, insertion_index)
    with torch.no_grad():
        image_indices = encode(images)
        r = torch.rand(images.size(0), device=device)
        noised_indices, mask = model.add_noise(image_indices, r)

    # with torch.autocast(device_type="cuda"):
    pred = model(noised_indices, clip_embeddings, r)
    loss = criterion(pred, image_indices)
    
    loss.backward()
    optim.step()
    optim.zero_grad()
    
    acc = (pred.argmax(1) == image_indices).float()  # .mean()
    acc = acc.mean()

    total_loss += loss.item()
    total_acc += acc.item()
    pbar.set_postfix({"total_loss": total_loss / (i+1), "total_acc": total_acc / (i+1)})
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        sampled = sample(model, clip_embeddings.expand(4, -1), T=12, size=(32, 32), starting_t=0, temp_range=[1., 1.],
               typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)
    sampled = decode(sampled[-1])

plt.figure(figsize=(32, 32))
plt.axis("off")
plt.imshow(torch.cat([
    torch.cat([i for i in images.expand(4, -1, -1, -1).cpu()], dim=-1),
    torch.cat([i for i in sampled.cpu()], dim=-1),
], dim=-2).permute(1, 2, 0).cpu())
plt.show()

text = "* at night"
tokenized_text = tokenizer.tokenize([text]).to(device)
insertion_index = tokenized_text.argmax(dim=-1)
neutral_text_encoded = clip_model.token_embedding(tokenized_text)
insertion_idx = torch.where(neutral_text_encoded == asteriks_emb)[1].unique()
neutral_text_encoded[:, insertion_idx, :] = context_word
clip_embeddings = text_encode(neutral_text_encoded, clip_model, insertion_index)
with torch.inference_mode():
    with torch.autocast(device_type="cuda"):
        sampled = sample(model, clip_embeddings.expand(4, -1), T=12, size=(32, 32), starting_t=0, temp_range=[1., 1.],
               typical_filtering=True, typical_mass=0.2, typical_min_tokens=1, classifier_free_scale=4, renoise_steps=11)
    sampled = decode(sampled[-1])

plt.figure(figsize=(32, 32))
plt.axis("off")
plt.imshow(torch.cat([
    torch.cat([i for i in images.expand(4, -1, -1, -1).cpu()], dim=-1),
    torch.cat([i for i in sampled.cpu()], dim=-1),
], dim=-2).permute(1, 2, 0).cpu())
plt.show()

 

以上