File size: 4,696 Bytes
ef198e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
#
# Modified by Jiale Xu
# The modifications are subject to the same license as the original.
#
# Modified by Yuze He
# The modifications are subject to the same license as the original.


"""
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class MipRayMarcher2(nn.Module):
    def __init__(self, activation_factory):
        super().__init__()
        self.activation_factory = activation_factory

    def run_forward(self, colors, densities, depths, semantics, rendering_options, normals=None):
        dtype = colors.dtype
        deltas = depths[:, :, 1:] - depths[:, :, :-1]
        colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
        densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2
        depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
        semantics_mid = (semantics[:, :, :-1] + semantics[:, :, 1:]) / 2

        # using factory mode for better usability
        densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype)

        density_delta = densities_mid * deltas

        alpha = 1 - torch.exp(-density_delta).to(dtype)

        for bid in range(len(rendering_options['levels'])):
            if rendering_options['levels'][bid] == 0:
                pass
            elif rendering_options['levels'][bid] == 1:
                alpha[bid] = alpha[bid] * (1 - semantics_mid[bid, ..., 0:1])
                semantics_mid[bid, ..., 0:1] = 0  # remove hair
                semantics_mid[bid, ..., 1:] = (semantics_mid[bid, ..., 1:] + 1e-6) / (torch.sum(semantics_mid[bid, ..., 1:], dim=-1, keepdim=True) + 1e-6)
            elif rendering_options['levels'][bid] == 2:
                alpha[bid] = alpha[bid] * (1 - semantics_mid[bid, ..., 0:1] - semantics_mid[bid, ..., 3:4])
                semantics_mid[bid, ..., 0:1] = 0  # remove hair
                semantics_mid[bid, ..., 3:4] = 0  # remove cloth
                semantics_mid[bid, ..., 1:3] = (semantics_mid[bid, ..., 1:3] + 1e-6) / (torch.sum(semantics_mid[bid, ..., 1:3], dim=-1, keepdim=True) + 1e-6)
            else:
                raise NotImplementedError("Only 0, 1, 2 levels are supported")

        alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
        weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
        weights = weights.to(dtype)

        composite_rgb = torch.sum(weights * colors_mid, -2)
        weight_total = weights.sum(2)
        # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
        composite_depth = torch.sum(weights * depths_mid, -2)
        composite_semantics = torch.sum(weights * semantics_mid, -2)

        # clip the composite to min/max range of depths
        composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype)
        composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))

        if rendering_options.get('white_back', False):
            composite_rgb = composite_rgb + 1 - weight_total

        # rendered value scale is 0-1, comment out original mipnerf scaling
        # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)

        return composite_rgb, composite_depth, composite_semantics, weights


    def forward(self, colors, densities, depths, semantics, rendering_options, normals=None):
        if normals is not None:
            raise NotImplementedError("Normals are not supported in the ray marcher yet.")
            composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals)
            return composite_rgb, composite_depth, composite_normals, weights

        composite_rgb, composite_depth, composite_semantic, weights = self.run_forward(colors, densities, depths, semantics, rendering_options)
        return composite_rgb, composite_depth, composite_semantic, weights