File size: 2,638 Bytes
17191f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#based on https://github.com/CompVis/taming-transformers

import pickle
from torch.utils.data import Dataset
from ldm.data.base import ImagePaths
import ldm.data.constants as CONSTANTS



class CustomBase(Dataset):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.data = None

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):   
        example = self.data[i]
        return example

class CustomTrain(CustomBase):
    def __init__(self, size, training_images_list_file, horizontalflip=False, random_contrast=False, shiftrotate=False, add_labels=False, unique_skipped_labels=[], class_to_node=None):
        super().__init__()
        with open(training_images_list_file, "r") as f:
            paths = sorted(f.read().splitlines())
    
        labels=None
        if add_labels:
            labels_per_file = list(map(lambda path: path.split('/')[-2], paths))
            labels_set = sorted(list(set(labels_per_file)))
            self.labels_to_idx = {label_name: i for i, label_name in enumerate(labels_set)}

            if class_to_node:
                with open(class_to_node, 'rb') as pickle_file:
                    class_to_node_dict = pickle.load(pickle_file)
                labels = {
                    CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
                    CONSTANTS.DATASET_CLASSNAME: labels_per_file,
                    'class_to_node': [class_to_node_dict[label_name] for label_name in labels_per_file]
                }
                # labels = [self.labels_to_idx[label_name] for label_name in labels_per_file]

            else:
                labels = {
                    CONSTANTS.DISENTANGLER_CLASS_OUTPUT: [self.labels_to_idx[label_name] for label_name in labels_per_file],
                    CONSTANTS.DATASET_CLASSNAME: labels_per_file
                }
                
        self.indx_to_label = {v: k for k, v in self.labels_to_idx.items()}

        self.data = ImagePaths(paths=paths, size=size, random_crop=False, horizontalflip=horizontalflip, 
                               random_contrast=random_contrast, shiftrotate=shiftrotate, labels=labels, 
                               unique_skipped_labels=unique_skipped_labels)


class CustomTest(CustomTrain):
    def __init__(self, size, test_images_list_file, add_labels=False, unique_skipped_labels=[], class_to_node=None):
        super().__init__(size, test_images_list_file, add_labels=add_labels, 
                         unique_skipped_labels=unique_skipped_labels, class_to_node=class_to_node)