File size: 1,643 Bytes
f831146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import DatasetFolder
import multiprocessing


class FBanksCrossEntropyDataset(Dataset):
    def __init__(self, root):
        self.dataset_folder = DatasetFolder(root=root, loader=FBanksCrossEntropyDataset._npy_loader, extensions='.npy')
        self.len_ = len(self.dataset_folder.samples)

        bin_counts = np.bincount(self.dataset_folder.targets)
        self.num_classes = len(self.dataset_folder.classes)
        self.label_to_index_range = {}
        start = 0
        for i in range(self.num_classes):
            self.label_to_index_range[i] = (start, start + bin_counts[i])
            start = start + bin_counts[i]

    @staticmethod
    def _npy_loader(path):
        sample = np.load(path)
        assert sample.shape[0] == 64
        assert sample.shape[1] == 64
        assert sample.shape[2] == 1

        sample = np.moveaxis(sample, 2, 0)  # pytorch expects input in the format in_channels x width x height
        sample = torch.from_numpy(sample).float()

        return sample

    def __getitem__(self, index):
        return self.dataset_folder[index]

    def __len__(self):
        return self.len_






if __name__ == '__main__':
    use_cuda = False
    kwargs = {'num_workers': multiprocessing.cpu_count(),
              'pin_memory': True} if use_cuda else {}

    data_test = FBanksCrossEntropyDataset('./dataset-speaker-csf/fbanks-test')
    print(data_test.label_to_index_range)
    test_loader = DataLoader(data_test, batch_size=1, shuffle=True, **kwargs)
    print(next(iter(test_loader))[0].shape)