taesiri's picture
Duplicate from taesiri/DeticChatGPT
f97cf44
raw
history blame
4.95 kB
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import gzip
import numpy as np
import io
from PIL import Image
from torch.utils.data import Dataset
try:
from PIL import UnidentifiedImageError
unidentified_error_available = True
except ImportError:
# UnidentifiedImageError isn't available in older versions of PIL
unidentified_error_available = False
class DiskTarDataset(Dataset):
def __init__(self,
tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy',
tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy',
preload=False,
num_synsets="all"):
"""
- preload (bool): Recommend to set preload to False when using
- num_synsets (integer or string "all"): set to small number for debugging
will load subset of dataset
"""
tar_files = np.load(tarfile_path)
chunk_datasets = []
dataset_lens = []
if isinstance(num_synsets, int):
assert num_synsets < len(tar_files)
tar_files = tar_files[:num_synsets]
for tar_file in tar_files:
dataset = _TarDataset(tar_file, tar_index_dir, preload=preload)
chunk_datasets.append(dataset)
dataset_lens.append(len(dataset))
self.chunk_datasets = chunk_datasets
self.dataset_lens = np.array(dataset_lens).astype(np.int32)
self.dataset_cumsums = np.cumsum(self.dataset_lens)
self.num_samples = sum(self.dataset_lens)
labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64)
sI = 0
for k in range(len(self.dataset_lens)):
assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}"
labels[sI:(sI+self.dataset_lens[k])] = k
sI += self.dataset_lens[k]
self.labels = labels
def __len__(self):
return self.num_samples
def __getitem__(self, index):
assert index >= 0 and index < len(self)
# find the dataset file we need to go to
d_index = np.searchsorted(self.dataset_cumsums, index)
# edge case, if index is at edge of chunks, move right
if index in self.dataset_cumsums:
d_index += 1
assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}"
# change index to local dataset index
if d_index == 0:
local_index = index
else:
local_index = index - self.dataset_cumsums[d_index - 1]
data_bytes = self.chunk_datasets[d_index][local_index]
exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception
try:
image = Image.open(data_bytes).convert("RGB")
except exception_to_catch:
image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128)
d_index = -1
# label is the dataset (synset) we indexed into
return image, d_index, index
def __repr__(self):
st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})"
return st
class _TarDataset(object):
def __init__(self, filename, npy_index_dir, preload=False):
# translated from
# fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua
self.filename = filename
self.names = []
self.offsets = []
self.npy_index_dir = npy_index_dir
names, offsets = self.load_index()
self.num_samples = len(names)
if preload:
self.data = np.memmap(filename, mode='r', dtype='uint8')
self.offsets = offsets
else:
self.data = None
def __len__(self):
return self.num_samples
def load_index(self):
basename = os.path.basename(self.filename)
basename = os.path.splitext(basename)[0]
names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy"))
offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy"))
return names, offsets
def __getitem__(self, idx):
if self.data is None:
self.data = np.memmap(self.filename, mode='r', dtype='uint8')
_, self.offsets = self.load_index()
ofs = self.offsets[idx] * 512
fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx])
data = self.data[ofs:ofs + fsize]
if data[:13].tostring() == '././@LongLink':
data = data[3 * 512:]
else:
data = data[512:]
# just to make it more fun a few JPEGs are GZIP compressed...
# catch this case
if tuple(data[:2]) == (0x1f, 0x8b):
s = io.BytesIO(data.tostring())
g = gzip.GzipFile(None, 'r', 0, s)
sdata = g.read()
else:
sdata = data.tostring()
return io.BytesIO(sdata)