from torch.utils.data import Dataset # Create a custom dataset class that takes a single input sample class SingleInputDataset(Dataset): def __init__(self, input_single): self.sample = input_single def __len__(self): return 1 def __getitem__(self, index): return self.sample