Spaces:
Runtime error
Runtime error
File size: 45,089 Bytes
480bfbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 |
import warnings
import functools
import os
import time
import sys
import json
import numpy as np
import torch
import torch.utils.tensorboard
from torch import nn
import torchvision
try:
import apex
from apex import amp
except ImportError:
pass
from . import models, utils, loss_fns
class Trainer:
"""
Class that handles training and logging for stylegan2.
For distributed training, the arguments `rank`, `world_size`,
`master_addr`, `master_port` can all be given as environmnet variables
(only difference is that the keys should be capital cased).
Environment variables if available will override any python
value for the same argument.
Arguments:
G (Generator): The generator model.
D (Discriminator): The discriminator model.
latent_size (int): The size of the latent inputs.
dataset (indexable object): The dataset. Has to implement
'__getitem__' and '__len__'. If `label_size` > 0, this
dataset object has to return both a data entry and its
label when calling '__getitem__'.
device (str, int, list, torch.device): The device to run training on.
Can be a list of integers for parallel training in the same
process. Parallel training can also be achieved by spawning
seperate processes and using the `rank` argument for each
process. In that case, only one device should be specified
per process.
Gs (Generator, optional): A generator copy with the current
moving average of the training generator. If not specified,
a copy of the generator is made for the moving average of
weights.
Gs_beta (float): The beta value for the moving average weights.
Default value is 1 / (2 ^(32 / 10000)).
Gs_device (str, int, torch.device, optional): The device to store
the moving average weights on. If using a different device
than what is specified for the `device` argument, updating
the moving average weights will take longer as the data
will have to be transfered over different devices. If
this argument is not specified, the same device is used
as specified in the `device` argument.
batch_size (int): The total batch size to average gradients
over. This should be the combined batch size of all used
devices (it is later divided by world size for distributed
training).
Example: We want to average gradients over 32 data
entries. To do this we just set `batch_size=32`.
Even if we train on 8 GPUs we still use the same
batch size (each GPU will take 4 data entries per
batch).
Default value is 32.
device_batch_size (int): The number of data entries that can
fit on the specified device at a time.
Example: We want to average gradients over 32 data
entries. To do this we just set `batch_size=32`.
However, our device can only handle a batch of
4 at a time before running out of memory. We
therefor set `device_batch_size=4`. With a
single device (no distributed training), each
batch is split into 32 / 4 parts and gradients
are averaged over all these parts.
Default value is 4.
label_size (int, optional): Number of possible class labels.
This is required for conditioning the GAN with labels.
If not specified it is assumed that no labels are used.
data_workers (int): The number of spawned processes that
handle data loading. Default value is 4.
G_loss (str, callable): The loss function to use
for the generator. If string, it can be one of the
following: 'logistic', 'logistic_ns' or 'wgan'.
If not a string, the callable has to follow
the format of functions found in `stylegan2.loss`.
Default value is 'logistic_ns' (non-saturating logistic).
D_loss (str, callable): The loss function to use
for the discriminator. If string, it can be one of the
following: 'logistic' or 'wgan'.
If not a string, same restriction follows as for `G_loss`.
Default value is 'logistic'.
G_reg (str, callable, None): The regularizer function to use
for the generator. If string, it can only be 'pathreg'
(pathlength regularization). A weight for the regularizer
can be passed after the string name like the following:
G_reg='pathreg:5'
This will assign a weight of 5 to the regularization loss.
If set to None, no geenerator regularization is performed.
Default value is 'pathreg:2'.
G_reg_interval (int): The interval at which to regularize the
generator. If set to 0, the regularization and loss gradients
are combined in a single optimization step every iteration.
If set to 1, the gradients for the regularization and loss
are used separately for two optimization steps. Any value
higher than 1 indicates that regularization should only
be performed at this interval (lazy regularization).
Default value is 4.
G_opt_class (str, class): The optimizer class for the generator.
Default value is 'Adam'.
G_opt_kwargs (dict): Keyword arguments for the generator optimizer
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
G_reg_batch_size (int): Same as `batch_size` but only for
the regularization loss of the generator. Default value
is 16.
G_reg_device_batch_size (int): Same as `device_batch_size`
but only for the regularization loss of the generator.
Default value is 2.
D_reg (str, callable, None): The regularizer function to use
for the discriminator. If string, the following values
can be used: 'r1', 'r2', 'gp'. See doc for `G_reg` for
rest of info on regularizer format.
Default value is 'r1:10'.
D_reg_interval (int): Same as `D_reg_interval` but for the
discriminator. Default value is 16.
D_opt_class (str, class): The optimizer class for the discriminator.
Default value is 'Adam'.
D_opt_kwargs (dict): Keyword arguments for the discriminator optimizer
constructor. Default value is {'lr': 2e-3, 'betas': (0, 0.99)}.
style_mix_prob (float): The probability of passing 2 latents instead of 1
to the generator during training. Default value is 0.9.
G_iter (int): Number of generator iterations for every full training
iteration. Default value is 1.
D_iter (int): Number of discriminator iterations for every full training
iteration. Default value is 1.
pl_avg (float, torch.Tensor): The average pathlength starting value for
pathlength regularization of the generator. Default value is 0.
tensorboard_log_dir (str, optional): A path to a directory to log training values
in for tensorboard. Only used without distributed training or when
distributed training is enabled and the rank of this trainer is 0.
checkpoint_dir (str, optional): A path to a directory to save training
checkpoints to. If not specified, not checkpoints are automatically
saved during training.
checkpoint_interval (int): The interval at which to save training checkpoints.
Default value is 10000.
seen (int): The number of previously trained iterations. Used for logging.
Default value is 0.
half (bool): Use mixed precision training. Default value is False.
rank (int, optional): If set, use distributed training. Expects that
this object has been constructed with the same arguments except
for `rank` in different processes.
world_size (int, optional): If using distributed training, this specifies
the number of nodes in the training.
master_addr (str): The master address for distributed training.
Default value is '127.0.0.1'.
master_port (str): The master port for distributed training.
Default value is '23456'.
"""
def __init__(self,
G,
D,
latent_size,
dataset,
device,
Gs=None,
Gs_beta=0.5 ** (32 / 10000),
Gs_device=None,
batch_size=32,
device_batch_size=4,
label_size=0,
data_workers=4,
G_loss='logistic_ns',
D_loss='logistic',
G_reg='pathreg:2',
G_reg_interval=4,
G_opt_class='Adam',
G_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
G_reg_batch_size=None,
G_reg_device_batch_size=None,
D_reg='r1:10',
D_reg_interval=16,
D_opt_class='Adam',
D_opt_kwargs={'lr': 2e-3, 'betas': (0, 0.99)},
style_mix_prob=0.9,
G_iter=1,
D_iter=1,
pl_avg=0.,
tensorboard_log_dir=None,
checkpoint_dir=None,
checkpoint_interval=10000,
seen=0,
half=False,
rank=None,
world_size=None,
master_addr='127.0.0.1',
master_port='23456'):
assert not isinstance(G, nn.parallel.DistributedDataParallel) and \
not isinstance(D, nn.parallel.DistributedDataParallel), \
'Encountered a model wrapped in `DistributedDataParallel`. ' + \
'Distributed parallelism is handled by this class and can ' + \
'not be initialized before.'
# We store the training settings in a dict that can be saved as a json file.
kwargs = locals()
# First we remove the arguments that can not be turned into json.
kwargs.pop('self')
kwargs.pop('G')
kwargs.pop('D')
kwargs.pop('Gs')
kwargs.pop('dataset')
# Some arguments may have to be turned into strings to be compatible with json.
kwargs.update(pl_avg=float(pl_avg))
if isinstance(device, torch.device):
kwargs.update(device=str(device))
if isinstance(Gs_device, torch.device):
kwargs.update(device=str(Gs_device))
self.kwargs = kwargs
if device or device == 0:
if isinstance(device, (tuple, list)):
self.device = torch.device(device[0])
else:
self.device = torch.device(device)
else:
self.device = torch.device('cpu')
if self.device.index is not None:
torch.cuda.set_device(self.device.index)
else:
assert not half, 'Mixed precision training only available ' + \
'for CUDA devices.'
# Set up the models
self.G = G.train().to(self.device)
self.D = D.train().to(self.device)
if isinstance(device, (tuple, list)) and len(device) > 1:
assert all(isinstance(dev, int) for dev in device), \
'Multiple devices have to be specified as a list ' + \
'or tuple of integers corresponding to device indices.'
# TODO: Look into bug with torch.autograd.grad and nn.DataParallel
# In the meanwhile just prohibit its use together.
assert G_reg is None and D_reg is None, 'Regularization ' + \
'currently not supported for multi-gpu training in single process. ' + \
'Please use distributed training with one device per process instead.'
device_batch_size *= len(device)
def to_data_parallel(model):
if not isinstance(model, nn.DataParallel):
return nn.DataParallel(model, device_ids=device)
return model
self.G = to_data_parallel(self.G)
self.D = to_data_parallel(self.D)
# Default generator reg batch size is the global batch size
# unless it has been specified otherwise.
G_reg_batch_size = G_reg_batch_size or batch_size
G_reg_device_batch_size = G_reg_device_batch_size or device_batch_size
# Set up distributed training
rank = os.environ.get('RANK', rank)
if rank is not None:
rank = int(rank)
addr = os.environ.get('MASTER_ADDR', master_addr)
port = os.environ.get('MASTER_PORT', master_port)
world_size = os.environ.get('WORLD_SIZE', world_size)
assert world_size is not None, 'Distributed training ' + \
'requires specifying world size.'
world_size = int(world_size)
assert self.device.index is not None, \
'Distributed training is only supported for CUDA.'
assert batch_size % world_size == 0, 'Batch size has to be ' + \
'evenly divisible by world size.'
assert G_reg_batch_size % world_size == 0, 'G reg batch size has to be ' + \
'evenly divisible by world size.'
batch_size = batch_size // world_size
G_reg_batch_size = G_reg_batch_size // world_size
init_method = 'tcp://{}:{}'.format(addr, port)
torch.distributed.init_process_group(
backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
else:
world_size = 1
self.rank = rank
self.world_size = world_size
# Set up variable to keep track of moving average of path lengths
self.pl_avg = torch.tensor(
pl_avg, dtype=torch.float16 if half else torch.float32, device=self.device)
# Broadcast parameters from rank 0 if running distributed
self._sync_distributed(G=self.G, D=self.D, broadcast_weights=True)
# Set up moving average of generator
# Only for non-distributed training or
# if rank is 0
if not self.rank:
# Values for `rank`: None -> not distributed, 0 -> distributed and 'main' node
self.Gs = Gs
if not isinstance(Gs, utils.MovingAverageModule):
self.Gs = utils.MovingAverageModule(
from_module=self.G,
to_module=Gs,
param_beta=Gs_beta,
device=self.device if Gs_device is None else Gs_device
)
else:
self.Gs = None
# Set up loss and regularization functions
self.G_loss = get_loss_fn('G', G_loss)
self.D_loss = get_loss_fn('D', D_loss)
self.G_reg = get_reg_fn('G', G_reg, pl_avg=self.pl_avg)
self.D_reg = get_reg_fn('D', D_reg)
self.G_reg_interval = G_reg_interval
self.D_reg_interval = D_reg_interval
self.G_iter = G_iter
self.D_iter = D_iter
# Set up optimizers (adjust hyperparameters if lazy regularization is active)
self.G_opt = build_opt(self.G, G_opt_class, G_opt_kwargs, self.G_reg, self.G_reg_interval)
self.D_opt = build_opt(self.D, D_opt_class, D_opt_kwargs, self.D_reg, self.D_reg_interval)
# Set up mixed precision training
if half:
assert 'apex' in sys.modules, 'Can not run mixed precision ' + \
'training (`half=True`) without the apex module.'
(self.G, self.D), (self.G_opt, self.D_opt) = amp.initialize(
[self.G, self.D], [self.G_opt, self.D_opt], opt_level='O1')
self.half = half
# Data
sampler = None
if self.rank is not None:
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
self.dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=device_batch_size,
num_workers=data_workers,
shuffle=sampler is None,
pin_memory=self.device.index is not None,
drop_last=True,
sampler=sampler
)
self.dataloader_iter = None
self.prior_generator = utils.PriorGenerator(
latent_size=latent_size,
label_size=label_size,
batch_size=device_batch_size,
device=self.device
)
assert batch_size % device_batch_size == 0, \
'Batch size has to be evenly divisible by the product of ' + \
'device batch size and world size.'
self.subdivisions = batch_size // device_batch_size
assert G_reg_batch_size % G_reg_device_batch_size == 0, \
'G reg batch size has to be evenly divisible by the product of ' + \
'G reg device batch size and world size.'
self.G_reg_subdivisions = G_reg_batch_size // G_reg_device_batch_size
self.G_reg_device_batch_size = G_reg_device_batch_size
self.tb_writer = None
if tensorboard_log_dir and not self.rank:
self.tb_writer = torch.utils.tensorboard.SummaryWriter(tensorboard_log_dir)
self.label_size = label_size
self.style_mix_prob = style_mix_prob
self.checkpoint_dir = checkpoint_dir
self.checkpoint_interval = checkpoint_interval
self.seen = seen
self.metrics = {}
self.callbacks = []
def _get_batch(self):
"""
Fetch a batch and its labels. If no labels are
available the returned labels will be `None`.
Returns:
data
labels
"""
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
try:
batch = next(self.dataloader_iter)
except StopIteration:
self.dataloader_iter = None
return self._get_batch()
if isinstance(batch, (tuple, list)):
if len(batch) > 1:
data, label = batch[:2]
else:
data, label = batch[0], None
else:
data, label = batch, None
if not self.label_size:
label = None
if torch.is_tensor(data):
data = data.to(self.device)
if torch.is_tensor(label):
label = label.to(self.device)
return data, label
def _sync_distributed(self, G=None, D=None, broadcast_weights=False):
"""
Sync the gradients (and alternatively the weights) of
the specified networks over the distributed training
nodes. Varying buffers are broadcasted from rank 0.
If no distributed training is not enabled, no action
is taken and this is a no-op function.
Arguments:
G (Generator, optional)
D (Discriminator, optional)
broadcast_weights (bool): Broadcast the weights from
node of rank 0 to all other ranks. Default
value is False.
"""
if self.rank is None:
return
for net in [G, D]:
if net is None:
continue
for p in net.parameters():
if p.grad is not None:
torch.distributed.all_reduce(p.grad, async_op=True)
if broadcast_weights:
torch.distributed.broadcast(p.data, src=0, async_op=True)
if G is not None:
if G.dlatent_avg is not None:
torch.distributed.broadcast(G.dlatent_avg, src=0, async_op=True)
if self.pl_avg is not None:
torch.distributed.broadcast(self.pl_avg, src=0, async_op=True)
if G is not None or D is not None:
torch.distributed.barrier(async_op=False)
def _backward(self, loss, opt, mul=1, subdivisions=None):
"""
Reduce loss by world size and subdivisions before
calling backward for the loss. Loss scaling is
performed when mixed precision training is
enabled.
Arguments:
loss (torch.Tensor)
opt (torch.optim.Optimizer)
mul (float): Loss weight. Default value is 1.
subdivisions (int, optional): The number of
subdivisions to divide by. If this is
not specified, the subdvisions from
the specified batch and device size
at construction is used.
Returns:
loss (torch.Tensor): The loss scaled by mul
and subdivisions but not by world size.
"""
if loss is None:
return 0
mul /= subdivisions or self.subdivisions
mul /= self.world_size or 1
if mul != 1:
loss *= mul
if self.half:
with amp.scale_loss(loss, opt) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
#get the scalar only
return loss.item() * (self.world_size or 1)
def train(self, iterations, callbacks=None, verbose=True):
"""
Train the models for a specific number of iterations.
Arguments:
iterations (int): Number of iterations to train for.
callbacks (callable, list, optional): One
or more callbacks to call at the end of each
iteration. The function is given the total
number of batches that have been processed since
this trainer object was initialized (not reset
when loading a saved checkpoint).
Default value is None (unused).
verbose (bool): Write progress to stdout.
Default value is True.
"""
evaluated_metrics = {}
if self.rank:
verbose=False
if verbose:
progress = utils.ProgressWriter(iterations)
value_tracker = utils.ValueTracker()
for _ in range(iterations):
# Figure out if G and/or D be
# regularized this iteration
G_reg = self.G_reg is not None
if self.G_reg_interval and G_reg:
G_reg = self.seen % self.G_reg_interval == 0
D_reg = self.D_reg is not None
if self.D_reg_interval and D_reg:
D_reg = self.seen % self.D_reg_interval == 0
# -----| Train G |----- #
# Disable gradients for D while training G
self.D.requires_grad_(False)
for _ in range(self.G_iter):
self.G_opt.zero_grad()
G_loss = 0
for i in range(self.subdivisions):
latents, latent_labels = self.prior_generator(
multi_latent_prob=self.style_mix_prob)
loss, _ = self.G_loss(
G=self.G,
D=self.D,
latents=latents,
latent_labels=latent_labels
)
G_loss += self._backward(loss, self.G_opt)
if G_reg:
if self.G_reg_interval:
# For lazy regularization, even if the interval
# is set to 1, the optimization step is taken
# before the gradients of the regularization is gathered.
self._sync_distributed(G=self.G)
self.G_opt.step()
self.G_opt.zero_grad()
G_reg_loss = 0
# Pathreg is expensive to compute which
# is why G regularization has its own settings
# for subdivisions and batch size.
for i in range(self.G_reg_subdivisions):
latents, latent_labels = self.prior_generator(
batch_size=self.G_reg_device_batch_size,
multi_latent_prob=self.style_mix_prob
)
_, reg_loss = self.G_reg(
G=self.G,
latents=latents,
latent_labels=latent_labels
)
G_reg_loss += self._backward(
reg_loss,
self.G_opt, mul=self.G_reg_interval or 1,
subdivisions=self.G_reg_subdivisions
)
self._sync_distributed(G=self.G)
self.G_opt.step()
# Update moving average of weights after
# each G training subiteration
if self.Gs is not None:
self.Gs.update()
# Re-enable gradients for D
self.D.requires_grad_(True)
# -----| Train D |----- #
# Disable gradients for G while training D
self.G.requires_grad_(False)
for _ in range(self.D_iter):
self.D_opt.zero_grad()
D_loss = 0
for i in range(self.subdivisions):
latents, latent_labels = self.prior_generator(
multi_latent_prob=self.style_mix_prob)
reals, real_labels = self._get_batch()
loss, _ = self.D_loss(
G=self.G,
D=self.D,
latents=latents,
latent_labels=latent_labels,
reals=reals,
real_labels=real_labels
)
D_loss += self._backward(loss, self.D_opt)
if D_reg:
if self.D_reg_interval:
# For lazy regularization, even if the interval
# is set to 1, the optimization step is taken
# before the gradients of the regularization is gathered.
self._sync_distributed(D=self.D)
self.D_opt.step()
self.D_opt.zero_grad()
D_reg_loss = 0
for i in range(self.subdivisions):
latents, latent_labels = self.prior_generator(
multi_latent_prob=self.style_mix_prob)
reals, real_labels = self._get_batch()
_, reg_loss = self.D_reg(
G=self.G,
D=self.D,
latents=latents,
latent_labels=latent_labels,
reals=reals,
real_labels=real_labels
)
D_reg_loss += self._backward(
reg_loss, self.D_opt, mul=self.D_reg_interval or 1)
self._sync_distributed(D=self.D)
self.D_opt.step()
# Re-enable grads for G
self.G.requires_grad_(True)
if self.tb_writer is not None or verbose:
# In case verbose is true and tensorboard logging enabled
# we calculate grad norm here to only do it once as well
# as making sure we do it before any metrics that may
# possibly zero the grads.
G_grad_norm = utils.get_grad_norm_from_optimizer(self.G_opt)
D_grad_norm = utils.get_grad_norm_from_optimizer(self.D_opt)
for name, metric in self.metrics.items():
if not metric['interval'] or self.seen % metric['interval'] == 0:
evaluated_metrics[name] = metric['eval_fn']()
# Printing and logging
# Tensorboard logging
if self.tb_writer is not None:
self.tb_writer.add_scalar('Loss/G_loss', G_loss, self.seen)
if G_reg:
self.tb_writer.add_scalar('Loss/G_reg', G_reg_loss, self.seen)
self.tb_writer.add_scalar('Grad_norm/G_reg', G_grad_norm, self.seen)
self.tb_writer.add_scalar('Params/pl_avg', self.pl_avg, self.seen)
else:
self.tb_writer.add_scalar('Grad_norm/G_loss', G_grad_norm, self.seen)
self.tb_writer.add_scalar('Loss/D_loss', D_loss, self.seen)
if D_reg:
self.tb_writer.add_scalar('Loss/D_reg', D_reg_loss, self.seen)
self.tb_writer.add_scalar('Grad_norm/D_reg', D_grad_norm, self.seen)
else:
self.tb_writer.add_scalar('Grad_norm/D_loss', D_grad_norm, self.seen)
for name, value in evaluated_metrics.items():
self.tb_writer.add_scalar('Metrics/{}'.format(name), value, self.seen)
# Printing
if verbose:
value_tracker.add('seen', self.seen + 1, beta=0)
value_tracker.add('G_lr', self.G_opt.param_groups[0]['lr'], beta=0)
value_tracker.add('G_loss', G_loss)
if G_reg:
value_tracker.add('G_reg', G_reg_loss)
value_tracker.add('G_reg_grad_norm', G_grad_norm)
value_tracker.add('pl_avg', self.pl_avg, beta=0)
else:
value_tracker.add('G_loss_grad_norm', G_grad_norm)
value_tracker.add('D_lr', self.D_opt.param_groups[0]['lr'], beta=0)
value_tracker.add('D_loss', D_loss)
if D_reg:
value_tracker.add('D_reg', D_reg_loss)
value_tracker.add('D_reg_grad_norm', D_grad_norm)
else:
value_tracker.add('D_loss_grad_norm', D_grad_norm)
for name, value in evaluated_metrics.items():
value_tracker.add(name, value, beta=0)
progress.write(str(value_tracker))
# Callback
for callback in utils.to_list(callbacks) + self.callbacks:
callback(self.seen)
self.seen += 1
# clear cache
torch.cuda.empty_cache()
# Handle checkpointing
if not self.rank and self.checkpoint_dir and self.checkpoint_interval:
if self.seen % self.checkpoint_interval == 0:
checkpoint_path = os.path.join(
self.checkpoint_dir,
'{}_{}'.format(self.seen, time.strftime('%Y-%m-%d_%H-%M-%S'))
)
self.save_checkpoint(checkpoint_path)
if verbose:
progress.close()
def register_metric(self, name, eval_fn, interval):
"""
Add a metric. This will be evaluated every `interval`
training iteration. Used by tensorboard and progress
updates written to stdout while training.
Arguments:
name (str): A name for the metric. If a metric with
this name already exists it will be overwritten.
eval_fn (callable): A function that evaluates the metric
and returns a python number.
interval (int): The interval to evaluate at.
"""
self.metrics[name] = {'eval_fn': eval_fn, 'interval': interval}
def remove_metric(self, name):
"""
Remove a metric that was previously registered.
Arguments:
name (str): Name of the metric.
"""
if name in self.metrics:
del self.metrics[name]
else:
warnings.warn(
'Attempting to remove metric {} '.format(name) + \
'which does not exist.'
)
def generate_images(self,
num_images,
seed=None,
truncation_psi=None,
truncation_cutoff=None,
label=None,
pixel_min=-1,
pixel_max=1):
"""
Generate some images with the generator and transform them into PIL
images and return them as a list.
Arguments:
num_images (int): Number of images to generate.
seed (int, optional): The seed for the random generation
of input latent values.
truncation_psi (float): See stylegan2.model.Generator.set_truncation()
Default value is None.
truncation_cutoff (int): See stylegan2.model.Generator.set_truncation()
label (int, list, optional): Label to condition all generated images with
or multiple labels, one for each generated image.
pixel_min (float): The min value in the pixel range of the generator.
Default value is -1.
pixel_min (float): The max value in the pixel range of the generator.
Default value is 1.
Returns:
images (list): List of PIL images.
"""
if seed is None:
seed = int(10000 * time.time())
latents, latent_labels = self.prior_generator(num_images, seed=seed)
if label:
assert latent_labels is not None, 'Can not specify label when no labels ' + \
'are used by this model.'
label = utils.to_list(label)
assert all(isinstance(l, int) for l in label), '`label` can only consist of ' + \
'one or more python integers.'
assert len(label) == 1 or len(label) == num_images, '`label` can either ' + \
'specify one label to use for all images or a list of labels of the ' + \
'same length as number of images. Received {} labels '.format(len(label)) + \
'but {} images are to be generated.'.format(num_images)
if len(label) == 1:
latent_labels.fill_(label[0])
else:
latent_labels = torch.tensor(label).to(latent_labels)
self.Gs.set_truncation(
truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff)
with torch.no_grad():
generated = self.Gs(latents=latents, labels=latent_labels)
assert generated.dim() - 2 == 2, 'Can only generate images when using a ' + \
'network built for 2-dimensional data.'
assert generated.dim() == 4, 'Only generators that produce 2d data ' + \
'can be used to generate images.'
return utils.tensor_to_PIL(generated, pixel_min=pixel_min, pixel_max=pixel_max)
def log_images_tensorboard(self, images, name, resize=256):
"""
Log a list of images to tensorboard by first turning
them into a grid. Can not be performed if rank > 0
or tensorboard_log_dir was not given at construction.
Arguments:
images (list): List of PIL images.
name (str): The name to log images for.
resize (int, tuple): The height and width to use for
each image in the grid. Default value is 256.
"""
assert self.tb_writer is not None, \
'No tensorboard log dir was specified ' + \
'when constructing this object.'
image = utils.stack_images_PIL(images, individual_img_size=resize)
image = torchvision.transforms.ToTensor()(image)
self.tb_writer.add_image(name, image, self.seen)
def add_tensorboard_image_logging(self,
name,
interval,
num_images,
resize=256,
seed=None,
truncation_psi=None,
truncation_cutoff=None,
label=None,
pixel_min=-1,
pixel_max=1):
"""
Set up tensorboard logging of generated images to be performed
at a certain training interval. If distributed training is set up
and this object does not have the rank 0, no logging will be performed
by this object.
All arguments except the ones mentioned below have their description
in the docstring of `generate_images()` and `log_images_tensorboard()`.
Arguments:
interval (int): The interval at which to log generated images.
"""
if self.rank:
return
def callback(seen):
if seen % interval == 0:
images = self.generate_images(
num_images=num_images,
seed=seed,
truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff,
label=label,
pixel_min=pixel_min,
pixel_max=pixel_max
)
self.log_images_tensorboard(
images=images,
name=name,
resize=resize
)
self.callbacks.append(callback)
def save_checkpoint(self, dir_path):
"""
Save the current state of this trainer as a checkpoint.
NOTE: The dataset can not be serialized and saved so this
has to be reconstructed and given when loading this checkpoint.
Arguments:
dir_path (str): The checkpoint path.
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
else:
assert os.path.isdir(dir_path), '`dir_path` points to a file.'
kwargs = self.kwargs.copy()
# Update arguments that may have changed since construction
kwargs.update(
seen=self.seen,
pl_avg=float(self.pl_avg)
)
with open(os.path.join(dir_path, 'kwargs.json'), 'w') as fp:
json.dump(kwargs, fp)
torch.save(self.G_opt.state_dict(), os.path.join(dir_path, 'G_opt.pth'))
torch.save(self.D_opt.state_dict(), os.path.join(dir_path, 'D_opt.pth'))
models.save(self.G, os.path.join(dir_path, 'G.pth'))
models.save(self.D, os.path.join(dir_path, 'D.pth'))
if self.Gs is not None:
models.save(self.Gs, os.path.join(dir_path, 'Gs.pth'))
@classmethod
def load_checkpoint(cls, checkpoint_path, dataset, **kwargs):
"""
Load a checkpoint into a new Trainer object and return that
object. If the path specified points at a folder containing
multiple checkpoints, the latest one will be used.
The dataset can not be serialized and saved so it is required
to be explicitly given when loading a checkpoint.
Arguments:
checkpoint_path (str): Path to a checkpoint or to a folder
containing one or more checkpoints.
dataset (indexable): The dataset to use.
**kwargs (keyword arguments): Any other arguments to override
the ones saved in the checkpoint. Useful for when training
is continued on a different device or when distributed training
is changed.
"""
checkpoint_path = _find_checkpoint(checkpoint_path)
_is_checkpoint(checkpoint_path, enforce=True)
with open(os.path.join(checkpoint_path, 'kwargs.json'), 'r') as fp:
loaded_kwargs = json.load(fp)
loaded_kwargs.update(**kwargs)
device = torch.device('cpu')
if isinstance(loaded_kwargs['device'], (list, tuple)):
device = torch.device(loaded_kwargs['device'][0])
for name in ['G', 'D']:
fpath = os.path.join(checkpoint_path, name + '.pth')
loaded_kwargs[name] = models.load(fpath, map_location=device)
if os.path.exists(os.path.join(checkpoint_path, 'Gs.pth')):
loaded_kwargs['Gs'] = models.load(
os.path.join(checkpoint_path, 'Gs.pth'),
map_location=device if loaded_kwargs['Gs_device'] is None \
else torch.device(loaded_kwargs['Gs_device'])
)
obj = cls(dataset=dataset, **loaded_kwargs)
for name in ['G_opt', 'D_opt']:
fpath = os.path.join(checkpoint_path, name + '.pth')
state_dict = torch.load(fpath, map_location=device)
getattr(obj, name).load_state_dict(state_dict)
return obj
#----------------------------------------------------------------------------
# Checkpoint helper functions
def _is_checkpoint(dir_path, enforce=False):
if not dir_path:
if enforce:
raise ValueError('Not a checkpoint.')
return False
if not os.path.exists(dir_path):
if enforce:
raise FileNotFoundError('{} could not be found.'.format(dir_path))
return False
if not os.path.isdir(dir_path):
if enforce:
raise NotADirectoryError('{} is not a directory.'.format(dir_path))
return False
fnames = os.listdir(dir_path)
for fname in ['G.pth', 'D.pth', 'G_opt.pth', 'D_opt.pth', 'kwargs.json']:
if fname not in fnames:
if enforce:
raise FileNotFoundError(
'Could not find {} in {}.'.format(fname, dir_path))
return False
return True
def _find_checkpoint(dir_path):
if not dir_path:
return None
if not os.path.exists(dir_path) or not os.path.isdir(dir_path):
return None
if _is_checkpoint(dir_path):
return dir_path
checkpoint_names = []
for name in os.listdir(dir_path):
if _is_checkpoint(os.path.join(dir_path, name)):
checkpoint_names.append(name)
if not checkpoint_names:
return None
def get_iteration(name):
return int(name.split('_')[0])
def get_timestamp(name):
return '_'.join(name.split('_')[1:])
# Python sort is stable, meaning that this sort operation
# will guarantee that the order of values after the first
# sort will stay for a set of values that have the same
# key value.
checkpoint_names = sorted(
sorted(checkpoint_names, key=get_iteration), key=get_timestamp)
return os.path.join(dir_path, checkpoint_names[-1])
#----------------------------------------------------------------------------
# Reg and loss function fetchers
def build_opt(net, opt_class, opt_kwargs, reg, reg_interval):
opt_kwargs['lr'] = opt_kwargs.get('lr', 1e-3)
if reg not in [None, False] and reg_interval:
mb_ratio = reg_interval / (reg_interval + 1.)
opt_kwargs['lr'] *= mb_ratio
if 'momentum' in opt_kwargs:
opt_kwargs['momentum'] = opt_kwargs['momentum'] ** mb_ratio
if 'betas' in opt_kwargs:
betas = opt_kwargs['betas']
opt_kwargs['betas'] = (betas[0] ** mb_ratio, betas[1] ** mb_ratio)
if isinstance(opt_class, str):
opt_class = getattr(torch.optim, opt_class.title())
return opt_class(net.parameters(), **opt_kwargs)
#----------------------------------------------------------------------------
# Reg and loss function fetchers
_LOSS_FNS = {
'G': {
'logistic': loss_fns.G_logistic,
'logistic_ns': loss_fns.G_logistic_ns,
'wgan': loss_fns.G_wgan
},
'D': {
'logistic': loss_fns.D_logistic,
'wgan': loss_fns.D_wgan
}
}
def get_loss_fn(net, loss):
if callable(loss):
return loss
net = net.upper()
assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
loss = loss.lower()
for name in _LOSS_FNS[net].keys():
if loss == name:
return _LOSS_FNS[net][name]
raise ValueError('Unknow {} loss {}'.format(net, loss))
_REG_FNS = {
'G': {
'pathreg': loss_fns.G_pathreg
},
'D': {
'r1': loss_fns.D_r1,
'r2': loss_fns.D_r2,
'gp': loss_fns.D_gp,
}
}
def get_reg_fn(net, reg, **kwargs):
if reg is None:
return None
if callable(reg):
functools.partial(reg, **kwargs)
net = net.upper()
assert net in ['G', 'D'], 'Unknown net type {}'.format(net)
reg = reg.lower()
gamma = None
for name in _REG_FNS[net].keys():
if reg.startswith(name):
gamma_chars = [c for c in reg.replace(name, '') if c.isdigit() or c == '.']
if gamma_chars:
kwargs.update(gamma=float(''.join(gamma_chars)))
return functools.partial(_REG_FNS[net][name], **kwargs)
raise ValueError('Unknow regularizer {}'.format(reg))
|