# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import contextlib import math import os import unittest from typing import Tuple import torch from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround from pytorch3d.implicitron.tools.config import expand_args_fields from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d from pytorch3d.renderer.cameras import CamerasBase from tests.common_testing import interactive_testing_requested from visdom import Visdom from .common_resources import get_skateboard_data class TestModelVisualize(unittest.TestCase): def test_flyaround_one_sequence( self, image_size: int = 256, ): if not interactive_testing_requested(): return category = "skateboard" stack = contextlib.ExitStack() dataset_root, path_manager = stack.enter_context(get_skateboard_data()) self.addCleanup(stack.close) frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") subset_lists_file = os.path.join(dataset_root, category, "set_lists.json") expand_args_fields(JsonIndexDataset) train_dataset = JsonIndexDataset( frame_annotations_file=frame_file, sequence_annotations_file=sequence_file, subset_lists_file=subset_lists_file, dataset_root=dataset_root, image_height=image_size, image_width=image_size, box_crop=True, load_point_clouds=True, path_manager=path_manager, subsets=[ "train_known", ], ) # select few sequences to visualize sequence_names = list(train_dataset.seq_annots.keys()) # select the first sequence name show_sequence_name = sequence_names[0] output_dir = os.path.split(os.path.abspath(__file__))[0] visdom_show_preds = Visdom().check_connection() for load_dataset_pointcloud in [True, False]: model = _PointcloudRenderingModel( train_dataset, show_sequence_name, device="cuda:0", load_dataset_pointcloud=load_dataset_pointcloud, ) video_path = os.path.join( output_dir, f"load_pcl_{load_dataset_pointcloud}", ) os.makedirs(output_dir, exist_ok=True) for output_video_frames_dir in [None, video_path]: render_flyaround( train_dataset, show_sequence_name, model, video_path, n_flyaround_poses=10, fps=5, max_angle=2 * math.pi, trajectory_type="circular_lsq_fit", trajectory_scale=1.1, scene_center=(0.0, 0.0, 0.0), up=(0.0, 1.0, 0.0), traj_offset=1.0, n_source_views=1, visdom_show_preds=visdom_show_preds, visdom_environment="test_model_visalize", visdom_server="http://127.0.0.1", visdom_port=8097, num_workers=10, seed=None, video_resize=None, visualize_preds_keys=[ "images_render", "depths_render", "masks_render", "_all_source_images", ], output_video_frames_dir=output_video_frames_dir, ) class _PointcloudRenderingModel(torch.nn.Module): def __init__( self, train_dataset: JsonIndexDataset, sequence_name: str, render_size: Tuple[int, int] = (400, 400), device=None, load_dataset_pointcloud: bool = False, max_frames: int = 30, num_workers: int = 10, ): super().__init__() self._render_size = render_size point_cloud, _ = get_implicitron_sequence_pointcloud( train_dataset, sequence_name=sequence_name, mask_points=True, max_frames=max_frames, num_workers=num_workers, load_dataset_point_cloud=load_dataset_pointcloud, ) self._point_cloud = point_cloud.to(device) def forward( self, camera: CamerasBase, **kwargs, ): image_render, mask_render, depth_render = render_point_cloud_pytorch3d( camera[0], self._point_cloud, render_size=self._render_size, point_radius=1e-2, topk=10, bg_color=0.0, ) return { "images_render": image_render.clamp(0.0, 1.0), "masks_render": mask_render, "depths_render": depth_render, }