|
|
|
|
|
|
|
|
|
import functools |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn.utils import spectral_norm, weight_norm |
|
|
|
from .conv import LinearBlock |
|
|
|
|
|
class WeightDemodulation(nn.Module): |
|
r"""Weight demodulation in |
|
"Analyzing and Improving the Image Quality of StyleGAN", Karras et al. |
|
|
|
Args: |
|
conv (torch.nn.Modules): Convolutional layer. |
|
cond_dims (int): The number of channels in the conditional input. |
|
eps (float, optional, default=1e-8): a value added to the |
|
denominator for numerical stability. |
|
adaptive_bias (bool, optional, default=False): If ``True``, adaptively |
|
predicts bias from the conditional input. |
|
demod (bool, optional, default=False): If ``True``, performs |
|
weight demodulation. |
|
""" |
|
|
|
def __init__(self, conv, cond_dims, eps=1e-8, |
|
adaptive_bias=False, demod=True): |
|
super().__init__() |
|
self.conv = conv |
|
self.adaptive_bias = adaptive_bias |
|
if adaptive_bias: |
|
self.conv.register_parameter('bias', None) |
|
self.fc_beta = LinearBlock(cond_dims, self.conv.out_channels) |
|
self.fc_gamma = LinearBlock(cond_dims, self.conv.in_channels) |
|
self.eps = eps |
|
self.demod = demod |
|
self.conditional = True |
|
|
|
def forward(self, x, y): |
|
r"""Weight demodulation forward""" |
|
b, c, h, w = x.size() |
|
self.conv.groups = b |
|
gamma = self.fc_gamma(y) |
|
gamma = gamma[:, None, :, None, None] |
|
weight = self.conv.weight[None, :, :, :, :] * (gamma + 1) |
|
|
|
if self.demod: |
|
d = torch.rsqrt( |
|
(weight ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps) |
|
weight = weight * d |
|
|
|
x = x.reshape(1, -1, h, w) |
|
_, _, *ws = weight.shape |
|
weight = weight.reshape(b * self.conv.out_channels, *ws) |
|
x = self.conv.conv2d_forward(x, weight) |
|
|
|
x = x.reshape(-1, self.conv.out_channels, h, w) |
|
if self.adaptive_bias: |
|
x += self.fc_beta(y)[:, :, None, None] |
|
return x |
|
|
|
|
|
def weight_demod(conv, cond_dims=256, eps=1e-8, demod=True): |
|
r"""Weight demodulation.""" |
|
return WeightDemodulation(conv, cond_dims, eps, demod) |
|
|
|
|
|
def get_weight_norm_layer(norm_type, **norm_params): |
|
r"""Return weight normalization. |
|
|
|
Args: |
|
norm_type (str): |
|
Type of weight normalization. |
|
``'none'``, ``'spectral'``, ``'weight'`` |
|
or ``'weight_demod'``. |
|
norm_params: Arbitrary keyword arguments that will be used to |
|
initialize the weight normalization. |
|
""" |
|
if norm_type == 'none' or norm_type == '': |
|
return lambda x: x |
|
elif norm_type == 'spectral': |
|
return functools.partial(spectral_norm, **norm_params) |
|
elif norm_type == 'weight': |
|
return functools.partial(weight_norm, **norm_params) |
|
elif norm_type == 'weight_demod': |
|
return functools.partial(weight_demod, **norm_params) |
|
else: |
|
raise ValueError( |
|
'Weight norm layer %s is not recognized' % norm_type) |
|
|