akhaliq3
New message for the combined commit
833ef7e
raw
history blame contribute delete
998 Bytes
import os
import numpy as np
import albumentations
from torch.utils.data import Dataset
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
class CustomBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None
def __len__(self):
return len(self.data)
def __getitem__(self, i):
example = self.data[i]
return example
class CustomTrain(CustomBase):
def __init__(self, size, training_images_list_file):
super().__init__()
with open(training_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
class CustomTest(CustomBase):
def __init__(self, size, test_images_list_file):
super().__init__()
with open(test_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)