from torch.utils.data import Dataset | |
class CustomDataset(Dataset): | |
def __init__(self, images, labels, transform=None): | |
self.images = images | |
self.labels = labels | |
self.transform = transform | |
def __len__(self): | |
return len(self.images) | |
def __getitem__(self, idx): | |
image = self.images[idx] | |
label = self.labels[idx] | |
# Apply transformations if any | |
if self.transform: | |
image = self.transform(image) | |
return image, label |