GeoCalib / siclib /utils /tensor.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
7.6 kB
"""
Author: Paul-Edouard Sarlin (skydes)
"""
import collections.abc as collections
import functools
import inspect
from typing import Callable, List, Tuple
import numpy as np
import torch
# flake8: noqa
# mypy: ignore-errors
string_classes = (str, bytes)
def autocast(func: Callable) -> Callable:
"""Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
Use the device and dtype of the wrapper.
Args:
func (Callable): Method of a TensorWrapper class.
Returns:
Callable: Wrapped method.
"""
@functools.wraps(func)
def wrap(self, *args):
device = torch.device("cpu")
dtype = None
if isinstance(self, TensorWrapper):
if self._data is not None:
device = self.device
dtype = self.dtype
elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
raise ValueError(self)
cast_args = []
for arg in args:
if isinstance(arg, np.ndarray):
arg = torch.from_numpy(arg)
arg = arg.to(device=device, dtype=dtype)
cast_args.append(arg)
return func(self, *cast_args)
return wrap
class TensorWrapper:
"""Wrapper for PyTorch tensors."""
_data = None
@autocast
def __init__(self, data: torch.Tensor):
"""Wrapper for PyTorch tensors."""
self._data = data
@property
def shape(self) -> torch.Size:
"""Shape of the underlying tensor."""
return self._data.shape[:-1]
@property
def device(self) -> torch.device:
"""Get the device of the underlying tensor."""
return self._data.device
@property
def dtype(self) -> torch.dtype:
"""Get the dtype of the underlying tensor."""
return self._data.dtype
def __getitem__(self, index) -> torch.Tensor:
"""Get the underlying tensor."""
return self.__class__(self._data[index])
def __setitem__(self, index, item):
"""Set the underlying tensor."""
self._data[index] = item.data
def to(self, *args, **kwargs):
"""Move the underlying tensor to a new device."""
return self.__class__(self._data.to(*args, **kwargs))
def cpu(self):
"""Move the underlying tensor to the CPU."""
return self.__class__(self._data.cpu())
def cuda(self):
"""Move the underlying tensor to the GPU."""
return self.__class__(self._data.cuda())
def pin_memory(self):
"""Pin the underlying tensor to memory."""
return self.__class__(self._data.pin_memory())
def float(self):
"""Cast the underlying tensor to float."""
return self.__class__(self._data.float())
def double(self):
"""Cast the underlying tensor to double."""
return self.__class__(self._data.double())
def detach(self):
"""Detach the underlying tensor."""
return self.__class__(self._data.detach())
def numpy(self):
"""Convert the underlying tensor to a numpy array."""
return self._data.detach().cpu().numpy()
def new_tensor(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_tensor(*args, **kwargs)
def new_zeros(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_zeros(*args, **kwargs)
def new_ones(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_ones(*args, **kwargs)
def new_full(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_full(*args, **kwargs)
def new_empty(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self._data.new_empty(*args, **kwargs)
def unsqueeze(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self.__class__(self._data.unsqueeze(*args, **kwargs))
def squeeze(self, *args, **kwargs):
"""Create a new tensor of the same type and device."""
return self.__class__(self._data.squeeze(*args, **kwargs))
@classmethod
def stack(cls, objects: List, dim=0, *, out=None):
"""Stack a list of objects with the same type and shape."""
data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
return cls(data)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""Support torch functions."""
if kwargs is None:
kwargs = {}
return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
def map_tensor(input_, func):
if isinstance(input_, string_classes):
return input_
elif isinstance(input_, collections.Mapping):
return {k: map_tensor(sample, func) for k, sample in input_.items()}
elif isinstance(input_, collections.Sequence):
return [map_tensor(sample, func) for sample in input_]
elif input_ is None:
return None
else:
return func(input_)
def batch_to_numpy(batch):
return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
def batch_to_device(batch, device, non_blocking=True, detach=False):
def _func(tensor):
t = tensor.to(device=device, non_blocking=non_blocking, dtype=torch.float32)
return t.detach() if detach else t
return map_tensor(batch, _func)
def remove_batch_dim(data: dict) -> dict:
"""Remove batch dimension from elements in data"""
return {
k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()
}
def add_batch_dim(data: dict) -> dict:
"""Add batch dimension to elements in data"""
return {
k: v[None] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
for k, v in data.items()
}
def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False):
"""Get padding to make the image size a multiple of the given number.
Args:
x (torch.Tensor): Input tensor.
multiple (int, optional): Multiple.
crop (bool, optional): Whether to crop or pad. Defaults to False.
Returns:
torch.Tensor: Padding.
"""
h, w = x.shape[-2:]
if crop:
pad_w = (w // multiple) * multiple - w
pad_h = (h // multiple) * multiple - h
else:
pad_w = (multiple - w % multiple) % multiple
pad_h = (multiple - h % multiple) % multiple
if mode == "center":
pad_l = pad_w // 2
pad_r = pad_w - pad_l
pad_t = pad_h // 2
pad_b = pad_h - pad_t
elif mode == "left":
pad_l = 0
pad_r = pad_w
pad_t = 0
pad_b = pad_h
else:
raise ValueError(f"Unknown mode {mode}")
return (pad_l, pad_r, pad_t, pad_b)
def fit_features_to_multiple(
features: torch.Tensor, multiple: int = 32, crop: bool = False
) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Pad image to a multiple of the given number.
Args:
features (torch.Tensor): Input features.
multiple (int, optional): Multiple. Defaults to 32.
crop (bool, optional): Whether to crop or pad. Defaults to False.
Returns:
Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding.
"""
pad = fit_to_multiple(features, multiple, crop=crop)
return torch.nn.functional.pad(features, pad, mode="reflect"), pad