File size: 7,249 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import queue
import random
from threading import Thread
import time
import pyarrow as pa
import torch.multiprocessing as multiprocessing
import torch
from copy import deepcopy
string_classes = (str, bytes)
import collections.abc as container_abcs
import re
def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, container_abcs.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, container_abcs.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data
np_str_obj_array_pattern = re.compile(r'[SaUO]')
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
class CircularCachedInputIterator(object):
"""
chunk: a serialized List[Dict] in the apache arrow format,
could be sequentially loaded into memory with minimum deserialization cost(<1ms)
shard: a part of dataset which is allocated to a specific rank(process) in the world,
generally contains multiple chunks
main thread:
- populate chunk_index_queue
- swap new chunk and old chunk
prefetch threads:
- fetch chunk_index_queue
- populate loaded_chunk_queue
chunk_index_queue: main -> prefetch, used for shuffling chunk order per epoch
loaded_chunk_queue: preftch -> main, a limited-size channel for prefetching worker to send back result
"""
def __init__(self,
input_map,
batch_size,
chunk_path_list,
num_data_point,
num_shards,
shard_id,
random_shuffle,
num_prefetch_chunk=4,
num_worker=4):
self.input_map = input_map
self.batch_size = batch_size
self.num_shareds = num_shards
self.shard_id = shard_id
self.random_shuffle = random_shuffle
self.num_data_point = num_data_point
self.chunk_filename_list = chunk_path_list
self.chunk = None
self.next_chunk_queue = queue.Queue(num_prefetch_chunk)
self.index_queue = queue.Queue()
self.chunk_index_queue = queue.Queue()
self.num_chunk_in_shard = None
self.chunk_indexes = None
self.worker = None
self.num_worker = num_worker
self.setup_shard()
self.warmup_cache()
def setup_shard(self):
# ensure each shard has the same of of chunks per epoch
# this might not be necessary
self.num_chunk_in_shard = len(self.chunk_filename_list) // self.num_shareds
# [start, end)
shard_start = self.num_chunk_in_shard * self.shard_id
shard_end = len(self.chunk_filename_list) if self.shard_id == self.num_shareds - 1 else self.num_chunk_in_shard * (self.shard_id + 1)
self.chunk_indexes = list(range(shard_start, shard_end))
def _chunk_prefetch_worker(self):
while True:
chunk_index = self.get_chunk_index()
chunk_filename = self.chunk_filename_list[chunk_index]
with open(chunk_filename, "rb") as fin:
chunk = pa.deserialize_from(fin, None)
self.next_chunk_queue.put(chunk)
def warmup_cache(self):
self.worker = [Thread(target=self._chunk_prefetch_worker, args=[]) for _ in range(self.num_worker)]
for worker in self.worker:
worker.daemon = True
worker.start()
def get_chunk_index(self):
if self.chunk_index_queue.empty():
if self.random_shuffle:
random.shuffle(self.chunk_indexes)
for ind in self.chunk_indexes[:self.num_chunk_in_shard]:
self.chunk_index_queue.put(ind)
return self.chunk_index_queue.get()
def get_index(self):
if self.index_queue.empty():
if self.chunk is not None:
del self.chunk # release memory
self.chunk = self.next_chunk_queue.get()
self.indexes = list(range(len(self.chunk)))
if self.random_shuffle:
random.shuffle(self.indexes)
# keep all shards of the same size
for ind in self.indexes:
self.index_queue.put(ind)
return self.index_queue.get()
def epoch_size(self):
return self.num_data_point // self.num_shareds
def __iter__(self):
return self
def __next__(self):
datas = tuple([] for _ in self.input_map)
for _ in range(self.batch_size):
ind = self.get_index()
data = self.chunk[ind]
# value = data['jpeg']
# label = data['label']
# # DO NOT reuse the buffer
# jpegs.append(value)
# labels.append(np.array([label], dtype=np.int32))
# datas.append(data)
for i, k in enumerate(self.input_map):
datas[i].append(deepcopy(data[k]))
return datas
next = __next__
|