|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import time |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
import config |
|
import tfutil |
|
import dataset |
|
import misc |
|
|
|
|
|
|
|
|
|
|
|
def setup_snapshot_image_grid(G, training_set, |
|
size = '1080p', |
|
layout = 'random'): |
|
|
|
|
|
gw = 1; gh = 1 |
|
if size == '1080p': |
|
gw = np.clip(1920 // G.output_shape[3], 3, 32) |
|
gh = np.clip(1080 // G.output_shape[2], 2, 32) |
|
if size == '4k': |
|
gw = np.clip(3840 // G.output_shape[3], 7, 32) |
|
gh = np.clip(2160 // G.output_shape[2], 4, 32) |
|
|
|
|
|
reals = np.zeros([gw * gh] + training_set.shape, dtype=training_set.dtype) |
|
labels = np.zeros([gw * gh, training_set.label_size], dtype=training_set.label_dtype) |
|
for idx in range(gw * gh): |
|
x = idx % gw; y = idx // gw |
|
while True: |
|
real, label = training_set.get_minibatch_np(1) |
|
if layout == 'row_per_class' and training_set.label_size > 0: |
|
if label[0, y % training_set.label_size] == 0.0: |
|
continue |
|
reals[idx] = real[0] |
|
labels[idx] = label[0] |
|
break |
|
|
|
|
|
latents = misc.random_latents(gw * gh, G) |
|
return (gw, gh), reals, labels, latents |
|
|
|
|
|
|
|
|
|
def process_reals(x, lod, mirror_augment, drange_data, drange_net): |
|
with tf.name_scope('ProcessReals'): |
|
with tf.name_scope('DynamicRange'): |
|
x = tf.cast(x, tf.float32) |
|
x = misc.adjust_dynamic_range(x, drange_data, drange_net) |
|
if mirror_augment: |
|
with tf.name_scope('MirrorAugment'): |
|
s = tf.shape(x) |
|
mask = tf.random_uniform([s[0], 1, 1, 1], 0.0, 1.0) |
|
mask = tf.tile(mask, [1, s[1], s[2], s[3]]) |
|
x = tf.where(mask < 0.5, x, tf.reverse(x, axis=[3])) |
|
with tf.name_scope('FadeLOD'): |
|
s = tf.shape(x) |
|
y = tf.reshape(x, [-1, s[1], s[2]//2, 2, s[3]//2, 2]) |
|
y = tf.reduce_mean(y, axis=[3, 5], keepdims=True) |
|
y = tf.tile(y, [1, 1, 1, 2, 1, 2]) |
|
y = tf.reshape(y, [-1, s[1], s[2], s[3]]) |
|
x = tfutil.lerp(x, y, lod - tf.floor(lod)) |
|
with tf.name_scope('UpscaleLOD'): |
|
s = tf.shape(x) |
|
factor = tf.cast(2 ** tf.floor(lod), tf.int32) |
|
x = tf.reshape(x, [-1, s[1], s[2], 1, s[3], 1]) |
|
x = tf.tile(x, [1, 1, 1, factor, 1, factor]) |
|
x = tf.reshape(x, [-1, s[1], s[2] * factor, s[3] * factor]) |
|
return x |
|
|
|
|
|
|
|
|
|
class TrainingSchedule: |
|
def __init__( |
|
self, |
|
cur_nimg, |
|
training_set, |
|
lod_initial_resolution = 4, |
|
lod_training_kimg = 600, |
|
lod_transition_kimg = 600, |
|
minibatch_base = 16, |
|
minibatch_dict = {}, |
|
max_minibatch_per_gpu = {}, |
|
G_lrate_base = 0.001, |
|
G_lrate_dict = {}, |
|
D_lrate_base = 0.001, |
|
D_lrate_dict = {}, |
|
tick_kimg_base = 160, |
|
tick_kimg_dict = {4: 160, 8:140, 16:120, 32:100, 64:80, 128:60, 256:40, 512:20, 1024:10}): |
|
|
|
|
|
self.kimg = cur_nimg / 1000.0 |
|
phase_dur = lod_training_kimg + lod_transition_kimg |
|
phase_idx = int(np.floor(self.kimg / phase_dur)) if phase_dur > 0 else 0 |
|
phase_kimg = self.kimg - phase_idx * phase_dur |
|
|
|
|
|
self.lod = training_set.resolution_log2 |
|
self.lod -= np.floor(np.log2(lod_initial_resolution)) |
|
self.lod -= phase_idx |
|
if lod_transition_kimg > 0: |
|
self.lod -= max(phase_kimg - lod_training_kimg, 0.0) / lod_transition_kimg |
|
self.lod = max(self.lod, 0.0) |
|
self.resolution = 2 ** (training_set.resolution_log2 - int(np.floor(self.lod))) |
|
|
|
|
|
self.minibatch = minibatch_dict.get(self.resolution, minibatch_base) |
|
self.minibatch -= self.minibatch % config.num_gpus |
|
if self.resolution in max_minibatch_per_gpu: |
|
self.minibatch = min(self.minibatch, max_minibatch_per_gpu[self.resolution] * config.num_gpus) |
|
|
|
|
|
self.G_lrate = G_lrate_dict.get(self.resolution, G_lrate_base) |
|
self.D_lrate = D_lrate_dict.get(self.resolution, D_lrate_base) |
|
self.tick_kimg = tick_kimg_dict.get(self.resolution, tick_kimg_base) |
|
|
|
|
|
|
|
|
|
|
|
def train_progressive_gan( |
|
G_smoothing = 0.999, |
|
D_repeats = 1, |
|
minibatch_repeats = 4, |
|
reset_opt_for_new_lod = True, |
|
total_kimg = 15000, |
|
mirror_augment = False, |
|
drange_net = [-1,1], |
|
image_snapshot_ticks = 1, |
|
network_snapshot_ticks = 10, |
|
save_tf_graph = False, |
|
save_weight_histograms = False, |
|
resume_run_id = None, |
|
resume_snapshot = None, |
|
resume_kimg = 0.0, |
|
resume_time = 0.0): |
|
|
|
maintenance_start_time = time.time() |
|
training_set = dataset.load_dataset(data_dir=config.data_dir, verbose=True, **config.dataset) |
|
|
|
|
|
with tf.device('/gpu:0'): |
|
if resume_run_id is not None: |
|
network_pkl = misc.locate_network_pkl(resume_run_id, resume_snapshot) |
|
print('Loading networks from "%s"...' % network_pkl) |
|
G, D, Gs = misc.load_pkl(network_pkl) |
|
else: |
|
print('Constructing networks...') |
|
G = tfutil.Network('G', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.G) |
|
D = tfutil.Network('D', num_channels=training_set.shape[0], resolution=training_set.shape[1], label_size=training_set.label_size, **config.D) |
|
Gs = G.clone('Gs') |
|
Gs_update_op = Gs.setup_as_moving_average_of(G, beta=G_smoothing) |
|
G.print_layers(); D.print_layers() |
|
|
|
print('Building TensorFlow graph...') |
|
with tf.name_scope('Inputs'): |
|
lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) |
|
lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) |
|
minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) |
|
minibatch_split = minibatch_in // config.num_gpus |
|
reals, labels = training_set.get_minibatch_tf() |
|
reals_split = tf.split(reals, config.num_gpus) |
|
labels_split = tf.split(labels, config.num_gpus) |
|
G_opt = tfutil.Optimizer(name='TrainG', learning_rate=lrate_in, **config.G_opt) |
|
D_opt = tfutil.Optimizer(name='TrainD', learning_rate=lrate_in, **config.D_opt) |
|
for gpu in range(config.num_gpus): |
|
with tf.name_scope('GPU%d' % gpu), tf.device('/gpu:%d' % gpu): |
|
G_gpu = G if gpu == 0 else G.clone(G.name + '_shadow') |
|
D_gpu = D if gpu == 0 else D.clone(D.name + '_shadow') |
|
lod_assign_ops = [tf.assign(G_gpu.find_var('lod'), lod_in), tf.assign(D_gpu.find_var('lod'), lod_in)] |
|
reals_gpu = process_reals(reals_split[gpu], lod_in, mirror_augment, training_set.dynamic_range, drange_net) |
|
labels_gpu = labels_split[gpu] |
|
with tf.name_scope('G_loss'), tf.control_dependencies(lod_assign_ops): |
|
G_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=G_opt, training_set=training_set, minibatch_size=minibatch_split, **config.G_loss) |
|
with tf.name_scope('D_loss'), tf.control_dependencies(lod_assign_ops): |
|
D_loss = tfutil.call_func_by_name(G=G_gpu, D=D_gpu, opt=D_opt, training_set=training_set, minibatch_size=minibatch_split, reals=reals_gpu, labels=labels_gpu, **config.D_loss) |
|
G_opt.register_gradients(tf.reduce_mean(G_loss), G_gpu.trainables) |
|
D_opt.register_gradients(tf.reduce_mean(D_loss), D_gpu.trainables) |
|
G_train_op = G_opt.apply_updates() |
|
D_train_op = D_opt.apply_updates() |
|
|
|
print('Setting up snapshot image grid...') |
|
grid_size, grid_reals, grid_labels, grid_latents = setup_snapshot_image_grid(G, training_set, **config.grid) |
|
sched = TrainingSchedule(total_kimg * 1000, training_set, **config.sched) |
|
grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus) |
|
|
|
print('Setting up result dir...') |
|
result_subdir = misc.create_result_subdir(config.result_dir, config.desc) |
|
misc.save_image_grid(grid_reals, os.path.join(result_subdir, 'reals.png'), drange=training_set.dynamic_range, grid_size=grid_size) |
|
misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % 0), drange=drange_net, grid_size=grid_size) |
|
summary_log = tf.summary.FileWriter(result_subdir) |
|
if save_tf_graph: |
|
summary_log.add_graph(tf.get_default_graph()) |
|
if save_weight_histograms: |
|
G.setup_weight_histograms(); D.setup_weight_histograms() |
|
|
|
print('Training...') |
|
cur_nimg = int(resume_kimg * 1000) |
|
cur_tick = 0 |
|
tick_start_nimg = cur_nimg |
|
tick_start_time = time.time() |
|
train_start_time = tick_start_time - resume_time |
|
prev_lod = -1.0 |
|
while cur_nimg < total_kimg * 1000: |
|
|
|
|
|
sched = TrainingSchedule(cur_nimg, training_set, **config.sched) |
|
training_set.configure(sched.minibatch, sched.lod) |
|
if reset_opt_for_new_lod: |
|
if np.floor(sched.lod) != np.floor(prev_lod) or np.ceil(sched.lod) != np.ceil(prev_lod): |
|
G_opt.reset_optimizer_state(); D_opt.reset_optimizer_state() |
|
prev_lod = sched.lod |
|
|
|
|
|
for repeat in range(minibatch_repeats): |
|
for _ in range(D_repeats): |
|
tfutil.run([D_train_op, Gs_update_op], {lod_in: sched.lod, lrate_in: sched.D_lrate, minibatch_in: sched.minibatch}) |
|
cur_nimg += sched.minibatch |
|
tfutil.run([G_train_op], {lod_in: sched.lod, lrate_in: sched.G_lrate, minibatch_in: sched.minibatch}) |
|
|
|
|
|
done = (cur_nimg >= total_kimg * 1000) |
|
if cur_nimg >= tick_start_nimg + sched.tick_kimg * 1000 or done: |
|
cur_tick += 1 |
|
cur_time = time.time() |
|
tick_kimg = (cur_nimg - tick_start_nimg) / 1000.0 |
|
tick_start_nimg = cur_nimg |
|
tick_time = cur_time - tick_start_time |
|
total_time = cur_time - train_start_time |
|
maintenance_time = tick_start_time - maintenance_start_time |
|
maintenance_start_time = cur_time |
|
|
|
|
|
print('tick %-5d kimg %-8.1f lod %-5.2f minibatch %-4d time %-12s sec/tick %-7.1f sec/kimg %-7.2f maintenance %.1f' % ( |
|
tfutil.autosummary('Progress/tick', cur_tick), |
|
tfutil.autosummary('Progress/kimg', cur_nimg / 1000.0), |
|
tfutil.autosummary('Progress/lod', sched.lod), |
|
tfutil.autosummary('Progress/minibatch', sched.minibatch), |
|
misc.format_time(tfutil.autosummary('Timing/total_sec', total_time)), |
|
tfutil.autosummary('Timing/sec_per_tick', tick_time), |
|
tfutil.autosummary('Timing/sec_per_kimg', tick_time / tick_kimg), |
|
tfutil.autosummary('Timing/maintenance_sec', maintenance_time))) |
|
tfutil.autosummary('Timing/total_hours', total_time / (60.0 * 60.0)) |
|
tfutil.autosummary('Timing/total_days', total_time / (24.0 * 60.0 * 60.0)) |
|
tfutil.save_summaries(summary_log, cur_nimg) |
|
|
|
|
|
if cur_tick % image_snapshot_ticks == 0 or done: |
|
grid_fakes = Gs.run(grid_latents, grid_labels, minibatch_size=sched.minibatch//config.num_gpus) |
|
misc.save_image_grid(grid_fakes, os.path.join(result_subdir, 'fakes%06d.png' % (cur_nimg // 1000)), drange=drange_net, grid_size=grid_size) |
|
if cur_tick % network_snapshot_ticks == 0 or done: |
|
misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-snapshot-%06d.pkl' % (cur_nimg // 1000))) |
|
|
|
|
|
tick_start_time = time.time() |
|
|
|
|
|
misc.save_pkl((G, D, Gs), os.path.join(result_subdir, 'network-final.pkl')) |
|
summary_log.close() |
|
open(os.path.join(result_subdir, '_training-done.txt'), 'wt').close() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
misc.init_output_logging() |
|
np.random.seed(config.random_seed) |
|
print('Initializing TensorFlow...') |
|
os.environ.update(config.env) |
|
tfutil.init_tf(config.tf_config) |
|
print('Running %s()...' % config.train['func']) |
|
tfutil.call_func_by_name(**config.train) |
|
print('Exiting...') |
|
|
|
|
|
|