File size: 7,615 Bytes
e34aada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""This script is the training script for Deep3DFaceRecon_pytorch
"""
import os
import time
import numpy as np
import torch
from options.train_options import TrainOptions
from data import create_dataset
from deep_3drecon_models import create_model
from util.visualizer import MyVisualizer
from util.util import genvalconf
import torch.multiprocessing as mp
import torch.distributed as dist
def setup(rank, world_size, port):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = port
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
def main(rank, world_size, train_opt):
val_opt = genvalconf(train_opt, isTrain=False)
device = torch.device(rank)
torch.cuda.set_device(device)
use_ddp = train_opt.use_ddp
if use_ddp:
setup(rank, world_size, train_opt.ddp_port)
train_dataset, val_dataset = create_dataset(train_opt, rank=rank), create_dataset(val_opt, rank=rank)
train_dataset_batches, val_dataset_batches = \
len(train_dataset) // train_opt.batch_size, len(val_dataset) // val_opt.batch_size
model = create_model(train_opt) # create a model given train_opt.model and other options
model.setup(train_opt)
model.device = device
model.parallelize()
if rank == 0:
print('The batch number of training images = %d\n, \
the batch number of validation images = %d'% (train_dataset_batches, val_dataset_batches))
model.print_networks(train_opt.verbose)
visualizer = MyVisualizer(train_opt) # create a visualizer that display/save images and plots
total_iters = train_dataset_batches * (train_opt.epoch_count - 1) # the total number of training iterations
t_data = 0
t_val = 0
optimize_time = 0.1
batch_size = 1 if train_opt.display_per_batch else train_opt.batch_size
if use_ddp:
dist.barrier()
times = []
for epoch in range(train_opt.epoch_count, train_opt.n_epochs + 1): # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
epoch_start_time = time.time() # timer for entire epoch
iter_data_time = time.time() # timer for train_data loading per iteration
epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
train_dataset.set_epoch(epoch)
for i, train_data in enumerate(train_dataset): # inner loop within one epoch
iter_start_time = time.time() # timer for computation per iteration
if total_iters % train_opt.print_freq == 0:
t_data = iter_start_time - iter_data_time
total_iters += batch_size
epoch_iter += batch_size
torch.cuda.synchronize()
optimize_start_time = time.time()
model.set_input(train_data) # unpack train_data from dataset and apply preprocessing
model.optimize_parameters() # calculate loss functions, get gradients, update network weights
torch.cuda.synchronize()
optimize_time = (time.time() - optimize_start_time) / batch_size * 0.005 + 0.995 * optimize_time
if use_ddp:
dist.barrier()
if rank == 0 and (total_iters == batch_size or total_iters % train_opt.display_freq == 0): # display images on visdom and save images to a HTML file
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
save_results=True,
add_image=train_opt.add_image)
# (total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0)
if rank == 0 and (total_iters == batch_size or total_iters % train_opt.print_freq == 0): # print training losses and save logging information to the disk
losses = model.get_current_losses()
visualizer.print_current_losses(epoch, epoch_iter, losses, optimize_time, t_data)
visualizer.plot_current_losses(total_iters, losses)
if total_iters == batch_size or total_iters % train_opt.evaluation_freq == 0:
with torch.no_grad():
torch.cuda.synchronize()
val_start_time = time.time()
losses_avg = {}
model.eval()
for j, val_data in enumerate(val_dataset):
model.set_input(val_data)
model.optimize_parameters(isTrain=False)
if rank == 0 and j < train_opt.vis_batch_nums:
model.compute_visuals()
visualizer.display_current_results(model.get_current_visuals(), total_iters, epoch,
dataset='val', save_results=True, count=j * val_opt.batch_size,
add_image=train_opt.add_image)
if j < train_opt.eval_batch_nums:
losses = model.get_current_losses()
for key, value in losses.items():
losses_avg[key] = losses_avg.get(key, 0) + value
for key, value in losses_avg.items():
losses_avg[key] = value / min(train_opt.eval_batch_nums, val_dataset_batches)
torch.cuda.synchronize()
eval_time = time.time() - val_start_time
if rank == 0:
visualizer.print_current_losses(epoch, epoch_iter, losses_avg, eval_time, t_data, dataset='val') # visualize training results
visualizer.plot_current_losses(total_iters, losses_avg, dataset='val')
model.train()
if use_ddp:
dist.barrier()
if rank == 0 and (total_iters == batch_size or total_iters % train_opt.save_latest_freq == 0): # cache our latest model every <save_latest_freq> iterations
print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters))
print(train_opt.name) # it's useful to occasionally show the experiment name on console
save_suffix = 'iter_%d' % total_iters if train_opt.save_by_iter else 'latest'
model.save_networks(save_suffix)
if use_ddp:
dist.barrier()
iter_data_time = time.time()
print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, train_opt.n_epochs, time.time() - epoch_start_time))
model.update_learning_rate() # update learning rates at the end of every epoch.
if rank == 0 and epoch % train_opt.save_epoch_freq == 0: # cache our model every <save_epoch_freq> epochs
print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
model.save_networks('latest')
model.save_networks(epoch)
if use_ddp:
dist.barrier()
if __name__ == '__main__':
import warnings
warnings.filterwarnings("ignore")
train_opt = TrainOptions().parse() # get training options
world_size = train_opt.world_size
if train_opt.use_ddp:
mp.spawn(main, args=(world_size, train_opt), nprocs=world_size, join=True)
else:
main(0, world_size, train_opt)
|