Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import unittest | |
import torch | |
from pytorch3d.implicitron.models.generic_model import GenericModel | |
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( | |
SRNHyperNetImplicitFunction, | |
SRNImplicitFunction, | |
SRNPixelGenerator, | |
) | |
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle | |
from pytorch3d.implicitron.tools.config import get_default_args | |
from pytorch3d.renderer import PerspectiveCameras | |
from tests.common_testing import TestCaseMixin | |
_BATCH_SIZE: int = 3 | |
_N_RAYS: int = 100 | |
_N_POINTS_ON_RAY: int = 10 | |
class TestSRN(TestCaseMixin, unittest.TestCase): | |
def setUp(self) -> None: | |
torch.manual_seed(42) | |
get_default_args(SRNHyperNetImplicitFunction) | |
get_default_args(SRNImplicitFunction) | |
def test_pixel_generator(self): | |
SRNPixelGenerator() | |
def _get_bundle(self, *, device) -> ImplicitronRayBundle: | |
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) | |
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) | |
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device) | |
bundle = ImplicitronRayBundle( | |
lengths=lengths, | |
origins=origins, | |
directions=directions, | |
xys=None, | |
camera_ids=None, | |
camera_counts=None, | |
) | |
return bundle | |
def test_srn_implicit_function(self): | |
implicit_function = SRNImplicitFunction() | |
device = torch.device("cpu") | |
bundle = self._get_bundle(device=device) | |
rays_densities, rays_colors = implicit_function(ray_bundle=bundle) | |
out_features = implicit_function.raymarch_function.out_features | |
self.assertEqual( | |
rays_densities.shape, | |
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features), | |
) | |
self.assertIsNone(rays_colors) | |
def test_srn_hypernet_implicit_function(self): | |
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core? | |
latent_dim_hypernet = 39 | |
device = torch.device("cuda:0") | |
implicit_function = SRNHyperNetImplicitFunction( | |
latent_dim_hypernet=latent_dim_hypernet | |
) | |
implicit_function.to(device) | |
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) | |
bundle = self._get_bundle(device=device) | |
rays_densities, rays_colors = implicit_function( | |
ray_bundle=bundle, global_code=global_code | |
) | |
out_features = implicit_function.hypernet.out_features | |
self.assertEqual( | |
rays_densities.shape, | |
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features), | |
) | |
self.assertIsNone(rays_colors) | |
def test_lstm(self): | |
args = get_default_args(GenericModel) | |
args.render_image_height = 80 | |
args.render_image_width = 80 | |
args.implicit_function_class_type = "SRNImplicitFunction" | |
args.renderer_class_type = "LSTMRenderer" | |
args.raysampler_class_type = "NearFarRaySampler" | |
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1 | |
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1 | |
args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2] | |
gm = GenericModel(**args) | |
camera = PerspectiveCameras() | |
image = gm.forward( | |
camera=camera, | |
image_rgb=None, | |
fg_probability=None, | |
sequence_name="", | |
mask_crop=None, | |
depth_map=None, | |
)["images_render"] | |
self.assertEqual(image.shape, (1, 3, 80, 80)) | |
self.assertGreater(image.max(), 0.8) | |
# Force everything to be background | |
pixel_generator = gm._implicit_functions[0]._fn.pixel_generator | |
pixel_generator._density_layer.weight.zero_() | |
pixel_generator._density_layer.bias.fill_(-1.0e6) | |
image = gm.forward( | |
camera=camera, | |
image_rgb=None, | |
fg_probability=None, | |
sequence_name="", | |
mask_crop=None, | |
depth_map=None, | |
)["images_render"] | |
self.assertConstant(image[:, :2], 0.4) | |
self.assertConstant(image[:, 2], 0.2) | |