Spaces:
Runtime error
Runtime error
File size: 759 Bytes
04a30fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import torch
from torch.utils import data
import torch.nn.functional as F
import numpy as np
class TorchDataset(data.Dataset):
def __init__(self, datasamples, is_inference: bool):
self.x = datasamples["image"]
self.y = datasamples["label"]
self.is_inference = is_inference
def __getitem__(self, idx):
if self.is_inference:
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
return x
else:
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
y = torch.tensor(self.y[idx]).type(torch.int64)
y = F.one_hot(y, 10)
y = y.type(torch.float32)
return x, y
def __len__(self):
return len(self.x)
|