Spaces:
Runtime error
Runtime error
"""This file contains functions to prepare dataloader in the way lightning expects""" | |
import pytorch_lightning as pl | |
import torchvision.datasets as datasets | |
from lightning_fabric.utilities.seed import seed_everything | |
from modules.dataset import CIFAR10Transforms, apply_cifar_image_transformations | |
from torch.utils.data import DataLoader, random_split | |
class CIFARDataModule(pl.LightningDataModule): | |
"""Lightning DataModule for CIFAR10 dataset""" | |
def __init__(self, data_path, batch_size, seed, val_split=0, num_workers=0): | |
super().__init__() | |
self.data_path = data_path | |
self.batch_size = batch_size | |
self.seed = seed | |
self.val_split = val_split | |
self.num_workers = num_workers | |
self.dataloader_dict = { | |
# "shuffle": True, | |
"batch_size": self.batch_size, | |
"num_workers": self.num_workers, | |
"pin_memory": True, | |
# "worker_init_fn": self._init_fn, | |
"persistent_workers": self.num_workers > 0, | |
} | |
self.prepare_data_per_node = False | |
# Fixes attribute defined outside __init__ warning | |
self.training_dataset = None | |
self.validation_dataset = None | |
self.testing_dataset = None | |
# # Make sure data is downloaded | |
# self.prepare_data() | |
def _split_train_val(self, dataset): | |
"""Split the dataset into train and validation sets""" | |
# Throw an error if the validation split is not between 0 and 1 | |
if not 0 < self.val_split < 1: | |
raise ValueError("Validation split must be between 0 and 1") | |
# # Set seed again, might not be necessary | |
# seed_everything(int(self.seed)) | |
# Calculate lengths of each dataset | |
total_length = len(dataset) | |
train_length = int((1 - self.val_split) * total_length) | |
val_length = total_length - train_length | |
# Split the dataset | |
train_dataset, val_dataset = random_split(dataset, [train_length, val_length]) | |
return train_dataset, val_dataset | |
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data | |
def prepare_data(self): | |
# Download the CIFAR10 dataset if it doesn't exist | |
datasets.CIFAR10(self.data_path, train=True, download=True) | |
datasets.CIFAR10(self.data_path, train=False, download=True) | |
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup | |
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.hooks.DataHooks.html#lightning.pytorch.core.hooks.DataHooks.setup | |
def setup(self, stage=None): | |
# seed_everything(int(self.seed)) | |
# Define the data transformations | |
train_transforms, test_transforms = apply_cifar_image_transformations() | |
val_transforms = test_transforms | |
# Create train and validation datasets | |
if stage == "fit" or stage is None: | |
if self.val_split != 0: | |
# Split the training data into training and validation sets | |
data_train, data_val = self._split_train_val(datasets.CIFAR10(self.data_path, train=True)) | |
# Apply transformations | |
self.training_dataset = CIFAR10Transforms(data_train, train_transforms) | |
self.validation_dataset = CIFAR10Transforms(data_val, val_transforms) | |
else: | |
# Only training data here | |
self.training_dataset = CIFAR10Transforms( | |
datasets.CIFAR10(self.data_path, train=True), train_transforms | |
) | |
# Validation will be same sa test | |
self.validation_dataset = CIFAR10Transforms( | |
datasets.CIFAR10(self.data_path, train=False), val_transforms | |
) | |
# Create test dataset | |
if stage == "test" or stage is None: | |
# Assign Test split(s) for use in Dataloaders | |
self.testing_dataset = CIFAR10Transforms(datasets.CIFAR10(self.data_path, train=False), test_transforms) | |
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#train-dataloader | |
def train_dataloader(self): | |
return DataLoader(self.training_dataset, **self.dataloader_dict, shuffle=True) | |
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#val-dataloader | |
def val_dataloader(self): | |
return DataLoader(self.validation_dataset, **self.dataloader_dict, shuffle=False) | |
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#test-dataloader | |
def test_dataloader(self): | |
return DataLoader(self.testing_dataset, **self.dataloader_dict, shuffle=False) | |
def _init_fn(self, worker_id): | |
seed_everything(int(self.seed) + worker_id) |