dattarij's picture
adding ContraCLIP folder
8c212a5
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
import os
import time
import numpy as np
import tensorflow as tf
import config
import tfutil
import dataset
import misc
#----------------------------------------------------------------------------
# Choose the size and contents of the image snapshot grids that are exported
# periodically during training.
def setup_snapshot_image_grid(G, training_set,
size = '1080p', # '1080p' = to be viewed on 1080p display, '4k' = to be viewed on 4k display.
layout = 'random'): # 'random' = grid contents are selected randomly, 'row_per_class' = each row corresponds to one class label.
# Select size.
gw = 1; gh = 1
if size == '1080p':
gw = np.clip(1920 // G.output_shape[3], 3, 32)
gh = np.clip(1080 // G.output_shape[2], 2, 32)
if size == '4k':
gw = np.clip(3840 // G.output_shape[3], 7, 32)
gh = np.clip(2160 // G.output_shape[2], 4, 32)
# Fill in reals and labels.
reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype)
labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype)
for idx in range(gw * gh):
x = idx % gw; y = idx // gw
while True:
real, label = training_set.get_minibatch_np(1)
if layout == 'row_per_class' and training_set.label_size > 0:
if label[0, y % training_set.label_size] == 0.0:
continue
reals[idx] = real[0]
labels[idx] = label[0]
break
# Generate latents.
latents = misc.random_latents(gw * gh, G)
return (gw, gh), reals, labels, latents
#----------------------------------------------------------------------------
# Just-in-time processing of training images before feeding them to the networks.
def process_reals(x, lod, mirror_augment, drange_data, drange_net):
with tf.name_scope('ProcessReals'):
with tf.name_scope('DynamicRange'):
x = tf.cast(x, tf.float32)
x = misc.adjust_dynamic_range(x, drange_data, drange_net)
if mirror_augment:
with tf.name_scope('MirrorAugment'):
s = tf.shape(x)
mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0)
mask = tf.tile(mask, [1, s[1], s[2], s[3]])
x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3]))
with tf.name_scope('FadeLOD'): # Smooth crossfade between consecutive levels-of-detail.
s = tf.shape(x)
y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2])
y = tf.reduce_mean(y, axis=[3, 5], keepdims=True)
y = tf.tile(y, [1, 1, 1, 2, 1, 2])
y = tf.reshape(y, [-1, s[1], s[2], s[3]])
x = tfutil.lerp(x, y, lod - tf.floor(lod))
with tf.name_scope('UpscaleLOD'): # Upscale to match the expected input/output size of the networks.
s = tf.shape(x)
factor = tf.cast(2 ** tf.floor(lod), tf.int32)
x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1])
x = tf.tile(x, [1, 1, 1, factor, 1, factor])
x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor])
return x
#----------------------------------------------------------------------------
# Class for evaluating and storing the values of time-varying training parameters.
class TrainingSchedule:
def __init__(
self,
cur_nimg,
training_set,
lod_initial_resolution = 4, # Image resolution used at the beginning.
lod_training_kimg = 600, # Thousands of real images to show before doubling the resolution.
lod_transition_kimg = 600, # Thousands of real images to show when fading in new layers.
minibatch_base = 16, # Maximum minibatch size, divided evenly among GPUs.
minibatch_dict = {}, # Resolution-specific overrides.
max_minibatch_per_gpu = {}, # Resolution-specific maximum minibatch size per GPU.
G_lrate_base = 0.001, # Learning rate for the generator.
G_lrate_dict = {}, # Resolution-specific overrides.
D_lrate_base = 0.001, # Learning rate for the discriminator.
D_lrate_dict = {}, # Resolution-specific overrides.
tick_kimg_base = 160, # Default interval of progress snapshots.
tick_kimg_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:20, 1024:10}): # Resolution-specific overrides.
# Training phase.
self.kimg = cur_nimg / 1000.0
phase_dur = lod_training_kimg + lod_transition_kimg
phase_idx = int(np.floor(self.kimg / phase_dur)) if phase_dur > 0 else 0
phase_kimg = self.kimg - phase_idx * phase_dur
# Level-of-detail and resolution.
self.lod = training_set.resolution_log2
self.lod -= np.floor(np.log2(lod_initial_resolution))
self.lod -= phase_idx
if lod_transition_kimg > 0:
self.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg
self.lod = max(self.lod, 0.0)
self.resolution = 2 ** (training_set.resolution_log2 - int(np.floor(self.lod)))
# Minibatch size.
self.minibatch = minibatch_dict.get(self.resolution, minibatch_base)
self.minibatch -= self.minibatch % config.num_gpus
if self.resolution in max_minibatch_per_gpu:
self.minibatch = min(self.minibatch, max_minibatch_per_gpu[self.resolution] * config.num_gpus)
# Other parameters.
self.G_lrate = G_lrate_dict.get(self.resolution, G_lrate_base)
self.D_lrate = D_lrate_dict.get(self.resolution, D_lrate_base)
self.tick_kimg = tick_kimg_dict.get(self.resolution, tick_kimg_base)
#----------------------------------------------------------------------------
# Main training script.
# To run, comment/uncomment appropriate lines in config.py and launch train.py.
def train_progressive_gan(
G_smoothing = 0.999, # Exponential running average of generator weights.
D_repeats = 1, # How many times the discriminator is trained per G iteration.
minibatch_repeats = 4, # Number of minibatches to run before adjusting training parameters.
reset_opt_for_new_lod = True, # Reset optimizer internal state (e.g. Adam moments) when new layers are introduced?
total_kimg = 15000, # Total length of the training, measured in thousands of real images.
mirror_augment = False, # Enable mirror augment?
drange_net = [-1,1], # Dynamic range used when feeding image data to the networks.
image_snapshot_ticks = 1, # How often to export image snapshots?
network_snapshot_ticks = 10, # How often to export network snapshots?
save_tf_graph = False, # Include full TensorFlow computation graph in the tfevents file?
save_weight_histograms = False, # Include weight histograms in the tfevents file?
resume_run_id = None, # Run ID or network pkl to resume training from, None = start from scratch.
resume_snapshot = None, # Snapshot index to resume training from, None = autodetect.
resume_kimg = 0.0, # Assumed training progress at the beginning. Affects reporting and training schedule.
resume_time = 0.0): # Assumed wallclock time at the beginning. Affects reporting.
maintenance_start_time = time.time()
training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset)
# Construct networks.
with tf.device('/gpu:0'):
if resume_run_id is not None:
network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot)
print('Loading networks from "%s"...' % network_pkl)
G, D, Gs = misc.load_pkl(network_pkl)
else:
print('Constructing networks...')
G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G)
D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D)
Gs = G.clone('Gs')
Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing)
G.print_layers(); D.print_layers()
print('Building TensorFlow graph...')
with tf.name_scope('Inputs'):
lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[])
lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[])
minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[])
minibatch_split = minibatch_in // config.num_gpus
reals, labels = training_set.get_minibatch_tf()
reals_split = tf.split(reals, config.num_gpus)
labels_split = tf.split(labels, config.num_gpus)
G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt)
D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt)
for gpu in range(config.num_gpus):
with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu):
G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow')
D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow')
lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)]
reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net)
labels_gpu = labels_split[gpu]
with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops):
G_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss)
with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops):
D_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss)
G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables)
D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables)
G_train_op = G_opt.apply_updates()
D_train_op = D_opt.apply_updates()
print('Setting up snapshot image grid...')
grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(G, training_set, **config.grid)
sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched)
grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)
print('Setting up result dir...')
result_subdir = misc.create_result_subdir(config.result_dir, config.desc)
misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size)
misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size)
summary_log = tf.summary.FileWriter(result_subdir)
if save_tf_graph:
summary_log.add_graph(tf.get_default_graph())
if save_weight_histograms:
G.setup_weight_histograms(); D.setup_weight_histograms()
print('Training...')
cur_nimg = int(resume_kimg * 1000)
cur_tick = 0
tick_start_nimg = cur_nimg
tick_start_time = time.time()
train_start_time = tick_start_time - resume_time
prev_lod = -1.0
while cur_nimg < total_kimg * 1000:
# Choose training parameters and configure training ops.
sched = TrainingSchedule(cur_nimg, training_set, **config.sched)
training_set.configure(sched.minibatch, sched.lod)
if reset_opt_for_new_lod:
if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod):
G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state()
prev_lod = sched.lod
# Run training ops.
for repeat in range(minibatch_repeats):
for _ in range(D_repeats):
tfutil.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch})
cur_nimg += sched.minibatch
tfutil.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch})
# Perform maintenance tasks once per tick.
done = (cur_nimg >= total_kimg * 1000)
if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done:
cur_tick += 1
cur_time = time.time()
tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0
tick_start_nimg = cur_nimg
tick_time = cur_time - tick_start_time
total_time = cur_time - train_start_time
maintenance_time = tick_start_time - maintenance_start_time
maintenance_start_time = cur_time
# Report progress.
print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % (
tfutil.autosummary('Progress/tick', cur_tick),
tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0),
tfutil.autosummary('Progress/lod', sched.lod),
tfutil.autosummary('Progress/minibatch', sched.minibatch),
misc.format_time(tfutil.autosummary('Timing/total_sec', total_time)),
tfutil.autosummary('Timing/sec_per_tick', tick_time),
tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg),
tfutil.autosummary('Timing/maintenance_sec', maintenance_time)))
tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0))
tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0))
tfutil.save_summaries(summary_log, cur_nimg)
# Save snapshots.
if cur_tick % image_snapshot_ticks == 0 or done:
grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus)
misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size)
if cur_tick % network_snapshot_ticks == 0 or done:
misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000)))
# Record start time of the next tick.
tick_start_time = time.time()
# Write final results.
misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl'))
summary_log.close()
open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close()
#----------------------------------------------------------------------------
# Main entry point.
# Calls the function indicated in config.py.
if __name__ == "__main__":
misc.init_output_logging()
np.random.seed(config.random_seed)
print('Initializing TensorFlow...')
os.environ.update(config.env)
tfutil.init_tf(config.tf_config)
print('Running %s()...' % config.train['func'])
tfutil.call_func_by_name(**config.train)
print('Exiting...')
#----------------------------------------------------------------------------