demo / generic.py
SahithiR's picture
Upload 2 files
74c7bcf
raw
history blame
3.53 kB
import os
from abc import ABC
from functools import cached_property
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
try:
from epoch.utils import plot_examples
except ModuleNotFoundError:
from utils import plot_examples
class MyDataSet(ABC):
DataSet = None
mean = None
std = None
classes = None
default_alb_transforms = None
def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None):
self.batch_size = batch_size
self.normalize = normalize
self.shuffle = shuffle
self.augment = augment
self.alb_transforms = alb_transforms or self.default_alb_transforms
self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True}
@classmethod
def set_classes(cls, data):
if cls.classes is None:
cls.classes = {i: c for i, c in enumerate(data.classes)}
@cached_property
def train_data(self):
res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms())
self.set_classes(res)
return res
@cached_property
def test_data(self):
res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms())
self.set_classes(res)
return res
@cached_property
def train_loader(self):
return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs)
@cached_property
def test_loader(self):
return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs)
@cached_property
def example_iter(self):
return iter(self.train_loader)
def get_train_transforms(self):
all_transforms = list()
if self.normalize:
all_transforms.append(A.Normalize(self.mean, self.std))
if self.augment and self.alb_transforms is not None:
all_transforms.extend(self.alb_transforms)
all_transforms.append(ToTensorV2())
return A.Compose(all_transforms)
def get_test_transforms(self):
all_transforms = list()
if self.normalize:
all_transforms.append(A.Normalize(self.mean, self.std))
all_transforms.append(ToTensorV2())
return A.Compose(all_transforms)
def download(self):
self.DataSet('../data', train=True, download=True)
self.DataSet('../data', train=False, download=True)
def denormalise(self, tensor):
result = tensor.clone().detach().requires_grad_(False)
if self.normalize:
for t, m, s in zip(result, self.mean, self.std):
t.mul_(s).add_(m)
return result
def show_transform(self, img):
if self.normalize:
img = self.denormalise(img)
if len(self.mean) == 3:
return img.permute(1, 2, 0)
else:
return img.squeeze(0)
def show_examples(self, figsize=(8, 6)):
batch_data, batch_label = next(self.example_iter)
images = list()
labels = list()
for i in range(len(batch_data)):
image = batch_data[i]
image = self.show_transform(image)
label = batch_label[i].item()
if self.classes is not None:
label = f'{label}:{self.classes[label]}'
images.append(image)
labels.append(label)
plot_examples(images, labels, figsize=figsize)