File size: 8,383 Bytes
7088d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Any, Dict
from unittest.mock import patch

import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.overfit_model import OverfitModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras

DEVICE = torch.device("cuda:0")


def _generate_fake_inputs(N: int, H: int, W: int) -> Dict[str, Any]:
    R, T = look_at_view_transform(azim=torch.rand(N) * 360)
    return {
        "camera": PerspectiveCameras(R=R, T=T, device=DEVICE),
        "fg_probability": torch.randint(
            high=2, size=(N, 1, H, W), device=DEVICE
        ).float(),
        "depth_map": torch.rand((N, 1, H, W), device=DEVICE) + 0.1,
        "mask_crop": torch.randint(high=2, size=(N, 1, H, W), device=DEVICE).float(),
        "sequence_name": ["sequence"] * N,
        "image_rgb": torch.rand((N, 1, H, W), device=DEVICE),
    }


def mock_safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor:
    """Return non deterministic indexes to mock safe_multinomial

    Args:
        input: tensor of shape [B, n] containing non-negative values;
                rows are interpreted as unnormalized event probabilities
                in categorical distributions.
        num_samples: number of samples to take.

    Returns:
        Tensor of shape [B, num_samples]
    """
    batch_size = input.shape[0]
    return torch.arange(num_samples).repeat(batch_size, 1).to(DEVICE)


class TestOverfitModel(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(42)

    def test_overfit_model_vs_generic_model_with_batch_size_one(self):
        """In this test we compare OverfitModel to GenericModel behavior.

        We use a Nerf setup (2 rendering passes).

        OverfitModel is a specific case of GenericModel. Hence, with the same inputs,
        they should provide the exact same results.
        """
        expand_args_fields(OverfitModel)
        expand_args_fields(GenericModel)
        batch_size, image_height, image_width = 1, 80, 80
        assert batch_size == 1
        overfit_model = OverfitModel(
            render_image_height=image_height,
            render_image_width=image_width,
            coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
            # To avoid randomization to compare the outputs of our model
            # we deactivate the stratified_point_sampling_training
            raysampler_AdaptiveRaySampler_args={
                "stratified_point_sampling_training": False
            },
            global_encoder_class_type="SequenceAutodecoder",
            global_encoder_SequenceAutodecoder_args={
                "autodecoder_args": {
                    "n_instances": 1000,
                    "init_scale": 1.0,
                    "encoding_dim": 64,
                }
            },
        )
        generic_model = GenericModel(
            render_image_height=image_height,
            render_image_width=image_width,
            n_train_target_views=batch_size,
            num_passes=2,
            # To avoid randomization to compare the outputs of our model
            # we deactivate the stratified_point_sampling_training
            raysampler_AdaptiveRaySampler_args={
                "stratified_point_sampling_training": False
            },
            global_encoder_class_type="SequenceAutodecoder",
            global_encoder_SequenceAutodecoder_args={
                "autodecoder_args": {
                    "n_instances": 1000,
                    "init_scale": 1.0,
                    "encoding_dim": 64,
                }
            },
        )

        # Check if they do share the number of parameters
        num_params_mvm = sum(p.numel() for p in overfit_model.parameters())
        num_params_gm = sum(p.numel() for p in generic_model.parameters())
        self.assertEqual(num_params_mvm, num_params_gm)

        # Adapt the mapping from generic model to overfit model
        mapping_om_from_gm = {
            key.replace(
                "_implicit_functions.0._fn", "coarse_implicit_function"
            ).replace("_implicit_functions.1._fn", "implicit_function"): val
            for key, val in generic_model.state_dict().items()
        }
        # Copy parameters from generic_model to overfit_model
        overfit_model.load_state_dict(mapping_om_from_gm)

        overfit_model.to(DEVICE)
        generic_model.to(DEVICE)
        inputs_ = _generate_fake_inputs(batch_size, image_height, image_width)

        # training forward pass
        overfit_model.train()
        generic_model.train()

        with patch(
            "pytorch3d.renderer.implicit.raysampling._safe_multinomial",
            side_effect=mock_safe_multinomial,
        ):
            train_preds_om = overfit_model(
                **inputs_,
                evaluation_mode=EvaluationMode.TRAINING,
            )
            train_preds_gm = generic_model(
                **inputs_,
                evaluation_mode=EvaluationMode.TRAINING,
            )

        self.assertTrue(len(train_preds_om) == len(train_preds_gm))

        self.assertTrue(train_preds_om["objective"].isfinite().item())
        # We avoid all the randomization and the weights are the same
        # The objective should be the same
        self.assertTrue(
            torch.allclose(train_preds_om["objective"], train_preds_gm["objective"])
        )

        # Test if the evaluation works
        overfit_model.eval()
        generic_model.eval()
        with torch.no_grad():
            eval_preds_om = overfit_model(
                **inputs_,
                evaluation_mode=EvaluationMode.EVALUATION,
            )
            eval_preds_gm = generic_model(
                **inputs_,
                evaluation_mode=EvaluationMode.EVALUATION,
            )

        self.assertEqual(
            eval_preds_om["images_render"].shape,
            (batch_size, 3, image_height, image_width),
        )
        self.assertTrue(
            torch.allclose(eval_preds_om["objective"], eval_preds_gm["objective"])
        )
        self.assertTrue(
            torch.allclose(
                eval_preds_om["images_render"], eval_preds_gm["images_render"]
            )
        )

    def test_overfit_model_check_share_weights(self):
        model = OverfitModel(share_implicit_function_across_passes=True)
        for p1, p2 in zip(
            model.implicit_function.parameters(),
            model.coarse_implicit_function.parameters(),
        ):
            self.assertEqual(id(p1), id(p2))

        model.to(DEVICE)
        inputs_ = _generate_fake_inputs(2, 80, 80)
        model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)

    def test_overfit_model_check_no_share_weights(self):
        model = OverfitModel(
            share_implicit_function_across_passes=False,
            coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
            coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args={
                "transformer_dim_down_factor": 1.0,
                "n_hidden_neurons_xyz": 256,
                "n_layers_xyz": 8,
                "append_xyz": (5,),
            },
        )
        for p1, p2 in zip(
            model.implicit_function.parameters(),
            model.coarse_implicit_function.parameters(),
        ):
            self.assertNotEqual(id(p1), id(p2))

        model.to(DEVICE)
        inputs_ = _generate_fake_inputs(2, 80, 80)
        model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)

    def test_overfit_model_coarse_implicit_function_is_none(self):
        model = OverfitModel(
            share_implicit_function_across_passes=False,
            coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args=None,
        )
        self.assertIsNone(model.coarse_implicit_function)
        model.to(DEVICE)
        inputs_ = _generate_fake_inputs(2, 80, 80)
        model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)