|
import torch |
|
import numpy as np |
|
from torch_utils.ops import bias_act |
|
from torch_utils import misc |
|
|
|
|
|
|
|
def normalize_2nd_moment(x, dim=1, eps=1e-8): |
|
return x * (x.square().mean(dim=dim, keepdim=True) + eps).rsqrt() |
|
|
|
|
|
class FullyConnectedLayer_normal(torch.nn.Module): |
|
def __init__(self, |
|
in_features, |
|
out_features, |
|
bias = True, |
|
bias_init = 0, |
|
): |
|
super().__init__() |
|
self.fc = torch.nn.Linear(in_features, out_features, bias=bias) |
|
if bias: |
|
with torch.no_grad(): |
|
self.fc.bias.fill_(bias_init) |
|
|
|
def forward(self, x): |
|
output = self.fc(x) |
|
return output |
|
|
|
|
|
class MappingNetwork_normal(torch.nn.Module): |
|
def __init__(self, |
|
in_features, |
|
int_dim, |
|
num_layers = 8, |
|
mapping_normalization = False |
|
): |
|
super().__init__() |
|
layers = [torch.nn.Linear(in_features, int_dim), torch.nn.LeakyReLU(0.2)] |
|
for i in range(1, num_layers): |
|
layers.append(torch.nn.Linear(int_dim, int_dim)) |
|
layers.append(torch.nn.LeakyReLU(0.2)) |
|
|
|
self.net = torch.nn.Sequential(*layers) |
|
self.normalization = mapping_normalization |
|
|
|
def forward(self, x): |
|
if self.normalization: |
|
x = normalize_2nd_moment(x) |
|
output = self.net(x) |
|
return output |
|
|
|
|
|
class DecodingNetwork(torch.nn.Module): |
|
def __init__(self, |
|
in_features, |
|
out_dim, |
|
num_layers = 8, |
|
): |
|
super().__init__() |
|
layers = [] |
|
for i in range(num_layers-1): |
|
layers.append(torch.nn.Linear(in_features, in_features)) |
|
layers.append(torch.nn.ReLU()) |
|
|
|
layers.append(torch.nn.Linear(in_features, out_dim)) |
|
|
|
self.net = torch.nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = torch.nn.functional.normalize(x, dim=1) |
|
output = self.net(x) |
|
return output |
|
|
|
|
|
class FullyConnectedLayer(torch.nn.Module): |
|
def __init__(self, |
|
in_features, |
|
out_features, |
|
bias = True, |
|
activation = 'linear', |
|
lr_multiplier = 1, |
|
bias_init = 0, |
|
): |
|
super().__init__() |
|
self.activation = activation |
|
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) |
|
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None |
|
self.weight_gain = lr_multiplier / np.sqrt(in_features) |
|
self.bias_gain = lr_multiplier |
|
|
|
def forward(self, x): |
|
w = self.weight.to(x.dtype) * self.weight_gain |
|
b = self.bias |
|
if b is not None: |
|
b = b.to(x.dtype) |
|
if self.bias_gain != 1: |
|
b = b * self.bias_gain |
|
|
|
if self.activation == 'linear' and b is not None: |
|
x = torch.addmm(b.unsqueeze(0), x, w.t()) |
|
else: |
|
x = x.matmul(w.t()) |
|
x = bias_act.bias_act(x, b, act=self.activation) |
|
return x |
|
|
|
|
|
class MappingNetwork(torch.nn.Module): |
|
def __init__(self, |
|
z_dim, |
|
c_dim, |
|
w_dim, |
|
num_ws, |
|
num_layers = 8, |
|
embed_features = None, |
|
layer_features = None, |
|
activation = 'lrelu', |
|
lr_multiplier = 0.01, |
|
w_avg_beta = 0.995, |
|
normalization = None |
|
): |
|
super().__init__() |
|
self.z_dim = z_dim |
|
self.c_dim = c_dim |
|
self.w_dim = w_dim |
|
self.num_ws = num_ws |
|
self.num_layers = num_layers |
|
self.w_avg_beta = w_avg_beta |
|
self.normalization = normalization |
|
|
|
if embed_features is None: |
|
embed_features = w_dim |
|
if c_dim == 0: |
|
embed_features = 0 |
|
if layer_features is None: |
|
layer_features = w_dim |
|
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim] |
|
|
|
if c_dim > 0: |
|
self.embed = FullyConnectedLayer(c_dim, embed_features) |
|
for idx in range(num_layers): |
|
in_features = features_list[idx] |
|
out_features = features_list[idx + 1] |
|
layer = FullyConnectedLayer(in_features, out_features, activation=activation, lr_multiplier=lr_multiplier) |
|
setattr(self, f'fc{idx}', layer) |
|
|
|
if num_ws is not None and w_avg_beta is not None: |
|
self.register_buffer('w_avg', torch.zeros([w_dim])) |
|
|
|
def forward(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False): |
|
|
|
x = None |
|
with torch.autograd.profiler.record_function('input'): |
|
if self.z_dim > 0: |
|
misc.assert_shape(z, [None, self.z_dim]) |
|
if self.normalization: |
|
x = normalize_2nd_moment(z.to(torch.float32)) |
|
else: |
|
x = z |
|
x = z.to(torch.float32) |
|
if self.c_dim > 0: |
|
raise ValueError("This implementation does not need class index") |
|
misc.assert_shape(c, [None, self.c_dim]) |
|
y = normalize_2nd_moment(self.embed(c.to(torch.float32))) |
|
y = self.embed(c.to(torch.float32)) |
|
x = torch.cat([x, y], dim=1) if x is not None else y |
|
|
|
|
|
for idx in range(self.num_layers): |
|
layer = getattr(self, f'fc{idx}') |
|
x = layer(x) |
|
|
|
|
|
if self.w_avg_beta is not None and self.training and not skip_w_avg_update: |
|
with torch.autograd.profiler.record_function('update_w_avg'): |
|
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta)) |
|
|
|
|
|
if self.num_ws is not None: |
|
with torch.autograd.profiler.record_function('broadcast'): |
|
x = x.unsqueeze(1).repeat([1, self.num_ws, 1]) |
|
|
|
|
|
if truncation_psi != 1: |
|
with torch.autograd.profiler.record_function('truncate'): |
|
assert self.w_avg_beta is not None |
|
if self.num_ws is None or truncation_cutoff is None: |
|
x = self.w_avg.lerp(x, truncation_psi) |
|
else: |
|
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi) |
|
return x |
|
|