File size: 4,021 Bytes
cacb27a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import glob
from omegaconf import DictConfig
from typing import Optional

import torch

from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
    JsonIndexDatasetMapProviderV2
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.io import IO
from pytorch3d.renderer import (
    NDCMultinomialRaysampler,
    ray_bundle_to_ray_points,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds


HOLDOUT_CATEGORIES = set([
    'apple',
    'baseballglove',
    'cup',
    'ball',
    'toyplane',
    'handbag',
    'book',
    'carrot',
    'suitcase',
    'bowl',
])

def get_dataset_map(
    dataset_root: str,
    category: str,
    subset_name: str,
) -> DatasetMap:
    """
    Obtain the dataset map that contains the train/val/test dataset objects.
    """
    expand_args_fields(JsonIndexDatasetMapProviderV2)
    dataset_map_provider = JsonIndexDatasetMapProviderV2(
        category=category,
        subset_name=subset_name,
        dataset_root=dataset_root,
        test_on_train=False,
        only_test_set=False,
        load_eval_batches=True,
        dataset_JsonIndexDataset_args=DictConfig({"remove_empty_masks": False, "load_point_clouds": False}),
    )
    return dataset_map_provider.get_dataset_map()


def _load_pointcloud(pcl_path, max_points):
    pcl = IO().load_pointcloud(pcl_path)
    if max_points > 0:
        pcl = pcl.subsample(max_points)

    return pcl


def get_all_dataset_maps(co3d_path, holdout_categories):
    all_categories = [c.split('/')[-1] for c in list(glob.glob(co3d_path + '/*')) if not c.endswith('.json')]
    all_categories = sorted(all_categories, key=lambda x: hash(x))

    # Obtain the CO3Dv2 dataset map
    train_dataset_maps = {}
    val_dataset_maps = {}
    for category in all_categories:

        print(f'Loading dataset map ({category})')
        dataset_map = {
            'train': torch.load(f'dataset_cache/{category}_train.pt'),
            'val': torch.load(f'dataset_cache/{category}_val.pt')
        }
        if not holdout_categories or category not in HOLDOUT_CATEGORIES:
            train_dataset_maps[category] = dataset_map['train']
        if not holdout_categories or category in HOLDOUT_CATEGORIES:
            val_dataset_maps[category] = dataset_map['val']

    print('Loaded', len(train_dataset_maps), 'categores for train')
    print('Loaded', len(val_dataset_maps), 'categores for val')
    return train_dataset_maps, val_dataset_maps


def get_rgbd_points(
    imh, imw,
    camera: CamerasBase,
    depth_map: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    mask_thr: float = 0.5,
) -> Pointclouds:
    """
    Given a batch of images, depths, masks and cameras, generate a colored
    point cloud by unprojecting depth maps to the  and coloring with the source
    pixel colors.
    """
    depth_map = torch.nn.functional.interpolate(
        depth_map,
        size=[imh, imw],
        mode="bilinear",
        align_corners=False,
    )
    # convert the depth maps to point clouds using the grid ray sampler
    pts_3d = ray_bundle_to_ray_points(
        NDCMultinomialRaysampler(
            image_width=imw,
            image_height=imh,
            n_pts_per_ray=1,
            min_depth=1.0,
            max_depth=1.0,
        )(camera)._replace(lengths=depth_map[:, 0, ..., None])
    ).squeeze(3)[None]

    pts_mask = depth_map > 0.0
    if mask is not None:
        mask = torch.nn.functional.interpolate(
            mask,
            size=[imh, imw],
            mode="bilinear",
            align_corners=False,
        )
        pts_mask *= mask > mask_thr
    pts_3d[~pts_mask] = float('inf')
    return pts_3d.squeeze(0).squeeze(0)