Spaces:
Runtime error
Runtime error
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
import pytorch_lightning as pl | |
from torch.utils.data import Subset | |
from sklearn.model_selection import train_test_split | |
from timm.data import ImageDataset | |
from timm.data.transforms_factory import create_transform | |
from constants import INPUT_IMAGE_SIZE | |
timm_transform = create_transform(224, scale=(0.7, 1.0), is_training=True, auto_augment='rand-mstd0.5') | |
NUM_WORKERS = 0 | |
batch_size = 40 | |
IMAGENET_STATS = ([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225]) | |
inference_transforms = transforms.Compose([ | |
transforms.Resize(size=256), | |
transforms.CenterCrop(size=INPUT_IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize(*IMAGENET_STATS) | |
]) | |
class BirdsDataModule(pl.LightningDataModule): | |
def __init__(self, data_dir='./'): | |
super().__init__() | |
self.data_dir = data_dir | |
self.batch_size = batch_size | |
self.augmentation = transforms.Compose([ | |
transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)), | |
transforms.RandomRotation(degrees=15), | |
transforms.RandomHorizontalFlip(), | |
transforms.CenterCrop(size=INPUT_IMAGE_SIZE), | |
transforms.ToTensor(), | |
transforms.Normalize(*IMAGENET_STATS) | |
]) | |
self.transform = inference_transforms | |
def prepare_data(self): | |
pass | |
def setup(self, stage=None): | |
tfms = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((256, 256)) | |
]) | |
ids = ImageDataset('data', transform=tfms) | |
# index_to_name = {v: k for k, v in ids.parser.class_to_idx.items()} | |
# import json | |
# with open('index_to_name.json', 'w') as f: | |
# json.dump(index_to_name, f) | |
targets = [c for (f, c) in ids.parser.samples] | |
train_indices, val_indices = train_test_split(list(range(len(targets))), test_size=0.13, | |
stratify=targets, shuffle=True) | |
self.train_dataset = Subset(ids, train_indices) | |
self.train_dataset.transform = self.augmentation | |
self.val_dataset = Subset(ids, val_indices) | |
self.val_dataset.transform = self.transform | |
# we define a separate DataLoader for each of train/val/test | |
def train_dataloader(self): | |
mnist_train = DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS) | |
return mnist_train | |
def val_dataloader(self): | |
mnist_val = DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS) | |
return mnist_val | |
def test_dataloader(self): | |
mnist_test = DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=NUM_WORKERS) | |
return mnist_test | |
birds = BirdsDataModule() | |
birds.prepare_data() | |
birds.setup() | |
samples = next(iter(birds.val_dataloader())) | |