Files changed (2) hide show
  1. cifar10.py +37 -0
  2. generic.py +111 -0
cifar10.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ from torchvision import datasets
4
+ import albumentations as A
5
+ from albumentations.pytorch import ToTensorV2
6
+
7
+ from .generic import MyDataSet
8
+
9
+
10
+ class AlbumentationsCIFAR10(datasets.CIFAR10):
11
+ def __init__(self, root, alb_transform=None, **kwargs):
12
+ super(AlbumentationsCIFAR10, self).__init__(root, **kwargs)
13
+ self.alb_transform = alb_transform
14
+
15
+ def __getitem__(self, index):
16
+ image, label = super(AlbumentationsCIFAR10, self).__getitem__(index)
17
+ if self.alb_transform is not None:
18
+ image = self.alb_transform(image=np.array(image))['image']
19
+ return image, label
20
+
21
+
22
+ class cifar10_dataset(MyDataSet):
23
+ DataSet = AlbumentationsCIFAR10
24
+ mean = (0.49139968, 0.48215827, 0.44653124)
25
+ std = (0.24703233, 0.24348505, 0.26158768)
26
+ default_alb_transforms = [
27
+ A.HorizontalFlip(p=1.0),
28
+ A.ShiftScaleRotate(shift_limit=(-0.2, 0.2), scale_limit=(-0.2, 0.2), rotate_limit=(-15, 15), p=0.5),
29
+ A.PadIfNeeded(min_height=36, min_width=36, p=1.0),
30
+ A.RandomCrop (32, 32, always_apply=False, p=1.0),
31
+ A.CenterCrop(32, 32, always_apply=False, p=1.0),
32
+ A.CoarseDropout(max_holes = 1, max_height=8, max_width=8, min_holes = 1,
33
+ min_height=8,min_width=8,
34
+ fill_value=(0.4914, 0.4822, 0.4465), always_apply=False,p=0.5),
35
+
36
+
37
+ ]
generic.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC
3
+ from functools import cached_property
4
+
5
+ import torch
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+
9
+ try:
10
+ from epoch.utils import plot_examples
11
+ except ModuleNotFoundError:
12
+ from utils import plot_examples
13
+
14
+
15
+ class MyDataSet(ABC):
16
+ DataSet = None
17
+ mean = None
18
+ std = None
19
+ classes = None
20
+ default_alb_transforms = None
21
+
22
+ def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
23
+ self.batch_size = batch_size
24
+ self.normalize = normalize
25
+ self.shuffle = shuffle
26
+ self.augment = augment
27
+ self.alb_transforms = alb_transforms or self.default_alb_transforms
28
+
29
+ self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}
30
+
31
+ @classmethod
32
+ def set_classes(cls, data):
33
+ if cls.classes is None:
34
+ cls.classes = {i: c for i, c in enumerate(data.classes)}
35
+
36
+ @cached_property
37
+ def train_data(self):
38
+ res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
39
+ self.set_classes(res)
40
+ return res
41
+
42
+ @cached_property
43
+ def test_data(self):
44
+ res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
45
+ self.set_classes(res)
46
+ return res
47
+
48
+ @cached_property
49
+ def train_loader(self):
50
+ return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)
51
+
52
+ @cached_property
53
+ def test_loader(self):
54
+ return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)
55
+
56
+ @cached_property
57
+ def example_iter(self):
58
+ return iter(self.train_loader)
59
+
60
+ def get_train_transforms(self):
61
+ all_transforms = list()
62
+ if self.normalize:
63
+ all_transforms.append(A.Normalize(self.mean, self.std))
64
+ if self.augment and self.alb_transforms is not None:
65
+ all_transforms.extend(self.alb_transforms)
66
+ all_transforms.append(ToTensorV2())
67
+ return A.Compose(all_transforms)
68
+
69
+ def get_test_transforms(self):
70
+ all_transforms = list()
71
+ if self.normalize:
72
+ all_transforms.append(A.Normalize(self.mean, self.std))
73
+ all_transforms.append(ToTensorV2())
74
+ return A.Compose(all_transforms)
75
+
76
+ def download(self):
77
+ self.DataSet('../data', train=True, download=True)
78
+ self.DataSet('../data', train=False, download=True)
79
+
80
+ def denormalise(self, tensor):
81
+ result = tensor.clone().detach().requires_grad_(False)
82
+ if self.normalize:
83
+ for t, m, s in zip(result, self.mean, self.std):
84
+ t.mul_(s).add_(m)
85
+ return result
86
+
87
+ def show_transform(self, img):
88
+ if self.normalize:
89
+ img = self.denormalise(img)
90
+ if len(self.mean) == 3:
91
+ return img.permute(1, 2, 0)
92
+ else:
93
+ return img.squeeze(0)
94
+
95
+ def show_examples(self, figsize=(8, 6)):
96
+ batch_data, batch_label = next(self.example_iter)
97
+ images = list()
98
+ labels = list()
99
+
100
+ for i in range(len(batch_data)):
101
+ image = batch_data[i]
102
+ image = self.show_transform(image)
103
+
104
+ label = batch_label[i].item()
105
+ if self.classes is not None:
106
+ label = f'{label}:{self.classes[label]}'
107
+
108
+ images.append(image)
109
+ labels.append(label)
110
+
111
+ plot_examples(images, labels, figsize=figsize)