File size: 1,405 Bytes
74c7bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import numpy as np
import cv2
from torchvision import datasets
import albumentations as A
from albumentations.pytorch import ToTensorV2

from .generic import MyDataSet


class AlbumentationsCIFAR10(datasets.CIFAR10):
    def __init__(self, root, alb_transform=None, **kwargs):
        super(AlbumentationsCIFAR10, self).__init__(root, **kwargs)
        self.alb_transform = alb_transform

    def __getitem__(self, index):
        image, label = super(AlbumentationsCIFAR10, self).__getitem__(index)
        if self.alb_transform is not None:
            image = self.alb_transform(image=np.array(image))['image']
        return image, label


class cifar10_dataset(MyDataSet):
    DataSet = AlbumentationsCIFAR10
    mean = (0.49139968, 0.48215827, 0.44653124)
    std = (0.24703233, 0.24348505, 0.26158768)
    default_alb_transforms = [
        A.HorizontalFlip(p=1.0),
          A.ShiftScaleRotate(shift_limit=(-0.2, 0.2), scale_limit=(-0.2, 0.2), rotate_limit=(-15, 15), p=0.5),
          A.PadIfNeeded(min_height=36, min_width=36, p=1.0),
          A.RandomCrop (32, 32, always_apply=False, p=1.0),
          A.CenterCrop(32, 32, always_apply=False, p=1.0),
          A.CoarseDropout(max_holes = 1, max_height=8, max_width=8, min_holes = 1,
                          min_height=8,min_width=8,
                          fill_value=(0.4914, 0.4822, 0.4465), always_apply=False,p=0.5),

         
    ]