In [3]:
import os
import zipfile
from PIL import Image
from io import BytesIO
import numpy as np
from datasets import load_dataset
import torch
from diffusers import AutoencoderKL, UNet2DModel, UNet2DConditionModel
import pickle

In [2]:
def preprocess_image(image):
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=Image.Resampling.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    return 2.0 * image - 1.0

def vae_embedding(preprocessed, num_samples=5, device="cuda"):
    with torch.no_grad():
        processed_image = preprocessed.to(device=device)
        latent_dist = vae.encode(processed_image).latent_dist
        t = [0.18215*latent_dist.sample().to("cpu").squeeze() for i in range(num_samples)] # sample num_samples latent vecs
        t = torch.stack(t) # stack them
        return torch.mean(t, axis=0).numpy() #average them. output shape: (4,64,64)

In [3]:
%%time
# dset = load_dataset("./dataset.py", ignore_verifications=True) This uses the loading script and loads data from the zipped folders
dset = load_dataset("SDbiaseval/dataset")
ds = dset["train"]

Using custom data configuration SDbiaseval--dataset-cc8e38e46c1acd54
Found cached dataset parquet (/mnt/1da05489-3812-4f15-a6e5-c8d3c57df39e/cache/huggingface/SDbiaseval___parquet/SDbiaseval--dataset-cc8e38e46c1acd54/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

CPU times: user 196 ms, sys: 23.3 ms, total: 219 ms
Wall time: 2.51 s


In [4]:
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae");
vae.eval()
vae.to("cuda");

In [5]:
ix = np.arange(len(ds))
np.random.shuffle(ix)
batch_size = 16
batche_indices = np.array_split(ix, np.ceil(len(ix)/batch_size))
print(f"{len(batche_indices)} batches of {batch_size}. Last batch of size {len(batche_indices[-1])}.")

19554 batches of 16. Last batch of size 15.


In [15]:
%%time
embs = []
for i in batche_indices:
    imx = ds.select(i)["image"]
    preprocessed = np.concatenate([preprocess_image(im) for im in imx])
    emb = vae_embedding(torch.from_numpy(preprocessed), num_samples=10)
    embs.append(emb)

CPU times: user 9h 52min 30s, sys: 2min 25s, total: 9h 54min 55s
Wall time: 7h 54min 48s


In [16]:
with open('embs.pkl', 'wb') as f:
    pickle.dump(embs, f)

In [None]:
embeddings = np.concatenate(embs)

In [4]:
with open("embs.pkl", "rb") as f:
    embeddings = pickle.load(f)

In [5]:
%%time
embeddings = np.concatenate(embeddings)

CPU times: user 3.82 s, sys: 4.34 s, total: 8.16 s
Wall time: 8.2 s


In [6]:
embeddings.shape

(312860, 4, 64, 64)

In [7]:
%%time
with open('vae_embeddings.npy', 'wb') as f:
    np.save(f, embeddings)

CPU times: user 0 ns, sys: 10.3 s, total: 10.3 s
Wall time: 10.3 s
