File size: 947 Bytes
bcf646b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
from torch.utils.data import Dataset
import numpy as np
class load_ECG_Dataset(Dataset):
# Initialize dataset
def __init__(self, dataset):
# Load sample
x_data = dataset["samples"]
# Convert to pytorch tensor
if isinstance(x_data, np.ndarray):
x_data = torch.from_numpy(x_data)
# Load labels
y_data = dataset.get("labels")
if y_data is not None and isinstance(y_data, np.ndarray):
y_data = torch.from_numpy(y_data)
self.x_data = x_data.float()
self.y_data = y_data.long() if y_data is not None else None
self.len = x_data.shape[0]
def get_labels(self):
return self.y_data
def __getitem__(self, idx):
sample = {
'samples': self.x_data[idx].squeeze(-1),
'labels': int(self.y_data[idx])
}
return sample
def __len__(self):
return self.len
|