SDSC6001_HW3 / utils /CustomDataset.py
MingLi
Train better models on ResNet101
6bbd4ce
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, images, labels, transform=None):
self.images = images
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
# Apply transformations if any
if self.transform:
image = self.transform(image)
return image, label