File size: 3,317 Bytes
17191f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from abc import abstractmethod
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
import numpy as np
import albumentations
from PIL import Image
from torch.utils.data import Dataset



class Txt2ImgIterableBaseDataset(IterableDataset):
    '''
    Define an interface to make the IterableDatasets for text2img data chainable
    '''
    def __init__(self, num_records=0, valid_ids=None, size=256):
        super().__init__()
        self.num_records = num_records
        self.valid_ids = valid_ids
        self.sample_ids = valid_ids
        self.size = size

        print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')

    def __len__(self):
        return self.num_records

    @abstractmethod
    def __iter__(self):
        pass

class ImagePaths(Dataset):
    def __init__(self, paths, size=None, random_crop=False, horizontalflip=False, random_contrast=False, shiftrotate=False, labels=None, unique_skipped_labels=[]):
        self.size = size
        self.random_crop = random_crop

        self.labels = dict() if labels is None else labels
        self.labels["file_path_"] = paths
        self._length = len(paths)
        
        self.labels_without_skipped = None
        if len(unique_skipped_labels)!=0:
            self.labels_without_skipped = dict()
            for i in self.labels.keys():
                self.labels_without_skipped[i] = [a for indx, a in enumerate(labels[i]) if labels['class'][indx] not in unique_skipped_labels]
            self._length = len(self.labels_without_skipped['class'])

        
        

        if self.size is not None and self.size > 0:
            self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
            l = [self.rescaler ]
            if not self.random_crop:
                self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
            else:
                self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
            l.append(self.cropper)
            if horizontalflip==True:
                l.append(albumentations.HorizontalFlip(p=0.2))
            if shiftrotate==True:
                l.append(albumentations.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=45, border_mode=0, value=( int(0.485*255), int(0.456*255), int(0.406*255 )), p=0.3))
            if random_contrast==True:
                l.append(albumentations.RandomBrightnessContrast(p=0.3))
            self.preprocessor = albumentations.Compose(l)
        else:
            self.preprocessor = lambda **kwargs: kwargs

    def __len__(self):
        return self._length

    def preprocess_image(self, image_path):
        image = Image.open(image_path)
        if not image.mode == "RGB":
            image = image.convert("RGB")
        image = np.array(image).astype(np.uint8)
        image = self.preprocessor(image=image)["image"]
        image = (image/127.5 - 1.0).astype(np.float32)
        return image
    
    def __getitem__(self, i):
        labels = self.labels if self.labels_without_skipped is None else self.labels_without_skipped
        example = dict()
        example["image"] = self.preprocess_image(labels["file_path_"][i])
        for k in labels:
            example[k] = labels[k][i]
        return example