# vqgan-jax-encoding-yfcc100m

Same as `vqgan-jax-encoding-with-captions`, but for YFCC100M.

This dataset was prepared by @borisdayma in Json lines format.

In [1]:
import io

import requests
from PIL import Image
import numpy as np
from tqdm import tqdm

import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets.folder import default_loader

import jax
from jax import pmap

## VQGAN-JAX model

`dalle_mini` is a local package that contains the VQGAN-JAX model and other utilities.

In [2]:
from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel

We'll use a VQGAN trained by using Taming Transformers and converted to a JAX model.

**Disabling** Does not work in my local system right now.

In [3]:
#model = VQModel.from_pretrained("flax-community/vqgan_f16_16384")

## Dataset

In [79]:
import pandas as pd
from pathlib import Path

In [80]:
yfcc100m = Path('/sddata/dalle-mini/YFCC100M_OpenAI_subset')
# Images are 'sharded' from the following directory
yfcc100m_images = yfcc100m/'data'/'images'
yfcc100m_metadata = yfcc100m/'metadata_YFCC100M.jsonl'
yfcc100m_output = yfcc100m/'metadata_encoded.jsonl'

### Cleanup

We need to select entries with images that exist. Otherwise we can't build batches because `Dataloader` does not support `None` in batches. We use Huggingface Datasets, I understand they support threaded reading of jsonl files, and I was running out of memory when using pandas.

In [81]:
import datasets
from datasets import Dataset, load_dataset

In [82]:
dataset = load_dataset("json", data_files=[str(yfcc100m_metadata)])

Using custom data configuration default-57592e8ed16d752b
Reusing dataset json (/home/pedro/.cache/huggingface/datasets/json/default-57592e8ed16d752b/0.0.0/793d004298099bd3c4e61eb7878475bcf1dc212bf2e34437d85126758720d7f9)


In [83]:
dataset = dataset['train']
dataset

Dataset({
    features: ['photoid', 'uid', 'unickname', 'datetaken', 'dateuploaded', 'capturedevice', 'title', 'description', 'usertags', 'machinetags', 'longitude', 'latitude', 'accuracy', 'pageurl', 'downloadurl', 'licensename', 'licenseurl', 'serverid', 'farmid', 'secret', 'secretoriginal', 'ext', 'marker', 'key', 'title_clean', 'description_clean'],
    num_rows: 14825233
})

In [84]:
def image_exists(root: str, name: str, ext: str):
    image_path = (Path(root)/name[0:3]/name[3:6]/name).with_suffix(ext)
    return image_path.exists()

In [90]:
def select_existing_rows(examples):
    # Select lists we want to keep
    keys = examples['key']
    titles_clean = examples['title_clean']
    descriptions_clean = examples.get('description_clean', '')
    exts = examples['ext']
    
    result = {'key': [], 'title_clean': [], 'description_clean': [], 'ext': []}
    for i, image_name in enumerate(keys):
        print(i, image_name)
        if image_exists(root=str(yfcc100m_images), name=image_name, ext='.' + exts[i]):
            result["key"].append(image_name)
            result["title_clean"].append(titles_clean[i])
            result["description_clean"].append(descriptions_clean[i])
            result["ext"].append(exts[i])
    print(f'returning {len(result["key"])}')
    return result

In [91]:
filtered_dataset = dataset.map(
    select_existing_rows,
    remove_columns = dataset.column_names,
    batched = True,
    num_proc = 1,
    desc = "Selecting rows with images that exist"
)

Selecting rows with images that exist:   0%|          | 0/14826 [00:00<?, ?ba/s]

0 d29e7c6a3028418c64eb15e3cf577c2
1 d29f01b149167d683f9ddde464bb3db
2 d296e9e34bdae41edb6c679ff824ab2a
3 d29ce96395848478b1e8396e44899
4 d29abf32c4e12ff881f975b70e0cec0
5 d298a61f2f7be6c9e2c2af81755b489
6 d29b1b973ab1a95a37cd4cda37999fb
7 d290d566266ad568e94128d4135b41a
8 d29b1ac2a497b0d9a4a43c3a51d13fb
9 d29ebe6c96f53b2f5d7f5eed9b2b2898
10 d29ec1b3f75749a231ee1d9d206baf6e
11 d290bee419ce98d9a79ccf512a47a79
12 d29bc1eff62a477131516c40a54f2dce
13 d292a123bcf58e13128d2067593d81
14 d294424637d532d8cfbcf2ca99b85f
15 d29a51d8502f531115b108d59c811ab
16 d29a9f0fce210c7e050877a53697031
17 d290c750469f11795ed85fa62e4b52
18 d29e13badf42d839b421478be4452dbe
19 d29c1d635348aa35474a90f57aafb7
20 d291a7c7c71455d5b3cdd97ca5e4c
21 d295f95d7cb204dc812a476af5f4f8a
22 d2932ecd1053165aa3d7b9e68547e0b6
23 d29cd5a4b1d6a759b63df357ef2b
24 d294e885117ca7d9b328c5b9388f52
25 d2999b54832bb275a7e2eea47e98f11
26 d29f89d491812beb84e62223b4541d7
27 d2993599afe456ba786060129fc9cdfd
28 d290ceb78d0f7c8c49930cd96b12b27


IndexError: index out of bounds

