File size: 3,526 Bytes
74c7bcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
from abc import ABC
from functools import cached_property

import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2

try:
    from epoch.utils import plot_examples
except ModuleNotFoundError:
    from utils import plot_examples


class MyDataSet(ABC):
    DataSet = None
    mean = None
    std = None
    classes = None
    default_alb_transforms = None

    def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
        self.batch_size = batch_size
        self.normalize = normalize
        self.shuffle = shuffle
        self.augment = augment
        self.alb_transforms = alb_transforms or self.default_alb_transforms

        self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}

    @classmethod
    def set_classes(cls, data):
        if cls.classes is None:
            cls.classes = {i: c for i, c in enumerate(data.classes)}

    @cached_property
    def train_data(self):
        res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
        self.set_classes(res)
        return res

    @cached_property
    def test_data(self):
        res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
        self.set_classes(res)
        return res

    @cached_property
    def train_loader(self):
        return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)

    @cached_property
    def test_loader(self):
        return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)

    @cached_property
    def example_iter(self):
        return iter(self.train_loader)

    def get_train_transforms(self):
        all_transforms = list()
        if self.normalize:
            all_transforms.append(A.Normalize(self.mean, self.std))
        if self.augment and self.alb_transforms is not None:
            all_transforms.extend(self.alb_transforms)
        all_transforms.append(ToTensorV2())
        return A.Compose(all_transforms)

    def get_test_transforms(self):
        all_transforms = list()
        if self.normalize:
            all_transforms.append(A.Normalize(self.mean, self.std))
        all_transforms.append(ToTensorV2())
        return A.Compose(all_transforms)

    def download(self):
        self.DataSet('../data', train=True, download=True)
        self.DataSet('../data', train=False, download=True)

    def denormalise(self, tensor):
        result = tensor.clone().detach().requires_grad_(False)
        if self.normalize:
            for t, m, s in zip(result, self.mean, self.std):
                t.mul_(s).add_(m)
        return result

    def show_transform(self, img):
        if self.normalize:
            img = self.denormalise(img)
        if len(self.mean) == 3:
            return img.permute(1, 2, 0)
        else:
            return img.squeeze(0)

    def show_examples(self, figsize=(8, 6)):
        batch_data, batch_label = next(self.example_iter)
        images = list()
        labels = list()

        for i in range(len(batch_data)):
            image = batch_data[i]
            image = self.show_transform(image)

            label = batch_label[i].item()
            if self.classes is not None:
                label = f'{label}:{self.classes[label]}'

            images.append(image)
            labels.append(label)

        plot_examples(images, labels, figsize=figsize)