File size: 3,351 Bytes
7088d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import numpy as np
import torch
from pytorch3d.ops import utils as oputil

from .common_testing import TestCaseMixin


class TestOpsUtils(TestCaseMixin, unittest.TestCase):
    def setUp(self) -> None:
        super().setUp()
        torch.manual_seed(42)
        np.random.seed(42)

    def test_wmean(self):
        device = torch.device("cuda:0")
        n_points = 20

        x = torch.rand(n_points, 3, device=device)
        weight = torch.rand(n_points, device=device)
        x_np = x.cpu().data.numpy()
        weight_np = weight.cpu().data.numpy()

        # test unweighted
        mean = oputil.wmean(x, keepdim=False)
        mean_gt = np.average(x_np, axis=-2)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test weighted
        mean = oputil.wmean(x, weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=-2, weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test keepdim
        mean = oputil.wmean(x, weight=weight, keepdim=True)
        self.assertClose(mean[0].cpu().data.numpy(), mean_gt)

        # test binary weigths
        mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
        mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test broadcasting
        x = torch.rand(10, n_points, 3, device=device)
        x_np = x.cpu().data.numpy()
        mean = oputil.wmean(x, weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=-2, weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        weight = weight[None, None, :].repeat(3, 1, 1)
        mean = oputil.wmean(x, weight=weight, keepdim=False)
        self.assertClose(mean[0].cpu().data.numpy(), mean_gt)

        # test failing broadcasting
        weight = torch.rand(x.shape[0], device=device)
        with self.assertRaises(ValueError) as context:
            oputil.wmean(x, weight=weight, keepdim=False)
        self.assertTrue("weights are not compatible" in str(context.exception))

        # test dim
        weight = torch.rand(x.shape[0], n_points, device=device)
        weight_np = np.tile(
            weight[:, :, None].cpu().data.numpy(), (1, 1, x_np.shape[-1])
        )
        mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=0, weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

        # test dim tuple
        mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
        mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
        self.assertClose(mean.cpu().data.numpy(), mean_gt)

    def test_masked_gather_errors(self):
        idx = torch.randint(0, 10, size=(5, 10, 4, 2))
        points = torch.randn(size=(5, 10, 3))
        with self.assertRaisesRegex(ValueError, "format is not supported"):
            oputil.masked_gather(points, idx)

        points = torch.randn(size=(2, 10, 3))
        with self.assertRaisesRegex(ValueError, "same batch dimension"):
            oputil.masked_gather(points, idx)