|
""" |
|
Author: Paul-Edouard Sarlin (skydes) |
|
""" |
|
|
|
import collections.abc as collections |
|
import functools |
|
import inspect |
|
from typing import Callable, List, Tuple |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
string_classes = (str, bytes) |
|
|
|
|
|
def autocast(func: Callable) -> Callable: |
|
"""Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays. |
|
|
|
Use the device and dtype of the wrapper. |
|
|
|
Args: |
|
func (Callable): Method of a TensorWrapper class. |
|
|
|
Returns: |
|
Callable: Wrapped method. |
|
""" |
|
|
|
@functools.wraps(func) |
|
def wrap(self, *args): |
|
device = torch.device("cpu") |
|
dtype = None |
|
if isinstance(self, TensorWrapper): |
|
if self._data is not None: |
|
device = self.device |
|
dtype = self.dtype |
|
elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): |
|
raise ValueError(self) |
|
|
|
cast_args = [] |
|
for arg in args: |
|
if isinstance(arg, np.ndarray): |
|
arg = torch.from_numpy(arg) |
|
arg = arg.to(device=device, dtype=dtype) |
|
cast_args.append(arg) |
|
return func(self, *cast_args) |
|
|
|
return wrap |
|
|
|
|
|
class TensorWrapper: |
|
"""Wrapper for PyTorch tensors.""" |
|
|
|
_data = None |
|
|
|
@autocast |
|
def __init__(self, data: torch.Tensor): |
|
"""Wrapper for PyTorch tensors.""" |
|
self._data = data |
|
|
|
@property |
|
def shape(self) -> torch.Size: |
|
"""Shape of the underlying tensor.""" |
|
return self._data.shape[:-1] |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
"""Get the device of the underlying tensor.""" |
|
return self._data.device |
|
|
|
@property |
|
def dtype(self) -> torch.dtype: |
|
"""Get the dtype of the underlying tensor.""" |
|
return self._data.dtype |
|
|
|
def __getitem__(self, index) -> torch.Tensor: |
|
"""Get the underlying tensor.""" |
|
return self.__class__(self._data[index]) |
|
|
|
def __setitem__(self, index, item): |
|
"""Set the underlying tensor.""" |
|
self._data[index] = item.data |
|
|
|
def to(self, *args, **kwargs): |
|
"""Move the underlying tensor to a new device.""" |
|
return self.__class__(self._data.to(*args, **kwargs)) |
|
|
|
def cpu(self): |
|
"""Move the underlying tensor to the CPU.""" |
|
return self.__class__(self._data.cpu()) |
|
|
|
def cuda(self): |
|
"""Move the underlying tensor to the GPU.""" |
|
return self.__class__(self._data.cuda()) |
|
|
|
def pin_memory(self): |
|
"""Pin the underlying tensor to memory.""" |
|
return self.__class__(self._data.pin_memory()) |
|
|
|
def float(self): |
|
"""Cast the underlying tensor to float.""" |
|
return self.__class__(self._data.float()) |
|
|
|
def double(self): |
|
"""Cast the underlying tensor to double.""" |
|
return self.__class__(self._data.double()) |
|
|
|
def detach(self): |
|
"""Detach the underlying tensor.""" |
|
return self.__class__(self._data.detach()) |
|
|
|
def numpy(self): |
|
"""Convert the underlying tensor to a numpy array.""" |
|
return self._data.detach().cpu().numpy() |
|
|
|
def new_tensor(self, *args, **kwargs): |
|
"""Create a new tensor of the same type and device.""" |
|
return self._data.new_tensor(*args, **kwargs) |
|
|
|
def new_zeros(self, *args, **kwargs): |
|
"""Create a new tensor of the same type and device.""" |
|
return self._data.new_zeros(*args, **kwargs) |
|
|
|
def new_ones(self, *args, **kwargs): |
|
"""Create a new tensor of the same type and device.""" |
|
return self._data.new_ones(*args, **kwargs) |
|
|
|
def new_full(self, *args, **kwargs): |
|
"""Create a new tensor of the same type and device.""" |
|
return self._data.new_full(*args, **kwargs) |
|
|
|
def new_empty(self, *args, **kwargs): |
|
"""Create a new tensor of the same type and device.""" |
|
return self._data.new_empty(*args, **kwargs) |
|
|
|
def unsqueeze(self, *args, **kwargs): |
|
"""Create a new tensor of the same type and device.""" |
|
return self.__class__(self._data.unsqueeze(*args, **kwargs)) |
|
|
|
def squeeze(self, *args, **kwargs): |
|
"""Create a new tensor of the same type and device.""" |
|
return self.__class__(self._data.squeeze(*args, **kwargs)) |
|
|
|
@classmethod |
|
def stack(cls, objects: List, dim=0, *, out=None): |
|
"""Stack a list of objects with the same type and shape.""" |
|
data = torch.stack([obj._data for obj in objects], dim=dim, out=out) |
|
return cls(data) |
|
|
|
@classmethod |
|
def __torch_function__(cls, func, types, args=(), kwargs=None): |
|
"""Support torch functions.""" |
|
if kwargs is None: |
|
kwargs = {} |
|
return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented |
|
|
|
|
|
def map_tensor(input_, func): |
|
if isinstance(input_, string_classes): |
|
return input_ |
|
elif isinstance(input_, collections.Mapping): |
|
return {k: map_tensor(sample, func) for k, sample in input_.items()} |
|
elif isinstance(input_, collections.Sequence): |
|
return [map_tensor(sample, func) for sample in input_] |
|
elif input_ is None: |
|
return None |
|
else: |
|
return func(input_) |
|
|
|
|
|
def batch_to_numpy(batch): |
|
return map_tensor(batch, lambda tensor: tensor.cpu().numpy()) |
|
|
|
|
|
def batch_to_device(batch, device, non_blocking=True, detach=False): |
|
def _func(tensor): |
|
t = tensor.to(device=device, non_blocking=non_blocking, dtype=torch.float32) |
|
return t.detach() if detach else t |
|
|
|
return map_tensor(batch, _func) |
|
|
|
|
|
def remove_batch_dim(data: dict) -> dict: |
|
"""Remove batch dimension from elements in data""" |
|
return { |
|
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items() |
|
} |
|
|
|
|
|
def add_batch_dim(data: dict) -> dict: |
|
"""Add batch dimension to elements in data""" |
|
return { |
|
k: v[None] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v |
|
for k, v in data.items() |
|
} |
|
|
|
|
|
def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False): |
|
"""Get padding to make the image size a multiple of the given number. |
|
|
|
Args: |
|
x (torch.Tensor): Input tensor. |
|
multiple (int, optional): Multiple. |
|
crop (bool, optional): Whether to crop or pad. Defaults to False. |
|
|
|
Returns: |
|
torch.Tensor: Padding. |
|
""" |
|
h, w = x.shape[-2:] |
|
|
|
if crop: |
|
pad_w = (w // multiple) * multiple - w |
|
pad_h = (h // multiple) * multiple - h |
|
else: |
|
pad_w = (multiple - w % multiple) % multiple |
|
pad_h = (multiple - h % multiple) % multiple |
|
|
|
if mode == "center": |
|
pad_l = pad_w // 2 |
|
pad_r = pad_w - pad_l |
|
pad_t = pad_h // 2 |
|
pad_b = pad_h - pad_t |
|
elif mode == "left": |
|
pad_l = 0 |
|
pad_r = pad_w |
|
pad_t = 0 |
|
pad_b = pad_h |
|
else: |
|
raise ValueError(f"Unknown mode {mode}") |
|
|
|
return (pad_l, pad_r, pad_t, pad_b) |
|
|
|
|
|
def fit_features_to_multiple( |
|
features: torch.Tensor, multiple: int = 32, crop: bool = False |
|
) -> Tuple[torch.Tensor, Tuple[int, int]]: |
|
"""Pad image to a multiple of the given number. |
|
|
|
Args: |
|
features (torch.Tensor): Input features. |
|
multiple (int, optional): Multiple. Defaults to 32. |
|
crop (bool, optional): Whether to crop or pad. Defaults to False. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding. |
|
""" |
|
pad = fit_to_multiple(features, multiple, crop=crop) |
|
return torch.nn.functional.pad(features, pad, mode="reflect"), pad |
|
|