Colorization / dataset.py
ChiKyi's picture
update file
4d92358
raw
history blame
1.55 kB
import numpy as np
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
SIZE = 256
class ColorizationDataset(Dataset):
def __init__(self, paths, split='train'):
if split == 'train':
self.transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE), Image.BICUBIC),
transforms.RandomHorizontalFlip(), # A little data augmentation!
])
elif split == 'val':
self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
self.split = split
self.size = SIZE
self.paths = paths
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("RGB")
img = self.transforms(img)
img = np.array(img)
img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
img_lab = transforms.ToTensor()(img_lab)
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
return {'L': L, 'ab': ab}
def __len__(self):
return len(self.paths)
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
dataset = ColorizationDataset(**kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
pin_memory=pin_memory)
return dataloader