File size: 7,249 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import queue
import random
from threading import Thread
import time

import pyarrow as pa
import torch.multiprocessing as multiprocessing

import torch
from copy import deepcopy

string_classes = (str, bytes)
import collections.abc as container_abcs
import re

def pin_memory(data):
    if isinstance(data, torch.Tensor):
        return data.pin_memory()
    elif isinstance(data, string_classes):
        return data
    elif isinstance(data, container_abcs.Mapping):
        return {k: pin_memory(sample) for k, sample in data.items()}
    elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
        return type(data)(*(pin_memory(sample) for sample in data))
    elif isinstance(data, container_abcs.Sequence):
        return [pin_memory(sample) for sample in data]
    elif hasattr(data, "pin_memory"):
        return data.pin_memory()
    else:
        return data


np_str_obj_array_pattern = re.compile(r'[SaUO]')
default_collate_err_msg_format = (
    "default_collate: batch must contain tensors, numpy arrays, numbers, "
    "dicts or lists; found {}")


def default_collate(batch):
    r"""Puts each data field into a tensor with outer dimension batch size"""

    elem = batch[0]
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        out = None
        if torch.utils.data.get_worker_info() is not None:
            # If we're in a background process, concatenate directly into a
            # shared memory tensor to avoid an extra copy
            numel = sum([x.numel() for x in batch])
            storage = elem.storage()._new_shared(numel)
            out = elem.new(storage)
        return torch.stack(batch, 0, out=out)
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        elem = batch[0]
        if elem_type.__name__ == 'ndarray':
            # array of string classes and object
            if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
                raise TypeError(default_collate_err_msg_format.format(elem.dtype))

            return default_collate([torch.as_tensor(b) for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, string_classes):
        return batch
    elif isinstance(elem, container_abcs.Mapping):
        return {key: default_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, tuple) and hasattr(elem, '_fields'):  # namedtuple
        return elem_type(*(default_collate(samples) for samples in zip(*batch)))
    elif isinstance(elem, container_abcs.Sequence):
        transposed = zip(*batch)
        return [default_collate(samples) for samples in transposed]

    raise TypeError(default_collate_err_msg_format.format(elem_type))


class CircularCachedInputIterator(object):
    """
    chunk: a serialized List[Dict] in the apache arrow format,
        could be sequentially loaded into memory with minimum deserialization cost(<1ms)
    shard: a part of dataset which is allocated to a specific rank(process) in the world,
        generally contains multiple chunks

    main thread:
        - populate chunk_index_queue
        - swap new chunk and old chunk
    prefetch threads:
        - fetch chunk_index_queue
        - populate loaded_chunk_queue

    chunk_index_queue: main -> prefetch, used for shuffling chunk order per epoch
    loaded_chunk_queue: preftch -> main, a limited-size channel for prefetching worker to send back result
    """
    def __init__(self,
                 input_map,
                 batch_size,
                 chunk_path_list,
                 num_data_point,
                 num_shards,
                 shard_id,
                 random_shuffle,
                 num_prefetch_chunk=4,
                 num_worker=4):
        self.input_map = input_map
        self.batch_size = batch_size
        self.num_shareds = num_shards
        self.shard_id = shard_id
        self.random_shuffle = random_shuffle
        self.num_data_point = num_data_point
        self.chunk_filename_list = chunk_path_list
        self.chunk = None
        self.next_chunk_queue = queue.Queue(num_prefetch_chunk)
        self.index_queue = queue.Queue()
        self.chunk_index_queue = queue.Queue()
        self.num_chunk_in_shard = None
        self.chunk_indexes = None
        self.worker = None
        self.num_worker = num_worker
        self.setup_shard()
        self.warmup_cache()

    def setup_shard(self):
        # ensure each shard has the same of of chunks per epoch
        # this might not be necessary
        self.num_chunk_in_shard = len(self.chunk_filename_list) // self.num_shareds
        # [start, end)
        shard_start = self.num_chunk_in_shard * self.shard_id
        shard_end = len(self.chunk_filename_list) if self.shard_id == self.num_shareds - 1 else self.num_chunk_in_shard * (self.shard_id + 1)
        self.chunk_indexes = list(range(shard_start, shard_end))

    def _chunk_prefetch_worker(self):
        while True:
            chunk_index = self.get_chunk_index()
            chunk_filename = self.chunk_filename_list[chunk_index]
            with open(chunk_filename, "rb") as fin:
                chunk = pa.deserialize_from(fin, None)
            self.next_chunk_queue.put(chunk)

    def warmup_cache(self):
        self.worker = [Thread(target=self._chunk_prefetch_worker, args=[]) for _ in range(self.num_worker)]
        for worker in self.worker:
            worker.daemon = True
            worker.start()

    def get_chunk_index(self):
        if self.chunk_index_queue.empty():
            if self.random_shuffle:
                random.shuffle(self.chunk_indexes)
            for ind in self.chunk_indexes[:self.num_chunk_in_shard]:
                self.chunk_index_queue.put(ind)
        return self.chunk_index_queue.get()

    def get_index(self):
        if self.index_queue.empty():
            if self.chunk is not None:
                del self.chunk  # release memory
            self.chunk = self.next_chunk_queue.get()
            self.indexes = list(range(len(self.chunk)))
            if self.random_shuffle:
                random.shuffle(self.indexes)
            # keep all shards of the same size
            for ind in self.indexes:
                self.index_queue.put(ind)
        return self.index_queue.get()

    def epoch_size(self):
        return self.num_data_point // self.num_shareds

    def __iter__(self):
        return self

    def __next__(self):
        datas = tuple([] for _ in self.input_map)
        for _ in range(self.batch_size):
            ind = self.get_index()
            data = self.chunk[ind]
            # value = data['jpeg']
            # label = data['label']
            # # DO NOT reuse the buffer
            # jpegs.append(value)
            # labels.append(np.array([label], dtype=np.int32))
            # datas.append(data)
            for i, k in enumerate(self.input_map):
                datas[i].append(deepcopy(data[k]))
        return datas

    next = __next__