Spaces:
No application file
No application file
import torch | |
import torchvision | |
import torchvision.transforms as transforms | |
import numpy as np | |
def get_datasets(): | |
transform = transforms.Compose( | |
[transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
batch_size = 1 | |
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, | |
download=True, transform=transform) | |
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, | |
shuffle=True, num_workers=1) | |
validset = torchvision.datasets.CIFAR10(root='./data', train=False, | |
download=True, transform=transform) | |
validloader = torch.utils.data.DataLoader(validset, batch_size=batch_size, | |
shuffle=False, num_workers=1) | |
return trainset, trainloader, validset, validloader | |