File size: 11,228 Bytes
e34aada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import torch
import torch.nn as nn
from modules.eg3ds.models.networks_stylegan2 import Generator as StyleGAN2Backbone
from modules.eg3ds.models.networks_stylegan2 import FullyConnectedLayer
from modules.eg3ds.volumetric_rendering.renderer import ImportanceRenderer
from modules.eg3ds.volumetric_rendering.ray_sampler import RaySampler
from modules.eg3ds.models.superresolution import SuperresolutionHybrid2X, SuperresolutionHybrid4X, SuperresolutionHybrid8X, SuperresolutionHybrid8XDC

import copy
from utils.commons.hparams import hparams


class TriPlaneGenerator(torch.nn.Module):
    def __init__(self, hp=None):
        super().__init__()
        global hparams
        self.hparams = copy.copy(hparams) if hp is None else copy.copy(hp)
        hparams = self.hparams

        self.z_dim = hparams['z_dim']
        self.camera_dim = 25
        self.w_dim=hparams['w_dim']

        self.img_resolution = hparams['final_resolution']
        self.img_channels = 3
        self.renderer = ImportanceRenderer(hp=hparams)
        self.renderer.triplane_feature_type = 'triplane'
        self.ray_sampler = RaySampler()

        self.neural_rendering_resolution = hparams['neural_rendering_resolution']

        mapping_kwargs = {'num_layers': hparams['mapping_network_depth']}
        synthesis_kwargs = {'channel_base': hparams['base_channel'], 'channel_max': hparams['max_channel'], 'fused_modconv_default': 'inference_only', 'num_fp16_res': hparams['num_fp16_layers_in_generator'], 'conv_clamp': None}

        triplane_c_dim = self.camera_dim        

        # if gen_cond_mode == 'mapping', add a cond_mapping in backbone
        self.backbone = StyleGAN2Backbone(self.z_dim, triplane_c_dim, self.w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
        self.decoder = OSGDecoder(32, {'decoder_lr_mul': 1, 'decoder_output_dim': 32})
        
        self.rendering_kwargs = {'image_resolution': hparams['final_resolution'], 
                            'disparity_space_sampling': False, 
                            'clamp_mode': 'softplus',
                            'gpc_reg_prob': hparams['gpc_reg_prob'], 
                            'c_scale': 1.0, 
                            'superresolution_noise_mode': 'none', 
                            'density_reg': hparams['lambda_density_reg'], 'density_reg_p_dist': hparams['density_reg_p_dist'], 
                            'reg_type': 'l1', 'decoder_lr_mul': 1.0, 
                            'sr_antialias': True, 
                            'depth_resolution': hparams['num_samples_coarse'], 
                            'depth_resolution_importance': hparams['num_samples_fine'],
                            'ray_start': hparams['ray_near'], 'ray_end': hparams['ray_far'],
                            'box_warp': hparams['box_warp'], 
                            'avg_camera_radius': 2.7, # 仅仅用在infer的pose sampler里面,在那里相机围绕一个半径恒定的球移动,这个半径代表着camera距离世界坐标系中心的距离。
                            'avg_camera_pivot': [0, 0, 0.2], # 仅仅用在infer的pose sampler里面,代表着camera看向的位置,这决定了view direction。这里的[0.,0.,0.2]应该是3dmm人脸的“人中”
                            'white_back': False, # 如果背景是纯白色可以考虑启用,因为默认无density的世界是黑色的,这个设置让默认世界变成白色,这让网络不需要建模一层薄薄的voxel来生成白色背景。
                            }
        
        sr_num_fp16_res = hparams['num_fp16_layers_in_super_resolution']
        sr_kwargs = {'channel_base': hparams['base_channel'], 'channel_max': hparams['max_channel'], 'fused_modconv_default': 'inference_only'}
        self.superresolution = SuperresolutionHybrid8XDC(channels=32, img_resolution=self.img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=True, **sr_kwargs)

    def mapping(self, z, camera, cond=None, truncation_psi=0.7, truncation_cutoff=None, update_emas=False):
        """
        Generate weights by forward the Mapping network.

        z: latent sampled from N(0,1): [B, z_dim=512]
        camera: falttened extrinsic 4x4 matrix and intrinsic 3x3 matrix [B, c=16+9]
        cond: auxiliary condition, such as idexp_lm3d: [B, c=68*3]
        truncation_psi: the threshold of truncation trick in BigGAN, 1.0 means no effect, 0.0 means the ws is the mean_ws, and 0~1 value means linear interpolation in these two.
        truncation_cutoff: number of ws to adopt truncation. default None means adopt to all ws. other int mean the first number of layers to adopt this trick.
        """
        c = camera
        ws = self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
        if hparams.get("gen_cond_mode", 'none') == 'mapping':
            d_ws = self.backbone.cond_mapping(cond, None, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
            ws = ws * 0.5 + d_ws * 0.5
        return ws
            
    def synthesis(self, ws, camera, cond=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
        """
        Run the Backbone to synthesize images given the ws generated by self.mapping
        """
        ret = {}

        cam2world_matrix = camera[:, :16].view(-1, 4, 4)
        intrinsics = camera[:, 16:25].view(-1, 3, 3)

        neural_rendering_resolution = self.neural_rendering_resolution

        # Create a batch of rays for volume rendering
        ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)

        # Create triplanes by running StyleGAN backbone
        N, M, _ = ray_origins.shape
        if use_cached_backbone and self._last_planes is not None:
            planes = self._last_planes
        else:
            planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
        if cache_backbone:
            self._last_planes = planes

        # Reshape output into three 32-channel planes
        planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1])

        # Perform volume rendering
        feature_samples, depth_samples, weights_samples, is_ray_valid = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last

        # Reshape into 'raw' neural-rendered image
        H = W = self.neural_rendering_resolution
        feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
        depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
        if hparams.get("mask_invalid_rays", False):
            is_ray_valid_mask = is_ray_valid.reshape([feature_samples.shape[0], 1,self.neural_rendering_resolution,self.neural_rendering_resolution]) # [B, 1, H, W]
            feature_image[~is_ray_valid_mask.repeat([1,feature_image.shape[1],1,1])] = -1
            depth_image[~is_ray_valid_mask] = depth_image[is_ray_valid_mask].min().item()

        # Run superresolution to get final image
        rgb_image = feature_image[:, :3]
        ws_to_sr = ws
        if hparams['ones_ws_for_sr']:
            ws_to_sr = torch.ones_like(ws)
        sr_image = self.superresolution(rgb_image, feature_image, ws_to_sr, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})

        rgb_image = rgb_image.clamp(-1,1)
        sr_image = sr_image.clamp(-1,1)
        ret.update({'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image, 'image_feature': feature_image[:, 3:], 'plane': planes})
        return ret
    
    def sample(self, coordinates, directions, z, camera, cond=None, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
        """
        Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. 
        Not aggregated into pixels, but in the world coordinate.
        """
        ws = self.mapping(z, camera, cond=cond, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
        planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
        planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
        return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)

    def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
        """
        Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
        """
        planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
        planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
        return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)

    def forward(self, z, camera, cond=None, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
        """
        Render a batch of generated images.
        """
        ws = self.mapping(z, camera, cond=cond, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
        return self.synthesis(ws, camera, cond=cond, update_emas=update_emas, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)


class OSGDecoder(torch.nn.Module):
    def __init__(self, n_features, options):
        super().__init__()
        self.hidden_dim = 64

        self.net = torch.nn.Sequential(
            FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']),
            torch.nn.Softplus(),
            FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])
        )
        
    def forward(self, sampled_features, ray_directions):
        # Aggregate features
        sampled_features = sampled_features.mean(1)
        x = sampled_features

        N, M, C = x.shape
        x = x.view(N*M, C)

        x = self.net(x)
        x = x.view(N, M, -1)
        rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
        sigma = x[..., 0:1]
        return {'rgb': rgb, 'sigma': sigma}