unit_test / uniperceiver /datasets /circular_cached_loader.py
herrius's picture
Upload 259 files
32b542e
import queue
import random
from threading import Thread
import time
import pyarrow as pa
import torch.multiprocessing as multiprocessing
import torch
from copy import deepcopy
string_classes = (str, bytes)
import collections.abc as container_abcs
import re
def pin_memory(data):
if isinstance(data, torch.Tensor):
return data.pin_memory()
elif isinstance(data, string_classes):
return data
elif isinstance(data, container_abcs.Mapping):
return {k: pin_memory(sample) for k, sample in data.items()}
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
return type(data)(*(pin_memory(sample) for sample in data))
elif isinstance(data, container_abcs.Sequence):
return [pin_memory(sample) for sample in data]
elif hasattr(data, "pin_memory"):
return data.pin_memory()
else:
return data
np_str_obj_array_pattern = re.compile(r'[SaUO]')
default_collate_err_msg_format = (
"default_collate: batch must contain tensors, numpy arrays, numbers, "
"dicts or lists; found {}")
def default_collate(batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
elem = batch[0]
if elem_type.__name__ == 'ndarray':
# array of string classes and object
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
return default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, string_classes):
return batch
elif isinstance(elem, container_abcs.Mapping):
return {key: default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, container_abcs.Sequence):
transposed = zip(*batch)
return [default_collate(samples) for samples in transposed]
raise TypeError(default_collate_err_msg_format.format(elem_type))
class CircularCachedInputIterator(object):
"""
chunk: a serialized List[Dict] in the apache arrow format,
could be sequentially loaded into memory with minimum deserialization cost(<1ms)
shard: a part of dataset which is allocated to a specific rank(process) in the world,
generally contains multiple chunks
main thread:
- populate chunk_index_queue
- swap new chunk and old chunk
prefetch threads:
- fetch chunk_index_queue
- populate loaded_chunk_queue
chunk_index_queue: main -> prefetch, used for shuffling chunk order per epoch
loaded_chunk_queue: preftch -> main, a limited-size channel for prefetching worker to send back result
"""
def __init__(self,
input_map,
batch_size,
chunk_path_list,
num_data_point,
num_shards,
shard_id,
random_shuffle,
num_prefetch_chunk=4,
num_worker=4):
self.input_map = input_map
self.batch_size = batch_size
self.num_shareds = num_shards
self.shard_id = shard_id
self.random_shuffle = random_shuffle
self.num_data_point = num_data_point
self.chunk_filename_list = chunk_path_list
self.chunk = None
self.next_chunk_queue = queue.Queue(num_prefetch_chunk)
self.index_queue = queue.Queue()
self.chunk_index_queue = queue.Queue()
self.num_chunk_in_shard = None
self.chunk_indexes = None
self.worker = None
self.num_worker = num_worker
self.setup_shard()
self.warmup_cache()
def setup_shard(self):
# ensure each shard has the same of of chunks per epoch
# this might not be necessary
self.num_chunk_in_shard = len(self.chunk_filename_list) // self.num_shareds
# [start, end)
shard_start = self.num_chunk_in_shard * self.shard_id
shard_end = len(self.chunk_filename_list) if self.shard_id == self.num_shareds - 1 else self.num_chunk_in_shard * (self.shard_id + 1)
self.chunk_indexes = list(range(shard_start, shard_end))
def _chunk_prefetch_worker(self):
while True:
chunk_index = self.get_chunk_index()
chunk_filename = self.chunk_filename_list[chunk_index]
with open(chunk_filename, "rb") as fin:
chunk = pa.deserialize_from(fin, None)
self.next_chunk_queue.put(chunk)
def warmup_cache(self):
self.worker = [Thread(target=self._chunk_prefetch_worker, args=[]) for _ in range(self.num_worker)]
for worker in self.worker:
worker.daemon = True
worker.start()
def get_chunk_index(self):
if self.chunk_index_queue.empty():
if self.random_shuffle:
random.shuffle(self.chunk_indexes)
for ind in self.chunk_indexes[:self.num_chunk_in_shard]:
self.chunk_index_queue.put(ind)
return self.chunk_index_queue.get()
def get_index(self):
if self.index_queue.empty():
if self.chunk is not None:
del self.chunk # release memory
self.chunk = self.next_chunk_queue.get()
self.indexes = list(range(len(self.chunk)))
if self.random_shuffle:
random.shuffle(self.indexes)
# keep all shards of the same size
for ind in self.indexes:
self.index_queue.put(ind)
return self.index_queue.get()
def epoch_size(self):
return self.num_data_point // self.num_shareds
def __iter__(self):
return self
def __next__(self):
datas = tuple([] for _ in self.input_map)
for _ in range(self.batch_size):
ind = self.get_index()
data = self.chunk[ind]
# value = data['jpeg']
# label = data['label']
# # DO NOT reuse the buffer
# jpegs.append(value)
# labels.append(np.array([label], dtype=np.int32))
# datas.append(data)
for i, k in enumerate(self.input_map):
datas[i].append(deepcopy(data[k]))
return datas
next = __next__