Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
import os | |
import gzip | |
import numpy as np | |
import io | |
from PIL import Image | |
from torch.utils.data import Dataset | |
try: | |
from PIL import UnidentifiedImageError | |
unidentified_error_available = True | |
except ImportError: | |
# UnidentifiedImageError isn't available in older versions of PIL | |
unidentified_error_available = False | |
class DiskTarDataset(Dataset): | |
def __init__(self, | |
tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy', | |
tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy', | |
preload=False, | |
num_synsets="all"): | |
""" | |
- preload (bool): Recommend to set preload to False when using | |
- num_synsets (integer or string "all"): set to small number for debugging | |
will load subset of dataset | |
""" | |
tar_files = np.load(tarfile_path) | |
chunk_datasets = [] | |
dataset_lens = [] | |
if isinstance(num_synsets, int): | |
assert num_synsets < len(tar_files) | |
tar_files = tar_files[:num_synsets] | |
for tar_file in tar_files: | |
dataset = _TarDataset(tar_file, tar_index_dir, preload=preload) | |
chunk_datasets.append(dataset) | |
dataset_lens.append(len(dataset)) | |
self.chunk_datasets = chunk_datasets | |
self.dataset_lens = np.array(dataset_lens).astype(np.int32) | |
self.dataset_cumsums = np.cumsum(self.dataset_lens) | |
self.num_samples = sum(self.dataset_lens) | |
labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64) | |
sI = 0 | |
for k in range(len(self.dataset_lens)): | |
assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}" | |
labels[sI:(sI+self.dataset_lens[k])] = k | |
sI += self.dataset_lens[k] | |
self.labels = labels | |
def __len__(self): | |
return self.num_samples | |
def __getitem__(self, index): | |
assert index >= 0 and index < len(self) | |
# find the dataset file we need to go to | |
d_index = np.searchsorted(self.dataset_cumsums, index) | |
# edge case, if index is at edge of chunks, move right | |
if index in self.dataset_cumsums: | |
d_index += 1 | |
assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}" | |
# change index to local dataset index | |
if d_index == 0: | |
local_index = index | |
else: | |
local_index = index - self.dataset_cumsums[d_index - 1] | |
data_bytes = self.chunk_datasets[d_index][local_index] | |
exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception | |
try: | |
image = Image.open(data_bytes).convert("RGB") | |
except exception_to_catch: | |
image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128) | |
d_index = -1 | |
# label is the dataset (synset) we indexed into | |
return image, d_index, index | |
def __repr__(self): | |
st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})" | |
return st | |
class _TarDataset(object): | |
def __init__(self, filename, npy_index_dir, preload=False): | |
# translated from | |
# fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua | |
self.filename = filename | |
self.names = [] | |
self.offsets = [] | |
self.npy_index_dir = npy_index_dir | |
names, offsets = self.load_index() | |
self.num_samples = len(names) | |
if preload: | |
self.data = np.memmap(filename, mode='r', dtype='uint8') | |
self.offsets = offsets | |
else: | |
self.data = None | |
def __len__(self): | |
return self.num_samples | |
def load_index(self): | |
basename = os.path.basename(self.filename) | |
basename = os.path.splitext(basename)[0] | |
names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy")) | |
offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy")) | |
return names, offsets | |
def __getitem__(self, idx): | |
if self.data is None: | |
self.data = np.memmap(self.filename, mode='r', dtype='uint8') | |
_, self.offsets = self.load_index() | |
ofs = self.offsets[idx] * 512 | |
fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx]) | |
data = self.data[ofs:ofs + fsize] | |
if data[:13].tostring() == '././@LongLink': | |
data = data[3 * 512:] | |
else: | |
data = data[512:] | |
# just to make it more fun a few JPEGs are GZIP compressed... | |
# catch this case | |
if tuple(data[:2]) == (0x1f, 0x8b): | |
s = io.BytesIO(data.tostring()) | |
g = gzip.GzipFile(None, 'r', 0, s) | |
sdata = g.read() | |
else: | |
sdata = data.tostring() | |
return io.BytesIO(sdata) |