Spaces:
Runtime error
Runtime error
import warnings | |
import functools | |
import os | |
import time | |
import sys | |
import json | |
import numpy as np | |
import torch | |
import torch.utils.tensorboard | |
from torch import nn | |
import torchvision | |
try: | |
import apex | |
from apex import amp | |
except ImportError: | |
pass | |
from . import models, utils, loss_fns | |
class Trainer: | |
""" | |
Class that handles training and logging for stylegan2. | |
For distributed training, the arguments `rank`, `world_size`, | |
`master_addr`, `master_port` can all be given as environmnet variables | |
(only difference is that the keys should be capital cased). | |
Environment variables if available will override any python | |
value for the same argument. | |
Arguments: | |
G (Generator): The generator model. | |
D (Discriminator): The discriminator model. | |
latent_size (int): The size of the latent inputs. | |
dataset (indexable object): The dataset. Has to implement | |
'__getitem__' and '__len__'. If `label_size` > 0, this | |
dataset object has to return both a data entry and its | |
label when calling '__getitem__'. | |
device (str, int, list, torch.device): The device to run training on. | |
Can be a list of integers for parallel training in the same | |
process. Parallel training can also be achieved by spawning | |
seperate processes and using the `rank` argument for each | |
process. In that case, only one device should be specified | |
per process. | |
Gs (Generator, optional): A generator copy with the current | |
moving average of the training generator. If not specified, | |
a copy of the generator is made for the moving average of | |
weights. | |
Gs_beta (float): The beta value for the moving average weights. | |
Default value is 1 / (2 ^(32 / 10000)). | |
Gs_device (str, int, torch.device, optional): The device to store | |
the moving average weights on. If using a different device | |
than what is specified for the `device` argument, updating | |
the moving average weights will take longer as the data | |
will have to be transfered over different devices. If | |
this argument is not specified, the same device is used | |
as specified in the `device` argument. | |
batch_size (int): The total batch size to average gradients | |
over. This should be the combined batch size of all used | |
devices (it is later divided by world size for distributed | |
training). | |
Example: We want to average gradients over 32 data | |
entries. To do this we just set `batch_size=32`. | |
Even if we train on 8 GPUs we still use the same | |
batch size (each GPU will take 4 data entries per | |
batch). | |
Default value is 32. | |
device_batch_size (int): The number of data entries that can | |
fit on the specified device at a time. | |
Example: We want to average gradients over 32 data | |
entries. To do this we just set `batch_size=32`. | |
However, our device can only handle a batch of | |
4 at a time before running out of memory. We | |
therefor set `device_batch_size=4`. With a | |
single device (no distributed training), each | |
batch is split into 32 / 4 parts and gradients | |
are averaged over all these parts. | |
Default value is 4. | |
label_size (int, optional): Number of possible class labels. | |
This is required for conditioning the GAN with labels. | |
If not specified it is assumed that no labels are used. | |
data_workers (int): The number of spawned processes that | |
handle data loading. Default value is 4. | |
G_loss (str, callable): The loss function to use | |
for the generator. If string, it can be one of the | |
following: 'logistic', 'logistic_ns' or 'wgan'. | |
If not a string, the callable has to follow | |
the format of functions found in `stylegan2.loss`. | |
Default value is 'logistic_ns' (non-saturating logistic). | |
D_loss (str, callable): The loss function to use | |
for the discriminator. If string, it can be one of the | |
following: 'logistic' or 'wgan'. | |
If not a string, same restriction follows as for `G_loss`. | |
Default value is 'logistic'. | |
G_reg (str, callable, None): The regularizer function to use | |
for the generator. If string, it can only be 'pathreg' | |
(pathlength regularization). A weight for the regularizer | |
can be passed after the string name like the following: | |
G_reg='pathreg:5' | |
This will assign a weight of 5 to the regularization loss. | |
If set to None, no geenerator regularization is performed. | |
Default value is 'pathreg:2'. | |
G_reg_interval (int): The interval at which to regularize the | |
generator. If set to 0, the regularization and loss gradients | |
are combined in a single optimization step every iteration. | |
If set to 1, the gradients for the regularization and loss | |
are used separately for two optimization steps. Any value | |
higher than 1 indicates that regularization should only | |
be performed at this interval (lazy regularization). | |
Default value is 4. | |
G_opt_class (str, class): The optimizer class for the generator. | |
Default value is 'Adam'. | |
G_opt_kwargs (dict): Keyword arguments for the generator optimizer | |
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}. | |
G_reg_batch_size (int): Same as `batch_size` but only for | |
the regularization loss of the generator. Default value | |
is 16. | |
G_reg_device_batch_size (int): Same as `device_batch_size` | |
but only for the regularization loss of the generator. | |
Default value is 2. | |
D_reg (str, callable, None): The regularizer function to use | |
for the discriminator. If string, the following values | |
can be used: 'r1', 'r2', 'gp'. See doc for `G_reg` for | |
rest of info on regularizer format. | |
Default value is 'r1:10'. | |
D_reg_interval (int): Same as `D_reg_interval` but for the | |
discriminator. Default value is 16. | |
D_opt_class (str, class): The optimizer class for the discriminator. | |
Default value is 'Adam'. | |
D_opt_kwargs (dict): Keyword arguments for the discriminator optimizer | |
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}. | |
style_mix_prob (float): The probability of passing 2 latents instead of 1 | |
to the generator during training. Default value is 0.9. | |
G_iter (int): Number of generator iterations for every full training | |
iteration. Default value is 1. | |
D_iter (int): Number of discriminator iterations for every full training | |
iteration. Default value is 1. | |
pl_avg (float, torch.Tensor): The average pathlength starting value for | |
pathlength regularization of the generator. Default value is 0. | |
tensorboard_log_dir (str, optional): A path to a directory to log training values | |
in for tensorboard. Only used without distributed training or when | |
distributed training is enabled and the rank of this trainer is 0. | |
checkpoint_dir (str, optional): A path to a directory to save training | |
checkpoints to. If not specified, not checkpoints are automatically | |
saved during training. | |
checkpoint_interval (int): The interval at which to save training checkpoints. | |
Default value is 10000. | |
seen (int): The number of previously trained iterations. Used for logging. | |
Default value is 0. | |
half (bool): Use mixed precision training. Default value is False. | |
rank (int, optional): If set, use distributed training. Expects that | |
this object has been constructed with the same arguments except | |
for `rank` in different processes. | |
world_size (int, optional): If using distributed training, this specifies | |
the number of nodes in the training. | |
master_addr (str): The master address for distributed training. | |
Default value is '127.0.0.1'. | |
master_port (str): The master port for distributed training. | |
Default value is '23456'. | |
""" | |
def __init__(self, | |
G, | |
D, | |
latent_size, | |
dataset, | |
device, | |
Gs=None, | |
Gs_beta=0.5 ** (32 / 10000), | |
Gs_device=None, | |
batch_size=32, | |
device_batch_size=4, | |
label_size=0, | |
data_workers=4, | |
G_loss='logistic_ns', | |
D_loss='logistic', | |
G_reg='pathreg:2', | |
G_reg_interval=4, | |
G_opt_class='Adam', | |
G_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)}, | |
G_reg_batch_size=None, | |
G_reg_device_batch_size=None, | |
D_reg='r1:10', | |
D_reg_interval=16, | |
D_opt_class='Adam', | |
D_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)}, | |
style_mix_prob=0.9, | |
G_iter=1, | |
D_iter=1, | |
pl_avg=0., | |
tensorboard_log_dir=None, | |
checkpoint_dir=None, | |
checkpoint_interval=10000, | |
seen=0, | |
half=False, | |
rank=None, | |
world_size=None, | |
master_addr='127.0.0.1', | |
master_port='23456'): | |
assert not isinstance(G, nn.parallel.DistributedDataParallel) and \ | |
not isinstance(D, nn.parallel.DistributedDataParallel), \ | |
'Encountered a model wrapped in `DistributedDataParallel`. ' + \ | |
'Distributed parallelism is handled by this class and can ' + \ | |
'not be initialized before.' | |
# We store the training settings in a dict that can be saved as a json file. | |
kwargs = locals() | |
# First we remove the arguments that can not be turned into json. | |
kwargs.pop('self') | |
kwargs.pop('G') | |
kwargs.pop('D') | |
kwargs.pop('Gs') | |
kwargs.pop('dataset') | |
# Some arguments may have to be turned into strings to be compatible with json. | |
kwargs.update(pl_avg=float(pl_avg)) | |
if isinstance(device, torch.device): | |
kwargs.update(device=str(device)) | |
if isinstance(Gs_device, torch.device): | |
kwargs.update(device=str(Gs_device)) | |
self.kwargs = kwargs | |
if device or device == 0: | |
if isinstance(device, (tuple, list)): | |
self.device = torch.device(device[0]) | |
else: | |
self.device = torch.device(device) | |
else: | |
self.device = torch.device('cpu') | |
if self.device.index is not None: | |
torch.cuda.set_device(self.device.index) | |
else: | |
assert not half, 'Mixed precision training only available ' + \ | |
'for CUDA devices.' | |
# Set up the models | |
self.G = G.train().to(self.device) | |
self.D = D.train().to(self.device) | |
if isinstance(device, (tuple, list)) and len(device) > 1: | |
assert all(isinstance(dev, int) for dev in device), \ | |
'Multiple devices have to be specified as a list ' + \ | |
'or tuple of integers corresponding to device indices.' | |
# TODO: Look into bug with torch.autograd.grad and nn.DataParallel | |
# In the meanwhile just prohibit its use together. | |
assert G_reg is None and D_reg is None, 'Regularization ' + \ | |
'currently not supported for multi-gpu training in single process. ' + \ | |
'Please use distributed training with one device per process instead.' | |
device_batch_size *= len(device) | |
def to_data_parallel(model): | |
if not isinstance(model, nn.DataParallel): | |
return nn.DataParallel(model, device_ids=device) | |
return model | |
self.G = to_data_parallel(self.G) | |
self.D = to_data_parallel(self.D) | |
# Default generator reg batch size is the global batch size | |
# unless it has been specified otherwise. | |
G_reg_batch_size = G_reg_batch_size or batch_size | |
G_reg_device_batch_size = G_reg_device_batch_size or device_batch_size | |
# Set up distributed training | |
rank = os.environ.get('RANK', rank) | |
if rank is not None: | |
rank = int(rank) | |
addr = os.environ.get('MASTER_ADDR', master_addr) | |
port = os.environ.get('MASTER_PORT', master_port) | |
world_size = os.environ.get('WORLD_SIZE', world_size) | |
assert world_size is not None, 'Distributed training ' + \ | |
'requires specifying world size.' | |
world_size = int(world_size) | |
assert self.device.index is not None, \ | |
'Distributed training is only supported for CUDA.' | |
assert batch_size % world_size == 0, 'Batch size has to be ' + \ | |
'evenly divisible by world size.' | |
assert G_reg_batch_size % world_size == 0, 'G reg batch size has to be ' + \ | |
'evenly divisible by world size.' | |
batch_size = batch_size // world_size | |
G_reg_batch_size = G_reg_batch_size // world_size | |
init_method = 'tcp://{}:{}'.format(addr, port) | |
torch.distributed.init_process_group( | |
backend='nccl', init_method=init_method, rank=rank, world_size=world_size) | |
else: | |
world_size = 1 | |
self.rank = rank | |
self.world_size = world_size | |
# Set up variable to keep track of moving average of path lengths | |
self.pl_avg = torch.tensor( | |
pl_avg, dtype=torch.float16 if half else torch.float32, device=self.device) | |
# Broadcast parameters from rank 0 if running distributed | |
self._sync_distributed(G=self.G, D=self.D, broadcast_weights=True) | |
# Set up moving average of generator | |
# Only for non-distributed training or | |
# if rank is 0 | |
if not self.rank: | |
# Values for `rank`: None -> not distributed, 0 -> distributed and 'main' node | |
self.Gs = Gs | |
if not isinstance(Gs, utils.MovingAverageModule): | |
self.Gs = utils.MovingAverageModule( | |
from_module=self.G, | |
to_module=Gs, | |
param_beta=Gs_beta, | |
device=self.device if Gs_device is None else Gs_device | |
) | |
else: | |
self.Gs = None | |
# Set up loss and regularization functions | |
self.G_loss = get_loss_fn('G', G_loss) | |
self.D_loss = get_loss_fn('D', D_loss) | |
self.G_reg = get_reg_fn('G', G_reg, pl_avg=self.pl_avg) | |
self.D_reg = get_reg_fn('D', D_reg) | |
self.G_reg_interval = G_reg_interval | |
self.D_reg_interval = D_reg_interval | |
self.G_iter = G_iter | |
self.D_iter = D_iter | |
# Set up optimizers (adjust hyperparameters if lazy regularization is active) | |
self.G_opt = build_opt(self.G, G_opt_class, G_opt_kwargs, self.G_reg, self.G_reg_interval) | |
self.D_opt = build_opt(self.D, D_opt_class, D_opt_kwargs, self.D_reg, self.D_reg_interval) | |
# Set up mixed precision training | |
if half: | |
assert 'apex' in sys.modules, 'Can not run mixed precision ' + \ | |
'training (`half=True`) without the apex module.' | |
(self.G, self.D), (self.G_opt, self.D_opt) = amp.initialize( | |
[self.G, self.D], [self.G_opt, self.D_opt], opt_level='O1') | |
self.half = half | |
# Data | |
sampler = None | |
if self.rank is not None: | |
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True) | |
self.dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=device_batch_size, | |
num_workers=data_workers, | |
shuffle=sampler is None, | |
pin_memory=self.device.index is not None, | |
drop_last=True, | |
sampler=sampler | |
) | |
self.dataloader_iter = None | |
self.prior_generator = utils.PriorGenerator( | |
latent_size=latent_size, | |
label_size=label_size, | |
batch_size=device_batch_size, | |
device=self.device | |
) | |
assert batch_size % device_batch_size == 0, \ | |
'Batch size has to be evenly divisible by the product of ' + \ | |
'device batch size and world size.' | |
self.subdivisions = batch_size // device_batch_size | |
assert G_reg_batch_size % G_reg_device_batch_size == 0, \ | |
'G reg batch size has to be evenly divisible by the product of ' + \ | |
'G reg device batch size and world size.' | |
self.G_reg_subdivisions = G_reg_batch_size // G_reg_device_batch_size | |
self.G_reg_device_batch_size = G_reg_device_batch_size | |
self.tb_writer = None | |
if tensorboard_log_dir and not self.rank: | |
self.tb_writer = torch.utils.tensorboard.SummaryWriter(tensorboard_log_dir) | |
self.label_size = label_size | |
self.style_mix_prob = style_mix_prob | |
self.checkpoint_dir = checkpoint_dir | |
self.checkpoint_interval = checkpoint_interval | |
self.seen = seen | |
self.metrics = {} | |
self.callbacks = [] | |
def _get_batch(self): | |
""" | |
Fetch a batch and its labels. If no labels are | |
available the returned labels will be `None`. | |
Returns: | |
data | |
labels | |
""" | |
if self.dataloader_iter is None: | |
self.dataloader_iter = iter(self.dataloader) | |
try: | |
batch = next(self.dataloader_iter) | |
except StopIteration: | |
self.dataloader_iter = None | |
return self._get_batch() | |
if isinstance(batch, (tuple, list)): | |
if len(batch) > 1: | |
data, label = batch[:2] | |
else: | |
data, label = batch[0], None | |
else: | |
data, label = batch, None | |
if not self.label_size: | |
label = None | |
if torch.is_tensor(data): | |
data = data.to(self.device) | |
if torch.is_tensor(label): | |
label = label.to(self.device) | |
return data, label | |
def _sync_distributed(self, G=None, D=None, broadcast_weights=False): | |
""" | |
Sync the gradients (and alternatively the weights) of | |
the specified networks over the distributed training | |
nodes. Varying buffers are broadcasted from rank 0. | |
If no distributed training is not enabled, no action | |
is taken and this is a no-op function. | |
Arguments: | |
G (Generator, optional) | |
D (Discriminator, optional) | |
broadcast_weights (bool): Broadcast the weights from | |
node of rank 0 to all other ranks. Default | |
value is False. | |
""" | |
if self.rank is None: | |
return | |
for net in [G, D]: | |
if net is None: | |
continue | |
for p in net.parameters(): | |
if p.grad is not None: | |
torch.distributed.all_reduce(p.grad, async_op=True) | |
if broadcast_weights: | |
torch.distributed.broadcast(p.data, src=0, async_op=True) | |
if G is not None: | |
if G.dlatent_avg is not None: | |
torch.distributed.broadcast(G.dlatent_avg, src=0, async_op=True) | |
if self.pl_avg is not None: | |
torch.distributed.broadcast(self.pl_avg, src=0, async_op=True) | |
if G is not None or D is not None: | |
torch.distributed.barrier(async_op=False) | |
def _backward(self, loss, opt, mul=1, subdivisions=None): | |
""" | |
Reduce loss by world size and subdivisions before | |
calling backward for the loss. Loss scaling is | |
performed when mixed precision training is | |
enabled. | |
Arguments: | |
loss (torch.Tensor) | |
opt (torch.optim.Optimizer) | |
mul (float): Loss weight. Default value is 1. | |
subdivisions (int, optional): The number of | |
subdivisions to divide by. If this is | |
not specified, the subdvisions from | |
the specified batch and device size | |
at construction is used. | |
Returns: | |
loss (torch.Tensor): The loss scaled by mul | |
and subdivisions but not by world size. | |
""" | |
if loss is None: | |
return 0 | |
mul /= subdivisions or self.subdivisions | |
mul /= self.world_size or 1 | |
if mul != 1: | |
loss *= mul | |
if self.half: | |
with amp.scale_loss(loss, opt) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss.backward() | |
#get the scalar only | |
return loss.item() * (self.world_size or 1) | |
def train(self, iterations, callbacks=None, verbose=True): | |
""" | |
Train the models for a specific number of iterations. | |
Arguments: | |
iterations (int): Number of iterations to train for. | |
callbacks (callable, list, optional): One | |
or more callbacks to call at the end of each | |
iteration. The function is given the total | |
number of batches that have been processed since | |
this trainer object was initialized (not reset | |
when loading a saved checkpoint). | |
Default value is None (unused). | |
verbose (bool): Write progress to stdout. | |
Default value is True. | |
""" | |
evaluated_metrics = {} | |
if self.rank: | |
verbose=False | |
if verbose: | |
progress = utils.ProgressWriter(iterations) | |
value_tracker = utils.ValueTracker() | |
for _ in range(iterations): | |
# Figure out if G and/or D be | |
# regularized this iteration | |
G_reg = self.G_reg is not None | |
if self.G_reg_interval and G_reg: | |
G_reg = self.seen % self.G_reg_interval == 0 | |
D_reg = self.D_reg is not None | |
if self.D_reg_interval and D_reg: | |
D_reg = self.seen % self.D_reg_interval == 0 | |
# -----| Train G |----- # | |
# Disable gradients for D while training G | |
self.D.requires_grad_(False) | |
for _ in range(self.G_iter): | |
self.G_opt.zero_grad() | |
G_loss = 0 | |
for i in range(self.subdivisions): | |
latents, latent_labels = self.prior_generator( | |
multi_latent_prob=self.style_mix_prob) | |
loss, _ = self.G_loss( | |
G=self.G, | |
D=self.D, | |
latents=latents, | |
latent_labels=latent_labels | |
) | |
G_loss += self._backward(loss, self.G_opt) | |
if G_reg: | |
if self.G_reg_interval: | |
# For lazy regularization, even if the interval | |
# is set to 1, the optimization step is taken | |
# before the gradients of the regularization is gathered. | |
self._sync_distributed(G=self.G) | |
self.G_opt.step() | |
self.G_opt.zero_grad() | |
G_reg_loss = 0 | |
# Pathreg is expensive to compute which | |
# is why G regularization has its own settings | |
# for subdivisions and batch size. | |
for i in range(self.G_reg_subdivisions): | |
latents, latent_labels = self.prior_generator( | |
batch_size=self.G_reg_device_batch_size, | |
multi_latent_prob=self.style_mix_prob | |
) | |
_, reg_loss = self.G_reg( | |
G=self.G, | |
latents=latents, | |
latent_labels=latent_labels | |
) | |
G_reg_loss += self._backward( | |
reg_loss, | |
self.G_opt, mul=self.G_reg_interval or 1, | |
subdivisions=self.G_reg_subdivisions | |
) | |
self._sync_distributed(G=self.G) | |
self.G_opt.step() | |
# Update moving average of weights after | |
# each G training subiteration | |
if self.Gs is not None: | |
self.Gs.update() | |
# Re-enable gradients for D | |
self.D.requires_grad_(True) | |
# -----| Train D |----- # | |
# Disable gradients for G while training D | |
self.G.requires_grad_(False) | |
for _ in range(self.D_iter): | |
self.D_opt.zero_grad() | |
D_loss = 0 | |
for i in range(self.subdivisions): | |
latents, latent_labels = self.prior_generator( | |
multi_latent_prob=self.style_mix_prob) | |
reals, real_labels = self._get_batch() | |
loss, _ = self.D_loss( | |
G=self.G, | |
D=self.D, | |
latents=latents, | |
latent_labels=latent_labels, | |
reals=reals, | |
real_labels=real_labels | |
) | |
D_loss += self._backward(loss, self.D_opt) | |
if D_reg: | |
if self.D_reg_interval: | |
# For lazy regularization, even if the interval | |
# is set to 1, the optimization step is taken | |
# before the gradients of the regularization is gathered. | |
self._sync_distributed(D=self.D) | |
self.D_opt.step() | |
self.D_opt.zero_grad() | |
D_reg_loss = 0 | |
for i in range(self.subdivisions): | |
latents, latent_labels = self.prior_generator( | |
multi_latent_prob=self.style_mix_prob) | |
reals, real_labels = self._get_batch() | |
_, reg_loss = self.D_reg( | |
G=self.G, | |
D=self.D, | |
latents=latents, | |
latent_labels=latent_labels, | |
reals=reals, | |
real_labels=real_labels | |
) | |
D_reg_loss += self._backward( | |
reg_loss, self.D_opt, mul=self.D_reg_interval or 1) | |
self._sync_distributed(D=self.D) | |
self.D_opt.step() | |
# Re-enable grads for G | |
self.G.requires_grad_(True) | |
if self.tb_writer is not None or verbose: | |
# In case verbose is true and tensorboard logging enabled | |
# we calculate grad norm here to only do it once as well | |
# as making sure we do it before any metrics that may | |
# possibly zero the grads. | |
G_grad_norm = utils.get_grad_norm_from_optimizer(self.G_opt) | |
D_grad_norm = utils.get_grad_norm_from_optimizer(self.D_opt) | |
for name, metric in self.metrics.items(): | |
if not metric['interval'] or self.seen % metric['interval'] == 0: | |
evaluated_metrics[name] = metric['eval_fn']() | |
# Printing and logging | |
# Tensorboard logging | |
if self.tb_writer is not None: | |
self.tb_writer.add_scalar('Loss/G_loss', G_loss, self.seen) | |
if G_reg: | |
self.tb_writer.add_scalar('Loss/G_reg', G_reg_loss, self.seen) | |
self.tb_writer.add_scalar('Grad_norm/G_reg', G_grad_norm, self.seen) | |
self.tb_writer.add_scalar('Params/pl_avg', self.pl_avg, self.seen) | |
else: | |
self.tb_writer.add_scalar('Grad_norm/G_loss', G_grad_norm, self.seen) | |
self.tb_writer.add_scalar('Loss/D_loss', D_loss, self.seen) | |
if D_reg: | |
self.tb_writer.add_scalar('Loss/D_reg', D_reg_loss, self.seen) | |
self.tb_writer.add_scalar('Grad_norm/D_reg', D_grad_norm, self.seen) | |
else: | |
self.tb_writer.add_scalar('Grad_norm/D_loss', D_grad_norm, self.seen) | |
for name, value in evaluated_metrics.items(): | |
self.tb_writer.add_scalar('Metrics/{}'.format(name), value, self.seen) | |
# Printing | |
if verbose: | |
value_tracker.add('seen', self.seen + 1, beta=0) | |
value_tracker.add('G_lr', self.G_opt.param_groups[0]['lr'], beta=0) | |
value_tracker.add('G_loss', G_loss) | |
if G_reg: | |
value_tracker.add('G_reg', G_reg_loss) | |
value_tracker.add('G_reg_grad_norm', G_grad_norm) | |
value_tracker.add('pl_avg', self.pl_avg, beta=0) | |
else: | |
value_tracker.add('G_loss_grad_norm', G_grad_norm) | |
value_tracker.add('D_lr', self.D_opt.param_groups[0]['lr'], beta=0) | |
value_tracker.add('D_loss', D_loss) | |
if D_reg: | |
value_tracker.add('D_reg', D_reg_loss) | |
value_tracker.add('D_reg_grad_norm', D_grad_norm) | |
else: | |
value_tracker.add('D_loss_grad_norm', D_grad_norm) | |
for name, value in evaluated_metrics.items(): | |
value_tracker.add(name, value, beta=0) | |
progress.write(str(value_tracker)) | |
# Callback | |
for callback in utils.to_list(callbacks) + self.callbacks: | |
callback(self.seen) | |
self.seen += 1 | |
# clear cache | |
torch.cuda.empty_cache() | |
# Handle checkpointing | |
if not self.rank and self.checkpoint_dir and self.checkpoint_interval: | |
if self.seen % self.checkpoint_interval == 0: | |
checkpoint_path = os.path.join( | |
self.checkpoint_dir, | |
'{}_{}'.format(self.seen, time.strftime('%Y-%m-%d_%H-%M-%S')) | |
) | |
self.save_checkpoint(checkpoint_path) | |
if verbose: | |
progress.close() | |
def register_metric(self, name, eval_fn, interval): | |
""" | |
Add a metric. This will be evaluated every `interval` | |
training iteration. Used by tensorboard and progress | |
updates written to stdout while training. | |
Arguments: | |
name (str): A name for the metric. If a metric with | |
this name already exists it will be overwritten. | |
eval_fn (callable): A function that evaluates the metric | |
and returns a python number. | |
interval (int): The interval to evaluate at. | |
""" | |
self.metrics[name] = {'eval_fn': eval_fn, 'interval': interval} | |
def remove_metric(self, name): | |
""" | |
Remove a metric that was previously registered. | |
Arguments: | |
name (str): Name of the metric. | |
""" | |
if name in self.metrics: | |
del self.metrics[name] | |
else: | |
warnings.warn( | |
'Attempting to remove metric {} '.format(name) + \ | |
'which does not exist.' | |
) | |
def generate_images(self, | |
num_images, | |
seed=None, | |
truncation_psi=None, | |
truncation_cutoff=None, | |
label=None, | |
pixel_min=-1, | |
pixel_max=1): | |
""" | |
Generate some images with the generator and transform them into PIL | |
images and return them as a list. | |
Arguments: | |
num_images (int): Number of images to generate. | |
seed (int, optional): The seed for the random generation | |
of input latent values. | |
truncation_psi (float): See stylegan2.model.Generator.set_truncation() | |
Default value is None. | |
truncation_cutoff (int): See stylegan2.model.Generator.set_truncation() | |
label (int, list, optional): Label to condition all generated images with | |
or multiple labels, one for each generated image. | |
pixel_min (float): The min value in the pixel range of the generator. | |
Default value is -1. | |
pixel_min (float): The max value in the pixel range of the generator. | |
Default value is 1. | |
Returns: | |
images (list): List of PIL images. | |
""" | |
if seed is None: | |
seed = int(10000 * time.time()) | |
latents, latent_labels = self.prior_generator(num_images, seed=seed) | |
if label: | |
assert latent_labels is not None, 'Can not specify label when no labels ' + \ | |
'are used by this model.' | |
label = utils.to_list(label) | |
assert all(isinstance(l, int) for l in label), '`label` can only consist of ' + \ | |
'one or more python integers.' | |
assert len(label) == 1 or len(label) == num_images, '`label` can either ' + \ | |
'specify one label to use for all images or a list of labels of the ' + \ | |
'same length as number of images. Received {} labels '.format(len(label)) + \ | |
'but {} images are to be generated.'.format(num_images) | |
if len(label) == 1: | |
latent_labels.fill_(label[0]) | |
else: | |
latent_labels = torch.tensor(label).to(latent_labels) | |
self.Gs.set_truncation( | |
truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff) | |
with torch.no_grad(): | |
generated = self.Gs(latents=latents, labels=latent_labels) | |
assert generated.dim() - 2 == 2, 'Can only generate images when using a ' + \ | |
'network built for 2-dimensional data.' | |
assert generated.dim() == 4, 'Only generators that produce 2d data ' + \ | |
'can be used to generate images.' | |
return utils.tensor_to_PIL(generated, pixel_min=pixel_min, pixel_max=pixel_max) | |
def log_images_tensorboard(self, images, name, resize=256): | |
""" | |
Log a list of images to tensorboard by first turning | |
them into a grid. Can not be performed if rank > 0 | |
or tensorboard_log_dir was not given at construction. | |
Arguments: | |
images (list): List of PIL images. | |
name (str): The name to log images for. | |
resize (int, tuple): The height and width to use for | |
each image in the grid. Default value is 256. | |
""" | |
assert self.tb_writer is not None, \ | |
'No tensorboard log dir was specified ' + \ | |
'when constructing this object.' | |
image = utils.stack_images_PIL(images, individual_img_size=resize) | |
image = torchvision.transforms.ToTensor()(image) | |
self.tb_writer.add_image(name, image, self.seen) | |
def add_tensorboard_image_logging(self, | |
name, | |
interval, | |
num_images, | |
resize=256, | |
seed=None, | |
truncation_psi=None, | |
truncation_cutoff=None, | |
label=None, | |
pixel_min=-1, | |
pixel_max=1): | |
""" | |
Set up tensorboard logging of generated images to be performed | |
at a certain training interval. If distributed training is set up | |
and this object does not have the rank 0, no logging will be performed | |
by this object. | |
All arguments except the ones mentioned below have their description | |
in the docstring of `generate_images()` and `log_images_tensorboard()`. | |
Arguments: | |
interval (int): The interval at which to log generated images. | |
""" | |
if self.rank: | |
return | |
def callback(seen): | |
if seen % interval == 0: | |
images = self.generate_images( | |
num_images=num_images, | |
seed=seed, | |
truncation_psi=truncation_psi, | |
truncation_cutoff=truncation_cutoff, | |
label=label, | |
pixel_min=pixel_min, | |
pixel_max=pixel_max | |
) | |
self.log_images_tensorboard( | |
images=images, | |
name=name, | |
resize=resize | |
) | |
self.callbacks.append(callback) | |
def save_checkpoint(self, dir_path): | |
""" | |
Save the current state of this trainer as a checkpoint. | |
NOTE: The dataset can not be serialized and saved so this | |
has to be reconstructed and given when loading this checkpoint. | |
Arguments: | |
dir_path (str): The checkpoint path. | |
""" | |
if not os.path.exists(dir_path): | |
os.makedirs(dir_path) | |
else: | |
assert os.path.isdir(dir_path), '`dir_path` points to a file.' | |
kwargs = self.kwargs.copy() | |
# Update arguments that may have changed since construction | |
kwargs.update( | |
seen=self.seen, | |
pl_avg=float(self.pl_avg) | |
) | |
with open(os.path.join(dir_path, 'kwargs.json'), 'w') as fp: | |
json.dump(kwargs, fp) | |
torch.save(self.G_opt.state_dict(), os.path.join(dir_path, 'G_opt.pth')) | |
torch.save(self.D_opt.state_dict(), os.path.join(dir_path, 'D_opt.pth')) | |
models.save(self.G, os.path.join(dir_path, 'G.pth')) | |
models.save(self.D, os.path.join(dir_path, 'D.pth')) | |
if self.Gs is not None: | |
models.save(self.Gs, os.path.join(dir_path, 'Gs.pth')) | |
def load_checkpoint(cls, checkpoint_path, dataset, **kwargs): | |
""" | |
Load a checkpoint into a new Trainer object and return that | |
object. If the path specified points at a folder containing | |
multiple checkpoints, the latest one will be used. | |
The dataset can not be serialized and saved so it is required | |
to be explicitly given when loading a checkpoint. | |
Arguments: | |
checkpoint_path (str): Path to a checkpoint or to a folder | |
containing one or more checkpoints. | |
dataset (indexable): The dataset to use. | |
**kwargs (keyword arguments): Any other arguments to override | |
the ones saved in the checkpoint. Useful for when training | |
is continued on a different device or when distributed training | |
is changed. | |
""" | |
checkpoint_path = _find_checkpoint(checkpoint_path) | |
_is_checkpoint(checkpoint_path, enforce=True) | |
with open(os.path.join(checkpoint_path, 'kwargs.json'), 'r') as fp: | |
loaded_kwargs = json.load(fp) | |
loaded_kwargs.update(**kwargs) | |
device = torch.device('cpu') | |
if isinstance(loaded_kwargs['device'], (list, tuple)): | |
device = torch.device(loaded_kwargs['device'][0]) | |
for name in ['G', 'D']: | |
fpath = os.path.join(checkpoint_path, name + '.pth') | |
loaded_kwargs[name] = models.load(fpath, map_location=device) | |
if os.path.exists(os.path.join(checkpoint_path, 'Gs.pth')): | |
loaded_kwargs['Gs'] = models.load( | |
os.path.join(checkpoint_path, 'Gs.pth'), | |
map_location=device if loaded_kwargs['Gs_device'] is None \ | |
else torch.device(loaded_kwargs['Gs_device']) | |
) | |
obj = cls(dataset=dataset, **loaded_kwargs) | |
for name in ['G_opt', 'D_opt']: | |
fpath = os.path.join(checkpoint_path, name + '.pth') | |
state_dict = torch.load(fpath, map_location=device) | |
getattr(obj, name).load_state_dict(state_dict) | |
return obj | |
#---------------------------------------------------------------------------- | |
# Checkpoint helper functions | |
def _is_checkpoint(dir_path, enforce=False): | |
if not dir_path: | |
if enforce: | |
raise ValueError('Not a checkpoint.') | |
return False | |
if not os.path.exists(dir_path): | |
if enforce: | |
raise FileNotFoundError('{} could not be found.'.format(dir_path)) | |
return False | |
if not os.path.isdir(dir_path): | |
if enforce: | |
raise NotADirectoryError('{} is not a directory.'.format(dir_path)) | |
return False | |
fnames = os.listdir(dir_path) | |
for fname in ['G.pth', 'D.pth', 'G_opt.pth', 'D_opt.pth', 'kwargs.json']: | |
if fname not in fnames: | |
if enforce: | |
raise FileNotFoundError( | |
'Could not find {} in {}.'.format(fname, dir_path)) | |
return False | |
return True | |
def _find_checkpoint(dir_path): | |
if not dir_path: | |
return None | |
if not os.path.exists(dir_path) or not os.path.isdir(dir_path): | |
return None | |
if _is_checkpoint(dir_path): | |
return dir_path | |
checkpoint_names = [] | |
for name in os.listdir(dir_path): | |
if _is_checkpoint(os.path.join(dir_path, name)): | |
checkpoint_names.append(name) | |
if not checkpoint_names: | |
return None | |
def get_iteration(name): | |
return int(name.split('_')[0]) | |
def get_timestamp(name): | |
return '_'.join(name.split('_')[1:]) | |
# Python sort is stable, meaning that this sort operation | |
# will guarantee that the order of values after the first | |
# sort will stay for a set of values that have the same | |
# key value. | |
checkpoint_names = sorted( | |
sorted(checkpoint_names, key=get_iteration), key=get_timestamp) | |
return os.path.join(dir_path, checkpoint_names[-1]) | |
#---------------------------------------------------------------------------- | |
# Reg and loss function fetchers | |
def build_opt(net, opt_class, opt_kwargs, reg, reg_interval): | |
opt_kwargs['lr'] = opt_kwargs.get('lr', 1e-3) | |
if reg not in [None, False] and reg_interval: | |
mb_ratio = reg_interval / (reg_interval + 1.) | |
opt_kwargs['lr'] *= mb_ratio | |
if 'momentum' in opt_kwargs: | |
opt_kwargs['momentum'] = opt_kwargs['momentum'] ** mb_ratio | |
if 'betas' in opt_kwargs: | |
betas = opt_kwargs['betas'] | |
opt_kwargs['betas'] = (betas[0] ** mb_ratio, betas[1] ** mb_ratio) | |
if isinstance(opt_class, str): | |
opt_class = getattr(torch.optim, opt_class.title()) | |
return opt_class(net.parameters(), **opt_kwargs) | |
#---------------------------------------------------------------------------- | |
# Reg and loss function fetchers | |
_LOSS_FNS = { | |
'G': { | |
'logistic': loss_fns.G_logistic, | |
'logistic_ns': loss_fns.G_logistic_ns, | |
'wgan': loss_fns.G_wgan | |
}, | |
'D': { | |
'logistic': loss_fns.D_logistic, | |
'wgan': loss_fns.D_wgan | |
} | |
} | |
def get_loss_fn(net, loss): | |
if callable(loss): | |
return loss | |
net = net.upper() | |
assert net in ['G', 'D'], 'Unknown net type {}'.format(net) | |
loss = loss.lower() | |
for name in _LOSS_FNS[net].keys(): | |
if loss == name: | |
return _LOSS_FNS[net][name] | |
raise ValueError('Unknow {} loss {}'.format(net, loss)) | |
_REG_FNS = { | |
'G': { | |
'pathreg': loss_fns.G_pathreg | |
}, | |
'D': { | |
'r1': loss_fns.D_r1, | |
'r2': loss_fns.D_r2, | |
'gp': loss_fns.D_gp, | |
} | |
} | |
def get_reg_fn(net, reg, **kwargs): | |
if reg is None: | |
return None | |
if callable(reg): | |
functools.partial(reg, **kwargs) | |
net = net.upper() | |
assert net in ['G', 'D'], 'Unknown net type {}'.format(net) | |
reg = reg.lower() | |
gamma = None | |
for name in _REG_FNS[net].keys(): | |
if reg.startswith(name): | |
gamma_chars = [c for c in reg.replace(name, '') if c.isdigit() or c == '.'] | |
if gamma_chars: | |
kwargs.update(gamma=float(''.join(gamma_chars))) | |
return functools.partial(_REG_FNS[net][name], **kwargs) | |
raise ValueError('Unknow regularizer {}'.format(reg)) | |