File size: 2,916 Bytes
7b0a1ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
import torchvision
import numpy as np
import os.path as osp
from PIL import Image
import torchvision
import torchvision.transforms as TF

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def center_crop_arr(pil_image, image_size):
    """

    Center cropping implementation from ADM.

    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126

    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])

def vae_transforms(split, aug='randcrop', img_size=256):
    t = []
    if split == 'train':
        if aug == 'randcrop':
            t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
            t.append(TF.RandomCrop(img_size))
        elif aug == 'centercrop':
            t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
        else:
            raise ValueError(f"Invalid augmentation: {aug}")
        t.append(TF.RandomHorizontalFlip(p=0.5))
    else:
        t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
        
    t.append(TF.ToTensor())

    return TF.Compose(t)


def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]):
    t = []
    if 'centercrop' in aug:
        t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
        t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))])))
    elif 'tencrop' in aug:
        crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges]
        t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes]))
        t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple]))
        t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops])))
    else:
        raise ValueError(f"Invalid augmentation: {aug}")

    return TF.Compose(t)

class ImageNet(torchvision.datasets.ImageFolder):
    def __init__(self, root, split='train', aug='randcrop', img_size=256):
        super().__init__(osp.join(root, split))
        if not 'cache' in aug:
            self.transform = vae_transforms(split, aug=aug, img_size=img_size)
        else:
            self.transform = cached_transforms(aug=aug, img_size=img_size)