NimaBoscarino's picture
WIP: Substra orchestrator
04a30fc
raw
history blame
759 Bytes
import torch
from torch.utils import data
import torch.nn.functional as F
import numpy as np
class TorchDataset(data.Dataset):
def __init__(self, datasamples, is_inference: bool):
self.x = datasamples["image"]
self.y = datasamples["label"]
self.is_inference = is_inference
def __getitem__(self, idx):
if self.is_inference:
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
return x
else:
x = torch.FloatTensor(np.array(self.x[idx])[None, ...]) / 255
y = torch.tensor(self.y[idx]).type(torch.int64)
y = F.one_hot(y, 10)
y = y.type(torch.float32)
return x, y
def __len__(self):
return len(self.x)