Spaces:
Sleeping
Sleeping
File size: 2,638 Bytes
17191f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
#based on https://github.com/CompVis/taming-transformers
import pickle
from torch.utils.data import Dataset
from ldm.data.base import ImagePaths
import ldm.data.constants as CONSTANTS
class CustomBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None
def __len__(self):
return len(self.data)
def __getitem__(self, i):
example = self.data[i]
return example
class CustomTrain(CustomBase):
def __init__(self, size, training_images_list_file, horizontalflip=False, random_contrast=False, shiftrotate=False, add_labels=False, unique_skipped_labels=[], class_to_node=None):
super().__init__()
with open(training_images_list_file, "r") as f:
paths = sorted(f.read().splitlines())
labels=None
if add_labels:
labels_per_file = list(map(lambda path: path.split('/')[-2], paths))
labels_set = sorted(list(set(labels_per_file)))
self.labels_to_idx = {label_name: i for i, label_name in enumerate(labels_set)}
if class_to_node:
with open(class_to_node, 'rb') as pickle_file:
class_to_node_dict = pickle.load(pickle_file)
labels = {
CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
CONSTANTS.DATASET_CLASSNAME: labels_per_file,
'class_to_node': [class_to_node_dict[label_name] for label_name in labels_per_file]
}
# labels = [self.labels_to_idx[label_name] for label_name in labels_per_file]
else:
labels = {
CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
CONSTANTS.DATASET_CLASSNAME: labels_per_file
}
self.indx_to_label = {v: k for k, v in self.labels_to_idx.items()}
self.data = ImagePaths(paths=paths, size=size, random_crop=False, horizontalflip=horizontalflip,
random_contrast=random_contrast, shiftrotate=shiftrotate, labels=labels,
unique_skipped_labels=unique_skipped_labels)
class CustomTest(CustomTrain):
def __init__(self, size, test_images_list_file, add_labels=False, unique_skipped_labels=[], class_to_node=None):
super().__init__(size, test_images_list_file, add_labels=add_labels,
unique_skipped_labels=unique_skipped_labels, class_to_node=class_to_node)
|