caixiaoshun's picture
使用huggingface hub尝试更新
fa7be76 verified
import h5py
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
class MNIST(Dataset):
def __init__(self, h5_file, transform=ToTensor()):
self.h5_file = h5_file
self.transform = transform
# 读取HDF5文件
with h5py.File(self.h5_file, 'r') as file:
self.data = []
self.labels = []
for i in range(10):
images = file[str(i)][()]
for img in images:
self.data.append(img)
self.labels.append(i)
self.data = np.array(self.data)
self.labels = np.array(self.labels)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = self.data[idx]
label = self.labels[idx]
if self.transform:
image = self.transform(image)
return image, label
if __name__ == '__main__':
mnist_h5_dataset = MNIST('data/mnist.h5')
assert len(mnist_h5_dataset) == 70000
# Display the first 10 images of each digit, along with their labels, in a 10x10 grid
fig, axs = plt.subplots(10, 10, figsize=(10, 10))
for i in range(10):
images = mnist_h5_dataset.data[mnist_h5_dataset.labels == i]
for j in range(10):
axs[i, j].imshow(images[j], cmap='gray')
axs[i, j].axis('off')
axs[i, j].set_title(i)
plt.tight_layout()
plt.savefig("mnist_h5_dataset.png")