File size: 759 Bytes
04a30fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch.utils import data
import torch.nn.functional as F
import numpy as np


class TorchDataset(data.Dataset):
    def __init__(self, datasamples, is_inference: bool):
        self.x = datasamples["image"]
        self.y = datasamples["label"]
        self.is_inference = is_inference

    def __getitem__(self, idx):

        if self.is_inference:
            x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
            return x

        else:
            x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255

            y = torch.tensor(self.y[idx]).type(torch.int64)
            y = F.one_hot(y, 10)
            y = y.type(torch.float32)

            return x, y

    def __len__(self):
        return len(self.x)