Spaces:
Sleeping
Sleeping
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to | |
collate samples fetched from dataset into Tensor(s). | |
These **needs** to be in global scope since Py2 doesn't support serializing | |
static methods. | |
""" | |
import torch | |
import re | |
import collections | |
np_str_obj_array_pattern = re.compile(r'[SaUO]') | |
def default_convert(data): | |
r"""Converts each NumPy array data field into a tensor""" | |
elem_type = type(data) | |
if isinstance(data, torch.Tensor): | |
return data | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
# array of string classes and object | |
if elem_type.__name__ == 'ndarray' \ | |
and np_str_obj_array_pattern.search(data.dtype.str) is not None: | |
return data | |
return torch.as_tensor(data) | |
elif isinstance(data, collections.abc.Mapping): | |
return {key: default_convert(data[key]) for key in data} | |
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple | |
return elem_type(*(default_convert(d) for d in data)) | |
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str): | |
return [default_convert(d) for d in data] | |
else: | |
return data | |
default_collate_err_msg_format = ( | |
"default_collate: batch must contain tensors, numpy arrays, numbers, " | |
"dicts or lists; found {}") | |
def default_collate(batch): | |
r"""Puts each data field into a tensor with outer dimension batch size""" | |
elem = batch[0] | |
elem_type = type(elem) | |
if isinstance(elem, torch.Tensor): | |
out = None | |
if torch.utils.data.get_worker_info() is not None: | |
# If we're in a background process, concatenate directly into a | |
# shared memory tensor to avoid an extra copy | |
numel = sum(x.numel() for x in batch) | |
storage = elem.storage()._new_shared(numel) | |
out = elem.new(storage) | |
return torch.stack(batch, 0, out=out) | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': | |
# array of string classes and object | |
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |
return default_collate([torch.as_tensor(b) for b in batch]) | |
elif elem.shape == (): # scalars | |
return torch.as_tensor(batch) | |
elif isinstance(elem, float): | |
return torch.tensor(batch, dtype=torch.float64) | |
elif isinstance(elem, int): | |
return torch.tensor(batch) | |
elif isinstance(elem, str): | |
return batch | |
elif isinstance(elem, collections.abc.Mapping): | |
return {key: default_collate([d[key] for d in batch]) for key in elem} | |
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | |
return elem_type(*(default_collate(samples) for samples in zip(*batch))) | |
elif isinstance(elem, collections.abc.Sequence): | |
# check to make sure that the elements in batch have consistent size | |
it = iter(batch) | |
elem_size = len(next(it)) | |
if not all(len(elem) == elem_size for elem in it): | |
raise RuntimeError('each element in list of batch should be of equal size') | |
transposed = zip(*batch) | |
return [default_collate(samples) for samples in transposed] | |
raise TypeError(default_collate_err_msg_format.format(elem_type)) |