Spaces:
Sleeping
Sleeping
import h5py | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from torch.utils.data import Dataset | |
from torchvision.transforms import ToTensor | |
class MNIST(Dataset): | |
def __init__(self, h5_file, transform=ToTensor()): | |
self.h5_file = h5_file | |
self.transform = transform | |
# 读取HDF5文件 | |
with h5py.File(self.h5_file, 'r') as file: | |
self.data = [] | |
self.labels = [] | |
for i in range(10): | |
images = file[str(i)][()] | |
for img in images: | |
self.data.append(img) | |
self.labels.append(i) | |
self.data = np.array(self.data) | |
self.labels = np.array(self.labels) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
image = self.data[idx] | |
label = self.labels[idx] | |
if self.transform: | |
image = self.transform(image) | |
return image, label | |
if __name__ == '__main__': | |
mnist_h5_dataset = MNIST('data/mnist.h5') | |
assert len(mnist_h5_dataset) == 70000 | |
# Display the first 10 images of each digit, along with their labels, in a 10x10 grid | |
fig, axs = plt.subplots(10, 10, figsize=(10, 10)) | |
for i in range(10): | |
images = mnist_h5_dataset.data[mnist_h5_dataset.labels == i] | |
for j in range(10): | |
axs[i, j].imshow(images[j], cmap='gray') | |
axs[i, j].axis('off') | |
axs[i, j].set_title(i) | |
plt.tight_layout() | |
plt.savefig("mnist_h5_dataset.png") | |