Spaces:
Running
Running
import os | |
import torch | |
import torch.utils.data as data | |
import numpy as np | |
from torchvision.datasets import ImageNet | |
from PIL import Image, ImageFilter | |
import h5py | |
from glob import glob | |
class ImagenetSegmentation(data.Dataset): | |
CLASSES = 2 | |
def __init__(self, | |
path, | |
transform=None, | |
target_transform=None): | |
self.path = path | |
self.transform = transform | |
self.target_transform = target_transform | |
self.h5py = None | |
tmp = h5py.File(path, 'r') | |
self.data_length = len(tmp['/value/img']) | |
tmp.close() | |
del tmp | |
def __getitem__(self, index): | |
if self.h5py is None: | |
self.h5py = h5py.File(self.path, 'r') | |
img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0)) | |
target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0)) | |
img = Image.fromarray(img).convert('RGB') | |
target = Image.fromarray(target) | |
if self.transform is not None: | |
img = self.transform(img) | |
if self.target_transform is not None: | |
target = np.array(self.target_transform(target)).astype('int32') | |
target = torch.from_numpy(target).long() | |
return img, target | |
def __len__(self): | |
return self.data_length | |