Spaces:
Runtime error
Runtime error
File size: 3,526 Bytes
74c7bcf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import os
from abc import ABC
from functools import cached_property
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
try:
from epoch.utils import plot_examples
except ModuleNotFoundError:
from utils import plot_examples
class MyDataSet(ABC):
DataSet = None
mean = None
std = None
classes = None
default_alb_transforms = None
def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
self.batch_size = batch_size
self.normalize = normalize
self.shuffle = shuffle
self.augment = augment
self.alb_transforms = alb_transforms or self.default_alb_transforms
self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}
@classmethod
def set_classes(cls, data):
if cls.classes is None:
cls.classes = {i: c for i, c in enumerate(data.classes)}
@cached_property
def train_data(self):
res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
self.set_classes(res)
return res
@cached_property
def test_data(self):
res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
self.set_classes(res)
return res
@cached_property
def train_loader(self):
return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)
@cached_property
def test_loader(self):
return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)
@cached_property
def example_iter(self):
return iter(self.train_loader)
def get_train_transforms(self):
all_transforms = list()
if self.normalize:
all_transforms.append(A.Normalize(self.mean, self.std))
if self.augment and self.alb_transforms is not None:
all_transforms.extend(self.alb_transforms)
all_transforms.append(ToTensorV2())
return A.Compose(all_transforms)
def get_test_transforms(self):
all_transforms = list()
if self.normalize:
all_transforms.append(A.Normalize(self.mean, self.std))
all_transforms.append(ToTensorV2())
return A.Compose(all_transforms)
def download(self):
self.DataSet('../data', train=True, download=True)
self.DataSet('../data', train=False, download=True)
def denormalise(self, tensor):
result = tensor.clone().detach().requires_grad_(False)
if self.normalize:
for t, m, s in zip(result, self.mean, self.std):
t.mul_(s).add_(m)
return result
def show_transform(self, img):
if self.normalize:
img = self.denormalise(img)
if len(self.mean) == 3:
return img.permute(1, 2, 0)
else:
return img.squeeze(0)
def show_examples(self, figsize=(8, 6)):
batch_data, batch_label = next(self.example_iter)
images = list()
labels = list()
for i in range(len(batch_data)):
image = batch_data[i]
image = self.show_transform(image)
label = batch_label[i].item()
if self.classes is not None:
label = f'{label}:{self.classes[label]}'
images.append(image)
labels.append(label)
plot_examples(images, labels, figsize=figsize)
|