import torch import torchvision import numpy as np import os.path as osp from PIL import Image import torchvision import torchvision.transforms as TF def pair(t): return t if isinstance(t, tuple) else (t, t) def center_crop_arr(pil_image, image_size): """ Center cropping implementation from ADM. https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 """ while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) def vae_transforms(split, aug='randcrop', img_size=256): t = [] if split == 'train': if aug == 'randcrop': t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True)) t.append(TF.RandomCrop(img_size)) elif aug == 'centercrop': t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size))) else: raise ValueError(f"Invalid augmentation: {aug}") t.append(TF.RandomHorizontalFlip(p=0.5)) else: t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size))) t.append(TF.ToTensor()) return TF.Compose(t) def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]): t = [] if 'centercrop' in aug: t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size))) t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))]))) elif 'tencrop' in aug: crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges] t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes])) t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple])) t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops]))) else: raise ValueError(f"Invalid augmentation: {aug}") return TF.Compose(t) class ImageNet(torchvision.datasets.ImageFolder): def __init__(self, root, split='train', aug='randcrop', img_size=256): super().__init__(osp.join(root, split)) if not 'cache' in aug: self.transform = vae_transforms(split, aug=aug, img_size=img_size) else: self.transform = cached_transforms(aug=aug, img_size=img_size)