|
from torchvision import transforms |
|
from torch.utils.data import Subset, DataLoader |
|
import numpy as np |
|
import torch |
|
import torchvision |
|
import matplotlib.pyplot as plt |
|
from sklearn.model_selection import train_test_split |
|
|
|
def load_transformed_dataset(IMG_SIZE=64): |
|
|
|
data_transforms = [ |
|
transforms.Resize((IMG_SIZE, IMG_SIZE)), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Lambda(lambda t: (t * 2) - 1) |
|
] |
|
data_transform = transforms.Compose(data_transforms) |
|
|
|
|
|
cifar10_dataset = torchvision.datasets.CIFAR10(root=".", download=True, transform=data_transform) |
|
|
|
|
|
dataset_size = len(cifar10_dataset) |
|
indices = list(range(dataset_size)) |
|
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42) |
|
|
|
|
|
train_subset = Subset(cifar10_dataset, train_indices) |
|
test_subset = Subset(cifar10_dataset, test_indices) |
|
|
|
|
|
combined_dataset = torch.utils.data.ConcatDataset([train_subset, test_subset]) |
|
|
|
return combined_dataset |
|
|
|
def load_dataloader(combined_dataset, batch_size=64): |
|
|
|
dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, drop_last=True) |
|
return dataloader |