|
import queue |
|
import random |
|
from threading import Thread |
|
import time |
|
|
|
import pyarrow as pa |
|
import torch.multiprocessing as multiprocessing |
|
|
|
import torch |
|
from copy import deepcopy |
|
|
|
string_classes = (str, bytes) |
|
import collections.abc as container_abcs |
|
import re |
|
|
|
def pin_memory(data): |
|
if isinstance(data, torch.Tensor): |
|
return data.pin_memory() |
|
elif isinstance(data, string_classes): |
|
return data |
|
elif isinstance(data, container_abcs.Mapping): |
|
return {k: pin_memory(sample) for k, sample in data.items()} |
|
elif isinstance(data, tuple) and hasattr(data, '_fields'): |
|
return type(data)(*(pin_memory(sample) for sample in data)) |
|
elif isinstance(data, container_abcs.Sequence): |
|
return [pin_memory(sample) for sample in data] |
|
elif hasattr(data, "pin_memory"): |
|
return data.pin_memory() |
|
else: |
|
return data |
|
|
|
|
|
np_str_obj_array_pattern = re.compile(r'[SaUO]') |
|
default_collate_err_msg_format = ( |
|
"default_collate: batch must contain tensors, numpy arrays, numbers, " |
|
"dicts or lists; found {}") |
|
|
|
|
|
def default_collate(batch): |
|
r"""Puts each data field into a tensor with outer dimension batch size""" |
|
|
|
elem = batch[0] |
|
elem_type = type(elem) |
|
if isinstance(elem, torch.Tensor): |
|
out = None |
|
if torch.utils.data.get_worker_info() is not None: |
|
|
|
|
|
numel = sum([x.numel() for x in batch]) |
|
storage = elem.storage()._new_shared(numel) |
|
out = elem.new(storage) |
|
return torch.stack(batch, 0, out=out) |
|
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ |
|
and elem_type.__name__ != 'string_': |
|
elem = batch[0] |
|
if elem_type.__name__ == 'ndarray': |
|
|
|
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
|
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) |
|
|
|
return default_collate([torch.as_tensor(b) for b in batch]) |
|
elif elem.shape == (): |
|
return torch.as_tensor(batch) |
|
elif isinstance(elem, float): |
|
return torch.tensor(batch, dtype=torch.float64) |
|
elif isinstance(elem, int): |
|
return torch.tensor(batch) |
|
elif isinstance(elem, string_classes): |
|
return batch |
|
elif isinstance(elem, container_abcs.Mapping): |
|
return {key: default_collate([d[key] for d in batch]) for key in elem} |
|
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): |
|
return elem_type(*(default_collate(samples) for samples in zip(*batch))) |
|
elif isinstance(elem, container_abcs.Sequence): |
|
transposed = zip(*batch) |
|
return [default_collate(samples) for samples in transposed] |
|
|
|
raise TypeError(default_collate_err_msg_format.format(elem_type)) |
|
|
|
|
|
class CircularCachedInputIterator(object): |
|
""" |
|
chunk: a serialized List[Dict] in the apache arrow format, |
|
could be sequentially loaded into memory with minimum deserialization cost(<1ms) |
|
shard: a part of dataset which is allocated to a specific rank(process) in the world, |
|
generally contains multiple chunks |
|
|
|
main thread: |
|
- populate chunk_index_queue |
|
- swap new chunk and old chunk |
|
prefetch threads: |
|
- fetch chunk_index_queue |
|
- populate loaded_chunk_queue |
|
|
|
chunk_index_queue: main -> prefetch, used for shuffling chunk order per epoch |
|
loaded_chunk_queue: preftch -> main, a limited-size channel for prefetching worker to send back result |
|
""" |
|
def __init__(self, |
|
input_map, |
|
batch_size, |
|
chunk_path_list, |
|
num_data_point, |
|
num_shards, |
|
shard_id, |
|
random_shuffle, |
|
num_prefetch_chunk=4, |
|
num_worker=4): |
|
self.input_map = input_map |
|
self.batch_size = batch_size |
|
self.num_shareds = num_shards |
|
self.shard_id = shard_id |
|
self.random_shuffle = random_shuffle |
|
self.num_data_point = num_data_point |
|
self.chunk_filename_list = chunk_path_list |
|
self.chunk = None |
|
self.next_chunk_queue = queue.Queue(num_prefetch_chunk) |
|
self.index_queue = queue.Queue() |
|
self.chunk_index_queue = queue.Queue() |
|
self.num_chunk_in_shard = None |
|
self.chunk_indexes = None |
|
self.worker = None |
|
self.num_worker = num_worker |
|
self.setup_shard() |
|
self.warmup_cache() |
|
|
|
def setup_shard(self): |
|
|
|
|
|
self.num_chunk_in_shard = len(self.chunk_filename_list) // self.num_shareds |
|
|
|
shard_start = self.num_chunk_in_shard * self.shard_id |
|
shard_end = len(self.chunk_filename_list) if self.shard_id == self.num_shareds - 1 else self.num_chunk_in_shard * (self.shard_id + 1) |
|
self.chunk_indexes = list(range(shard_start, shard_end)) |
|
|
|
def _chunk_prefetch_worker(self): |
|
while True: |
|
chunk_index = self.get_chunk_index() |
|
chunk_filename = self.chunk_filename_list[chunk_index] |
|
with open(chunk_filename, "rb") as fin: |
|
chunk = pa.deserialize_from(fin, None) |
|
self.next_chunk_queue.put(chunk) |
|
|
|
def warmup_cache(self): |
|
self.worker = [Thread(target=self._chunk_prefetch_worker, args=[]) for _ in range(self.num_worker)] |
|
for worker in self.worker: |
|
worker.daemon = True |
|
worker.start() |
|
|
|
def get_chunk_index(self): |
|
if self.chunk_index_queue.empty(): |
|
if self.random_shuffle: |
|
random.shuffle(self.chunk_indexes) |
|
for ind in self.chunk_indexes[:self.num_chunk_in_shard]: |
|
self.chunk_index_queue.put(ind) |
|
return self.chunk_index_queue.get() |
|
|
|
def get_index(self): |
|
if self.index_queue.empty(): |
|
if self.chunk is not None: |
|
del self.chunk |
|
self.chunk = self.next_chunk_queue.get() |
|
self.indexes = list(range(len(self.chunk))) |
|
if self.random_shuffle: |
|
random.shuffle(self.indexes) |
|
|
|
for ind in self.indexes: |
|
self.index_queue.put(ind) |
|
return self.index_queue.get() |
|
|
|
def epoch_size(self): |
|
return self.num_data_point // self.num_shareds |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
datas = tuple([] for _ in self.input_map) |
|
for _ in range(self.batch_size): |
|
ind = self.get_index() |
|
data = self.chunk[ind] |
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, k in enumerate(self.input_map): |
|
datas[i].append(deepcopy(data[k])) |
|
return datas |
|
|
|
next = __next__ |
|
|
|
|