Spaces:
Runtime error
Runtime error
import os | |
from abc import ABC | |
from functools import cached_property | |
import torch | |
import albumentations as A | |
from albumentations.pytorch import ToTensorV2 | |
try: | |
from epoch.utils import plot_examples | |
except ModuleNotFoundError: | |
from utils import plot_examples | |
class MyDataSet(ABC): | |
DataSet = None | |
mean = None | |
std = None | |
classes = None | |
default_alb_transforms = None | |
def __init__(self, batch_size=1, normalize=True, shuffle=True, augment=True, alb_transforms=None): | |
self.batch_size = batch_size | |
self.normalize = normalize | |
self.shuffle = shuffle | |
self.augment = augment | |
self.alb_transforms = alb_transforms or self.default_alb_transforms | |
self.loader_kwargs = {'batch_size': batch_size, 'num_workers': os.cpu_count(), 'pin_memory': True} | |
def set_classes(cls, data): | |
if cls.classes is None: | |
cls.classes = {i: c for i, c in enumerate(data.classes)} | |
def train_data(self): | |
res = self.DataSet('../data', train=True, download=True, alb_transform=self.get_train_transforms()) | |
self.set_classes(res) | |
return res | |
def test_data(self): | |
res = self.DataSet('../data', train=False, download=True, alb_transform=self.get_test_transforms()) | |
self.set_classes(res) | |
return res | |
def train_loader(self): | |
return torch.utils.data.DataLoader(self.train_data, shuffle=self.shuffle, **self.loader_kwargs) | |
def test_loader(self): | |
return torch.utils.data.DataLoader(self.test_data, shuffle=False, **self.loader_kwargs) | |
def example_iter(self): | |
return iter(self.train_loader) | |
def get_train_transforms(self): | |
all_transforms = list() | |
if self.normalize: | |
all_transforms.append(A.Normalize(self.mean, self.std)) | |
if self.augment and self.alb_transforms is not None: | |
all_transforms.extend(self.alb_transforms) | |
all_transforms.append(ToTensorV2()) | |
return A.Compose(all_transforms) | |
def get_test_transforms(self): | |
all_transforms = list() | |
if self.normalize: | |
all_transforms.append(A.Normalize(self.mean, self.std)) | |
all_transforms.append(ToTensorV2()) | |
return A.Compose(all_transforms) | |
def download(self): | |
self.DataSet('../data', train=True, download=True) | |
self.DataSet('../data', train=False, download=True) | |
def denormalise(self, tensor): | |
result = tensor.clone().detach().requires_grad_(False) | |
if self.normalize: | |
for t, m, s in zip(result, self.mean, self.std): | |
t.mul_(s).add_(m) | |
return result | |
def show_transform(self, img): | |
if self.normalize: | |
img = self.denormalise(img) | |
if len(self.mean) == 3: | |
return img.permute(1, 2, 0) | |
else: | |
return img.squeeze(0) | |
def show_examples(self, figsize=(8, 6)): | |
batch_data, batch_label = next(self.example_iter) | |
images = list() | |
labels = list() | |
for i in range(len(batch_data)): | |
image = batch_data[i] | |
image = self.show_transform(image) | |
label = batch_label[i].item() | |
if self.classes is not None: | |
label = f'{label}:{self.classes[label]}' | |
images.append(image) | |
labels.append(label) | |
plot_examples(images, labels, figsize=figsize) | |