Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
# group operations implemented in cuda | |
from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToMatrix, ToVec, FromVec | |
from .broadcasting import broadcast_inputs | |
class LieGroupParameter(torch.Tensor): | |
""" Wrapper class for LieGroup """ | |
from torch._C import _disabled_torch_function_impl | |
__torch_function__ = _disabled_torch_function_impl | |
def __new__(cls, group, requires_grad=True): | |
data = torch.zeros(group.tangent_shape, | |
device=group.data.device, | |
dtype=group.data.dtype, | |
requires_grad=True) | |
return torch.Tensor._make_subclass(cls, data, requires_grad) | |
def __init__(self, group): | |
self.group = group | |
def retr(self): | |
return self.group.retr(self) | |
def log(self): | |
return self.retr().log() | |
def inv(self): | |
return self.retr().inv() | |
def adj(self, a): | |
return self.retr().adj(a) | |
def __mul__(self, other): | |
if isinstance(other, LieGroupParameter): | |
return self.retr() * other.retr() | |
else: | |
return self.retr() * other | |
def add_(self, update, alpha): | |
self.group = self.group.exp(alpha*update) * self.group | |
def __getitem__(self, index): | |
return self.retr().__getitem__(index) | |
class LieGroup: | |
""" Base class for Lie Group """ | |
def __init__(self, data): | |
self.data = data | |
def __repr__(self): | |
return "{}: size={}, device={}, dtype={}".format( | |
self.group_name, self.shape, self.device, self.dtype) | |
def shape(self): | |
return self.data.shape[:-1] | |
def device(self): | |
return self.data.device | |
def dtype(self): | |
return self.data.dtype | |
def vec(self): | |
return self.apply_op(ToVec, self.data) | |
def tangent_shape(self): | |
return self.data.shape[:-1] + (self.manifold_dim,) | |
def Identity(cls, *batch_shape, **kwargs): | |
""" Construct identity element with batch shape """ | |
if isinstance(batch_shape[0], tuple): | |
batch_shape = batch_shape[0] | |
elif isinstance(batch_shape[0], list): | |
batch_shape = tuple(batch_shape[0]) | |
numel = np.prod(batch_shape) | |
data = cls.id_elem.reshape(1,-1) | |
if 'device' in kwargs: | |
data = data.to(kwargs['device']) | |
if 'dtype' in kwargs: | |
data = data.type(kwargs['dtype']) | |
data = data.repeat(numel, 1) | |
return cls(data).view(batch_shape) | |
def IdentityLike(cls, G): | |
return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype) | |
def InitFromVec(cls, data): | |
return cls(cls.apply_op(FromVec, data)) | |
def Random(cls, *batch_shape, sigma=1.0, **kwargs): | |
""" Construct random element with batch_shape by random sampling in tangent space""" | |
if isinstance(batch_shape[0], tuple): | |
batch_shape = batch_shape[0] | |
elif isinstance(batch_shape[0], list): | |
batch_shape = tuple(batch_shape[0]) | |
tangent_shape = batch_shape + (cls.manifold_dim,) | |
xi = torch.randn(tangent_shape, **kwargs) | |
return cls.exp(sigma * xi) | |
def apply_op(cls, op, x, y=None): | |
""" Apply group operator """ | |
inputs, out_shape = broadcast_inputs(x, y) | |
data = op.apply(cls.group_id, *inputs) | |
return data.view(out_shape + (-1,)) | |
def exp(cls, x): | |
""" exponential map: x -> X """ | |
return cls(cls.apply_op(Exp, x)) | |
def quaternion(self): | |
""" extract quaternion """ | |
return self.apply_op(Quat, self.data) | |
def log(self): | |
""" logarithm map """ | |
return self.apply_op(Log, self.data) | |
def inv(self): | |
""" group inverse """ | |
return self.__class__(self.apply_op(Inv, self.data)) | |
def mul(self, other): | |
""" group multiplication """ | |
return self.__class__(self.apply_op(Mul, self.data, other.data)) | |
def retr(self, a): | |
""" retraction: Exp(a) * X """ | |
dX = self.__class__.apply_op(Exp, a) | |
return self.__class__(self.apply_op(Mul, dX, self.data)) | |
def adj(self, a): | |
""" adjoint operator: b = A(X) * a """ | |
return self.apply_op(Adj, self.data, a) | |
def adjT(self, a): | |
""" transposed adjoint operator: b = a * A(X) """ | |
return self.apply_op(AdjT, self.data, a) | |
def Jinv(self, a): | |
return self.apply_op(Jinv, self.data, a) | |
def act(self, p): | |
""" action on a point cloud """ | |
# action on point | |
if p.shape[-1] == 3: | |
return self.apply_op(Act3, self.data, p) | |
# action on homogeneous point | |
elif p.shape[-1] == 4: | |
return self.apply_op(Act4, self.data, p) | |
def matrix(self): | |
""" convert element to 4x4 matrix """ | |
I = torch.eye(4, dtype=self.dtype, device=self.device) | |
I = I.view([1] * (len(self.data.shape) - 1) + [4, 4]) | |
return self.__class__(self.data[...,None,:]).act(I).transpose(-1,-2) | |
def translation(self): | |
""" extract translation component """ | |
p = torch.as_tensor([0.0, 0.0, 0.0, 1.0], dtype=self.dtype, device=self.device) | |
p = p.view([1] * (len(self.data.shape) - 1) + [4,]) | |
return self.apply_op(Act4, self.data, p) | |
def detach(self): | |
return self.__class__(self.data.detach()) | |
def view(self, dims): | |
data_reshaped = self.data.view(dims + (self.embedded_dim,)) | |
return self.__class__(data_reshaped) | |
def __mul__(self, other): | |
# group multiplication | |
if isinstance(other, LieGroup): | |
return self.mul(other) | |
# action on point | |
elif isinstance(other, torch.Tensor): | |
return self.act(other) | |
def __getitem__(self, index): | |
return self.__class__(self.data[index]) | |
def __setitem__(self, index, item): | |
self.data[index] = item.data | |
def to(self, *args, **kwargs): | |
return self.__class__(self.data.to(*args, **kwargs)) | |
def cpu(self): | |
return self.__class__(self.data.cpu()) | |
def cuda(self): | |
return self.__class__(self.data.cuda()) | |
def float(self, device): | |
return self.__class__(self.data.float()) | |
def double(self, device): | |
return self.__class__(self.data.double()) | |
def unbind(self, dim=0): | |
return [self.__class__(x) for x in self.data.unbind(dim=dim)] | |
class SO3(LieGroup): | |
group_name = 'SO3' | |
group_id = 1 | |
manifold_dim = 3 | |
embedded_dim = 4 | |
# unit quaternion | |
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0]) | |
def __init__(self, data): | |
if isinstance(data, SE3): | |
data = data.data[..., 3:7] | |
super(SO3, self).__init__(data) | |
class RxSO3(LieGroup): | |
group_name = 'RxSO3' | |
group_id = 2 | |
manifold_dim = 4 | |
embedded_dim = 5 | |
# unit quaternion | |
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0, 1.0]) | |
def __init__(self, data): | |
if isinstance(data, Sim3): | |
data = data.data[..., 3:8] | |
super(RxSO3, self).__init__(data) | |
class SE3(LieGroup): | |
group_name = 'SE3' | |
group_id = 3 | |
manifold_dim = 6 | |
embedded_dim = 7 | |
# translation, unit quaternion | |
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]) | |
def __init__(self, data): | |
if isinstance(data, SO3): | |
translation = torch.zeros_like(data.data[...,:3]) | |
data = torch.cat([translation, data.data], -1) | |
super(SE3, self).__init__(data) | |
def scale(self, s): | |
t, q = self.data.split([3,4], -1) | |
t = t * s.unsqueeze(-1) | |
return SE3(torch.cat([t, q], dim=-1)) | |
class Sim3(LieGroup): | |
group_name = 'Sim3' | |
group_id = 4 | |
manifold_dim = 7 | |
embedded_dim = 8 | |
# translation, unit quaternion, scale | |
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0]) | |
def __init__(self, data): | |
if isinstance(data, SO3): | |
scale = torch.ones_like(SO3.data[...,:1]) | |
translation = torch.zeros_like(SO3.data[...,:3]) | |
data = torch.cat([translation, SO3.data, scale], -1) | |
elif isinstance(data, SE3): | |
scale = torch.ones_like(data.data[...,:1]) | |
data = torch.cat([data.data, scale], -1) | |
elif isinstance(data, Sim3): | |
data = data.data | |
super(Sim3, self).__init__(data) | |
def cat(group_list, dim): | |
""" Concatenate groups along dimension """ | |
data = torch.cat([X.data for X in group_list], dim=dim) | |
return group_list[0].__class__(data) | |
def stack(group_list, dim): | |
""" Concatenate groups along dimension """ | |
data = torch.stack([X.data for X in group_list], dim=dim) | |
return group_list[0].__class__(data) | |