import torch from liegroups.torch import utils def test_isclose(): tol = 1e-6 mat = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol]) ans = torch.ByteTensor([1, 0, 0, 0, 1]) assert (utils.isclose(mat, 0., tol=tol) == ans).all() def test_allclose(): tol = 1e-6 mat_good = torch.Tensor([0.1 * tol, 0.01 * tol, 0, 0, 0]) mat_bad = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol]) assert utils.allclose(mat_good, 0., tol=tol) assert not utils.allclose(mat_bad, 0., tol=tol) def test_outer(): vec1 = torch.Tensor([1, 2, 3]) vec2 = torch.Tensor([0, 1, 2]) assert (utils.outer(vec1, vec2) == torch.mm( vec1.unsqueeze(dim=1), vec2.unsqueeze(dim=0))).all() vecs1 = torch.Tensor([[1, 2, 3], [4, 5, 6]]) vecs2 = torch.Tensor([[0, 1, 2], [3, 4, 5]]) assert (utils.outer(vecs1, vecs2) == torch.bmm( vecs1.unsqueeze(dim=2), vecs2.unsqueeze(dim=1))).all() def test_trace(): mat = torch.arange(1, 10).view(3, 3) assert utils.trace(mat)[0] == torch.trace(mat) mats = torch.cat([torch.arange(1, 10).view(1, 3, 3), torch.arange(11, 20).view(1, 3, 3)], dim=0) traces = utils.trace(mats) assert len(traces) == 2 and \ traces[0] == torch.trace(mats[0]) and \ traces[1] == torch.trace(mats[1])