Spaces:
Runtime error
Runtime error
File size: 11,995 Bytes
480bfbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 |
import warnings
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from . import models, utils
from .external_models import lpips
class Projector(nn.Module):
"""
Projects data to latent space and noise tensors.
Arguments:
G (Generator)
dlatent_avg_samples (int): Number of dlatent samples
to collect to find the mean and std.
Default value is 10 000.
dlatent_avg_label (int, torch.Tensor, optional): The label to
use when gathering dlatent statistics.
dlatent_device (int, str, torch.device, optional): Device to use
for gathering statistics of dlatents. By default uses
the same device as parameters of `G` reside on.
dlatent_batch_size (int): The batch size to sample
dlatents with. Default value is 1024.
lpips_model (nn.Module): A model that returns feature the distance
between two inputs. Default value is the LPIPS VGG16 model.
lpips_size (int, optional): Resize any data fed to `lpips_model` by scaling
the data so that its smallest side is the same size as this
argument. Only has a default value of 256 if `lpips_model` is unspecified.
verbose (bool): Write progress of dlatent statistics gathering to stdout.
Default value is True.
"""
def __init__(self,
G,
dlatent_avg_samples=10000,
dlatent_avg_label=None,
dlatent_device=None,
dlatent_batch_size=1024,
lpips_model=None,
lpips_size=None,
verbose=True):
super(Projector, self).__init__()
assert isinstance(G, models.Generator)
G.eval().requires_grad_(False)
self.G_synthesis = G.G_synthesis
G_mapping = G.G_mapping
dlatent_batch_size = min(dlatent_batch_size, dlatent_avg_samples)
if dlatent_device is None:
dlatent_device = next(G_mapping.parameters()).device()
else:
dlatent_device = torch.device(dlatent_device)
G_mapping.to(dlatent_device)
latents = torch.empty(
dlatent_avg_samples, G_mapping.latent_size).normal_()
dlatents = []
labels = None
if dlatent_avg_label is not None:
labels = torch.tensor(dlatent_avg_label).to(dlatent_device).long().view(-1).repeat(dlatent_batch_size)
if verbose:
progress = utils.ProgressWriter(np.ceil(dlatent_avg_samples / dlatent_batch_size))
progress.write('Gathering dlatents...', step=False)
for i in range(0, dlatent_avg_samples, dlatent_batch_size):
batch_latents = latents[i: i + dlatent_batch_size].to(dlatent_device)
batch_labels = None
if labels is not None:
batch_labels = labels[:len(batch_latents)]
with torch.no_grad():
dlatents.append(G_mapping(batch_latents, labels=batch_labels).cpu())
if verbose:
progress.step()
if verbose:
progress.write('Done!', step=False)
progress.close()
dlatents = torch.cat(dlatents, dim=0)
self.register_buffer(
'_dlatent_avg',
dlatents.mean(dim=0).view(1, 1, -1)
)
self.register_buffer(
'_dlatent_std',
torch.sqrt(
torch.sum((dlatents - self._dlatent_avg) ** 2) / dlatent_avg_samples + 1e-8
).view(1, 1, 1)
)
if lpips_model is None:
warnings.warn(
'Using default LPIPS distance metric based on VGG 16. ' + \
'This metric will only work on image data where values are in ' + \
'the range [-1, 1], please specify an lpips module if you want ' + \
'to use other kinds of data formats.'
)
lpips_model = lpips.LPIPS_VGG16(pixel_min=-1, pixel_max=1)
lpips_size = 256
self.lpips_model = lpips_model.eval().requires_grad_(False)
self.lpips_size = lpips_size
self.to(dlatent_device)
def _scale_for_lpips(self, data):
if not self.lpips_size:
return data
scale_factor = self.lpips_size / min(data.size()[2:])
if scale_factor == 1:
return data
mode = 'nearest'
if scale_factor < 1:
mode = 'area'
return F.interpolate(data, scale_factor=scale_factor, mode=mode)
def _check_job(self):
assert self._job is not None, 'Call `start()` first to set up target.'
# device of dlatent param will not change with the rest of the models
# and buffers of this class as it was never registered as a buffer or
# parameter. Same goes for optimizer. Make sure it is on the correct device.
if self._job.dlatent_param.device != self._dlatent_avg.device:
self._job.dlatent_param = self._job.dlatent_param.to(self._dlatent_avg)
self._job.opt.load_state_dict(
utils.move_to_device(self._job.opt.state_dict(), self._dlatent_avg.device)[0])
def generate(self):
"""
Generate an output with the current dlatent and noise values.
Returns:
output (torch.Tensor)
"""
self._check_job()
with torch.no_grad():
return self.G_synthesis(self._job.dlatent_param)
def get_dlatent(self):
"""
Get a copy of the current dlatent values.
Returns:
dlatents (torch.Tensor)
"""
self._check_job()
return self._job.dlatent_param.data.clone()
def get_noise(self):
"""
Get a copy of the current noise values.
Returns:
noise_tensors (list)
"""
self._check_job()
return [noise.data.clone() for noise in self._job.noise_params]
def start(self,
target,
num_steps=1000,
initial_learning_rate=0.1,
initial_noise_factor=0.05,
lr_rampdown_length=0.25,
lr_rampup_length=0.05,
noise_ramp_length=0.75,
regularize_noise_weight=1e5,
verbose=True,
verbose_prefix=''):
"""
Set up a target and its projection parameters.
Arguments:
target (torch.Tensor): The data target. This should
already be preprocessed (scaled to correct value range).
num_steps (int): Number of optimization steps. Default
value is 1000.
initial_learning_rate (float): Default value is 0.1.
initial_noise_factor (float): Default value is 0.05.
lr_rampdown_length (float): Default value is 0.25.
lr_rampup_length (float): Default value is 0.05.
noise_ramp_length (float): Default value is 0.75.
regularize_noise_weight (float): Default value is 1e5.
verbose (bool): Write progress to stdout every time
`step()` is called.
verbose_prefix (str, optional): This is written before
any other output to stdout.
"""
if target.dim() == self.G_synthesis.dim + 1:
target = target.unsqueeze(0)
assert target.dim() == self.G_synthesis.dim + 2, \
'Number of dimensions of target data is incorrect.'
target = target.to(self._dlatent_avg)
target_scaled = self._scale_for_lpips(target)
dlatent_param = nn.Parameter(
self._dlatent_avg.clone().repeat(target.size(0), len(self.G_synthesis), 1))
noise_params = self.G_synthesis.static_noise(trainable=True)
params = [dlatent_param] + noise_params
opt = torch.optim.Adam(params)
noise_tensor = torch.empty_like(dlatent_param)
if verbose:
progress = utils.ProgressWriter(num_steps)
value_tracker = utils.ValueTracker()
self._job = utils.AttributeDict(**locals())
self._job.current_step = 0
def step(self, steps=1):
"""
Take a projection step.
Arguments:
steps (int): Number of steps to take. If this
exceeds the remaining steps of the projection
that amount of steps is taken instead. Default
value is 1.
"""
self._check_job()
remaining_steps = self._job.num_steps - self._job.current_step
if not remaining_steps > 0:
warnings.warn(
'Trying to take a projection step after the ' + \
'final projection iteration has been completed.'
)
if steps < 0:
steps = remaining_steps
steps = min(remaining_steps, steps)
if not steps > 0:
return
for _ in range(steps):
if self._job.current_step >= self._job.num_steps:
break
# Hyperparameters.
t = self._job.current_step / self._job.num_steps
noise_strength = self._dlatent_std * self._job.initial_noise_factor \
* max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2
lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length)
learning_rate = self._job.initial_learning_rate * lr_ramp
for param_group in self._job.opt.param_groups:
param_group['lr'] = learning_rate
dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_()
output = self.G_synthesis(dlatents)
assert output.size() == self._job.target.size(), \
'target size {} does not fit output size {} of generator'.format(
target.size(), output.size())
output_scaled = self._scale_for_lpips(output)
# Main loss: LPIPS distance of output and target
lpips_distance = torch.mean(self.lpips_model(output_scaled, self._job.target_scaled))
# Calculate noise regularization loss
reg_loss = 0
for p in self._job.noise_params:
size = min(p.size()[2:])
dim = p.dim() - 2
while True:
reg_loss += torch.mean(
(p * p.roll(shifts=[1] * dim, dims=list(range(2, 2 + dim)))) ** 2)
if size <= 8:
break
p = F.interpolate(p, scale_factor=0.5, mode='area')
size = size // 2
# Combine loss, backward and update params
loss = lpips_distance + self._job.regularize_noise_weight * reg_loss
self._job.opt.zero_grad()
loss.backward()
self._job.opt.step()
# Normalize noise values
for p in self._job.noise_params:
with torch.no_grad():
p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True)
p_rstd = torch.rsqrt(
torch.mean((p - p_mean) ** 2, dim=list(range(1, p.dim())), keepdim=True) + 1e-8)
p.data = (p.data - p_mean) * p_rstd
self._job.current_step += 1
if self._job.verbose:
self._job.value_tracker.add('loss', float(loss))
self._job.value_tracker.add('lpips_distance', float(lpips_distance))
self._job.value_tracker.add('noise_reg', float(reg_loss))
self._job.value_tracker.add('lr', learning_rate, beta=0)
self._job.progress.write(self._job.verbose_prefix, str(self._job.value_tracker))
if self._job.current_step >= self._job.num_steps:
self._job.progress.close()
|