File size: 4,021 Bytes
cacb27a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import glob
from omegaconf import DictConfig
from typing import Optional
import torch
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
JsonIndexDatasetMapProviderV2
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.io import IO
from pytorch3d.renderer import (
NDCMultinomialRaysampler,
ray_bundle_to_ray_points,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
HOLDOUT_CATEGORIES = set([
'apple',
'baseballglove',
'cup',
'ball',
'toyplane',
'handbag',
'book',
'carrot',
'suitcase',
'bowl',
])
def get_dataset_map(
dataset_root: str,
category: str,
subset_name: str,
) -> DatasetMap:
"""
Obtain the dataset map that contains the train/val/test dataset objects.
"""
expand_args_fields(JsonIndexDatasetMapProviderV2)
dataset_map_provider = JsonIndexDatasetMapProviderV2(
category=category,
subset_name=subset_name,
dataset_root=dataset_root,
test_on_train=False,
only_test_set=False,
load_eval_batches=True,
dataset_JsonIndexDataset_args=DictConfig({"remove_empty_masks": False, "load_point_clouds": False}),
)
return dataset_map_provider.get_dataset_map()
def _load_pointcloud(pcl_path, max_points):
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)
return pcl
def get_all_dataset_maps(co3d_path, holdout_categories):
all_categories = [c.split('/')[-1] for c in list(glob.glob(co3d_path + '/*')) if not c.endswith('.json')]
all_categories = sorted(all_categories, key=lambda x: hash(x))
# Obtain the CO3Dv2 dataset map
train_dataset_maps = {}
val_dataset_maps = {}
for category in all_categories:
print(f'Loading dataset map ({category})')
dataset_map = {
'train': torch.load(f'dataset_cache/{category}_train.pt'),
'val': torch.load(f'dataset_cache/{category}_val.pt')
}
if not holdout_categories or category not in HOLDOUT_CATEGORIES:
train_dataset_maps[category] = dataset_map['train']
if not holdout_categories or category in HOLDOUT_CATEGORIES:
val_dataset_maps[category] = dataset_map['val']
print('Loaded', len(train_dataset_maps), 'categores for train')
print('Loaded', len(val_dataset_maps), 'categores for val')
return train_dataset_maps, val_dataset_maps
def get_rgbd_points(
imh, imw,
camera: CamerasBase,
depth_map: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask_thr: float = 0.5,
) -> Pointclouds:
"""
Given a batch of images, depths, masks and cameras, generate a colored
point cloud by unprojecting depth maps to the and coloring with the source
pixel colors.
"""
depth_map = torch.nn.functional.interpolate(
depth_map,
size=[imh, imw],
mode="bilinear",
align_corners=False,
)
# convert the depth maps to point clouds using the grid ray sampler
pts_3d = ray_bundle_to_ray_points(
NDCMultinomialRaysampler(
image_width=imw,
image_height=imh,
n_pts_per_ray=1,
min_depth=1.0,
max_depth=1.0,
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
).squeeze(3)[None]
pts_mask = depth_map > 0.0
if mask is not None:
mask = torch.nn.functional.interpolate(
mask,
size=[imh, imw],
mode="bilinear",
align_corners=False,
)
pts_mask *= mask > mask_thr
pts_3d[~pts_mask] = float('inf')
return pts_3d.squeeze(0).squeeze(0)
|