Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from torch.nn import Module | |
from torch.nn import functional as F | |
import math | |
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): | |
rest_dim = [1] * (input.ndim - bias.ndim - 1) | |
input = input #.cuda() | |
if input.ndim == 3: | |
return ( | |
F.leaky_relu( | |
input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope | |
) | |
* scale | |
) | |
else: | |
return ( | |
F.leaky_relu( | |
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope | |
) | |
* scale | |
) | |
class PixelNorm(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, input): | |
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) | |
class EqualLinear(nn.Module): | |
def __init__( | |
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None | |
): | |
super().__init__() | |
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) | |
else: | |
self.bias = None | |
self.activation = activation | |
self.scale = (1 / math.sqrt(in_dim)) * lr_mul | |
self.lr_mul = lr_mul | |
def forward(self, input): | |
if self.activation: | |
out = F.linear(input, self.weight * self.scale) | |
out = fused_leaky_relu(out, self.bias * self.lr_mul) | |
else: | |
out = F.linear( | |
input, self.weight * self.scale, bias=self.bias * self.lr_mul | |
) | |
return out | |
class Mapper(Module): | |
def __init__(self, latent_dim=512): | |
super(Mapper, self).__init__() | |
layers = [PixelNorm()] | |
for i in range(4): | |
layers.append( | |
EqualLinear( | |
latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu' | |
) | |
) | |
self.mapping = nn.Sequential(*layers) | |
def forward(self, x): | |
x = self.mapping(x) | |
return x | |
class LevelsMapper(Module): | |
def __init__(self, opts): | |
super(LevelsMapper, self).__init__() | |
self.opts = opts | |
if not opts.no_coarse_mapper: | |
self.course_mapping = Mapper() | |
if not opts.no_medium_mapper: | |
self.medium_mapping = Mapper() | |
if not opts.no_fine_mapper: | |
self.fine_mapping = Mapper() | |
def forward(self, x): | |
x_coarse = x[:, :4, :] | |
x_medium = x[:, 4:8, :] | |
x_fine = x[:, 8:, :] | |
if not self.opts.no_coarse_mapper: | |
x_coarse = self.course_mapping(x_coarse) | |
else: | |
x_coarse = torch.zeros_like(x_coarse) | |
if not self.opts.no_medium_mapper: | |
x_medium = self.medium_mapping(x_medium) | |
else: | |
x_medium = torch.zeros_like(x_medium) | |
if not self.opts.no_fine_mapper: | |
x_fine = self.fine_mapping(x_fine) | |
else: | |
x_fine = torch.zeros_like(x_fine) | |
out = torch.cat([x_coarse, x_medium, x_fine], dim=1) | |
return out |