DmitrMakeev's picture
Upload 8 files
add457a
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import functools
import torch
from torch import nn
from torch.nn.utils import spectral_norm, weight_norm
from .conv import LinearBlock
class WeightDemodulation(nn.Module):
r"""Weight demodulation in
"Analyzing and Improving the Image Quality of StyleGAN", Karras et al.
Args:
conv (torch.nn.Modules): Convolutional layer.
cond_dims (int): The number of channels in the conditional input.
eps (float, optional, default=1e-8): a value added to the
denominator for numerical stability.
adaptive_bias (bool, optional, default=False): If ``True``, adaptively
predicts bias from the conditional input.
demod (bool, optional, default=False): If ``True``, performs
weight demodulation.
"""
def __init__(self, conv, cond_dims, eps=1e-8,
adaptive_bias=False, demod=True):
super().__init__()
self.conv = conv
self.adaptive_bias = adaptive_bias
if adaptive_bias:
self.conv.register_parameter('bias', None)
self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels)
self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels)
self.eps = eps
self.demod = demod
self.conditional = True
def forward(self, x, y):
r"""Weight demodulation forward"""
b, c, h, w = x.size()
self.conv.groups = b
gamma = self.fc_gamma(y)
gamma = gamma[:, None, :, None, None]
weight = self.conv.weight[None, :, :, :, :] * (gamma + 1)
if self.demod:
d = torch.rsqrt(
(weight ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
weight = weight * d
x = x.reshape(1, -1, h, w)
_, _, *ws = weight.shape
weight = weight.reshape(b * self.conv.out_channels, *ws)
x = self.conv.conv2d_forward(x, weight)
x = x.reshape(-1, self.conv.out_channels, h, w)
if self.adaptive_bias:
x += self.fc_beta(y)[:, :, None, None]
return x
def weight_demod(conv, cond_dims=256, eps=1e-8, demod=True):
r"""Weight demodulation."""
return WeightDemodulation(conv, cond_dims, eps, demod)
def get_weight_norm_layer(norm_type, **norm_params):
r"""Return weight normalization.
Args:
norm_type (str):
Type of weight normalization.
``'none'``, ``'spectral'``, ``'weight'``
or ``'weight_demod'``.
norm_params: Arbitrary keyword arguments that will be used to
initialize the weight normalization.
"""
if norm_type == 'none' or norm_type == '': # no normalization
return lambda x: x
elif norm_type == 'spectral': # spectral normalization
return functools.partial(spectral_norm, **norm_params)
elif norm_type == 'weight': # weight normalization
return functools.partial(weight_norm, **norm_params)
elif norm_type == 'weight_demod': # weight demodulation
return functools.partial(weight_demod, **norm_params)
else:
raise ValueError(
'Weight norm layer %s is not recognized' % norm_type)