# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. """Discriminator architectures from the paper "Efficient Geometry-aware 3D Generative Adversarial Networks".""" import numpy as np import torch import torch.nn as nn # from modules.eg3ds.torch_utils.ops import upfirdn2d from modules.eg3ds.models.networks_stylegan2 import DiscriminatorBlock, MappingNetwork, DiscriminatorEpilogue from einops import rearrange from utils.commons.hparams import hparams class SingleDiscriminator(torch.nn.Module): def __init__(self, img_resolution, # Input resolution. img_channels =3, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. sr_upsample_factor = 1, # Ignored for SingleDiscriminator block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() self.camera_dim = 25 if hparams['disc_cond_mode'] == 'idexp_lm3d_normalized': self.cond_dim = 204 else: self.cond_dim = 0 c_dim = self.camera_dim self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) def forward(self, img, camera, cond=None, update_emas=False, **block_kwargs): img = img['image'] _ = update_emas # unused x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None c = camera if self.cond_dim > 0: cond_feat = self.cond_encoder(cond) c = torch.cat([c, cond_feat], dim=-1) # [b, 25+8] cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x def extra_repr(self): return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' #---------------------------------------------------------------------------- def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'): is_bcthw_flag = True if image_orig_tensor.ndim == 5 else False if is_bcthw_flag: # [B, c, T, H, W] n,c,t,h,w = image_orig_tensor.shape image_orig_tensor = rearrange(image_orig_tensor, "n c t h w -> (n t) c h w") if filter_mode == 'antialiased': ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True) elif filter_mode == 'classic': ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2) ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False) ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1) elif filter_mode == 'none': ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False) elif type(filter_mode) == float: assert 0 < filter_mode < 1 filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=True) aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False, antialias=False) ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered if is_bcthw_flag: # [B, c, T, H, W] ada_filtered_64 = rearrange(ada_filtered_64, "(n t) c h w -> n c t h w", n=n,t=t) return ada_filtered_64 #---------------------------------------------------------------------------- class DualDiscriminator(torch.nn.Module): def __init__(self): super().__init__() channel_base = hparams['base_channel'] channel_max = hparams['max_channel'] conv_clamp = 256 cmap_dim = None block_kwargs = {'freeze_layers': 0} mapping_kwargs = {} epilogue_kwargs = {'mbstd_group_size': hparams['group_size_for_mini_batch_std']} architecture = 'resnet' # Architecture: 'orig', 'skip', 'resnet'. img_channels = 3 img_channels *= 2 self.camera_dim = 25 c_dim = self.camera_dim self.img_resolution = hparams['final_resolution'] self.img_resolution_log2 = int(np.log2(self.img_resolution)) self.img_channels = 3 self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} self.num_fp16_res = hparams['num_fp16_layers_in_discriminator'] fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - self.num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < self.img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) # use_fp16 = True block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) if hparams.get("disc_cond_mode", 'none') != 'none': """ For discriminator, embed cond with mapping network works well. """ self.cond_dim = 204 self.mapping = MappingNetwork(z_dim=self.cond_dim, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) def forward(self, img, camera, cond=None, update_emas=False, feature_maps=None, **block_kwargs): image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) img = torch.cat([img['image'], image_raw], 1) # add by yerfor img = torch.clamp(img, min=-1, max=1) _ = update_emas # unused x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) if feature_maps is not None: feature_maps.append(x) cmap = None c = camera.clone() # prevent inplace modification in sample! if hparams['disc_c_noise'] > 0: if len(c) > 1: c_std = c.std(0) else: # c_std = 1 c_std = torch.tensor([0.0664, 0.0295, 0.2720, 0.6971, 0.0279, 0.0178, 0.1280, 0.3284, 0.2721, 0.1274, 0.0679, 0.1642, 0.0000, 0.0000, 0.0000, 0.0000, 0.0079, 0.0000, 0.0000, 0.0000, 0.0079, 0.0000, 0.0000, 0.0000, 0.0000]).to(c.device) c += torch.randn_like(c) * c_std * hparams['disc_c_noise'] # x: [B, 512, 4, 4], img: None, cmap: [B, 512] if hparams.get("disc_cond_mode", 'none') != 'none': cmap = self.mapping(cond, c) else: cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x def extra_repr(self): return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' #---------------------------------------------------------------------------- class DummyDualDiscriminator(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() img_channels *= 2 self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) self.raw_fade = 1 def forward(self, img, c, update_emas=False, **block_kwargs): self.raw_fade = max(0, self.raw_fade - 1/(500000/32)) image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade img = torch.cat([img['image'], image_raw], 1) _ = update_emas # unused x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None if self.c_dim > 0: cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x def extra_repr(self): return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}' #---------------------------------------------------------------------------- # Tri-discriminator: upsampled image, super-resolved image, and segmentation mask # V2: first concatenate imgs and seg mask, using only one conv block class MaskDualDiscriminatorV2(torch.nn.Module): def __init__(self, c_dim, # Conditioning label (C) dimensionality. img_resolution, # Input resolution. img_channels, # Number of input color channels. seg_resolution = 128, # Input resolution. seg_channels = 1, # Number of input color channels. architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'. channel_base = 32768, # Overall multiplier for the number of channels. channel_max = 512, # Maximum number of channels in any layer. num_fp16_res = 4, # Use FP16 for the N highest resolutions. conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping. cmap_dim = None, # Dimensionality of mapped conditioning label, None = default. disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning. block_kwargs = {}, # Arguments for DiscriminatorBlock. mapping_kwargs = {}, # Arguments for MappingNetwork. epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue. ): super().__init__() img_channels = img_channels * 2 + seg_channels self.c_dim = c_dim self.img_resolution = img_resolution self.img_resolution_log2 = int(np.log2(img_resolution)) self.img_channels = img_channels self.seg_resolution = seg_resolution self.seg_channels = seg_channels self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)] channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]} fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8) if cmap_dim is None: cmap_dim = channels_dict[4] if c_dim == 0: cmap_dim = 0 common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp) cur_layer_idx = 0 for res in self.block_resolutions: in_channels = channels_dict[res] if res < img_resolution else 0 tmp_channels = channels_dict[res] out_channels = channels_dict[res // 2] use_fp16 = (res >= fp16_resolution) block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res, first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs) setattr(self, f'b{res}', block) cur_layer_idx += block.num_layers if c_dim > 0: self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs) self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs) self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1])) self.disc_c_noise = disc_c_noise def forward(self, img, c, update_emas=False, **block_kwargs): image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) seg = filtered_resizing(img['image_mask'], size=img['image'].shape[-1], f=self.resample_filter) seg = 2 * seg - 1 # normalize to [-1,1] img = torch.cat([img['image'], image_raw, seg], 1) _ = update_emas # unused x = None for res in self.block_resolutions: block = getattr(self, f'b{res}') x, img = block(x, img, **block_kwargs) cmap = None if self.c_dim > 0: if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise cmap = self.mapping(None, c) x = self.b4(x, img, cmap) return x def extra_repr(self): return ' '.join([ f'c_dim={self.c_dim:d},', f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},', f'seg_resolution={self.seg_resolution:d}, seg_channels={self.seg_channels:d}'])