Spaces:
Runtime error
Runtime error
File size: 1,316 Bytes
7fab858 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.utils.data
import random
from data.base_data_loader import BaseDataLoader
from data import online_dataset_for_old_photos as dts_ray_bigfile
def CreateDataset(opt):
dataset = None
if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':
dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
if opt.training_dataset=='mapping':
if opt.random_hole:
dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
else:
dataset = dts_ray_bigfile.PairOldPhotos()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads),
drop_last=True)
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
|