Spaces:
Runtime error
Runtime error
Upload 2 files
#1
by
SahithiR
- opened
- cifar10.py +37 -0
- generic.py +111 -0
cifar10.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
from torchvision import datasets
|
4 |
+
import albumentations as A
|
5 |
+
from albumentations.pytorch import ToTensorV2
|
6 |
+
|
7 |
+
from .generic import MyDataSet
|
8 |
+
|
9 |
+
|
10 |
+
class AlbumentationsCIFAR10(datasets.CIFAR10):
|
11 |
+
def __init__(self, root, alb_transform=None, **kwargs):
|
12 |
+
super(AlbumentationsCIFAR10, self).__init__(root, **kwargs)
|
13 |
+
self.alb_transform = alb_transform
|
14 |
+
|
15 |
+
def __getitem__(self, index):
|
16 |
+
image, label = super(AlbumentationsCIFAR10, self).__getitem__(index)
|
17 |
+
if self.alb_transform is not None:
|
18 |
+
image = self.alb_transform(image=np.array(image))['image']
|
19 |
+
return image, label
|
20 |
+
|
21 |
+
|
22 |
+
class cifar10_dataset(MyDataSet):
|
23 |
+
DataSet = AlbumentationsCIFAR10
|
24 |
+
mean = (0.49139968, 0.48215827, 0.44653124)
|
25 |
+
std = (0.24703233, 0.24348505, 0.26158768)
|
26 |
+
default_alb_transforms = [
|
27 |
+
A.HorizontalFlip(p=1.0),
|
28 |
+
A.ShiftScaleRotate(shift_limit=(-0.2, 0.2), scale_limit=(-0.2, 0.2), rotate_limit=(-15, 15), p=0.5),
|
29 |
+
A.PadIfNeeded(min_height=36, min_width=36, p=1.0),
|
30 |
+
A.RandomCrop (32, 32, always_apply=False, p=1.0),
|
31 |
+
A.CenterCrop(32, 32, always_apply=False, p=1.0),
|
32 |
+
A.CoarseDropout(max_holes = 1, max_height=8, max_width=8, min_holes = 1,
|
33 |
+
min_height=8,min_width=8,
|
34 |
+
fill_value=(0.4914, 0.4822, 0.4465), always_apply=False,p=0.5),
|
35 |
+
|
36 |
+
|
37 |
+
]
|
generic.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from abc import ABC
|
3 |
+
from functools import cached_property
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import albumentations as A
|
7 |
+
from albumentations.pytorch import ToTensorV2
|
8 |
+
|
9 |
+
try:
|
10 |
+
from epoch.utils import plot_examples
|
11 |
+
except ModuleNotFoundError:
|
12 |
+
from utils import plot_examples
|
13 |
+
|
14 |
+
|
15 |
+
class MyDataSet(ABC):
|
16 |
+
DataSet = None
|
17 |
+
mean = None
|
18 |
+
std = None
|
19 |
+
classes = None
|
20 |
+
default_alb_transforms = None
|
21 |
+
|
22 |
+
def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
|
23 |
+
self.batch_size = batch_size
|
24 |
+
self.normalize = normalize
|
25 |
+
self.shuffle = shuffle
|
26 |
+
self.augment = augment
|
27 |
+
self.alb_transforms = alb_transforms or self.default_alb_transforms
|
28 |
+
|
29 |
+
self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def set_classes(cls, data):
|
33 |
+
if cls.classes is None:
|
34 |
+
cls.classes = {i: c for i, c in enumerate(data.classes)}
|
35 |
+
|
36 |
+
@cached_property
|
37 |
+
def train_data(self):
|
38 |
+
res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
|
39 |
+
self.set_classes(res)
|
40 |
+
return res
|
41 |
+
|
42 |
+
@cached_property
|
43 |
+
def test_data(self):
|
44 |
+
res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
|
45 |
+
self.set_classes(res)
|
46 |
+
return res
|
47 |
+
|
48 |
+
@cached_property
|
49 |
+
def train_loader(self):
|
50 |
+
return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)
|
51 |
+
|
52 |
+
@cached_property
|
53 |
+
def test_loader(self):
|
54 |
+
return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)
|
55 |
+
|
56 |
+
@cached_property
|
57 |
+
def example_iter(self):
|
58 |
+
return iter(self.train_loader)
|
59 |
+
|
60 |
+
def get_train_transforms(self):
|
61 |
+
all_transforms = list()
|
62 |
+
if self.normalize:
|
63 |
+
all_transforms.append(A.Normalize(self.mean, self.std))
|
64 |
+
if self.augment and self.alb_transforms is not None:
|
65 |
+
all_transforms.extend(self.alb_transforms)
|
66 |
+
all_transforms.append(ToTensorV2())
|
67 |
+
return A.Compose(all_transforms)
|
68 |
+
|
69 |
+
def get_test_transforms(self):
|
70 |
+
all_transforms = list()
|
71 |
+
if self.normalize:
|
72 |
+
all_transforms.append(A.Normalize(self.mean, self.std))
|
73 |
+
all_transforms.append(ToTensorV2())
|
74 |
+
return A.Compose(all_transforms)
|
75 |
+
|
76 |
+
def download(self):
|
77 |
+
self.DataSet('../data', train=True, download=True)
|
78 |
+
self.DataSet('../data', train=False, download=True)
|
79 |
+
|
80 |
+
def denormalise(self, tensor):
|
81 |
+
result = tensor.clone().detach().requires_grad_(False)
|
82 |
+
if self.normalize:
|
83 |
+
for t, m, s in zip(result, self.mean, self.std):
|
84 |
+
t.mul_(s).add_(m)
|
85 |
+
return result
|
86 |
+
|
87 |
+
def show_transform(self, img):
|
88 |
+
if self.normalize:
|
89 |
+
img = self.denormalise(img)
|
90 |
+
if len(self.mean) == 3:
|
91 |
+
return img.permute(1, 2, 0)
|
92 |
+
else:
|
93 |
+
return img.squeeze(0)
|
94 |
+
|
95 |
+
def show_examples(self, figsize=(8, 6)):
|
96 |
+
batch_data, batch_label = next(self.example_iter)
|
97 |
+
images = list()
|
98 |
+
labels = list()
|
99 |
+
|
100 |
+
for i in range(len(batch_data)):
|
101 |
+
image = batch_data[i]
|
102 |
+
image = self.show_transform(image)
|
103 |
+
|
104 |
+
label = batch_label[i].item()
|
105 |
+
if self.classes is not None:
|
106 |
+
label = f'{label}:{self.classes[label]}'
|
107 |
+
|
108 |
+
images.append(image)
|
109 |
+
labels.append(label)
|
110 |
+
|
111 |
+
plot_examples(images, labels, figsize=figsize)
|