Spaces:
Runtime error
Runtime error
import lietorch_backends | |
import torch | |
import torch.nn.functional as F | |
class GroupOp(torch.autograd.Function): | |
""" group operation base class """ | |
def forward(cls, ctx, group_id, *inputs): | |
ctx.group_id = group_id | |
ctx.save_for_backward(*inputs) | |
out = cls.forward_op(ctx.group_id, *inputs) | |
return out | |
def backward(cls, ctx, grad): | |
error_str = "Backward operation not implemented for {}".format(cls) | |
assert cls.backward_op is not None, error_str | |
inputs = ctx.saved_tensors | |
grad = grad.contiguous() | |
grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs) | |
return (None, ) + tuple(grad_inputs) | |
class Exp(GroupOp): | |
""" exponential map """ | |
forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward | |
class Log(GroupOp): | |
""" logarithm map """ | |
forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward | |
class Inv(GroupOp): | |
""" group inverse """ | |
forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward | |
class Mul(GroupOp): | |
""" group multiplication """ | |
forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward | |
class Adj(GroupOp): | |
""" adjoint operator """ | |
forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward | |
class AdjT(GroupOp): | |
""" adjoint operator """ | |
forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward | |
class Act3(GroupOp): | |
""" action on point """ | |
forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward | |
class Act4(GroupOp): | |
""" action on point """ | |
forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward | |
class Jinv(GroupOp): | |
""" adjoint operator """ | |
forward_op, backward_op = lietorch_backends.Jinv, None | |
class ToMatrix(GroupOp): | |
""" convert to matrix representation """ | |
forward_op, backward_op = lietorch_backends.as_matrix, None | |
### conversion operations to/from Euclidean embeddings ### | |
class FromVec(torch.autograd.Function): | |
""" convert vector into group object """ | |
def forward(cls, ctx, group_id, *inputs): | |
ctx.group_id = group_id | |
ctx.save_for_backward(*inputs) | |
return inputs[0] | |
def backward(cls, ctx, grad): | |
inputs = ctx.saved_tensors | |
J = lietorch_backends.projector(ctx.group_id, *inputs) | |
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2) | |
class ToVec(torch.autograd.Function): | |
""" convert group object to vector """ | |
def forward(cls, ctx, group_id, *inputs): | |
ctx.group_id = group_id | |
ctx.save_for_backward(*inputs) | |
return inputs[0] | |
def backward(cls, ctx, grad): | |
inputs = ctx.saved_tensors | |
J = lietorch_backends.projector(ctx.group_id, *inputs) | |
return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2) | |