Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
from itertools import product | |
import torch | |
from pytorch3d.implicitron.models.renderer.ray_point_refiner import ( | |
apply_blurpool_on_weights, | |
RayPointRefiner, | |
) | |
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle | |
from tests.common_testing import TestCaseMixin | |
class TestRayPointRefiner(TestCaseMixin, unittest.TestCase): | |
def test_simple(self): | |
length = 15 | |
n_pts_per_ray = 10 | |
for add_input_samples, use_blurpool in product([False, True], [False, True]): | |
ray_point_refiner = RayPointRefiner( | |
n_pts_per_ray=n_pts_per_ray, | |
random_sampling=False, | |
add_input_samples=add_input_samples, | |
blurpool_weights=use_blurpool, | |
) | |
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length) | |
bundle = ImplicitronRayBundle( | |
lengths=lengths, | |
origins=None, | |
directions=None, | |
xys=None, | |
camera_ids=None, | |
camera_counts=None, | |
) | |
weights = torch.ones(3, 25, length) | |
refined = ray_point_refiner(bundle, weights) | |
self.assertIsNone(refined.directions) | |
self.assertIsNone(refined.origins) | |
self.assertIsNone(refined.xys) | |
expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray) | |
expected = expected.expand(3, 25, n_pts_per_ray) | |
if add_input_samples: | |
full_expected = torch.cat((lengths, expected), dim=-1).sort()[0] | |
else: | |
full_expected = expected | |
self.assertClose(refined.lengths, full_expected) | |
ray_point_refiner_random = RayPointRefiner( | |
n_pts_per_ray=n_pts_per_ray, | |
random_sampling=True, | |
add_input_samples=add_input_samples, | |
blurpool_weights=use_blurpool, | |
) | |
refined_random = ray_point_refiner_random(bundle, weights) | |
lengths_random = refined_random.lengths | |
self.assertEqual(lengths_random.shape, full_expected.shape) | |
if not add_input_samples: | |
self.assertGreater(lengths_random.min().item(), 0.5) | |
self.assertLess(lengths_random.max().item(), length - 1.5) | |
# Check sorted | |
self.assertTrue( | |
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all() | |
) | |
def test_simple_use_bins(self): | |
""" | |
Same spirit than test_simple but use bins in the ImplicitronRayBunle. | |
It has been duplicated to avoid cognitive overload while reading the | |
test (lot of if else). | |
""" | |
length = 15 | |
n_pts_per_ray = 10 | |
for add_input_samples, use_blurpool in product([False, True], [False, True]): | |
ray_point_refiner = RayPointRefiner( | |
n_pts_per_ray=n_pts_per_ray, | |
random_sampling=False, | |
add_input_samples=add_input_samples, | |
) | |
bundle = ImplicitronRayBundle( | |
lengths=None, | |
bins=torch.arange(length + 1, dtype=torch.float32).expand( | |
3, 25, length + 1 | |
), | |
origins=None, | |
directions=None, | |
xys=None, | |
camera_ids=None, | |
camera_counts=None, | |
) | |
weights = torch.ones(3, 25, length) | |
refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool) | |
self.assertIsNone(refined.directions) | |
self.assertIsNone(refined.origins) | |
self.assertIsNone(refined.xys) | |
expected_bins = torch.linspace(0, length, n_pts_per_ray + 1) | |
expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1) | |
if add_input_samples: | |
expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[ | |
0 | |
] | |
full_expected = torch.lerp( | |
expected_bins[..., :-1], expected_bins[..., 1:], 0.5 | |
) | |
self.assertClose(refined.lengths, full_expected) | |
ray_point_refiner_random = RayPointRefiner( | |
n_pts_per_ray=n_pts_per_ray, | |
random_sampling=True, | |
add_input_samples=add_input_samples, | |
) | |
refined_random = ray_point_refiner_random( | |
bundle, weights, blurpool_weights=use_blurpool | |
) | |
lengths_random = refined_random.lengths | |
self.assertEqual(lengths_random.shape, full_expected.shape) | |
if not add_input_samples: | |
self.assertGreater(lengths_random.min().item(), 0) | |
self.assertLess(lengths_random.max().item(), length) | |
# Check sorted | |
self.assertTrue( | |
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all() | |
) | |
def test_apply_blurpool_on_weights(self): | |
weights = torch.tensor( | |
[ | |
[0.5, 0.6, 0.7], | |
[0.5, 0.3, 0.9], | |
] | |
) | |
expected_weights = 0.5 * torch.tensor( | |
[ | |
[0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7], | |
[0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9], | |
] | |
) | |
out_weights = apply_blurpool_on_weights(weights) | |
self.assertTrue(torch.allclose(out_weights, expected_weights)) | |
def test_shapes_apply_blurpool_on_weights(self): | |
weights = torch.randn((5, 4, 3, 2, 1)) | |
out_weights = apply_blurpool_on_weights(weights) | |
self.assertEqual((5, 4, 3, 2, 1), out_weights.shape) | |