|
import importlib |
|
import torch.utils.data |
|
from data.base_data_loader import BaseDataLoader |
|
from data.base_dataset import BaseDataset |
|
|
|
|
|
def find_dataset_using_name(dataset_name): |
|
|
|
|
|
|
|
dataset_filename = "data." + dataset_name + "_dataset" |
|
datasetlib = importlib.import_module(dataset_filename) |
|
|
|
|
|
|
|
|
|
dataset = None |
|
target_dataset_name = dataset_name.replace('_', '') + 'dataset' |
|
for name, cls in datasetlib.__dict__.items(): |
|
if name.lower() == target_dataset_name.lower() \ |
|
and issubclass(cls, BaseDataset): |
|
dataset = cls |
|
|
|
if dataset is None: |
|
print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) |
|
exit(0) |
|
|
|
return dataset |
|
|
|
|
|
def get_option_setter(dataset_name): |
|
dataset_class = find_dataset_using_name(dataset_name) |
|
return dataset_class.modify_commandline_options |
|
|
|
|
|
def create_dataset(opt): |
|
dataset = find_dataset_using_name(opt.dataset_mode) |
|
instance = dataset() |
|
instance.initialize(opt) |
|
print("dataset [%s] was created" % (instance.name())) |
|
return instance |
|
|
|
|
|
def CreateDataLoader(opt): |
|
data_loader = CustomDatasetDataLoader() |
|
data_loader.initialize(opt) |
|
return data_loader |
|
|
|
|
|
|
|
|
|
class CustomDatasetDataLoader(BaseDataLoader): |
|
def name(self): |
|
return 'CustomDatasetDataLoader' |
|
|
|
def initialize(self, opt): |
|
BaseDataLoader.initialize(self, opt) |
|
self.dataset = create_dataset(opt) |
|
self.dataloader = torch.utils.data.DataLoader( |
|
self.dataset, |
|
batch_size=opt.batch_size, |
|
shuffle=not opt.serial_batches, |
|
num_workers=int(opt.num_threads)) |
|
|
|
def load_data(self): |
|
return self |
|
|
|
def __len__(self): |
|
return min(len(self.dataset), self.opt.max_dataset_size) |
|
|
|
def __iter__(self): |
|
for i, data in enumerate(self.dataloader): |
|
if i * self.opt.batch_size >= self.opt.max_dataset_size: |
|
break |
|
yield data |
|
|