Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import functools | |
from typing import Tuple, Optional | |
########################## | |
#### from pytorch3d #### | |
########################## | |
def _axis_angle_rotation(axis: str, angle): | |
""" | |
Return the rotation matrices for one of the rotations about an axis | |
of which Euler angles describe, for each value of the angle given. | |
Args: | |
axis: Axis label "X" or "Y or "Z". | |
angle: any shape tensor of Euler angles in radians | |
Returns: | |
Rotation matrices as tensor of shape (..., 3, 3). | |
""" | |
cos = torch.cos(angle) | |
sin = torch.sin(angle) | |
one = torch.ones_like(angle) | |
zero = torch.zeros_like(angle) | |
if axis == "X": | |
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) | |
if axis == "Y": | |
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) | |
if axis == "Z": | |
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) | |
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) | |
def euler_angles_to_matrix(euler_angles, convention: str): | |
""" | |
Convert rotations given as Euler angles in radians to rotation matrices. | |
Args: | |
euler_angles: Euler angles in radians as tensor of shape (..., 3). | |
convention: Convention string of three uppercase letters from | |
{"X", "Y", and "Z"}. | |
Returns: | |
Rotation matrices as tensor of shape (..., 3, 3). | |
""" | |
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: | |
raise ValueError("Invalid input euler angles.") | |
if len(convention) != 3: | |
raise ValueError("Convention must have 3 letters.") | |
if convention[1] in (convention[0], convention[2]): | |
raise ValueError(f"Invalid convention {convention}.") | |
for letter in convention: | |
if letter not in ("X", "Y", "Z"): | |
raise ValueError(f"Invalid letter {letter} in convention string.") | |
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)) | |
return functools.reduce(torch.matmul, matrices) | |
########################### | |
#### from pytorchgemotry #### | |
########################### | |
def get_perspective_transform(src, dst): | |
r"""Calculates a perspective transform from four pairs of the corresponding | |
points. | |
The function calculates the matrix of a perspective transform so that: | |
.. math :: | |
\begin{bmatrix} | |
t_{i}x_{i}^{'} \\ | |
t_{i}y_{i}^{'} \\ | |
t_{i} \\ | |
\end{bmatrix} | |
= | |
\textbf{map_matrix} \cdot | |
\begin{bmatrix} | |
x_{i} \\ | |
y_{i} \\ | |
1 \\ | |
\end{bmatrix} | |
where | |
.. math :: | |
dst(i) = (x_{i}^{'},y_{i}^{'}), src(i) = (x_{i}, y_{i}), i = 0,1,2,3 | |
Args: | |
src (Tensor): coordinates of quadrangle vertices in the source image. | |
dst (Tensor): coordinates of the corresponding quadrangle vertices in | |
the destination image. | |
Returns: | |
Tensor: the perspective transformation. | |
Shape: | |
- Input: :math:`(B, 4, 2)` and :math:`(B, 4, 2)` | |
- Output: :math:`(B, 3, 3)` | |
""" | |
if not torch.is_tensor(src): | |
raise TypeError("Input type is not a torch.Tensor. Got {}" | |
.format(type(src))) | |
if not torch.is_tensor(dst): | |
raise TypeError("Input type is not a torch.Tensor. Got {}" | |
.format(type(dst))) | |
if not src.shape[-2:] == (4, 2): | |
raise ValueError("Inputs must be a Bx4x2 tensor. Got {}" | |
.format(src.shape)) | |
if not src.shape == dst.shape: | |
raise ValueError("Inputs must have the same shape. Got {}" | |
.format(dst.shape)) | |
if not (src.shape[0] == dst.shape[0]): | |
raise ValueError("Inputs must have same batch size dimension. Got {}" | |
.format(src.shape, dst.shape)) | |
def ax(p, q): | |
ones = torch.ones_like(p)[..., 0:1] | |
zeros = torch.zeros_like(p)[..., 0:1] | |
return torch.cat( | |
[p[:, 0:1], p[:, 1:2], ones, zeros, zeros, zeros, | |
-p[:, 0:1] * q[:, 0:1], -p[:, 1:2] * q[:, 0:1] | |
], dim=1) | |
def ay(p, q): | |
ones = torch.ones_like(p)[..., 0:1] | |
zeros = torch.zeros_like(p)[..., 0:1] | |
return torch.cat( | |
[zeros, zeros, zeros, p[:, 0:1], p[:, 1:2], ones, | |
-p[:, 0:1] * q[:, 1:2], -p[:, 1:2] * q[:, 1:2]], dim=1) | |
# we build matrix A by using only 4 point correspondence. The linear | |
# system is solved with the least square method, so here | |
# we could even pass more correspondence | |
p = [] | |
p.append(ax(src[:, 0], dst[:, 0])) | |
p.append(ay(src[:, 0], dst[:, 0])) | |
p.append(ax(src[:, 1], dst[:, 1])) | |
p.append(ay(src[:, 1], dst[:, 1])) | |
p.append(ax(src[:, 2], dst[:, 2])) | |
p.append(ay(src[:, 2], dst[:, 2])) | |
p.append(ax(src[:, 3], dst[:, 3])) | |
p.append(ay(src[:, 3], dst[:, 3])) | |
# A is Bx8x8 | |
A = torch.stack(p, dim=1) | |
# b is a Bx8x1 | |
b = torch.stack([ | |
dst[:, 0:1, 0], dst[:, 0:1, 1], | |
dst[:, 1:2, 0], dst[:, 1:2, 1], | |
dst[:, 2:3, 0], dst[:, 2:3, 1], | |
dst[:, 3:4, 0], dst[:, 3:4, 1], | |
], dim=1) | |
# solve the system Ax = b | |
# X, LU = torch.gesv(b, A) | |
X = torch.linalg.solve(A, b) | |
# create variable to return | |
batch_size = src.shape[0] | |
M = torch.ones(batch_size, 9, device=src.device, dtype=src.dtype) | |
M[..., :8] = torch.squeeze(X, dim=-1) | |
return M.view(-1, 3, 3) # Bx3x3 | |
def warp_perspective(src, M, dsize, flags='bilinear', border_mode=None, | |
border_value=0): | |
r"""Applies a perspective transformation to an image. | |
The function warp_perspective transforms the source image using | |
the specified matrix: | |
.. math:: | |
\text{dst} (x, y) = \text{src} \left( | |
\frac{M_{11} x + M_{12} y + M_{13}}{M_{31} x + M_{32} y + M_{33}} , | |
\frac{M_{21} x + M_{22} y + M_{23}}{M_{31} x + M_{32} y + M_{33}} | |
\right ) | |
Args: | |
src (torch.Tensor): input image. | |
M (Tensor): transformation matrix. | |
dsize (tuple): size of the output image (height, width). | |
Returns: | |
Tensor: the warped input image. | |
Shape: | |
- Input: :math:`(B, C, H, W)` and :math:`(B, 3, 3)` | |
- Output: :math:`(B, C, H, W)` | |
.. note:: | |
See a working example `here <https://github.com/arraiy/torchgeometry/ | |
blob/master/examples/warp_perspective.ipynb>`_. | |
""" | |
if not torch.is_tensor(src): | |
raise TypeError("Input src type is not a torch.Tensor. Got {}" | |
.format(type(src))) | |
if not torch.is_tensor(M): | |
raise TypeError("Input M type is not a torch.Tensor. Got {}" | |
.format(type(M))) | |
if not len(src.shape) == 4: | |
raise ValueError("Input src must be a BxCxHxW tensor. Got {}" | |
.format(src.shape)) | |
if not (len(M.shape) == 3 or M.shape[-2:] == (3, 3)): | |
raise ValueError("Input M must be a Bx3x3 tensor. Got {}" | |
.format(src.shape)) | |
# launches the warper | |
return transform_warp_impl(src, M, (src.shape[-2:]), dsize) | |
def transform_warp_impl(src, dst_pix_trans_src_pix, dsize_src, dsize_dst): | |
"""Compute the transform in normalized cooridnates and perform the warping. | |
""" | |
dst_norm_trans_dst_norm = dst_norm_to_dst_norm( | |
dst_pix_trans_src_pix, dsize_src, dsize_dst) | |
return homography_warp(src, torch.inverse( | |
dst_norm_trans_dst_norm), dsize_dst) | |
def dst_norm_to_dst_norm(dst_pix_trans_src_pix, dsize_src, dsize_dst): | |
# source and destination sizes | |
src_h, src_w = dsize_src | |
dst_h, dst_w = dsize_dst | |
# the devices and types | |
device = dst_pix_trans_src_pix.device | |
dtype = dst_pix_trans_src_pix.dtype | |
# compute the transformation pixel/norm for src/dst | |
src_norm_trans_src_pix = normal_transform_pixel( | |
src_h, src_w).to(device).to(dtype) | |
src_pix_trans_src_norm = torch.inverse(src_norm_trans_src_pix) | |
dst_norm_trans_dst_pix = normal_transform_pixel( | |
dst_h, dst_w).to(device).to(dtype) | |
# compute chain transformations | |
dst_norm_trans_src_norm = torch.matmul( | |
dst_norm_trans_dst_pix, torch.matmul( | |
dst_pix_trans_src_pix, src_pix_trans_src_norm)) | |
return dst_norm_trans_src_norm | |
def normal_transform_pixel(height, width): | |
tr_mat = torch.Tensor([[1.0, 0.0, -1.0], | |
[0.0, 1.0, -1.0], | |
[0.0, 0.0, 1.0]]) # 1x3x3 | |
tr_mat[0, 0] = tr_mat[0, 0] * 2.0 / (width - 1.0) | |
tr_mat[1, 1] = tr_mat[1, 1] * 2.0 / (height - 1.0) | |
tr_mat = tr_mat.unsqueeze(0) | |
return tr_mat | |
def homography_warp(patch_src: torch.Tensor, | |
dst_homo_src: torch.Tensor, | |
dsize: Tuple[int, int], | |
mode: Optional[str] = 'bilinear', | |
padding_mode: Optional[str] = 'zeros') -> torch.Tensor: | |
r"""Function that warps image patchs or tensors by homographies. | |
See :class:`~torchgeometry.HomographyWarper` for details. | |
Args: | |
patch_src (torch.Tensor): The image or tensor to warp. Should be from | |
source of shape :math:`(N, C, H, W)`. | |
dst_homo_src (torch.Tensor): The homography or stack of homographies | |
from source to destination of shape | |
:math:`(N, 3, 3)`. | |
dsize (Tuple[int, int]): The height and width of the image to warp. | |
mode (Optional[str]): interpolation mode to calculate output values | |
'bilinear' | 'nearest'. Default: 'bilinear'. | |
padding_mode (Optional[str]): padding mode for outside grid values | |
'zeros' | 'border' | 'reflection'. Default: 'zeros'. | |
Return: | |
torch.Tensor: Patch sampled at locations from source to destination. | |
Example: | |
>>> input = torch.rand(1, 3, 32, 32) | |
>>> homography = torch.eye(3).view(1, 3, 3) | |
>>> output = tgm.homography_warp(input, homography, (32, 32)) # NxCxHxW | |
""" | |
height, width = dsize | |
warper = HomographyWarper(height, width, mode, padding_mode) | |
return warper(patch_src, dst_homo_src) | |
class HomographyWarper(nn.Module): | |
r"""Warps image patches or tensors by homographies. | |
.. math:: | |
X_{dst} = H_{src}^{\{dst\}} * X_{src} | |
Args: | |
height (int): The height of the image to warp. | |
width (int): The width of the image to warp. | |
mode (Optional[str]): interpolation mode to calculate output values | |
'bilinear' | 'nearest'. Default: 'bilinear'. | |
padding_mode (Optional[str]): padding mode for outside grid values | |
'zeros' | 'border' | 'reflection'. Default: 'zeros'. | |
normalized_coordinates (Optional[bool]): wether to use a grid with | |
normalized coordinates. | |
""" | |
def __init__( | |
self, | |
height: int, | |
width: int, | |
mode: Optional[str] = 'bilinear', | |
padding_mode: Optional[str] = 'zeros', | |
normalized_coordinates: Optional[bool] = True) -> None: | |
super(HomographyWarper, self).__init__() | |
self.width: int = width | |
self.height: int = height | |
self.mode: Optional[str] = mode | |
self.padding_mode: Optional[str] = padding_mode | |
self.normalized_coordinates: Optional[bool] = normalized_coordinates | |
# create base grid to compute the flow | |
self.grid: torch.Tensor = create_meshgrid( | |
height, width, normalized_coordinates=normalized_coordinates) | |
def warp_grid(self, dst_homo_src: torch.Tensor) -> torch.Tensor: | |
r"""Computes the grid to warp the coordinates grid by an homography. | |
Args: | |
dst_homo_src (torch.Tensor): Homography or homographies (stacked) to | |
transform all points in the grid. Shape of the | |
homography has to be :math:`(N, 3, 3)`. | |
Returns: | |
torch.Tensor: the transformed grid of shape :math:`(N, H, W, 2)`. | |
""" | |
batch_size: int = dst_homo_src.shape[0] | |
device: torch.device = dst_homo_src.device | |
dtype: torch.dtype = dst_homo_src.dtype | |
# expand grid to match the input batch size | |
grid: torch.Tensor = self.grid.expand(batch_size, -1, -1, -1) # NxHxWx2 | |
if len(dst_homo_src.shape) == 3: # local homography case | |
dst_homo_src = dst_homo_src.view(batch_size, 1, 3, 3) # NxHxWx3x3 | |
# perform the actual grid transformation, | |
# the grid is copied to input device and casted to the same type | |
flow: torch.Tensor = transform_points( | |
dst_homo_src, grid.to(device).to(dtype)) # NxHxWx2 | |
return flow.view(batch_size, self.height, self.width, 2) # NxHxWx2 | |
def forward( | |
self, | |
patch_src: torch.Tensor, | |
dst_homo_src: torch.Tensor) -> torch.Tensor: | |
r"""Warps an image or tensor from source into reference frame. | |
Args: | |
patch_src (torch.Tensor): The image or tensor to warp. | |
Should be from source. | |
dst_homo_src (torch.Tensor): The homography or stack of homographies | |
from source to destination. The homography assumes normalized | |
coordinates [-1, 1]. | |
Return: | |
torch.Tensor: Patch sampled at locations from source to destination. | |
Shape: | |
- Input: :math:`(N, C, H, W)` and :math:`(N, 3, 3)` | |
- Output: :math:`(N, C, H, W)` | |
Example: | |
>>> input = torch.rand(1, 3, 32, 32) | |
>>> homography = torch.eye(3).view(1, 3, 3) | |
>>> warper = tgm.HomographyWarper(32, 32) | |
>>> output = warper(input, homography) # NxCxHxW | |
""" | |
if not dst_homo_src.device == patch_src.device: | |
raise TypeError("Patch and homography must be on the same device. \ | |
Got patch.device: {} dst_H_src.device: {}." | |
.format(patch_src.device, dst_homo_src.device)) | |
return F.grid_sample(patch_src, self.warp_grid(dst_homo_src), | |
mode=self.mode, padding_mode=self.padding_mode) | |
def create_meshgrid( | |
height: int, | |
width: int, | |
normalized_coordinates: Optional[bool] = True): | |
"""Generates a coordinate grid for an image. | |
When the flag `normalized_coordinates` is set to True, the grid is | |
normalized to be in the range [-1,1] to be consistent with the pytorch | |
function grid_sample. | |
http://pytorch.org/docs/master/nn.html#torch.nn.functional.grid_sample | |
Args: | |
height (int): the image height (rows). | |
width (int): the image width (cols). | |
normalized_coordinates (Optional[bool]): wether to normalize | |
coordinates in the range [-1, 1] in order to be consistent with the | |
PyTorch function grid_sample. | |
Return: | |
torch.Tensor: returns a grid tensor with shape :math:`(1, H, W, 2)`. | |
""" | |
# generate coordinates | |
xs: Optional[torch.Tensor] = None | |
ys: Optional[torch.Tensor] = None | |
if normalized_coordinates: | |
xs = torch.linspace(-1, 1, width) | |
ys = torch.linspace(-1, 1, height) | |
else: | |
xs = torch.linspace(0, width - 1, width) | |
ys = torch.linspace(0, height - 1, height) | |
# generate grid by stacking coordinates | |
base_grid: torch.Tensor = torch.stack( | |
torch.meshgrid([xs, ys])).transpose(1, 2) # 2xHxW | |
return torch.unsqueeze(base_grid, dim=0).permute(0, 2, 3, 1) # 1xHxWx2 | |
def transform_points(trans_01: torch.Tensor, | |
points_1: torch.Tensor) -> torch.Tensor: | |
r"""Function that applies transformations to a set of points. | |
Args: | |
trans_01 (torch.Tensor): tensor for transformations of shape | |
:math:`(B, D+1, D+1)`. | |
points_1 (torch.Tensor): tensor of points of shape :math:`(B, N, D)`. | |
Returns: | |
torch.Tensor: tensor of N-dimensional points. | |
Shape: | |
- Output: :math:`(B, N, D)` | |
Examples: | |
>>> points_1 = torch.rand(2, 4, 3) # BxNx3 | |
>>> trans_01 = torch.eye(4).view(1, 4, 4) # Bx4x4 | |
>>> points_0 = tgm.transform_points(trans_01, points_1) # BxNx3 | |
""" | |
if not torch.is_tensor(trans_01) or not torch.is_tensor(points_1): | |
raise TypeError("Input type is not a torch.Tensor") | |
if not trans_01.device == points_1.device: | |
raise TypeError("Tensor must be in the same device") | |
if not trans_01.shape[0] == points_1.shape[0]: | |
raise ValueError("Input batch size must be the same for both tensors") | |
if not trans_01.shape[-1] == (points_1.shape[-1] + 1): | |
raise ValueError("Last input dimensions must differe by one unit") | |
# to homogeneous | |
points_1_h = convert_points_to_homogeneous(points_1) # BxNxD+1 | |
# transform coordinates | |
points_0_h = torch.matmul( | |
trans_01.unsqueeze(1), points_1_h.unsqueeze(-1)) | |
points_0_h = torch.squeeze(points_0_h, dim=-1) | |
# to euclidean | |
points_0 = convert_points_from_homogeneous(points_0_h) # BxNxD | |
return points_0 | |
def convert_points_to_homogeneous(points): | |
r"""Function that converts points from Euclidean to homogeneous space. | |
See :class:`~torchgeometry.ConvertPointsToHomogeneous` for details. | |
Examples:: | |
>>> input = torch.rand(2, 4, 3) # BxNx3 | |
>>> output = tgm.convert_points_to_homogeneous(input) # BxNx4 | |
""" | |
if not torch.is_tensor(points): | |
raise TypeError("Input type is not a torch.Tensor. Got {}".format( | |
type(points))) | |
if len(points.shape) < 2: | |
raise ValueError("Input must be at least a 2D tensor. Got {}".format( | |
points.shape)) | |
return nn.functional.pad(points, (0, 1), "constant", 1.0) | |
def convert_points_from_homogeneous(points): | |
r"""Function that converts points from homogeneous to Euclidean space. | |
See :class:`~torchgeometry.ConvertPointsFromHomogeneous` for details. | |
Examples:: | |
>>> input = torch.rand(2, 4, 3) # BxNx3 | |
>>> output = tgm.convert_points_from_homogeneous(input) # BxNx2 | |
""" | |
if not torch.is_tensor(points): | |
raise TypeError("Input type is not a torch.Tensor. Got {}".format( | |
type(points))) | |
if len(points.shape) < 2: | |
raise ValueError("Input must be at least a 2D tensor. Got {}".format( | |
points.shape)) | |
return points[..., :-1] / points[..., -1:] |