File size: 3,964 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
"""Implementation of manifolds."""
import logging
import torch
logger = logging.getLogger(__name__)
class EuclideanManifold:
"""Simple euclidean manifold."""
@staticmethod
def J_plus(x: torch.Tensor) -> torch.Tensor:
"""Plus operator Jacobian."""
return torch.eye(x.shape[-1]).to(x)
@staticmethod
def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
"""Plus operator."""
return x + delta
class SphericalManifold:
"""Implementation of the spherical manifold.
Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
Golub et al.
"""
@staticmethod
def householder_vector(x: torch.Tensor) -> torch.Tensor:
"""Return the Householder vector and beta.
Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
first.
This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
orthogonal and H * x = ||x||_2 * e_n.
Args:
x (torch.Tensor): [..., n] tensor.
Returns:
torch.Tensor: v of shape [..., n]
torch.Tensor: beta of shape [...]
"""
sigma = torch.sum(x[..., :-1] ** 2, -1)
xpiv = x[..., -1]
norm = torch.norm(x, dim=-1)
if torch.any(sigma < 1e-7):
sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
logger.warning("sigma < 1e-7")
vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
beta = 2 * vpiv**2 / (sigma + vpiv**2)
v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
return v, beta
@staticmethod
def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
"""Apply Householder transformation.
Args:
y (torch.Tensor): Vector to transform of shape [..., n].
v (torch.Tensor): Householder vector of shape [..., n].
beta (torch.Tensor): Householder beta of shape [...].
Returns:
torch.Tensor: Transformed vector of shape [..., n].
"""
return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
@classmethod
def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
"""Plus operator Jacobian."""
v, beta = cls.householder_vector(x)
H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
H = H + torch.eye(H.shape[-1]).to(H)
return H[..., :-1] # J
@classmethod
def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
"""Plus operator.
Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
element of the input vector as pivot instead of first.
Args:
x: point on the manifold
delta: tangent vector
"""
eps = 1e-7
# keep norm is not equal to 1
nx = torch.norm(x, dim=-1, keepdim=True)
nd = torch.norm(delta, dim=-1, keepdim=True)
# make sure we don't divide by zero in backward as torch.where computes grad for both
# branches
nd_ = torch.where(nd < eps, nd + eps, nd)
sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
# cos is applied to last dim instead of first
exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
v, beta = cls.householder_vector(x)
return nx * cls.apply_householder(exp_delta, v, beta)
|