SparseAGS / liegroups /tests /torch /test_utils_torch.py
qitaoz's picture
init commit
26ce2a9 verified
import torch
from liegroups.torch import utils
def test_isclose():
tol = 1e-6
mat = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol])
ans = torch.ByteTensor([1, 0, 0, 0, 1])
assert (utils.isclose(mat, 0., tol=tol) == ans).all()
def test_allclose():
tol = 1e-6
mat_good = torch.Tensor([0.1 * tol, 0.01 * tol, 0, 0, 0])
mat_bad = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol])
assert utils.allclose(mat_good, 0., tol=tol)
assert not utils.allclose(mat_bad, 0., tol=tol)
def test_outer():
vec1 = torch.Tensor([1, 2, 3])
vec2 = torch.Tensor([0, 1, 2])
assert (utils.outer(vec1, vec2) == torch.mm(
vec1.unsqueeze(dim=1), vec2.unsqueeze(dim=0))).all()
vecs1 = torch.Tensor([[1, 2, 3], [4, 5, 6]])
vecs2 = torch.Tensor([[0, 1, 2], [3, 4, 5]])
assert (utils.outer(vecs1, vecs2) == torch.bmm(
vecs1.unsqueeze(dim=2), vecs2.unsqueeze(dim=1))).all()
def test_trace():
mat = torch.arange(1, 10).view(3, 3)
assert utils.trace(mat)[0] == torch.trace(mat)
mats = torch.cat([torch.arange(1, 10).view(1, 3, 3),
torch.arange(11, 20).view(1, 3, 3)], dim=0)
traces = utils.trace(mats)
assert len(traces) == 2 and \
traces[0] == torch.trace(mats[0]) and \
traces[1] == torch.trace(mats[1])