Spaces:
Sleeping
Sleeping
File size: 1,545 Bytes
4d92358 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import numpy as np
from PIL import Image
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
SIZE = 256
class ColorizationDataset(Dataset):
def __init__(self, paths, split='train'):
if split == 'train':
self.transforms = transforms.Compose([
transforms.Resize((SIZE, SIZE), Image.BICUBIC),
transforms.RandomHorizontalFlip(), # A little data augmentation!
])
elif split == 'val':
self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC)
self.split = split
self.size = SIZE
self.paths = paths
def __getitem__(self, idx):
img = Image.open(self.paths[idx]).convert("RGB")
img = self.transforms(img)
img = np.array(img)
img_lab = rgb2lab(img).astype("float32") # Converting RGB to L*a*b
img_lab = transforms.ToTensor()(img_lab)
L = img_lab[[0], ...] / 50. - 1. # Between -1 and 1
ab = img_lab[[1, 2], ...] / 110. # Between -1 and 1
return {'L': L, 'ab': ab}
def __len__(self):
return len(self.paths)
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): # A handy function to make our dataloaders
dataset = ColorizationDataset(**kwargs)
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers,
pin_memory=pin_memory)
return dataloader |