Text-to-Image
Diffusers
English
learn_ddpm / dataloader /dataloader_cifar10.py
Harshit Agarwal
initial comm
eaefa93
from torchvision import transforms
from torch.utils.data import Subset, DataLoader
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
def load_transformed_dataset(IMG_SIZE=64):
# Define the transformation pipeline
data_transforms = [
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # Scales data into [0,1]
transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1]
]
data_transform = transforms.Compose(data_transforms)
# Load CIFAR10 dataset without splitting
cifar10_dataset = torchvision.datasets.CIFAR10(root=".", download=True, transform=data_transform)
# Split indices into train and test using sklearn's train_test_split
dataset_size = len(cifar10_dataset)
indices = list(range(dataset_size))
train_indices, test_indices = train_test_split(indices, test_size=0.2, random_state=42)
# Create train and test subsets
train_subset = Subset(cifar10_dataset, train_indices)
test_subset = Subset(cifar10_dataset, test_indices)
# Combine train and test subsets into a single ConcatDataset
combined_dataset = torch.utils.data.ConcatDataset([train_subset, test_subset])
return combined_dataset
def load_dataloader(combined_dataset, batch_size=64):
# Create dataloaders
dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
return dataloader