Spaces:
Running
Running
File size: 1,367 Bytes
c64fb9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import os
import torch
import torch.utils.data as data
import numpy as np
from torchvision.datasets import ImageNet
from PIL import Image, ImageFilter
import h5py
from glob import glob
class ImagenetSegmentation(data.Dataset):
CLASSES = 2
def __init__(self,
path,
transform=None,
target_transform=None):
self.path = path
self.transform = transform
self.target_transform = target_transform
self.h5py = None
tmp = h5py.File(path, 'r')
self.data_length = len(tmp['/value/img'])
tmp.close()
del tmp
def __getitem__(self, index):
if self.h5py is None:
self.h5py = h5py.File(self.path, 'r')
img = np.array(self.h5py[self.h5py['/value/img'][index, 0]]).transpose((2, 1, 0))
target = np.array(self.h5py[self.h5py[self.h5py['/value/gt'][index, 0]][0, 0]]).transpose((1, 0))
img = Image.fromarray(img).convert('RGB')
target = Image.fromarray(target)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = np.array(self.target_transform(target)).astype('int32')
target = torch.from_numpy(target).long()
return img, target
def __len__(self):
return self.data_length
|