#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. """ This example demonstrates scene optimization with the PyTorch3D pulsar interface. For this, a reference image has been pre-generated (you can find it at `../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png`). The scene is initialized with random spheres. Gradient-based optimization is used to converge towards a faithful scene representation. """ import logging import math import cv2 import imageio import numpy as np import torch from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.renderer.points import ( PointsRasterizationSettings, PointsRasterizer, PulsarPointsRenderer, ) from pytorch3d.structures.pointclouds import Pointclouds from torch import nn, optim LOGGER = logging.getLogger(__name__) N_POINTS = 10_000 WIDTH = 1_000 HEIGHT = 1_000 DEVICE = torch.device("cuda") class SceneModel(nn.Module): """ A simple scene model to demonstrate use of pulsar in PyTorch modules. The scene model is parameterized with sphere locations (vert_pos), channel content (vert_col), radiuses (vert_rad), camera position (cam_pos), camera rotation (cam_rot) and sensor focal length and width (cam_sensor). The forward method of the model renders this scene description. Any of these parameters could instead be passed as inputs to the forward method and come from a different model. """ def __init__(self): super(SceneModel, self).__init__() self.gamma = 1.0 # Points. torch.manual_seed(1) vert_pos = torch.rand(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 10.0 vert_pos[:, 2] += 25.0 vert_pos[:, :2] -= 5.0 self.register_parameter("vert_pos", nn.Parameter(vert_pos, requires_grad=True)) self.register_parameter( "vert_col", nn.Parameter( torch.ones(N_POINTS, 3, dtype=torch.float32, device=DEVICE) * 0.5, requires_grad=True, ), ) self.register_parameter( "vert_rad", nn.Parameter( torch.ones(N_POINTS, dtype=torch.float32) * 0.3, requires_grad=True ), ) self.register_buffer( "cam_params", torch.tensor( [0.0, 0.0, 0.0, 0.0, math.pi, 0.0, 5.0, 2.0], dtype=torch.float32 ), ) self.cameras = PerspectiveCameras( # The focal length must be double the size for PyTorch3D because of the NDC # coordinates spanning a range of two - and they must be normalized by the # sensor width (see the pulsar example). This means we need here # 5.0 * 2.0 / 2.0 to get the equivalent results as in pulsar. focal_length=5.0, R=torch.eye(3, dtype=torch.float32, device=DEVICE)[None, ...], T=torch.zeros((1, 3), dtype=torch.float32, device=DEVICE), image_size=((HEIGHT, WIDTH),), device=DEVICE, ) raster_settings = PointsRasterizationSettings( image_size=(HEIGHT, WIDTH), radius=self.vert_rad, ) rasterizer = PointsRasterizer( cameras=self.cameras, raster_settings=raster_settings ) self.renderer = PulsarPointsRenderer(rasterizer=rasterizer, n_track=32) def forward(self): # The Pointclouds object creates copies of it's arguments - that's why # we have to create a new object in every forward step. pcl = Pointclouds( points=self.vert_pos[None, ...], features=self.vert_col[None, ...] ) return self.renderer( pcl, gamma=(self.gamma,), zfar=(45.0,), znear=(1.0,), radius_world=True, bg_col=torch.ones((3,), dtype=torch.float32, device=DEVICE), )[0] def cli(): """ Scene optimization example using pulsar and the unified PyTorch3D interface. """ LOGGER.info("Loading reference...") # Load reference. ref = ( torch.from_numpy( imageio.imread( "../../tests/pulsar/reference/examples_TestRenderer_test_smallopt.png" )[:, ::-1, :].copy() ).to(torch.float32) / 255.0 ).to(DEVICE) # Set up model. model = SceneModel().to(DEVICE) # Optimizer. optimizer = optim.SGD( [ {"params": [model.vert_col], "lr": 1e0}, {"params": [model.vert_rad], "lr": 5e-3}, {"params": [model.vert_pos], "lr": 1e-2}, ] ) LOGGER.info("Optimizing...") # Optimize. for i in range(500): optimizer.zero_grad() result = model() # Visualize. result_im = (result.cpu().detach().numpy() * 255).astype(np.uint8) cv2.imshow("opt", result_im[:, :, ::-1]) overlay_img = np.ascontiguousarray( ((result * 0.5 + ref * 0.5).cpu().detach().numpy() * 255).astype(np.uint8)[ :, :, ::-1 ] ) overlay_img = cv2.putText( overlay_img, "Step %d" % (i), (10, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA, False, ) cv2.imshow("overlay", overlay_img) cv2.waitKey(1) # Update. loss = ((result - ref) ** 2).sum() LOGGER.info("loss %d: %f", i, loss.item()) loss.backward() optimizer.step() # Cleanup. with torch.no_grad(): model.vert_col.data = torch.clamp(model.vert_col.data, 0.0, 1.0) # Remove points. model.vert_pos.data[model.vert_rad < 0.001, :] = -1000.0 model.vert_rad.data[model.vert_rad < 0.001] = 0.0001 vd = ( (model.vert_col - torch.ones(3, dtype=torch.float32).to(DEVICE)) .abs() .sum(dim=1) ) model.vert_pos.data[vd <= 0.2] = -1000.0 LOGGER.info("Done.") if __name__ == "__main__": logging.basicConfig(level=logging.INFO) cli()