Shadhil's picture
Upload 425 files
1a79cb6 verified
import torch
import os
from PIL import Image
import random
import numpy as np
import pickle
import torchvision.transforms as transforms
from .celeba import CelebADataset
def create_dataloader(opt):
data_loader = DataLoader()
data_loader.initialize(opt)
return data_loader
class DataLoader:
def name(self):
return self.dataset.name() + "_Loader"
def create_datase(self):
# specify which dataset to load here
loaded_dataset = os.path.basename(self.opt.data_root.strip('/')).lower()
if 'celeba' in loaded_dataset or 'emotion' in loaded_dataset:
dataset = CelebADataset()
else:
dataset = BaseDataset()
dataset.initialize(self.opt)
return dataset
def initialize(self, opt):
self.opt = opt
self.dataset = self.create_datase()
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=not opt.serial_batches,
num_workers=int(opt.n_threads)
)
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data