soutrik's picture
added: model and code and app
29730dd
"""This file contains functions to prepare dataloader in the way lightning expects"""
import pytorch_lightning as pl
import torchvision.datasets as datasets
from lightning_fabric.utilities.seed import seed_everything
from modules.dataset import CIFAR10Transforms, apply_cifar_image_transformations
from torch.utils.data import DataLoader, random_split
class CIFARDataModule(pl.LightningDataModule):
"""Lightning DataModule for CIFAR10 dataset"""
def __init__(self, data_path, batch_size, seed, val_split=0, num_workers=0):
super().__init__()
self.data_path = data_path
self.batch_size = batch_size
self.seed = seed
self.val_split = val_split
self.num_workers = num_workers
self.dataloader_dict = {
# "shuffle": True,
"batch_size": self.batch_size,
"num_workers": self.num_workers,
"pin_memory": True,
# "worker_init_fn": self._init_fn,
"persistent_workers": self.num_workers > 0,
}
self.prepare_data_per_node = False
# Fixes attribute defined outside __init__ warning
self.training_dataset = None
self.validation_dataset = None
self.testing_dataset = None
# # Make sure data is downloaded
# self.prepare_data()
def _split_train_val(self, dataset):
"""Split the dataset into train and validation sets"""
# Throw an error if the validation split is not between 0 and 1
if not 0 < self.val_split < 1:
raise ValueError("Validation split must be between 0 and 1")
# # Set seed again, might not be necessary
# seed_everything(int(self.seed))
# Calculate lengths of each dataset
total_length = len(dataset)
train_length = int((1 - self.val_split) * total_length)
val_length = total_length - train_length
# Split the dataset
train_dataset, val_dataset = random_split(dataset, [train_length, val_length])
return train_dataset, val_dataset
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data
def prepare_data(self):
# Download the CIFAR10 dataset if it doesn't exist
datasets.CIFAR10(self.data_path, train=True, download=True)
datasets.CIFAR10(self.data_path, train=False, download=True)
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.hooks.DataHooks.html#lightning.pytorch.core.hooks.DataHooks.setup
def setup(self, stage=None):
# seed_everything(int(self.seed))
# Define the data transformations
train_transforms, test_transforms = apply_cifar_image_transformations()
val_transforms = test_transforms
# Create train and validation datasets
if stage == "fit" or stage is None:
if self.val_split != 0:
# Split the training data into training and validation sets
data_train, data_val = self._split_train_val(datasets.CIFAR10(self.data_path, train=True))
# Apply transformations
self.training_dataset = CIFAR10Transforms(data_train, train_transforms)
self.validation_dataset = CIFAR10Transforms(data_val, val_transforms)
else:
# Only training data here
self.training_dataset = CIFAR10Transforms(
datasets.CIFAR10(self.data_path, train=True), train_transforms
)
# Validation will be same sa test
self.validation_dataset = CIFAR10Transforms(
datasets.CIFAR10(self.data_path, train=False), val_transforms
)
# Create test dataset
if stage == "test" or stage is None:
# Assign Test split(s) for use in Dataloaders
self.testing_dataset = CIFAR10Transforms(datasets.CIFAR10(self.data_path, train=False), test_transforms)
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#train-dataloader
def train_dataloader(self):
return DataLoader(self.training_dataset, **self.dataloader_dict, shuffle=True)
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#val-dataloader
def val_dataloader(self):
return DataLoader(self.validation_dataset, **self.dataloader_dict, shuffle=False)
# https://lightning.ai/docs/pytorch/stable/data/datamodule.html#test-dataloader
def test_dataloader(self):
return DataLoader(self.testing_dataset, **self.dataloader_dict, shuffle=False)
def _init_fn(self, worker_id):
seed_everything(int(self.seed) + worker_id)