from torch.utils.data import Dataset class CustomDataset(Dataset): def __init__(self, images, labels, transform=None): self.images = images self.labels = labels self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] label = self.labels[idx] # Apply transformations if any if self.transform: image = self.transform(image) return image, label