Linly-Talker / pytorch3d /tests /implicitron /test_ray_point_refiner.py
linxianzhong0128's picture
Upload folder using huggingface_hub
7088d16 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from itertools import product
import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
apply_blurpool_on_weights,
RayPointRefiner,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from tests.common_testing import TestCaseMixin
class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
def test_simple(self):
length = 15
n_pts_per_ray = 10
for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
blurpool_weights=use_blurpool,
)
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = ImplicitronRayBundle(
lengths=lengths,
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray)
expected = expected.expand(3, 25, n_pts_per_ray)
if add_input_samples:
full_expected = torch.cat((lengths, expected), dim=-1).sort()[0]
else:
full_expected = expected
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
blurpool_weights=use_blurpool,
)
refined_random = ray_point_refiner_random(bundle, weights)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0.5)
self.assertLess(lengths_random.max().item(), length - 1.5)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_simple_use_bins(self):
"""
Same spirit than test_simple but use bins in the ImplicitronRayBunle.
It has been duplicated to avoid cognitive overload while reading the
test (lot of if else).
"""
length = 15
n_pts_per_ray = 10
for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)
bundle = ImplicitronRayBundle(
lengths=None,
bins=torch.arange(length + 1, dtype=torch.float32).expand(
3, 25, length + 1
),
origins=None,
directions=None,
xys=None,
camera_ids=None,
camera_counts=None,
)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights, blurpool_weights=use_blurpool)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected_bins = torch.linspace(0, length, n_pts_per_ray + 1)
expected_bins = expected_bins.expand(3, 25, n_pts_per_ray + 1)
if add_input_samples:
expected_bins = torch.cat((bundle.bins, expected_bins), dim=-1).sort()[
0
]
full_expected = torch.lerp(
expected_bins[..., :-1], expected_bins[..., 1:], 0.5
)
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)
refined_random = ray_point_refiner_random(
bundle, weights, blurpool_weights=use_blurpool
)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0)
self.assertLess(lengths_random.max().item(), length)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_apply_blurpool_on_weights(self):
weights = torch.tensor(
[
[0.5, 0.6, 0.7],
[0.5, 0.3, 0.9],
]
)
expected_weights = 0.5 * torch.tensor(
[
[0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7],
[0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9],
]
)
out_weights = apply_blurpool_on_weights(weights)
self.assertTrue(torch.allclose(out_weights, expected_weights))
def test_shapes_apply_blurpool_on_weights(self):
weights = torch.randn((5, 4, 3, 2, 1))
out_weights = apply_blurpool_on_weights(weights)
self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)