yyk19's picture
first trial
0902a5f
raw
history blame
2 kB
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
from PIL import Image, ImageFile
from pathlib import Path
from functools import partial
from torchvision import transforms as T, utils
from torch import nn
def exists(val):
return val is not None
def cycle(dl):
while True:
for data in dl:
yield data
def convert_image_to(img_type, image):
if image.mode != img_type:
return image.convert(img_type)
return image
class Txt2ImgIterableBaseDataset(IterableDataset):
'''
Define an interface to make the IterableDatasets for text2img data chainable
'''
def __init__(self, num_records=0, valid_ids=None, size=256):
super().__init__()
self.num_records = num_records
self.valid_ids = valid_ids
self.sample_ids = valid_ids
self.size = size
# print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
# def __len__(self):
# return self.num_records
@abstractmethod
def __iter__(self):
pass
class BaseDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png', 'tiff'],
convert_image_to_type = None
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()
self.transform = T.Compose([
T.Lambda(convert_fn),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)