Spaces:
Paused
Paused
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to | |
collate samples fetched from dataset into Tensor(s). | |
These **needs** to be in global scope since Py2 doesn't support serializing | |
static methods. | |
`default_collate` and `default_convert` are exposed to users via 'dataloader.py'. | |
""" | |
import torch | |
import re | |
import collections | |
from torch._six import string_classes | |
np_str_obj_array_pattern = re.compile(r'[SaUO]') | |
def default_convert(data): | |
r""" | |
Function that converts each NumPy array element into a :class:`torch.Tensor`. If the input is a `Sequence`, | |
`Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`. | |
If the input is not an NumPy array, it is left unchanged. | |
This is used as the default function for collation when both `batch_sampler` and | |
`batch_size` are NOT defined in :class:`~torch.utils.data.DataLoader`. | |
The general input type to output type mapping is similar to that | |
of :func:`~torch.utils.data.default_collate`. See the description there for more details. | |
Args: | |
data: a single data point to be converted | |
Examples: | |
>>> # Example with `int` | |
>>> default_convert(0) | |
0 | |
>>> # Example with NumPy array | |
>>> # xdoctest: +SKIP | |
>>> default_convert(np.array([0, 1])) | |
tensor([0, 1]) | |
>>> # Example with NamedTuple | |
>>> Point = namedtuple('Point', ['x', 'y']) | |
>>> default_convert(Point(0, 0)) | |
Point(x=0, y=0) | |
>>> default_convert(Point(np.array(0), np.array(0))) | |
Point(x=tensor(0), y=tensor(0)) | |
>>> # Example with List | |
>>> default_convert([np.array([0, 1]), np.array([2, 3])]) | |
[tensor([0, 1]), tensor([2, 3])] | |
""" | |
elem_type = type(data) | |
if isinstance(data, torch.Tensor): | |
return data | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
# array of string classes and object | |
if elem_type.__name__ == 'ndarray' \ | |
and np_str_obj_array_pattern.search(data.dtype.str) is not None: | |
return data | |
return torch.as_tensor(data) | |
elif isinstance(data, collections.abc.Mapping): | |
try: | |
return elem_type({key: default_convert(data[key]) for key in data}) | |
except TypeError: | |
# The mapping type may not support `__init__(iterable)`. | |
return {key: default_convert(data[key]) for key in data} | |
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple | |
return elem_type(*(default_convert(d) for d in data)) | |
elif isinstance(data, tuple): | |
return [default_convert(d) for d in data] # Backwards compatibility. | |
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, string_classes): | |
try: | |
return elem_type([default_convert(d) for d in data]) | |
except TypeError: | |
# The sequence type may not support `__init__(iterable)` (e.g., `range`). | |
return [default_convert(d) for d in data] | |
else: | |
return data | |
default_collate_err_msg_format = ( | |
"default_collate: batch must contain tensors, numpy arrays, numbers, " | |
"dicts or lists; found {}") | |
def default_collate(batch): | |
r""" | |
Function that takes in a batch of data and puts the elements within the batch | |
into a tensor with an additional outer dimension - batch size. The exact output type can be | |
a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a | |
Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type. | |
This is used as the default function for collation when | |
`batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`. | |
Here is the general input type (based on the type of the element within the batch) to output type mapping: | |
* :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size) | |
* NumPy Arrays -> :class:`torch.Tensor` | |
* `float` -> :class:`torch.Tensor` | |
* `int` -> :class:`torch.Tensor` | |
* `str` -> `str` (unchanged) | |
* `bytes` -> `bytes` (unchanged) | |
* `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]` | |
* `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]), | |
default_collate([V2_1, V2_2, ...]), ...]` | |
* `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]), | |
default_collate([V2_1, V2_2, ...]), ...]` | |
Args: | |
batch: a single batch to be collated | |
Examples: | |
>>> # Example with a batch of `int`s: | |
>>> default_collate([0, 1, 2, 3]) | |
tensor([0, 1, 2, 3]) | |
>>> # Example with a batch of `str`s: | |
>>> default_collate(['a', 'b', 'c']) | |
['a', 'b', 'c'] | |
>>> # Example with `Map` inside the batch: | |
>>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}]) | |
{'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} | |
>>> # Example with `NamedTuple` inside the batch: | |
>>> # xdoctest: +SKIP | |
>>> Point = namedtuple('Point', ['x', 'y']) | |
>>> default_collate([Point(0, 0), Point(1, 1)]) | |
Point(x=tensor([0, 1]), y=tensor([0, 1])) | |
>>> # Example with `Tuple` inside the batch: | |
>>> default_collate([(0, 1), (2, 3)]) | |
[tensor([0, 2]), tensor([1, 3])] | |
>>> # Example with `List` inside the batch: | |
>>> default_collate([[0, 1], [2, 3]]) | |
[tensor([0, 2]), tensor([1, 3])] | |
""" | |
elem = batch[0] | |
elem_type = type(elem) | |
if isinstance(elem, torch.Tensor): | |
out = None | |
if torch.utils.data.get_worker_info() is not None: | |
# If we're in a background process, concatenate directly into a | |
# shared memory tensor to avoid an extra copy | |
numel = sum(x.numel() for x in batch) | |
storage = elem.storage()._new_shared(numel, device=elem.device) | |
out = elem.new(storage).resize_(len(batch), *list(elem.size())) | |
return torch.stack(batch, 0, out=out) | |
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ | |
and elem_type.__name__ != 'string_': | |
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': | |
# array of string classes and object | |
if np_str_obj_array_pattern.search(elem.dtype.str) is not None: | |
raise TypeError(default_collate_err_msg_format.format(elem.dtype)) | |
return default_collate([torch.as_tensor(b) for b in batch]) | |
elif elem.shape == (): # scalars | |
return torch.as_tensor(batch) | |
elif isinstance(elem, float): | |
return torch.tensor(batch, dtype=torch.float64) | |
elif isinstance(elem, int): | |
return torch.tensor(batch) | |
elif isinstance(elem, string_classes): | |
return batch | |
elif isinstance(elem, collections.abc.Mapping): | |
try: | |
return elem_type({key: default_collate([d[key] for d in batch]) for key in elem}) | |
except TypeError: | |
# The mapping type may not support `__init__(iterable)`. | |
return {key: default_collate([d[key] for d in batch]) for key in elem} | |
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple | |
return elem_type(*(default_collate(samples) for samples in zip(*batch))) | |
elif isinstance(elem, collections.abc.Sequence): | |
# check to make sure that the elements in batch have consistent size | |
it = iter(batch) | |
elem_size = len(next(it)) | |
if not all(len(elem) == elem_size for elem in it): | |
raise RuntimeError('each element in list of batch should be of equal size') | |
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list. | |
if isinstance(elem, tuple): | |
return [default_collate(samples) for samples in transposed] # Backwards compatibility. | |
else: | |
try: | |
return elem_type([default_collate(samples) for samples in transposed]) | |
except TypeError: | |
# The sequence type may not support `__init__(iterable)` (e.g., `range`). | |
return [default_collate(samples) for samples in transposed] | |
raise TypeError(default_collate_err_msg_format.format(elem_type)) |