File size: 7,664 Bytes
744eb4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
from torch_utils import persistence
from training.networks_stylegan2 import Generator as StyleGAN2Backbone
# from training.volumetric_rendering.renderer import ImportanceRenderer
# from training.volumetric_rendering.ray_sampler import RaySampler
import dnnlib
@persistence.persistent_class
class TriPlaneGenerator(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
# img_resolution, # Output resolution.
# img_channels, # Number of output color channels.
# sr_num_fp16_res = 0,
mapping_kwargs = {}, # Arguments for MappingNetwork.
# rendering_kwargs = {},
# sr_kwargs = {},
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim=z_dim
self.c_dim=c_dim
self.w_dim=w_dim
# self.img_resolution=img_resolution
# self.img_channels=img_channels
# self.renderer = ImportanceRenderer()
# self.ray_sampler = RaySampler()
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
# self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
self.decoder = OSGDecoder(32, {'decoder_output_dim': 0})
# self.neural_rendering_resolution = 64
# self.rendering_kwargs = rendering_kwargs
self._last_planes = None
def mapping(self, z, c=None, truncation_psi=1, truncation_cutoff=None, update_emas=False):
# if self.rendering_kwargs['c_gen_conditioning_zero']:
# c = torch.zeros_like(c)
# return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
return self.backbone.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
def synthesis(self, ws, c=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
# cam2world_matrix = c[:, :16].view(-1, 4, 4)
# intrinsics = c[:, 16:25].view(-1, 3, 3)
# if neural_rendering_resolution is None:
# neural_rendering_resolution = self.neural_rendering_resolution
# else:
# self.neural_rendering_resolution = neural_rendering_resolution
# Create a batch of rays for volume rendering
# ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
# Create triplanes by running StyleGAN backbone
# N, M, _ = ray_origins.shape
if use_cached_backbone and self._last_planes is not None:
planes = self._last_planes
else:
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
if cache_backbone:
self._last_planes = planes
# Reshape output into three 32-channel planes
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return planes
# Perform volume rendering
feature_samples, depth_samples, weights_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last
# Reshape into 'raw' neural-rendered image
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run superresolution to get final image
rgb_image = feature_image[:, :3]
sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image}
def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
# Render a batch of generated images.
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
return self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
from .training.networks_stylegan2 import FullyConnectedLayer
class OSGDecoder(torch.nn.Module):
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = torch.nn.Sequential(
FullyConnectedLayer(n_features, self.hidden_dim),
torch.nn.Softplus(),
FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'])
)
def forward(self, sampled_features, ray_directions=None):
# Aggregate features
sampled_features = sampled_features.mean(1)
x = sampled_features
N, M, C = x.shape
x = x.view(N*M, C)
x = self.net(x)
x = x.view(N, M, -1)
return x
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}
|