import numpy as np import cv2 from torchvision import datasets import albumentations as A from albumentations.pytorch import ToTensorV2 from .generic import MyDataSet class AlbumentationsCIFAR10(datasets.CIFAR10): def __init__(self, root, alb_transform=None, **kwargs): super(AlbumentationsCIFAR10, self).__init__(root, **kwargs) self.alb_transform = alb_transform def __getitem__(self, index): image, label = super(AlbumentationsCIFAR10, self).__getitem__(index) if self.alb_transform is not None: image = self.alb_transform(image=np.array(image))['image'] return image, label class cifar10_dataset(MyDataSet): DataSet = AlbumentationsCIFAR10 mean = (0.49139968, 0.48215827, 0.44653124) std = (0.24703233, 0.24348505, 0.26158768) default_alb_transforms = [ A.HorizontalFlip(p=1.0), A.ShiftScaleRotate(shift_limit=(-0.2, 0.2), scale_limit=(-0.2, 0.2), rotate_limit=(-15, 15), p=0.5), A.PadIfNeeded(min_height=36, min_width=36, p=1.0), A.RandomCrop (32, 32, always_apply=False, p=1.0), A.CenterCrop(32, 32, always_apply=False, p=1.0), A.CoarseDropout(max_holes = 1, max_height=8, max_width=8, min_holes = 1, min_height=8,min_width=8, fill_value=(0.4914, 0.4822, 0.4465), always_apply=False,p=0.5), ]