Spaces:
Runtime error
Runtime error
File size: 3,194 Bytes
6709fc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
from torch import nn
from torch.nn import Module
from torch.nn import functional as F
import math
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
rest_dim = [1] * (input.ndim - bias.ndim - 1)
input = input #.cuda()
if input.ndim == 3:
return (
F.leaky_relu(
input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
)
* scale
)
else:
return (
F.leaky_relu(
input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
)
* scale
)
class PixelNorm(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
class EqualLinear(nn.Module):
def __init__(
self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
):
super().__init__()
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
if bias:
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
else:
self.bias = None
self.activation = activation
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
self.lr_mul = lr_mul
def forward(self, input):
if self.activation:
out = F.linear(input, self.weight * self.scale)
out = fused_leaky_relu(out, self.bias * self.lr_mul)
else:
out = F.linear(
input, self.weight * self.scale, bias=self.bias * self.lr_mul
)
return out
class Mapper(Module):
def __init__(self, latent_dim=512):
super(Mapper, self).__init__()
layers = [PixelNorm()]
for i in range(4):
layers.append(
EqualLinear(
latent_dim, latent_dim, lr_mul=0.01, activation='fused_lrelu'
)
)
self.mapping = nn.Sequential(*layers)
def forward(self, x):
x = self.mapping(x)
return x
class LevelsMapper(Module):
def __init__(self, opts):
super(LevelsMapper, self).__init__()
self.opts = opts
if not opts.no_coarse_mapper:
self.course_mapping = Mapper()
if not opts.no_medium_mapper:
self.medium_mapping = Mapper()
if not opts.no_fine_mapper:
self.fine_mapping = Mapper()
def forward(self, x):
x_coarse = x[:, :4, :]
x_medium = x[:, 4:8, :]
x_fine = x[:, 8:, :]
if not self.opts.no_coarse_mapper:
x_coarse = self.course_mapping(x_coarse)
else:
x_coarse = torch.zeros_like(x_coarse)
if not self.opts.no_medium_mapper:
x_medium = self.medium_mapping(x_medium)
else:
x_medium = torch.zeros_like(x_medium)
if not self.opts.no_fine_mapper:
x_fine = self.fine_mapping(x_fine)
else:
x_fine = torch.zeros_like(x_fine)
out = torch.cat([x_coarse, x_medium, x_fine], dim=1)
return out |