import numpy as np from PIL import Image from skimage.color import rgb2lab, lab2rgb import torch from torch import nn, optim from torchvision import transforms from torch.utils.data import Dataset, DataLoader SIZE = 256 class ColorizationDataset(Dataset): def __init__(self, paths, split='train'): if split == 'train': self.transforms = transforms.Compose([ transforms.Resize((SIZE, SIZE), Image.BICUBIC), transforms.RandomHorizontalFlip(), # A little data augmentation! ]) elif split == 'val': self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC) self.split = split self.size = SIZE self.paths = paths def __getitem__(self, idx): img = Image.open(self.paths[idx]).convert("RGB") img = self.transforms(img) img = np.array(img) img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b img_lab = transforms.ToTensor()(img_lab) L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1 ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1 return {'L': L, 'ab': ab} def __len__(self): return len(self.paths) def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders dataset = ColorizationDataset(**kwargs) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, pin_memory=pin_memory) return dataloader