In [109]:
# df['image_exists'] = df.apply(lambda row: image_exists(row['key']), axis=1)

In [113]:
image_size = 256
def image_transform(image):
    s = min(image.size)
    r = image_size / s
    s = (round(r * image.size[1]), round(r * image.size[0]))
    image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
    image = TF.center_crop(image, output_size = 2 * [image_size])
    image = torch.unsqueeze(T.ToTensor()(image), 0)
    image = image.permute(0, 2, 3, 1).numpy()
    return image

In [98]:
class YFC100Dataset(Dataset):
    def __init__(self, image_list_path: str, images_root: str, image_size: int, max_items=None):
        """
        :param image_list_path: Path to a file containing a list of all images, in jsonl format.
        :param images_root: Root directory containing the images
        :param image_size: Image size. Source images will be resized and center-cropped.
        :max_items: Limit dataset size for debugging
        """
        self.image_list = pd.read_json(image_list_path, orient="records", lines=True)
        self.images_root = Path(images_root)
        if max_items is not None: self.image_list = self.image_list[:max_items]
        self.image_size = image_size
        
    def __len__(self):
        return len(self.image_list)
    
    def _get_raw_image(self, i):
        image_name = self.image_list.iloc[0].key
        image_path = (self.images_root/image_name[0:3]/image_name[3:6]/image_name).with_suffix('.jpg')
        return default_loader(image_path) if image_path.exists() else None
    
    # TODO: we could maybe use jax resizing / scaling functions
    def resize_image(self, image):
        s = min(image.size)
        r = self.image_size / s
        s = (round(r * image.size[1]), round(r * image.size[0]))
        image = TF.resize(image, s, interpolation=InterpolationMode.LANCZOS)
        image = TF.center_crop(image, output_size = 2 * [self.image_size])
        image = torch.unsqueeze(T.ToTensor()(image), 0)
        image = image.permute(0, 2, 3, 1).numpy()
        return image
    
    def _get_caption(self, i):
        # We are currently appending title and caption. Should we use another separator?
        row = self.image_list.iloc[i]
        return ' '.join(row.title_clean, row.description_clean)
    
    def __getitem__(self, i):
        image = self._get_raw_image(i)
        if image is None: return None
        image = self.resize_image(image)
        caption = self._get_caption(i)
        return {'image': image, 'text': caption}

In [99]:
dataset = YFC100Dataset(
    image_list_path = yfc100m_metadata,
    images_root = yfc100m_images,
    image_size = 256,
)

In [100]:
len(dataset)

5000

In [102]:
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)

In [103]:
next(iter(dataloader))

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 86, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>


## Encoding

In [89]:
def encode(model, batch):
    print("jitting encode function")
#     _, indices = model.encode(batch)

    # The model does not run in my computer (no cudNN currently installed) - faking it
    indices = [random.randint(0, 16384) for _ in range(256)]
    return indices

In [90]:
def superbatch_generator(dataloader, num_tpus):
    iter_loader = iter(dataloader)
    for batch in iter_loader:
        superbatch = [batch.squeeze(1)]
        try:
            for b in range(num_tpus-1):
                batch = next(iter_loader)
                if batch is None:
                    break
                # Skip incomplete last batch
                if batch.shape[0] == dataloader.batch_size:
                    superbatch.append(batch.squeeze(1))
        except StopIteration:
            pass
        superbatch = torch.stack(superbatch, axis=0)
        yield superbatch

In [93]:
import os
import jax

def encode_captioned_dataset(dataset, output_jsonl, batch_size=32, num_workers=16):
    if os.path.isfile(output_jsonl):
        print(f"Destination file {output_jsonl} already exists, please move away.")
        return
    
    num_tpus = jax.device_count()
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    superbatches = superbatch_generator(dataloader, num_tpus=num_tpus)
    
    p_encoder = pmap(lambda batch: encode(model, batch))

    # We save each superbatch to avoid reallocation of buffers as we process them.
    # We keep the file open to prevent excessive file seeks.
    with open(output_jsonl, "w") as file:
        iterations = len(dataset) // (batch_size * num_tpus)
        for n in tqdm(range(iterations)):
            superbatch = next(superbatches)
            encoded = p_encoder(superbatch.numpy())
            encoded = encoded.reshape(-1, encoded.shape[-1])

            # Extract fields from the dataset internal `captions` property, and save to disk
            start_index = n * batch_size * num_tpus
            end_index = (n+1) * batch_size * num_tpus
            paths = dataset.captions["image_file"][start_index:end_index].values
            captions = dataset.captions["caption"][start_index:end_index].values
            encoded_as_string = list(map(lambda item: np.array2string(item, separator=',', max_line_width=50000, formatter={'int':lambda x: str(x)}), encoded))
            batch_df = pd.DataFrame.from_dict({"image_file": paths, "caption": captions, "encoding": encoded_as_string})
            batch_df = batch_df.dropna()
            batch_df.to_json(file, orient='records', lines=True, index=None)
            

In [94]:
encode_captioned_dataset(dataset, yfc100m_output, batch_size=64, num_workers=16)

  0%|                                                                                        | 0/78 [00:00<?, ?it/s]


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 287, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
    return self.collate_fn(data)
  File "/home/pedro/miniconda3/envs/hf_jax/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 86, in default_collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'NoneType'>


----