File size: 1,578 Bytes
fa7be76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import h5py
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor


class MNIST(Dataset):
    def __init__(self, h5_file, transform=ToTensor()):
        self.h5_file = h5_file
        self.transform = transform
        # 读取HDF5文件
        with h5py.File(self.h5_file, 'r') as file:
            self.data = []
            self.labels = []
            for i in range(10):
                images = file[str(i)][()]
                for img in images:
                    self.data.append(img)
                    self.labels.append(i)
            self.data = np.array(self.data)
            self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label


if __name__ == '__main__':
    mnist_h5_dataset = MNIST('data/mnist.h5')

    assert len(mnist_h5_dataset) == 70000

    # Display the first 10 images of each digit, along with their labels, in a 10x10 grid
    fig, axs = plt.subplots(10, 10, figsize=(10, 10))
    for i in range(10):
        images = mnist_h5_dataset.data[mnist_h5_dataset.labels == i]
        for j in range(10):
            axs[i, j].imshow(images[j], cmap='gray')
            axs[i, j].axis('off')
            axs[i, j].set_title(i)
    plt.tight_layout()
    plt.savefig("mnist_h5_dataset.png")