|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
import torch |
|
from pytorch3d.implicitron.models.generic_model import GenericModel |
|
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( |
|
SRNHyperNetImplicitFunction, |
|
SRNImplicitFunction, |
|
SRNPixelGenerator, |
|
) |
|
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle |
|
from pytorch3d.implicitron.tools.config import get_default_args |
|
from pytorch3d.renderer import PerspectiveCameras |
|
|
|
from tests.common_testing import TestCaseMixin |
|
|
|
_BATCH_SIZE: int = 3 |
|
_N_RAYS: int = 100 |
|
_N_POINTS_ON_RAY: int = 10 |
|
|
|
|
|
class TestSRN(TestCaseMixin, unittest.TestCase): |
|
def setUp(self) -> None: |
|
torch.manual_seed(42) |
|
get_default_args(SRNHyperNetImplicitFunction) |
|
get_default_args(SRNImplicitFunction) |
|
|
|
def test_pixel_generator(self): |
|
SRNPixelGenerator() |
|
|
|
def _get_bundle(self, *, device) -> ImplicitronRayBundle: |
|
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) |
|
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) |
|
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device) |
|
bundle = ImplicitronRayBundle( |
|
lengths=lengths, |
|
origins=origins, |
|
directions=directions, |
|
xys=None, |
|
camera_ids=None, |
|
camera_counts=None, |
|
) |
|
return bundle |
|
|
|
def test_srn_implicit_function(self): |
|
implicit_function = SRNImplicitFunction() |
|
device = torch.device("cpu") |
|
bundle = self._get_bundle(device=device) |
|
rays_densities, rays_colors = implicit_function(ray_bundle=bundle) |
|
out_features = implicit_function.raymarch_function.out_features |
|
self.assertEqual( |
|
rays_densities.shape, |
|
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features), |
|
) |
|
self.assertIsNone(rays_colors) |
|
|
|
def test_srn_hypernet_implicit_function(self): |
|
|
|
latent_dim_hypernet = 39 |
|
device = torch.device("cuda:0") |
|
implicit_function = SRNHyperNetImplicitFunction( |
|
latent_dim_hypernet=latent_dim_hypernet |
|
) |
|
implicit_function.to(device) |
|
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) |
|
bundle = self._get_bundle(device=device) |
|
rays_densities, rays_colors = implicit_function( |
|
ray_bundle=bundle, global_code=global_code |
|
) |
|
out_features = implicit_function.hypernet.out_features |
|
self.assertEqual( |
|
rays_densities.shape, |
|
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features), |
|
) |
|
self.assertIsNone(rays_colors) |
|
|
|
@torch.no_grad() |
|
def test_lstm(self): |
|
args = get_default_args(GenericModel) |
|
args.render_image_height = 80 |
|
args.render_image_width = 80 |
|
args.implicit_function_class_type = "SRNImplicitFunction" |
|
args.renderer_class_type = "LSTMRenderer" |
|
args.raysampler_class_type = "NearFarRaySampler" |
|
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_training = 1 |
|
args.raysampler_NearFarRaySampler_args.n_pts_per_ray_evaluation = 1 |
|
args.renderer_LSTMRenderer_args.bg_color = [0.4, 0.4, 0.2] |
|
gm = GenericModel(**args) |
|
|
|
camera = PerspectiveCameras() |
|
image = gm.forward( |
|
camera=camera, |
|
image_rgb=None, |
|
fg_probability=None, |
|
sequence_name="", |
|
mask_crop=None, |
|
depth_map=None, |
|
)["images_render"] |
|
self.assertEqual(image.shape, (1, 3, 80, 80)) |
|
self.assertGreater(image.max(), 0.8) |
|
|
|
|
|
pixel_generator = gm._implicit_functions[0]._fn.pixel_generator |
|
pixel_generator._density_layer.weight.zero_() |
|
pixel_generator._density_layer.bias.fill_(-1.0e6) |
|
|
|
image = gm.forward( |
|
camera=camera, |
|
image_rgb=None, |
|
fg_probability=None, |
|
sequence_name="", |
|
mask_crop=None, |
|
depth_map=None, |
|
)["images_render"] |
|
self.assertConstant(image[:, :2], 0.4) |
|
self.assertConstant(image[:, 2], 0.2) |
|
|