import torch from torch.utils import data import torch.nn.functional as F import numpy as np class TorchDataset(data.Dataset): def __init__(self, datasamples, is_inference: bool): self.x = datasamples["image"] self.y = datasamples["label"] self.is_inference = is_inference def __getitem__(self, idx): if self.is_inference: x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255 return x else: x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255 y = torch.tensor(self.y[idx]).type(torch.int64) y = F.one_hot(y, 10) y = y.type(torch.float32) return x, y def __len__(self): return len(self.x)