hr16's picture
Fork adriansahlman's stylegan2_pytorch
480bfbc
import warnings
import functools
import os
import time
import sys
import json
import numpy as np
import torch
import torch.utils.tensorboard
from torch import nn
import torchvision
try:
import apex
from apex import amp
except ImportError:
pass
from . import models, utils, loss_fns
class Trainer:
"""
Class that handles training and logging for stylegan2.
For distributed training, the arguments `rank`, `world_size`,
`master_addr`, `master_port` can all be given as environmnet variables
(only difference is that the keys should be capital cased).
Environment variables if available will override any python
value for the same argument.
Arguments:
G (Generator): The generator model.
D (Discriminator): The discriminator model.
latent_size (int): The size of the latent inputs.
dataset (indexable object): The dataset. Has to implement
'__getitem__' and '__len__'. If `label_size` > 0, this
dataset object has to return both a data entry and its
label when calling '__getitem__'.
device (str, int, list, torch.device): The device to run training on.
Can be a list of integers for parallel training in the same
process. Parallel training can also be achieved by spawning
seperate processes and using the `rank` argument for each
process. In that case, only one device should be specified
per process.
Gs (Generator, optional): A generator copy with the current
moving average of the training generator. If not specified,
a copy of the generator is made for the moving average of
weights.
Gs_beta (float): The beta value for the moving average weights.
Default value is 1 / (2 ^(32 / 10000)).
Gs_device (str, int, torch.device, optional): The device to store
the moving average weights on. If using a different device
than what is specified for the `device` argument, updating
the moving average weights will take longer as the data
will have to be transfered over different devices. If
this argument is not specified, the same device is used
as specified in the `device` argument.
batch_size (int): The total batch size to average gradients
over. This should be the combined batch size of all used
devices (it is later divided by world size for distributed
training).
Example: We want to average gradients over 32 data
entries. To do this we just set `batch_size=32`.
Even if we train on 8 GPUs we still use the same
batch size (each GPU will take 4 data entries per
batch).
Default value is 32.
device_batch_size (int): The number of data entries that can
fit on the specified device at a time.
Example: We want to average gradients over 32 data
entries. To do this we just set `batch_size=32`.
However, our device can only handle a batch of
4 at a time before running out of memory. We
therefor set `device_batch_size=4`. With a
single device (no distributed training), each
batch is split into 32 / 4 parts and gradients
are averaged over all these parts.
Default value is 4.
label_size (int, optional): Number of possible class labels.
This is required for conditioning the GAN with labels.
If not specified it is assumed that no labels are used.
data_workers (int): The number of spawned processes that
handle data loading. Default value is 4.
G_loss (str, callable): The loss function to use
for the generator. If string, it can be one of the
following: 'logistic', 'logistic_ns' or 'wgan'.
If not a string, the callable has to follow
the format of functions found in `stylegan2.loss`.
Default value is 'logistic_ns' (non-saturating logistic).
D_loss (str, callable): The loss function to use
for the discriminator. If string, it can be one of the
following: 'logistic' or 'wgan'.
If not a string, same restriction follows as for `G_loss`.
Default value is 'logistic'.
G_reg (str, callable, None): The regularizer function to use
for the generator. If string, it can only be 'pathreg'
(pathlength regularization). A weight for the regularizer
can be passed after the string name like the following:
G_reg='pathreg:5'
This will assign a weight of 5 to the regularization loss.
If set to None, no geenerator regularization is performed.
Default value is 'pathreg:2'.
G_reg_interval (int): The interval at which to regularize the
generator. If set to 0, the regularization and loss gradients
are combined in a single optimization step every iteration.
If set to 1, the gradients for the regularization and loss
are used separately for two optimization steps. Any value
higher than 1 indicates that regularization should only
be performed at this interval (lazy regularization).
Default value is 4.
G_opt_class (str, class): The optimizer class for the generator.
Default value is 'Adam'.
G_opt_kwargs (dict): Keyword arguments for the generator optimizer
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
G_reg_batch_size (int): Same as `batch_size` but only for
the regularization loss of the generator. Default value
is 16.
G_reg_device_batch_size (int): Same as `device_batch_size`
but only for the regularization loss of the generator.
Default value is 2.
D_reg (str, callable, None): The regularizer function to use
for the discriminator. If string, the following values
can be used: 'r1', 'r2', 'gp'. See doc for `G_reg` for
rest of info on regularizer format.
Default value is 'r1:10'.
D_reg_interval (int): Same as `D_reg_interval` but for the
discriminator. Default value is 16.
D_opt_class (str, class): The optimizer class for the discriminator.
Default value is 'Adam'.
D_opt_kwargs (dict): Keyword arguments for the discriminator optimizer
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
style_mix_prob (float): The probability of passing 2 latents instead of 1
to the generator during training. Default value is 0.9.
G_iter (int): Number of generator iterations for every full training
iteration. Default value is 1.
D_iter (int): Number of discriminator iterations for every full training
iteration. Default value is 1.
pl_avg (float, torch.Tensor): The average pathlength starting value for
pathlength regularization of the generator. Default value is 0.
tensorboard_log_dir (str, optional): A path to a directory to log training values
in for tensorboard. Only used without distributed training or when
distributed training is enabled and the rank of this trainer is 0.
checkpoint_dir (str, optional): A path to a directory to save training
checkpoints to. If not specified, not checkpoints are automatically
saved during training.
checkpoint_interval (int): The interval at which to save training checkpoints.
Default value is 10000.
seen (int): The number of previously trained iterations. Used for logging.
Default value is 0.
half (bool): Use mixed precision training. Default value is False.
rank (int, optional): If set, use distributed training. Expects that
this object has been constructed with the same arguments except
for `rank` in different processes.
world_size (int, optional): If using distributed training, this specifies
the number of nodes in the training.
master_addr (str): The master address for distributed training.
Default value is '127.0.0.1'.
master_port (str): The master port for distributed training.
Default value is '23456'.
"""
def __init__(self,
G,
D,
latent_size,
dataset,
device,
Gs=None,
Gs_beta=0.5 ** (32 / 10000),
Gs_device=None,
batch_size=32,
device_batch_size=4,
label_size=0,
data_workers=4,
G_loss='logistic_ns',
D_loss='logistic',
G_reg='pathreg:2',
G_reg_interval=4,
G_opt_class='Adam',
G_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
G_reg_batch_size=None,
G_reg_device_batch_size=None,
D_reg='r1:10',
D_reg_interval=16,
D_opt_class='Adam',
D_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
style_mix_prob=0.9,
G_iter=1,
D_iter=1,
pl_avg=0.,
tensorboard_log_dir=None,
checkpoint_dir=None,
checkpoint_interval=10000,
seen=0,
half=False,
rank=None,
world_size=None,
master_addr='127.0.0.1',
master_port='23456'):
assert not isinstance(G, nn.parallel.DistributedDataParallel) and \
not isinstance(D, nn.parallel.DistributedDataParallel), \
'Encountered a model wrapped in `DistributedDataParallel`. ' + \
'Distributed parallelism is handled by this class and can ' + \
'not be initialized before.'
# We store the training settings in a dict that can be saved as a json file.
kwargs = locals()
# First we remove the arguments that can not be turned into json.
kwargs.pop('self')
kwargs.pop('G')
kwargs.pop('D')
kwargs.pop('Gs')
kwargs.pop('dataset')
# Some arguments may have to be turned into strings to be compatible with json.
kwargs.update(pl_avg=float(pl_avg))
if isinstance(device, torch.device):
kwargs.update(device=str(device))
if isinstance(Gs_device, torch.device):
kwargs.update(device=str(Gs_device))
self.kwargs = kwargs
if device or device == 0:
if isinstance(device, (tuple, list)):
self.device = torch.device(device[0])
else:
self.device = torch.device(device)
else:
self.device = torch.device('cpu')
if self.device.index is not None:
torch.cuda.set_device(self.device.index)
else:
assert not half, 'Mixed precision training only available ' + \
'for CUDA devices.'
# Set up the models
self.G = G.train().to(self.device)
self.D = D.train().to(self.device)
if isinstance(device, (tuple, list)) and len(device) > 1:
assert all(isinstance(dev, int) for dev in device), \
'Multiple devices have to be specified as a list ' + \
'or tuple of integers corresponding to device indices.'
# TODO: Look into bug with torch.autograd.grad and nn.DataParallel
# In the meanwhile just prohibit its use together.
assert G_reg is None and D_reg is None, 'Regularization ' + \
'currently not supported for multi-gpu training in single process. ' + \
'Please use distributed training with one device per process instead.'
device_batch_size *= len(device)
def to_data_parallel(model):
if not isinstance(model, nn.DataParallel):
return nn.DataParallel(model, device_ids=device)
return model
self.G = to_data_parallel(self.G)
self.D = to_data_parallel(self.D)
# Default generator reg batch size is the global batch size
# unless it has been specified otherwise.
G_reg_batch_size = G_reg_batch_size or batch_size
G_reg_device_batch_size = G_reg_device_batch_size or device_batch_size
# Set up distributed training
rank = os.environ.get('RANK', rank)
if rank is not None:
rank = int(rank)
addr = os.environ.get('MASTER_ADDR', master_addr)
port = os.environ.get('MASTER_PORT', master_port)
world_size = os.environ.get('WORLD_SIZE', world_size)
assert world_size is not None, 'Distributed training ' + \
'requires specifying world size.'
world_size = int(world_size)
assert self.device.index is not None, \
'Distributed training is only supported for CUDA.'
assert batch_size % world_size == 0, 'Batch size has to be ' + \
'evenly divisible by world size.'
assert G_reg_batch_size % world_size == 0, 'G reg batch size has to be ' + \
'evenly divisible by world size.'
batch_size = batch_size // world_size
G_reg_batch_size = G_reg_batch_size // world_size
init_method = 'tcp://{}:{}'.format(addr, port)
torch.distributed.init_process_group(
backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
else:
world_size = 1
self.rank = rank
self.world_size = world_size
# Set up variable to keep track of moving average of path lengths
self.pl_avg = torch.tensor(
pl_avg, dtype=torch.float16 if half else torch.float32, device=self.device)
# Broadcast parameters from rank 0 if running distributed
self._sync_distributed(G=self.G, D=self.D, broadcast_weights=True)
# Set up moving average of generator
# Only for non-distributed training or
# if rank is 0
if not self.rank:
# Values for `rank`: None -> not distributed, 0 -> distributed and 'main' node
self.Gs = Gs
if not isinstance(Gs, utils.MovingAverageModule):
self.Gs = utils.MovingAverageModule(
from_module=self.G,
to_module=Gs,
param_beta=Gs_beta,
device=self.device if Gs_device is None else Gs_device
)
else:
self.Gs = None
# Set up loss and regularization functions
self.G_loss = get_loss_fn('G', G_loss)
self.D_loss = get_loss_fn('D', D_loss)
self.G_reg = get_reg_fn('G', G_reg, pl_avg=self.pl_avg)
self.D_reg = get_reg_fn('D', D_reg)
self.G_reg_interval = G_reg_interval
self.D_reg_interval = D_reg_interval
self.G_iter = G_iter
self.D_iter = D_iter
# Set up optimizers (adjust hyperparameters if lazy regularization is active)
self.G_opt = build_opt(self.G, G_opt_class, G_opt_kwargs, self.G_reg, self.G_reg_interval)
self.D_opt = build_opt(self.D, D_opt_class, D_opt_kwargs, self.D_reg, self.D_reg_interval)
# Set up mixed precision training
if half:
assert 'apex' in sys.modules, 'Can not run mixed precision ' + \
'training (`half=True`) without the apex module.'
(self.G, self.D), (self.G_opt, self.D_opt) = amp.initialize(
[self.G, self.D], [self.G_opt, self.D_opt], opt_level='O1')
self.half = half
# Data
sampler = None
if self.rank is not None:
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
self.dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=device_batch_size,
num_workers=data_workers,
shuffle=sampler is None,
pin_memory=self.device.index is not None,
drop_last=True,
sampler=sampler
)
self.dataloader_iter = None
self.prior_generator = utils.PriorGenerator(
latent_size=latent_size,
label_size=label_size,
batch_size=device_batch_size,
device=self.device
)
assert batch_size % device_batch_size == 0, \
'Batch size has to be evenly divisible by the product of ' + \
'device batch size and world size.'
self.subdivisions = batch_size // device_batch_size
assert G_reg_batch_size % G_reg_device_batch_size == 0, \
'G reg batch size has to be evenly divisible by the product of ' + \
'G reg device batch size and world size.'
self.G_reg_subdivisions = G_reg_batch_size // G_reg_device_batch_size
self.G_reg_device_batch_size = G_reg_device_batch_size
self.tb_writer = None
if tensorboard_log_dir and not self.rank:
self.tb_writer = torch.utils.tensorboard.SummaryWriter(tensorboard_log_dir)
self.label_size = label_size
self.style_mix_prob = style_mix_prob
self.checkpoint_dir = checkpoint_dir
self.checkpoint_interval = checkpoint_interval
self.seen = seen
self.metrics = {}
self.callbacks = []
def _get_batch(self):
"""
Fetch a batch and its labels. If no labels are
available the returned labels will be `None`.
Returns:
data
labels
"""
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
try:
batch = next(self.dataloader_iter)
except StopIteration:
self.dataloader_iter = None
return self._get_batch()
if isinstance(batch, (tuple, list)):
if len(batch) > 1:
data, label = batch[:2]
else:
data, label = batch[0], None
else:
data, label = batch, None
if not self.label_size:
label = None
if torch.is_tensor(data):
data = data.to(self.device)
if torch.is_tensor(label):
label = label.to(self.device)
return data, label
def _sync_distributed(self, G=None, D=None, broadcast_weights=False):
"""
Sync the gradients (and alternatively the weights) of
the specified networks over the distributed training
nodes. Varying buffers are broadcasted from rank 0.
If no distributed training is not enabled, no action
is taken and this is a no-op function.
Arguments:
G (Generator, optional)
D (Discriminator, optional)
broadcast_weights (bool): Broadcast the weights from
node of rank 0 to all other ranks. Default
value is False.
"""
if self.rank is None:
return
for net in [G, D]:
if net is None:
continue
for p in net.parameters():
if p.grad is not None:
torch.distributed.all_reduce(p.grad, async_op=True)
if broadcast_weights:
torch.distributed.broadcast(p.data, src=0, async_op=True)
if G is not None:
if G.dlatent_avg is not None:
torch.distributed.broadcast(G.dlatent_avg, src=0, async_op=True)
if self.pl_avg is not None:
torch.distributed.broadcast(self.pl_avg, src=0, async_op=True)
if G is not None or D is not None:
torch.distributed.barrier(async_op=False)
def _backward(self, loss, opt, mul=1, subdivisions=None):
"""
Reduce loss by world size and subdivisions before
calling backward for the loss. Loss scaling is
performed when mixed precision training is
enabled.
Arguments:
loss (torch.Tensor)
opt (torch.optim.Optimizer)
mul (float): Loss weight. Default value is 1.
subdivisions (int, optional): The number of
subdivisions to divide by. If this is
not specified, the subdvisions from
the specified batch and device size
at construction is used.
Returns:
loss (torch.Tensor): The loss scaled by mul
and subdivisions but not by world size.
"""
if loss is None:
return 0
mul /= subdivisions or self.subdivisions
mul /= self.world_size or 1
if mul != 1:
loss *= mul
if self.half:
with amp.scale_loss(loss, opt) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
#get the scalar only
return loss.item() * (self.world_size or 1)
def train(self, iterations, callbacks=None, verbose=True):
"""
Train the models for a specific number of iterations.
Arguments:
iterations (int): Number of iterations to train for.
callbacks (callable, list, optional): One
or more callbacks to call at the end of each
iteration. The function is given the total
number of batches that have been processed since
this trainer object was initialized (not reset
when loading a saved checkpoint).
Default value is None (unused).
verbose (bool): Write progress to stdout.
Default value is True.
"""
evaluated_metrics = {}
if self.rank:
verbose=False
if verbose:
progress = utils.ProgressWriter(iterations)
value_tracker = utils.ValueTracker()
for _ in range(iterations):
# Figure out if G and/or D be
# regularized this iteration
G_reg = self.G_reg is not None
if self.G_reg_interval and G_reg:
G_reg = self.seen % self.G_reg_interval == 0
D_reg = self.D_reg is not None
if self.D_reg_interval and D_reg:
D_reg = self.seen % self.D_reg_interval == 0
# -----| Train G |----- #
# Disable gradients for D while training G
self.D.requires_grad_(False)
for _ in range(self.G_iter):
self.G_opt.zero_grad()
G_loss = 0
for i in range(self.subdivisions):
latents, latent_labels = self.prior_generator(
multi_latent_prob=self.style_mix_prob)
loss, _ = self.G_loss(
G=self.G,
D=self.D,
latents=latents,
latent_labels=latent_labels
)
G_loss += self._backward(loss, self.G_opt)
if G_reg:
if self.G_reg_interval:
# For lazy regularization, even if the interval
# is set to 1, the optimization step is taken
# before the gradients of the regularization is gathered.
self._sync_distributed(G=self.G)
self.G_opt.step()
self.G_opt.zero_grad()
G_reg_loss = 0
# Pathreg is expensive to compute which
# is why G regularization has its own settings
# for subdivisions and batch size.
for i in range(self.G_reg_subdivisions):
latents, latent_labels = self.prior_generator(
batch_size=self.G_reg_device_batch_size,
multi_latent_prob=self.style_mix_prob
)
_, reg_loss = self.G_reg(
G=self.G,
latents=latents,
latent_labels=latent_labels
)
G_reg_loss += self._backward(
reg_loss,
self.G_opt, mul=self.G_reg_interval or 1,
subdivisions=self.G_reg_subdivisions
)
self._sync_distributed(G=self.G)
self.G_opt.step()
# Update moving average of weights after
# each G training subiteration
if self.Gs is not None:
self.Gs.update()
# Re-enable gradients for D
self.D.requires_grad_(True)
# -----| Train D |----- #
# Disable gradients for G while training D
self.G.requires_grad_(False)
for _ in range(self.D_iter):
self.D_opt.zero_grad()
D_loss = 0
for i in range(self.subdivisions):
latents, latent_labels = self.prior_generator(
multi_latent_prob=self.style_mix_prob)
reals, real_labels = self._get_batch()
loss, _ = self.D_loss(
G=self.G,
D=self.D,
latents=latents,
latent_labels=latent_labels,
reals=reals,
real_labels=real_labels
)
D_loss += self._backward(loss, self.D_opt)
if D_reg:
if self.D_reg_interval:
# For lazy regularization, even if the interval
# is set to 1, the optimization step is taken
# before the gradients of the regularization is gathered.
self._sync_distributed(D=self.D)
self.D_opt.step()
self.D_opt.zero_grad()
D_reg_loss = 0
for i in range(self.subdivisions):
latents, latent_labels = self.prior_generator(
multi_latent_prob=self.style_mix_prob)
reals, real_labels = self._get_batch()
_, reg_loss = self.D_reg(
G=self.G,
D=self.D,
latents=latents,
latent_labels=latent_labels,
reals=reals,
real_labels=real_labels
)
D_reg_loss += self._backward(
reg_loss, self.D_opt, mul=self.D_reg_interval or 1)
self._sync_distributed(D=self.D)
self.D_opt.step()
# Re-enable grads for G
self.G.requires_grad_(True)
if self.tb_writer is not None or verbose:
# In case verbose is true and tensorboard logging enabled
# we calculate grad norm here to only do it once as well
# as making sure we do it before any metrics that may
# possibly zero the grads.
G_grad_norm = utils.get_grad_norm_from_optimizer(self.G_opt)
D_grad_norm = utils.get_grad_norm_from_optimizer(self.D_opt)
for name, metric in self.metrics.items():
if not metric['interval'] or self.seen % metric['interval'] == 0:
evaluated_metrics[name] = metric['eval_fn']()
# Printing and logging
# Tensorboard logging
if self.tb_writer is not None:
self.tb_writer.add_scalar('Loss/G_loss', G_loss, self.seen)
if G_reg:
self.tb_writer.add_scalar('Loss/G_reg', G_reg_loss, self.seen)
self.tb_writer.add_scalar('Grad_norm/G_reg', G_grad_norm, self.seen)
self.tb_writer.add_scalar('Params/pl_avg', self.pl_avg, self.seen)
else:
self.tb_writer.add_scalar('Grad_norm/G_loss', G_grad_norm, self.seen)
self.tb_writer.add_scalar('Loss/D_loss', D_loss, self.seen)
if D_reg:
self.tb_writer.add_scalar('Loss/D_reg', D_reg_loss, self.seen)
self.tb_writer.add_scalar('Grad_norm/D_reg', D_grad_norm, self.seen)
else:
self.tb_writer.add_scalar('Grad_norm/D_loss', D_grad_norm, self.seen)
for name, value in evaluated_metrics.items():
self.tb_writer.add_scalar('Metrics/{}'.format(name), value, self.seen)
# Printing
if verbose:
value_tracker.add('seen', self.seen + 1, beta=0)
value_tracker.add('G_lr', self.G_opt.param_groups[0]['lr'], beta=0)
value_tracker.add('G_loss', G_loss)
if G_reg:
value_tracker.add('G_reg', G_reg_loss)
value_tracker.add('G_reg_grad_norm', G_grad_norm)
value_tracker.add('pl_avg', self.pl_avg, beta=0)
else:
value_tracker.add('G_loss_grad_norm', G_grad_norm)
value_tracker.add('D_lr', self.D_opt.param_groups[0]['lr'], beta=0)
value_tracker.add('D_loss', D_loss)
if D_reg:
value_tracker.add('D_reg', D_reg_loss)
value_tracker.add('D_reg_grad_norm', D_grad_norm)
else:
value_tracker.add('D_loss_grad_norm', D_grad_norm)
for name, value in evaluated_metrics.items():
value_tracker.add(name, value, beta=0)
progress.write(str(value_tracker))
# Callback
for callback in utils.to_list(callbacks) + self.callbacks:
callback(self.seen)
self.seen += 1
# clear cache
torch.cuda.empty_cache()
# Handle checkpointing
if not self.rank and self.checkpoint_dir and self.checkpoint_interval:
if self.seen % self.checkpoint_interval == 0:
checkpoint_path = os.path.join(
self.checkpoint_dir,
'{}_{}'.format(self.seen, time.strftime('%Y-%m-%d_%H-%M-%S'))
)
self.save_checkpoint(checkpoint_path)
if verbose:
progress.close()
def register_metric(self, name, eval_fn, interval):
"""
Add a metric. This will be evaluated every `interval`
training iteration. Used by tensorboard and progress
updates written to stdout while training.
Arguments:
name (str): A name for the metric. If a metric with
this name already exists it will be overwritten.
eval_fn (callable): A function that evaluates the metric
and returns a python number.
interval (int): The interval to evaluate at.
"""
self.metrics[name] = {'eval_fn': eval_fn, 'interval': interval}
def remove_metric(self, name):
"""
Remove a metric that was previously registered.
Arguments:
name (str): Name of the metric.
"""
if name in self.metrics:
del self.metrics[name]
else:
warnings.warn(
'Attempting to remove metric {} '.format(name) + \
'which does not exist.'
)
def generate_images(self,
num_images,
seed=None,
truncation_psi=None,
truncation_cutoff=None,
label=None,
pixel_min=-1,
pixel_max=1):
"""
Generate some images with the generator and transform them into PIL
images and return them as a list.
Arguments:
num_images (int): Number of images to generate.
seed (int, optional): The seed for the random generation
of input latent values.
truncation_psi (float): See stylegan2.model.Generator.set_truncation()
Default value is None.
truncation_cutoff (int): See stylegan2.model.Generator.set_truncation()
label (int, list, optional): Label to condition all generated images with
or multiple labels, one for each generated image.
pixel_min (float): The min value in the pixel range of the generator.
Default value is -1.
pixel_min (float): The max value in the pixel range of the generator.
Default value is 1.
Returns:
images (list): List of PIL images.
"""
if seed is None:
seed = int(10000 * time.time())
latents, latent_labels = self.prior_generator(num_images, seed=seed)
if label:
assert latent_labels is not None, 'Can not specify label when no labels ' + \
'are used by this model.'
label = utils.to_list(label)
assert all(isinstance(l, int) for l in label), '`label` can only consist of ' + \
'one or more python integers.'
assert len(label) == 1 or len(label) == num_images, '`label` can either ' + \
'specify one label to use for all images or a list of labels of the ' + \
'same length as number of images. Received {} labels '.format(len(label)) + \
'but {} images are to be generated.'.format(num_images)
if len(label) == 1:
latent_labels.fill_(label[0])
else:
latent_labels = torch.tensor(label).to(latent_labels)
self.Gs.set_truncation(
truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
with torch.no_grad():
generated = self.Gs(latents=latents, labels=latent_labels)
assert generated.dim() - 2 == 2, 'Can only generate images when using a ' + \
'network built for 2-dimensional data.'
assert generated.dim() == 4, 'Only generators that produce 2d data ' + \
'can be used to generate images.'
return utils.tensor_to_PIL(generated, pixel_min=pixel_min, pixel_max=pixel_max)
def log_images_tensorboard(self, images, name, resize=256):
"""
Log a list of images to tensorboard by first turning
them into a grid. Can not be performed if rank > 0
or tensorboard_log_dir was not given at construction.
Arguments:
images (list): List of PIL images.
name (str): The name to log images for.
resize (int, tuple): The height and width to use for
each image in the grid. Default value is 256.
"""
assert self.tb_writer is not None, \
'No tensorboard log dir was specified ' + \
'when constructing this object.'
image = utils.stack_images_PIL(images, individual_img_size=resize)
image = torchvision.transforms.ToTensor()(image)
self.tb_writer.add_image(name, image, self.seen)
def add_tensorboard_image_logging(self,
name,
interval,
num_images,
resize=256,
seed=None,
truncation_psi=None,
truncation_cutoff=None,
label=None,
pixel_min=-1,
pixel_max=1):
"""
Set up tensorboard logging of generated images to be performed
at a certain training interval. If distributed training is set up
and this object does not have the rank 0, no logging will be performed
by this object.
All arguments except the ones mentioned below have their description
in the docstring of `generate_images()` and `log_images_tensorboard()`.
Arguments:
interval (int): The interval at which to log generated images.
"""
if self.rank:
return
def callback(seen):
if seen % interval == 0:
images = self.generate_images(
num_images=num_images,
seed=seed,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
label=label,
pixel_min=pixel_min,
pixel_max=pixel_max
)
self.log_images_tensorboard(
images=images,
name=name,
resize=resize
)
self.callbacks.append(callback)
def save_checkpoint(self, dir_path):
"""
Save the current state of this trainer as a checkpoint.
NOTE: The dataset can not be serialized and saved so this
has to be reconstructed and given when loading this checkpoint.
Arguments:
dir_path (str): The checkpoint path.
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
else:
assert os.path.isdir(dir_path), '`dir_path` points to a file.'
kwargs = self.kwargs.copy()
# Update arguments that may have changed since construction
kwargs.update(
seen=self.seen,
pl_avg=float(self.pl_avg)
)
with open(os.path.join(dir_path, 'kwargs.json'), 'w') as fp:
json.dump(kwargs, fp)
torch.save(self.G_opt.state_dict(), os.path.join(dir_path, 'G_opt.pth'))
torch.save(self.D_opt.state_dict(), os.path.join(dir_path, 'D_opt.pth'))
models.save(self.G, os.path.join(dir_path, 'G.pth'))
models.save(self.D, os.path.join(dir_path, 'D.pth'))
if self.Gs is not None:
models.save(self.Gs, os.path.join(dir_path, 'Gs.pth'))
@classmethod
def load_checkpoint(cls, checkpoint_path, dataset, **kwargs):
"""
Load a checkpoint into a new Trainer object and return that
object. If the path specified points at a folder containing
multiple checkpoints, the latest one will be used.
The dataset can not be serialized and saved so it is required
to be explicitly given when loading a checkpoint.
Arguments:
checkpoint_path (str): Path to a checkpoint or to a folder
containing one or more checkpoints.
dataset (indexable): The dataset to use.
**kwargs (keyword arguments): Any other arguments to override
the ones saved in the checkpoint. Useful for when training
is continued on a different device or when distributed training
is changed.
"""
checkpoint_path = _find_checkpoint(checkpoint_path)
_is_checkpoint(checkpoint_path, enforce=True)
with open(os.path.join(checkpoint_path, 'kwargs.json'), 'r') as fp:
loaded_kwargs = json.load(fp)
loaded_kwargs.update(**kwargs)
device = torch.device('cpu')
if isinstance(loaded_kwargs['device'], (list, tuple)):
device = torch.device(loaded_kwargs['device'][0])
for name in ['G', 'D']:
fpath = os.path.join(checkpoint_path, name + '.pth')
loaded_kwargs[name] = models.load(fpath, map_location=device)
if os.path.exists(os.path.join(checkpoint_path, 'Gs.pth')):
loaded_kwargs['Gs'] = models.load(
os.path.join(checkpoint_path, 'Gs.pth'),
map_location=device if loaded_kwargs['Gs_device'] is None \
else torch.device(loaded_kwargs['Gs_device'])
)
obj = cls(dataset=dataset, **loaded_kwargs)
for name in ['G_opt', 'D_opt']:
fpath = os.path.join(checkpoint_path, name + '.pth')
state_dict = torch.load(fpath, map_location=device)
getattr(obj, name).load_state_dict(state_dict)
return obj
#----------------------------------------------------------------------------
# Checkpoint helper functions
def _is_checkpoint(dir_path, enforce=False):
if not dir_path:
if enforce:
raise ValueError('Not a checkpoint.')
return False
if not os.path.exists(dir_path):
if enforce:
raise FileNotFoundError('{} could not be found.'.format(dir_path))
return False
if not os.path.isdir(dir_path):
if enforce:
raise NotADirectoryError('{} is not a directory.'.format(dir_path))
return False
fnames = os.listdir(dir_path)
for fname in ['G.pth', 'D.pth', 'G_opt.pth', 'D_opt.pth', 'kwargs.json']:
if fname not in fnames:
if enforce:
raise FileNotFoundError(
'Could not find {} in {}.'.format(fname, dir_path))
return False
return True
def _find_checkpoint(dir_path):
if not dir_path:
return None
if not os.path.exists(dir_path) or not os.path.isdir(dir_path):
return None
if _is_checkpoint(dir_path):
return dir_path
checkpoint_names = []
for name in os.listdir(dir_path):
if _is_checkpoint(os.path.join(dir_path, name)):
checkpoint_names.append(name)
if not checkpoint_names:
return None
def get_iteration(name):
return int(name.split('_')[0])
def get_timestamp(name):
return '_'.join(name.split('_')[1:])
# Python sort is stable, meaning that this sort operation
# will guarantee that the order of values after the first
# sort will stay for a set of values that have the same
# key value.
checkpoint_names = sorted(
sorted(checkpoint_names, key=get_iteration), key=get_timestamp)
return os.path.join(dir_path, checkpoint_names[-1])
#----------------------------------------------------------------------------
# Reg and loss function fetchers
def build_opt(net, opt_class, opt_kwargs, reg, reg_interval):
opt_kwargs['lr'] = opt_kwargs.get('lr', 1e-3)
if reg not in [None, False] and reg_interval:
mb_ratio = reg_interval / (reg_interval + 1.)
opt_kwargs['lr'] *= mb_ratio
if 'momentum' in opt_kwargs:
opt_kwargs['momentum'] = opt_kwargs['momentum'] ** mb_ratio
if 'betas' in opt_kwargs:
betas = opt_kwargs['betas']
opt_kwargs['betas'] = (betas[0] ** mb_ratio, betas[1] ** mb_ratio)
if isinstance(opt_class, str):
opt_class = getattr(torch.optim, opt_class.title())
return opt_class(net.parameters(), **opt_kwargs)
#----------------------------------------------------------------------------
# Reg and loss function fetchers
_LOSS_FNS = {
'G': {
'logistic': loss_fns.G_logistic,
'logistic_ns': loss_fns.G_logistic_ns,
'wgan': loss_fns.G_wgan
},
'D': {
'logistic': loss_fns.D_logistic,
'wgan': loss_fns.D_wgan
}
}
def get_loss_fn(net, loss):
if callable(loss):
return loss
net = net.upper()
assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
loss = loss.lower()
for name in _LOSS_FNS[net].keys():
if loss == name:
return _LOSS_FNS[net][name]
raise ValueError('Unknow {} loss {}'.format(net, loss))
_REG_FNS = {
'G': {
'pathreg': loss_fns.G_pathreg
},
'D': {
'r1': loss_fns.D_r1,
'r2': loss_fns.D_r2,
'gp': loss_fns.D_gp,
}
}
def get_reg_fn(net, reg, **kwargs):
if reg is None:
return None
if callable(reg):
functools.partial(reg, **kwargs)
net = net.upper()
assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
reg = reg.lower()
gamma = None
for name in _REG_FNS[net].keys():
if reg.startswith(name):
gamma_chars = [c for c in reg.replace(name, '') if c.isdigit() or c == '.']
if gamma_chars:
kwargs.update(gamma=float(''.join(gamma_chars)))
return functools.partial(_REG_FNS[net][name], **kwargs)
raise ValueError('Unknow regularizer {}'.format(reg))