Spaces:
Runtime error
Runtime error
import warnings | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from . import models, utils | |
from .external_models import lpips | |
class Projector(nn.Module): | |
""" | |
Projects data to latent space and noise tensors. | |
Arguments: | |
G (Generator) | |
dlatent_avg_samples (int): Number of dlatent samples | |
to collect to find the mean and std. | |
Default value is 10 000. | |
dlatent_avg_label (int, torch.Tensor, optional): The label to | |
use when gathering dlatent statistics. | |
dlatent_device (int, str, torch.device, optional): Device to use | |
for gathering statistics of dlatents. By default uses | |
the same device as parameters of `G` reside on. | |
dlatent_batch_size (int): The batch size to sample | |
dlatents with. Default value is 1024. | |
lpips_model (nn.Module): A model that returns feature the distance | |
between two inputs. Default value is the LPIPS VGG16 model. | |
lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling | |
the data so that its smallest side is the same size as this | |
argument. Only has a default value of 256 if `lpips_model` is unspecified. | |
verbose (bool): Write progress of dlatent statistics gathering to stdout. | |
Default value is True. | |
""" | |
def __init__(self, | |
G, | |
dlatent_avg_samples=10000, | |
dlatent_avg_label=None, | |
dlatent_device=None, | |
dlatent_batch_size=1024, | |
lpips_model=None, | |
lpips_size=None, | |
verbose=True): | |
super(Projector, self).__init__() | |
assert isinstance(G, models.Generator) | |
G.eval().requires_grad_(False) | |
self.G_synthesis = G.G_synthesis | |
G_mapping = G.G_mapping | |
dlatent_batch_size = min(dlatent_batch_size, dlatent_avg_samples) | |
if dlatent_device is None: | |
dlatent_device = next(G_mapping.parameters()).device() | |
else: | |
dlatent_device = torch.device(dlatent_device) | |
G_mapping.to(dlatent_device) | |
latents = torch.empty( | |
dlatent_avg_samples, G_mapping.latent_size).normal_() | |
dlatents = [] | |
labels = None | |
if dlatent_avg_label is not None: | |
labels = torch.tensor(dlatent_avg_label).to(dlatent_device).long().view(-1).repeat(dlatent_batch_size) | |
if verbose: | |
progress = utils.ProgressWriter(np.ceil(dlatent_avg_samples / dlatent_batch_size)) | |
progress.write('Gathering dlatents...', step=False) | |
for i in range(0, dlatent_avg_samples, dlatent_batch_size): | |
batch_latents = latents[i: i + dlatent_batch_size].to(dlatent_device) | |
batch_labels = None | |
if labels is not None: | |
batch_labels = labels[:len(batch_latents)] | |
with torch.no_grad(): | |
dlatents.append(G_mapping(batch_latents, labels=batch_labels).cpu()) | |
if verbose: | |
progress.step() | |
if verbose: | |
progress.write('Done!', step=False) | |
progress.close() | |
dlatents = torch.cat(dlatents, dim=0) | |
self.register_buffer( | |
'_dlatent_avg', | |
dlatents.mean(dim=0).view(1, 1, -1) | |
) | |
self.register_buffer( | |
'_dlatent_std', | |
torch.sqrt( | |
torch.sum((dlatents - self._dlatent_avg) ** 2) / dlatent_avg_samples + 1e-8 | |
).view(1, 1, 1) | |
) | |
if lpips_model is None: | |
warnings.warn( | |
'Using default LPIPS distance metric based on VGG 16. ' + \ | |
'This metric will only work on image data where values are in ' + \ | |
'the range [-1, 1], please specify an lpips module if you want ' + \ | |
'to use other kinds of data formats.' | |
) | |
lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1) | |
lpips_size = 256 | |
self.lpips_model = lpips_model.eval().requires_grad_(False) | |
self.lpips_size = lpips_size | |
self.to(dlatent_device) | |
def _scale_for_lpips(self, data): | |
if not self.lpips_size: | |
return data | |
scale_factor = self.lpips_size / min(data.size()[2:]) | |
if scale_factor == 1: | |
return data | |
mode = 'nearest' | |
if scale_factor < 1: | |
mode = 'area' | |
return F.interpolate(data, scale_factor=scale_factor, mode=mode) | |
def _check_job(self): | |
assert self._job is not None, 'Call `start()` first to set up target.' | |
# device of dlatent param will not change with the rest of the models | |
# and buffers of this class as it was never registered as a buffer or | |
# parameter. Same goes for optimizer. Make sure it is on the correct device. | |
if self._job.dlatent_param.device != self._dlatent_avg.device: | |
self._job.dlatent_param = self._job.dlatent_param.to(self._dlatent_avg) | |
self._job.opt.load_state_dict( | |
utils.move_to_device(self._job.opt.state_dict(), self._dlatent_avg.device)[0]) | |
def generate(self): | |
""" | |
Generate an output with the current dlatent and noise values. | |
Returns: | |
output (torch.Tensor) | |
""" | |
self._check_job() | |
with torch.no_grad(): | |
return self.G_synthesis(self._job.dlatent_param) | |
def get_dlatent(self): | |
""" | |
Get a copy of the current dlatent values. | |
Returns: | |
dlatents (torch.Tensor) | |
""" | |
self._check_job() | |
return self._job.dlatent_param.data.clone() | |
def get_noise(self): | |
""" | |
Get a copy of the current noise values. | |
Returns: | |
noise_tensors (list) | |
""" | |
self._check_job() | |
return [noise.data.clone() for noise in self._job.noise_params] | |
def start(self, | |
target, | |
num_steps=1000, | |
initial_learning_rate=0.1, | |
initial_noise_factor=0.05, | |
lr_rampdown_length=0.25, | |
lr_rampup_length=0.05, | |
noise_ramp_length=0.75, | |
regularize_noise_weight=1e5, | |
verbose=True, | |
verbose_prefix=''): | |
""" | |
Set up a target and its projection parameters. | |
Arguments: | |
target (torch.Tensor): The data target. This should | |
already be preprocessed (scaled to correct value range). | |
num_steps (int): Number of optimization steps. Default | |
value is 1000. | |
initial_learning_rate (float): Default value is 0.1. | |
initial_noise_factor (float): Default value is 0.05. | |
lr_rampdown_length (float): Default value is 0.25. | |
lr_rampup_length (float): Default value is 0.05. | |
noise_ramp_length (float): Default value is 0.75. | |
regularize_noise_weight (float): Default value is 1e5. | |
verbose (bool): Write progress to stdout every time | |
`step()` is called. | |
verbose_prefix (str, optional): This is written before | |
any other output to stdout. | |
""" | |
if target.dim() == self.G_synthesis.dim + 1: | |
target = target.unsqueeze(0) | |
assert target.dim() == self.G_synthesis.dim + 2, \ | |
'Number of dimensions of target data is incorrect.' | |
target = target.to(self._dlatent_avg) | |
target_scaled = self._scale_for_lpips(target) | |
dlatent_param = nn.Parameter( | |
self._dlatent_avg.clone().repeat(target.size(0), len(self.G_synthesis), 1)) | |
noise_params = self.G_synthesis.static_noise(trainable=True) | |
params = [dlatent_param] + noise_params | |
opt = torch.optim.Adam(params) | |
noise_tensor = torch.empty_like(dlatent_param) | |
if verbose: | |
progress = utils.ProgressWriter(num_steps) | |
value_tracker = utils.ValueTracker() | |
self._job = utils.AttributeDict(**locals()) | |
self._job.current_step = 0 | |
def step(self, steps=1): | |
""" | |
Take a projection step. | |
Arguments: | |
steps (int): Number of steps to take. If this | |
exceeds the remaining steps of the projection | |
that amount of steps is taken instead. Default | |
value is 1. | |
""" | |
self._check_job() | |
remaining_steps = self._job.num_steps - self._job.current_step | |
if not remaining_steps > 0: | |
warnings.warn( | |
'Trying to take a projection step after the ' + \ | |
'final projection iteration has been completed.' | |
) | |
if steps < 0: | |
steps = remaining_steps | |
steps = min(remaining_steps, steps) | |
if not steps > 0: | |
return | |
for _ in range(steps): | |
if self._job.current_step >= self._job.num_steps: | |
break | |
# Hyperparameters. | |
t = self._job.current_step / self._job.num_steps | |
noise_strength = self._dlatent_std * self._job.initial_noise_factor \ | |
* max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2 | |
lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length) | |
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) | |
lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length) | |
learning_rate = self._job.initial_learning_rate * lr_ramp | |
for param_group in self._job.opt.param_groups: | |
param_group['lr'] = learning_rate | |
dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_() | |
output = self.G_synthesis(dlatents) | |
assert output.size() == self._job.target.size(), \ | |
'target size {} does not fit output size {} of generator'.format( | |
target.size(), output.size()) | |
output_scaled = self._scale_for_lpips(output) | |
# Main loss: LPIPS distance of output and target | |
lpips_distance = torch.mean(self.lpips_model(output_scaled, self._job.target_scaled)) | |
# Calculate noise regularization loss | |
reg_loss = 0 | |
for p in self._job.noise_params: | |
size = min(p.size()[2:]) | |
dim = p.dim() - 2 | |
while True: | |
reg_loss += torch.mean( | |
(p * p.roll(shifts=[1] * dim, dims=list(range(2, 2 + dim)))) ** 2) | |
if size <= 8: | |
break | |
p = F.interpolate(p, scale_factor=0.5, mode='area') | |
size = size // 2 | |
# Combine loss, backward and update params | |
loss = lpips_distance + self._job.regularize_noise_weight * reg_loss | |
self._job.opt.zero_grad() | |
loss.backward() | |
self._job.opt.step() | |
# Normalize noise values | |
for p in self._job.noise_params: | |
with torch.no_grad(): | |
p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True) | |
p_rstd = torch.rsqrt( | |
torch.mean((p - p_mean) ** 2, dim=list(range(1, p.dim())), keepdim=True) + 1e-8) | |
p.data = (p.data - p_mean) * p_rstd | |
self._job.current_step += 1 | |
if self._job.verbose: | |
self._job.value_tracker.add('loss', float(loss)) | |
self._job.value_tracker.add('lpips_distance', float(lpips_distance)) | |
self._job.value_tracker.add('noise_reg', float(reg_loss)) | |
self._job.value_tracker.add('lr', learning_rate, beta=0) | |
self._job.progress.write(self._job.verbose_prefix, str(self._job.value_tracker)) | |
if self._job.current_step >= self._job.num_steps: | |
self._job.progress.close() | |