File size: 4,946 Bytes
159f437
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
import os
import gzip
import numpy as np
import io
from PIL import Image
from torch.utils.data import Dataset

try:
    from PIL import UnidentifiedImageError

    unidentified_error_available = True
except ImportError:
    # UnidentifiedImageError isn't available in older versions of PIL
    unidentified_error_available = False

class DiskTarDataset(Dataset):
    def __init__(self, 
        tarfile_path='dataset/imagenet/ImageNet-21k/metadata/tar_files.npy',
        tar_index_dir='dataset/imagenet/ImageNet-21k/metadata/tarindex_npy',
        preload=False, 
        num_synsets="all"):
        """
        - preload (bool): Recommend to set preload to False when using
        - num_synsets (integer or string "all"): set to small number for debugging
            will load subset of dataset
        """
        tar_files = np.load(tarfile_path)

        chunk_datasets = []
        dataset_lens = []
        if isinstance(num_synsets, int):
            assert num_synsets < len(tar_files)
            tar_files = tar_files[:num_synsets]
        for tar_file in tar_files:
            dataset = _TarDataset(tar_file, tar_index_dir, preload=preload)
            chunk_datasets.append(dataset)
            dataset_lens.append(len(dataset))

        self.chunk_datasets = chunk_datasets
        self.dataset_lens = np.array(dataset_lens).astype(np.int32)
        self.dataset_cumsums = np.cumsum(self.dataset_lens)
        self.num_samples = sum(self.dataset_lens)
        labels = np.zeros(self.dataset_lens.sum(), dtype=np.int64)
        sI = 0
        for k in range(len(self.dataset_lens)):
            assert (sI+self.dataset_lens[k]) <= len(labels), f"{k} {sI+self.dataset_lens[k]} vs. {len(labels)}"
            labels[sI:(sI+self.dataset_lens[k])] = k
            sI += self.dataset_lens[k]
        self.labels = labels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, index):
        assert index >= 0 and index < len(self)
        # find the dataset file we need to go to
        d_index = np.searchsorted(self.dataset_cumsums, index)

        # edge case, if index is at edge of chunks, move right
        if index in self.dataset_cumsums:
            d_index += 1

        assert d_index == self.labels[index], f"{d_index} vs. {self.labels[index]} mismatch for {index}"

        # change index to local dataset index
        if d_index == 0:
            local_index = index
        else:
            local_index = index - self.dataset_cumsums[d_index - 1]
        data_bytes = self.chunk_datasets[d_index][local_index]
        exception_to_catch = UnidentifiedImageError if unidentified_error_available else Exception
        try:
            image = Image.open(data_bytes).convert("RGB")
        except exception_to_catch:
            image = Image.fromarray(np.ones((224,224,3), dtype=np.uint8)*128)
            d_index = -1

        # label is the dataset (synset) we indexed into
        return image, d_index, index

    def __repr__(self):
        st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})"
        return st

class _TarDataset(object):

    def __init__(self, filename, npy_index_dir, preload=False):
        # translated from
        # fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua
        self.filename = filename
        self.names = []
        self.offsets = []
        self.npy_index_dir = npy_index_dir
        names, offsets = self.load_index()

        self.num_samples = len(names)
        if preload:
            self.data = np.memmap(filename, mode='r', dtype='uint8')
            self.offsets = offsets
        else:
            self.data = None


    def __len__(self):
        return self.num_samples

    def load_index(self):
        basename = os.path.basename(self.filename)
        basename = os.path.splitext(basename)[0]
        names = np.load(os.path.join(self.npy_index_dir, f"{basename}_names.npy"))
        offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy"))
        return names, offsets

    def __getitem__(self, idx):
        if self.data is None:
            self.data = np.memmap(self.filename, mode='r', dtype='uint8')
            _, self.offsets = self.load_index()

        ofs = self.offsets[idx] * 512
        fsize = 512 * (self.offsets[idx + 1] - self.offsets[idx])
        data = self.data[ofs:ofs + fsize]

        if data[:13].tostring() == '././@LongLink':
            data = data[3 * 512:]
        else:
            data = data[512:]

        # just to make it more fun a few JPEGs are GZIP compressed...
        # catch this case
        if tuple(data[:2]) == (0x1f, 0x8b):
            s = io.BytesIO(data.tostring())
            g = gzip.GzipFile(None, 'r', 0, s)
            sdata = g.read()
        else:
            sdata = data.tostring()
        return io.BytesIO(sdata)