|
|
|
|
|
import os |
|
import gzip |
|
import numpy as np |
|
import io |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
|
|
try: |
|
from PIL import UnidentifiedImageError |
|
|
|
unidentified_error_available = True |
|
except ImportError: |
|
|
|
unidentified_error_available = False |
|
|
|
class DiskTarDataset(Dataset): |
|
def __init__(self, |
|
tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy', |
|
tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy', |
|
preload=False, |
|
num_synsets="all"): |
|
""" |
|
- preload (bool): Recommend to set preload to False when using |
|
- num_synsets (integer or string "all"): set to small number for debugging |
|
will load subset of dataset |
|
""" |
|
tar_files = np.load(tarfile_path) |
|
|
|
chunk_datasets = [] |
|
dataset_lens = [] |
|
if isinstance(num_synsets, int): |
|
assert num_synsets < len(tar_files) |
|
tar_files = tar_files[:num_synsets] |
|
for tar_file in tar_files: |
|
dataset = _TarDataset(tar_file, tar_index_dir, preload=preload) |
|
chunk_datasets.append(dataset) |
|
dataset_lens.append(len(dataset)) |
|
|
|
self.chunk_datasets = chunk_datasets |
|
self.dataset_lens = np.array(dataset_lens).astype(np.int32) |
|
self.dataset_cumsums = np.cumsum(self.dataset_lens) |
|
self.num_samples = sum(self.dataset_lens) |
|
labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64) |
|
sI = 0 |
|
for k in range(len(self.dataset_lens)): |
|
assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}" |
|
labels[sI:(sI+self.dataset_lens[k])] = k |
|
sI += self.dataset_lens[k] |
|
self.labels = labels |
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def __getitem__(self, index): |
|
assert index >= 0 and index < len(self) |
|
|
|
d_index = np.searchsorted(self.dataset_cumsums, index) |
|
|
|
|
|
if index in self.dataset_cumsums: |
|
d_index += 1 |
|
|
|
assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}" |
|
|
|
|
|
if d_index == 0: |
|
local_index = index |
|
else: |
|
local_index = index - self.dataset_cumsums[d_index - 1] |
|
data_bytes = self.chunk_datasets[d_index][local_index] |
|
exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception |
|
try: |
|
image = Image.open(data_bytes).convert("RGB") |
|
except exception_to_catch: |
|
image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128) |
|
d_index = -1 |
|
|
|
|
|
return image, d_index, index |
|
|
|
def __repr__(self): |
|
st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})" |
|
return st |
|
|
|
class _TarDataset(object): |
|
|
|
def __init__(self, filename, npy_index_dir, preload=False): |
|
|
|
|
|
self.filename = filename |
|
self.names = [] |
|
self.offsets = [] |
|
self.npy_index_dir = npy_index_dir |
|
names, offsets = self.load_index() |
|
|
|
self.num_samples = len(names) |
|
if preload: |
|
self.data = np.memmap(filename, mode='r', dtype='uint8') |
|
self.offsets = offsets |
|
else: |
|
self.data = None |
|
|
|
|
|
def __len__(self): |
|
return self.num_samples |
|
|
|
def load_index(self): |
|
basename = os.path.basename(self.filename) |
|
basename = os.path.splitext(basename)[0] |
|
names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy")) |
|
offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy")) |
|
return names, offsets |
|
|
|
def __getitem__(self, idx): |
|
if self.data is None: |
|
self.data = np.memmap(self.filename, mode='r', dtype='uint8') |
|
_, self.offsets = self.load_index() |
|
|
|
ofs = self.offsets[idx] * 512 |
|
fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx]) |
|
data = self.data[ofs:ofs + fsize] |
|
|
|
if data[:13].tostring() == '././@LongLink': |
|
data = data[3 * 512:] |
|
else: |
|
data = data[512:] |
|
|
|
|
|
|
|
if tuple(data[:2]) == (0x1f, 0x8b): |
|
s = io.BytesIO(data.tostring()) |
|
g = gzip.GzipFile(None, 'r', 0, s) |
|
sdata = g.read() |
|
else: |
|
sdata = data.tostring() |
|
return io.BytesIO(sdata) |